diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/jpeg-xl/lib/jxl | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl')
406 files changed, 111192 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/ac_context.h b/third_party/jpeg-xl/lib/jxl/ac_context.h new file mode 100644 index 0000000000..a2b9e046d1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ac_context.h @@ -0,0 +1,149 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_AC_CONTEXT_H_ +#define LIB_JXL_AC_CONTEXT_H_ + +#include <algorithm> +#include <vector> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" + +namespace jxl { + +// Block context used for scanning order, number of non-zeros, AC coefficients. +// Equal to the channel. +constexpr uint32_t kDCTOrderContextStart = 0; + +// The number of predicted nonzeros goes from 0 to 1008. We use +// ceil(log2(predicted+1)) as a context for the number of nonzeros, so from 0 to +// 10, inclusive. +constexpr uint32_t kNonZeroBuckets = 37; + +static const uint16_t kCoeffFreqContext[64] = { + 0xBAD, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, + 23, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26, 26, 26, + 27, 27, 27, 27, 28, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, +}; + +static const uint16_t kCoeffNumNonzeroContext[64] = { + 0xBAD, 0, 31, 62, 62, 93, 93, 93, 93, 123, 123, 123, 123, + 152, 152, 152, 152, 152, 152, 152, 152, 180, 180, 180, 180, 180, + 180, 180, 180, 180, 180, 180, 180, 206, 206, 206, 206, 206, 206, + 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, + 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, 206, +}; + +// Supremum of ZeroDensityContext(x, y) + 1, when x + y < 64. +constexpr int kZeroDensityContextCount = 458; +// Supremum of ZeroDensityContext(x, y) + 1. +constexpr int kZeroDensityContextLimit = 474; + +/* This function is used for entropy-sources pre-clustering. + * + * Ideally, each combination of |nonzeros_left| and |k| should go to its own + * bucket; but it implies (64 * 63 / 2) == 2016 buckets. If there is other + * dimension (e.g. block context), then number of primary clusters becomes too + * big. + * + * To solve this problem, |nonzeros_left| and |k| values are clustered. It is + * known that their sum is at most 64, consequently, the total number buckets + * is at most A(64) * B(64). + */ +// TODO(user): investigate, why disabling pre-clustering makes entropy code +// less dense. Perhaps we would need to add HQ clustering algorithm that would +// be able to squeeze better by spending more CPU cycles. +static JXL_INLINE size_t ZeroDensityContext(size_t nonzeros_left, size_t k, + size_t covered_blocks, + size_t log2_covered_blocks, + size_t prev) { + JXL_DASSERT((1u << log2_covered_blocks) == covered_blocks); + nonzeros_left = (nonzeros_left + covered_blocks - 1) >> log2_covered_blocks; + k >>= log2_covered_blocks; + JXL_DASSERT(k > 0); + JXL_DASSERT(k < 64); + JXL_DASSERT(nonzeros_left > 0); + // Asserting nonzeros_left + k < 65 here causes crashes in debug mode with + // invalid input, since the (hot) decoding loop does not check this condition. + // As no out-of-bound memory reads are issued even if that condition is + // broken, we check this simpler condition which holds anyway. The decoder + // will still mark a file in which that condition happens as not valid at the + // end of the decoding loop, as `nzeros` will not be `0`. + JXL_DASSERT(nonzeros_left < 64); + return (kCoeffNumNonzeroContext[nonzeros_left] + kCoeffFreqContext[k]) * 2 + + prev; +} + +struct BlockCtxMap { + std::vector<int> dc_thresholds[3]; + std::vector<uint32_t> qf_thresholds; + std::vector<uint8_t> ctx_map; + size_t num_ctxs, num_dc_ctxs; + + static constexpr uint8_t kDefaultCtxMap[] = { + // Default ctx map clusters all the large transforms together. + 0, 1, 2, 2, 3, 3, 4, 5, 6, 6, 6, 6, 6, // + 7, 8, 9, 9, 10, 11, 12, 13, 14, 14, 14, 14, 14, // + 7, 8, 9, 9, 10, 11, 12, 13, 14, 14, 14, 14, 14, // + }; + static_assert(3 * kNumOrders == + sizeof(kDefaultCtxMap) / sizeof *kDefaultCtxMap, + "Update default context map"); + + size_t Context(int dc_idx, uint32_t qf, size_t ord, size_t c) const { + size_t qf_idx = 0; + for (uint32_t t : qf_thresholds) { + if (qf > t) qf_idx++; + } + size_t idx = c < 2 ? c ^ 1 : 2; + idx = idx * kNumOrders + ord; + idx = idx * (qf_thresholds.size() + 1) + qf_idx; + idx = idx * num_dc_ctxs + dc_idx; + return ctx_map[idx]; + } + // Non-zero context is based on number of non-zeros and block context. + // For better clustering, contexts with same number of non-zeros are grouped. + constexpr uint32_t ZeroDensityContextsOffset(uint32_t block_ctx) const { + return num_ctxs * kNonZeroBuckets + kZeroDensityContextCount * block_ctx; + } + + // Context map for AC coefficients consists of 2 blocks: + // |num_ctxs x : context for number of non-zeros in the block + // kNonZeroBuckets| computed from block context and predicted + // value (based top and left values) + // |num_ctxs x : context for AC coefficient symbols, + // kZeroDensityContextCount| computed from block context, + // number of non-zeros left and + // index in scan order + constexpr uint32_t NumACContexts() const { + return num_ctxs * (kNonZeroBuckets + kZeroDensityContextCount); + } + + // Non-zero context is based on number of non-zeros and block context. + // For better clustering, contexts with same number of non-zeros are grouped. + inline uint32_t NonZeroContext(uint32_t non_zeros, uint32_t block_ctx) const { + uint32_t ctx; + if (non_zeros >= 64) non_zeros = 64; + if (non_zeros < 8) { + ctx = non_zeros; + } else { + ctx = 4 + non_zeros / 2; + } + return ctx * num_ctxs + block_ctx; + } + + BlockCtxMap() { + ctx_map.assign(std::begin(kDefaultCtxMap), std::end(kDefaultCtxMap)); + num_ctxs = *std::max_element(ctx_map.begin(), ctx_map.end()) + 1; + num_dc_ctxs = 1; + } +}; + +} // namespace jxl + +#endif // LIB_JXL_AC_CONTEXT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/ac_strategy.cc b/third_party/jpeg-xl/lib/jxl/ac_strategy.cc new file mode 100644 index 0000000000..3de477f71c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ac_strategy.cc @@ -0,0 +1,106 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/ac_strategy.h" + +#include <string.h> + +#include <algorithm> +#include <numeric> // iota +#include <type_traits> +#include <utility> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +// Tries to generalize zig-zag order to non-square blocks. Surprisingly, in +// square block frequency along the (i + j == const) diagonals is roughly the +// same. For historical reasons, consecutive diagonals are traversed +// in alternating directions - so called "zig-zag" (or "snake") order. +template <bool is_lut> +static void CoeffOrderAndLut(AcStrategy acs, coeff_order_t* out) { + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + CoefficientLayout(&cy, &cx); + + // CoefficientLayout ensures cx >= cy. + // We compute the zigzag order for a cx x cx block, then discard all the + // lines that are not multiple of the ratio between cx and cy. + size_t xs = cx / cy; + size_t xsm = xs - 1; + size_t xss = CeilLog2Nonzero(xs); + // First half of the block + size_t cur = cx * cy; + for (size_t i = 0; i < cx * kBlockDim; i++) { + for (size_t j = 0; j <= i; j++) { + size_t x = j; + size_t y = i - j; + if (i % 2) std::swap(x, y); + if ((y & xsm) != 0) continue; + y >>= xss; + size_t val = 0; + if (x < cx && y < cy) { + val = y * cx + x; + } else { + val = cur++; + } + if (is_lut) { + out[y * cx * kBlockDim + x] = val; + } else { + out[val] = y * cx * kBlockDim + x; + } + } + } + // Second half + for (size_t ip = cx * kBlockDim - 1; ip > 0; ip--) { + size_t i = ip - 1; + for (size_t j = 0; j <= i; j++) { + size_t x = cx * kBlockDim - 1 - (i - j); + size_t y = cx * kBlockDim - 1 - j; + if (i % 2) std::swap(x, y); + if ((y & xsm) != 0) continue; + y >>= xss; + size_t val = cur++; + if (is_lut) { + out[y * cx * kBlockDim + x] = val; + } else { + out[val] = y * cx * kBlockDim + x; + } + } + } +} + +void AcStrategy::ComputeNaturalCoeffOrder(coeff_order_t* order) const { + CoeffOrderAndLut</*is_lut=*/false>(*this, order); +} +void AcStrategy::ComputeNaturalCoeffOrderLut(coeff_order_t* lut) const { + CoeffOrderAndLut</*is_lut=*/true>(*this, lut); +} + +// These definitions are needed before C++17. +constexpr size_t AcStrategy::kMaxCoeffBlocks; +constexpr size_t AcStrategy::kMaxBlockDim; +constexpr size_t AcStrategy::kMaxCoeffArea; + +AcStrategyImage::AcStrategyImage(size_t xsize, size_t ysize) + : layers_(xsize, ysize) { + row_ = layers_.Row(0); + stride_ = layers_.PixelsPerRow(); +} + +size_t AcStrategyImage::CountBlocks(AcStrategy::Type type) const { + size_t ret = 0; + for (size_t y = 0; y < layers_.ysize(); y++) { + const uint8_t* JXL_RESTRICT row = layers_.ConstRow(y); + for (size_t x = 0; x < layers_.xsize(); x++) { + if (row[x] == ((static_cast<uint8_t>(type) << 1) | 1)) ret++; + } + } + return ret; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/ac_strategy.h b/third_party/jpeg-xl/lib/jxl/ac_strategy.h new file mode 100644 index 0000000000..ecdcbbbd32 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ac_strategy.h @@ -0,0 +1,261 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_AC_STRATEGY_H_ +#define LIB_JXL_AC_STRATEGY_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <hwy/base.h> // kMaxVectorSize + +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/image_ops.h" + +// Defines the different kinds of transforms, and heuristics to choose between +// them. +// `AcStrategy` represents what transform should be used, and which sub-block of +// that transform we are currently in. Note that DCT4x4 is applied on all four +// 4x4 sub-blocks of an 8x8 block. +// `AcStrategyImage` defines which strategy should be used for each 8x8 block +// of the image. The highest 4 bits represent the strategy to be used, the +// lowest 4 represent the index of the block inside that strategy. + +namespace jxl { + +class AcStrategy { + public: + // Extremal values for the number of blocks/coefficients of a single strategy. + static constexpr size_t kMaxCoeffBlocks = 32; + static constexpr size_t kMaxBlockDim = kBlockDim * kMaxCoeffBlocks; + // Maximum number of coefficients in a block. Guaranteed to be a multiple of + // the vector size. + static constexpr size_t kMaxCoeffArea = kMaxBlockDim * kMaxBlockDim; + static_assert((kMaxCoeffArea * sizeof(float)) % hwy::kMaxVectorSize == 0, + "Coefficient area is not a multiple of vector size"); + + // Raw strategy types. + enum Type : uint32_t { + // Regular block size DCT + DCT = 0, + // Encode pixels without transforming + IDENTITY = 1, + // Use 2-by-2 DCT + DCT2X2 = 2, + // Use 4-by-4 DCT + DCT4X4 = 3, + // Use 16-by-16 DCT + DCT16X16 = 4, + // Use 32-by-32 DCT + DCT32X32 = 5, + // Use 16-by-8 DCT + DCT16X8 = 6, + // Use 8-by-16 DCT + DCT8X16 = 7, + // Use 32-by-8 DCT + DCT32X8 = 8, + // Use 8-by-32 DCT + DCT8X32 = 9, + // Use 32-by-16 DCT + DCT32X16 = 10, + // Use 16-by-32 DCT + DCT16X32 = 11, + // 4x8 and 8x4 DCT + DCT4X8 = 12, + DCT8X4 = 13, + // Corner-DCT. + AFV0 = 14, + AFV1 = 15, + AFV2 = 16, + AFV3 = 17, + // Larger DCTs + DCT64X64 = 18, + DCT64X32 = 19, + DCT32X64 = 20, + DCT128X128 = 21, + DCT128X64 = 22, + DCT64X128 = 23, + DCT256X256 = 24, + DCT256X128 = 25, + DCT128X256 = 26, + // Marker for num of valid strategies. + kNumValidStrategies + }; + + static constexpr uint32_t TypeBit(const Type type) { + return 1u << static_cast<uint32_t>(type); + } + + // Returns true if this block is the first 8x8 block (i.e. top-left) of a + // possibly multi-block strategy. + JXL_INLINE bool IsFirstBlock() const { return is_first_; } + + JXL_INLINE bool IsMultiblock() const { + constexpr uint32_t bits = + TypeBit(Type::DCT16X16) | TypeBit(Type::DCT32X32) | + TypeBit(Type::DCT16X8) | TypeBit(Type::DCT8X16) | + TypeBit(Type::DCT32X8) | TypeBit(Type::DCT8X32) | + TypeBit(Type::DCT16X32) | TypeBit(Type::DCT32X16) | + TypeBit(Type::DCT32X64) | TypeBit(Type::DCT64X32) | + TypeBit(Type::DCT64X64) | TypeBit(DCT64X128) | TypeBit(DCT128X64) | + TypeBit(DCT128X128) | TypeBit(DCT128X256) | TypeBit(DCT256X128) | + TypeBit(DCT256X256); + JXL_DASSERT(Strategy() < kNumValidStrategies); + return ((1u << static_cast<uint32_t>(Strategy())) & bits) != 0; + } + + // Returns the raw strategy value. Should only be used for tokenization. + JXL_INLINE uint8_t RawStrategy() const { + return static_cast<uint8_t>(strategy_); + } + + JXL_INLINE Type Strategy() const { return strategy_; } + + // Inverse check + static JXL_INLINE constexpr bool IsRawStrategyValid(int raw_strategy) { + return raw_strategy < static_cast<int32_t>(kNumValidStrategies) && + raw_strategy >= 0; + } + static JXL_INLINE AcStrategy FromRawStrategy(uint8_t raw_strategy) { + return FromRawStrategy(static_cast<Type>(raw_strategy)); + } + static JXL_INLINE AcStrategy FromRawStrategy(Type raw_strategy) { + JXL_DASSERT(IsRawStrategyValid(static_cast<uint32_t>(raw_strategy))); + return AcStrategy(raw_strategy, /*is_first=*/true); + } + + // "Natural order" means the order of increasing of "anisotropic" frequency of + // continuous version of DCT basis. + // Round-trip, for any given strategy s: + // X = NaturalCoeffOrder(s)[NaturalCoeffOrderLutN(s)[X]] + // X = NaturalCoeffOrderLut(s)[NaturalCoeffOrderN(s)[X]] + void ComputeNaturalCoeffOrder(coeff_order_t* order) const; + void ComputeNaturalCoeffOrderLut(coeff_order_t* lut) const; + + // Number of 8x8 blocks that this strategy will cover. 0 for non-top-left + // blocks inside a multi-block transform. + JXL_INLINE size_t covered_blocks_x() const { + static constexpr uint8_t kLut[] = {1, 1, 1, 1, 2, 4, 1, 2, 1, + 4, 2, 4, 1, 1, 1, 1, 1, 1, + 8, 4, 8, 16, 8, 16, 32, 16, 32}; + static_assert(sizeof(kLut) / sizeof(*kLut) == kNumValidStrategies, + "Update LUT"); + return kLut[size_t(strategy_)]; + } + + JXL_INLINE size_t covered_blocks_y() const { + static constexpr uint8_t kLut[] = {1, 1, 1, 1, 2, 4, 2, 1, 4, + 1, 4, 2, 1, 1, 1, 1, 1, 1, + 8, 8, 4, 16, 16, 8, 32, 32, 16}; + static_assert(sizeof(kLut) / sizeof(*kLut) == kNumValidStrategies, + "Update LUT"); + return kLut[size_t(strategy_)]; + } + + JXL_INLINE size_t log2_covered_blocks() const { + static constexpr uint8_t kLut[] = {0, 0, 0, 0, 2, 4, 1, 1, 2, + 2, 3, 3, 0, 0, 0, 0, 0, 0, + 6, 5, 5, 8, 7, 7, 10, 9, 9}; + static_assert(sizeof(kLut) / sizeof(*kLut) == kNumValidStrategies, + "Update LUT"); + return kLut[size_t(strategy_)]; + } + + private: + friend class AcStrategyRow; + JXL_INLINE AcStrategy(Type strategy, bool is_first) + : strategy_(strategy), is_first_(is_first) { + JXL_DASSERT(IsMultiblock() || is_first == true); + } + + Type strategy_; + bool is_first_; +}; + +// Class to use a certain row of the AC strategy. +class AcStrategyRow { + public: + explicit AcStrategyRow(const uint8_t* row) : row_(row) {} + AcStrategy operator[](size_t x) const { + return AcStrategy(static_cast<AcStrategy::Type>(row_[x] >> 1), row_[x] & 1); + } + + private: + const uint8_t* JXL_RESTRICT row_; +}; + +class AcStrategyImage { + public: + AcStrategyImage() = default; + AcStrategyImage(size_t xsize, size_t ysize); + AcStrategyImage(AcStrategyImage&&) = default; + AcStrategyImage& operator=(AcStrategyImage&&) = default; + + void FillDCT8(const Rect& rect) { + FillPlane<uint8_t>((static_cast<uint8_t>(AcStrategy::Type::DCT) << 1) | 1, + &layers_, rect); + } + void FillDCT8() { FillDCT8(Rect(layers_)); } + + void FillInvalid() { FillImage(INVALID, &layers_); } + + void Set(size_t x, size_t y, AcStrategy::Type type) { +#if JXL_ENABLE_ASSERT + AcStrategy acs = AcStrategy::FromRawStrategy(type); +#endif // JXL_ENABLE_ASSERT + JXL_ASSERT(y + acs.covered_blocks_y() <= layers_.ysize()); + JXL_ASSERT(x + acs.covered_blocks_x() <= layers_.xsize()); + JXL_CHECK(SetNoBoundsCheck(x, y, type, /*check=*/false)); + } + + Status SetNoBoundsCheck(size_t x, size_t y, AcStrategy::Type type, + bool check = true) { + AcStrategy acs = AcStrategy::FromRawStrategy(type); + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + size_t pos = (y + iy) * stride_ + x + ix; + if (check && row_[pos] != INVALID) { + return JXL_FAILURE("Invalid AC strategy: block overlap"); + } + row_[pos] = + (static_cast<uint8_t>(type) << 1) | ((iy | ix) == 0 ? 1 : 0); + } + } + return true; + } + + bool IsValid(size_t x, size_t y) { return row_[y * stride_ + x] != INVALID; } + + AcStrategyRow ConstRow(size_t y, size_t x_prefix = 0) const { + return AcStrategyRow(layers_.ConstRow(y) + x_prefix); + } + + AcStrategyRow ConstRow(const Rect& rect, size_t y) const { + return ConstRow(rect.y0() + y, rect.x0()); + } + + size_t PixelsPerRow() const { return layers_.PixelsPerRow(); } + + size_t xsize() const { return layers_.xsize(); } + size_t ysize() const { return layers_.ysize(); } + + // Count the number of blocks of a given type. + size_t CountBlocks(AcStrategy::Type type) const; + + private: + ImageB layers_; + uint8_t* JXL_RESTRICT row_; + size_t stride_; + + // A value that does not represent a valid combined AC strategy + // value. Used as a sentinel. + static constexpr uint8_t INVALID = 0xFF; +}; + +} // namespace jxl + +#endif // LIB_JXL_AC_STRATEGY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/ac_strategy_test.cc b/third_party/jpeg-xl/lib/jxl/ac_strategy_test.cc new file mode 100644 index 0000000000..3745db2b32 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ac_strategy_test.cc @@ -0,0 +1,256 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/ac_strategy.h" + +#include <string.h> + +#include <cmath> +#include <hwy/aligned_allocator.h> +#include <hwy/base.h> // HWY_ALIGN_MAX +#include <hwy/tests/hwy_gtest.h> +#include <utility> + +#include "lib/jxl/base/random.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_transforms_testonly.h" +#include "lib/jxl/enc_transforms.h" +#include "lib/jxl/simd_util.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +// Test that DCT -> IDCT is a noop. +class AcStrategyRoundtrip : public ::hwy::TestWithParamTargetAndT<int> { + protected: + void Run() { + const AcStrategy::Type type = static_cast<AcStrategy::Type>(GetParam()); + const AcStrategy acs = AcStrategy::FromRawStrategy(type); + const size_t dct_scratch_size = + 3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim; + + auto mem = hwy::AllocateAligned<float>(4 * AcStrategy::kMaxCoeffArea + + dct_scratch_size); + float* coeffs = mem.get(); + float* idct = coeffs + AcStrategy::kMaxCoeffArea; + float* input = idct + AcStrategy::kMaxCoeffArea; + float* scratch_space = input + AcStrategy::kMaxCoeffArea; + + Rng rng(type * 65537 + 13); + + for (size_t j = 0; j < 64; j++) { + size_t i = (acs.log2_covered_blocks() + ? rng.UniformU(0, 64u << acs.log2_covered_blocks()) + : j); + std::fill_n(input, AcStrategy::kMaxCoeffArea, 0); + input[i] = 0.2f; + TransformFromPixels(type, input, acs.covered_blocks_x() * 8, coeffs, + scratch_space); + ASSERT_NEAR(coeffs[0], 0.2 / (64 << acs.log2_covered_blocks()), 1e-6) + << " i = " << i; + TransformToPixels(type, coeffs, idct, acs.covered_blocks_x() * 8, + scratch_space); + for (size_t j = 0; j < 64u << acs.log2_covered_blocks(); j++) { + ASSERT_NEAR(idct[j], j == i ? 0.2f : 0, 2e-6) + << "j = " << j << " i = " << i << " acs " << type; + } + } + // Test DC. + std::fill_n(idct, AcStrategy::kMaxCoeffArea, 0); + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + float* dc = idct + AcStrategy::kMaxCoeffArea; + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2; + LowestFrequenciesFromDC(type, dc, acs.covered_blocks_x() * 8, coeffs, + scratch_space); + DCFromLowestFrequencies(type, coeffs, idct, acs.covered_blocks_x() * 8); + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2; + for (size_t j = 0; j < 64u << acs.log2_covered_blocks(); j++) { + ASSERT_NEAR(idct[j], dc[j], 1e-6) + << "j = " << j << " x = " << x << " y = " << y << " acs " << type; + } + } + } + } +}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T( + AcStrategyRoundtrip, + ::testing::Range(0, int(AcStrategy::Type::kNumValidStrategies))); + +TEST_P(AcStrategyRoundtrip, Test) { Run(); } + +// Test that DC(2x2) -> DCT coefficients -> IDCT -> downsampled IDCT is a noop. +class AcStrategyRoundtripDownsample + : public ::hwy::TestWithParamTargetAndT<int> { + protected: + void Run() { + const AcStrategy::Type type = static_cast<AcStrategy::Type>(GetParam()); + const AcStrategy acs = AcStrategy::FromRawStrategy(type); + const size_t dct_scratch_size = + 3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim; + + auto mem = hwy::AllocateAligned<float>(4 * AcStrategy::kMaxCoeffArea + + dct_scratch_size); + float* coeffs = mem.get(); + float* idct = coeffs + AcStrategy::kMaxCoeffArea; + float* dc = idct + AcStrategy::kMaxCoeffArea; + float* scratch_space = dc + AcStrategy::kMaxCoeffArea; + + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0.0f); + Rng rng(type * 65537 + 13); + + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + if (x > 4 || y > 4) { + if (rng.Bernoulli(0.9f)) continue; + } + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2f; + LowestFrequenciesFromDC(type, dc, acs.covered_blocks_x() * 8, coeffs, + scratch_space); + TransformToPixels(type, coeffs, idct, acs.covered_blocks_x() * 8, + scratch_space); + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0.0f); + std::fill_n(dc, AcStrategy::kMaxCoeffArea, 0); + dc[y * acs.covered_blocks_x() * 8 + x] = 0.2f; + // Downsample + for (size_t dy = 0; dy < acs.covered_blocks_y(); dy++) { + for (size_t dx = 0; dx < acs.covered_blocks_x(); dx++) { + float sum = 0; + for (size_t iy = 0; iy < 8; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + sum += idct[(dy * 8 + iy) * 8 * acs.covered_blocks_x() + + dx * 8 + ix]; + } + } + sum /= 64.0f; + ASSERT_NEAR(sum, dc[dy * 8 * acs.covered_blocks_x() + dx], 1e-6) + << "acs " << type; + } + } + } + } + } +}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T( + AcStrategyRoundtripDownsample, + ::testing::Range(0, int(AcStrategy::Type::kNumValidStrategies))); + +TEST_P(AcStrategyRoundtripDownsample, Test) { Run(); } + +// Test that IDCT(block with zeros in the non-topleft corner) -> downsampled +// IDCT is the same as IDCT -> DC(2x2) of the same block. +class AcStrategyDownsample : public ::hwy::TestWithParamTargetAndT<int> { + protected: + void Run() { + const AcStrategy::Type type = static_cast<AcStrategy::Type>(GetParam()); + const AcStrategy acs = AcStrategy::FromRawStrategy(type); + const size_t dct_scratch_size = + 3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim; + size_t cx = acs.covered_blocks_y(); + size_t cy = acs.covered_blocks_x(); + CoefficientLayout(&cy, &cx); + + auto mem = hwy::AllocateAligned<float>(4 * AcStrategy::kMaxCoeffArea + + dct_scratch_size); + float* idct = mem.get(); + float* idct_acs_downsampled = idct + AcStrategy::kMaxCoeffArea; + float* coeffs = idct + AcStrategy::kMaxCoeffArea; + float* scratch_space = coeffs + AcStrategy::kMaxCoeffArea; + + Rng rng(type * 65537 + 13); + + for (size_t y = 0; y < cy; y++) { + for (size_t x = 0; x < cx; x++) { + if (x > 4 || y > 4) { + if (rng.Bernoulli(0.9f)) continue; + } + float* coeffs = idct + AcStrategy::kMaxCoeffArea; + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0); + coeffs[y * cx * 8 + x] = 0.2f; + TransformToPixels(type, coeffs, idct, acs.covered_blocks_x() * 8, + scratch_space); + std::fill_n(coeffs, AcStrategy::kMaxCoeffArea, 0); + coeffs[y * cx * 8 + x] = 0.2f; + DCFromLowestFrequencies(type, coeffs, idct_acs_downsampled, + acs.covered_blocks_x() * 8); + // Downsample + for (size_t dy = 0; dy < acs.covered_blocks_y(); dy++) { + for (size_t dx = 0; dx < acs.covered_blocks_x(); dx++) { + float sum = 0; + for (size_t iy = 0; iy < 8; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + sum += idct[(dy * 8 + iy) * 8 * acs.covered_blocks_x() + + dx * 8 + ix]; + } + } + sum /= 64; + ASSERT_NEAR( + sum, idct_acs_downsampled[dy * 8 * acs.covered_blocks_x() + dx], + 1e-6) + << " acs " << type; + } + } + } + } + } +}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T( + AcStrategyDownsample, + ::testing::Range(0, int(AcStrategy::Type::kNumValidStrategies))); + +TEST_P(AcStrategyDownsample, Test) { Run(); } + +class AcStrategyTargetTest : public ::hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(AcStrategyTargetTest); + +TEST_P(AcStrategyTargetTest, RoundtripAFVDCT) { + HWY_ALIGN_MAX float idct[16]; + for (size_t i = 0; i < 16; i++) { + HWY_ALIGN_MAX float pixels[16] = {}; + pixels[i] = 1; + HWY_ALIGN_MAX float coeffs[16] = {}; + + AFVDCT4x4(pixels, coeffs); + AFVIDCT4x4(coeffs, idct); + for (size_t j = 0; j < 16; j++) { + EXPECT_NEAR(idct[j], pixels[j], 1e-6); + } + } +} + +TEST_P(AcStrategyTargetTest, BenchmarkAFV) { + const AcStrategy::Type type = AcStrategy::Type::AFV0; + HWY_ALIGN_MAX float pixels[64] = {1}; + HWY_ALIGN_MAX float coeffs[64] = {}; + const size_t dct_scratch_size = + 3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim; + auto mem = hwy::AllocateAligned<float>(64 + dct_scratch_size); + float* scratch_space = mem.get(); + for (size_t i = 0; i < 1 << 14; i++) { + TransformToPixels(type, coeffs, pixels, 8, scratch_space); + TransformFromPixels(type, pixels, 8, coeffs, scratch_space); + } + EXPECT_NEAR(pixels[0], 0.0, 1E-6); +} + +TEST_P(AcStrategyTargetTest, BenchmarkAFVDCT) { + HWY_ALIGN_MAX float pixels[64] = {1}; + HWY_ALIGN_MAX float coeffs[64] = {}; + for (size_t i = 0; i < 1 << 14; i++) { + AFVDCT4x4(pixels, coeffs); + AFVIDCT4x4(coeffs, pixels); + } + EXPECT_NEAR(pixels[0], 1.0, 1E-6); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/alpha.cc b/third_party/jpeg-xl/lib/jxl/alpha.cc new file mode 100644 index 0000000000..48d7e7ee92 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/alpha.cc @@ -0,0 +1,115 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/alpha.h" + +#include <string.h> + +#include <algorithm> + +namespace jxl { + +static float Clamp(float x) { return std::max(std::min(1.0f, x), 0.0f); } + +void PerformAlphaBlending(const AlphaBlendingInputLayer& bg, + const AlphaBlendingInputLayer& fg, + const AlphaBlendingOutput& out, size_t num_pixels, + bool alpha_is_premultiplied, bool clamp) { + if (alpha_is_premultiplied) { + for (size_t x = 0; x < num_pixels; ++x) { + float fga = clamp ? Clamp(fg.a[x]) : fg.a[x]; + out.r[x] = (fg.r[x] + bg.r[x] * (1.f - fga)); + out.g[x] = (fg.g[x] + bg.g[x] * (1.f - fga)); + out.b[x] = (fg.b[x] + bg.b[x] * (1.f - fga)); + out.a[x] = (1.f - (1.f - fga) * (1.f - bg.a[x])); + } + } else { + for (size_t x = 0; x < num_pixels; ++x) { + float fga = clamp ? Clamp(fg.a[x]) : fg.a[x]; + const float new_a = 1.f - (1.f - fga) * (1.f - bg.a[x]); + const float rnew_a = (new_a > 0 ? 1.f / new_a : 0.f); + out.r[x] = (fg.r[x] * fga + bg.r[x] * bg.a[x] * (1.f - fga)) * rnew_a; + out.g[x] = (fg.g[x] * fga + bg.g[x] * bg.a[x] * (1.f - fga)) * rnew_a; + out.b[x] = (fg.b[x] * fga + bg.b[x] * bg.a[x] * (1.f - fga)) * rnew_a; + out.a[x] = new_a; + } + } +} +void PerformAlphaBlending(const float* bg, const float* bga, const float* fg, + const float* fga, float* out, size_t num_pixels, + bool alpha_is_premultiplied, bool clamp) { + if (bg == bga && fg == fga) { + for (size_t x = 0; x < num_pixels; ++x) { + float fa = clamp ? fga[x] : Clamp(fga[x]); + out[x] = (1.f - (1.f - fa) * (1.f - bga[x])); + } + } else { + if (alpha_is_premultiplied) { + for (size_t x = 0; x < num_pixels; ++x) { + float fa = clamp ? fga[x] : Clamp(fga[x]); + out[x] = (fg[x] + bg[x] * (1.f - fa)); + } + } else { + for (size_t x = 0; x < num_pixels; ++x) { + float fa = clamp ? fga[x] : Clamp(fga[x]); + const float new_a = 1.f - (1.f - fa) * (1.f - bga[x]); + const float rnew_a = (new_a > 0 ? 1.f / new_a : 0.f); + out[x] = (fg[x] * fa + bg[x] * bga[x] * (1.f - fa)) * rnew_a; + } + } + } +} + +void PerformAlphaWeightedAdd(const float* bg, const float* fg, const float* fga, + float* out, size_t num_pixels, bool clamp) { + if (fg == fga) { + memcpy(out, bg, num_pixels * sizeof(*out)); + } else if (clamp) { + for (size_t x = 0; x < num_pixels; ++x) { + out[x] = bg[x] + fg[x] * Clamp(fga[x]); + } + } else { + for (size_t x = 0; x < num_pixels; ++x) { + out[x] = bg[x] + fg[x] * fga[x]; + } + } +} + +void PerformMulBlending(const float* bg, const float* fg, float* out, + size_t num_pixels, bool clamp) { + if (clamp) { + for (size_t x = 0; x < num_pixels; ++x) { + out[x] = bg[x] * Clamp(fg[x]); + } + } else { + for (size_t x = 0; x < num_pixels; ++x) { + out[x] = bg[x] * fg[x]; + } + } +} + +void PremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels) { + for (size_t x = 0; x < num_pixels; ++x) { + const float multiplier = std::max(kSmallAlpha, a[x]); + r[x] *= multiplier; + g[x] *= multiplier; + b[x] *= multiplier; + } +} + +void UnpremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels) { + for (size_t x = 0; x < num_pixels; ++x) { + const float multiplier = 1.f / std::max(kSmallAlpha, a[x]); + r[x] *= multiplier; + g[x] *= multiplier; + b[x] *= multiplier; + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/alpha.h b/third_party/jpeg-xl/lib/jxl/alpha.h new file mode 100644 index 0000000000..efb76c800f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/alpha.h @@ -0,0 +1,66 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ALPHA_H_ +#define LIB_JXL_ALPHA_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <limits> + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// A very small value to avoid divisions by zero when converting to +// unpremultiplied alpha. Page 21 of the technical introduction to OpenEXR +// (https://www.openexr.com/documentation/TechnicalIntroduction.pdf) recommends +// "a power of two" that is "less than half of the smallest positive 16-bit +// floating-point value". That smallest value happens to be the denormal number +// 2^-24, so 2^-26 should be a good choice. +static constexpr float kSmallAlpha = 1.f / (1u << 26u); + +struct AlphaBlendingInputLayer { + const float* r; + const float* g; + const float* b; + const float* a; +}; + +struct AlphaBlendingOutput { + float* r; + float* g; + float* b; + float* a; +}; + +// Note: The pointers in `out` are allowed to alias those in `bg` or `fg`. +// No pointer shall be null. +void PerformAlphaBlending(const AlphaBlendingInputLayer& bg, + const AlphaBlendingInputLayer& fg, + const AlphaBlendingOutput& out, size_t num_pixels, + bool alpha_is_premultiplied, bool clamp); +// Single plane alpha blending +void PerformAlphaBlending(const float* bg, const float* bga, const float* fg, + const float* fga, float* out, size_t num_pixels, + bool alpha_is_premultiplied, bool clamp); + +void PerformAlphaWeightedAdd(const float* bg, const float* fg, const float* fga, + float* out, size_t num_pixels, bool clamp); + +void PerformMulBlending(const float* bg, const float* fg, float* out, + size_t num_pixels, bool clamp); + +void PremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels); +void UnpremultiplyAlpha(float* JXL_RESTRICT r, float* JXL_RESTRICT g, + float* JXL_RESTRICT b, const float* JXL_RESTRICT a, + size_t num_pixels); + +} // namespace jxl + +#endif // LIB_JXL_ALPHA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/alpha_test.cc b/third_party/jpeg-xl/lib/jxl/alpha_test.cc new file mode 100644 index 0000000000..ddafd829ec --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/alpha_test.cc @@ -0,0 +1,134 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/alpha.h" + +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::FloatNear; + +TEST(AlphaTest, BlendingWithNonPremultiplied) { + const float bg_rgb[3] = {100, 110, 120}; + const float bg_a = 180.f / 255; + const float fg_rgb[3] = {25, 21, 23}; + const float fg_a = 15420.f / 65535; + const float fg_a2 = 2.0f; + float out_rgb[3]; + float out_a; + PerformAlphaBlending( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1, + /*alpha_is_premultiplied=*/false, /*clamp=*/false); + EXPECT_THAT(out_rgb, + ElementsAre(FloatNear(77.2f, .05f), FloatNear(83.0f, .05f), + FloatNear(90.6f, .05f))); + EXPECT_NEAR(out_a, 3174.f / 4095, 1e-5); + PerformAlphaBlending( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a2}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1, + /*alpha_is_premultiplied=*/false, /*clamp=*/true); + EXPECT_THAT(out_rgb, ElementsAre(FloatNear(fg_rgb[0], .05f), + FloatNear(fg_rgb[1], .05f), + FloatNear(fg_rgb[2], .05f))); + EXPECT_NEAR(out_a, 1.0f, 1e-5); +} + +TEST(AlphaTest, BlendingWithPremultiplied) { + const float bg_rgb[3] = {100, 110, 120}; + const float bg_a = 180.f / 255; + const float fg_rgb[3] = {25, 21, 23}; + const float fg_a = 15420.f / 65535; + const float fg_a2 = 2.0f; + float out_rgb[3]; + float out_a; + PerformAlphaBlending( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1, + /*alpha_is_premultiplied=*/true, /*clamp=*/false); + EXPECT_THAT(out_rgb, + ElementsAre(FloatNear(101.5f, .05f), FloatNear(105.1f, .05f), + FloatNear(114.8f, .05f))); + EXPECT_NEAR(out_a, 3174.f / 4095, 1e-5); + PerformAlphaBlending( + /*bg=*/{&bg_rgb[0], &bg_rgb[1], &bg_rgb[2], &bg_a}, + /*fg=*/{&fg_rgb[0], &fg_rgb[1], &fg_rgb[2], &fg_a2}, + /*out=*/{&out_rgb[0], &out_rgb[1], &out_rgb[2], &out_a}, 1, + /*alpha_is_premultiplied=*/true, /*clamp=*/true); + EXPECT_THAT(out_rgb, ElementsAre(FloatNear(fg_rgb[0], .05f), + FloatNear(fg_rgb[1], .05f), + FloatNear(fg_rgb[2], .05f))); + EXPECT_NEAR(out_a, 1.0f, 1e-5); +} + +TEST(AlphaTest, Mul) { + const float bg = 100; + const float fg = 25; + float out; + PerformMulBlending(&bg, &fg, &out, 1, /*clamp=*/false); + EXPECT_THAT(out, FloatNear(fg * bg, .05f)); + PerformMulBlending(&bg, &fg, &out, 1, /*clamp=*/true); + EXPECT_THAT(out, FloatNear(bg, .05f)); +} + +TEST(AlphaTest, PremultiplyAndUnpremultiply) { + const float alpha[] = {0.f, 63.f / 255, 127.f / 255, 1.f}; + float r[] = {120, 130, 140, 150}; + float g[] = {124, 134, 144, 154}; + float b[] = {127, 137, 147, 157}; + + PremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT( + r, ElementsAre(FloatNear(0.f, 1e-5f), FloatNear(130 * 63.f / 255, 1e-5f), + FloatNear(140 * 127.f / 255, 1e-5f), 150)); + EXPECT_THAT( + g, ElementsAre(FloatNear(0.f, 1e-5f), FloatNear(134 * 63.f / 255, 1e-5f), + FloatNear(144 * 127.f / 255, 1e-5f), 154)); + EXPECT_THAT( + b, ElementsAre(FloatNear(0.f, 1e-5f), FloatNear(137 * 63.f / 255, 1e-5f), + FloatNear(147 * 127.f / 255, 1e-5f), 157)); + + UnpremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT(r, ElementsAre(FloatNear(120, 1e-4f), FloatNear(130, 1e-4f), + FloatNear(140, 1e-4f), FloatNear(150, 1e-4f))); + EXPECT_THAT(g, ElementsAre(FloatNear(124, 1e-4f), FloatNear(134, 1e-4f), + FloatNear(144, 1e-4f), FloatNear(154, 1e-4f))); + EXPECT_THAT(b, ElementsAre(FloatNear(127, 1e-4f), FloatNear(137, 1e-4f), + FloatNear(147, 1e-4f), FloatNear(157, 1e-4f))); +} + +TEST(AlphaTest, UnpremultiplyAndPremultiply) { + const float alpha[] = {0.f, 63.f / 255, 127.f / 255, 1.f}; + float r[] = {50, 60, 70, 80}; + float g[] = {54, 64, 74, 84}; + float b[] = {57, 67, 77, 87}; + + UnpremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT(r, ElementsAre(_, FloatNear(60 * 255.f / 63, 1e-4f), + FloatNear(70 * 255.f / 127, 1e-4f), 80)); + EXPECT_THAT(g, ElementsAre(_, FloatNear(64 * 255.f / 63, 1e-4f), + FloatNear(74 * 255.f / 127, 1e-4f), 84)); + EXPECT_THAT(b, ElementsAre(_, FloatNear(67 * 255.f / 63, 1e-4f), + FloatNear(77 * 255.f / 127, 1e-4f), 87)); + + PremultiplyAlpha(r, g, b, alpha, 4); + EXPECT_THAT(r, ElementsAre(FloatNear(50, 1e-4f), FloatNear(60, 1e-4f), + FloatNear(70, 1e-4f), FloatNear(80, 1e-4f))); + EXPECT_THAT(g, ElementsAre(FloatNear(54, 1e-4f), FloatNear(64, 1e-4f), + FloatNear(74, 1e-4f), FloatNear(84, 1e-4f))); + EXPECT_THAT(b, ElementsAre(FloatNear(57, 1e-4f), FloatNear(67, 1e-4f), + FloatNear(77, 1e-4f), FloatNear(87, 1e-4f))); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/ans_common.cc b/third_party/jpeg-xl/lib/jxl/ans_common.cc new file mode 100644 index 0000000000..8e52cad0e8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ans_common.cc @@ -0,0 +1,148 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/ans_common.h" + +#include <numeric> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +std::vector<int32_t> CreateFlatHistogram(int length, int total_count) { + JXL_ASSERT(length > 0); + JXL_ASSERT(length <= total_count); + const int count = total_count / length; + std::vector<int32_t> result(length, count); + const int rem_counts = total_count % length; + for (int i = 0; i < rem_counts; ++i) { + ++result[i]; + } + return result; +} + +// First, all trailing non-occurring symbols are removed from the distribution; +// if this leaves the distribution empty, a placeholder symbol with max weight +// is added. This ensures that the resulting distribution sums to total table +// size. Then, `entry_size` is chosen to be the largest power of two so that +// `table_size` = ANS_TAB_SIZE/`entry_size` is at least as big as the +// distribution size. +// Note that each entry will only ever contain two different symbols, and +// consecutive ranges of offsets, which allows us to use a compact +// representation. +// Each entry is initialized with only the (symbol=i, offset) pairs; then +// positions for which the entry overflows (i.e. distribution[i] > entry_size) +// or is not full are computed, and put into a stack in increasing order. +// Missing symbols in the distribution are padded with 0 (because `table_size` +// >= number of symbols). The `cutoff` value for each entry is initialized to +// the number of occupied slots in that entry (i.e. `distributions[i]`). While +// the overflowing-symbol stack is not empty (which implies that the +// underflowing-symbol stack also is not), the top overfull and underfull +// positions are popped from the stack; the empty slots in the underfull entry +// are then filled with as many slots as needed from the overfull entry; such +// slots are placed after the slots in the overfull entry, and `offsets[1]` is +// computed accordingly. The formerly underfull entry is thus now neither +// underfull nor overfull, and represents exactly two symbols. The overfull +// entry might be either overfull or underfull, and is pushed into the +// corresponding stack. +void InitAliasTable(std::vector<int32_t> distribution, uint32_t range, + size_t log_alpha_size, AliasTable::Entry* JXL_RESTRICT a) { + while (!distribution.empty() && distribution.back() == 0) { + distribution.pop_back(); + } + // Ensure that a valid table is always returned, even for an empty + // alphabet. Otherwise, a specially-crafted stream might crash the + // decoder. + if (distribution.empty()) { + distribution.emplace_back(range); + } + const size_t table_size = 1 << log_alpha_size; +#if JXL_ENABLE_ASSERT + int sum = std::accumulate(distribution.begin(), distribution.end(), 0); +#endif // JXL_ENABLE_ASSERT + JXL_ASSERT(static_cast<uint32_t>(sum) == range); + // range must be a power of two + JXL_ASSERT((range & (range - 1)) == 0); + JXL_ASSERT(distribution.size() <= table_size); + JXL_ASSERT(table_size <= range); + const uint32_t entry_size = range >> log_alpha_size; // this is exact + // Special case for single-symbol distributions, that ensures that the state + // does not change when decoding from such a distribution. Note that, since we + // hardcode offset0 == 0, it is not straightforward (if at all possible) to + // fix the general case to produce this result. + for (size_t sym = 0; sym < distribution.size(); sym++) { + if (distribution[sym] == ANS_TAB_SIZE) { + for (size_t i = 0; i < table_size; i++) { + a[i].right_value = sym; + a[i].cutoff = 0; + a[i].offsets1 = entry_size * i; + a[i].freq0 = 0; + a[i].freq1_xor_freq0 = ANS_TAB_SIZE; + } + return; + } + } + std::vector<uint32_t> underfull_posn; + std::vector<uint32_t> overfull_posn; + std::vector<uint32_t> cutoffs(1 << log_alpha_size); + // Initialize entries. + for (size_t i = 0; i < distribution.size(); i++) { + cutoffs[i] = distribution[i]; + if (cutoffs[i] > entry_size) { + overfull_posn.push_back(i); + } else if (cutoffs[i] < entry_size) { + underfull_posn.push_back(i); + } + } + for (uint32_t i = distribution.size(); i < table_size; i++) { + cutoffs[i] = 0; + underfull_posn.push_back(i); + } + // Reassign overflow/underflow values. + while (!overfull_posn.empty()) { + uint32_t overfull_i = overfull_posn.back(); + overfull_posn.pop_back(); + JXL_ASSERT(!underfull_posn.empty()); + uint32_t underfull_i = underfull_posn.back(); + underfull_posn.pop_back(); + uint32_t underfull_by = entry_size - cutoffs[underfull_i]; + cutoffs[overfull_i] -= underfull_by; + // overfull positions have their original symbols + a[underfull_i].right_value = overfull_i; + a[underfull_i].offsets1 = cutoffs[overfull_i]; + // Slots in the right part of entry underfull_i were taken from the end + // of the symbols in entry overfull_i. + if (cutoffs[overfull_i] < entry_size) { + underfull_posn.push_back(overfull_i); + } else if (cutoffs[overfull_i] > entry_size) { + overfull_posn.push_back(overfull_i); + } + } + for (uint32_t i = 0; i < table_size; i++) { + // cutoffs[i] is properly initialized but the clang-analyzer doesn't infer + // it since it is partially initialized across two for-loops. + // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) + if (cutoffs[i] == entry_size) { + a[i].right_value = i; + a[i].offsets1 = 0; + a[i].cutoff = 0; + } else { + // Note that, if cutoff is not equal to entry_size, + // a[i].offsets1 was initialized with (overfull cutoff) - + // (entry_size - a[i].cutoff). Thus, subtracting + // a[i].cutoff cannot make it negative. + a[i].offsets1 -= cutoffs[i]; + a[i].cutoff = cutoffs[i]; + } + const size_t freq0 = i < distribution.size() ? distribution[i] : 0; + const size_t i1 = a[i].right_value; + const size_t freq1 = i1 < distribution.size() ? distribution[i1] : 0; + a[i].freq0 = static_cast<uint16_t>(freq0); + a[i].freq1_xor_freq0 = static_cast<uint16_t>(freq1 ^ freq0); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/ans_common.h b/third_party/jpeg-xl/lib/jxl/ans_common.h new file mode 100644 index 0000000000..fb5058e310 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ans_common.h @@ -0,0 +1,143 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ANS_COMMON_H_ +#define LIB_JXL_ANS_COMMON_H_ + +#include <stdint.h> + +#include <algorithm> +#include <hwy/cache_control.h> // Prefetch +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Returns the precision (number of bits) that should be used to store +// a histogram count such that Log2Floor(count) == logcount. +static JXL_INLINE uint32_t GetPopulationCountPrecision(uint32_t logcount, + uint32_t shift) { + int32_t r = std::min<int>( + logcount, int(shift) - int((ANS_LOG_TAB_SIZE - logcount) >> 1)); + if (r < 0) return 0; + return r; +} + +// Returns a histogram where the counts are positive, differ by at most 1, +// and add up to total_count. The bigger counts (if any) are at the beginning +// of the histogram. +std::vector<int32_t> CreateFlatHistogram(int length, int total_count); + +// An alias table implements a mapping from the [0, ANS_TAB_SIZE) range into +// the [0, ANS_MAX_ALPHABET_SIZE) range, satisfying the following conditions: +// - each symbol occurs as many times as specified by any valid distribution +// of frequencies of the symbols. A valid distribution here is an array of +// ANS_MAX_ALPHABET_SIZE that contains numbers in the range [0, ANS_TAB_SIZE], +// and whose sum is ANS_TAB_SIZE. +// - lookups can be done in constant time, and also return how many smaller +// input values map into the same symbol, according to some well-defined order +// of input values. +// - the space used by the alias table is given by a small constant times the +// index of the largest symbol with nonzero probability in the distribution. +// Each of the entries in the table covers a range of `entry_size` values in the +// [0, ANS_TAB_SIZE) range; consecutive entries represent consecutive +// sub-ranges. In the range covered by entry `i`, the first `cutoff` values map +// to symbol `i`, while the others map to symbol `right_value`. +// +// TODO(veluca): consider making the order used for computing offsets easier to +// define - it is currently defined by the algorithm to compute the alias table. +// Beware of breaking the implicit assumption that symbols that come after the +// cutoff value should have an offset at least as big as the cutoff. + +struct AliasTable { + struct Symbol { + size_t value; + size_t offset; + size_t freq; + }; + +// Working set size matters here (~64 tables x 256 entries). +// offsets0 is always zero (beginning of [0] side among the same symbol). +// offsets1 is an offset of (pos >= cutoff) side decremented by cutoff. +#pragma pack(push, 1) + struct Entry { + uint8_t cutoff; // < kEntrySizeMinus1 when used by ANS. + uint8_t right_value; // < alphabet size. + uint16_t freq0; + + // Only used if `greater` (see Lookup) + uint16_t offsets1; // <= ANS_TAB_SIZE + uint16_t freq1_xor_freq0; // for branchless ternary in Lookup + }; +#pragma pack(pop) + + // Dividing `value` by `entry_size` determines `i`, the entry which is + // responsible for the input. If the remainder is below `cutoff`, then the + // mapped symbol is `i`; since `offsets[0]` stores the number of occurrences + // of `i` "before" the start of this entry, the offset of the input will be + // `offsets[0] + remainder`. If the remainder is above cutoff, the mapped + // symbol is `right_value`; since `offsets[1]` stores the number of + // occurrences of `right_value` "before" this entry, minus the `cutoff` value, + // the input offset is then `remainder + offsets[1]`. + static JXL_INLINE Symbol Lookup(const Entry* JXL_RESTRICT table, size_t value, + size_t log_entry_size, + size_t entry_size_minus_1) { + const size_t i = value >> log_entry_size; + const size_t pos = value & entry_size_minus_1; + +#if JXL_BYTE_ORDER_LITTLE + uint64_t entry; + memcpy(&entry, &table[i].cutoff, sizeof(entry)); + const size_t cutoff = entry & 0xFF; // = MOVZX + const size_t right_value = (entry >> 8) & 0xFF; // = MOVZX + const size_t freq0 = (entry >> 16) & 0xFFFF; +#else + // Generates multiple loads with complex addressing. + const size_t cutoff = table[i].cutoff; + const size_t right_value = table[i].right_value; + const size_t freq0 = table[i].freq0; +#endif + + const bool greater = pos >= cutoff; + +#if JXL_BYTE_ORDER_LITTLE + const uint64_t conditional = greater ? entry : 0; // = CMOV + const size_t offsets1_or_0 = (conditional >> 32) & 0xFFFF; + const size_t freq1_xor_freq0_or_0 = conditional >> 48; +#else + const size_t offsets1_or_0 = greater ? table[i].offsets1 : 0; + const size_t freq1_xor_freq0_or_0 = greater ? table[i].freq1_xor_freq0 : 0; +#endif + + // WARNING: moving this code may interfere with CMOV heuristics. + Symbol s; + s.value = greater ? right_value : i; + s.offset = offsets1_or_0 + pos; + s.freq = freq0 ^ freq1_xor_freq0_or_0; // = greater ? freq1 : freq0 + // XOR avoids implementation-defined conversion from unsigned to signed. + // Alternatives considered: BEXTR is 2 cycles on HSW, SET+shift causes + // spills, simple ternary has a long dependency chain. + + return s; + } + + static HWY_INLINE void Prefetch(const Entry* JXL_RESTRICT table, size_t value, + size_t log_entry_size) { + const size_t i = value >> log_entry_size; + hwy::Prefetch(table + i); + } +}; + +// Computes an alias table for a given distribution. +void InitAliasTable(std::vector<int32_t> distribution, uint32_t range, + size_t log_alpha_size, AliasTable::Entry* JXL_RESTRICT a); + +} // namespace jxl + +#endif // LIB_JXL_ANS_COMMON_H_ diff --git a/third_party/jpeg-xl/lib/jxl/ans_common_test.cc b/third_party/jpeg-xl/lib/jxl/ans_common_test.cc new file mode 100644 index 0000000000..487b6cf5bd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ans_common_test.cc @@ -0,0 +1,43 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/ans_common.h" + +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +void VerifyAliasDistribution(const std::vector<int>& distribution, + uint32_t range) { + constexpr size_t log_alpha_size = 8; + AliasTable::Entry table[1 << log_alpha_size]; + InitAliasTable(distribution, range, log_alpha_size, table); + std::vector<std::vector<uint32_t>> offsets(distribution.size()); + for (uint32_t i = 0; i < range; i++) { + AliasTable::Symbol s = AliasTable::Lookup( + table, i, ANS_LOG_TAB_SIZE - 8, (1 << (ANS_LOG_TAB_SIZE - 8)) - 1); + offsets[s.value].push_back(s.offset); + } + for (uint32_t i = 0; i < distribution.size(); i++) { + ASSERT_EQ(static_cast<size_t>(distribution[i]), offsets[i].size()); + std::sort(offsets[i].begin(), offsets[i].end()); + for (uint32_t j = 0; j < offsets[i].size(); j++) { + ASSERT_EQ(offsets[i][j], j); + } + } +} + +TEST(ANSCommonTest, AliasDistributionSmoke) { + VerifyAliasDistribution({ANS_TAB_SIZE / 2, ANS_TAB_SIZE / 2}, ANS_TAB_SIZE); + VerifyAliasDistribution({ANS_TAB_SIZE}, ANS_TAB_SIZE); + VerifyAliasDistribution({0, 0, 0, ANS_TAB_SIZE, 0}, ANS_TAB_SIZE); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/ans_params.h b/third_party/jpeg-xl/lib/jxl/ans_params.h new file mode 100644 index 0000000000..4bbc284c0b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ans_params.h @@ -0,0 +1,36 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ANS_PARAMS_H_ +#define LIB_JXL_ANS_PARAMS_H_ + +// Common parameters that are needed for both the ANS entropy encoding and +// decoding methods. + +#include <stdint.h> +#include <stdlib.h> + +namespace jxl { + +// TODO(veluca): decide if 12 is the best constant here (valid range is up to +// 16). This requires recomputing the Huffman tables in {enc,dec}_ans.cc +// 14 gives a 0.2% improvement at d1 and makes d8 slightly worse. This is +// likely not worth the increase in encoder complexity. +#define ANS_LOG_TAB_SIZE 12u +#define ANS_TAB_SIZE (1 << ANS_LOG_TAB_SIZE) +#define ANS_TAB_MASK (ANS_TAB_SIZE - 1) + +// Largest possible symbol to be encoded by either ANS or prefix coding. +#define PREFIX_MAX_ALPHABET_SIZE 4096 +#define ANS_MAX_ALPHABET_SIZE 256 + +// Max number of bits for prefix coding. +#define PREFIX_MAX_BITS 15 + +#define ANS_SIGNATURE 0x13 // Initial state, used as CRC. + +} // namespace jxl + +#endif // LIB_JXL_ANS_PARAMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/ans_test.cc b/third_party/jpeg-xl/lib/jxl/ans_test.cc new file mode 100644 index 0000000000..c28daf7b85 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/ans_test.cc @@ -0,0 +1,279 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +void RoundtripTestcase(int n_histograms, int alphabet_size, + const std::vector<Token>& input_values) { + constexpr uint16_t kMagic1 = 0x9e33; + constexpr uint16_t kMagic2 = 0x8b04; + + BitWriter writer; + // Space for magic bytes. + BitWriter::Allotment allotment_magic1(&writer, 16); + writer.Write(16, kMagic1); + allotment_magic1.ReclaimAndCharge(&writer, 0, nullptr); + + std::vector<uint8_t> context_map; + EntropyEncodingData codes; + std::vector<std::vector<Token>> input_values_vec; + input_values_vec.push_back(input_values); + + BuildAndEncodeHistograms(HistogramParams(), n_histograms, input_values_vec, + &codes, &context_map, &writer, 0, nullptr); + WriteTokens(input_values_vec[0], codes, context_map, 0, &writer, 0, nullptr); + + // Magic bytes + padding + BitWriter::Allotment allotment_magic2(&writer, 24); + writer.Write(16, kMagic2); + writer.ZeroPadToByte(); + allotment_magic2.ReclaimAndCharge(&writer, 0, nullptr); + + // We do not truncate the output. Reading past the end reads out zeroes + // anyway. + BitReader br(writer.GetSpan()); + + ASSERT_EQ(br.ReadBits(16), kMagic1); + + std::vector<uint8_t> dec_context_map; + ANSCode decoded_codes; + ASSERT_TRUE( + DecodeHistograms(&br, n_histograms, &decoded_codes, &dec_context_map)); + ASSERT_EQ(dec_context_map, context_map); + ANSSymbolReader reader(&decoded_codes, &br); + + for (const Token& symbol : input_values) { + uint32_t read_symbol = + reader.ReadHybridUint(symbol.context, &br, dec_context_map); + ASSERT_EQ(read_symbol, symbol.value); + } + ASSERT_TRUE(reader.CheckANSFinalState()); + + ASSERT_EQ(br.ReadBits(16), kMagic2); + EXPECT_TRUE(br.Close()); +} + +TEST(ANSTest, EmptyRoundtrip) { + RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, std::vector<Token>()); +} + +TEST(ANSTest, SingleSymbolRoundtrip) { + for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) { + RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, {{0, i}}); + } + for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) { + RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, + std::vector<Token>(1024, {0, i})); + } +} + +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) +constexpr size_t kReps = 3; +#else +constexpr size_t kReps = 10; +#endif + +void RoundtripRandomStream(int alphabet_size, size_t reps = kReps, + size_t num = 1 << 18) { + constexpr int kNumHistograms = 3; + Rng rng(0); + for (size_t i = 0; i < reps; i++) { + std::vector<Token> symbols; + for (size_t j = 0; j < num; j++) { + int context = rng.UniformI(0, kNumHistograms); + int value = rng.UniformU(0, alphabet_size); + symbols.emplace_back(context, value); + } + RoundtripTestcase(kNumHistograms, alphabet_size, symbols); + } +} + +void RoundtripRandomUnbalancedStream(int alphabet_size) { + constexpr int kNumHistograms = 3; + constexpr int kPrecision = 1 << 10; + Rng rng(0); + for (size_t i = 0; i < kReps; i++) { + std::vector<int> distributions[kNumHistograms] = {}; + for (int j = 0; j < kNumHistograms; j++) { + distributions[j].resize(kPrecision); + int symbol = 0; + int remaining = 1; + for (int k = 0; k < kPrecision; k++) { + if (remaining == 0) { + if (symbol < alphabet_size - 1) symbol++; + // There is no meaning behind this distribution: it's anything that + // will create a nonuniform distribution and won't have too few + // symbols usually. Also we want different distributions we get to be + // sufficiently dissimilar. + remaining = rng.UniformU(0, kPrecision - k + 1); + } + distributions[j][k] = symbol; + remaining--; + } + } + std::vector<Token> symbols; + for (int j = 0; j < 1 << 18; j++) { + int context = rng.UniformI(0, kNumHistograms); + int value = rng.UniformU(0, kPrecision); + symbols.emplace_back(context, value); + } + RoundtripTestcase(kNumHistograms + 1, alphabet_size, symbols); + } +} + +TEST(ANSTest, RandomStreamRoundtrip3Small) { RoundtripRandomStream(3, 1, 16); } + +TEST(ANSTest, RandomStreamRoundtrip3) { RoundtripRandomStream(3); } + +TEST(ANSTest, RandomStreamRoundtripBig) { + RoundtripRandomStream(ANS_MAX_ALPHABET_SIZE); +} + +TEST(ANSTest, RandomUnbalancedStreamRoundtrip3) { + RoundtripRandomUnbalancedStream(3); +} + +TEST(ANSTest, RandomUnbalancedStreamRoundtripBig) { + RoundtripRandomUnbalancedStream(ANS_MAX_ALPHABET_SIZE); +} + +TEST(ANSTest, UintConfigRoundtrip) { + for (size_t log_alpha_size = 5; log_alpha_size <= 8; log_alpha_size++) { + std::vector<HybridUintConfig> uint_config, uint_config_dec; + for (size_t i = 0; i < log_alpha_size; i++) { + for (size_t j = 0; j <= i; j++) { + for (size_t k = 0; k <= i - j; k++) { + uint_config.emplace_back(i, j, k); + } + } + } + uint_config.emplace_back(log_alpha_size, 0, 0); + uint_config_dec.resize(uint_config.size()); + BitWriter writer; + BitWriter::Allotment allotment(&writer, 10 * uint_config.size()); + EncodeUintConfigs(uint_config, &writer, log_alpha_size); + allotment.ReclaimAndCharge(&writer, 0, nullptr); + writer.ZeroPadToByte(); + BitReader br(writer.GetSpan()); + EXPECT_TRUE(DecodeUintConfigs(log_alpha_size, &uint_config_dec, &br)); + EXPECT_TRUE(br.Close()); + for (size_t i = 0; i < uint_config.size(); i++) { + EXPECT_EQ(uint_config[i].split_token, uint_config_dec[i].split_token); + EXPECT_EQ(uint_config[i].msb_in_token, uint_config_dec[i].msb_in_token); + EXPECT_EQ(uint_config[i].lsb_in_token, uint_config_dec[i].lsb_in_token); + } + } +} + +void TestCheckpointing(bool ans, bool lz77) { + std::vector<std::vector<Token>> input_values(1); + for (size_t i = 0; i < 1024; i++) { + input_values[0].push_back(Token(0, i % 4)); + } + // up to lz77 window size. + for (size_t i = 0; i < (1 << 20) - 1022; i++) { + input_values[0].push_back(Token(0, (i % 5) + 4)); + } + // Ensure that when the window wraps around, new values are different. + input_values[0].push_back(Token(0, 0)); + for (size_t i = 0; i < 1024; i++) { + input_values[0].push_back(Token(0, i % 4)); + } + + std::vector<uint8_t> context_map; + EntropyEncodingData codes; + HistogramParams params; + params.lz77_method = lz77 ? HistogramParams::LZ77Method::kLZ77 + : HistogramParams::LZ77Method::kNone; + params.force_huffman = !ans; + + BitWriter writer; + { + auto input_values_copy = input_values; + BuildAndEncodeHistograms(params, 1, input_values_copy, &codes, &context_map, + &writer, 0, nullptr); + WriteTokens(input_values_copy[0], codes, context_map, 0, &writer, 0, + nullptr); + writer.ZeroPadToByte(); + } + + // We do not truncate the output. Reading past the end reads out zeroes + // anyway. + BitReader br(writer.GetSpan()); + Status status = true; + { + BitReaderScopedCloser bc(&br, &status); + + std::vector<uint8_t> dec_context_map; + ANSCode decoded_codes; + ASSERT_TRUE(DecodeHistograms(&br, 1, &decoded_codes, &dec_context_map)); + ASSERT_EQ(dec_context_map, context_map); + ANSSymbolReader reader(&decoded_codes, &br); + + ANSSymbolReader::Checkpoint checkpoint; + size_t br_pos = 0; + constexpr size_t kInterval = ANSSymbolReader::kMaxCheckpointInterval - 2; + for (size_t i = 0; i < input_values[0].size(); i++) { + if (i % kInterval == 0 && i > 0) { + reader.Restore(checkpoint); + ASSERT_TRUE(br.Close()); + br = BitReader(writer.GetSpan()); + br.SkipBits(br_pos); + for (size_t j = i - kInterval; j < i; j++) { + Token symbol = input_values[0][j]; + uint32_t read_symbol = + reader.ReadHybridUint(symbol.context, &br, dec_context_map); + ASSERT_EQ(read_symbol, symbol.value) << "j = " << j; + } + } + if (i % kInterval == 0) { + reader.Save(&checkpoint); + br_pos = br.TotalBitsConsumed(); + } + Token symbol = input_values[0][i]; + uint32_t read_symbol = + reader.ReadHybridUint(symbol.context, &br, dec_context_map); + ASSERT_EQ(read_symbol, symbol.value) << "i = " << i; + } + ASSERT_TRUE(reader.CheckANSFinalState()); + } + EXPECT_TRUE(status); +} + +TEST(ANSTest, TestCheckpointingANS) { + TestCheckpointing(/*ans=*/true, /*lz77=*/false); +} + +TEST(ANSTest, TestCheckpointingPrefix) { + TestCheckpointing(/*ans=*/false, /*lz77=*/false); +} + +TEST(ANSTest, TestCheckpointingANSLZ77) { + TestCheckpointing(/*ans=*/true, /*lz77=*/true); +} + +TEST(ANSTest, TestCheckpointingPrefixLZ77) { + TestCheckpointing(/*ans=*/false, /*lz77=*/true); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/base/arch_macros.h b/third_party/jpeg-xl/lib/jxl/base/arch_macros.h new file mode 100644 index 0000000000..a98301915e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/arch_macros.h @@ -0,0 +1,33 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_ARCH_MACROS_H_ +#define LIB_JXL_BASE_ARCH_MACROS_H_ + +// Defines the JXL_ARCH_* macros. + +namespace jxl { + +#if defined(__x86_64__) || defined(_M_X64) +#define JXL_ARCH_X64 1 +#else +#define JXL_ARCH_X64 0 +#endif + +#if defined(__powerpc64__) || defined(_M_PPC) +#define JXL_ARCH_PPC 1 +#else +#define JXL_ARCH_PPC 0 +#endif + +#if defined(__aarch64__) || defined(__arm__) +#define JXL_ARCH_ARM 1 +#else +#define JXL_ARCH_ARM 0 +#endif + +} // namespace jxl + +#endif // LIB_JXL_BASE_ARCH_MACROS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/bits.h b/third_party/jpeg-xl/lib/jxl/base/bits.h new file mode 100644 index 0000000000..9f86118e72 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/bits.h @@ -0,0 +1,147 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_BITS_H_ +#define LIB_JXL_BASE_BITS_H_ + +// Specialized instructions for processing register-sized bit arrays. + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +#if JXL_COMPILER_MSVC +#include <intrin.h> +#endif + +#include <stddef.h> +#include <stdint.h> + +namespace jxl { + +// Empty struct used as a size tag type. +template <size_t N> +struct SizeTag {}; + +template <typename T> +constexpr bool IsSigned() { + return T(0) > T(-1); +} + +// Undefined results for x == 0. +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsAboveMS1Bit_Nonzero(SizeTag<4> /* tag */, const uint32_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC + unsigned long index; + _BitScanReverse(&index, x); + return 31 - index; +#else + return static_cast<size_t>(__builtin_clz(x)); +#endif +} +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsAboveMS1Bit_Nonzero(SizeTag<8> /* tag */, const uint64_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC +#if JXL_ARCH_X64 + unsigned long index; + _BitScanReverse64(&index, x); + return 63 - index; +#else // JXL_ARCH_X64 + // _BitScanReverse64 not available + uint32_t msb = static_cast<uint32_t>(x >> 32u); + unsigned long index; + if (msb == 0) { + uint32_t lsb = static_cast<uint32_t>(x & 0xFFFFFFFF); + _BitScanReverse(&index, lsb); + return 63 - index; + } else { + _BitScanReverse(&index, msb); + return 31 - index; + } +#endif // JXL_ARCH_X64 +#else + return static_cast<size_t>(__builtin_clzll(x)); +#endif +} +template <typename T> +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsAboveMS1Bit_Nonzero(const T x) { + static_assert(!IsSigned<T>(), "Num0BitsAboveMS1Bit_Nonzero: use unsigned"); + return Num0BitsAboveMS1Bit_Nonzero(SizeTag<sizeof(T)>(), x); +} + +// Undefined results for x == 0. +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsBelowLS1Bit_Nonzero(SizeTag<4> /* tag */, const uint32_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC + unsigned long index; + _BitScanForward(&index, x); + return index; +#else + return static_cast<size_t>(__builtin_ctz(x)); +#endif +} +static JXL_INLINE JXL_MAYBE_UNUSED size_t +Num0BitsBelowLS1Bit_Nonzero(SizeTag<8> /* tag */, const uint64_t x) { + JXL_DASSERT(x != 0); +#if JXL_COMPILER_MSVC +#if JXL_ARCH_X64 + unsigned long index; + _BitScanForward64(&index, x); + return index; +#else // JXL_ARCH_64 + // _BitScanForward64 not available + uint32_t lsb = static_cast<uint32_t>(x & 0xFFFFFFFF); + unsigned long index; + if (lsb == 0) { + uint32_t msb = static_cast<uint32_t>(x >> 32u); + _BitScanForward(&index, msb); + return 32 + index; + } else { + _BitScanForward(&index, lsb); + return index; + } +#endif // JXL_ARCH_X64 +#else + return static_cast<size_t>(__builtin_ctzll(x)); +#endif +} +template <typename T> +static JXL_INLINE JXL_MAYBE_UNUSED size_t Num0BitsBelowLS1Bit_Nonzero(T x) { + static_assert(!IsSigned<T>(), "Num0BitsBelowLS1Bit_Nonzero: use unsigned"); + return Num0BitsBelowLS1Bit_Nonzero(SizeTag<sizeof(T)>(), x); +} + +// Returns bit width for x == 0. +template <typename T> +static JXL_INLINE JXL_MAYBE_UNUSED size_t Num0BitsAboveMS1Bit(const T x) { + return (x == 0) ? sizeof(T) * 8 : Num0BitsAboveMS1Bit_Nonzero(x); +} + +// Returns bit width for x == 0. +template <typename T> +static JXL_INLINE JXL_MAYBE_UNUSED size_t Num0BitsBelowLS1Bit(const T x) { + return (x == 0) ? sizeof(T) * 8 : Num0BitsBelowLS1Bit_Nonzero(x); +} + +// Returns base-2 logarithm, rounded down. +template <typename T> +static JXL_INLINE JXL_MAYBE_UNUSED size_t FloorLog2Nonzero(const T x) { + return (sizeof(T) * 8 - 1) ^ Num0BitsAboveMS1Bit_Nonzero(x); +} + +// Returns base-2 logarithm, rounded up. +template <typename T> +static JXL_INLINE JXL_MAYBE_UNUSED size_t CeilLog2Nonzero(const T x) { + const size_t floor_log2 = FloorLog2Nonzero(x); + if ((x & (x - 1)) == 0) return floor_log2; // power of two + return floor_log2 + 1; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_BITS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/byte_order.h b/third_party/jpeg-xl/lib/jxl/base/byte_order.h new file mode 100644 index 0000000000..8966834e08 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/byte_order.h @@ -0,0 +1,274 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_BYTE_ORDER_H_ +#define LIB_JXL_BASE_BYTE_ORDER_H_ + +#include <jxl/types.h> +#include <stdint.h> +#include <string.h> // memcpy + +#include "lib/jxl/base/compiler_specific.h" + +#if JXL_COMPILER_MSVC +#include <intrin.h> // _byteswap_* +#endif + +#if (defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) +#define JXL_BYTE_ORDER_LITTLE 1 +#else +// This means that we don't know that the byte order is little endian, in +// this case we use endian-neutral code that works for both little- and +// big-endian. +#define JXL_BYTE_ORDER_LITTLE 0 +#endif + +// Returns whether the system is little-endian (least-significant byte first). +#if JXL_BYTE_ORDER_LITTLE +static constexpr bool IsLittleEndian() { return true; } +#else +static inline bool IsLittleEndian() { + const uint32_t multibyte = 1; + uint8_t byte; + memcpy(&byte, &multibyte, 1); + return byte == 1; +} +#endif + +static inline bool SwapEndianness(JxlEndianness endianness) { + return ((endianness == JXL_BIG_ENDIAN && IsLittleEndian()) || + (endianness == JXL_LITTLE_ENDIAN && !IsLittleEndian())); +} + +#if JXL_COMPILER_MSVC +#define JXL_BSWAP16(x) _byteswap_ushort(x) +#define JXL_BSWAP32(x) _byteswap_ulong(x) +#define JXL_BSWAP64(x) _byteswap_uint64(x) +#else +#define JXL_BSWAP16(x) __builtin_bswap16(x) +#define JXL_BSWAP32(x) __builtin_bswap32(x) +#define JXL_BSWAP64(x) __builtin_bswap64(x) +#endif + +static JXL_INLINE uint32_t LoadBE16(const uint8_t* p) { + const uint32_t byte1 = p[0]; + const uint32_t byte0 = p[1]; + return (byte1 << 8) | byte0; +} + +static JXL_INLINE uint32_t LoadLE16(const uint8_t* p) { + const uint32_t byte0 = p[0]; + const uint32_t byte1 = p[1]; + return (byte1 << 8) | byte0; +} + +static JXL_INLINE uint32_t LoadBE32(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint32_t big; + memcpy(&big, p, 4); + return JXL_BSWAP32(big); +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint32_t byte3 = p[0]; + const uint32_t byte2 = p[1]; + const uint32_t byte1 = p[2]; + const uint32_t byte0 = p[3]; + return (byte3 << 24) | (byte2 << 16) | (byte1 << 8) | byte0; +#endif +} + +static JXL_INLINE uint64_t LoadBE64(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint64_t big; + memcpy(&big, p, 8); + return JXL_BSWAP64(big); +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint64_t byte7 = p[0]; + const uint64_t byte6 = p[1]; + const uint64_t byte5 = p[2]; + const uint64_t byte4 = p[3]; + const uint64_t byte3 = p[4]; + const uint64_t byte2 = p[5]; + const uint64_t byte1 = p[6]; + const uint64_t byte0 = p[7]; + return (byte7 << 56ull) | (byte6 << 48ull) | (byte5 << 40ull) | + (byte4 << 32ull) | (byte3 << 24ull) | (byte2 << 16ull) | + (byte1 << 8ull) | byte0; +#endif +} + +static JXL_INLINE uint32_t LoadLE32(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint32_t little; + memcpy(&little, p, 4); + return little; +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint32_t byte0 = p[0]; + const uint32_t byte1 = p[1]; + const uint32_t byte2 = p[2]; + const uint32_t byte3 = p[3]; + return (byte3 << 24) | (byte2 << 16) | (byte1 << 8) | byte0; +#endif +} + +static JXL_INLINE uint64_t LoadLE64(const uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + uint64_t little; + memcpy(&little, p, 8); + return little; +#else + // Byte-order-independent - can't assume this machine is big endian. + const uint64_t byte0 = p[0]; + const uint64_t byte1 = p[1]; + const uint64_t byte2 = p[2]; + const uint64_t byte3 = p[3]; + const uint64_t byte4 = p[4]; + const uint64_t byte5 = p[5]; + const uint64_t byte6 = p[6]; + const uint64_t byte7 = p[7]; + return (byte7 << 56) | (byte6 << 48) | (byte5 << 40) | (byte4 << 32) | + (byte3 << 24) | (byte2 << 16) | (byte1 << 8) | byte0; +#endif +} + +// Loads a Big-Endian float +static JXL_INLINE float LoadBEFloat(const uint8_t* p) { + uint32_t u = LoadBE32(p); + float result; + memcpy(&result, &u, 4); + return result; +} + +// Loads a Little-Endian float +static JXL_INLINE float LoadLEFloat(const uint8_t* p) { + uint32_t u = LoadLE32(p); + float result; + memcpy(&result, &u, 4); + return result; +} + +static JXL_INLINE void StoreBE16(const uint32_t native, uint8_t* p) { + p[0] = (native >> 8) & 0xFF; + p[1] = native & 0xFF; +} + +static JXL_INLINE void StoreLE16(const uint32_t native, uint8_t* p) { + p[1] = (native >> 8) & 0xFF; + p[0] = native & 0xFF; +} + +static JXL_INLINE void StoreBE32(const uint32_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint32_t big = JXL_BSWAP32(native); + memcpy(p, &big, 4); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[0] = native >> 24; + p[1] = (native >> 16) & 0xFF; + p[2] = (native >> 8) & 0xFF; + p[3] = native & 0xFF; +#endif +} + +static JXL_INLINE void StoreBE64(const uint64_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint64_t big = JXL_BSWAP64(native); + memcpy(p, &big, 8); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[0] = native >> 56ull; + p[1] = (native >> 48ull) & 0xFF; + p[2] = (native >> 40ull) & 0xFF; + p[3] = (native >> 32ull) & 0xFF; + p[4] = (native >> 24ull) & 0xFF; + p[5] = (native >> 16ull) & 0xFF; + p[6] = (native >> 8ull) & 0xFF; + p[7] = native & 0xFF; +#endif +} + +static JXL_INLINE void StoreLE32(const uint32_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint32_t little = native; + memcpy(p, &little, 4); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[3] = native >> 24; + p[2] = (native >> 16) & 0xFF; + p[1] = (native >> 8) & 0xFF; + p[0] = native & 0xFF; +#endif +} + +static JXL_INLINE void StoreLE64(const uint64_t native, uint8_t* p) { +#if JXL_BYTE_ORDER_LITTLE + const uint64_t little = native; + memcpy(p, &little, 8); +#else + // Byte-order-independent - can't assume this machine is big endian. + p[7] = native >> 56; + p[6] = (native >> 48) & 0xFF; + p[5] = (native >> 40) & 0xFF; + p[4] = (native >> 32) & 0xFF; + p[3] = (native >> 24) & 0xFF; + p[2] = (native >> 16) & 0xFF; + p[1] = (native >> 8) & 0xFF; + p[0] = native & 0xFF; +#endif +} + +static JXL_INLINE float BSwapFloat(float x) { + uint32_t u; + memcpy(&u, &x, 4); + uint32_t uswap = JXL_BSWAP32(u); + float xswap; + memcpy(&xswap, &uswap, 4); + return xswap; +} + +// Big/Little Endian order. +struct OrderBE {}; +struct OrderLE {}; + +// Wrappers for calling from generic code. +static JXL_INLINE void Store16(OrderBE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreBE16(native, p); +} + +static JXL_INLINE void Store16(OrderLE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreLE16(native, p); +} + +static JXL_INLINE void Store32(OrderBE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreBE32(native, p); +} + +static JXL_INLINE void Store32(OrderLE /*tag*/, const uint32_t native, + uint8_t* p) { + return StoreLE32(native, p); +} + +static JXL_INLINE uint32_t Load16(OrderBE /*tag*/, const uint8_t* p) { + return LoadBE16(p); +} + +static JXL_INLINE uint32_t Load16(OrderLE /*tag*/, const uint8_t* p) { + return LoadLE16(p); +} + +static JXL_INLINE uint32_t Load32(OrderBE /*tag*/, const uint8_t* p) { + return LoadBE32(p); +} + +static JXL_INLINE uint32_t Load32(OrderLE /*tag*/, const uint8_t* p) { + return LoadLE32(p); +} + +#endif // LIB_JXL_BASE_BYTE_ORDER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/c_callback_support.h b/third_party/jpeg-xl/lib/jxl/base/c_callback_support.h new file mode 100644 index 0000000000..aee0ce5346 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/c_callback_support.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_C_CALLBACK_SUPPORT_H_ +#define LIB_JXL_BASE_C_CALLBACK_SUPPORT_H_ + +#include <utility> + +namespace jxl { +namespace detail { + +template <typename T> +struct MethodToCCallbackHelper {}; + +template <typename T, typename R, typename... Args> +struct MethodToCCallbackHelper<R (T::*)(Args...)> { + template <R (T::*method)(Args...)> + static R Call(void *opaque, Args... args) { + return (reinterpret_cast<T *>(opaque)->*method)( + std::forward<Args>(args)...); + } +}; + +} // namespace detail +} // namespace jxl + +#define METHOD_TO_C_CALLBACK(method) \ + ::jxl::detail::MethodToCCallbackHelper<decltype(method)>::Call<method> + +#endif // LIB_JXL_BASE_C_CALLBACK_SUPPORT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/common.h b/third_party/jpeg-xl/lib/jxl/base/common.h new file mode 100644 index 0000000000..b7fe6ab0bc --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/common.h @@ -0,0 +1,95 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_COMMON_H_ +#define LIB_JXL_BASE_COMMON_H_ + +// Shared constants and helper functions. + +#include <cstddef> +#include <cstdint> +#include <cstdio> +#include <memory> +#include <string> + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { +// Some enums and typedefs used by more than one header file. + +constexpr size_t kBitsPerByte = 8; // more clear than CHAR_BIT + +constexpr inline size_t RoundUpBitsToByteMultiple(size_t bits) { + return (bits + 7) & ~size_t(7); +} + +constexpr inline size_t RoundUpToBlockDim(size_t dim) { + return (dim + 7) & ~size_t(7); +} + +static inline bool JXL_MAYBE_UNUSED SafeAdd(const uint64_t a, const uint64_t b, + uint64_t& sum) { + sum = a + b; + return sum >= a; // no need to check b - either sum >= both or < both. +} + +template <typename T1, typename T2> +constexpr inline T1 DivCeil(T1 a, T2 b) { + return (a + b - 1) / b; +} + +// Works for any `align`; if a power of two, compiler emits ADD+AND. +constexpr inline size_t RoundUpTo(size_t what, size_t align) { + return DivCeil(what, align) * align; +} + +constexpr double kPi = 3.14159265358979323846264338327950288; + +// Reasonable default for sRGB, matches common monitors. We map white to this +// many nits (cd/m^2) by default. Butteraugli was tuned for 250 nits, which is +// very close. +// NB: This constant is not very "base", but it is shared between modules. +static constexpr float kDefaultIntensityTarget = 255; + +template <typename T> +constexpr T Pi(T multiplier) { + return static_cast<T>(multiplier * kPi); +} + +// Prior to C++14 (i.e. C++11): provide our own make_unique +#if __cplusplus < 201402L +template <typename T, typename... Args> +std::unique_ptr<T> make_unique(Args&&... args) { + return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); +} +#else +using std::make_unique; +#endif + +template <typename T> +JXL_INLINE T Clamp1(T val, T low, T hi) { + return val < low ? low : val > hi ? hi : val; +} + +// conversion from integer to string. +template <typename T> +std::string ToString(T n) { + char data[32] = {}; + if (T(0.1) != T(0)) { + // float + snprintf(data, sizeof(data), "%g", static_cast<double>(n)); + } else if (T(-1) > T(0)) { + // unsigned + snprintf(data, sizeof(data), "%llu", static_cast<unsigned long long>(n)); + } else { + // signed + snprintf(data, sizeof(data), "%lld", static_cast<long long>(n)); + } + return data; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_COMMON_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/compiler_specific.h b/third_party/jpeg-xl/lib/jxl/base/compiler_specific.h new file mode 100644 index 0000000000..702ff8e058 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/compiler_specific.h @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_COMPILER_SPECIFIC_H_ +#define LIB_JXL_BASE_COMPILER_SPECIFIC_H_ + +// Macros for compiler version + nonstandard keywords, e.g. __builtin_expect. + +#include <stdint.h> +#include <sys/types.h> + +#include "lib/jxl/base/sanitizer_definitions.h" + +// #if is shorter and safer than #ifdef. *_VERSION are zero if not detected, +// otherwise 100 * major + minor version. Note that other packages check for +// #ifdef COMPILER_MSVC, so we cannot use that same name. + +#ifdef _MSC_VER +#define JXL_COMPILER_MSVC _MSC_VER +#else +#define JXL_COMPILER_MSVC 0 +#endif + +#ifdef __GNUC__ +#define JXL_COMPILER_GCC (__GNUC__ * 100 + __GNUC_MINOR__) +#else +#define JXL_COMPILER_GCC 0 +#endif + +#ifdef __clang__ +#define JXL_COMPILER_CLANG (__clang_major__ * 100 + __clang_minor__) +// Clang pretends to be GCC for compatibility. +#undef JXL_COMPILER_GCC +#define JXL_COMPILER_GCC 0 +#else +#define JXL_COMPILER_CLANG 0 +#endif + +#if JXL_COMPILER_MSVC +#define JXL_RESTRICT __restrict +#elif JXL_COMPILER_GCC || JXL_COMPILER_CLANG +#define JXL_RESTRICT __restrict__ +#else +#define JXL_RESTRICT +#endif + +#if JXL_COMPILER_MSVC +#define JXL_INLINE __forceinline +#define JXL_NOINLINE __declspec(noinline) +#else +#define JXL_INLINE inline __attribute__((always_inline)) +#define JXL_NOINLINE __attribute__((noinline)) +#endif + +#if JXL_COMPILER_MSVC +#define JXL_NORETURN __declspec(noreturn) +#elif JXL_COMPILER_GCC || JXL_COMPILER_CLANG +#define JXL_NORETURN __attribute__((noreturn)) +#else +#define JXL_NORETURN +#endif + +#if JXL_COMPILER_MSVC +#define JXL_UNREACHABLE_BUILTIN __assume(false) +#elif JXL_COMPILER_CLANG || JXL_COMPILER_GCC >= 405 +#define JXL_UNREACHABLE_BUILTIN __builtin_unreachable() +#else +#define JXL_UNREACHABLE_BUILTIN +#endif + +#if JXL_COMPILER_MSVC +#define JXL_MAYBE_UNUSED +#else +// Encountered "attribute list cannot appear here" when using the C++17 +// [[maybe_unused]], so only use the old style attribute for now. +#define JXL_MAYBE_UNUSED __attribute__((unused)) +#endif + +// MSAN execution won't hurt if some code it not inlined, but this can greatly +// improve compilation time. Unfortunately this macro can not be used just +// everywhere - inside header files it leads to "multiple definition" error; +// though it would be better not to have JXL_INLINE in header overall. +#if JXL_MEMORY_SANITIZER || JXL_ADDRESS_SANITIZER || JXL_THREAD_SANITIZER +#define JXL_MAYBE_INLINE JXL_MAYBE_UNUSED +#else +#define JXL_MAYBE_INLINE JXL_INLINE +#endif + +#if JXL_COMPILER_MSVC +// Unsupported, __assume is not the same. +#define JXL_LIKELY(expr) expr +#define JXL_UNLIKELY(expr) expr +#else +#define JXL_LIKELY(expr) __builtin_expect(!!(expr), 1) +#define JXL_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#endif + +// Returns a void* pointer which the compiler then assumes is N-byte aligned. +// Example: float* JXL_RESTRICT aligned = (float*)JXL_ASSUME_ALIGNED(in, 32); +// +// The assignment semantics are required by GCC/Clang. ICC provides an in-place +// __assume_aligned, whereas MSVC's __assume appears unsuitable. +#if JXL_COMPILER_CLANG +// Early versions of Clang did not support __builtin_assume_aligned. +#define JXL_HAS_ASSUME_ALIGNED __has_builtin(__builtin_assume_aligned) +#elif JXL_COMPILER_GCC +#define JXL_HAS_ASSUME_ALIGNED 1 +#else +#define JXL_HAS_ASSUME_ALIGNED 0 +#endif + +#if JXL_HAS_ASSUME_ALIGNED +#define JXL_ASSUME_ALIGNED(ptr, align) __builtin_assume_aligned((ptr), (align)) +#else +#define JXL_ASSUME_ALIGNED(ptr, align) (ptr) /* not supported */ +#endif + +#ifdef __has_attribute +#define JXL_HAVE_ATTRIBUTE(x) __has_attribute(x) +#else +#define JXL_HAVE_ATTRIBUTE(x) 0 +#endif + +// Raises warnings if the function return value is unused. Should appear as the +// first part of a function definition/declaration. +#if JXL_HAVE_ATTRIBUTE(nodiscard) +#define JXL_MUST_USE_RESULT [[nodiscard]] +#elif JXL_COMPILER_CLANG && JXL_HAVE_ATTRIBUTE(warn_unused_result) +#define JXL_MUST_USE_RESULT __attribute__((warn_unused_result)) +#else +#define JXL_MUST_USE_RESULT +#endif + +// Disable certain -fsanitize flags for functions that are expected to include +// things like unsigned integer overflow. For example use in the function +// declaration JXL_NO_SANITIZE("unsigned-integer-overflow") to silence unsigned +// integer overflow ubsan messages. +#if JXL_COMPILER_CLANG && JXL_HAVE_ATTRIBUTE(no_sanitize) +#define JXL_NO_SANITIZE(X) __attribute__((no_sanitize(X))) +#else +#define JXL_NO_SANITIZE(X) +#endif + +#if JXL_HAVE_ATTRIBUTE(__format__) +#define JXL_FORMAT(idx_fmt, idx_arg) \ + __attribute__((__format__(__printf__, idx_fmt, idx_arg))) +#else +#define JXL_FORMAT(idx_fmt, idx_arg) +#endif + +#if JXL_COMPILER_MSVC +using ssize_t = intptr_t; +#endif + +#endif // LIB_JXL_BASE_COMPILER_SPECIFIC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/data_parallel.h b/third_party/jpeg-xl/lib/jxl/base/data_parallel.h new file mode 100644 index 0000000000..a7f977b7ce --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/data_parallel.h @@ -0,0 +1,124 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_DATA_PARALLEL_H_ +#define LIB_JXL_BASE_DATA_PARALLEL_H_ + +// Portable, low-overhead C++11 ThreadPool alternative to OpenMP for +// data-parallel computations. + +#include <jxl/parallel_runner.h> +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#if JXL_COMPILER_MSVC +// suppress warnings about the const & applied to function types +#pragma warning(disable : 4180) +#endif + +namespace jxl { + +class ThreadPool { + public: + ThreadPool(JxlParallelRunner runner, void* runner_opaque) + : runner_(runner), + runner_opaque_(runner ? runner_opaque : static_cast<void*>(this)) {} + + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator&(const ThreadPool&) = delete; + + JxlParallelRunner runner() const { return runner_; } + void* runner_opaque() const { return runner_opaque_; } + + // Runs init_func(num_threads) followed by data_func(task, thread) on worker + // thread(s) for every task in [begin, end). init_func() must return a Status + // indicating whether the initialization succeeded. + // "thread" is an integer smaller than num_threads. + // Not thread-safe - no two calls to Run may overlap. + // Subsequent calls will reuse the same threads. + // + // Precondition: begin <= end. + template <class InitFunc, class DataFunc> + Status Run(uint32_t begin, uint32_t end, const InitFunc& init_func, + const DataFunc& data_func, const char* caller = "") { + JXL_ASSERT(begin <= end); + if (begin == end) return true; + RunCallState<InitFunc, DataFunc> call_state(init_func, data_func); + // The runner_ uses the C convention and returns 0 in case of error, so we + // convert it to a Status. + if (!runner_) { + void* jpegxl_opaque = static_cast<void*>(&call_state); + if (call_state.CallInitFunc(jpegxl_opaque, 1) != 0) { + return JXL_FAILURE("Failed to initialize thread"); + } + for (uint32_t i = begin; i < end; i++) { + call_state.CallDataFunc(jpegxl_opaque, i, 0); + } + return true; + } + return (*runner_)(runner_opaque_, static_cast<void*>(&call_state), + &call_state.CallInitFunc, &call_state.CallDataFunc, begin, + end) == 0; + } + + // Use this as init_func when no initialization is needed. + static Status NoInit(size_t num_threads) { return true; } + + private: + // class holding the state of a Run() call to pass to the runner_ as an + // opaque_jpegxl pointer. + template <class InitFunc, class DataFunc> + class RunCallState final { + public: + RunCallState(const InitFunc& init_func, const DataFunc& data_func) + : init_func_(init_func), data_func_(data_func) {} + + // JxlParallelRunInit interface. + static int CallInitFunc(void* jpegxl_opaque, size_t num_threads) { + const auto* self = + static_cast<RunCallState<InitFunc, DataFunc>*>(jpegxl_opaque); + // Returns -1 when the internal init function returns false Status to + // indicate an error. + return self->init_func_(num_threads) ? 0 : -1; + } + + // JxlParallelRunFunction interface. + static void CallDataFunc(void* jpegxl_opaque, uint32_t value, + size_t thread_id) { + const auto* self = + static_cast<RunCallState<InitFunc, DataFunc>*>(jpegxl_opaque); + return self->data_func_(value, thread_id); + } + + private: + const InitFunc& init_func_; + const DataFunc& data_func_; + }; + + // The caller supplied runner function and its opaque void*. + const JxlParallelRunner runner_; + void* const runner_opaque_; +}; + +template <class InitFunc, class DataFunc> +Status RunOnPool(ThreadPool* pool, const uint32_t begin, const uint32_t end, + const InitFunc& init_func, const DataFunc& data_func, + const char* caller) { + if (pool == nullptr) { + ThreadPool default_pool(nullptr, nullptr); + return default_pool.Run(begin, end, init_func, data_func, caller); + } else { + return pool->Run(begin, end, init_func, data_func, caller); + } +} + +} // namespace jxl +#if JXL_COMPILER_MSVC +#pragma warning(default : 4180) +#endif + +#endif // LIB_JXL_BASE_DATA_PARALLEL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/exif.h b/third_party/jpeg-xl/lib/jxl/base/exif.h new file mode 100644 index 0000000000..2caafddc04 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/exif.h @@ -0,0 +1,91 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_EXIF_H_ +#define LIB_JXL_EXIF_H_ + +// Basic parsing of Exif (just enough for the render-impacting things +// like orientation) + +#include <jxl/codestream_header.h> + +#include <cstddef> +#include <cstdint> +#include <vector> + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +constexpr uint16_t kExifOrientationTag = 274; + +// Checks if a blob looks like Exif, and if so, sets bigendian +// according to the tiff endianness +JXL_INLINE bool IsExif(const std::vector<uint8_t>& exif, bool* bigendian) { + if (exif.size() < 12) return false; // not enough bytes for a valid exif blob + const uint8_t* t = exif.data(); + if (LoadLE32(t) == 0x2A004D4D) { + *bigendian = true; + return true; + } else if (LoadLE32(t) == 0x002A4949) { + *bigendian = false; + return true; + } + return false; // not a valid tiff header +} + +// Finds the position of an Exif tag, or 0 if it is not found +JXL_INLINE size_t FindExifTagPosition(const std::vector<uint8_t>& exif, + uint16_t tagname) { + bool bigendian; + if (!IsExif(exif, &bigendian)) return 0; + const uint8_t* t = exif.data() + 4; + uint64_t offset = (bigendian ? LoadBE32(t) : LoadLE32(t)); + if (exif.size() < 12 + offset + 2 || offset < 8) return 0; + t += offset - 4; + if (offset + 2 >= exif.size()) return 0; + uint16_t nb_tags = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + while (nb_tags > 0) { + if (t + 12 >= exif.data() + exif.size()) return 0; + uint16_t tag = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + if (tag == tagname) return static_cast<size_t>(t - exif.data()); + t += 10; + nb_tags--; + } + return 0; +} + +// TODO(jon): tag 1 can be used to represent Adobe RGB 1998 if it has value +// "R03" +// TODO(jon): set intrinsic dimensions according to +// https://discourse.wicg.io/t/proposal-exif-image-resolution-auto-and-from-image/4326/24 +// Parses the Exif data just enough to extract any render-impacting info. +// If the Exif data is invalid or could not be parsed, then it is treated +// as a no-op. +JXL_INLINE void InterpretExif(const std::vector<uint8_t>& exif, + JxlOrientation* orientation) { + bool bigendian; + if (!IsExif(exif, &bigendian)) return; + size_t o_pos = FindExifTagPosition(exif, kExifOrientationTag); + if (o_pos) { + const uint8_t* t = exif.data() + o_pos; + uint16_t type = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 2; + uint32_t count = (bigendian ? LoadBE32(t) : LoadLE32(t)); + t += 4; + uint16_t value = (bigendian ? LoadBE16(t) : LoadLE16(t)); + t += 4; + if (type == 3 && count == 1 && value >= 1 && value <= 8) { + *orientation = static_cast<JxlOrientation>(value); + } + } +} + +} // namespace jxl + +#endif // LIB_JXL_EXIF_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/fast_math-inl.h b/third_party/jpeg-xl/lib/jxl/base/fast_math-inl.h new file mode 100644 index 0000000000..fa749cc257 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/fast_math-inl.h @@ -0,0 +1,236 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Fast SIMD math ops (log2, encoder only, cos, erf for splines) + +#if defined(LIB_JXL_BASE_FAST_MATH_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_BASE_FAST_MATH_INL_H_ +#undef LIB_JXL_BASE_FAST_MATH_INL_H_ +#else +#define LIB_JXL_BASE_FAST_MATH_INL_H_ +#endif + +#include <hwy/highway.h> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/rational_polynomial-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::Floor; +using hwy::HWY_NAMESPACE::Ge; +using hwy::HWY_NAMESPACE::GetLane; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::IfThenZeroElse; +using hwy::HWY_NAMESPACE::Le; +using hwy::HWY_NAMESPACE::Min; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::NegMulAdd; +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::Xor; + +// Computes base-2 logarithm like std::log2. Undefined if negative / NaN. +// L1 error ~3.9E-6 +template <class DF, class V> +V FastLog2f(const DF df, V x) { + // 2,2 rational polynomial approximation of std::log1p(x) / std::log(2). + HWY_ALIGN const float p[4 * (2 + 1)] = {HWY_REP4(-1.8503833400518310E-06f), + HWY_REP4(1.4287160470083755E+00f), + HWY_REP4(7.4245873327820566E-01f)}; + HWY_ALIGN const float q[4 * (2 + 1)] = {HWY_REP4(9.9032814277590719E-01f), + HWY_REP4(1.0096718572241148E+00f), + HWY_REP4(1.7409343003366853E-01f)}; + + const Rebind<int32_t, DF> di; + const auto x_bits = BitCast(di, x); + + // Range reduction to [-1/3, 1/3] - 3 integer, 2 float ops + const auto exp_bits = Sub(x_bits, Set(di, 0x3f2aaaab)); // = 2/3 + // Shifted exponent = log2; also used to clear mantissa. + const auto exp_shifted = ShiftRight<23>(exp_bits); + const auto mantissa = BitCast(df, Sub(x_bits, ShiftLeft<23>(exp_shifted))); + const auto exp_val = ConvertTo(df, exp_shifted); + return Add(EvalRationalPolynomial(df, Sub(mantissa, Set(df, 1.0f)), p, q), + exp_val); +} + +// max relative error ~3e-7 +template <class DF, class V> +V FastPow2f(const DF df, V x) { + const Rebind<int32_t, DF> di; + auto floorx = Floor(x); + auto exp = + BitCast(df, ShiftLeft<23>(Add(ConvertTo(di, floorx), Set(di, 127)))); + auto frac = Sub(x, floorx); + auto num = Add(frac, Set(df, 1.01749063e+01)); + num = MulAdd(num, frac, Set(df, 4.88687798e+01)); + num = MulAdd(num, frac, Set(df, 9.85506591e+01)); + num = Mul(num, exp); + auto den = MulAdd(frac, Set(df, 2.10242958e-01), Set(df, -2.22328856e-02)); + den = MulAdd(den, frac, Set(df, -1.94414990e+01)); + den = MulAdd(den, frac, Set(df, 9.85506633e+01)); + return Div(num, den); +} + +// max relative error ~3e-5 +template <class DF, class V> +V FastPowf(const DF df, V base, V exponent) { + return FastPow2f(df, Mul(FastLog2f(df, base), exponent)); +} + +// Computes cosine like std::cos. +// L1 error 7e-5. +template <class DF, class V> +V FastCosf(const DF df, V x) { + // Step 1: range reduction to [0, 2pi) + const auto pi2 = Set(df, kPi * 2.0f); + const auto pi2_inv = Set(df, 0.5f / kPi); + const auto npi2 = Mul(Floor(Mul(x, pi2_inv)), pi2); + const auto xmodpi2 = Sub(x, npi2); + // Step 2: range reduction to [0, pi] + const auto x_pi = Min(xmodpi2, Sub(pi2, xmodpi2)); + // Step 3: range reduction to [0, pi/2] + const auto above_pihalf = Ge(x_pi, Set(df, kPi / 2.0f)); + const auto x_pihalf = IfThenElse(above_pihalf, Sub(Set(df, kPi), x_pi), x_pi); + // Step 4: Taylor-like approximation, scaled by 2**0.75 to make angle + // duplication steps faster, on x/4. + const auto xs = Mul(x_pihalf, Set(df, 0.25f)); + const auto x2 = Mul(xs, xs); + const auto x4 = Mul(x2, x2); + const auto cosx_prescaling = + MulAdd(x4, Set(df, 0.06960438), + MulAdd(x2, Set(df, -0.84087373), Set(df, 1.68179268))); + // Step 5: angle duplication. + const auto cosx_scale1 = + MulAdd(cosx_prescaling, cosx_prescaling, Set(df, -1.414213562)); + const auto cosx_scale2 = MulAdd(cosx_scale1, cosx_scale1, Set(df, -1)); + // Step 6: change sign if needed. + const Rebind<uint32_t, DF> du; + auto signbit = ShiftLeft<31>(BitCast(du, VecFromMask(df, above_pihalf))); + return BitCast(df, Xor(signbit, BitCast(du, cosx_scale2))); +} + +// Computes the error function like std::erf. +// L1 error 7e-4. +template <class DF, class V> +V FastErff(const DF df, V x) { + // Formula from + // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations + // but constants have been recomputed. + const auto xle0 = Le(x, Zero(df)); + const auto absx = Abs(x); + // Compute 1 - 1 / ((((x * a + b) * x + c) * x + d) * x + 1)**4 + const auto denom1 = + MulAdd(absx, Set(df, 7.77394369e-02), Set(df, 2.05260015e-04)); + const auto denom2 = MulAdd(denom1, absx, Set(df, 2.32120216e-01)); + const auto denom3 = MulAdd(denom2, absx, Set(df, 2.77820801e-01)); + const auto denom4 = MulAdd(denom3, absx, Set(df, 1.0f)); + const auto denom5 = Mul(denom4, denom4); + const auto inv_denom5 = Div(Set(df, 1.0f), denom5); + const auto result = NegMulAdd(inv_denom5, inv_denom5, Set(df, 1.0f)); + // Change sign if needed. + const Rebind<uint32_t, DF> du; + auto signbit = ShiftLeft<31>(BitCast(du, VecFromMask(df, xle0))); + return BitCast(df, Xor(signbit, BitCast(du, result))); +} + +inline float FastLog2f(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastLog2f(D, Set(D, f))); +} + +inline float FastPow2f(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastPow2f(D, Set(D, f))); +} + +inline float FastPowf(float b, float e) { + HWY_CAPPED(float, 1) D; + return GetLane(FastPowf(D, Set(D, b), Set(D, e))); +} + +inline float FastCosf(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastCosf(D, Set(D, f))); +} + +inline float FastErff(float f) { + HWY_CAPPED(float, 1) D; + return GetLane(FastErff(D, Set(D, f))); +} + +// Returns cbrt(x) + add with 6 ulp max error. +// Modified from vectormath_exp.h, Apache 2 license. +// https://www.agner.org/optimize/vectorclass.zip +template <class V> +V CubeRootAndAdd(const V x, const V add) { + const HWY_FULL(float) df; + const HWY_FULL(int32_t) di; + + const auto kExpBias = Set(di, 0x54800000); // cast(1.) + cast(1.) / 3 + const auto kExpMul = Set(di, 0x002AAAAA); // shifted 1/3 + const auto k1_3 = Set(df, 1.0f / 3); + const auto k4_3 = Set(df, 4.0f / 3); + + const auto xa = x; // assume inputs never negative + const auto xa_3 = Mul(k1_3, xa); + + // Multiply exponent by -1/3 + const auto m1 = BitCast(di, xa); + // Special case for 0. 0 is represented with an exponent of 0, so the + // "kExpBias - 1/3 * exp" below gives the wrong result. The IfThenZeroElse() + // sets those values as 0, which prevents having NaNs in the computations + // below. + // TODO(eustas): use fused op + const auto m2 = IfThenZeroElse( + Eq(m1, Zero(di)), Sub(kExpBias, Mul((ShiftRight<23>(m1)), kExpMul))); + auto r = BitCast(df, m2); + + // Newton-Raphson iterations + for (int i = 0; i < 3; i++) { + const auto r2 = Mul(r, r); + r = NegMulAdd(xa_3, Mul(r2, r2), Mul(k4_3, r)); + } + // Final iteration + auto r2 = Mul(r, r); + r = MulAdd(k1_3, NegMulAdd(xa, Mul(r2, r2), r), r); + r2 = Mul(r, r); + r = MulAdd(r2, x, add); + + return r; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_BASE_FAST_MATH_INL_H_ + +#if HWY_ONCE +#ifndef LIB_JXL_BASE_FAST_MATH_ONCE +#define LIB_JXL_BASE_FAST_MATH_ONCE + +namespace jxl { +inline float FastLog2f(float f) { return HWY_STATIC_DISPATCH(FastLog2f)(f); } +inline float FastPow2f(float f) { return HWY_STATIC_DISPATCH(FastPow2f)(f); } +inline float FastPowf(float b, float e) { + return HWY_STATIC_DISPATCH(FastPowf)(b, e); +} +inline float FastCosf(float f) { return HWY_STATIC_DISPATCH(FastCosf)(f); } +inline float FastErff(float f) { return HWY_STATIC_DISPATCH(FastErff)(f); } +} // namespace jxl + +#endif // LIB_JXL_BASE_FAST_MATH_ONCE +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/base/float.h b/third_party/jpeg-xl/lib/jxl/base/float.h new file mode 100644 index 0000000000..00e112bb34 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/float.h @@ -0,0 +1,102 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_FLOAT_H_ +#define LIB_JXL_BASE_FLOAT_H_ + +#include <jxl/types.h> +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +namespace { +// Based on highway scalar implementation, for testing +float LoadFloat16(uint16_t bits16) { + const uint32_t sign = bits16 >> 15; + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + // Subnormal or zero + if (biased_exp == 0) { + const float subnormal = + (1.0f / 16384) * (static_cast<float>(mantissa) * (1.0f / 1024)); + return sign ? -subnormal : subnormal; + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + + float result; + memcpy(&result, &bits32, 4); + return result; +} +} // namespace + +template <typename SaveFloatAtFn> +static Status JXL_INLINE LoadFloatRow(const uint8_t* src, size_t count, + size_t stride, JxlDataType type, + bool little_endian, float scale, + SaveFloatAtFn callback) { + switch (type) { + case JXL_TYPE_FLOAT: + if (little_endian) { + for (size_t i = 0; i < count; ++i) { + callback(i, LoadLEFloat(src + stride * i)); + } + } else { + for (size_t i = 0; i < count; ++i) { + callback(i, LoadBEFloat(src + stride * i)); + } + } + return true; + + case JXL_TYPE_UINT8: + for (size_t i = 0; i < count; ++i) { + // Integer multiply uint8 value before scaling so that the UINT8 value + // and the corresponding UINT16 value convert to the same float + callback(i, (src[stride * i] * 257) * scale); + } + return true; + + case JXL_TYPE_UINT16: + if (little_endian) { + for (size_t i = 0; i < count; ++i) { + callback(i, LoadLE16(src + stride * i) * scale); + } + } else { + for (size_t i = 0; i < count; ++i) { + callback(i, LoadBE16(src + stride * i) * scale); + } + } + return true; + + case JXL_TYPE_FLOAT16: + if (little_endian) { + for (size_t i = 0; i < count; ++i) { + callback(i, LoadFloat16(LoadLE16(src + stride * i))); + } + } else { + for (size_t i = 0; i < count; ++i) { + callback(i, LoadFloat16(LoadBE16(src + stride * i))); + } + } + return true; + + default: + return JXL_FAILURE("Unsupported sample format"); + } +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_FLOAT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/iaca.h b/third_party/jpeg-xl/lib/jxl/base/iaca.h new file mode 100644 index 0000000000..e5732dae5c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/iaca.h @@ -0,0 +1,65 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_IACA_H_ +#define LIB_JXL_BASE_IACA_H_ + +#include "lib/jxl/base/compiler_specific.h" + +// IACA (Intel's Code Analyzer) analyzes instruction latencies, but only for +// code between special markers. These functions embed such markers in an +// executable, but only for reading via IACA - they deliberately trigger a +// crash if executed to ensure they are removed in normal builds. + +#ifndef JXL_IACA_ENABLED +#define JXL_IACA_ENABLED 0 +#endif + +namespace jxl { + +// Call before the region of interest. +static JXL_INLINE void BeginIACA() { +#if JXL_IACA_ENABLED && (JXL_COMPILER_GCC || JXL_COMPILER_CLANG) + asm volatile( + // UD2 "instruction" raises an invalid opcode exception. + ".byte 0x0F, 0x0B\n\t" + // Magic sequence recognized by IACA (MOV + addr32 fs:NOP). This actually + // clobbers EBX, but we don't care because the code won't be run, and we + // want IACA to observe the same code the compiler would have generated + // without this marker. + "movl $111, %%ebx\n\t" + ".byte 0x64, 0x67, 0x90\n\t" + : + : + // (Allegedly) clobbering memory may prevent reordering. + : "memory"); +#endif +} + +// Call after the region of interest. +static JXL_INLINE void EndIACA() { +#if JXL_IACA_ENABLED && (JXL_COMPILER_GCC || JXL_COMPILER_CLANG) + asm volatile( + // See above. + "movl $222, %%ebx\n\t" + ".byte 0x64, 0x67, 0x90\n\t" + // UD2 + ".byte 0x0F, 0x0B\n\t" + : + : + // (Allegedly) clobbering memory may prevent reordering. + : "memory"); +#endif +} + +// Add to a scope to mark a region. +struct ScopeIACA { + JXL_INLINE ScopeIACA() { BeginIACA(); } + JXL_INLINE ~ScopeIACA() { EndIACA(); } +}; + +} // namespace jxl + +#endif // LIB_JXL_BASE_IACA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h b/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h new file mode 100644 index 0000000000..1a969bd4f0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h @@ -0,0 +1,84 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MATRIX_OPS_H_ +#define LIB_JXL_MATRIX_OPS_H_ + +// 3x3 matrix operations. + +#include <cmath> // abs +#include <cstddef> + +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Computes C = A * B, where A, B, C are 3x3 matrices. +template <typename T> +void Mul3x3Matrix(const T* a, const T* b, T* c) { + alignas(16) T temp[3]; // For transposed column + for (size_t x = 0; x < 3; x++) { + for (size_t z = 0; z < 3; z++) { + temp[z] = b[z * 3 + x]; + } + for (size_t y = 0; y < 3; y++) { + double e = 0; + for (size_t z = 0; z < 3; z++) { + e += a[y * 3 + z] * temp[z]; + } + c[y * 3 + x] = e; + } + } +} + +// Computes C = A * B, where A is 3x3 matrix and B is vector. +template <typename T> +void Mul3x3Vector(const T* a, const T* b, T* c) { + for (size_t y = 0; y < 3; y++) { + double e = 0; + for (size_t x = 0; x < 3; x++) { + e += a[y * 3 + x] * b[x]; + } + c[y] = e; + } +} + +// Inverts a 3x3 matrix in place. +template <typename T> +Status Inv3x3Matrix(T* matrix) { + // Intermediate computation is done in double precision. + double temp[9]; + temp[0] = static_cast<double>(matrix[4]) * matrix[8] - + static_cast<double>(matrix[5]) * matrix[7]; + temp[1] = static_cast<double>(matrix[2]) * matrix[7] - + static_cast<double>(matrix[1]) * matrix[8]; + temp[2] = static_cast<double>(matrix[1]) * matrix[5] - + static_cast<double>(matrix[2]) * matrix[4]; + temp[3] = static_cast<double>(matrix[5]) * matrix[6] - + static_cast<double>(matrix[3]) * matrix[8]; + temp[4] = static_cast<double>(matrix[0]) * matrix[8] - + static_cast<double>(matrix[2]) * matrix[6]; + temp[5] = static_cast<double>(matrix[2]) * matrix[3] - + static_cast<double>(matrix[0]) * matrix[5]; + temp[6] = static_cast<double>(matrix[3]) * matrix[7] - + static_cast<double>(matrix[4]) * matrix[6]; + temp[7] = static_cast<double>(matrix[1]) * matrix[6] - + static_cast<double>(matrix[0]) * matrix[7]; + temp[8] = static_cast<double>(matrix[0]) * matrix[4] - + static_cast<double>(matrix[1]) * matrix[3]; + double det = matrix[0] * temp[0] + matrix[1] * temp[3] + matrix[2] * temp[6]; + if (std::abs(det) < 1e-10) { + return JXL_FAILURE("Matrix determinant is too close to 0"); + } + double idet = 1.0 / det; + for (size_t i = 0; i < 9; i++) { + matrix[i] = temp[i] * idet; + } + return true; +} + +} // namespace jxl + +#endif // LIB_JXL_MATRIX_OPS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/os_macros.h b/third_party/jpeg-xl/lib/jxl/base/os_macros.h new file mode 100644 index 0000000000..84d0b82bf5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/os_macros.h @@ -0,0 +1,50 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_OS_MACROS_H_ +#define LIB_JXL_BASE_OS_MACROS_H_ + +// Defines the JXL_OS_* macros. + +#if defined(_WIN32) || defined(_WIN64) +#define JXL_OS_WIN 1 +#else +#define JXL_OS_WIN 0 +#endif + +#ifdef __linux__ +#define JXL_OS_LINUX 1 +#else +#define JXL_OS_LINUX 0 +#endif + +#ifdef __APPLE__ +#define JXL_OS_MAC 1 +#else +#define JXL_OS_MAC 0 +#endif + +#define JXL_OS_IOS 0 +#ifdef __APPLE__ +#include <TargetConditionals.h> +#if TARGET_OS_IPHONE +#undef JXL_OS_IOS +#define JXL_OS_IOS 1 +#endif +#endif + +#ifdef __FreeBSD__ +#define JXL_OS_FREEBSD 1 +#else +#define JXL_OS_FREEBSD 0 +#endif + +#ifdef __HAIKU__ +#define JXL_OS_HAIKU 1 +#else +#define JXL_OS_HAIKU 0 +#endif + +#endif // LIB_JXL_BASE_OS_MACROS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/override.h b/third_party/jpeg-xl/lib/jxl/base/override.h new file mode 100644 index 0000000000..1f8b657974 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/override.h @@ -0,0 +1,29 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_OVERRIDE_H_ +#define LIB_JXL_BASE_OVERRIDE_H_ + +// 'Trool' for command line arguments: force enable/disable, or use default. + +namespace jxl { + +// No effect if kDefault, otherwise forces a feature (typically a FrameHeader +// flag) on or off. +enum class Override : int { kOn = 1, kOff = 0, kDefault = -1 }; + +static inline Override OverrideFromBool(bool flag) { + return flag ? Override::kOn : Override::kOff; +} + +static inline bool ApplyOverride(Override o, bool default_condition) { + if (o == Override::kOn) return true; + if (o == Override::kOff) return false; + return default_condition; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_OVERRIDE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/printf_macros.h b/third_party/jpeg-xl/lib/jxl/base/printf_macros.h new file mode 100644 index 0000000000..3215052afd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/printf_macros.h @@ -0,0 +1,34 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_PRINTF_MACROS_H_ +#define LIB_JXL_BASE_PRINTF_MACROS_H_ + +// Format string macros. These should be included after any other system +// library since those may unconditionally define these, depending on the +// platform. + +// PRIuS and PRIdS macros to print size_t and ssize_t respectively. +#if !defined(PRIdS) +#if defined(_WIN64) +#define PRIdS "lld" +#elif defined(_WIN32) +#define PRIdS "d" +#else +#define PRIdS "zd" +#endif +#endif // PRIdS + +#if !defined(PRIuS) +#if defined(_WIN64) +#define PRIuS "llu" +#elif defined(_WIN32) +#define PRIuS "u" +#else +#define PRIuS "zu" +#endif +#endif // PRIuS + +#endif // LIB_JXL_BASE_PRINTF_MACROS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/random.h b/third_party/jpeg-xl/lib/jxl/base/random.h new file mode 100644 index 0000000000..b27815bf00 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/random.h @@ -0,0 +1,99 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_RANDOM_ +#define LIB_JXL_BASE_RANDOM_ + +// Random number generator + distributions. +// We don't use <random> because the implementation (and thus results) differs +// between libstdc++ and libc++. + +#include <stdint.h> +#include <string.h> + +#include <algorithm> +#include <cmath> + +#include "lib/jxl/base/status.h" + +namespace jxl { +struct Rng { + explicit Rng(size_t seed) + : s{static_cast<uint64_t>(0x94D049BB133111EBull), + static_cast<uint64_t>(0xBF58476D1CE4E5B9ull) + seed} {} + + // Xorshift128+ adapted from xorshift128+-inl.h + uint64_t operator()() { + uint64_t s1 = s[0]; + const uint64_t s0 = s[1]; + const uint64_t bits = s1 + s0; // b, c + s[0] = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s[1] = s1; + return bits; + } + + // Uniformly distributed int64_t in [begin, end), under the assumption that + // `end-begin` is significantly smaller than 1<<64, otherwise there is some + // bias. + int64_t UniformI(int64_t begin, int64_t end) { + JXL_DASSERT(end > begin); + return static_cast<int64_t>((*this)() % + static_cast<uint64_t>(end - begin)) + + begin; + } + + // Same as UniformI, but for uint64_t. + uint64_t UniformU(uint64_t begin, uint64_t end) { + JXL_DASSERT(end > begin); + return (*this)() % (end - begin) + begin; + } + + // Uniformly distributed float in [begin, end) range. Note: only 23 bits of + // randomness. + float UniformF(float begin, float end) { + float f; + // Bits of a random [1, 2) float. + uint32_t u = ((*this)() >> (64 - 23)) | 0x3F800000; + static_assert(sizeof(f) == sizeof(u), + "Float and U32 must have the same size"); + memcpy(&f, &u, sizeof(f)); + // Note: (end-begin) * f + (2*begin-end) may fail to return a number >= + // begin. + return (end - begin) * (f - 1.0f) + begin; + } + + // Bernoulli trial + bool Bernoulli(float p) { return UniformF(0, 1) < p; } + + // State for geometric distributions. + // The stored value is inv_log_1mp + using GeometricDistribution = float; + static GeometricDistribution MakeGeometric(float p) { + return 1.0 / std::log(1 - p); + } + + uint32_t Geometric(const GeometricDistribution& dist) { + float f = UniformF(0, 1); + float inv_log_1mp = dist; + float log = std::log(1 - f) * inv_log_1mp; + return static_cast<uint32_t>(log); + } + + template <typename T> + void Shuffle(T* t, size_t n) { + for (size_t i = 0; i + 1 < n; i++) { + size_t a = UniformU(i, n); + std::swap(t[a], t[i]); + } + } + + private: + uint64_t s[2]; +}; + +} // namespace jxl +#endif // LIB_JXL_BASE_RANDOM_ diff --git a/third_party/jpeg-xl/lib/jxl/base/rational_polynomial-inl.h b/third_party/jpeg-xl/lib/jxl/base/rational_polynomial-inl.h new file mode 100644 index 0000000000..e073937675 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/rational_polynomial-inl.h @@ -0,0 +1,102 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Fast SIMD evaluation of rational polynomials for approximating functions. + +#if defined(LIB_JXL_BASE_RATIONAL_POLYNOMIAL_INL_H_) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_BASE_RATIONAL_POLYNOMIAL_INL_H_ +#undef LIB_JXL_BASE_RATIONAL_POLYNOMIAL_INL_H_ +#else +#define LIB_JXL_BASE_RATIONAL_POLYNOMIAL_INL_H_ +#endif + +#include <stddef.h> + +#include <hwy/highway.h> +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Div; +using hwy::HWY_NAMESPACE::MulAdd; + +// Primary template: default to actual division. +template <typename T, class V> +struct FastDivision { + HWY_INLINE V operator()(const V n, const V d) const { return n / d; } +}; +// Partial specialization for float vectors. +template <class V> +struct FastDivision<float, V> { + // One Newton-Raphson iteration. + static HWY_INLINE V ReciprocalNR(const V x) { + const auto rcp = ApproximateReciprocal(x); + const auto sum = Add(rcp, rcp); + const auto x_rcp = Mul(x, rcp); + return NegMulAdd(x_rcp, rcp, sum); + } + + V operator()(const V n, const V d) const { +#if 1 // Faster on SKX + return Div(n, d); +#else + return n * ReciprocalNR(d); +#endif + } +}; + +// Approximates smooth functions via rational polynomials (i.e. dividing two +// polynomials). Evaluates polynomials via Horner's scheme, which is faster than +// Clenshaw recurrence for Chebyshev polynomials. LoadDup128 allows us to +// specify constants (replicated 4x) independently of the lane count. +template <size_t NP, size_t NQ, class D, class V, typename T> +HWY_INLINE HWY_MAYBE_UNUSED V EvalRationalPolynomial(const D d, const V x, + const T (&p)[NP], + const T (&q)[NQ]) { + constexpr size_t kDegP = NP / 4 - 1; + constexpr size_t kDegQ = NQ / 4 - 1; + auto yp = LoadDup128(d, &p[kDegP * 4]); + auto yq = LoadDup128(d, &q[kDegQ * 4]); + // We use pointer arithmetic to refer to &p[(kDegP - n) * 4] to avoid a + // compiler warning that the index is out of bounds since we are already + // checking that it is not out of bounds with (kDegP >= n) and the access + // will be optimized away. Similarly with q and kDegQ. + HWY_FENCE; + if (kDegP >= 1) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 1) * 4))); + if (kDegQ >= 1) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 1) * 4))); + HWY_FENCE; + if (kDegP >= 2) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 2) * 4))); + if (kDegQ >= 2) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 2) * 4))); + HWY_FENCE; + if (kDegP >= 3) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 3) * 4))); + if (kDegQ >= 3) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 3) * 4))); + HWY_FENCE; + if (kDegP >= 4) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 4) * 4))); + if (kDegQ >= 4) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 4) * 4))); + HWY_FENCE; + if (kDegP >= 5) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 5) * 4))); + if (kDegQ >= 5) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 5) * 4))); + HWY_FENCE; + if (kDegP >= 6) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 6) * 4))); + if (kDegQ >= 6) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 6) * 4))); + HWY_FENCE; + if (kDegP >= 7) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 7) * 4))); + if (kDegQ >= 7) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 7) * 4))); + + static_assert(kDegP < 8, "Polynomial degree is too high"); + static_assert(kDegQ < 8, "Polynomial degree is too high"); + + return FastDivision<T, V>()(yp, yq); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); +#endif // LIB_JXL_BASE_RATIONAL_POLYNOMIAL_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/sanitizer_definitions.h b/third_party/jpeg-xl/lib/jxl/base/sanitizer_definitions.h new file mode 100644 index 0000000000..315f3bd003 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/sanitizer_definitions.h @@ -0,0 +1,44 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef LIB_JXL_BASE_SANITIZER_DEFINITIONS_H_
+#define LIB_JXL_BASE_SANITIZER_DEFINITIONS_H_
+
+#ifdef MEMORY_SANITIZER
+#define JXL_MEMORY_SANITIZER 1
+#elif defined(__has_feature)
+#if __has_feature(memory_sanitizer)
+#define JXL_MEMORY_SANITIZER 1
+#else
+#define JXL_MEMORY_SANITIZER 0
+#endif
+#else
+#define JXL_MEMORY_SANITIZER 0
+#endif
+
+#ifdef ADDRESS_SANITIZER
+#define JXL_ADDRESS_SANITIZER 1
+#elif defined(__has_feature)
+#if __has_feature(address_sanitizer)
+#define JXL_ADDRESS_SANITIZER 1
+#else
+#define JXL_ADDRESS_SANITIZER 0
+#endif
+#else
+#define JXL_ADDRESS_SANITIZER 0
+#endif
+
+#ifdef THREAD_SANITIZER
+#define JXL_THREAD_SANITIZER 1
+#elif defined(__has_feature)
+#if __has_feature(thread_sanitizer)
+#define JXL_THREAD_SANITIZER 1
+#else
+#define JXL_THREAD_SANITIZER 0
+#endif
+#else
+#define JXL_THREAD_SANITIZER 0
+#endif
+#endif // LIB_JXL_BASE_SANITIZER_DEFINITIONS_H
diff --git a/third_party/jpeg-xl/lib/jxl/base/scope_guard.h b/third_party/jpeg-xl/lib/jxl/base/scope_guard.h new file mode 100644 index 0000000000..a18a44cb79 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/scope_guard.h @@ -0,0 +1,48 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_SCOPE_GUARD_H_ +#define LIB_JXL_BASE_SCOPE_GUARD_H_ + +#include <utility> + +namespace jxl { + +template <typename Callback> +class ScopeGuard { + public: + // Discourage unnecessary moves / copies. + ScopeGuard(const ScopeGuard &) = delete; + ScopeGuard &operator=(const ScopeGuard &) = delete; + ScopeGuard &operator=(ScopeGuard &&) = delete; + + // Pre-C++17 does not guarantee RVO -> require move constructor. + ScopeGuard(ScopeGuard &&other) : callback_(std::move(other.callback_)) { + other.armed_ = false; + } + + template <typename CallbackParam> + explicit ScopeGuard(CallbackParam &&callback) + : callback_(std::forward<CallbackParam>(callback)), armed_(true) {} + + ~ScopeGuard() { + if (armed_) callback_(); + } + + void Disarm() { armed_ = false; } + + private: + Callback callback_; + bool armed_; +}; + +template <typename Callback> +ScopeGuard<Callback> MakeScopeGuard(Callback &&callback) { + return ScopeGuard<Callback>{std::forward<Callback>(callback)}; +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_SCOPE_GUARD_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/span.h b/third_party/jpeg-xl/lib/jxl/base/span.h new file mode 100644 index 0000000000..dc1c781b9d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/span.h @@ -0,0 +1,80 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_SPAN_H_ +#define LIB_JXL_BASE_SPAN_H_ + +// Span (array view) is a non-owning container that provides cheap "cut" +// operations and could be used as "ArrayLike" data source for PaddedBytes. + +#include <cstddef> +#include <cstdint> +#include <vector> + +#include "lib/jxl/base/status.h" + +namespace jxl { + +template <typename T> +class Span { + public: + constexpr Span() noexcept : Span(nullptr, 0) {} + + constexpr Span(T* array, size_t length) noexcept + : ptr_(array), len_(length) {} + + template <size_t N> + explicit constexpr Span(T (&a)[N]) noexcept : Span(a, N) {} + + template <typename U> + constexpr Span(U* array, size_t length) noexcept + : ptr_(reinterpret_cast<T*>(array)), len_(length) { + static_assert(sizeof(U) == sizeof(T), "Incompatible type of source."); + } + + template <typename ArrayLike> + explicit constexpr Span(const ArrayLike& other) noexcept + : Span(reinterpret_cast<T*>(other.data()), other.size()) { + static_assert(sizeof(*other.data()) == sizeof(T), + "Incompatible type of source."); + } + + constexpr T* data() const noexcept { return ptr_; } + + constexpr size_t size() const noexcept { return len_; } + + constexpr bool empty() const noexcept { return len_ == 0; } + + constexpr T* begin() const noexcept { return data(); } + + constexpr T* end() const noexcept { return data() + size(); } + + constexpr T& operator[](size_t i) const noexcept { + // MSVC 2015 accepts this as constexpr, but not ptr_[i] + return *(data() + i); + } + + void remove_prefix(size_t n) noexcept { + JXL_ASSERT(size() >= n); + ptr_ += n; + len_ -= n; + } + + // NCT == non-const-T; compiler will complain if NCT is not compatible with T. + template <typename NCT> + void AppendTo(std::vector<NCT>* dst) const { + dst->insert(dst->end(), begin(), end()); + } + + private: + T* ptr_; + size_t len_; +}; + +typedef Span<const uint8_t> Bytes; + +} // namespace jxl + +#endif // LIB_JXL_BASE_SPAN_H_ diff --git a/third_party/jpeg-xl/lib/jxl/base/status.h b/third_party/jpeg-xl/lib/jxl/base/status.h new file mode 100644 index 0000000000..b33bd64fc3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/base/status.h @@ -0,0 +1,456 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_STATUS_H_ +#define LIB_JXL_BASE_STATUS_H_ + +// Error handling: Status return type + helper macros. + +#include <stdarg.h> +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> + +#include <type_traits> +#include <utility> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/sanitizer_definitions.h" + +#if JXL_ADDRESS_SANITIZER || JXL_MEMORY_SANITIZER || JXL_THREAD_SANITIZER +#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace +#endif // defined(*_SANITIZER) + +namespace jxl { + +// Uncomment to abort when JXL_FAILURE or JXL_STATUS with a fatal error is +// reached: +// #define JXL_CRASH_ON_ERROR + +#ifndef JXL_ENABLE_ASSERT +#define JXL_ENABLE_ASSERT 1 +#endif + +#ifndef JXL_ENABLE_CHECK +#define JXL_ENABLE_CHECK 1 +#endif + +// Pass -DJXL_DEBUG_ON_ERROR at compile time to print debug messages when a +// function returns JXL_FAILURE or calls JXL_NOTIFY_ERROR. Note that this is +// irrelevant if you also pass -DJXL_CRASH_ON_ERROR. +#if defined(JXL_DEBUG_ON_ERROR) || defined(JXL_CRASH_ON_ERROR) +#undef JXL_DEBUG_ON_ERROR +#define JXL_DEBUG_ON_ERROR 1 +#else // JXL_DEBUG_ON_ERROR || JXL_CRASH_ON_ERROR +#ifdef NDEBUG +#define JXL_DEBUG_ON_ERROR 0 +#else // NDEBUG +#define JXL_DEBUG_ON_ERROR 1 +#endif // NDEBUG +#endif // JXL_DEBUG_ON_ERROR || JXL_CRASH_ON_ERROR + +// Pass -DJXL_DEBUG_ON_ALL_ERROR at compile time to print debug messages on +// all error (fatal and non-fatal) status. This implies JXL_DEBUG_ON_ERROR. +#if defined(JXL_DEBUG_ON_ALL_ERROR) +#undef JXL_DEBUG_ON_ALL_ERROR +#define JXL_DEBUG_ON_ALL_ERROR 1 +// JXL_DEBUG_ON_ALL_ERROR implies JXL_DEBUG_ON_ERROR too. +#undef JXL_DEBUG_ON_ERROR +#define JXL_DEBUG_ON_ERROR 1 +#else // JXL_DEBUG_ON_ALL_ERROR +#define JXL_DEBUG_ON_ALL_ERROR 0 +#endif // JXL_DEBUG_ON_ALL_ERROR + +// The Verbose level for the library +#ifndef JXL_DEBUG_V_LEVEL +#define JXL_DEBUG_V_LEVEL 0 +#endif // JXL_DEBUG_V_LEVEL + +// Pass -DJXL_DEBUG_ON_ABORT={0,1} to force disable/enable the debug messages on +// JXL_ASSERT, JXL_CHECK and JXL_ABORT. +#ifndef JXL_DEBUG_ON_ABORT +#define JXL_DEBUG_ON_ABORT JXL_DEBUG_ON_ERROR +#endif // JXL_DEBUG_ON_ABORT + +#ifdef USE_ANDROID_LOGGER +#include <android/log.h> +#define LIBJXL_ANDROID_LOG_TAG ("libjxl") +inline void android_vprintf(const char* format, va_list args) { + char* message = nullptr; + int res = vasprintf(&message, format, args); + if (res != -1) { + __android_log_write(ANDROID_LOG_DEBUG, LIBJXL_ANDROID_LOG_TAG, message); + free(message); + } +} +#endif + +// Print a debug message on standard error or android logs. You should use the +// JXL_DEBUG macro instead of calling Debug directly. This function returns +// false, so it can be used as a return value in JXL_FAILURE. +JXL_FORMAT(1, 2) +inline JXL_NOINLINE bool Debug(const char* format, ...) { + va_list args; + va_start(args, format); +#ifdef USE_ANDROID_LOGGER + android_vprintf(format, args); +#else + vfprintf(stderr, format, args); +#endif + va_end(args); + return false; +} + +// Print a debug message on standard error if "enabled" is true. "enabled" is +// normally a macro that evaluates to 0 or 1 at compile time, so the Debug +// function is never called and optimized out in release builds. Note that the +// arguments are compiled but not evaluated when enabled is false. The format +// string must be a explicit string in the call, for example: +// JXL_DEBUG(JXL_DEBUG_MYMODULE, "my module message: %d", some_var); +// Add a header at the top of your module's .cc or .h file (depending on whether +// you have JXL_DEBUG calls from the .h as well) like this: +// #ifndef JXL_DEBUG_MYMODULE +// #define JXL_DEBUG_MYMODULE 0 +// #endif JXL_DEBUG_MYMODULE +#define JXL_DEBUG_TMP(format, ...) \ + ::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__) + +#define JXL_DEBUG(enabled, format, ...) \ + do { \ + if (enabled) { \ + JXL_DEBUG_TMP(format, ##__VA_ARGS__); \ + } \ + } while (0) + +// JXL_DEBUG version that prints the debug message if the global verbose level +// defined at compile time by JXL_DEBUG_V_LEVEL is greater or equal than the +// passed level. +#if JXL_DEBUG_V_LEVEL > 0 +#define JXL_DEBUG_V(level, format, ...) \ + JXL_DEBUG(level <= JXL_DEBUG_V_LEVEL, format, ##__VA_ARGS__) +#else +#define JXL_DEBUG_V(level, format, ...) +#endif + +// Warnings (via JXL_WARNING) are enabled by default in debug builds (opt and +// debug). +#ifdef JXL_DEBUG_WARNING +#undef JXL_DEBUG_WARNING +#define JXL_DEBUG_WARNING 1 +#else // JXL_DEBUG_WARNING +#ifdef NDEBUG +#define JXL_DEBUG_WARNING 0 +#else // JXL_DEBUG_WARNING +#define JXL_DEBUG_WARNING 1 +#endif // NDEBUG +#endif // JXL_DEBUG_WARNING +#define JXL_WARNING(format, ...) \ + JXL_DEBUG(JXL_DEBUG_WARNING, format, ##__VA_ARGS__) + +// Exits the program after printing a stack trace when possible. +JXL_NORETURN inline JXL_NOINLINE bool Abort() { +#if JXL_ADDRESS_SANITIZER || JXL_MEMORY_SANITIZER || JXL_THREAD_SANITIZER + // If compiled with any sanitizer print a stack trace. This call doesn't crash + // the program, instead the trap below will crash it also allowing gdb to + // break there. + __sanitizer_print_stack_trace(); +#endif // *_SANITIZER) + +#if JXL_COMPILER_MSVC + __debugbreak(); + abort(); +#else + __builtin_trap(); +#endif +} + +// Exits the program after printing file/line plus a formatted string. +#define JXL_ABORT(format, ...) \ + ((JXL_DEBUG_ON_ABORT) && ::jxl::Debug(("%s:%d: JXL_ABORT: " format "\n"), \ + __FILE__, __LINE__, ##__VA_ARGS__), \ + ::jxl::Abort()) + +// Use this for code paths that are unreachable unless the code would change +// to make it reachable, in which case it will print a warning and abort in +// debug builds. In release builds no code is produced for this, so only use +// this if this path is really unreachable. +#define JXL_UNREACHABLE(format, ...) \ + do { \ + if (JXL_DEBUG_WARNING) { \ + ::jxl::Debug(("%s:%d: JXL_UNREACHABLE: " format "\n"), __FILE__, \ + __LINE__, ##__VA_ARGS__); \ + ::jxl::Abort(); \ + } else { \ + JXL_UNREACHABLE_BUILTIN; \ + } \ + } while (0) + +// Does not guarantee running the code, use only for debug mode checks. +#if JXL_ENABLE_ASSERT +#define JXL_ASSERT(condition) \ + do { \ + if (!(condition)) { \ + JXL_DEBUG(JXL_DEBUG_ON_ABORT, "JXL_ASSERT: %s", #condition); \ + ::jxl::Abort(); \ + } \ + } while (0) +#else +#define JXL_ASSERT(condition) \ + do { \ + } while (0) +#endif + +// Define JXL_IS_DEBUG_BUILD that denotes asan, msan and other debug builds, +// but not opt or release. +#ifndef JXL_IS_DEBUG_BUILD +#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) || \ + defined(MEMORY_SANITIZER) || defined(THREAD_SANITIZER) || \ + defined(__clang_analyzer__) +#define JXL_IS_DEBUG_BUILD 1 +#else +#define JXL_IS_DEBUG_BUILD 0 +#endif +#endif // JXL_IS_DEBUG_BUILD + +// Same as above, but only runs in debug builds (builds where NDEBUG is not +// defined). This is useful for slower asserts that we want to run more rarely +// than usual. These will run on asan, msan and other debug builds, but not in +// opt or release. +#if JXL_IS_DEBUG_BUILD +#define JXL_DASSERT(condition) \ + do { \ + if (!(condition)) { \ + JXL_DEBUG(JXL_DEBUG_ON_ABORT, "JXL_DASSERT: %s", #condition); \ + ::jxl::Abort(); \ + } \ + } while (0) +#else +#define JXL_DASSERT(condition) \ + do { \ + } while (0) +#endif + +// Always runs the condition, so can be used for non-debug calls. +#if JXL_ENABLE_CHECK +#define JXL_CHECK(condition) \ + do { \ + if (!(condition)) { \ + JXL_DEBUG(JXL_DEBUG_ON_ABORT, "JXL_CHECK: %s", #condition); \ + ::jxl::Abort(); \ + } \ + } while (0) +#else +#define JXL_CHECK(condition) \ + do { \ + (void)(condition); \ + } while (0) +#endif + +// A jxl::Status value from a StatusCode or Status which prints a debug message +// when enabled. +#define JXL_STATUS(status, format, ...) \ + ::jxl::StatusMessage(::jxl::Status(status), "%s:%d: " format "\n", __FILE__, \ + __LINE__, ##__VA_ARGS__) + +// Notify of an error but discard the resulting Status value. This is only +// useful for debug builds or when building with JXL_CRASH_ON_ERROR. +#define JXL_NOTIFY_ERROR(format, ...) \ + (void)JXL_STATUS(::jxl::StatusCode::kGenericError, "JXL_ERROR: " format, \ + ##__VA_ARGS__) + +// An error Status with a message. The JXL_STATUS() macro will return a Status +// object with a kGenericError code, but the comma operator helps with +// clang-tidy inference and potentially with optimizations. +#define JXL_FAILURE(format, ...) \ + ((void)JXL_STATUS(::jxl::StatusCode::kGenericError, "JXL_FAILURE: " format, \ + ##__VA_ARGS__), \ + ::jxl::Status(::jxl::StatusCode::kGenericError)) + +// Always evaluates the status exactly once, so can be used for non-debug calls. +// Returns from the current context if the passed Status expression is an error +// (fatal or non-fatal). The return value is the passed Status. +#define JXL_RETURN_IF_ERROR(status) \ + do { \ + ::jxl::Status jxl_return_if_error_status = (status); \ + if (!jxl_return_if_error_status) { \ + (void)::jxl::StatusMessage( \ + jxl_return_if_error_status, \ + "%s:%d: JXL_RETURN_IF_ERROR code=%d: %s\n", __FILE__, __LINE__, \ + static_cast<int>(jxl_return_if_error_status.code()), #status); \ + return jxl_return_if_error_status; \ + } \ + } while (0) + +// As above, but without calling StatusMessage. Intended for bundles (see +// fields.h), which have numerous call sites (-> relevant for code size) and do +// not want to generate excessive messages when decoding partial headers. +#define JXL_QUIET_RETURN_IF_ERROR(status) \ + do { \ + ::jxl::Status jxl_return_if_error_status = (status); \ + if (!jxl_return_if_error_status) { \ + return jxl_return_if_error_status; \ + } \ + } while (0) + +enum class StatusCode : int32_t { + // Non-fatal errors (negative values). + kNotEnoughBytes = -1, + + // The only non-error status code. + kOk = 0, + + // Fatal-errors (positive values) + kGenericError = 1, +}; + +// Drop-in replacement for bool that raises compiler warnings if not used +// after being returned from a function. Example: +// Status LoadFile(...) { return true; } is more compact than +// bool JXL_MUST_USE_RESULT LoadFile(...) { return true; } +// In case of error, the status can carry an extra error code in its value which +// is split between fatal and non-fatal error codes. +class JXL_MUST_USE_RESULT Status { + public: + // We want implicit constructor from bool to allow returning "true" or "false" + // on a function when using Status. "true" means kOk while "false" means a + // generic fatal error. + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Status(bool ok) + : code_(ok ? StatusCode::kOk : StatusCode::kGenericError) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr Status(StatusCode code) : code_(code) {} + + // We also want implicit cast to bool to check for return values of functions. + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr operator bool() const { return code_ == StatusCode::kOk; } + + constexpr StatusCode code() const { return code_; } + + // Returns whether the status code is a fatal error. + constexpr bool IsFatalError() const { + return static_cast<int32_t>(code_) > 0; + } + + private: + StatusCode code_; +}; + +static constexpr Status OkStatus() { return Status(StatusCode::kOk); } + +// Helper function to create a Status and print the debug message or abort when +// needed. +inline JXL_FORMAT(2, 3) Status + StatusMessage(const Status status, const char* format, ...) { + // This block will be optimized out when JXL_DEBUG_ON_ERROR and + // JXL_DEBUG_ON_ALL_ERROR are both disabled. + if ((JXL_DEBUG_ON_ERROR && status.IsFatalError()) || + (JXL_DEBUG_ON_ALL_ERROR && !status)) { + va_list args; + va_start(args, format); +#ifdef USE_ANDROID_LOGGER + android_vprintf(format, args); +#else + vfprintf(stderr, format, args); +#endif + va_end(args); + } +#ifdef JXL_CRASH_ON_ERROR + // JXL_CRASH_ON_ERROR means to Abort() only on non-fatal errors. + if (status.IsFatalError()) { + Abort(); + } +#endif // JXL_CRASH_ON_ERROR + return status; +} + +template <typename T> +class JXL_MUST_USE_RESULT StatusOr { + static_assert(!std::is_convertible<StatusCode, T>::value && + !std::is_convertible<T, StatusCode>::value, + "You cannot make a StatusOr with a type convertible from or to " + "StatusCode"); + static_assert(std::is_move_constructible<T>::value && + std::is_move_assignable<T>::value, + "T must be move constructible and move assignable"); + + public: + // NOLINTNEXTLINE(google-explicit-constructor) + StatusOr(StatusCode code) : code_(code) { + JXL_ASSERT(code_ != StatusCode::kOk); + } + + // NOLINTNEXTLINE(google-explicit-constructor) + StatusOr(Status status) : StatusOr(status.code()) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + StatusOr(T&& value) : code_(StatusCode::kOk) { + new (&storage_.data_) T(std::move(value)); + } + + StatusOr(StatusOr&& other) noexcept { + if (other.ok()) { + new (&storage_.data_) T(std::move(other.storage_.data_)); + } + code_ = other.code_; + } + + StatusOr& operator=(StatusOr&& other) noexcept { + if (this == &other) return *this; + if (ok() && other.ok()) { + storage_.data_ = std::move(other.storage_.data_); + } else if (other.ok()) { + new (&storage_.data_) T(std::move(other.storage_.data_)); + } else if (ok()) { + storage_.data_.~T(); + } + code_ = other.code_; + return *this; + } + + StatusOr(const StatusOr&) = delete; + StatusOr operator=(const StatusOr&) = delete; + + bool ok() const { return code_ == StatusCode::kOk; } + Status status() const { return code_; } + + // Only call this if you are absolutely sure that `ok()` is true. + // Ideally, never call this manually and rely on JXL_ASSIGN_OR_RETURN. + T value() && { + JXL_ASSERT(ok()); + return std::move(storage_.data_); + } + + ~StatusOr() { + if (code_ == StatusCode::kOk) { + storage_.data_.~T(); + } + } + + private: + union Storage { + char placeholder_; + T data_; + Storage() {} + ~Storage() {} + } storage_; + + StatusCode code_; +}; + +#define JXL_ASSIGN_OR_RETURN(lhs, statusor) \ + PRIVATE_JXL_ASSIGN_OR_RETURN_IMPL( \ + assign_or_return_temporary_variable##__LINE__, lhs, statusor) + +// NOLINTBEGIN(bugprone-macro-parentheses) +#define PRIVATE_JXL_ASSIGN_OR_RETURN_IMPL(name, lhs, statusor) \ + auto name = statusor; \ + JXL_RETURN_IF_ERROR(name.status()); \ + lhs = std::move(name).value(); +// NOLINTEND(bugprone-macro-parentheses) + +} // namespace jxl + +#endif // LIB_JXL_BASE_STATUS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/bit_reader_test.cc b/third_party/jpeg-xl/lib/jxl/bit_reader_test.cc new file mode 100644 index 0000000000..22a20649e0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/bit_reader_test.cc @@ -0,0 +1,262 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <stddef.h> +#include <stdint.h> + +#include <array> +#include <vector> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +TEST(BitReaderTest, ExtendsWithZeroes) { + for (size_t size = 4; size < 32; ++size) { + std::vector<uint8_t> data(size, 0xff); + + for (size_t n_bytes = 0; n_bytes < size; n_bytes++) { + BitReader br(Bytes(data.data(), n_bytes)); + // Read all the bits + for (size_t i = 0; i < n_bytes * kBitsPerByte; i++) { + ASSERT_EQ(br.ReadBits(1), 1u) << "n_bytes=" << n_bytes << " i=" << i; + } + + // PEEK more than the declared size - all will be zero. Cannot consume. + for (size_t i = 0; i < BitReader::kMaxBitsPerCall; i++) { + ASSERT_EQ(br.PeekBits(i), 0u) + << "size=" << size << "n_bytes=" << n_bytes << " i=" << i; + } + + EXPECT_TRUE(br.Close()); + } + } +} + +struct Symbol { + uint32_t num_bits; + uint32_t value; +}; + +// Reading from output gives the same values. +TEST(BitReaderTest, TestRoundTrip) { + test::ThreadPoolForTests pool(8); + EXPECT_TRUE(RunOnPool( + &pool, 0, 1000, ThreadPool::NoInit, + [](const uint32_t task, size_t /* thread */) { + constexpr size_t kMaxBits = 8000; + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + + std::vector<Symbol> symbols; + symbols.reserve(1000); + + Rng rng(55537 + 129 * task); + + for (;;) { + const uint32_t num_bits = rng.UniformU(1, 33); + if (writer.BitsWritten() + num_bits > kMaxBits) break; + const uint32_t value = rng.UniformU(0, 1ULL << num_bits); + symbols.push_back({num_bits, value}); + writer.Write(num_bits, value); + } + + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, 0, nullptr); + BitReader reader(writer.GetSpan()); + for (const Symbol& s : symbols) { + EXPECT_EQ(s.value, reader.ReadBits(s.num_bits)); + } + EXPECT_TRUE(reader.Close()); + }, + "TestTBitReaderRoundTrip")); +} + +// SkipBits is the same as reading that many bits. +TEST(BitReaderTest, TestSkip) { + test::ThreadPoolForTests pool(8); + EXPECT_TRUE(RunOnPool( + &pool, 0, 96, ThreadPool::NoInit, + [](const uint32_t task, size_t /* thread */) { + constexpr size_t kSize = 100; + + for (size_t skip = 0; skip < 128; ++skip) { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kSize * kBitsPerByte); + // Start with "task" 1-bits. + for (size_t i = 0; i < task; ++i) { + writer.Write(1, 1); + } + + // Write 0-bits that we will skip over + for (size_t i = 0; i < skip; ++i) { + writer.Write(1, 0); + } + + // Write terminator bits '101' + writer.Write(3, 5); + EXPECT_EQ(task + skip + 3, writer.BitsWritten()); + writer.ZeroPadToByte(); + AuxOut aux_out; + allotment.ReclaimAndCharge(&writer, 0, &aux_out); + EXPECT_LT(aux_out.layers[0].total_bits, kSize * 8); + + BitReader reader1(writer.GetSpan()); + BitReader reader2(writer.GetSpan()); + // Verify initial 1-bits + for (size_t i = 0; i < task; ++i) { + EXPECT_EQ(1u, reader1.ReadBits(1)); + EXPECT_EQ(1u, reader2.ReadBits(1)); + } + + // SkipBits or manually read "skip" bits + reader1.SkipBits(skip); + for (size_t i = 0; i < skip; ++i) { + EXPECT_EQ(0u, reader2.ReadBits(1)) + << " skip=" << skip << " i=" << i; + } + EXPECT_EQ(reader1.TotalBitsConsumed(), reader2.TotalBitsConsumed()); + + // Ensure both readers see the terminator bits. + EXPECT_EQ(5u, reader1.ReadBits(3)); + EXPECT_EQ(5u, reader2.ReadBits(3)); + + EXPECT_TRUE(reader1.Close()); + EXPECT_TRUE(reader2.Close()); + } + }, + "TestSkip")); +} + +// Verifies byte order and different groupings of bits. +TEST(BitReaderTest, TestOrder) { + constexpr size_t kMaxBits = 16; + + // u(1) - bits written into LSBs of first byte + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + for (size_t i = 0; i < 5; ++i) { + writer.Write(1, 1); + } + for (size_t i = 0; i < 5; ++i) { + writer.Write(1, 0); + } + for (size_t i = 0; i < 6; ++i) { + writer.Write(1, 1); + } + + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0x1Fu, reader.ReadFixedBits<8>()); + EXPECT_EQ(0xFCu, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } + + // u(8) - get bytes in the same order + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + writer.Write(8, 0xF8); + writer.Write(8, 0x3F); + + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0xF8u, reader.ReadFixedBits<8>()); + EXPECT_EQ(0x3Fu, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } + + // u(16) - little-endian bytes + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + writer.Write(16, 0xF83F); + + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0x3Fu, reader.ReadFixedBits<8>()); + EXPECT_EQ(0xF8u, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } + + // Non-byte-aligned, mixed sizes + { + BitWriter writer; + BitWriter::Allotment allotment(&writer, kMaxBits); + writer.Write(1, 1); + writer.Write(3, 6); + writer.Write(8, 0xDB); + writer.Write(4, 8); + + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, 0, nullptr); + BitReader reader(writer.GetSpan()); + EXPECT_EQ(0xBDu, reader.ReadFixedBits<8>()); + EXPECT_EQ(0x8Du, reader.ReadFixedBits<8>()); + EXPECT_TRUE(reader.Close()); + } +} + +TEST(BitReaderTest, TotalCountersTest) { + uint8_t buf[8] = {1, 2, 3, 4}; + BitReader reader(Bytes(buf, sizeof(buf))); + + EXPECT_EQ(sizeof(buf), reader.TotalBytes()); + EXPECT_EQ(0u, reader.TotalBitsConsumed()); + reader.ReadFixedBits<1>(); + EXPECT_EQ(1u, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<10>(); + EXPECT_EQ(11u, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<4>(); + EXPECT_EQ(15u, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<1>(); + EXPECT_EQ(16u, reader.TotalBitsConsumed()); + + reader.ReadFixedBits<16>(); + EXPECT_EQ(32u, reader.TotalBitsConsumed()); + + EXPECT_TRUE(reader.Close()); +} + +TEST(BitReaderTest, MoveTest) { + uint8_t buf[8] = {1, 2, 3, 4}; + BitReader reader2; + { + BitReader reader1(Bytes(buf, sizeof(buf))); + + EXPECT_EQ(0u, reader1.TotalBitsConsumed()); + reader1.ReadFixedBits<16>(); + EXPECT_EQ(16u, reader1.TotalBitsConsumed()); + + reader2 = std::move(reader1); + // From this point reader1 is invalid, but can continue to access reader2 + // and we don't need to call Close() on reader1. + } + + EXPECT_EQ(16u, reader2.TotalBitsConsumed()); + EXPECT_EQ(3U, reader2.ReadFixedBits<8>()); + EXPECT_EQ(24u, reader2.TotalBitsConsumed()); + + EXPECT_TRUE(reader2.Close()); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/bits_test.cc b/third_party/jpeg-xl/lib/jxl/bits_test.cc new file mode 100644 index 0000000000..bd7aa548c8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/bits_test.cc @@ -0,0 +1,87 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/bits.h" + +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +TEST(BitsTest, TestNumZeroBits) { + // Zero input is well-defined. + EXPECT_EQ(32u, Num0BitsAboveMS1Bit(0u)); + EXPECT_EQ(64u, Num0BitsAboveMS1Bit(0ull)); + EXPECT_EQ(32u, Num0BitsBelowLS1Bit(0u)); + EXPECT_EQ(64u, Num0BitsBelowLS1Bit(0ull)); + + EXPECT_EQ(31u, Num0BitsAboveMS1Bit(1u)); + EXPECT_EQ(30u, Num0BitsAboveMS1Bit(2u)); + EXPECT_EQ(63u, Num0BitsAboveMS1Bit(1ull)); + EXPECT_EQ(62u, Num0BitsAboveMS1Bit(2ull)); + + EXPECT_EQ(0u, Num0BitsBelowLS1Bit(1u)); + EXPECT_EQ(0u, Num0BitsBelowLS1Bit(1ull)); + EXPECT_EQ(1u, Num0BitsBelowLS1Bit(2u)); + EXPECT_EQ(1u, Num0BitsBelowLS1Bit(2ull)); + + EXPECT_EQ(0u, Num0BitsAboveMS1Bit(0x80000000u)); + EXPECT_EQ(0u, Num0BitsAboveMS1Bit(0x8000000000000000ull)); + EXPECT_EQ(31u, Num0BitsBelowLS1Bit(0x80000000u)); + EXPECT_EQ(63u, Num0BitsBelowLS1Bit(0x8000000000000000ull)); +} + +TEST(BitsTest, TestFloorLog2) { + // for input = [1, 7] + const size_t expected[7] = {0, 1, 1, 2, 2, 2, 2}; + for (uint32_t i = 1; i <= 7; ++i) { + EXPECT_EQ(expected[i - 1], FloorLog2Nonzero(i)) << " " << i; + EXPECT_EQ(expected[i - 1], FloorLog2Nonzero(uint64_t(i))) << " " << i; + } + + EXPECT_EQ(11u, FloorLog2Nonzero(0x00000fffu)); // 4095 + EXPECT_EQ(12u, FloorLog2Nonzero(0x00001000u)); // 4096 + EXPECT_EQ(12u, FloorLog2Nonzero(0x00001001u)); // 4097 + + EXPECT_EQ(31u, FloorLog2Nonzero(0x80000000u)); + EXPECT_EQ(31u, FloorLog2Nonzero(0x80000001u)); + EXPECT_EQ(31u, FloorLog2Nonzero(0xFFFFFFFFu)); + + EXPECT_EQ(31u, FloorLog2Nonzero(0x80000000ull)); + EXPECT_EQ(31u, FloorLog2Nonzero(0x80000001ull)); + EXPECT_EQ(31u, FloorLog2Nonzero(0xFFFFFFFFull)); + + EXPECT_EQ(63u, FloorLog2Nonzero(0x8000000000000000ull)); + EXPECT_EQ(63u, FloorLog2Nonzero(0x8000000000000001ull)); + EXPECT_EQ(63u, FloorLog2Nonzero(0xFFFFFFFFFFFFFFFFull)); +} + +TEST(BitsTest, TestCeilLog2) { + // for input = [1, 7] + const size_t expected[7] = {0, 1, 2, 2, 3, 3, 3}; + for (uint32_t i = 1; i <= 7; ++i) { + EXPECT_EQ(expected[i - 1], CeilLog2Nonzero(i)) << " " << i; + EXPECT_EQ(expected[i - 1], CeilLog2Nonzero(uint64_t(i))) << " " << i; + } + + EXPECT_EQ(12u, CeilLog2Nonzero(0x00000fffu)); // 4095 + EXPECT_EQ(12u, CeilLog2Nonzero(0x00001000u)); // 4096 + EXPECT_EQ(13u, CeilLog2Nonzero(0x00001001u)); // 4097 + + EXPECT_EQ(31u, CeilLog2Nonzero(0x80000000u)); + EXPECT_EQ(32u, CeilLog2Nonzero(0x80000001u)); + EXPECT_EQ(32u, CeilLog2Nonzero(0xFFFFFFFFu)); + + EXPECT_EQ(31u, CeilLog2Nonzero(0x80000000ull)); + EXPECT_EQ(32u, CeilLog2Nonzero(0x80000001ull)); + EXPECT_EQ(32u, CeilLog2Nonzero(0xFFFFFFFFull)); + + EXPECT_EQ(63u, CeilLog2Nonzero(0x8000000000000000ull)); + EXPECT_EQ(64u, CeilLog2Nonzero(0x8000000000000001ull)); + EXPECT_EQ(64u, CeilLog2Nonzero(0xFFFFFFFFFFFFFFFFull)); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/blending.cc b/third_party/jpeg-xl/lib/jxl/blending.cc new file mode 100644 index 0000000000..ccb168ee45 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/blending.cc @@ -0,0 +1,151 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/blending.h" + +#include "lib/jxl/alpha.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +bool NeedsBlending(const FrameHeader& frame_header) { + if (!(frame_header.frame_type == FrameType::kRegularFrame || + frame_header.frame_type == FrameType::kSkipProgressive)) { + return false; + } + const auto& info = frame_header.blending_info; + bool replace_all = (info.mode == BlendMode::kReplace); + for (const auto& ec_i : frame_header.extra_channel_blending_info) { + if (ec_i.mode != BlendMode::kReplace) { + replace_all = false; + } + } + // Replace the full frame: nothing to do. + if (!frame_header.custom_size_or_origin && replace_all) { + return false; + } + return true; +} + +void PerformBlending(const float* const* bg, const float* const* fg, + float* const* out, size_t x0, size_t xsize, + const PatchBlending& color_blending, + const PatchBlending* ec_blending, + const std::vector<ExtraChannelInfo>& extra_channel_info) { + bool has_alpha = false; + size_t num_ec = extra_channel_info.size(); + for (size_t i = 0; i < num_ec; i++) { + if (extra_channel_info[i].type == jxl::ExtraChannel::kAlpha) { + has_alpha = true; + break; + } + } + ImageF tmp(xsize, 3 + num_ec); + // Blend extra channels first so that we use the pre-blending alpha. + for (size_t i = 0; i < num_ec; i++) { + if (ec_blending[i].mode == PatchBlendMode::kAdd) { + for (size_t x = 0; x < xsize; x++) { + tmp.Row(3 + i)[x] = bg[3 + i][x + x0] + fg[3 + i][x + x0]; + } + } else if (ec_blending[i].mode == PatchBlendMode::kBlendAbove) { + size_t alpha = ec_blending[i].alpha_channel; + bool is_premultiplied = extra_channel_info[alpha].alpha_associated; + PerformAlphaBlending(bg[3 + i] + x0, bg[3 + alpha] + x0, fg[3 + i] + x0, + fg[3 + alpha] + x0, tmp.Row(3 + i), xsize, + is_premultiplied, ec_blending[i].clamp); + } else if (ec_blending[i].mode == PatchBlendMode::kBlendBelow) { + size_t alpha = ec_blending[i].alpha_channel; + bool is_premultiplied = extra_channel_info[alpha].alpha_associated; + PerformAlphaBlending(fg[3 + i] + x0, fg[3 + alpha] + x0, bg[3 + i] + x0, + bg[3 + alpha] + x0, tmp.Row(3 + i), xsize, + is_premultiplied, ec_blending[i].clamp); + } else if (ec_blending[i].mode == PatchBlendMode::kAlphaWeightedAddAbove) { + size_t alpha = ec_blending[i].alpha_channel; + PerformAlphaWeightedAdd(bg[3 + i] + x0, fg[3 + i] + x0, + fg[3 + alpha] + x0, tmp.Row(3 + i), xsize, + ec_blending[i].clamp); + } else if (ec_blending[i].mode == PatchBlendMode::kAlphaWeightedAddBelow) { + size_t alpha = ec_blending[i].alpha_channel; + PerformAlphaWeightedAdd(fg[3 + i] + x0, bg[3 + i] + x0, + bg[3 + alpha] + x0, tmp.Row(3 + i), xsize, + ec_blending[i].clamp); + } else if (ec_blending[i].mode == PatchBlendMode::kMul) { + PerformMulBlending(bg[3 + i] + x0, fg[3 + i] + x0, tmp.Row(3 + i), xsize, + ec_blending[i].clamp); + } else if (ec_blending[i].mode == PatchBlendMode::kReplace) { + memcpy(tmp.Row(3 + i), fg[3 + i] + x0, xsize * sizeof(**fg)); + } else if (ec_blending[i].mode == PatchBlendMode::kNone) { + if (xsize) memcpy(tmp.Row(3 + i), bg[3 + i] + x0, xsize * sizeof(**fg)); + } else { + JXL_UNREACHABLE("new PatchBlendMode?"); + } + } + size_t alpha = color_blending.alpha_channel; + + if (color_blending.mode == PatchBlendMode::kAdd || + (color_blending.mode == PatchBlendMode::kAlphaWeightedAddAbove && + !has_alpha) || + (color_blending.mode == PatchBlendMode::kAlphaWeightedAddBelow && + !has_alpha)) { + for (int p = 0; p < 3; p++) { + float* out = tmp.Row(p); + for (size_t x = 0; x < xsize; x++) { + out[x] = bg[p][x + x0] + fg[p][x + x0]; + } + } + } else if (color_blending.mode == PatchBlendMode::kBlendAbove + // blend without alpha is just replace + && has_alpha) { + bool is_premultiplied = extra_channel_info[alpha].alpha_associated; + PerformAlphaBlending( + {bg[0] + x0, bg[1] + x0, bg[2] + x0, bg[3 + alpha] + x0}, + {fg[0] + x0, fg[1] + x0, fg[2] + x0, fg[3 + alpha] + x0}, + {tmp.Row(0), tmp.Row(1), tmp.Row(2), tmp.Row(3 + alpha)}, xsize, + is_premultiplied, color_blending.clamp); + } else if (color_blending.mode == PatchBlendMode::kBlendBelow + // blend without alpha is just replace + && has_alpha) { + bool is_premultiplied = extra_channel_info[alpha].alpha_associated; + PerformAlphaBlending( + {fg[0] + x0, fg[1] + x0, fg[2] + x0, fg[3 + alpha] + x0}, + {bg[0] + x0, bg[1] + x0, bg[2] + x0, bg[3 + alpha] + x0}, + {tmp.Row(0), tmp.Row(1), tmp.Row(2), tmp.Row(3 + alpha)}, xsize, + is_premultiplied, color_blending.clamp); + } else if (color_blending.mode == PatchBlendMode::kAlphaWeightedAddAbove) { + JXL_DASSERT(has_alpha); + for (size_t c = 0; c < 3; c++) { + PerformAlphaWeightedAdd(bg[c] + x0, fg[c] + x0, fg[3 + alpha] + x0, + tmp.Row(c), xsize, color_blending.clamp); + } + } else if (color_blending.mode == PatchBlendMode::kAlphaWeightedAddBelow) { + JXL_DASSERT(has_alpha); + for (size_t c = 0; c < 3; c++) { + PerformAlphaWeightedAdd(fg[c] + x0, bg[c] + x0, bg[3 + alpha] + x0, + tmp.Row(c), xsize, color_blending.clamp); + } + } else if (color_blending.mode == PatchBlendMode::kMul) { + for (int p = 0; p < 3; p++) { + PerformMulBlending(bg[p] + x0, fg[p] + x0, tmp.Row(p), xsize, + color_blending.clamp); + } + } else if (color_blending.mode == PatchBlendMode::kReplace || + color_blending.mode == PatchBlendMode::kBlendAbove || + color_blending.mode == PatchBlendMode::kBlendBelow) { // kReplace + for (size_t p = 0; p < 3; p++) { + memcpy(tmp.Row(p), fg[p] + x0, xsize * sizeof(**fg)); + } + } else if (color_blending.mode == PatchBlendMode::kNone) { + for (size_t p = 0; p < 3; p++) { + memcpy(tmp.Row(p), bg[p] + x0, xsize * sizeof(**fg)); + } + } else { + JXL_UNREACHABLE("new PatchBlendMode?"); + } + for (size_t i = 0; i < 3 + num_ec; i++) { + if (xsize != 0) memcpy(out[i] + x0, tmp.Row(i), xsize * sizeof(**out)); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/blending.h b/third_party/jpeg-xl/lib/jxl/blending.h new file mode 100644 index 0000000000..3f23297f1d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/blending.h @@ -0,0 +1,27 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BLENDING_H_ +#define LIB_JXL_BLENDING_H_ + +#include <vector> + +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image_metadata.h" + +namespace jxl { + +bool NeedsBlending(const FrameHeader& frame_header); + +void PerformBlending(const float* const* bg, const float* const* fg, + float* const* out, size_t x0, size_t xsize, + const PatchBlending& color_blending, + const PatchBlending* ec_blending, + const std::vector<ExtraChannelInfo>& extra_channel_info); + +} // namespace jxl + +#endif // LIB_JXL_BLENDING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/blending_test.cc b/third_party/jpeg-xl/lib/jxl/blending_test.cc new file mode 100644 index 0000000000..c34ab5c7ca --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/blending_test.cc @@ -0,0 +1,57 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/types.h> + +#include <cstdint> +#include <sstream> +#include <utility> +#include <vector> + +#include "lib/extras/dec/decode.h" +#include "lib/extras/dec/jxl.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +using ::testing::SizeIs; + +TEST(BlendingTest, Crops) { + const std::vector<uint8_t> compressed = + jxl::test::ReadTestData("jxl/blending/cropped_traffic_light.jxl"); + extras::JXLDecompressParams dparams; + dparams.accepted_formats = {{3, JXL_TYPE_UINT16, JXL_LITTLE_ENDIAN, 0}}; + extras::PackedPixelFile decoded; + ASSERT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), dparams, + /*decoded_bytes=*/nullptr, &decoded)); + ASSERT_THAT(decoded.frames, SizeIs(4)); + + int i = 0; + for (auto&& decoded_frame : decoded.frames) { + std::ostringstream filename; + filename << "jxl/blending/cropped_traffic_light_frame-" << i << ".png"; + const std::vector<uint8_t> compressed_frame = + jxl::test::ReadTestData(filename.str()); + extras::PackedPixelFile decoded_frame_ppf; + decoded_frame_ppf.info = decoded.info; + decoded_frame_ppf.icc = decoded.icc; + decoded_frame_ppf.color_encoding = decoded.color_encoding; + decoded_frame_ppf.extra_channels_info = decoded.extra_channels_info; + decoded_frame_ppf.frames.emplace_back(std::move(decoded_frame)); + extras::PackedPixelFile expected_frame_ppf; + ASSERT_TRUE(extras::DecodeBytes(Bytes(compressed_frame), + extras::ColorHints(), &expected_frame_ppf)); + EXPECT_EQ(0.0f, + test::ComputeDistance2(decoded_frame_ppf, expected_frame_ppf)); + ++i; + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/box_content_decoder.cc b/third_party/jpeg-xl/lib/jxl/box_content_decoder.cc new file mode 100644 index 0000000000..c4cba3a31a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/box_content_decoder.cc @@ -0,0 +1,101 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/box_content_decoder.h" + +#include "lib/jxl/sanitizers.h" + +namespace jxl { + +JxlBoxContentDecoder::JxlBoxContentDecoder() {} + +JxlBoxContentDecoder::~JxlBoxContentDecoder() { + if (brotli_dec) { + BrotliDecoderDestroyInstance(brotli_dec); + } +} + +void JxlBoxContentDecoder::StartBox(bool brob_decode, bool box_until_eof, + size_t contents_size) { + if (brotli_dec) { + BrotliDecoderDestroyInstance(brotli_dec); + brotli_dec = nullptr; + } + header_done_ = false; + brob_decode_ = brob_decode; + box_until_eof_ = box_until_eof; + remaining_ = box_until_eof ? 0 : contents_size; + pos_ = 0; +} + +JxlDecoderStatus JxlBoxContentDecoder::Process(const uint8_t* next_in, + size_t avail_in, size_t box_pos, + uint8_t** next_out, + size_t* avail_out) { + next_in += pos_ - box_pos; + avail_in -= pos_ - box_pos; + + if (brob_decode_) { + if (!header_done_) { + if (avail_in < 4) return JXL_DEC_NEED_MORE_INPUT; + if (!box_until_eof_) { + if (remaining_ < 4) return JXL_DEC_ERROR; + remaining_ -= 4; + } + next_in += 4; + avail_in -= 4; + pos_ += 4; + header_done_ = true; + } + + if (!brotli_dec) { + brotli_dec = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + } + + const uint8_t* next_in_before = next_in; + uint8_t* next_out_before = *next_out; + msan::MemoryIsInitialized(next_in, avail_in); + BrotliDecoderResult res = BrotliDecoderDecompressStream( + brotli_dec, &avail_in, &next_in, avail_out, next_out, nullptr); + size_t consumed = next_in - next_in_before; + size_t produced = *next_out - next_out_before; + if (res == BROTLI_DECODER_RESULT_ERROR) { + return JXL_DEC_ERROR; + } + msan::UnpoisonMemory(next_out_before, produced); + pos_ += consumed; + if (!box_until_eof_) remaining_ -= consumed; + if (res == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT) { + return JXL_DEC_NEED_MORE_INPUT; + } + if (res == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + return JXL_DEC_BOX_NEED_MORE_OUTPUT; + } + if (res == BROTLI_DECODER_RESULT_SUCCESS) { + return JXL_DEC_SUCCESS; + } + // unknown Brotli result + return JXL_DEC_ERROR; + } else { + // remaining box bytes as seen from dec->file_pos + size_t can_read = avail_in; + if (!box_until_eof_) can_read = std::min<size_t>(can_read, remaining_); + size_t to_write = std::min<size_t>(can_read, *avail_out); + memcpy(*next_out, next_in, to_write); + + *next_out += to_write; + *avail_out -= to_write; + if (!box_until_eof_) remaining_ -= to_write; + pos_ += to_write; + + if (to_write < can_read) return JXL_DEC_BOX_NEED_MORE_OUTPUT; + + if (!box_until_eof_ && remaining_ > 0) return JXL_DEC_NEED_MORE_INPUT; + + return JXL_DEC_SUCCESS; + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/box_content_decoder.h b/third_party/jpeg-xl/lib/jxl/box_content_decoder.h new file mode 100644 index 0000000000..17db7faa3b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/box_content_decoder.h @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BOX_CONTENT_DECODER_H_ +#define LIB_JXL_BOX_CONTENT_DECODER_H_ + +#include <brotli/decode.h> +#include <jxl/decode.h> +#include <stdint.h> +#include <stdlib.h> + +namespace jxl { + +/** Outputs the contents of a box in a streaming fashion, either directly, or + * optionally decoding with Brotli, in case of a brob box. The input must be + * the contents of a box, excluding the box header. + */ +class JxlBoxContentDecoder { + public: + JxlBoxContentDecoder(); + ~JxlBoxContentDecoder(); + + void StartBox(bool brob_decode, bool box_until_eof, size_t contents_size); + + // Outputs decoded bytes from the box, decoding with brotli if needed. + // box_pos is the position in the box content which next_in points to. + // Returns success, whether more input or output bytes are needed, or error. + JxlDecoderStatus Process(const uint8_t* next_in, size_t avail_in, + size_t box_pos, uint8_t** next_out, + size_t* avail_out); + + private: + BrotliDecoderState* brotli_dec; + + bool header_done_; + bool brob_decode_; + bool box_until_eof_; + size_t remaining_; + size_t pos_; +}; + +} // namespace jxl + +#endif // LIB_JXL_BOX_CONTENT_DECODER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.cc b/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.cc new file mode 100644 index 0000000000..66dde9afb1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.cc @@ -0,0 +1,2090 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// Author: Jyrki Alakuijala (jyrki.alakuijala@gmail.com) +// +// The physical architecture of butteraugli is based on the following naming +// convention: +// * Opsin - dynamics of the photosensitive chemicals in the retina +// with their immediate electrical processing +// * Xyb - hybrid opponent/trichromatic color space +// x is roughly red-subtract-green. +// y is yellow. +// b is blue. +// Xyb values are computed from Opsin mixing, not directly from rgb. +// * Mask - for visual masking +// * Hf - color modeling for spatially high-frequency features +// * Lf - color modeling for spatially low-frequency features +// * Diffmap - to cluster and build an image of error between the images +// * Blur - to hold the smoothing code + +#include "lib/jxl/butteraugli/butteraugli.h" + +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include <algorithm> +#include <array> +#include <cmath> +#include <new> +#include <vector> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/butteraugli/butteraugli.cc" +#include <hwy/foreach_target.h> + +#include "lib/jxl/base/fast_math-inl.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image_ops.h" + +#ifndef JXL_BUTTERAUGLI_ONCE +#define JXL_BUTTERAUGLI_ONCE + +namespace jxl { + +static const double wMfMalta = 37.0819870399; +static const double norm1Mf = 130262059.556; +static const double wMfMaltaX = 8246.75321353; +static const double norm1MfX = 1009002.70582; +static const double wHfMalta = 18.7237414387; +static const double norm1Hf = 4498534.45232; +static const double wHfMaltaX = 6923.99476109; +static const double norm1HfX = 8051.15833247; +static const double wUhfMalta = 1.10039032555; +static const double norm1Uhf = 71.7800275169; +static const double wUhfMaltaX = 173.5; +static const double norm1UhfX = 5.0; +static const double wmul[9] = { + 400.0, 1.50815703118, 0, + 2150.0, 10.6195433239, 16.2176043152, + 29.2353797994, 0.844626970982, 0.703646627719, +}; + +std::vector<float> ComputeKernel(float sigma) { + const float m = 2.25; // Accuracy increases when m is increased. + const double scaler = -1.0 / (2.0 * sigma * sigma); + const int diff = std::max<int>(1, m * std::fabs(sigma)); + std::vector<float> kernel(2 * diff + 1); + for (int i = -diff; i <= diff; ++i) { + kernel[i + diff] = std::exp(scaler * i * i); + } + return kernel; +} + +void ConvolveBorderColumn(const ImageF& in, const std::vector<float>& kernel, + const size_t x, float* BUTTERAUGLI_RESTRICT row_out) { + const size_t offset = kernel.size() / 2; + int minx = x < offset ? 0 : x - offset; + int maxx = std::min<int>(in.xsize() - 1, x + offset); + float weight = 0.0f; + for (int j = minx; j <= maxx; ++j) { + weight += kernel[j - x + offset]; + } + float scale = 1.0f / weight; + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y); + float sum = 0.0f; + for (int j = minx; j <= maxx; ++j) { + sum += row_in[j] * kernel[j - x + offset]; + } + row_out[y] = sum * scale; + } +} + +// Computes a horizontal convolution and transposes the result. +void ConvolutionWithTranspose(const ImageF& in, + const std::vector<float>& kernel, + ImageF* BUTTERAUGLI_RESTRICT out) { + JXL_CHECK(out->xsize() == in.ysize()); + JXL_CHECK(out->ysize() == in.xsize()); + const size_t len = kernel.size(); + const size_t offset = len / 2; + float weight_no_border = 0.0f; + for (size_t j = 0; j < len; ++j) { + weight_no_border += kernel[j]; + } + const float scale_no_border = 1.0f / weight_no_border; + const size_t border1 = std::min(in.xsize(), offset); + const size_t border2 = in.xsize() > offset ? in.xsize() - offset : 0; + std::vector<float> scaled_kernel(len / 2 + 1); + for (size_t i = 0; i <= len / 2; ++i) { + scaled_kernel[i] = kernel[i] * scale_no_border; + } + + // middle + switch (len) { + case 7: { + const float sk0 = scaled_kernel[0]; + const float sk1 = scaled_kernel[1]; + const float sk2 = scaled_kernel[2]; + const float sk3 = scaled_kernel[3]; + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + const float sum0 = (row_in[0] + row_in[6]) * sk0; + const float sum1 = (row_in[1] + row_in[5]) * sk1; + const float sum2 = (row_in[2] + row_in[4]) * sk2; + const float sum = (row_in[3]) * sk3 + sum0 + sum1 + sum2; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum; + } + } + } break; + case 13: { + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + float sum0 = (row_in[0] + row_in[12]) * scaled_kernel[0]; + float sum1 = (row_in[1] + row_in[11]) * scaled_kernel[1]; + float sum2 = (row_in[2] + row_in[10]) * scaled_kernel[2]; + float sum3 = (row_in[3] + row_in[9]) * scaled_kernel[3]; + sum0 += (row_in[4] + row_in[8]) * scaled_kernel[4]; + sum1 += (row_in[5] + row_in[7]) * scaled_kernel[5]; + const float sum = (row_in[6]) * scaled_kernel[6]; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum + sum0 + sum1 + sum2 + sum3; + } + } + break; + } + case 15: { + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + float sum0 = (row_in[0] + row_in[14]) * scaled_kernel[0]; + float sum1 = (row_in[1] + row_in[13]) * scaled_kernel[1]; + float sum2 = (row_in[2] + row_in[12]) * scaled_kernel[2]; + float sum3 = (row_in[3] + row_in[11]) * scaled_kernel[3]; + sum0 += (row_in[4] + row_in[10]) * scaled_kernel[4]; + sum1 += (row_in[5] + row_in[9]) * scaled_kernel[5]; + sum2 += (row_in[6] + row_in[8]) * scaled_kernel[6]; + const float sum = (row_in[7]) * scaled_kernel[7]; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum + sum0 + sum1 + sum2 + sum3; + } + } + break; + } + case 33: { + for (size_t y = 0; y < in.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = in.Row(y) + border1 - offset; + for (size_t x = border1; x < border2; ++x, ++row_in) { + float sum0 = (row_in[0] + row_in[32]) * scaled_kernel[0]; + float sum1 = (row_in[1] + row_in[31]) * scaled_kernel[1]; + float sum2 = (row_in[2] + row_in[30]) * scaled_kernel[2]; + float sum3 = (row_in[3] + row_in[29]) * scaled_kernel[3]; + sum0 += (row_in[4] + row_in[28]) * scaled_kernel[4]; + sum1 += (row_in[5] + row_in[27]) * scaled_kernel[5]; + sum2 += (row_in[6] + row_in[26]) * scaled_kernel[6]; + sum3 += (row_in[7] + row_in[25]) * scaled_kernel[7]; + sum0 += (row_in[8] + row_in[24]) * scaled_kernel[8]; + sum1 += (row_in[9] + row_in[23]) * scaled_kernel[9]; + sum2 += (row_in[10] + row_in[22]) * scaled_kernel[10]; + sum3 += (row_in[11] + row_in[21]) * scaled_kernel[11]; + sum0 += (row_in[12] + row_in[20]) * scaled_kernel[12]; + sum1 += (row_in[13] + row_in[19]) * scaled_kernel[13]; + sum2 += (row_in[14] + row_in[18]) * scaled_kernel[14]; + sum3 += (row_in[15] + row_in[17]) * scaled_kernel[15]; + const float sum = (row_in[16]) * scaled_kernel[16]; + float* BUTTERAUGLI_RESTRICT row_out = out->Row(x); + row_out[y] = sum + sum0 + sum1 + sum2 + sum3; + } + } + break; + } + default: + JXL_UNREACHABLE("Kernel size %" PRIuS " not implemented", len); + } + // left border + for (size_t x = 0; x < border1; ++x) { + ConvolveBorderColumn(in, kernel, x, out->Row(x)); + } + + // right border + for (size_t x = border2; x < in.xsize(); ++x) { + ConvolveBorderColumn(in, kernel, x, out->Row(x)); + } +} + +// A blur somewhat similar to a 2D Gaussian blur. +// See: https://en.wikipedia.org/wiki/Gaussian_blur +// +// This is a bottleneck because the sigma can be quite large (>7). We can use +// gauss_blur.cc (runtime independent of sigma, closer to a 4*sigma truncated +// Gaussian and our 2.25 in ComputeKernel), but its boundary conditions are +// zero-valued. This leads to noticeable differences at the edges of diffmaps. +// We retain a special case for 5x5 kernels (even faster than gauss_blur), +// optionally use gauss_blur followed by fixup of the borders for large images, +// or fall back to the previous truncated FIR followed by a transpose. +void Blur(const ImageF& in, float sigma, const ButteraugliParams& params, + BlurTemp* temp, ImageF* out) { + std::vector<float> kernel = ComputeKernel(sigma); + // Separable5 does an in-place convolution, so this fast path is not safe if + // in aliases out. + if (kernel.size() == 5 && &in != out) { + float sum_weights = 0.0f; + for (const float w : kernel) { + sum_weights += w; + } + const float scale = 1.0f / sum_weights; + const float w0 = kernel[2] * scale; + const float w1 = kernel[1] * scale; + const float w2 = kernel[0] * scale; + const WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + }; + Separable5(in, Rect(in), weights, /*pool=*/nullptr, out); + return; + } + + ImageF* JXL_RESTRICT temp_t = temp->GetTransposed(in); + ConvolutionWithTranspose(in, kernel, temp_t); + ConvolutionWithTranspose(*temp_t, kernel, out); +} + +// Allows PaddedMaltaUnit to call either function via overloading. +struct MaltaTagLF {}; +struct MaltaTag {}; + +} // namespace jxl + +#endif // JXL_BUTTERAUGLI_ONCE + +#include <hwy/highway.h> +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Div; +using hwy::HWY_NAMESPACE::Gt; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::IfThenElseZero; +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::MulSub; +using hwy::HWY_NAMESPACE::Neg; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::Vec; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +template <class D, class V> +HWY_INLINE V MaximumClamp(D d, V v, double kMaxVal) { + static const double kMul = 0.724216145665; + const V mul = Set(d, kMul); + const V maxval = Set(d, kMaxVal); + // If greater than maxval or less than -maxval, replace with if_*. + const V if_pos = MulAdd(Sub(v, maxval), mul, maxval); + const V if_neg = MulSub(Add(v, maxval), mul, maxval); + const V pos_or_v = IfThenElse(Ge(v, maxval), if_pos, v); + return IfThenElse(Lt(v, Neg(maxval)), if_neg, pos_or_v); +} + +// Make area around zero less important (remove it). +template <class D, class V> +HWY_INLINE V RemoveRangeAroundZero(const D d, const double kw, const V x) { + const auto w = Set(d, kw); + return IfThenElse(Gt(x, w), Sub(x, w), + IfThenElseZero(Lt(x, Neg(w)), Add(x, w))); +} + +// Make area around zero more important (2x it until the limit). +template <class D, class V> +HWY_INLINE V AmplifyRangeAroundZero(const D d, const double kw, const V x) { + const auto w = Set(d, kw); + return IfThenElse(Gt(x, w), Add(x, w), + IfThenElse(Lt(x, Neg(w)), Sub(x, w), Add(x, x))); +} + +// XybLowFreqToVals converts from low-frequency XYB space to the 'vals' space. +// Vals space can be converted to L2-norm space (Euclidean and normalized) +// through visual masking. +template <class D, class V> +HWY_INLINE void XybLowFreqToVals(const D d, const V& x, const V& y, + const V& b_arg, V* HWY_RESTRICT valx, + V* HWY_RESTRICT valy, V* HWY_RESTRICT valb) { + static const double xmul_scalar = 33.832837186260; + static const double ymul_scalar = 14.458268100570; + static const double bmul_scalar = 49.87984651440; + static const double y_to_b_mul_scalar = -0.362267051518; + const V xmul = Set(d, xmul_scalar); + const V ymul = Set(d, ymul_scalar); + const V bmul = Set(d, bmul_scalar); + const V y_to_b_mul = Set(d, y_to_b_mul_scalar); + const V b = MulAdd(y_to_b_mul, y, b_arg); + *valb = Mul(b, bmul); + *valx = Mul(x, xmul); + *valy = Mul(y, ymul); +} + +void XybLowFreqToVals(Image3F* xyb_lf) { + // Modify range around zero code only concerns the high frequency + // planes and only the X and Y channels. + // Convert low freq xyb to vals space so that we can do a simple squared sum + // diff on the low frequencies later. + const HWY_FULL(float) d; + for (size_t y = 0; y < xyb_lf->ysize(); ++y) { + float* BUTTERAUGLI_RESTRICT row_x = xyb_lf->PlaneRow(0, y); + float* BUTTERAUGLI_RESTRICT row_y = xyb_lf->PlaneRow(1, y); + float* BUTTERAUGLI_RESTRICT row_b = xyb_lf->PlaneRow(2, y); + for (size_t x = 0; x < xyb_lf->xsize(); x += Lanes(d)) { + auto valx = Undefined(d); + auto valy = Undefined(d); + auto valb = Undefined(d); + XybLowFreqToVals(d, Load(d, row_x + x), Load(d, row_y + x), + Load(d, row_b + x), &valx, &valy, &valb); + Store(valx, d, row_x + x); + Store(valy, d, row_y + x); + Store(valb, d, row_b + x); + } + } +} + +void SuppressXByY(const ImageF& in_y, ImageF* HWY_RESTRICT inout_x) { + JXL_DASSERT(SameSize(*inout_x, in_y)); + const size_t xsize = in_y.xsize(); + const size_t ysize = in_y.ysize(); + const HWY_FULL(float) d; + static const double suppress = 46.0; + static const double s = 0.653020556257; + const auto sv = Set(d, s); + const auto one_minus_s = Set(d, 1.0 - s); + const auto ywv = Set(d, suppress); + + for (size_t y = 0; y < ysize; ++y) { + const float* HWY_RESTRICT row_y = in_y.ConstRow(y); + float* HWY_RESTRICT row_x = inout_x->Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto vx = Load(d, row_x + x); + const auto vy = Load(d, row_y + x); + const auto scaler = + MulAdd(Div(ywv, MulAdd(vy, vy, ywv)), one_minus_s, sv); + Store(Mul(scaler, vx), d, row_x + x); + } + } +} + +void Subtract(const ImageF& a, const ImageF& b, ImageF* c) { + const HWY_FULL(float) d; + for (size_t y = 0; y < a.ysize(); ++y) { + const float* row_a = a.ConstRow(y); + const float* row_b = b.ConstRow(y); + float* row_c = c->Row(y); + for (size_t x = 0; x < a.xsize(); x += Lanes(d)) { + Store(Sub(Load(d, row_a + x), Load(d, row_b + x)), d, row_c + x); + } + } +} + +void SeparateLFAndMF(const ButteraugliParams& params, const Image3F& xyb, + Image3F* lf, Image3F* mf, BlurTemp* blur_temp) { + static const double kSigmaLf = 7.15593339443; + for (int i = 0; i < 3; ++i) { + // Extract lf ... + Blur(xyb.Plane(i), kSigmaLf, params, blur_temp, &lf->Plane(i)); + // ... and keep everything else in mf. + Subtract(xyb.Plane(i), lf->Plane(i), &mf->Plane(i)); + } + XybLowFreqToVals(lf); +} + +void SeparateMFAndHF(const ButteraugliParams& params, Image3F* mf, ImageF* hf, + BlurTemp* blur_temp) { + const HWY_FULL(float) d; + static const double kSigmaHf = 3.22489901262; + const size_t xsize = mf->xsize(); + const size_t ysize = mf->ysize(); + hf[0] = ImageF(xsize, ysize); + hf[1] = ImageF(xsize, ysize); + for (int i = 0; i < 3; ++i) { + if (i == 2) { + Blur(mf->Plane(i), kSigmaHf, params, blur_temp, &mf->Plane(i)); + break; + } + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_mf = mf->PlaneRow(i, y); + float* BUTTERAUGLI_RESTRICT row_hf = hf[i].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + Store(Load(d, row_mf + x), d, row_hf + x); + } + } + Blur(mf->Plane(i), kSigmaHf, params, blur_temp, &mf->Plane(i)); + static const double kRemoveMfRange = 0.29; + static const double kAddMfRange = 0.1; + if (i == 0) { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_mf = mf->PlaneRow(0, y); + float* BUTTERAUGLI_RESTRICT row_hf = hf[0].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto mf = Load(d, row_mf + x); + auto hf = Sub(Load(d, row_hf + x), mf); + mf = RemoveRangeAroundZero(d, kRemoveMfRange, mf); + Store(mf, d, row_mf + x); + Store(hf, d, row_hf + x); + } + } + } else { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_mf = mf->PlaneRow(1, y); + float* BUTTERAUGLI_RESTRICT row_hf = hf[1].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto mf = Load(d, row_mf + x); + auto hf = Sub(Load(d, row_hf + x), mf); + + mf = AmplifyRangeAroundZero(d, kAddMfRange, mf); + Store(mf, d, row_mf + x); + Store(hf, d, row_hf + x); + } + } + } + } + // Suppress red-green by intensity change in the high freq channels. + SuppressXByY(hf[1], &hf[0]); +} + +void SeparateHFAndUHF(const ButteraugliParams& params, ImageF* hf, ImageF* uhf, + BlurTemp* blur_temp) { + const HWY_FULL(float) d; + const size_t xsize = hf[0].xsize(); + const size_t ysize = hf[0].ysize(); + static const double kSigmaUhf = 1.56416327805; + uhf[0] = ImageF(xsize, ysize); + uhf[1] = ImageF(xsize, ysize); + for (int i = 0; i < 2; ++i) { + // Divide hf into hf and uhf. + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_uhf = uhf[i].Row(y); + float* BUTTERAUGLI_RESTRICT row_hf = hf[i].Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_uhf[x] = row_hf[x]; + } + } + Blur(hf[i], kSigmaUhf, params, blur_temp, &hf[i]); + static const double kRemoveHfRange = 1.5; + static const double kAddHfRange = 0.132; + static const double kRemoveUhfRange = 0.04; + static const double kMaxclampHf = 28.4691806922; + static const double kMaxclampUhf = 5.19175294647; + static double kMulYHf = 2.155; + static double kMulYUhf = 2.69313763794; + if (i == 0) { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_uhf = uhf[0].Row(y); + float* BUTTERAUGLI_RESTRICT row_hf = hf[0].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto hf = Load(d, row_hf + x); + auto uhf = Sub(Load(d, row_uhf + x), hf); + hf = RemoveRangeAroundZero(d, kRemoveHfRange, hf); + uhf = RemoveRangeAroundZero(d, kRemoveUhfRange, uhf); + Store(hf, d, row_hf + x); + Store(uhf, d, row_uhf + x); + } + } + } else { + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_uhf = uhf[1].Row(y); + float* BUTTERAUGLI_RESTRICT row_hf = hf[1].Row(y); + for (size_t x = 0; x < xsize; x += Lanes(d)) { + auto hf = Load(d, row_hf + x); + hf = MaximumClamp(d, hf, kMaxclampHf); + + auto uhf = Sub(Load(d, row_uhf + x), hf); + uhf = MaximumClamp(d, uhf, kMaxclampUhf); + uhf = Mul(uhf, Set(d, kMulYUhf)); + Store(uhf, d, row_uhf + x); + + hf = Mul(hf, Set(d, kMulYHf)); + hf = AmplifyRangeAroundZero(d, kAddHfRange, hf); + Store(hf, d, row_hf + x); + } + } + } + } +} + +void DeallocateHFAndUHF(ImageF* hf, ImageF* uhf) { + for (int i = 0; i < 2; ++i) { + hf[i] = ImageF(); + uhf[i] = ImageF(); + } +} + +static void SeparateFrequencies(size_t xsize, size_t ysize, + const ButteraugliParams& params, + BlurTemp* blur_temp, const Image3F& xyb, + PsychoImage& ps) { + ps.lf = Image3F(xyb.xsize(), xyb.ysize()); + ps.mf = Image3F(xyb.xsize(), xyb.ysize()); + SeparateLFAndMF(params, xyb, &ps.lf, &ps.mf, blur_temp); + SeparateMFAndHF(params, &ps.mf, &ps.hf[0], blur_temp); + SeparateHFAndUHF(params, &ps.hf[0], &ps.uhf[0], blur_temp); +} + +namespace { +template <typename V> +BUTTERAUGLI_INLINE V Sum(V a, V b, V c, V d) { + return Add(Add(a, b), Add(c, d)); +} +template <typename V> +BUTTERAUGLI_INLINE V Sum(V a, V b, V c, V d, V e) { + return Sum(a, b, c, Add(d, e)); +} +template <typename V> +BUTTERAUGLI_INLINE V Sum(V a, V b, V c, V d, V e, V f, V g) { + return Sum(a, b, c, Sum(d, e, f, g)); +} +template <typename V> +BUTTERAUGLI_INLINE V Sum(V a, V b, V c, V d, V e, V f, V g, V h, V i) { + return Add(Add(Sum(a, b, c, d), Sum(e, f, g, h)), i); +} +} // namespace + +template <class D> +Vec<D> MaltaUnit(MaltaTagLF /*tag*/, const D df, + const float* BUTTERAUGLI_RESTRICT d, const intptr_t xs) { + const intptr_t xs3 = 3 * xs; + + const auto center = LoadU(df, d); + + // x grows, y constant + const auto sum_yconst = Sum(LoadU(df, d - 4), LoadU(df, d - 2), center, + LoadU(df, d + 2), LoadU(df, d + 4)); + // Will return this, sum of all line kernels + auto retval = Mul(sum_yconst, sum_yconst); + { + // y grows, x constant + auto sum = Sum(LoadU(df, d - xs3 - xs), LoadU(df, d - xs - xs), center, + LoadU(df, d + xs + xs), LoadU(df, d + xs3 + xs)); + retval = MulAdd(sum, sum, retval); + } + { + // both grow + auto sum = Sum(LoadU(df, d - xs3 - 3), LoadU(df, d - xs - xs - 2), center, + LoadU(df, d + xs + xs + 2), LoadU(df, d + xs3 + 3)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows, x shrinks + auto sum = Sum(LoadU(df, d - xs3 + 3), LoadU(df, d - xs - xs + 2), center, + LoadU(df, d + xs + xs - 2), LoadU(df, d + xs3 - 3)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x shrinks 1 -> -1 + auto sum = + Sum(LoadU(df, d - xs3 - xs + 1), LoadU(df, d - xs - xs + 1), center, + LoadU(df, d + xs + xs - 1), LoadU(df, d + xs3 + xs - 1)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x grows -1 -> 1 + auto sum = + Sum(LoadU(df, d - xs3 - xs - 1), LoadU(df, d - xs - xs - 1), center, + LoadU(df, d + xs + xs + 1), LoadU(df, d + xs3 + xs + 1)); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y grows -1 to 1 + auto sum = Sum(LoadU(df, d - 4 - xs), LoadU(df, d - 2 - xs), center, + LoadU(df, d + 2 + xs), LoadU(df, d + 4 + xs)); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y shrinks 1 to -1 + auto sum = Sum(LoadU(df, d - 4 + xs), LoadU(df, d - 2 + xs), center, + LoadU(df, d + 2 - xs), LoadU(df, d + 4 - xs)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1__*______ + 2___*_____ + 3_________ + 4____0____ + 5_________ + 6_____*___ + 7______*__ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs3 - 2), LoadU(df, d - xs - xs - 1), center, + LoadU(df, d + xs + xs + 1), LoadU(df, d + xs3 + 2)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1______*__ + 2_____*___ + 3_________ + 4____0____ + 5_________ + 6___*_____ + 7__*______ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs3 + 2), LoadU(df, d - xs - xs + 1), center, + LoadU(df, d + xs + xs - 1), LoadU(df, d + xs3 - 2)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_*_______ + 3__*______ + 4____0____ + 5______*__ + 6_______*_ + 7_________ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs - xs - 3), LoadU(df, d - xs - 2), center, + LoadU(df, d + xs + 2), LoadU(df, d + xs + xs + 3)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_______*_ + 3______*__ + 4____0____ + 5__*______ + 6_*_______ + 7_________ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs - xs + 3), LoadU(df, d - xs + 2), center, + LoadU(df, d + xs - 2), LoadU(df, d + xs + xs - 3)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2________* + 3______*__ + 4____0____ + 5__*______ + 6*________ + 7_________ + 8_________ */ + + auto sum = Sum(LoadU(df, d + xs + xs - 4), LoadU(df, d + xs - 2), center, + LoadU(df, d - xs + 2), LoadU(df, d - xs - xs + 4)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2*________ + 3__*______ + 4____0____ + 5______*__ + 6________* + 7_________ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs - xs - 4), LoadU(df, d - xs - 2), center, + LoadU(df, d + xs + 2), LoadU(df, d + xs + xs + 4)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0__*______ + 1_________ + 2___*_____ + 3_________ + 4____0____ + 5_________ + 6_____*___ + 7_________ + 8______*__ */ + auto sum = + Sum(LoadU(df, d - xs3 - xs - 2), LoadU(df, d - xs - xs - 1), center, + LoadU(df, d + xs + xs + 1), LoadU(df, d + xs3 + xs + 2)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0______*__ + 1_________ + 2_____*___ + 3_________ + 4____0____ + 5_________ + 6___*_____ + 7_________ + 8__*______ */ + auto sum = + Sum(LoadU(df, d - xs3 - xs + 2), LoadU(df, d - xs - xs + 1), center, + LoadU(df, d + xs + xs - 1), LoadU(df, d + xs3 + xs - 2)); + retval = MulAdd(sum, sum, retval); + } + return retval; +} + +template <class D> +Vec<D> MaltaUnit(MaltaTag /*tag*/, const D df, + const float* BUTTERAUGLI_RESTRICT d, const intptr_t xs) { + const intptr_t xs3 = 3 * xs; + + const auto center = LoadU(df, d); + + // x grows, y constant + const auto sum_yconst = + Sum(LoadU(df, d - 4), LoadU(df, d - 3), LoadU(df, d - 2), + LoadU(df, d - 1), center, LoadU(df, d + 1), LoadU(df, d + 2), + LoadU(df, d + 3), LoadU(df, d + 4)); + // Will return this, sum of all line kernels + auto retval = Mul(sum_yconst, sum_yconst); + + { + // y grows, x constant + auto sum = Sum(LoadU(df, d - xs3 - xs), LoadU(df, d - xs3), + LoadU(df, d - xs - xs), LoadU(df, d - xs), center, + LoadU(df, d + xs), LoadU(df, d + xs + xs), + LoadU(df, d + xs3), LoadU(df, d + xs3 + xs)); + retval = MulAdd(sum, sum, retval); + } + { + // both grow + auto sum = Sum(LoadU(df, d - xs3 - 3), LoadU(df, d - xs - xs - 2), + LoadU(df, d - xs - 1), center, LoadU(df, d + xs + 1), + LoadU(df, d + xs + xs + 2), LoadU(df, d + xs3 + 3)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows, x shrinks + auto sum = Sum(LoadU(df, d - xs3 + 3), LoadU(df, d - xs - xs + 2), + LoadU(df, d - xs + 1), center, LoadU(df, d + xs - 1), + LoadU(df, d + xs + xs - 2), LoadU(df, d + xs3 - 3)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x shrinks 1 -> -1 + auto sum = Sum(LoadU(df, d - xs3 - xs + 1), LoadU(df, d - xs3 + 1), + LoadU(df, d - xs - xs + 1), LoadU(df, d - xs), center, + LoadU(df, d + xs), LoadU(df, d + xs + xs - 1), + LoadU(df, d + xs3 - 1), LoadU(df, d + xs3 + xs - 1)); + retval = MulAdd(sum, sum, retval); + } + { + // y grows -4 to 4, x grows -1 -> 1 + auto sum = Sum(LoadU(df, d - xs3 - xs - 1), LoadU(df, d - xs3 - 1), + LoadU(df, d - xs - xs - 1), LoadU(df, d - xs), center, + LoadU(df, d + xs), LoadU(df, d + xs + xs + 1), + LoadU(df, d + xs3 + 1), LoadU(df, d + xs3 + xs + 1)); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y grows -1 to 1 + auto sum = + Sum(LoadU(df, d - 4 - xs), LoadU(df, d - 3 - xs), LoadU(df, d - 2 - xs), + LoadU(df, d - 1), center, LoadU(df, d + 1), LoadU(df, d + 2 + xs), + LoadU(df, d + 3 + xs), LoadU(df, d + 4 + xs)); + retval = MulAdd(sum, sum, retval); + } + { + // x grows -4 to 4, y shrinks 1 to -1 + auto sum = + Sum(LoadU(df, d - 4 + xs), LoadU(df, d - 3 + xs), LoadU(df, d - 2 + xs), + LoadU(df, d - 1), center, LoadU(df, d + 1), LoadU(df, d + 2 - xs), + LoadU(df, d + 3 - xs), LoadU(df, d + 4 - xs)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1__*______ + 2___*_____ + 3___*_____ + 4____0____ + 5_____*___ + 6_____*___ + 7______*__ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs3 - 2), LoadU(df, d - xs - xs - 1), + LoadU(df, d - xs - 1), center, LoadU(df, d + xs + 1), + LoadU(df, d + xs + xs + 1), LoadU(df, d + xs3 + 2)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1______*__ + 2_____*___ + 3_____*___ + 4____0____ + 5___*_____ + 6___*_____ + 7__*______ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs3 + 2), LoadU(df, d - xs - xs + 1), + LoadU(df, d - xs + 1), center, LoadU(df, d + xs - 1), + LoadU(df, d + xs + xs - 1), LoadU(df, d + xs3 - 2)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_*_______ + 3__**_____ + 4____0____ + 5_____**__ + 6_______*_ + 7_________ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs - xs - 3), LoadU(df, d - xs - 2), + LoadU(df, d - xs - 1), center, LoadU(df, d + xs + 1), + LoadU(df, d + xs + 2), LoadU(df, d + xs + xs + 3)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_______*_ + 3_____**__ + 4____0____ + 5__**_____ + 6_*_______ + 7_________ + 8_________ */ + auto sum = Sum(LoadU(df, d - xs - xs + 3), LoadU(df, d - xs + 2), + LoadU(df, d - xs + 1), center, LoadU(df, d + xs - 1), + LoadU(df, d + xs - 2), LoadU(df, d + xs + xs - 3)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_________ + 3______*** + 4___*0*___ + 5***______ + 6_________ + 7_________ + 8_________ */ + + auto sum = + Sum(LoadU(df, d + xs - 4), LoadU(df, d + xs - 3), LoadU(df, d + xs - 2), + LoadU(df, d - 1), center, LoadU(df, d + 1), LoadU(df, d - xs + 2), + LoadU(df, d - xs + 3), LoadU(df, d - xs + 4)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_________ + 1_________ + 2_________ + 3***______ + 4___*0*___ + 5______*** + 6_________ + 7_________ + 8_________ */ + auto sum = + Sum(LoadU(df, d - xs - 4), LoadU(df, d - xs - 3), LoadU(df, d - xs - 2), + LoadU(df, d - 1), center, LoadU(df, d + 1), LoadU(df, d + xs + 2), + LoadU(df, d + xs + 3), LoadU(df, d + xs + 4)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0___*_____ + 1___*_____ + 2___*_____ + 3____*____ + 4____0____ + 5____*____ + 6_____*___ + 7_____*___ + 8_____*___ */ + auto sum = Sum(LoadU(df, d - xs3 - xs - 1), LoadU(df, d - xs3 - 1), + LoadU(df, d - xs - xs - 1), LoadU(df, d - xs), center, + LoadU(df, d + xs), LoadU(df, d + xs + xs + 1), + LoadU(df, d + xs3 + 1), LoadU(df, d + xs3 + xs + 1)); + retval = MulAdd(sum, sum, retval); + } + { + /* 0_____*___ + 1_____*___ + 2____ *___ + 3____*____ + 4____0____ + 5____*____ + 6___*_____ + 7___*_____ + 8___*_____ */ + auto sum = Sum(LoadU(df, d - xs3 - xs + 1), LoadU(df, d - xs3 + 1), + LoadU(df, d - xs - xs + 1), LoadU(df, d - xs), center, + LoadU(df, d + xs), LoadU(df, d + xs + xs - 1), + LoadU(df, d + xs3 - 1), LoadU(df, d + xs3 + xs - 1)); + retval = MulAdd(sum, sum, retval); + } + return retval; +} + +// Returns MaltaUnit. Avoids bounds-checks when x0 and y0 are known +// to be far enough from the image borders. "diffs" is a packed image. +template <class Tag> +static BUTTERAUGLI_INLINE float PaddedMaltaUnit(const ImageF& diffs, + const size_t x0, + const size_t y0) { + const float* BUTTERAUGLI_RESTRICT d = diffs.ConstRow(y0) + x0; + const HWY_CAPPED(float, 1) df; + if ((x0 >= 4 && y0 >= 4 && x0 < (diffs.xsize() - 4) && + y0 < (diffs.ysize() - 4))) { + return GetLane(MaltaUnit(Tag(), df, d, diffs.PixelsPerRow())); + } + + float borderimage[12 * 9]; // round up to 4 + for (int dy = 0; dy < 9; ++dy) { + int y = y0 + dy - 4; + if (y < 0 || static_cast<size_t>(y) >= diffs.ysize()) { + for (int dx = 0; dx < 12; ++dx) { + borderimage[dy * 12 + dx] = 0.0f; + } + continue; + } + + const float* row_diffs = diffs.ConstRow(y); + for (int dx = 0; dx < 9; ++dx) { + int x = x0 + dx - 4; + if (x < 0 || static_cast<size_t>(x) >= diffs.xsize()) { + borderimage[dy * 12 + dx] = 0.0f; + } else { + borderimage[dy * 12 + dx] = row_diffs[x]; + } + } + std::fill(borderimage + dy * 12 + 9, borderimage + dy * 12 + 12, 0.0f); + } + return GetLane(MaltaUnit(Tag(), df, &borderimage[4 * 12 + 4], 12)); +} + +template <class Tag> +static void MaltaDiffMapT(const Tag tag, const ImageF& lum0, const ImageF& lum1, + const double w_0gt1, const double w_0lt1, + const double norm1, const double len, + const double mulli, ImageF* HWY_RESTRICT diffs, + ImageF* HWY_RESTRICT block_diff_ac) { + JXL_DASSERT(SameSize(lum0, lum1) && SameSize(lum0, *diffs)); + const size_t xsize_ = lum0.xsize(); + const size_t ysize_ = lum0.ysize(); + + const float kWeight0 = 0.5; + const float kWeight1 = 0.33; + + const double w_pre0gt1 = mulli * std::sqrt(kWeight0 * w_0gt1) / (len * 2 + 1); + const double w_pre0lt1 = mulli * std::sqrt(kWeight1 * w_0lt1) / (len * 2 + 1); + const float norm2_0gt1 = w_pre0gt1 * norm1; + const float norm2_0lt1 = w_pre0lt1 * norm1; + + for (size_t y = 0; y < ysize_; ++y) { + const float* HWY_RESTRICT row0 = lum0.ConstRow(y); + const float* HWY_RESTRICT row1 = lum1.ConstRow(y); + float* HWY_RESTRICT row_diffs = diffs->Row(y); + for (size_t x = 0; x < xsize_; ++x) { + const float absval = 0.5f * (std::abs(row0[x]) + std::abs(row1[x])); + const float diff = row0[x] - row1[x]; + const float scaler = norm2_0gt1 / (static_cast<float>(norm1) + absval); + + // Primary symmetric quadratic objective. + row_diffs[x] = scaler * diff; + + const float scaler2 = norm2_0lt1 / (static_cast<float>(norm1) + absval); + const double fabs0 = std::fabs(row0[x]); + + // Secondary half-open quadratic objectives. + const double too_small = 0.55 * fabs0; + const double too_big = 1.05 * fabs0; + + if (row0[x] < 0) { + if (row1[x] > -too_small) { + double impact = scaler2 * (row1[x] + too_small); + row_diffs[x] -= impact; + } else if (row1[x] < -too_big) { + double impact = scaler2 * (-row1[x] - too_big); + row_diffs[x] += impact; + } + } else { + if (row1[x] < too_small) { + double impact = scaler2 * (too_small - row1[x]); + row_diffs[x] += impact; + } else if (row1[x] > too_big) { + double impact = scaler2 * (row1[x] - too_big); + row_diffs[x] -= impact; + } + } + } + } + + size_t y0 = 0; + // Top + for (; y0 < 4; ++y0) { + float* BUTTERAUGLI_RESTRICT row_diff = block_diff_ac->Row(y0); + for (size_t x0 = 0; x0 < xsize_; ++x0) { + row_diff[x0] += PaddedMaltaUnit<Tag>(*diffs, x0, y0); + } + } + + const HWY_FULL(float) df; + const size_t aligned_x = std::max(size_t(4), Lanes(df)); + const intptr_t stride = diffs->PixelsPerRow(); + + // Middle + for (; y0 < ysize_ - 4; ++y0) { + const float* BUTTERAUGLI_RESTRICT row_in = diffs->ConstRow(y0); + float* BUTTERAUGLI_RESTRICT row_diff = block_diff_ac->Row(y0); + size_t x0 = 0; + for (; x0 < aligned_x; ++x0) { + row_diff[x0] += PaddedMaltaUnit<Tag>(*diffs, x0, y0); + } + for (; x0 + Lanes(df) + 4 <= xsize_; x0 += Lanes(df)) { + auto diff = Load(df, row_diff + x0); + diff = Add(diff, MaltaUnit(Tag(), df, row_in + x0, stride)); + Store(diff, df, row_diff + x0); + } + + for (; x0 < xsize_; ++x0) { + row_diff[x0] += PaddedMaltaUnit<Tag>(*diffs, x0, y0); + } + } + + // Bottom + for (; y0 < ysize_; ++y0) { + float* BUTTERAUGLI_RESTRICT row_diff = block_diff_ac->Row(y0); + for (size_t x0 = 0; x0 < xsize_; ++x0) { + row_diff[x0] += PaddedMaltaUnit<Tag>(*diffs, x0, y0); + } + } +} + +// Need non-template wrapper functions for HWY_EXPORT. +void MaltaDiffMap(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, + ImageF* HWY_RESTRICT diffs, + ImageF* HWY_RESTRICT block_diff_ac) { + const double len = 3.75; + static const double mulli = 0.39905817637; + MaltaDiffMapT(MaltaTag(), lum0, lum1, w_0gt1, w_0lt1, norm1, len, mulli, + diffs, block_diff_ac); +} + +void MaltaDiffMapLF(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, + ImageF* HWY_RESTRICT diffs, + ImageF* HWY_RESTRICT block_diff_ac) { + const double len = 3.75; + static const double mulli = 0.611612573796; + MaltaDiffMapT(MaltaTagLF(), lum0, lum1, w_0gt1, w_0lt1, norm1, len, mulli, + diffs, block_diff_ac); +} + +void CombineChannelsForMasking(const ImageF* hf, const ImageF* uhf, + ImageF* out) { + // Only X and Y components are involved in masking. B's influence + // is considered less important in the high frequency area, and we + // don't model masking from lower frequency signals. + static const float muls[3] = { + 2.5f, + 0.4f, + 0.4f, + }; + // Silly and unoptimized approach here. TODO(jyrki): rework this. + for (size_t y = 0; y < hf[0].ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_y_hf = hf[1].Row(y); + const float* BUTTERAUGLI_RESTRICT row_y_uhf = uhf[1].Row(y); + const float* BUTTERAUGLI_RESTRICT row_x_hf = hf[0].Row(y); + const float* BUTTERAUGLI_RESTRICT row_x_uhf = uhf[0].Row(y); + float* BUTTERAUGLI_RESTRICT row = out->Row(y); + for (size_t x = 0; x < hf[0].xsize(); ++x) { + float xdiff = (row_x_uhf[x] + row_x_hf[x]) * muls[0]; + float ydiff = row_y_uhf[x] * muls[1] + row_y_hf[x] * muls[2]; + row[x] = xdiff * xdiff + ydiff * ydiff; + row[x] = sqrt(row[x]); + } + } +} + +void DiffPrecompute(const ImageF& xyb, float mul, float bias_arg, ImageF* out) { + const size_t xsize = xyb.xsize(); + const size_t ysize = xyb.ysize(); + const float bias = mul * bias_arg; + const float sqrt_bias = sqrt(bias); + for (size_t y = 0; y < ysize; ++y) { + const float* BUTTERAUGLI_RESTRICT row_in = xyb.Row(y); + float* BUTTERAUGLI_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < xsize; ++x) { + // kBias makes sqrt behave more linearly. + row_out[x] = sqrt(mul * std::abs(row_in[x]) + bias) - sqrt_bias; + } + } +} + +// std::log(80.0) / std::log(255.0); +constexpr float kIntensityTargetNormalizationHack = 0.79079917404f; +static const float kInternalGoodQualityThreshold = + 17.83f * kIntensityTargetNormalizationHack; +static const float kGlobalScale = 1.0 / kInternalGoodQualityThreshold; + +void StoreMin3(const float v, float& min0, float& min1, float& min2) { + if (v < min2) { + if (v < min0) { + min2 = min1; + min1 = min0; + min0 = v; + } else if (v < min1) { + min2 = min1; + min1 = v; + } else { + min2 = v; + } + } +} + +// Look for smooth areas near the area of degradation. +// If the areas area generally smooth, don't do masking. +void FuzzyErosion(const ImageF& from, ImageF* to) { + const size_t xsize = from.xsize(); + const size_t ysize = from.ysize(); + static const int kStep = 3; + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + float min0 = from.Row(y)[x]; + float min1 = 2 * min0; + float min2 = min1; + if (x >= kStep) { + float v = from.Row(y)[x - kStep]; + StoreMin3(v, min0, min1, min2); + if (y >= kStep) { + float v = from.Row(y - kStep)[x - kStep]; + StoreMin3(v, min0, min1, min2); + } + if (y < ysize - kStep) { + float v = from.Row(y + kStep)[x - kStep]; + StoreMin3(v, min0, min1, min2); + } + } + if (x < xsize - kStep) { + float v = from.Row(y)[x + kStep]; + StoreMin3(v, min0, min1, min2); + if (y >= kStep) { + float v = from.Row(y - kStep)[x + kStep]; + StoreMin3(v, min0, min1, min2); + } + if (y < ysize - kStep) { + float v = from.Row(y + kStep)[x + kStep]; + StoreMin3(v, min0, min1, min2); + } + } + if (y >= kStep) { + float v = from.Row(y - kStep)[x]; + StoreMin3(v, min0, min1, min2); + } + if (y < ysize - kStep) { + float v = from.Row(y + kStep)[x]; + StoreMin3(v, min0, min1, min2); + } + to->Row(y)[x] = (0.45f * min0 + 0.3f * min1 + 0.25f * min2); + } + } +} + +// Compute values of local frequency and dc masking based on the activity +// in the two images. img_diff_ac may be null. +void Mask(const ImageF& mask0, const ImageF& mask1, + const ButteraugliParams& params, BlurTemp* blur_temp, + ImageF* BUTTERAUGLI_RESTRICT mask, + ImageF* BUTTERAUGLI_RESTRICT diff_ac) { + const size_t xsize = mask0.xsize(); + const size_t ysize = mask0.ysize(); + *mask = ImageF(xsize, ysize); + static const float kMul = 6.19424080439; + static const float kBias = 12.61050594197; + static const float kRadius = 2.7; + ImageF diff0(xsize, ysize); + ImageF diff1(xsize, ysize); + ImageF blurred0(xsize, ysize); + ImageF blurred1(xsize, ysize); + DiffPrecompute(mask0, kMul, kBias, &diff0); + DiffPrecompute(mask1, kMul, kBias, &diff1); + Blur(diff0, kRadius, params, blur_temp, &blurred0); + FuzzyErosion(blurred0, &diff0); + Blur(diff1, kRadius, params, blur_temp, &blurred1); + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + mask->Row(y)[x] = diff0.Row(y)[x]; + if (diff_ac != nullptr) { + static const float kMaskToErrorMul = 10.0; + float diff = blurred0.Row(y)[x] - blurred1.Row(y)[x]; + diff_ac->Row(y)[x] += kMaskToErrorMul * diff * diff; + } + } + } +} + +// `diff_ac` may be null. +void MaskPsychoImage(const PsychoImage& pi0, const PsychoImage& pi1, + const size_t xsize, const size_t ysize, + const ButteraugliParams& params, BlurTemp* blur_temp, + ImageF* BUTTERAUGLI_RESTRICT mask, + ImageF* BUTTERAUGLI_RESTRICT diff_ac) { + ImageF mask0(xsize, ysize); + ImageF mask1(xsize, ysize); + CombineChannelsForMasking(&pi0.hf[0], &pi0.uhf[0], &mask0); + CombineChannelsForMasking(&pi1.hf[0], &pi1.uhf[0], &mask1); + Mask(mask0, mask1, params, blur_temp, mask, diff_ac); +} + +double MaskY(double delta) { + static const double offset = 0.829591754942; + static const double scaler = 0.451936922203; + static const double mul = 2.5485944793; + const double c = mul / ((scaler * delta) + offset); + const double retval = kGlobalScale * (1.0 + c); + return retval * retval; +} + +double MaskDcY(double delta) { + static const double offset = 0.20025578522; + static const double scaler = 3.87449418804; + static const double mul = 0.505054525019; + const double c = mul / ((scaler * delta) + offset); + const double retval = kGlobalScale * (1.0 + c); + return retval * retval; +} + +inline float MaskColor(const float color[3], const float mask) { + return color[0] * mask + color[1] * mask + color[2] * mask; +} + +// Diffmap := sqrt of sum{diff images by multiplied by X and Y/B masks} +void CombineChannelsToDiffmap(const ImageF& mask, const Image3F& block_diff_dc, + const Image3F& block_diff_ac, float xmul, + ImageF* result) { + JXL_CHECK(SameSize(mask, *result)); + size_t xsize = mask.xsize(); + size_t ysize = mask.ysize(); + for (size_t y = 0; y < ysize; ++y) { + float* BUTTERAUGLI_RESTRICT row_out = result->Row(y); + for (size_t x = 0; x < xsize; ++x) { + float val = mask.Row(y)[x]; + float maskval = MaskY(val); + float dc_maskval = MaskDcY(val); + float diff_dc[3]; + float diff_ac[3]; + for (int i = 0; i < 3; ++i) { + diff_dc[i] = block_diff_dc.PlaneRow(i, y)[x]; + diff_ac[i] = block_diff_ac.PlaneRow(i, y)[x]; + } + diff_ac[0] *= xmul; + diff_dc[0] *= xmul; + row_out[x] = + sqrt(MaskColor(diff_dc, dc_maskval) + MaskColor(diff_ac, maskval)); + } + } +} + +// Adds weighted L2 difference between i0 and i1 to diffmap. +static void L2Diff(const ImageF& i0, const ImageF& i1, const float w, + ImageF* BUTTERAUGLI_RESTRICT diffmap) { + if (w == 0) return; + + const HWY_FULL(float) d; + const auto weight = Set(d, w); + + for (size_t y = 0; y < i0.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row0 = i0.ConstRow(y); + const float* BUTTERAUGLI_RESTRICT row1 = i1.ConstRow(y); + float* BUTTERAUGLI_RESTRICT row_diff = diffmap->Row(y); + + for (size_t x = 0; x < i0.xsize(); x += Lanes(d)) { + const auto diff = Sub(Load(d, row0 + x), Load(d, row1 + x)); + const auto diff2 = Mul(diff, diff); + const auto prev = Load(d, row_diff + x); + Store(MulAdd(diff2, weight, prev), d, row_diff + x); + } + } +} + +// Initializes diffmap to the weighted L2 difference between i0 and i1. +static void SetL2Diff(const ImageF& i0, const ImageF& i1, const float w, + ImageF* BUTTERAUGLI_RESTRICT diffmap) { + if (w == 0) return; + + const HWY_FULL(float) d; + const auto weight = Set(d, w); + + for (size_t y = 0; y < i0.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row0 = i0.ConstRow(y); + const float* BUTTERAUGLI_RESTRICT row1 = i1.ConstRow(y); + float* BUTTERAUGLI_RESTRICT row_diff = diffmap->Row(y); + + for (size_t x = 0; x < i0.xsize(); x += Lanes(d)) { + const auto diff = Sub(Load(d, row0 + x), Load(d, row1 + x)); + const auto diff2 = Mul(diff, diff); + Store(Mul(diff2, weight), d, row_diff + x); + } + } +} + +// i0 is the original image. +// i1 is the deformed copy. +static void L2DiffAsymmetric(const ImageF& i0, const ImageF& i1, float w_0gt1, + float w_0lt1, + ImageF* BUTTERAUGLI_RESTRICT diffmap) { + if (w_0gt1 == 0 && w_0lt1 == 0) { + return; + } + + const HWY_FULL(float) d; + const auto vw_0gt1 = Set(d, w_0gt1 * 0.8); + const auto vw_0lt1 = Set(d, w_0lt1 * 0.8); + + for (size_t y = 0; y < i0.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row0 = i0.Row(y); + const float* BUTTERAUGLI_RESTRICT row1 = i1.Row(y); + float* BUTTERAUGLI_RESTRICT row_diff = diffmap->Row(y); + + for (size_t x = 0; x < i0.xsize(); x += Lanes(d)) { + const auto val0 = Load(d, row0 + x); + const auto val1 = Load(d, row1 + x); + + // Primary symmetric quadratic objective. + const auto diff = Sub(val0, val1); + auto total = MulAdd(Mul(diff, diff), vw_0gt1, Load(d, row_diff + x)); + + // Secondary half-open quadratic objectives. + const auto fabs0 = Abs(val0); + const auto too_small = Mul(Set(d, 0.4), fabs0); + const auto too_big = fabs0; + + const auto if_neg = IfThenElse( + Gt(val1, Neg(too_small)), Add(val1, too_small), + IfThenElseZero(Lt(val1, Neg(too_big)), Sub(Neg(val1), too_big))); + const auto if_pos = + IfThenElse(Lt(val1, too_small), Sub(too_small, val1), + IfThenElseZero(Gt(val1, too_big), Sub(val1, too_big))); + const auto v = IfThenElse(Lt(val0, Zero(d)), if_neg, if_pos); + total = MulAdd(vw_0lt1, Mul(v, v), total); + Store(total, d, row_diff + x); + } + } +} + +// A simple HDR compatible gamma function. +template <class DF, class V> +V Gamma(const DF df, V v) { + // ln(2) constant folded in because we want std::log but have FastLog2f. + const auto kRetMul = Set(df, 19.245013259874995f * 0.693147180559945f); + const auto kRetAdd = Set(df, -23.16046239805755); + // This should happen rarely, but may lead to a NaN in log, which is + // undesirable. Since negative photons don't exist we solve the NaNs by + // clamping here. + v = ZeroIfNegative(v); + + const auto biased = Add(v, Set(df, 9.9710635769299145)); + const auto log = FastLog2f(df, biased); + // We could fold this into a custom Log2 polynomial, but there would be + // relatively little gain. + return MulAdd(kRetMul, log, kRetAdd); +} + +template <bool Clamp, class DF, class V> +BUTTERAUGLI_INLINE void OpsinAbsorbance(const DF df, const V& in0, const V& in1, + const V& in2, V* JXL_RESTRICT out0, + V* JXL_RESTRICT out1, + V* JXL_RESTRICT out2) { + // https://en.wikipedia.org/wiki/Photopsin absorbance modeling. + static const double mixi0 = 0.29956550340058319; + static const double mixi1 = 0.63373087833825936; + static const double mixi2 = 0.077705617820981968; + static const double mixi3 = 1.7557483643287353; + static const double mixi4 = 0.22158691104574774; + static const double mixi5 = 0.69391388044116142; + static const double mixi6 = 0.0987313588422; + static const double mixi7 = 1.7557483643287353; + static const double mixi8 = 0.02; + static const double mixi9 = 0.02; + static const double mixi10 = 0.20480129041026129; + static const double mixi11 = 12.226454707163354; + + const V mix0 = Set(df, mixi0); + const V mix1 = Set(df, mixi1); + const V mix2 = Set(df, mixi2); + const V mix3 = Set(df, mixi3); + const V mix4 = Set(df, mixi4); + const V mix5 = Set(df, mixi5); + const V mix6 = Set(df, mixi6); + const V mix7 = Set(df, mixi7); + const V mix8 = Set(df, mixi8); + const V mix9 = Set(df, mixi9); + const V mix10 = Set(df, mixi10); + const V mix11 = Set(df, mixi11); + + *out0 = MulAdd(mix0, in0, MulAdd(mix1, in1, MulAdd(mix2, in2, mix3))); + *out1 = MulAdd(mix4, in0, MulAdd(mix5, in1, MulAdd(mix6, in2, mix7))); + *out2 = MulAdd(mix8, in0, MulAdd(mix9, in1, MulAdd(mix10, in2, mix11))); + + if (Clamp) { + *out0 = Max(*out0, mix3); + *out1 = Max(*out1, mix7); + *out2 = Max(*out2, mix11); + } +} + +// `blurred` is a temporary image used inside this function and not returned. +void OpsinDynamicsImage(const Image3F& rgb, const ButteraugliParams& params, + Image3F* blurred, BlurTemp* blur_temp, Image3F* xyb) { + const double kSigma = 1.2; + Blur(rgb.Plane(0), kSigma, params, blur_temp, &blurred->Plane(0)); + Blur(rgb.Plane(1), kSigma, params, blur_temp, &blurred->Plane(1)); + Blur(rgb.Plane(2), kSigma, params, blur_temp, &blurred->Plane(2)); + const HWY_FULL(float) df; + const auto intensity_target_multiplier = Set(df, params.intensity_target); + for (size_t y = 0; y < rgb.ysize(); ++y) { + const float* row_r = rgb.ConstPlaneRow(0, y); + const float* row_g = rgb.ConstPlaneRow(1, y); + const float* row_b = rgb.ConstPlaneRow(2, y); + const float* row_blurred_r = blurred->ConstPlaneRow(0, y); + const float* row_blurred_g = blurred->ConstPlaneRow(1, y); + const float* row_blurred_b = blurred->ConstPlaneRow(2, y); + float* row_out_x = xyb->PlaneRow(0, y); + float* row_out_y = xyb->PlaneRow(1, y); + float* row_out_b = xyb->PlaneRow(2, y); + const auto min = Set(df, 1e-4f); + for (size_t x = 0; x < rgb.xsize(); x += Lanes(df)) { + auto sensitivity0 = Undefined(df); + auto sensitivity1 = Undefined(df); + auto sensitivity2 = Undefined(df); + { + // Calculate sensitivity based on the smoothed image gamma derivative. + auto pre_mixed0 = Undefined(df); + auto pre_mixed1 = Undefined(df); + auto pre_mixed2 = Undefined(df); + OpsinAbsorbance<true>( + df, Mul(Load(df, row_blurred_r + x), intensity_target_multiplier), + Mul(Load(df, row_blurred_g + x), intensity_target_multiplier), + Mul(Load(df, row_blurred_b + x), intensity_target_multiplier), + &pre_mixed0, &pre_mixed1, &pre_mixed2); + pre_mixed0 = Max(pre_mixed0, min); + pre_mixed1 = Max(pre_mixed1, min); + pre_mixed2 = Max(pre_mixed2, min); + sensitivity0 = Div(Gamma(df, pre_mixed0), pre_mixed0); + sensitivity1 = Div(Gamma(df, pre_mixed1), pre_mixed1); + sensitivity2 = Div(Gamma(df, pre_mixed2), pre_mixed2); + sensitivity0 = Max(sensitivity0, min); + sensitivity1 = Max(sensitivity1, min); + sensitivity2 = Max(sensitivity2, min); + } + auto cur_mixed0 = Undefined(df); + auto cur_mixed1 = Undefined(df); + auto cur_mixed2 = Undefined(df); + OpsinAbsorbance<false>( + df, Mul(Load(df, row_r + x), intensity_target_multiplier), + Mul(Load(df, row_g + x), intensity_target_multiplier), + Mul(Load(df, row_b + x), intensity_target_multiplier), &cur_mixed0, + &cur_mixed1, &cur_mixed2); + cur_mixed0 = Mul(cur_mixed0, sensitivity0); + cur_mixed1 = Mul(cur_mixed1, sensitivity1); + cur_mixed2 = Mul(cur_mixed2, sensitivity2); + // This is a kludge. The negative values should be zeroed away before + // blurring. Ideally there would be no negative values in the first place. + const auto min01 = Set(df, 1.7557483643287353f); + const auto min2 = Set(df, 12.226454707163354f); + cur_mixed0 = Max(cur_mixed0, min01); + cur_mixed1 = Max(cur_mixed1, min01); + cur_mixed2 = Max(cur_mixed2, min2); + + Store(Sub(cur_mixed0, cur_mixed1), df, row_out_x + x); + Store(Add(cur_mixed0, cur_mixed1), df, row_out_y + x); + Store(cur_mixed2, df, row_out_b + x); + } + } +} + +void ButteraugliDiffmapInPlace(Image3F& image0, Image3F& image1, + const ButteraugliParams& params, + ImageF& diffmap) { + // image0 and image1 are in linear sRGB color space + const size_t xsize = image0.xsize(); + const size_t ysize = image0.ysize(); + BlurTemp blur_temp; + { + // Convert image0 and image1 to XYB in-place + Image3F temp(xsize, ysize); + OpsinDynamicsImage(image0, params, &temp, &blur_temp, &image0); + OpsinDynamicsImage(image1, params, &temp, &blur_temp, &image1); + } + // image0 and image1 are in XYB color space + ImageF block_diff_dc(xsize, ysize); + ZeroFillImage(&block_diff_dc); + { + // separate out LF components from image0 and image1 and compute the dc + // diff image from them + Image3F lf0 = Image3F(xsize, ysize); + Image3F lf1 = Image3F(xsize, ysize); + SeparateLFAndMF(params, image0, &lf0, &image0, &blur_temp); + SeparateLFAndMF(params, image1, &lf1, &image1, &blur_temp); + for (size_t c = 0; c < 3; ++c) { + L2Diff(lf0.Plane(c), lf1.Plane(c), wmul[6 + c], &block_diff_dc); + } + } + // image0 and image1 are MF residuals (before blurring) in XYB color space + ImageF hf0[2]; + ImageF hf1[2]; + SeparateMFAndHF(params, &image0, &hf0[0], &blur_temp); + SeparateMFAndHF(params, &image1, &hf1[0], &blur_temp); + // image0 and image1 are MF-images in XYB color space + + ImageF block_diff_ac(xsize, ysize); + ZeroFillImage(&block_diff_ac); + // start accumulating ac diff image from MF images + { + ImageF diffs(xsize, ysize); + MaltaDiffMapLF(image0.Plane(1), image1.Plane(1), wMfMalta, wMfMalta, + norm1Mf, &diffs, &block_diff_ac); + MaltaDiffMapLF(image0.Plane(0), image1.Plane(0), wMfMaltaX, wMfMaltaX, + norm1MfX, &diffs, &block_diff_ac); + } + for (size_t c = 0; c < 3; ++c) { + L2Diff(image0.Plane(c), image1.Plane(c), wmul[3 + c], &block_diff_ac); + } + // we will not need the MF-images and more, so we deallocate them to reduce + // peak memory usage + image0 = Image3F(); + image1 = Image3F(); + + ImageF uhf0[2]; + ImageF uhf1[2]; + SeparateHFAndUHF(params, &hf0[0], &uhf0[0], &blur_temp); + SeparateHFAndUHF(params, &hf1[0], &uhf1[0], &blur_temp); + + // continue accumulating ac diff image from HF and UHF images + const float hf_asymmetry = params.hf_asymmetry; + { + ImageF diffs(xsize, ysize); + MaltaDiffMap(uhf0[1], uhf1[1], wUhfMalta * hf_asymmetry, + wUhfMalta / hf_asymmetry, norm1Uhf, &diffs, &block_diff_ac); + MaltaDiffMap(uhf0[0], uhf1[0], wUhfMaltaX * hf_asymmetry, + wUhfMaltaX / hf_asymmetry, norm1UhfX, &diffs, &block_diff_ac); + MaltaDiffMapLF(hf0[1], hf1[1], wHfMalta * std::sqrt(hf_asymmetry), + wHfMalta / std::sqrt(hf_asymmetry), norm1Hf, &diffs, + &block_diff_ac); + MaltaDiffMapLF(hf0[0], hf1[0], wHfMaltaX * std::sqrt(hf_asymmetry), + wHfMaltaX / std::sqrt(hf_asymmetry), norm1HfX, &diffs, + &block_diff_ac); + } + for (size_t c = 0; c < 2; ++c) { + L2DiffAsymmetric(hf0[c], hf1[c], wmul[c] * hf_asymmetry, + wmul[c] / hf_asymmetry, &block_diff_ac); + } + + // compute mask image from HF and UHF X and Y images + ImageF mask(xsize, ysize); + { + ImageF mask0(xsize, ysize); + ImageF mask1(xsize, ysize); + CombineChannelsForMasking(&hf0[0], &uhf0[0], &mask0); + CombineChannelsForMasking(&hf1[0], &uhf1[0], &mask1); + DeallocateHFAndUHF(&hf1[0], &uhf1[0]); + DeallocateHFAndUHF(&hf0[0], &uhf0[0]); + Mask(mask0, mask1, params, &blur_temp, &mask, &block_diff_ac); + } + + // compute final diffmap from mask image and ac and dc diff images + diffmap = ImageF(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + const float* row_dc = block_diff_dc.Row(y); + const float* row_ac = block_diff_ac.Row(y); + float* row_out = diffmap.Row(y); + for (size_t x = 0; x < xsize; ++x) { + const float val = mask.Row(y)[x]; + row_out[x] = sqrt(row_dc[x] * MaskDcY(val) + row_ac[x] * MaskY(val)); + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(SeparateFrequencies); // Local function. +HWY_EXPORT(MaskPsychoImage); // Local function. +HWY_EXPORT(L2DiffAsymmetric); // Local function. +HWY_EXPORT(L2Diff); // Local function. +HWY_EXPORT(SetL2Diff); // Local function. +HWY_EXPORT(CombineChannelsToDiffmap); // Local function. +HWY_EXPORT(MaltaDiffMap); // Local function. +HWY_EXPORT(MaltaDiffMapLF); // Local function. +HWY_EXPORT(OpsinDynamicsImage); // Local function. +HWY_EXPORT(ButteraugliDiffmapInPlace); // Local function. + +#if BUTTERAUGLI_ENABLE_CHECKS + +static inline bool IsNan(const float x) { + uint32_t bits; + memcpy(&bits, &x, sizeof(bits)); + const uint32_t bitmask_exp = 0x7F800000; + return (bits & bitmask_exp) == bitmask_exp && (bits & 0x7FFFFF); +} + +static inline bool IsNan(const double x) { + uint64_t bits; + memcpy(&bits, &x, sizeof(bits)); + return (0x7ff0000000000001ULL <= bits && bits <= 0x7fffffffffffffffULL) || + (0xfff0000000000001ULL <= bits && bits <= 0xffffffffffffffffULL); +} + +static inline void CheckImage(const ImageF& image, const char* name) { + for (size_t y = 0; y < image.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row = image.Row(y); + for (size_t x = 0; x < image.xsize(); ++x) { + if (IsNan(row[x])) { + printf("NAN: Image %s @ %" PRIuS ",%" PRIuS " (of %" PRIuS ",%" PRIuS + ")\n", + name, x, y, image.xsize(), image.ysize()); + exit(1); + } + } + } +} + +#define CHECK_NAN(x, str) \ + do { \ + if (IsNan(x)) { \ + printf("%d: %s\n", __LINE__, str); \ + abort(); \ + } \ + } while (0) + +#define CHECK_IMAGE(image, name) CheckImage(image, name) + +#else // BUTTERAUGLI_ENABLE_CHECKS + +#define CHECK_NAN(x, str) +#define CHECK_IMAGE(image, name) + +#endif // BUTTERAUGLI_ENABLE_CHECKS + +// Calculate a 2x2 subsampled image for purposes of recursive butteraugli at +// multiresolution. +static Image3F SubSample2x(const Image3F& in) { + size_t xs = (in.xsize() + 1) / 2; + size_t ys = (in.ysize() + 1) / 2; + Image3F retval(xs, ys); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < ys; ++y) { + for (size_t x = 0; x < xs; ++x) { + retval.PlaneRow(c, y)[x] = 0; + } + } + } + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < in.ysize(); ++y) { + for (size_t x = 0; x < in.xsize(); ++x) { + retval.PlaneRow(c, y / 2)[x / 2] += 0.25f * in.PlaneRow(c, y)[x]; + } + } + if ((in.xsize() & 1) != 0) { + for (size_t y = 0; y < retval.ysize(); ++y) { + size_t last_column = retval.xsize() - 1; + retval.PlaneRow(c, y)[last_column] *= 2.0f; + } + } + if ((in.ysize() & 1) != 0) { + for (size_t x = 0; x < retval.xsize(); ++x) { + size_t last_row = retval.ysize() - 1; + retval.PlaneRow(c, last_row)[x] *= 2.0f; + } + } + } + return retval; +} + +// Supersample src by 2x and add it to dest. +static void AddSupersampled2x(const ImageF& src, float w, ImageF& dest) { + for (size_t y = 0; y < dest.ysize(); ++y) { + for (size_t x = 0; x < dest.xsize(); ++x) { + // There will be less errors from the more averaged images. + // We take it into account to some extent using a scaler. + static const double kHeuristicMixingValue = 0.3; + dest.Row(y)[x] *= 1.0 - kHeuristicMixingValue * w; + dest.Row(y)[x] += w * src.Row(y / 2)[x / 2]; + } + } +} + +Image3F* ButteraugliComparator::Temp() const { + bool was_in_use = temp_in_use_.test_and_set(std::memory_order_acq_rel); + JXL_ASSERT(!was_in_use); + (void)was_in_use; + return &temp_; +} + +void ButteraugliComparator::ReleaseTemp() const { temp_in_use_.clear(); } + +ButteraugliComparator::ButteraugliComparator(const Image3F& rgb0, + const ButteraugliParams& params) + : xsize_(rgb0.xsize()), + ysize_(rgb0.ysize()), + params_(params), + temp_(xsize_, ysize_) { + if (xsize_ < 8 || ysize_ < 8) { + return; + } + + Image3F xyb0(xsize_, ysize_); + HWY_DYNAMIC_DISPATCH(OpsinDynamicsImage) + (rgb0, params, Temp(), &blur_temp_, &xyb0); + ReleaseTemp(); + HWY_DYNAMIC_DISPATCH(SeparateFrequencies) + (xsize_, ysize_, params_, &blur_temp_, xyb0, pi0_); + + // Awful recursive construction of samples of different resolution. + // This is an after-thought and possibly somewhat parallel in + // functionality with the PsychoImage multi-resolution approach. + sub_.reset(new ButteraugliComparator(SubSample2x(rgb0), params)); +} + +void ButteraugliComparator::Mask(ImageF* BUTTERAUGLI_RESTRICT mask) const { + HWY_DYNAMIC_DISPATCH(MaskPsychoImage) + (pi0_, pi0_, xsize_, ysize_, params_, &blur_temp_, mask, nullptr); +} + +void ButteraugliComparator::Diffmap(const Image3F& rgb1, ImageF& result) const { + if (xsize_ < 8 || ysize_ < 8) { + ZeroFillImage(&result); + return; + } + Image3F xyb1(xsize_, ysize_); + HWY_DYNAMIC_DISPATCH(OpsinDynamicsImage) + (rgb1, params_, Temp(), &blur_temp_, &xyb1); + ReleaseTemp(); + DiffmapOpsinDynamicsImage(xyb1, result); + if (sub_) { + if (sub_->xsize_ < 8 || sub_->ysize_ < 8) { + return; + } + Image3F sub_xyb(sub_->xsize_, sub_->ysize_); + HWY_DYNAMIC_DISPATCH(OpsinDynamicsImage) + (SubSample2x(rgb1), params_, sub_->Temp(), &sub_->blur_temp_, &sub_xyb); + sub_->ReleaseTemp(); + ImageF subresult; + sub_->DiffmapOpsinDynamicsImage(sub_xyb, subresult); + AddSupersampled2x(subresult, 0.5, result); + } +} + +void ButteraugliComparator::DiffmapOpsinDynamicsImage(const Image3F& xyb1, + ImageF& result) const { + if (xsize_ < 8 || ysize_ < 8) { + ZeroFillImage(&result); + return; + } + PsychoImage pi1; + HWY_DYNAMIC_DISPATCH(SeparateFrequencies) + (xsize_, ysize_, params_, &blur_temp_, xyb1, pi1); + result = ImageF(xsize_, ysize_); + DiffmapPsychoImage(pi1, result); +} + +namespace { + +void MaltaDiffMap(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, + ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + HWY_DYNAMIC_DISPATCH(MaltaDiffMap) + (lum0, lum1, w_0gt1, w_0lt1, norm1, diffs, &block_diff_ac->Plane(c)); +} + +void MaltaDiffMapLF(const ImageF& lum0, const ImageF& lum1, const double w_0gt1, + const double w_0lt1, const double norm1, + ImageF* HWY_RESTRICT diffs, + Image3F* HWY_RESTRICT block_diff_ac, size_t c) { + HWY_DYNAMIC_DISPATCH(MaltaDiffMapLF) + (lum0, lum1, w_0gt1, w_0lt1, norm1, diffs, &block_diff_ac->Plane(c)); +} + +} // namespace + +void ButteraugliComparator::DiffmapPsychoImage(const PsychoImage& pi1, + ImageF& diffmap) const { + if (xsize_ < 8 || ysize_ < 8) { + ZeroFillImage(&diffmap); + return; + } + + const float hf_asymmetry_ = params_.hf_asymmetry; + const float xmul_ = params_.xmul; + + ImageF diffs(xsize_, ysize_); + Image3F block_diff_ac(xsize_, ysize_); + ZeroFillImage(&block_diff_ac); + MaltaDiffMap(pi0_.uhf[1], pi1.uhf[1], wUhfMalta * hf_asymmetry_, + wUhfMalta / hf_asymmetry_, norm1Uhf, &diffs, &block_diff_ac, 1); + MaltaDiffMap(pi0_.uhf[0], pi1.uhf[0], wUhfMaltaX * hf_asymmetry_, + wUhfMaltaX / hf_asymmetry_, norm1UhfX, &diffs, &block_diff_ac, + 0); + MaltaDiffMapLF(pi0_.hf[1], pi1.hf[1], wHfMalta * std::sqrt(hf_asymmetry_), + wHfMalta / std::sqrt(hf_asymmetry_), norm1Hf, &diffs, + &block_diff_ac, 1); + MaltaDiffMapLF(pi0_.hf[0], pi1.hf[0], wHfMaltaX * std::sqrt(hf_asymmetry_), + wHfMaltaX / std::sqrt(hf_asymmetry_), norm1HfX, &diffs, + &block_diff_ac, 0); + MaltaDiffMapLF(pi0_.mf.Plane(1), pi1.mf.Plane(1), wMfMalta, wMfMalta, norm1Mf, + &diffs, &block_diff_ac, 1); + MaltaDiffMapLF(pi0_.mf.Plane(0), pi1.mf.Plane(0), wMfMaltaX, wMfMaltaX, + norm1MfX, &diffs, &block_diff_ac, 0); + + Image3F block_diff_dc(xsize_, ysize_); + for (size_t c = 0; c < 3; ++c) { + if (c < 2) { // No blue channel error accumulated at HF. + HWY_DYNAMIC_DISPATCH(L2DiffAsymmetric) + (pi0_.hf[c], pi1.hf[c], wmul[c] * hf_asymmetry_, wmul[c] / hf_asymmetry_, + &block_diff_ac.Plane(c)); + } + HWY_DYNAMIC_DISPATCH(L2Diff) + (pi0_.mf.Plane(c), pi1.mf.Plane(c), wmul[3 + c], &block_diff_ac.Plane(c)); + HWY_DYNAMIC_DISPATCH(SetL2Diff) + (pi0_.lf.Plane(c), pi1.lf.Plane(c), wmul[6 + c], &block_diff_dc.Plane(c)); + } + + ImageF mask; + HWY_DYNAMIC_DISPATCH(MaskPsychoImage) + (pi0_, pi1, xsize_, ysize_, params_, &blur_temp_, &mask, + &block_diff_ac.Plane(1)); + + HWY_DYNAMIC_DISPATCH(CombineChannelsToDiffmap) + (mask, block_diff_dc, block_diff_ac, xmul_, &diffmap); +} + +double ButteraugliScoreFromDiffmap(const ImageF& diffmap, + const ButteraugliParams* params) { + float retval = 0.0f; + for (size_t y = 0; y < diffmap.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row = diffmap.ConstRow(y); + for (size_t x = 0; x < diffmap.xsize(); ++x) { + retval = std::max(retval, row[x]); + } + } + return retval; +} + +bool ButteraugliDiffmap(const Image3F& rgb0, const Image3F& rgb1, + double hf_asymmetry, double xmul, ImageF& diffmap) { + ButteraugliParams params; + params.hf_asymmetry = hf_asymmetry; + params.xmul = xmul; + return ButteraugliDiffmap(rgb0, rgb1, params, diffmap); +} + +template <size_t kMax> +bool ButteraugliDiffmapSmall(const Image3F& rgb0, const Image3F& rgb1, + const ButteraugliParams& params, ImageF& diffmap) { + const size_t xsize = rgb0.xsize(); + const size_t ysize = rgb0.ysize(); + // Butteraugli values for small (where xsize or ysize is smaller + // than 8 pixels) images are non-sensical, but most likely it is + // less disruptive to try to compute something than just give up. + // Temporarily extend the borders of the image to fit 8 x 8 size. + size_t xborder = xsize < kMax ? (kMax - xsize) / 2 : 0; + size_t yborder = ysize < kMax ? (kMax - ysize) / 2 : 0; + size_t xscaled = std::max<size_t>(kMax, xsize); + size_t yscaled = std::max<size_t>(kMax, ysize); + Image3F scaled0(xscaled, yscaled); + Image3F scaled1(xscaled, yscaled); + for (int i = 0; i < 3; ++i) { + for (size_t y = 0; y < yscaled; ++y) { + for (size_t x = 0; x < xscaled; ++x) { + size_t x2 = std::min<size_t>(xsize - 1, x > xborder ? x - xborder : 0); + size_t y2 = std::min<size_t>(ysize - 1, y > yborder ? y - yborder : 0); + scaled0.PlaneRow(i, y)[x] = rgb0.PlaneRow(i, y2)[x2]; + scaled1.PlaneRow(i, y)[x] = rgb1.PlaneRow(i, y2)[x2]; + } + } + } + ImageF diffmap_scaled; + const bool ok = ButteraugliDiffmap(scaled0, scaled1, params, diffmap_scaled); + diffmap = ImageF(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + diffmap.Row(y)[x] = diffmap_scaled.Row(y + yborder)[x + xborder]; + } + } + return ok; +} + +bool ButteraugliDiffmap(const Image3F& rgb0, const Image3F& rgb1, + const ButteraugliParams& params, ImageF& diffmap) { + const size_t xsize = rgb0.xsize(); + const size_t ysize = rgb0.ysize(); + if (xsize < 1 || ysize < 1) { + return JXL_FAILURE("Zero-sized image"); + } + if (!SameSize(rgb0, rgb1)) { + return JXL_FAILURE("Size mismatch"); + } + static const int kMax = 8; + if (xsize < kMax || ysize < kMax) { + return ButteraugliDiffmapSmall<kMax>(rgb0, rgb1, params, diffmap); + } + ButteraugliComparator butteraugli(rgb0, params); + butteraugli.Diffmap(rgb1, diffmap); + return true; +} + +bool ButteraugliInterface(const Image3F& rgb0, const Image3F& rgb1, + float hf_asymmetry, float xmul, ImageF& diffmap, + double& diffvalue) { + ButteraugliParams params; + params.hf_asymmetry = hf_asymmetry; + params.xmul = xmul; + return ButteraugliInterface(rgb0, rgb1, params, diffmap, diffvalue); +} + +bool ButteraugliInterface(const Image3F& rgb0, const Image3F& rgb1, + const ButteraugliParams& params, ImageF& diffmap, + double& diffvalue) { + if (!ButteraugliDiffmap(rgb0, rgb1, params, diffmap)) { + return false; + } + diffvalue = ButteraugliScoreFromDiffmap(diffmap, ¶ms); + return true; +} + +bool ButteraugliInterfaceInPlace(Image3F&& rgb0, Image3F&& rgb1, + const ButteraugliParams& params, + ImageF& diffmap, double& diffvalue) { + const size_t xsize = rgb0.xsize(); + const size_t ysize = rgb0.ysize(); + if (xsize < 1 || ysize < 1) { + return JXL_FAILURE("Zero-sized image"); + } + if (!SameSize(rgb0, rgb1)) { + return JXL_FAILURE("Size mismatch"); + } + static const int kMax = 8; + if (xsize < kMax || ysize < kMax) { + bool ok = ButteraugliDiffmapSmall<kMax>(rgb0, rgb1, params, diffmap); + diffvalue = ButteraugliScoreFromDiffmap(diffmap, ¶ms); + return ok; + } + ImageF subdiffmap; + if (xsize >= 15 && ysize >= 15) { + Image3F rgb0_sub = SubSample2x(rgb0); + Image3F rgb1_sub = SubSample2x(rgb1); + HWY_DYNAMIC_DISPATCH(ButteraugliDiffmapInPlace) + (rgb0_sub, rgb1_sub, params, subdiffmap); + } + HWY_DYNAMIC_DISPATCH(ButteraugliDiffmapInPlace)(rgb0, rgb1, params, diffmap); + if (xsize >= 15 && ysize >= 15) { + AddSupersampled2x(subdiffmap, 0.5, diffmap); + } + diffvalue = ButteraugliScoreFromDiffmap(diffmap, ¶ms); + return true; +} + +double ButteraugliFuzzyClass(double score) { + static const double fuzzy_width_up = 4.8; + static const double fuzzy_width_down = 4.8; + static const double m0 = 2.0; + static const double scaler = 0.7777; + double val; + if (score < 1.0) { + // val in [scaler .. 2.0] + val = m0 / (1.0 + exp((score - 1.0) * fuzzy_width_down)); + val -= 1.0; // from [1 .. 2] to [0 .. 1] + val *= 2.0 - scaler; // from [0 .. 1] to [0 .. 2.0 - scaler] + val += scaler; // from [0 .. 2.0 - scaler] to [scaler .. 2.0] + } else { + // val in [0 .. scaler] + val = m0 / (1.0 + exp((score - 1.0) * fuzzy_width_up)); + val *= scaler; + } + return val; +} + +// #define PRINT_OUT_NORMALIZATION + +double ButteraugliFuzzyInverse(double seek) { + double pos = 0; + // NOLINTNEXTLINE(clang-analyzer-security.FloatLoopCounter) + for (double range = 1.0; range >= 1e-10; range *= 0.5) { + double cur = ButteraugliFuzzyClass(pos); + if (cur < seek) { + pos -= range; + } else { + pos += range; + } + } +#ifdef PRINT_OUT_NORMALIZATION + if (seek == 1.0) { + fprintf(stderr, "Fuzzy inverse %g\n", pos); + } +#endif + return pos; +} + +#ifdef PRINT_OUT_NORMALIZATION +static double print_out_normalization = ButteraugliFuzzyInverse(1.0); +#endif + +namespace { + +void ScoreToRgb(double score, double good_threshold, double bad_threshold, + float rgb[3]) { + double heatmap[12][3] = { + {0, 0, 0}, {0, 0, 1}, + {0, 1, 1}, {0, 1, 0}, // Good level + {1, 1, 0}, {1, 0, 0}, // Bad level + {1, 0, 1}, {0.5, 0.5, 1.0}, + {1.0, 0.5, 0.5}, // Pastel colors for the very bad quality range. + {1.0, 1.0, 0.5}, {1, 1, 1}, + {1, 1, 1}, // Last color repeated to have a solid range of white. + }; + if (score < good_threshold) { + score = (score / good_threshold) * 0.3; + } else if (score < bad_threshold) { + score = 0.3 + + (score - good_threshold) / (bad_threshold - good_threshold) * 0.15; + } else { + score = 0.45 + (score - bad_threshold) / (bad_threshold * 12) * 0.5; + } + static const int kTableSize = sizeof(heatmap) / sizeof(heatmap[0]); + score = std::min<double>(std::max<double>(score * (kTableSize - 1), 0.0), + kTableSize - 2); + int ix = static_cast<int>(score); + ix = std::min(std::max(0, ix), kTableSize - 2); // Handle NaN + double mix = score - ix; + for (int i = 0; i < 3; ++i) { + double v = mix * heatmap[ix + 1][i] + (1 - mix) * heatmap[ix][i]; + rgb[i] = pow(v, 0.5); + } +} + +} // namespace + +Image3F CreateHeatMapImage(const ImageF& distmap, double good_threshold, + double bad_threshold) { + Image3F heatmap(distmap.xsize(), distmap.ysize()); + for (size_t y = 0; y < distmap.ysize(); ++y) { + const float* BUTTERAUGLI_RESTRICT row_distmap = distmap.ConstRow(y); + float* BUTTERAUGLI_RESTRICT row_h0 = heatmap.PlaneRow(0, y); + float* BUTTERAUGLI_RESTRICT row_h1 = heatmap.PlaneRow(1, y); + float* BUTTERAUGLI_RESTRICT row_h2 = heatmap.PlaneRow(2, y); + for (size_t x = 0; x < distmap.xsize(); ++x) { + const float d = row_distmap[x]; + float rgb[3]; + ScoreToRgb(d, good_threshold, bad_threshold, rgb); + row_h0[x] = rgb[0]; + row_h1[x] = rgb[1]; + row_h2[x] = rgb[2]; + } + } + return heatmap; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.h b/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.h new file mode 100644 index 0000000000..29130e8768 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli.h @@ -0,0 +1,214 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// Author: Jyrki Alakuijala (jyrki.alakuijala@gmail.com) + +#ifndef LIB_JXL_BUTTERAUGLI_BUTTERAUGLI_H_ +#define LIB_JXL_BUTTERAUGLI_BUTTERAUGLI_H_ + +#include <stdint.h> +#include <stdlib.h> +#include <string.h> + +#include <atomic> +#include <cmath> +#include <memory> +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +#define BUTTERAUGLI_ENABLE_CHECKS 0 +#define BUTTERAUGLI_RESTRICT JXL_RESTRICT + +// This is the main interface to butteraugli image similarity +// analysis function. + +namespace jxl { + +struct ButteraugliParams { + // Multiplier for penalizing new HF artifacts more than blurring away + // features. 1.0=neutral. + float hf_asymmetry = 1.0f; + + // Multiplier for the psychovisual difference in the X channel. + float xmul = 1.0f; + + // Number of nits that correspond to 1.0f input values. + float intensity_target = 80.0f; +}; + +// ButteraugliInterface defines the public interface for butteraugli. +// +// It calculates the difference between rgb0 and rgb1. +// +// rgb0 and rgb1 contain the images. rgb0[c][px] and rgb1[c][px] contains +// the red image for c == 0, green for c == 1, blue for c == 2. Location index +// px is calculated as y * xsize + x. +// +// Value of pixels of images rgb0 and rgb1 need to be represented as raw +// intensity. Most image formats store gamma corrected intensity in pixel +// values. This gamma correction has to be removed, by applying the following +// function to values in the 0-1 range: +// butteraugli_val = pow(input_val, gamma); +// A typical value of gamma is 2.2. It is usually stored in the image header. +// Take care not to confuse that value with its inverse. The gamma value should +// be always greater than one. +// Butteraugli does not work as intended if the caller does not perform +// gamma correction. +// +// hf_asymmetry is a multiplier for penalizing new HF artifacts more than +// blurring away features (1.0 -> neutral). +// +// diffmap will contain an image of the size xsize * ysize, containing +// localized differences for values px (indexed with the px the same as rgb0 +// and rgb1). diffvalue will give a global score of similarity. +// +// A diffvalue smaller than kButteraugliGood indicates that images can be +// observed as the same image. +// diffvalue larger than kButteraugliBad indicates that a difference between +// the images can be observed. +// A diffvalue between kButteraugliGood and kButteraugliBad indicates that +// a subtle difference can be observed between the images. +// +// Returns true on success. +bool ButteraugliInterface(const Image3F &rgb0, const Image3F &rgb1, + const ButteraugliParams ¶ms, ImageF &diffmap, + double &diffvalue); + +// Deprecated (calls the previous function) +bool ButteraugliInterface(const Image3F &rgb0, const Image3F &rgb1, + float hf_asymmetry, float xmul, ImageF &diffmap, + double &diffvalue); + +// Same as ButteraugliInterface, but reuses rgb0 and rgb1 for other purposes +// inside the function after they are not needed any more, and it ignores +// params.xmul. +bool ButteraugliInterfaceInPlace(Image3F &&rgb0, Image3F &&rgb1, + const ButteraugliParams ¶ms, + ImageF &diffmap, double &diffvalue); + +// Converts the butteraugli score into fuzzy class values that are continuous +// at the class boundary. The class boundary location is based on human +// raters, but the slope is arbitrary. Particularly, it does not reflect +// the expectation value of probabilities of the human raters. It is just +// expected that a smoother class boundary will allow for higher-level +// optimization algorithms to work faster. +// +// Returns 2.0 for a perfect match, and 1.0 for 'ok', 0.0 for bad. Because the +// scoring is fuzzy, a butteraugli score of 0.96 would return a class of +// around 1.9. +double ButteraugliFuzzyClass(double score); + +// Input values should be in range 0 (bad) to 2 (good). Use +// kButteraugliNormalization as normalization. +double ButteraugliFuzzyInverse(double seek); + +// Implementation details, don't use anything below or your code will +// break in the future. + +#ifdef _MSC_VER +#define BUTTERAUGLI_INLINE __forceinline +#else +#define BUTTERAUGLI_INLINE inline +#endif + +#ifdef __clang__ +// Early versions of Clang did not support __builtin_assume_aligned. +#define BUTTERAUGLI_HAS_ASSUME_ALIGNED __has_builtin(__builtin_assume_aligned) +#elif defined(__GNUC__) +#define BUTTERAUGLI_HAS_ASSUME_ALIGNED 1 +#else +#define BUTTERAUGLI_HAS_ASSUME_ALIGNED 0 +#endif + +// Returns a void* pointer which the compiler then assumes is N-byte aligned. +// Example: float* JXL_RESTRICT aligned = (float*)JXL_ASSUME_ALIGNED(in, 32); +// +// The assignment semantics are required by GCC/Clang. ICC provides an in-place +// __assume_aligned, whereas MSVC's __assume appears unsuitable. +#if BUTTERAUGLI_HAS_ASSUME_ALIGNED +#define BUTTERAUGLI_ASSUME_ALIGNED(ptr, align) \ + __builtin_assume_aligned((ptr), (align)) +#else +#define BUTTERAUGLI_ASSUME_ALIGNED(ptr, align) (ptr) +#endif // BUTTERAUGLI_HAS_ASSUME_ALIGNED + +struct PsychoImage { + ImageF uhf[2]; // XY + ImageF hf[2]; // XY + Image3F mf; // XYB + Image3F lf; // XYB +}; + +// Blur needs a transposed image. +// Hold it here and only allocate on demand to reduce memory usage. +struct BlurTemp { + ImageF *GetTransposed(const ImageF &in) { + if (transposed_temp.xsize() == 0) { + transposed_temp = ImageF(in.ysize(), in.xsize()); + } + return &transposed_temp; + } + + ImageF transposed_temp; +}; + +class ButteraugliComparator { + public: + // Butteraugli is calibrated at xmul = 1.0. We add a multiplier here so that + // we can test the hypothesis that a higher weighing of the X channel would + // improve results at higher Butteraugli values. + ButteraugliComparator(const Image3F &rgb0, const ButteraugliParams ¶ms); + virtual ~ButteraugliComparator() = default; + + // Computes the butteraugli map between the original image given in the + // constructor and the distorted image give here. + void Diffmap(const Image3F &rgb1, ImageF &result) const; + + // Same as above, but OpsinDynamicsImage() was already applied. + void DiffmapOpsinDynamicsImage(const Image3F &xyb1, ImageF &result) const; + + // Same as above, but the frequency decomposition was already applied. + void DiffmapPsychoImage(const PsychoImage &pi1, ImageF &diffmap) const; + + void Mask(ImageF *BUTTERAUGLI_RESTRICT mask) const; + + private: + Image3F *Temp() const; + void ReleaseTemp() const; + + const size_t xsize_; + const size_t ysize_; + ButteraugliParams params_; + PsychoImage pi0_; + + // Shared temporary image storage to reduce the number of allocations; + // obtained via Temp(), must call ReleaseTemp when no longer needed. + mutable Image3F temp_; + mutable std::atomic_flag temp_in_use_ = ATOMIC_FLAG_INIT; + + mutable BlurTemp blur_temp_; + std::unique_ptr<ButteraugliComparator> sub_; +}; + +// Deprecated. +bool ButteraugliDiffmap(const Image3F &rgb0, const Image3F &rgb1, + double hf_asymmetry, double xmul, ImageF &diffmap); + +bool ButteraugliDiffmap(const Image3F &rgb0, const Image3F &rgb1, + const ButteraugliParams ¶ms, ImageF &diffmap); + +double ButteraugliScoreFromDiffmap(const ImageF &diffmap, + const ButteraugliParams *params = nullptr); + +// Generate rgb-representation of the distance between two images. +Image3F CreateHeatMapImage(const ImageF &distmap, double good_threshold, + double bad_threshold); + +} // namespace jxl + +#endif // LIB_JXL_BUTTERAUGLI_BUTTERAUGLI_H_ diff --git a/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli_test.cc b/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli_test.cc new file mode 100644 index 0000000000..c2ccf56175 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/butteraugli/butteraugli_test.cc @@ -0,0 +1,117 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/butteraugli/butteraugli.h" + +#include <jxl/types.h> +#include <stddef.h> + +#include <algorithm> +#include <cstdint> +#include <utility> + +#include "lib/extras/metrics.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/test_image.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +using extras::PackedImage; +using extras::PackedPixelFile; +using test::TestImage; + +Image3F SinglePixelImage(float red, float green, float blue) { + Image3F img(1, 1); + img.PlaneRow(0, 0)[0] = red; + img.PlaneRow(1, 0)[0] = green; + img.PlaneRow(2, 0)[0] = blue; + return img; +} + +Image3F GetColorImage(const PackedPixelFile& ppf) { + JXL_CHECK(!ppf.frames.empty()); + const PackedImage& image = ppf.frames[0].color; + const JxlPixelFormat& format = image.format; + const uint8_t* pixels = reinterpret_cast<const uint8_t*>(image.pixels()); + Image3F color(image.xsize, image.ysize); + for (size_t c = 0; c < format.num_channels; ++c) { + JXL_CHECK(ConvertFromExternal(pixels, image.pixels_size, image.xsize, + image.ysize, ppf.info.bits_per_sample, format, + c, nullptr, &color.Plane(c))); + } + return color; +} + +void AddUniformNoise(Image3F* img, float d, size_t seed) { + Rng generator(seed); + for (size_t y = 0; y < img->ysize(); ++y) { + for (int c = 0; c < 3; ++c) { + for (size_t x = 0; x < img->xsize(); ++x) { + img->PlaneRow(c, y)[x] += generator.UniformF(-d, d); + } + } + } +} + +void AddEdge(Image3F* img, float d, size_t x0, size_t y0) { + const size_t h = std::min<size_t>(img->ysize() - y0, 100); + const size_t w = std::min<size_t>(img->xsize() - x0, 5); + for (size_t dy = 0; dy < h; ++dy) { + for (size_t dx = 0; dx < w; ++dx) { + img->PlaneRow(1, y0 + dy)[x0 + dx] += d; + } + } +} + +TEST(ButteraugliInPlaceTest, SinglePixel) { + Image3F rgb0 = SinglePixelImage(0.5f, 0.5f, 0.5f); + Image3F rgb1 = SinglePixelImage(0.5f, 0.49f, 0.5f); + ButteraugliParams ba; + ImageF diffmap; + double diffval; + EXPECT_TRUE(ButteraugliInterface(rgb0, rgb1, ba, diffmap, diffval)); + EXPECT_NEAR(diffval, 2.5, 0.5); + ImageF diffmap2; + double diffval2; + EXPECT_TRUE(ButteraugliInterfaceInPlace(std::move(rgb0), std::move(rgb1), ba, + diffmap2, diffval2)); + EXPECT_NEAR(diffval, diffval2, 1e-10); +} + +TEST(ButteraugliInPlaceTest, LargeImage) { + const size_t xsize = 1024; + const size_t ysize = 1024; + TestImage img; + img.SetDimensions(xsize, ysize).AddFrame().RandomFill(777); + Image3F rgb0 = GetColorImage(img.ppf()); + Image3F rgb1(xsize, ysize); + CopyImageTo(rgb0, &rgb1); + AddUniformNoise(&rgb1, 0.02f, 7777); + AddEdge(&rgb1, 0.1f, xsize / 2, xsize / 2); + ButteraugliParams ba; + ImageF diffmap; + double diffval; + EXPECT_TRUE(ButteraugliInterface(rgb0, rgb1, ba, diffmap, diffval)); + double distp = ComputeDistanceP(diffmap, ba, 3.0); + EXPECT_NEAR(diffval, 4.0, 0.5); + EXPECT_NEAR(distp, 1.5, 0.5); + ImageF diffmap2; + double diffval2; + EXPECT_TRUE(ButteraugliInterfaceInPlace(std::move(rgb0), std::move(rgb1), ba, + diffmap2, diffval2)); + double distp2 = ComputeDistanceP(diffmap2, ba, 3.0); + EXPECT_NEAR(diffval, diffval2, 5e-7); + EXPECT_NEAR(distp, distp2, 1e-7); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/byte_order_test.cc b/third_party/jpeg-xl/lib/jxl/byte_order_test.cc new file mode 100644 index 0000000000..17d7ef6643 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/byte_order_test.cc @@ -0,0 +1,53 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/byte_order.h" + +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +TEST(ByteOrderTest, TestRoundTripBE16) { + const uint32_t in = 0x1234; + uint8_t buf[2]; + StoreBE16(in, buf); + EXPECT_EQ(in, LoadBE16(buf)); + EXPECT_NE(in, LoadLE16(buf)); +} + +TEST(ByteOrderTest, TestRoundTripLE16) { + const uint32_t in = 0x1234; + uint8_t buf[2]; + StoreLE16(in, buf); + EXPECT_EQ(in, LoadLE16(buf)); + EXPECT_NE(in, LoadBE16(buf)); +} + +TEST(ByteOrderTest, TestRoundTripBE32) { + const uint32_t in = 0xFEDCBA98u; + uint8_t buf[4]; + StoreBE32(in, buf); + EXPECT_EQ(in, LoadBE32(buf)); + EXPECT_NE(in, LoadLE32(buf)); +} + +TEST(ByteOrderTest, TestRoundTripLE32) { + const uint32_t in = 0xFEDCBA98u; + uint8_t buf[4]; + StoreLE32(in, buf); + EXPECT_EQ(in, LoadLE32(buf)); + EXPECT_NE(in, LoadBE32(buf)); +} + +TEST(ByteOrderTest, TestRoundTripLE64) { + const uint64_t in = 0xFEDCBA9876543210ull; + uint8_t buf[8]; + StoreLE64(in, buf); + EXPECT_EQ(in, LoadLE64(buf)); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/cache_aligned.cc b/third_party/jpeg-xl/lib/jxl/cache_aligned.cc new file mode 100644 index 0000000000..992efc4d48 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cache_aligned.cc @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/cache_aligned.h" + +#include <stdio.h> +#include <stdlib.h> + +// Disabled: slower than malloc + alignment. +#define JXL_USE_MMAP 0 + +#if JXL_USE_MMAP +#include <sys/mman.h> +#endif + +#include <algorithm> // std::max +#include <atomic> +#include <hwy/base.h> // kMaxVectorSize +#include <limits> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace { + +#pragma pack(push, 1) +struct AllocationHeader { + void* allocated; + size_t allocated_size; + uint8_t left_padding[hwy::kMaxVectorSize]; +}; +#pragma pack(pop) + +std::atomic<uint64_t> num_allocations{0}; +std::atomic<uint64_t> bytes_in_use{0}; +std::atomic<uint64_t> max_bytes_in_use{0}; + +} // namespace + +// Avoids linker errors in pre-C++17 builds. +constexpr size_t CacheAligned::kPointerSize; +constexpr size_t CacheAligned::kCacheLineSize; +constexpr size_t CacheAligned::kAlignment; +constexpr size_t CacheAligned::kAlias; + +void CacheAligned::PrintStats() { + fprintf( + stderr, "Allocations: %" PRIuS " (max bytes in use: %E)\n", + static_cast<size_t>(num_allocations.load(std::memory_order_relaxed)), + static_cast<double>(max_bytes_in_use.load(std::memory_order_relaxed))); +} + +size_t CacheAligned::NextOffset() { + static std::atomic<uint32_t> next{0}; + constexpr uint32_t kGroups = CacheAligned::kAlias / CacheAligned::kAlignment; + const uint32_t group = next.fetch_add(1, std::memory_order_relaxed) % kGroups; + return CacheAligned::kAlignment * group; +} + +void* CacheAligned::Allocate(const size_t payload_size, size_t offset) { + JXL_ASSERT(payload_size <= std::numeric_limits<size_t>::max() / 2); + JXL_ASSERT((offset % kAlignment == 0) && offset <= kAlias); + + // What: | misalign | unused | AllocationHeader |payload + // Size: |<= kAlias | offset | |payload_size + // ^allocated.^aligned.^header............^payload + // The header must immediately precede payload, which must remain aligned. + // To avoid wasting space, the header resides at the end of `unused`, + // which therefore cannot be empty (offset == 0). + if (offset == 0) { + // SVE/RVV vectors can be large, so we cannot rely on them (including the + // padding at the end of AllocationHeader) to fit in kAlignment. + offset = hwy::RoundUpTo(sizeof(AllocationHeader), kAlignment); + } + +#if JXL_USE_MMAP + const size_t allocated_size = offset + payload_size; + const int flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_POPULATE; + void* allocated = + mmap(nullptr, allocated_size, PROT_READ | PROT_WRITE, flags, -1, 0); + if (allocated == MAP_FAILED) return nullptr; + const uintptr_t aligned = reinterpret_cast<uintptr_t>(allocated); +#else + const size_t allocated_size = kAlias + offset + payload_size; + void* allocated = malloc(allocated_size); + if (allocated == nullptr) return nullptr; + // Always round up even if already aligned - we already asked for kAlias + // extra bytes and there's no way to give them back. + uintptr_t aligned = reinterpret_cast<uintptr_t>(allocated) + kAlias; + static_assert((kAlias & (kAlias - 1)) == 0, "kAlias must be a power of 2"); + static_assert(kAlias >= kAlignment, "Cannot align to more than kAlias"); + aligned &= ~(kAlias - 1); +#endif + +#if 0 + // No effect. + uintptr_t page_aligned = reinterpret_cast<uintptr_t>(allocated); + page_aligned &= ~(4096 - 1); + if (madvise(reinterpret_cast<void*>(page_aligned), allocated_size, + MADV_WILLNEED) != 0) { + JXL_NOTIFY_ERROR("madvise failed"); + } +#elif 0 + // INCREASES both first and subsequent decode times. + if (mlock(allocated, allocated_size) != 0) { + JXL_NOTIFY_ERROR("mlock failed"); + } +#endif + + // Update statistics (#allocations and max bytes in use) + num_allocations.fetch_add(1, std::memory_order_relaxed); + const uint64_t prev_bytes = + bytes_in_use.fetch_add(allocated_size, std::memory_order_acq_rel); + uint64_t expected_max = max_bytes_in_use.load(std::memory_order_acquire); + for (;;) { + const uint64_t desired = + std::max(expected_max, prev_bytes + allocated_size); + if (max_bytes_in_use.compare_exchange_strong(expected_max, desired, + std::memory_order_acq_rel)) { + break; + } + } + + const uintptr_t payload = aligned + offset; // still aligned + + // Stash `allocated` and payload_size inside header for use by Free(). + AllocationHeader* header = reinterpret_cast<AllocationHeader*>(payload) - 1; + header->allocated = allocated; + header->allocated_size = allocated_size; + + return JXL_ASSUME_ALIGNED(reinterpret_cast<void*>(payload), 64); +} + +void CacheAligned::Free(const void* aligned_pointer) { + if (aligned_pointer == nullptr) { + return; + } + const uintptr_t payload = reinterpret_cast<uintptr_t>(aligned_pointer); + JXL_ASSERT(payload % kAlignment == 0); + const AllocationHeader* header = + reinterpret_cast<const AllocationHeader*>(payload) - 1; + + // Subtract (2's complement negation). + bytes_in_use.fetch_add(~header->allocated_size + 1, + std::memory_order_acq_rel); + +#if JXL_USE_MMAP + munmap(header->allocated, header->allocated_size); +#else + free(header->allocated); +#endif +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/cache_aligned.h b/third_party/jpeg-xl/lib/jxl/cache_aligned.h new file mode 100644 index 0000000000..d79d7be461 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cache_aligned.h @@ -0,0 +1,67 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_CACHE_ALIGNED_H_ +#define LIB_JXL_BASE_CACHE_ALIGNED_H_ + +// Memory allocator with support for alignment + misalignment. + +#include <stddef.h> +#include <stdint.h> + +#include <memory> + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// Functions that depend on the cache line size. +class CacheAligned { + public: + static void PrintStats(); + + static constexpr size_t kPointerSize = sizeof(void*); + static constexpr size_t kCacheLineSize = 64; + // To avoid RFOs, match L2 fill size (pairs of lines). + static constexpr size_t kAlignment = 2 * kCacheLineSize; + // Minimum multiple for which cache set conflicts and/or loads blocked by + // preceding stores can occur. + static constexpr size_t kAlias = 2048; + + // Returns a 'random' (cyclical) offset suitable for Allocate. + static size_t NextOffset(); + + // Returns null or memory whose address is congruent to `offset` (mod kAlias). + // This reduces cache conflicts and load/store stalls, especially with large + // allocations that would otherwise have similar alignments. At least + // `payload_size` (which can be zero) bytes will be accessible. + static void* Allocate(size_t payload_size, size_t offset); + + static void* Allocate(const size_t payload_size) { + return Allocate(payload_size, NextOffset()); + } + + static void Free(const void* aligned_pointer); +}; + +// Avoids the need for a function pointer (deleter) in CacheAlignedUniquePtr. +struct CacheAlignedDeleter { + void operator()(uint8_t* aligned_pointer) const { + return CacheAligned::Free(aligned_pointer); + } +}; + +using CacheAlignedUniquePtr = std::unique_ptr<uint8_t[], CacheAlignedDeleter>; + +// Does not invoke constructors. +static inline CacheAlignedUniquePtr AllocateArray(const size_t bytes) { + return CacheAlignedUniquePtr( + static_cast<uint8_t*>(CacheAligned::Allocate(bytes)), + CacheAlignedDeleter()); +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_CACHE_ALIGNED_H_ diff --git a/third_party/jpeg-xl/lib/jxl/chroma_from_luma.cc b/third_party/jpeg-xl/lib/jxl/chroma_from_luma.cc new file mode 100644 index 0000000000..63d21cbb4b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/chroma_from_luma.cc @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/chroma_from_luma.h" + +namespace jxl { + +ColorCorrelationMap::ColorCorrelationMap(size_t xsize, size_t ysize, bool XYB) + : ytox_map(DivCeil(xsize, kColorTileDim), DivCeil(ysize, kColorTileDim)), + ytob_map(DivCeil(xsize, kColorTileDim), DivCeil(ysize, kColorTileDim)) { + ZeroFillImage(&ytox_map); + ZeroFillImage(&ytob_map); + if (!XYB) { + base_correlation_b_ = 0; + } + RecomputeDCFactors(); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/chroma_from_luma.h b/third_party/jpeg-xl/lib/jxl/chroma_from_luma.h new file mode 100644 index 0000000000..cb3b710762 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/chroma_from_luma.h @@ -0,0 +1,147 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_CHROMA_FROM_LUMA_H_ +#define LIB_JXL_CHROMA_FROM_LUMA_H_ + +// Chroma-from-luma, computed using heuristics to determine the best linear +// model for the X and B channels from the Y channel. + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/cms/opsin_params.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +// Tile is the rectangular grid of blocks that share color correlation +// parameters ("factor_x/b" such that residual_b = blue - Y * factor_b). +static constexpr size_t kColorTileDim = 64; + +static_assert(kColorTileDim % kBlockDim == 0, + "Color tile dim should be divisible by block dim"); +static constexpr size_t kColorTileDimInBlocks = kColorTileDim / kBlockDim; + +static_assert(kGroupDimInBlocks % kColorTileDimInBlocks == 0, + "Group dim should be divisible by color tile dim"); + +static constexpr uint8_t kDefaultColorFactor = 84; + +// JPEG DCT coefficients are at most 1024. CfL constants are at most 127, and +// the ratio of two entries in a JPEG quantization table is at most 255. Thus, +// since the CfL denominator is 84, this leaves 12 bits of mantissa to be used. +// For extra caution, we use 11. +static constexpr uint8_t kCFLFixedPointPrecision = 11; + +static constexpr U32Enc kColorFactorDist(Val(kDefaultColorFactor), Val(256), + BitsOffset(8, 2), BitsOffset(16, 258)); + +struct ColorCorrelationMap { + ColorCorrelationMap() = default; + // xsize/ysize are in pixels + // set XYB=false to do something close to no-op cmap (needed for now since + // cmap is mandatory) + ColorCorrelationMap(size_t xsize, size_t ysize, bool XYB = true); + + float YtoXRatio(int32_t x_factor) const { + return base_correlation_x_ + x_factor * color_scale_; + } + + float YtoBRatio(int32_t b_factor) const { + return base_correlation_b_ + b_factor * color_scale_; + } + + Status DecodeDC(BitReader* br) { + if (br->ReadFixedBits<1>() == 1) { + // All default. + return true; + } + SetColorFactor(U32Coder::Read(kColorFactorDist, br)); + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &base_correlation_x_)); + if (std::abs(base_correlation_x_) > 4.0f) { + return JXL_FAILURE("Base X correlation is out of range"); + } + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &base_correlation_b_)); + if (std::abs(base_correlation_b_) > 4.0f) { + return JXL_FAILURE("Base B correlation is out of range"); + } + ytox_dc_ = static_cast<int>(br->ReadFixedBits<kBitsPerByte>()) + + std::numeric_limits<int8_t>::min(); + ytob_dc_ = static_cast<int>(br->ReadFixedBits<kBitsPerByte>()) + + std::numeric_limits<int8_t>::min(); + RecomputeDCFactors(); + return true; + } + + // We consider a CfL map to be JPEG-reconstruction-compatible if base + // correlation is 0, no DC correlation is used, and we use the default color + // factor. + bool IsJPEGCompatible() const { + return base_correlation_x_ == 0 && base_correlation_b_ == 0 && + ytob_dc_ == 0 && ytox_dc_ == 0 && + color_factor_ == kDefaultColorFactor; + } + + int32_t RatioJPEG(int32_t factor) const { + return factor * (1 << kCFLFixedPointPrecision) / kDefaultColorFactor; + } + + void SetColorFactor(uint32_t factor) { + color_factor_ = factor; + color_scale_ = 1.0f / color_factor_; + RecomputeDCFactors(); + } + + void SetYToBDC(int32_t ytob_dc) { + ytob_dc_ = ytob_dc; + RecomputeDCFactors(); + } + void SetYToXDC(int32_t ytox_dc) { + ytox_dc_ = ytox_dc; + RecomputeDCFactors(); + } + + int32_t GetYToXDC() const { return ytox_dc_; } + int32_t GetYToBDC() const { return ytob_dc_; } + float GetColorFactor() const { return color_factor_; } + float GetBaseCorrelationX() const { return base_correlation_x_; } + float GetBaseCorrelationB() const { return base_correlation_b_; } + + const float* DCFactors() const { return dc_factors_; } + + void RecomputeDCFactors() { + dc_factors_[0] = YtoXRatio(ytox_dc_); + dc_factors_[2] = YtoBRatio(ytob_dc_); + } + + ImageSB ytox_map; + ImageSB ytob_map; + + private: + float dc_factors_[4] = {}; + // range of factor: -1.51 to +1.52 + uint32_t color_factor_ = kDefaultColorFactor; + float color_scale_ = 1.0f / color_factor_; + float base_correlation_x_ = 0.0f; + float base_correlation_b_ = jxl::cms::kYToBRatio; + int32_t ytox_dc_ = 0; + int32_t ytob_dc_ = 0; +}; + +} // namespace jxl + +#endif // LIB_JXL_CHROMA_FROM_LUMA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/cms/color_encoding_cms.h b/third_party/jpeg-xl/lib/jxl/cms/color_encoding_cms.h new file mode 100644 index 0000000000..db61f820ca --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cms/color_encoding_cms.h @@ -0,0 +1,623 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_CMS_COLOR_ENCODING_CMS_H_ +#define LIB_JXL_CMS_COLOR_ENCODING_CMS_H_ + +#include <jxl/cms_interface.h> +#include <jxl/color_encoding.h> +#include <jxl/types.h> + +#include <cmath> +#include <cstdint> +#include <cstring> +#include <utility> +#include <vector> + +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace cms { + +using IccBytes = std::vector<uint8_t>; + +// Returns whether the two inputs are approximately equal. +static inline bool ApproxEq(const double a, const double b, + double max_l1 = 1E-3) { + // Threshold should be sufficient for ICC's 15-bit fixed-point numbers. + // We have seen differences of 7.1E-5 with lcms2 and 1E-3 with skcms. + return std::abs(a - b) <= max_l1; +} + +// (All CIE units are for the standard 1931 2 degree observer) + +// Color space the color pixel data is encoded in. The color pixel data is +// 3-channel in all cases except in case of kGray, where it uses only 1 channel. +// This also determines the amount of channels used in modular encoding. +enum class ColorSpace : uint32_t { + // Trichromatic color data. This also includes CMYK if a kBlack + // ExtraChannelInfo is present. This implies, if there is an ICC profile, that + // the ICC profile uses a 3-channel color space if no kBlack extra channel is + // present, or uses color space 'CMYK' if a kBlack extra channel is present. + kRGB, + // Single-channel data. This implies, if there is an ICC profile, that the ICC + // profile also represents single-channel data and has the appropriate color + // space ('GRAY'). + kGray, + // Like kRGB, but implies fixed values for primaries etc. + kXYB, + // For non-RGB/gray data, e.g. from non-electro-optical sensors. Otherwise + // the same conditions as kRGB apply. + kUnknown + // NB: don't forget to update EnumBits! +}; + +// Values from CICP ColourPrimaries. +enum class WhitePoint : uint32_t { + kD65 = 1, // sRGB/BT.709/Display P3/BT.2020 + kCustom = 2, // Actual values encoded in separate fields + kE = 10, // XYZ + kDCI = 11, // DCI-P3 + // NB: don't forget to update EnumBits! +}; + +// Values from CICP ColourPrimaries +enum class Primaries : uint32_t { + kSRGB = 1, // Same as BT.709 + kCustom = 2, // Actual values encoded in separate fields + k2100 = 9, // Same as BT.2020 + kP3 = 11, + // NB: don't forget to update EnumBits! +}; + +// Values from CICP TransferCharacteristics +enum class TransferFunction : uint32_t { + k709 = 1, + kUnknown = 2, + kLinear = 8, + kSRGB = 13, + kPQ = 16, // from BT.2100 + kDCI = 17, // from SMPTE RP 431-2 reference projector + kHLG = 18, // from BT.2100 + // NB: don't forget to update EnumBits! +}; + +enum class RenderingIntent : uint32_t { + // Values match ICC sRGB encodings. + kPerceptual = 0, // good for photos, requires a profile with LUT. + kRelative, // good for logos. + kSaturation, // perhaps useful for CG with fully saturated colors. + kAbsolute, // leaves white point unchanged; good for proofing. + // NB: don't forget to update EnumBits! +}; + +// Chromaticity (Y is omitted because it is 1 for white points and implicit for +// primaries) +struct CIExy { + double x = 0.0; + double y = 0.0; +}; + +struct PrimariesCIExy { + CIExy r; + CIExy g; + CIExy b; +}; + +// Serializable form of CIExy. +struct Customxy { + static constexpr uint32_t kMul = 1000000; + static constexpr double kRoughLimit = 4.0; + static constexpr int32_t kMin = -0x200000; + static constexpr int32_t kMax = 0x1FFFFF; + + int32_t x = 0; + int32_t y = 0; + + CIExy GetValue() const { + CIExy xy; + xy.x = x * (1.0 / kMul); + xy.y = y * (1.0 / kMul); + return xy; + } + + Status SetValue(const CIExy& xy) { + bool ok = (std::abs(xy.x) < kRoughLimit) && (std::abs(xy.y) < kRoughLimit); + if (!ok) return JXL_FAILURE("X or Y is out of bounds"); + x = static_cast<int32_t>(roundf(xy.x * kMul)); + if (x < kMin || x > kMax) return JXL_FAILURE("X is out of bounds"); + y = static_cast<int32_t>(roundf(xy.y * kMul)); + if (y < kMin || y > kMax) return JXL_FAILURE("Y is out of bounds"); + return true; + } + + bool IsSame(const Customxy& other) const { + return (x == other.x) && (y == other.y); + } +}; + +static inline Status WhitePointFromExternal(const JxlWhitePoint external, + WhitePoint* out) { + switch (external) { + case JXL_WHITE_POINT_D65: + *out = WhitePoint::kD65; + return true; + case JXL_WHITE_POINT_CUSTOM: + *out = WhitePoint::kCustom; + return true; + case JXL_WHITE_POINT_E: + *out = WhitePoint::kE; + return true; + case JXL_WHITE_POINT_DCI: + *out = WhitePoint::kDCI; + return true; + } + return JXL_FAILURE("Invalid WhitePoint enum value %d", + static_cast<int>(external)); +} + +static inline Status PrimariesFromExternal(const JxlPrimaries external, + Primaries* out) { + switch (external) { + case JXL_PRIMARIES_SRGB: + *out = Primaries::kSRGB; + return true; + case JXL_PRIMARIES_CUSTOM: + *out = Primaries::kCustom; + return true; + case JXL_PRIMARIES_2100: + *out = Primaries::k2100; + return true; + case JXL_PRIMARIES_P3: + *out = Primaries::kP3; + return true; + } + return JXL_FAILURE("Invalid Primaries enum value"); +} + +static inline Status RenderingIntentFromExternal( + const JxlRenderingIntent external, RenderingIntent* out) { + switch (external) { + case JXL_RENDERING_INTENT_PERCEPTUAL: + *out = RenderingIntent::kPerceptual; + return true; + case JXL_RENDERING_INTENT_RELATIVE: + *out = RenderingIntent::kRelative; + return true; + case JXL_RENDERING_INTENT_SATURATION: + *out = RenderingIntent::kSaturation; + return true; + case JXL_RENDERING_INTENT_ABSOLUTE: + *out = RenderingIntent::kAbsolute; + return true; + } + return JXL_FAILURE("Invalid RenderingIntent enum value"); +} + +struct CustomTransferFunction { + // Highest reasonable value for the gamma of a transfer curve. + static constexpr uint32_t kMaxGamma = 8192; + static constexpr uint32_t kGammaMul = 10000000; + + bool have_gamma = false; + + // OETF exponent to go from linear to gamma-compressed. + uint32_t gamma = 0; // Only used if have_gamma_. + + // Can be kUnknown. + TransferFunction transfer_function = + TransferFunction::kSRGB; // Only used if !have_gamma_. + + TransferFunction GetTransferFunction() const { + JXL_ASSERT(!have_gamma); + return transfer_function; + } + void SetTransferFunction(const TransferFunction tf) { + have_gamma = false; + transfer_function = tf; + } + + bool IsUnknown() const { + return !have_gamma && (transfer_function == TransferFunction::kUnknown); + } + bool IsSRGB() const { + return !have_gamma && (transfer_function == TransferFunction::kSRGB); + } + bool IsLinear() const { + return !have_gamma && (transfer_function == TransferFunction::kLinear); + } + bool IsPQ() const { + return !have_gamma && (transfer_function == TransferFunction::kPQ); + } + bool IsHLG() const { + return !have_gamma && (transfer_function == TransferFunction::kHLG); + } + bool Is709() const { + return !have_gamma && (transfer_function == TransferFunction::k709); + } + bool IsDCI() const { + return !have_gamma && (transfer_function == TransferFunction::kDCI); + } + + double GetGamma() const { + JXL_ASSERT(have_gamma); + return gamma * (1.0 / kGammaMul); // (0, 1) + } + Status SetGamma(double new_gamma) { + if (new_gamma < (1.0 / kMaxGamma) || new_gamma > 1.0) { + return JXL_FAILURE("Invalid gamma %f", new_gamma); + } + + have_gamma = false; + if (ApproxEq(new_gamma, 1.0)) { + transfer_function = TransferFunction::kLinear; + return true; + } + if (ApproxEq(new_gamma, 1.0 / 2.6)) { + transfer_function = TransferFunction::kDCI; + return true; + } + // Don't translate 0.45.. to kSRGB nor k709 - that might change pixel + // values because those curves also have a linear part. + + have_gamma = true; + gamma = roundf(new_gamma * kGammaMul); + transfer_function = TransferFunction::kUnknown; + return true; + } + + bool IsSame(const CustomTransferFunction& other) const { + if (have_gamma != other.have_gamma) { + return false; + } + if (have_gamma) { + if (gamma != other.gamma) { + return false; + } + } else { + if (transfer_function != other.transfer_function) { + return false; + } + } + return true; + } +}; + +static inline Status ConvertExternalToInternalTransferFunction( + const JxlTransferFunction external, TransferFunction* internal) { + switch (external) { + case JXL_TRANSFER_FUNCTION_709: + *internal = TransferFunction::k709; + return true; + case JXL_TRANSFER_FUNCTION_UNKNOWN: + *internal = TransferFunction::kUnknown; + return true; + case JXL_TRANSFER_FUNCTION_LINEAR: + *internal = TransferFunction::kLinear; + return true; + case JXL_TRANSFER_FUNCTION_SRGB: + *internal = TransferFunction::kSRGB; + return true; + case JXL_TRANSFER_FUNCTION_PQ: + *internal = TransferFunction::kPQ; + return true; + case JXL_TRANSFER_FUNCTION_DCI: + *internal = TransferFunction::kDCI; + return true; + case JXL_TRANSFER_FUNCTION_HLG: + *internal = TransferFunction::kHLG; + return true; + case JXL_TRANSFER_FUNCTION_GAMMA: + return JXL_FAILURE("Gamma should be handled separately"); + } + return JXL_FAILURE("Invalid TransferFunction enum value"); +} + +// Compact encoding of data required to interpret and translate pixels to a +// known color space. Stored in Metadata. Thread-compatible. +struct ColorEncoding { + // Only valid if HaveFields() + WhitePoint white_point = WhitePoint::kD65; + Primaries primaries = Primaries::kSRGB; // Only valid if HasPrimaries() + RenderingIntent rendering_intent = RenderingIntent::kRelative; + + // When false, fields such as white_point and tf are invalid and must not be + // used. This occurs after setting a raw bytes-only ICC profile, only the + // ICC bytes may be used. The color_space_ field is still valid. + bool have_fields = true; + + IccBytes icc; // Valid ICC profile + + ColorSpace color_space = ColorSpace::kRGB; // Can be kUnknown + bool cmyk = false; + + // "late sync" fields + CustomTransferFunction tf; + Customxy white; // Only used if white_point == kCustom + Customxy red; // Only used if primaries == kCustom + Customxy green; // Only used if primaries == kCustom + Customxy blue; // Only used if primaries == kCustom + + // Returns false if the field is invalid and unusable. + bool HasPrimaries() const { + return (color_space != ColorSpace::kGray) && + (color_space != ColorSpace::kXYB); + } + + size_t Channels() const { return (color_space == ColorSpace::kGray) ? 1 : 3; } + + PrimariesCIExy GetPrimaries() const { + JXL_DASSERT(have_fields); + JXL_ASSERT(HasPrimaries()); + PrimariesCIExy xy; + switch (primaries) { + case Primaries::kCustom: + xy.r = red.GetValue(); + xy.g = green.GetValue(); + xy.b = blue.GetValue(); + return xy; + + case Primaries::kSRGB: + xy.r.x = 0.639998686; + xy.r.y = 0.330010138; + xy.g.x = 0.300003784; + xy.g.y = 0.600003357; + xy.b.x = 0.150002046; + xy.b.y = 0.059997204; + return xy; + + case Primaries::k2100: + xy.r.x = 0.708; + xy.r.y = 0.292; + xy.g.x = 0.170; + xy.g.y = 0.797; + xy.b.x = 0.131; + xy.b.y = 0.046; + return xy; + + case Primaries::kP3: + xy.r.x = 0.680; + xy.r.y = 0.320; + xy.g.x = 0.265; + xy.g.y = 0.690; + xy.b.x = 0.150; + xy.b.y = 0.060; + return xy; + } + JXL_UNREACHABLE("Invalid Primaries %u", static_cast<uint32_t>(primaries)); + } + + Status SetPrimaries(const PrimariesCIExy& xy) { + JXL_DASSERT(have_fields); + JXL_ASSERT(HasPrimaries()); + if (xy.r.x == 0.0 || xy.r.y == 0.0 || xy.g.x == 0.0 || xy.g.y == 0.0 || + xy.b.x == 0.0 || xy.b.y == 0.0) { + return JXL_FAILURE("Invalid primaries %f %f %f %f %f %f", xy.r.x, xy.r.y, + xy.g.x, xy.g.y, xy.b.x, xy.b.y); + } + + if (ApproxEq(xy.r.x, 0.64) && ApproxEq(xy.r.y, 0.33) && + ApproxEq(xy.g.x, 0.30) && ApproxEq(xy.g.y, 0.60) && + ApproxEq(xy.b.x, 0.15) && ApproxEq(xy.b.y, 0.06)) { + primaries = Primaries::kSRGB; + return true; + } + + if (ApproxEq(xy.r.x, 0.708) && ApproxEq(xy.r.y, 0.292) && + ApproxEq(xy.g.x, 0.170) && ApproxEq(xy.g.y, 0.797) && + ApproxEq(xy.b.x, 0.131) && ApproxEq(xy.b.y, 0.046)) { + primaries = Primaries::k2100; + return true; + } + if (ApproxEq(xy.r.x, 0.680) && ApproxEq(xy.r.y, 0.320) && + ApproxEq(xy.g.x, 0.265) && ApproxEq(xy.g.y, 0.690) && + ApproxEq(xy.b.x, 0.150) && ApproxEq(xy.b.y, 0.060)) { + primaries = Primaries::kP3; + return true; + } + + primaries = Primaries::kCustom; + JXL_RETURN_IF_ERROR(red.SetValue(xy.r)); + JXL_RETURN_IF_ERROR(green.SetValue(xy.g)); + JXL_RETURN_IF_ERROR(blue.SetValue(xy.b)); + return true; + } + + CIExy GetWhitePoint() const { + JXL_DASSERT(have_fields); + CIExy xy; + switch (white_point) { + case WhitePoint::kCustom: + return white.GetValue(); + + case WhitePoint::kD65: + xy.x = 0.3127; + xy.y = 0.3290; + return xy; + + case WhitePoint::kDCI: + // From https://ieeexplore.ieee.org/document/7290729 C.2 page 11 + xy.x = 0.314; + xy.y = 0.351; + return xy; + + case WhitePoint::kE: + xy.x = xy.y = 1.0 / 3; + return xy; + } + JXL_UNREACHABLE("Invalid WhitePoint %u", + static_cast<uint32_t>(white_point)); + } + + Status SetWhitePoint(const CIExy& xy) { + JXL_DASSERT(have_fields); + if (xy.x == 0.0 || xy.y == 0.0) { + return JXL_FAILURE("Invalid white point %f %f", xy.x, xy.y); + } + if (ApproxEq(xy.x, 0.3127) && ApproxEq(xy.y, 0.3290)) { + white_point = WhitePoint::kD65; + return true; + } + if (ApproxEq(xy.x, 1.0 / 3) && ApproxEq(xy.y, 1.0 / 3)) { + white_point = WhitePoint::kE; + return true; + } + if (ApproxEq(xy.x, 0.314) && ApproxEq(xy.y, 0.351)) { + white_point = WhitePoint::kDCI; + return true; + } + white_point = WhitePoint::kCustom; + return white.SetValue(xy); + } + + // Checks if the color spaces (including white point / primaries) are the + // same, but ignores the transfer function, rendering intent and ICC bytes. + bool SameColorSpace(const ColorEncoding& other) const { + if (color_space != other.color_space) return false; + + if (white_point != other.white_point) return false; + if (white_point == WhitePoint::kCustom) { + if (!white.IsSame(other.white)) { + return false; + } + } + + if (HasPrimaries() != other.HasPrimaries()) return false; + if (HasPrimaries()) { + if (primaries != other.primaries) return false; + if (primaries == Primaries::kCustom) { + if (!red.IsSame(other.red)) return false; + if (!green.IsSame(other.green)) return false; + if (!blue.IsSame(other.blue)) return false; + } + } + return true; + } + + // Checks if the color space and transfer function are the same, ignoring + // rendering intent and ICC bytes + bool SameColorEncoding(const ColorEncoding& other) const { + return SameColorSpace(other) && tf.IsSame(other.tf); + } + + // Returns true if all fields have been initialized (possibly to kUnknown). + // Returns false if the ICC profile is invalid or decoding it fails. + Status SetFieldsFromICC(IccBytes&& new_icc, const JxlCmsInterface& cms) { + // In case parsing fails, mark the ColorEncoding as invalid. + JXL_ASSERT(!new_icc.empty()); + color_space = ColorSpace::kUnknown; + tf.transfer_function = TransferFunction::kUnknown; + icc.clear(); + + JxlColorEncoding external; + JXL_BOOL new_cmyk; + JXL_RETURN_IF_ERROR(cms.set_fields_from_icc(cms.set_fields_data, + new_icc.data(), new_icc.size(), + &external, &new_cmyk)); + cmyk = new_cmyk; + JXL_RETURN_IF_ERROR(FromExternal(external)); + icc = std::move(new_icc); + return true; + } + + JxlColorEncoding ToExternal() const { + JxlColorEncoding external = {}; + if (!have_fields) { + external.color_space = JXL_COLOR_SPACE_UNKNOWN; + external.primaries = JXL_PRIMARIES_CUSTOM; + external.rendering_intent = JXL_RENDERING_INTENT_PERCEPTUAL; //? + external.transfer_function = JXL_TRANSFER_FUNCTION_UNKNOWN; + external.white_point = JXL_WHITE_POINT_CUSTOM; + return external; + } + external.color_space = static_cast<JxlColorSpace>(color_space); + + external.white_point = static_cast<JxlWhitePoint>(white_point); + + CIExy wp = GetWhitePoint(); + external.white_point_xy[0] = wp.x; + external.white_point_xy[1] = wp.y; + + if (external.color_space == JXL_COLOR_SPACE_RGB || + external.color_space == JXL_COLOR_SPACE_UNKNOWN) { + external.primaries = static_cast<JxlPrimaries>(primaries); + PrimariesCIExy p = GetPrimaries(); + external.primaries_red_xy[0] = p.r.x; + external.primaries_red_xy[1] = p.r.y; + external.primaries_green_xy[0] = p.g.x; + external.primaries_green_xy[1] = p.g.y; + external.primaries_blue_xy[0] = p.b.x; + external.primaries_blue_xy[1] = p.b.y; + } + + if (tf.have_gamma) { + external.transfer_function = JXL_TRANSFER_FUNCTION_GAMMA; + external.gamma = tf.GetGamma(); + } else { + external.transfer_function = + static_cast<JxlTransferFunction>(tf.GetTransferFunction()); + external.gamma = 0; + } + + external.rendering_intent = + static_cast<JxlRenderingIntent>(rendering_intent); + return external; + } + + // NB: does not create ICC. + Status FromExternal(const JxlColorEncoding& external) { + // TODO(eustas): update non-serializable on call-site + color_space = static_cast<ColorSpace>(external.color_space); + + JXL_RETURN_IF_ERROR( + WhitePointFromExternal(external.white_point, &white_point)); + if (external.white_point == JXL_WHITE_POINT_CUSTOM) { + CIExy wp; + wp.x = external.white_point_xy[0]; + wp.y = external.white_point_xy[1]; + JXL_RETURN_IF_ERROR(SetWhitePoint(wp)); + } + + if (external.color_space == JXL_COLOR_SPACE_RGB || + external.color_space == JXL_COLOR_SPACE_UNKNOWN) { + JXL_RETURN_IF_ERROR( + PrimariesFromExternal(external.primaries, &primaries)); + if (external.primaries == JXL_PRIMARIES_CUSTOM) { + PrimariesCIExy primaries; + primaries.r.x = external.primaries_red_xy[0]; + primaries.r.y = external.primaries_red_xy[1]; + primaries.g.x = external.primaries_green_xy[0]; + primaries.g.y = external.primaries_green_xy[1]; + primaries.b.x = external.primaries_blue_xy[0]; + primaries.b.y = external.primaries_blue_xy[1]; + JXL_RETURN_IF_ERROR(SetPrimaries(primaries)); + } + } + CustomTransferFunction tf; + if (external.transfer_function == JXL_TRANSFER_FUNCTION_GAMMA) { + JXL_RETURN_IF_ERROR(tf.SetGamma(external.gamma)); + } else { + TransferFunction tf_enum; + // JXL_TRANSFER_FUNCTION_GAMMA is not handled by this function since + // there's no internal enum value for it. + JXL_RETURN_IF_ERROR(ConvertExternalToInternalTransferFunction( + external.transfer_function, &tf_enum)); + tf.SetTransferFunction(tf_enum); + } + this->tf = tf; + + JXL_RETURN_IF_ERROR(RenderingIntentFromExternal(external.rendering_intent, + &rendering_intent)); + + icc.clear(); + + return true; + } +}; + +} // namespace cms +} // namespace jxl + +#endif // LIB_JXL_CMS_COLOR_ENCODING_CMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/cms/jxl_cms.cc b/third_party/jpeg-xl/lib/jxl/cms/jxl_cms.cc new file mode 100644 index 0000000000..dd00b8b81f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cms/jxl_cms.cc @@ -0,0 +1,1343 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/cms.h> + +#ifndef JPEGXL_ENABLE_SKCMS +#define JPEGXL_ENABLE_SKCMS 0 +#endif + +#include <jxl/cms_interface.h> + +#include <algorithm> +#include <array> +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <memory> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/cms/jxl_cms.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/matrix_ops.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/cms/jxl_cms_internal.h" +#include "lib/jxl/cms/transfer_functions-inl.h" +#include "lib/jxl/color_encoding_internal.h" +#if JPEGXL_ENABLE_SKCMS +#include "skcms.h" +#else // JPEGXL_ENABLE_SKCMS +#include "lcms2.h" +#include "lcms2_plugin.h" +#endif // JPEGXL_ENABLE_SKCMS + +#define JXL_CMS_VERBOSE 0 + +// Define these only once. We can't use HWY_ONCE here because it is defined as +// 1 only on the last pass. +#ifndef LIB_JXL_JXL_CMS_CC +#define LIB_JXL_JXL_CMS_CC + +namespace jxl { +namespace { + +using ::jxl::cms::ColorEncoding; + +struct JxlCms { +#if JPEGXL_ENABLE_SKCMS + IccBytes icc_src, icc_dst; + skcms_ICCProfile profile_src, profile_dst; +#else + void* lcms_transform; +#endif + + // These fields are used when the HLG OOTF or inverse OOTF must be applied. + bool apply_hlg_ootf; + size_t hlg_ootf_num_channels; + // Y component of the primaries. + std::array<float, 3> hlg_ootf_luminances; + + size_t channels_src; + size_t channels_dst; + + std::vector<float> src_storage; + std::vector<float*> buf_src; + std::vector<float> dst_storage; + std::vector<float*> buf_dst; + + float intensity_target; + bool skip_lcms = false; + ExtraTF preprocess = ExtraTF::kNone; + ExtraTF postprocess = ExtraTF::kNone; +}; + +Status ApplyHlgOotf(JxlCms* t, float* JXL_RESTRICT buf, size_t xsize, + bool forward); +} // namespace +} // namespace jxl + +#endif // LIB_JXL_JXL_CMS_CC + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +#if JXL_CMS_VERBOSE >= 2 +const size_t kX = 0; // pixel index, multiplied by 3 for RGB +#endif + +// xform_src = UndoGammaCompression(buf_src). +Status BeforeTransform(JxlCms* t, const float* buf_src, float* xform_src, + size_t buf_size) { + switch (t->preprocess) { + case ExtraTF::kNone: + JXL_DASSERT(false); // unreachable + break; + + case ExtraTF::kPQ: { + HWY_FULL(float) df; + TF_PQ tf_pq(t->intensity_target); + for (size_t i = 0; i < buf_size; i += Lanes(df)) { + const auto val = Load(df, buf_src + i); + const auto result = tf_pq.DisplayFromEncoded(df, val); + Store(result, df, xform_src + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("pre in %.4f %.4f %.4f undoPQ %.4f %.4f %.4f\n", buf_src[3 * kX], + buf_src[3 * kX + 1], buf_src[3 * kX + 2], xform_src[3 * kX], + xform_src[3 * kX + 1], xform_src[3 * kX + 2]); +#endif + break; + } + + case ExtraTF::kHLG: + for (size_t i = 0; i < buf_size; ++i) { + xform_src[i] = static_cast<float>( + TF_HLG_Base::DisplayFromEncoded(static_cast<double>(buf_src[i]))); + } + if (t->apply_hlg_ootf) { + JXL_RETURN_IF_ERROR( + ApplyHlgOotf(t, xform_src, buf_size, /*forward=*/true)); + } +#if JXL_CMS_VERBOSE >= 2 + printf("pre in %.4f %.4f %.4f undoHLG %.4f %.4f %.4f\n", buf_src[3 * kX], + buf_src[3 * kX + 1], buf_src[3 * kX + 2], xform_src[3 * kX], + xform_src[3 * kX + 1], xform_src[3 * kX + 2]); +#endif + break; + + case ExtraTF::kSRGB: + HWY_FULL(float) df; + for (size_t i = 0; i < buf_size; i += Lanes(df)) { + const auto val = Load(df, buf_src + i); + const auto result = TF_SRGB().DisplayFromEncoded(val); + Store(result, df, xform_src + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("pre in %.4f %.4f %.4f undoSRGB %.4f %.4f %.4f\n", buf_src[3 * kX], + buf_src[3 * kX + 1], buf_src[3 * kX + 2], xform_src[3 * kX], + xform_src[3 * kX + 1], xform_src[3 * kX + 2]); +#endif + break; + } + return true; +} + +// Applies gamma compression in-place. +Status AfterTransform(JxlCms* t, float* JXL_RESTRICT buf_dst, size_t buf_size) { + switch (t->postprocess) { + case ExtraTF::kNone: + JXL_DASSERT(false); // unreachable + break; + case ExtraTF::kPQ: { + HWY_FULL(float) df; + TF_PQ tf_pq(t->intensity_target); + for (size_t i = 0; i < buf_size; i += Lanes(df)) { + const auto val = Load(df, buf_dst + i); + const auto result = tf_pq.EncodedFromDisplay(df, val); + Store(result, df, buf_dst + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("after PQ enc %.4f %.4f %.4f\n", buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + break; + } + case ExtraTF::kHLG: + if (t->apply_hlg_ootf) { + JXL_RETURN_IF_ERROR( + ApplyHlgOotf(t, buf_dst, buf_size, /*forward=*/false)); + } + for (size_t i = 0; i < buf_size; ++i) { + buf_dst[i] = static_cast<float>( + TF_HLG_Base::EncodedFromDisplay(static_cast<double>(buf_dst[i]))); + } +#if JXL_CMS_VERBOSE >= 2 + printf("after HLG enc %.4f %.4f %.4f\n", buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + break; + case ExtraTF::kSRGB: + HWY_FULL(float) df; + for (size_t i = 0; i < buf_size; i += Lanes(df)) { + const auto val = Load(df, buf_dst + i); + const auto result = TF_SRGB().EncodedFromDisplay(df, val); + Store(result, df, buf_dst + i); + } +#if JXL_CMS_VERBOSE >= 2 + printf("after SRGB enc %.4f %.4f %.4f\n", buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + break; + } + return true; +} + +Status DoColorSpaceTransform(void* cms_data, const size_t thread, + const float* buf_src, float* buf_dst, + size_t xsize) { + // No lock needed. + JxlCms* t = reinterpret_cast<JxlCms*>(cms_data); + + const float* xform_src = buf_src; // Read-only. + if (t->preprocess != ExtraTF::kNone) { + float* mutable_xform_src = t->buf_src[thread]; // Writable buffer. + JXL_RETURN_IF_ERROR(BeforeTransform(t, buf_src, mutable_xform_src, + xsize * t->channels_src)); + xform_src = mutable_xform_src; + } + +#if JPEGXL_ENABLE_SKCMS + if (t->channels_src == 1 && !t->skip_lcms) { + // Expand from 1 to 3 channels, starting from the end in case + // xform_src == t->buf_src[thread]. + float* mutable_xform_src = t->buf_src[thread]; + for (size_t i = 0; i < xsize; ++i) { + const size_t x = xsize - i - 1; + mutable_xform_src[x * 3] = mutable_xform_src[x * 3 + 1] = + mutable_xform_src[x * 3 + 2] = xform_src[x]; + } + xform_src = mutable_xform_src; + } +#else + if (t->channels_src == 4 && !t->skip_lcms) { + // LCMS does CMYK in a weird way: 0 = white, 100 = max ink + float* mutable_xform_src = t->buf_src[thread]; + for (size_t x = 0; x < xsize * 4; ++x) { + mutable_xform_src[x] = 100.f - 100.f * mutable_xform_src[x]; + } + xform_src = mutable_xform_src; + } +#endif + +#if JXL_CMS_VERBOSE >= 2 + // Save inputs for printing before in-place transforms overwrite them. + const float in0 = xform_src[3 * kX + 0]; + const float in1 = xform_src[3 * kX + 1]; + const float in2 = xform_src[3 * kX + 2]; +#endif + + if (t->skip_lcms) { + if (buf_dst != xform_src) { + memcpy(buf_dst, xform_src, xsize * t->channels_src * sizeof(*buf_dst)); + } // else: in-place, no need to copy + } else { +#if JPEGXL_ENABLE_SKCMS + JXL_CHECK( + skcms_Transform(xform_src, + (t->channels_src == 4 ? skcms_PixelFormat_RGBA_ffff + : skcms_PixelFormat_RGB_fff), + skcms_AlphaFormat_Opaque, &t->profile_src, buf_dst, + skcms_PixelFormat_RGB_fff, skcms_AlphaFormat_Opaque, + &t->profile_dst, xsize)); +#else // JPEGXL_ENABLE_SKCMS + cmsDoTransform(t->lcms_transform, xform_src, buf_dst, + static_cast<cmsUInt32Number>(xsize)); +#endif // JPEGXL_ENABLE_SKCMS + } +#if JXL_CMS_VERBOSE >= 2 + printf("xform skip%d: %.4f %.4f %.4f (%p) -> (%p) %.4f %.4f %.4f\n", + t->skip_lcms, in0, in1, in2, xform_src, buf_dst, buf_dst[3 * kX], + buf_dst[3 * kX + 1], buf_dst[3 * kX + 2]); +#endif + +#if JPEGXL_ENABLE_SKCMS + if (t->channels_dst == 1 && !t->skip_lcms) { + // Contract back from 3 to 1 channel, this time forward. + float* grayscale_buf_dst = t->buf_dst[thread]; + for (size_t x = 0; x < xsize; ++x) { + grayscale_buf_dst[x] = buf_dst[x * 3]; + } + buf_dst = grayscale_buf_dst; + } +#endif + + if (t->postprocess != ExtraTF::kNone) { + JXL_RETURN_IF_ERROR(AfterTransform(t, buf_dst, xsize * t->channels_dst)); + } + return true; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { + +HWY_EXPORT(DoColorSpaceTransform); +int DoColorSpaceTransform(void* t, size_t thread, const float* buf_src, + float* buf_dst, size_t xsize) { + return HWY_DYNAMIC_DISPATCH(DoColorSpaceTransform)(t, thread, buf_src, + buf_dst, xsize); +} + +// Define to 1 on OS X as a workaround for older LCMS lacking MD5. +#define JXL_CMS_OLD_VERSION 0 + +#if JPEGXL_ENABLE_SKCMS + +JXL_MUST_USE_RESULT CIExy CIExyFromXYZ(const float XYZ[3]) { + const float factor = 1.f / (XYZ[0] + XYZ[1] + XYZ[2]); + CIExy xy; + xy.x = XYZ[0] * factor; + xy.y = XYZ[1] * factor; + return xy; +} + +#else // JPEGXL_ENABLE_SKCMS +// (LCMS interface requires xyY but we omit the Y for white points/primaries.) + +JXL_MUST_USE_RESULT CIExy CIExyFromxyY(const cmsCIExyY& xyY) { + CIExy xy; + xy.x = xyY.x; + xy.y = xyY.y; + return xy; +} + +JXL_MUST_USE_RESULT CIExy CIExyFromXYZ(const cmsCIEXYZ& XYZ) { + cmsCIExyY xyY; + cmsXYZ2xyY(/*Dest=*/&xyY, /*Source=*/&XYZ); + return CIExyFromxyY(xyY); +} + +JXL_MUST_USE_RESULT cmsCIEXYZ D50_XYZ() { + // Quantized D50 as stored in ICC profiles. + return {0.96420288, 1.0, 0.82490540}; +} + +// RAII + +struct ProfileDeleter { + void operator()(void* p) { cmsCloseProfile(p); } +}; +using Profile = std::unique_ptr<void, ProfileDeleter>; + +struct TransformDeleter { + void operator()(void* p) { cmsDeleteTransform(p); } +}; +using Transform = std::unique_ptr<void, TransformDeleter>; + +struct CurveDeleter { + void operator()(cmsToneCurve* p) { cmsFreeToneCurve(p); } +}; +using Curve = std::unique_ptr<cmsToneCurve, CurveDeleter>; + +Status CreateProfileXYZ(const cmsContext context, + Profile* JXL_RESTRICT profile) { + profile->reset(cmsCreateXYZProfileTHR(context)); + if (profile->get() == nullptr) return JXL_FAILURE("Failed to create XYZ"); + return true; +} + +#endif // !JPEGXL_ENABLE_SKCMS + +#if JPEGXL_ENABLE_SKCMS +// IMPORTANT: icc must outlive profile. +Status DecodeProfile(const uint8_t* icc, size_t size, + skcms_ICCProfile* const profile) { + if (!skcms_Parse(icc, size, profile)) { + return JXL_FAILURE("Failed to parse ICC profile with %" PRIuS " bytes", + size); + } + return true; +} +#else // JPEGXL_ENABLE_SKCMS +Status DecodeProfile(const cmsContext context, Span<const uint8_t> icc, + Profile* profile) { + profile->reset(cmsOpenProfileFromMemTHR(context, icc.data(), icc.size())); + if (profile->get() == nullptr) { + return JXL_FAILURE("Failed to decode profile"); + } + + // WARNING: due to the LCMS MD5 issue mentioned above, many existing + // profiles have incorrect MD5, so do not even bother checking them nor + // generating warning clutter. + + return true; +} +#endif // JPEGXL_ENABLE_SKCMS + +#if JPEGXL_ENABLE_SKCMS + +ColorSpace ColorSpaceFromProfile(const skcms_ICCProfile& profile) { + switch (profile.data_color_space) { + case skcms_Signature_RGB: + case skcms_Signature_CMYK: + // spec says CMYK is encoded as RGB (the kBlack extra channel signals that + // it is actually CMYK) + return ColorSpace::kRGB; + case skcms_Signature_Gray: + return ColorSpace::kGray; + default: + return ColorSpace::kUnknown; + } +} + +// vector_out := matmul(matrix, vector_in) +void MatrixProduct(const skcms_Matrix3x3& matrix, const float vector_in[3], + float vector_out[3]) { + for (int i = 0; i < 3; ++i) { + vector_out[i] = 0; + for (int j = 0; j < 3; ++j) { + vector_out[i] += matrix.vals[i][j] * vector_in[j]; + } + } +} + +// Returns white point that was specified when creating the profile. +JXL_MUST_USE_RESULT Status UnadaptedWhitePoint(const skcms_ICCProfile& profile, + CIExy* out) { + float media_white_point_XYZ[3]; + if (!skcms_GetWTPT(&profile, media_white_point_XYZ)) { + return JXL_FAILURE("ICC profile does not contain WhitePoint tag"); + } + skcms_Matrix3x3 CHAD; + if (!skcms_GetCHAD(&profile, &CHAD)) { + // If there is no chromatic adaptation matrix, it means that the white point + // is already unadapted. + *out = CIExyFromXYZ(media_white_point_XYZ); + return true; + } + // Otherwise, it has been adapted to the PCS white point using said matrix, + // and the adaptation needs to be undone. + skcms_Matrix3x3 inverse_CHAD; + if (!skcms_Matrix3x3_invert(&CHAD, &inverse_CHAD)) { + return JXL_FAILURE("Non-invertible ChromaticAdaptation matrix"); + } + float unadapted_white_point_XYZ[3]; + MatrixProduct(inverse_CHAD, media_white_point_XYZ, unadapted_white_point_XYZ); + *out = CIExyFromXYZ(unadapted_white_point_XYZ); + return true; +} + +Status IdentifyPrimaries(const skcms_ICCProfile& profile, + const CIExy& wp_unadapted, ColorEncoding* c) { + if (!c->HasPrimaries()) return true; + + skcms_Matrix3x3 CHAD, inverse_CHAD; + if (skcms_GetCHAD(&profile, &CHAD)) { + JXL_RETURN_IF_ERROR(skcms_Matrix3x3_invert(&CHAD, &inverse_CHAD)); + } else { + static constexpr skcms_Matrix3x3 kLMSFromXYZ = { + {{0.8951, 0.2664, -0.1614}, + {-0.7502, 1.7135, 0.0367}, + {0.0389, -0.0685, 1.0296}}}; + static constexpr skcms_Matrix3x3 kXYZFromLMS = { + {{0.9869929, -0.1470543, 0.1599627}, + {0.4323053, 0.5183603, 0.0492912}, + {-0.0085287, 0.0400428, 0.9684867}}}; + static constexpr float kWpD50XYZ[3] = {0.96420288, 1.0, 0.82490540}; + float wp_unadapted_XYZ[3]; + JXL_RETURN_IF_ERROR( + CIEXYZFromWhiteCIExy(wp_unadapted.x, wp_unadapted.y, wp_unadapted_XYZ)); + float wp_D50_LMS[3], wp_unadapted_LMS[3]; + MatrixProduct(kLMSFromXYZ, kWpD50XYZ, wp_D50_LMS); + MatrixProduct(kLMSFromXYZ, wp_unadapted_XYZ, wp_unadapted_LMS); + inverse_CHAD = {{{wp_unadapted_LMS[0] / wp_D50_LMS[0], 0, 0}, + {0, wp_unadapted_LMS[1] / wp_D50_LMS[1], 0}, + {0, 0, wp_unadapted_LMS[2] / wp_D50_LMS[2]}}}; + inverse_CHAD = skcms_Matrix3x3_concat(&kXYZFromLMS, &inverse_CHAD); + inverse_CHAD = skcms_Matrix3x3_concat(&inverse_CHAD, &kLMSFromXYZ); + } + + float XYZ[3]; + PrimariesCIExy primaries; + CIExy* const chromaticities[] = {&primaries.r, &primaries.g, &primaries.b}; + for (int i = 0; i < 3; ++i) { + float RGB[3] = {}; + RGB[i] = 1; + skcms_Transform(RGB, skcms_PixelFormat_RGB_fff, skcms_AlphaFormat_Opaque, + &profile, XYZ, skcms_PixelFormat_RGB_fff, + skcms_AlphaFormat_Opaque, skcms_XYZD50_profile(), 1); + float unadapted_XYZ[3]; + MatrixProduct(inverse_CHAD, XYZ, unadapted_XYZ); + *chromaticities[i] = CIExyFromXYZ(unadapted_XYZ); + } + return c->SetPrimaries(primaries); +} + +bool IsApproximatelyEqual(const skcms_ICCProfile& profile, + const ColorEncoding& JXL_RESTRICT c) { + IccBytes bytes; + if (!MaybeCreateProfile(c.ToExternal(), &bytes)) { + return false; + } + + skcms_ICCProfile profile_test; + if (!DecodeProfile(bytes.data(), bytes.size(), &profile_test)) { + return false; + } + + if (!skcms_ApproximatelyEqualProfiles(&profile_test, &profile)) { + return false; + } + + return true; +} + +void DetectTransferFunction(const skcms_ICCProfile& profile, + ColorEncoding* JXL_RESTRICT c) { + JXL_CHECK(c->color_space != ColorSpace::kXYB); + + float gamma[3] = {}; + if (profile.has_trc) { + const auto IsGamma = [](const skcms_TransferFunction& tf) { + return tf.a == 1 && tf.b == 0 && + /* if b and d are zero, it is fine for c not to be */ tf.d == 0 && + tf.e == 0 && tf.f == 0; + }; + for (int i = 0; i < 3; ++i) { + if (profile.trc[i].table_entries == 0 && + IsGamma(profile.trc->parametric)) { + gamma[i] = 1.f / profile.trc->parametric.g; + } else { + skcms_TransferFunction approximate_tf; + float max_error; + if (skcms_ApproximateCurve(&profile.trc[i], &approximate_tf, + &max_error)) { + if (IsGamma(approximate_tf)) { + gamma[i] = 1.f / approximate_tf.g; + } + } + } + } + } + if (gamma[0] != 0 && std::abs(gamma[0] - gamma[1]) < 1e-4f && + std::abs(gamma[1] - gamma[2]) < 1e-4f) { + if (c->tf.SetGamma(gamma[0])) { + if (IsApproximatelyEqual(profile, *c)) return; + } + } + + for (TransferFunction tf : Values<TransferFunction>()) { + // Can only create profile from known transfer function. + if (tf == TransferFunction::kUnknown) continue; + c->tf.SetTransferFunction(tf); + if (IsApproximatelyEqual(profile, *c)) return; + } + + c->tf.SetTransferFunction(TransferFunction::kUnknown); +} + +#else // JPEGXL_ENABLE_SKCMS + +uint32_t Type32(const ColorEncoding& c, bool cmyk) { + if (cmyk) return TYPE_CMYK_FLT; + if (c.color_space == ColorSpace::kGray) return TYPE_GRAY_FLT; + return TYPE_RGB_FLT; +} + +uint32_t Type64(const ColorEncoding& c) { + if (c.color_space == ColorSpace::kGray) return TYPE_GRAY_DBL; + return TYPE_RGB_DBL; +} + +ColorSpace ColorSpaceFromProfile(const Profile& profile) { + switch (cmsGetColorSpace(profile.get())) { + case cmsSigRgbData: + case cmsSigCmykData: + return ColorSpace::kRGB; + case cmsSigGrayData: + return ColorSpace::kGray; + default: + return ColorSpace::kUnknown; + } +} + +// "profile1" is pre-decoded to save time in DetectTransferFunction. +Status ProfileEquivalentToICC(const cmsContext context, const Profile& profile1, + const IccBytes& icc, const ColorEncoding& c) { + const uint32_t type_src = Type64(c); + + Profile profile2; + JXL_RETURN_IF_ERROR(DecodeProfile(context, Bytes(icc), &profile2)); + + Profile profile_xyz; + JXL_RETURN_IF_ERROR(CreateProfileXYZ(context, &profile_xyz)); + + const uint32_t intent = INTENT_RELATIVE_COLORIMETRIC; + const uint32_t flags = cmsFLAGS_NOOPTIMIZE | cmsFLAGS_BLACKPOINTCOMPENSATION | + cmsFLAGS_HIGHRESPRECALC; + Transform xform1(cmsCreateTransformTHR(context, profile1.get(), type_src, + profile_xyz.get(), TYPE_XYZ_DBL, + intent, flags)); + Transform xform2(cmsCreateTransformTHR(context, profile2.get(), type_src, + profile_xyz.get(), TYPE_XYZ_DBL, + intent, flags)); + if (xform1 == nullptr || xform2 == nullptr) { + return JXL_FAILURE("Failed to create transform"); + } + + double in[3]; + double out1[3]; + double out2[3]; + + // Uniformly spaced samples from very dark to almost fully bright. + const double init = 1E-3; + const double step = 0.2; + + if (c.color_space == ColorSpace::kGray) { + // Finer sampling and replicate each component. + for (in[0] = init; in[0] < 1.0; in[0] += step / 8) { + cmsDoTransform(xform1.get(), in, out1, 1); + cmsDoTransform(xform2.get(), in, out2, 1); + if (!cms::ApproxEq(out1[0], out2[0], 2E-4)) { + return false; + } + } + } else { + for (in[0] = init; in[0] < 1.0; in[0] += step) { + for (in[1] = init; in[1] < 1.0; in[1] += step) { + for (in[2] = init; in[2] < 1.0; in[2] += step) { + cmsDoTransform(xform1.get(), in, out1, 1); + cmsDoTransform(xform2.get(), in, out2, 1); + for (size_t i = 0; i < 3; ++i) { + if (!cms::ApproxEq(out1[i], out2[i], 2E-4)) { + return false; + } + } + } + } + } + } + + return true; +} + +// Returns white point that was specified when creating the profile. +// NOTE: we can't just use cmsSigMediaWhitePointTag because its interpretation +// differs between ICC versions. +JXL_MUST_USE_RESULT cmsCIEXYZ UnadaptedWhitePoint(const cmsContext context, + const Profile& profile, + const ColorEncoding& c) { + const cmsCIEXYZ* white_point = static_cast<const cmsCIEXYZ*>( + cmsReadTag(profile.get(), cmsSigMediaWhitePointTag)); + if (white_point != nullptr && + cmsReadTag(profile.get(), cmsSigChromaticAdaptationTag) == nullptr) { + // No chromatic adaptation matrix: the white point is already unadapted. + return *white_point; + } + + cmsCIEXYZ XYZ = {1.0, 1.0, 1.0}; + Profile profile_xyz; + if (!CreateProfileXYZ(context, &profile_xyz)) return XYZ; + // Array arguments are one per profile. + cmsHPROFILE profiles[2] = {profile.get(), profile_xyz.get()}; + // Leave white point unchanged - that is what we're trying to extract. + cmsUInt32Number intents[2] = {INTENT_ABSOLUTE_COLORIMETRIC, + INTENT_ABSOLUTE_COLORIMETRIC}; + cmsBool black_compensation[2] = {0, 0}; + cmsFloat64Number adaption[2] = {0.0, 0.0}; + // Only transforming a single pixel, so skip expensive optimizations. + cmsUInt32Number flags = cmsFLAGS_NOOPTIMIZE | cmsFLAGS_HIGHRESPRECALC; + Transform xform(cmsCreateExtendedTransform( + context, 2, profiles, black_compensation, intents, adaption, nullptr, 0, + Type64(c), TYPE_XYZ_DBL, flags)); + if (!xform) return XYZ; // TODO(lode): return error + + // xy are relative, so magnitude does not matter if we ignore output Y. + const cmsFloat64Number in[3] = {1.0, 1.0, 1.0}; + cmsDoTransform(xform.get(), in, &XYZ.X, 1); + return XYZ; +} + +Status IdentifyPrimaries(const cmsContext context, const Profile& profile, + const cmsCIEXYZ& wp_unadapted, ColorEncoding* c) { + if (!c->HasPrimaries()) return true; + if (ColorSpaceFromProfile(profile) == ColorSpace::kUnknown) return true; + + // These were adapted to the profile illuminant before storing in the profile. + const cmsCIEXYZ* adapted_r = static_cast<const cmsCIEXYZ*>( + cmsReadTag(profile.get(), cmsSigRedColorantTag)); + const cmsCIEXYZ* adapted_g = static_cast<const cmsCIEXYZ*>( + cmsReadTag(profile.get(), cmsSigGreenColorantTag)); + const cmsCIEXYZ* adapted_b = static_cast<const cmsCIEXYZ*>( + cmsReadTag(profile.get(), cmsSigBlueColorantTag)); + + cmsCIEXYZ converted_rgb[3]; + if (adapted_r == nullptr || adapted_g == nullptr || adapted_b == nullptr) { + // No colorant tag, determine the XYZ coordinates of the primaries by + // converting from the colorspace. + Profile profile_xyz; + if (!CreateProfileXYZ(context, &profile_xyz)) { + return JXL_FAILURE("Failed to retrieve colorants"); + } + // Array arguments are one per profile. + cmsHPROFILE profiles[2] = {profile.get(), profile_xyz.get()}; + cmsUInt32Number intents[2] = {INTENT_RELATIVE_COLORIMETRIC, + INTENT_RELATIVE_COLORIMETRIC}; + cmsBool black_compensation[2] = {0, 0}; + cmsFloat64Number adaption[2] = {0.0, 0.0}; + // Only transforming three pixels, so skip expensive optimizations. + cmsUInt32Number flags = cmsFLAGS_NOOPTIMIZE | cmsFLAGS_HIGHRESPRECALC; + Transform xform(cmsCreateExtendedTransform( + context, 2, profiles, black_compensation, intents, adaption, nullptr, 0, + Type64(*c), TYPE_XYZ_DBL, flags)); + if (!xform) return JXL_FAILURE("Failed to retrieve colorants"); + + const cmsFloat64Number in[9] = {1.0, 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 1.0}; + cmsDoTransform(xform.get(), in, &converted_rgb->X, 3); + adapted_r = &converted_rgb[0]; + adapted_g = &converted_rgb[1]; + adapted_b = &converted_rgb[2]; + } + + // TODO(janwas): no longer assume Bradford and D50. + // Undo the chromatic adaptation. + const cmsCIEXYZ d50 = D50_XYZ(); + + cmsCIEXYZ r, g, b; + cmsAdaptToIlluminant(&r, &d50, &wp_unadapted, adapted_r); + cmsAdaptToIlluminant(&g, &d50, &wp_unadapted, adapted_g); + cmsAdaptToIlluminant(&b, &d50, &wp_unadapted, adapted_b); + + const PrimariesCIExy rgb = {CIExyFromXYZ(r), CIExyFromXYZ(g), + CIExyFromXYZ(b)}; + return c->SetPrimaries(rgb); +} + +void DetectTransferFunction(const cmsContext context, const Profile& profile, + ColorEncoding* JXL_RESTRICT c) { + JXL_CHECK(c->color_space != ColorSpace::kXYB); + + float gamma = 0; + if (const auto* gray_trc = reinterpret_cast<const cmsToneCurve*>( + cmsReadTag(profile.get(), cmsSigGrayTRCTag))) { + const double estimated_gamma = + cmsEstimateGamma(gray_trc, /*precision=*/1e-4); + if (estimated_gamma > 0) { + gamma = 1. / estimated_gamma; + } + } else { + float rgb_gamma[3] = {}; + int i = 0; + for (const auto tag : + {cmsSigRedTRCTag, cmsSigGreenTRCTag, cmsSigBlueTRCTag}) { + if (const auto* trc = reinterpret_cast<const cmsToneCurve*>( + cmsReadTag(profile.get(), tag))) { + const double estimated_gamma = + cmsEstimateGamma(trc, /*precision=*/1e-4); + if (estimated_gamma > 0) { + rgb_gamma[i] = 1. / estimated_gamma; + } + } + ++i; + } + if (rgb_gamma[0] != 0 && std::abs(rgb_gamma[0] - rgb_gamma[1]) < 1e-4f && + std::abs(rgb_gamma[1] - rgb_gamma[2]) < 1e-4f) { + gamma = rgb_gamma[0]; + } + } + + if (gamma != 0 && c->tf.SetGamma(gamma)) { + IccBytes icc_test; + if (MaybeCreateProfile(c->ToExternal(), &icc_test) && + ProfileEquivalentToICC(context, profile, icc_test, *c)) { + return; + } + } + + for (TransferFunction tf : Values<TransferFunction>()) { + // Can only create profile from known transfer function. + if (tf == TransferFunction::kUnknown) continue; + + c->tf.SetTransferFunction(tf); + + IccBytes icc_test; + if (MaybeCreateProfile(c->ToExternal(), &icc_test) && + ProfileEquivalentToICC(context, profile, icc_test, *c)) { + return; + } + } + + c->tf.SetTransferFunction(TransferFunction::kUnknown); +} + +void ErrorHandler(cmsContext context, cmsUInt32Number code, const char* text) { + JXL_WARNING("LCMS error %u: %s", code, text); +} + +// Returns a context for the current thread, creating it if necessary. +cmsContext GetContext() { + static thread_local void* context_; + if (context_ == nullptr) { + context_ = cmsCreateContext(nullptr, nullptr); + JXL_ASSERT(context_ != nullptr); + + cmsSetLogErrorHandlerTHR(static_cast<cmsContext>(context_), &ErrorHandler); + } + return static_cast<cmsContext>(context_); +} + +#endif // JPEGXL_ENABLE_SKCMS + +Status GetPrimariesLuminances(const ColorEncoding& encoding, + float luminances[3]) { + // Explanation: + // We know that the three primaries must sum to white: + // + // [Xr, Xg, Xb; [1; [Xw; + // Yr, Yg, Yb; × 1; = Yw; + // Zr, Zg, Zb] 1] Zw] + // + // By noting that X = x·(X+Y+Z), Y = y·(X+Y+Z) and Z = z·(X+Y+Z) (note the + // lower case indicating chromaticity), and factoring the totals (X+Y+Z) out + // of the left matrix and into the all-ones vector, we get: + // + // [xr, xg, xb; [Xr + Yr + Zr; [Xw; + // yr, yg, yb; × Xg + Yg + Zg; = Yw; + // zr, zg, zb] Xb + Yb + Zb] Zw] + // + // Which makes it apparent that we can compute those totals as: + // + // [Xr + Yr + Zr; inv([xr, xg, xb; [Xw; + // Xg + Yg + Zg; = yr, yg, yb; × Yw; + // Xb + Yb + Zb] zr, zg, zb]) Zw] + // + // From there, by multiplying each total by its corresponding y, we get Y for + // that primary. + + float white_XYZ[3]; + CIExy wp = encoding.GetWhitePoint(); + JXL_RETURN_IF_ERROR(CIEXYZFromWhiteCIExy(wp.x, wp.y, white_XYZ)); + + const PrimariesCIExy primaries = encoding.GetPrimaries(); + double chromaticities[3][3] = { + {primaries.r.x, primaries.g.x, primaries.b.x}, + {primaries.r.y, primaries.g.y, primaries.b.y}, + {1 - primaries.r.x - primaries.r.y, 1 - primaries.g.x - primaries.g.y, + 1 - primaries.b.x - primaries.b.y}}; + JXL_RETURN_IF_ERROR(Inv3x3Matrix(&chromaticities[0][0])); + const double ys[3] = {primaries.r.y, primaries.g.y, primaries.b.y}; + for (size_t i = 0; i < 3; ++i) { + luminances[i] = ys[i] * (chromaticities[i][0] * white_XYZ[0] + + chromaticities[i][1] * white_XYZ[1] + + chromaticities[i][2] * white_XYZ[2]); + } + return true; +} + +Status ApplyHlgOotf(JxlCms* t, float* JXL_RESTRICT buf, size_t xsize, + bool forward) { + if (295 <= t->intensity_target && t->intensity_target <= 305) { + // The gamma is approximately 1 so this can essentially be skipped. + return true; + } + float gamma = 1.2f * std::pow(1.111f, std::log2(t->intensity_target * 1e-3f)); + if (!forward) gamma = 1.f / gamma; + + switch (t->hlg_ootf_num_channels) { + case 1: + for (size_t x = 0; x < xsize; ++x) { + buf[x] = std::pow(buf[x], gamma); + } + break; + + case 3: + for (size_t x = 0; x < xsize; x += 3) { + const float luminance = buf[x] * t->hlg_ootf_luminances[0] + + buf[x + 1] * t->hlg_ootf_luminances[1] + + buf[x + 2] * t->hlg_ootf_luminances[2]; + const float ratio = std::pow(luminance, gamma - 1); + if (std::isfinite(ratio)) { + buf[x] *= ratio; + buf[x + 1] *= ratio; + buf[x + 2] *= ratio; + if (forward && gamma < 1) { + // If gamma < 1, the ratio above will be > 1 which can push bright + // saturated highlights out of gamut. There are several possible + // ways to bring them back in-gamut; this one preserves hue and + // saturation at the slight expense of luminance. If !forward, the + // previously-applied forward OOTF with gamma > 1 already pushed + // those highlights down and we are simply putting them back where + // they were so this is not necessary. + const float maximum = + std::max(buf[x], std::max(buf[x + 1], buf[x + 2])); + if (maximum > 1) { + const float normalizer = 1.f / maximum; + buf[x] *= normalizer; + buf[x + 1] *= normalizer; + buf[x + 2] *= normalizer; + } + } + } + } + break; + + default: + return JXL_FAILURE("HLG OOTF not implemented for %" PRIuS " channels", + t->hlg_ootf_num_channels); + } + return true; +} + +bool IsKnownTransferFunction(jxl::cms::TransferFunction tf) { + using TF = jxl::cms::TransferFunction; + // All but kUnknown + return tf == TF::k709 || tf == TF::kLinear || tf == TF::kSRGB || + tf == TF::kPQ || tf == TF::kDCI || tf == TF::kHLG; +} + +constexpr uint8_t kColorPrimariesP3_D65 = 12; + +bool IsKnownColorPrimaries(uint8_t color_primaries) { + using P = jxl::cms::Primaries; + // All but kCustom + if (color_primaries == kColorPrimariesP3_D65) return true; + const auto p = static_cast<Primaries>(color_primaries); + return p == P::kSRGB || p == P::k2100 || p == P::kP3; +} + +bool ApplyCICP(const uint8_t color_primaries, + const uint8_t transfer_characteristics, + const uint8_t matrix_coefficients, const uint8_t full_range, + ColorEncoding* JXL_RESTRICT c) { + if (matrix_coefficients != 0) return false; + if (full_range != 1) return false; + + const auto primaries = static_cast<Primaries>(color_primaries); + const auto tf = static_cast<TransferFunction>(transfer_characteristics); + if (!IsKnownTransferFunction(tf)) return false; + if (!IsKnownColorPrimaries(color_primaries)) return false; + c->color_space = ColorSpace::kRGB; + c->tf.SetTransferFunction(tf); + if (primaries == Primaries::kP3) { + c->white_point = WhitePoint::kDCI; + c->primaries = Primaries::kP3; + } else if (color_primaries == kColorPrimariesP3_D65) { + c->white_point = WhitePoint::kD65; + c->primaries = Primaries::kP3; + } else { + c->white_point = WhitePoint::kD65; + c->primaries = primaries; + } + return true; +} + +JXL_BOOL JxlCmsSetFieldsFromICC(void* user_data, const uint8_t* icc_data, + size_t icc_size, JxlColorEncoding* c, + JXL_BOOL* cmyk) { + if (c == nullptr) return JXL_FALSE; + if (cmyk == nullptr) return JXL_FALSE; + + *cmyk = JXL_FALSE; + + // In case parsing fails, mark the ColorEncoding as invalid. + c->color_space = JXL_COLOR_SPACE_UNKNOWN; + c->transfer_function = JXL_TRANSFER_FUNCTION_UNKNOWN; + + if (icc_size == 0) return JXL_FAILURE("Empty ICC profile"); + + ColorEncoding c_enc; + +#if JPEGXL_ENABLE_SKCMS + if (icc_size < 128) { + return JXL_FAILURE("ICC file too small"); + } + + skcms_ICCProfile profile; + JXL_RETURN_IF_ERROR(skcms_Parse(icc_data, icc_size, &profile)); + + // skcms does not return the rendering intent, so get it from the file. It + // is encoded as big-endian 32-bit integer in bytes 60..63. + uint32_t rendering_intent32 = icc_data[67]; + if (rendering_intent32 > 3 || icc_data[64] != 0 || icc_data[65] != 0 || + icc_data[66] != 0) { + return JXL_FAILURE("Invalid rendering intent %u\n", rendering_intent32); + } + // ICC and RenderingIntent have the same values (0..3). + c_enc.rendering_intent = static_cast<RenderingIntent>(rendering_intent32); + + if (profile.has_CICP && + ApplyCICP(profile.CICP.color_primaries, + profile.CICP.transfer_characteristics, + profile.CICP.matrix_coefficients, + profile.CICP.video_full_range_flag, &c_enc)) { + *c = c_enc.ToExternal(); + return true; + } + + c_enc.color_space = ColorSpaceFromProfile(profile); + *cmyk = (profile.data_color_space == skcms_Signature_CMYK); + + CIExy wp_unadapted; + JXL_RETURN_IF_ERROR(UnadaptedWhitePoint(profile, &wp_unadapted)); + JXL_RETURN_IF_ERROR(c_enc.SetWhitePoint(wp_unadapted)); + + // Relies on color_space. + JXL_RETURN_IF_ERROR(IdentifyPrimaries(profile, wp_unadapted, &c_enc)); + + // Relies on color_space/white point/primaries being set already. + DetectTransferFunction(profile, &c_enc); +#else // JPEGXL_ENABLE_SKCMS + + const cmsContext context = GetContext(); + + Profile profile; + JXL_RETURN_IF_ERROR( + DecodeProfile(context, Bytes(icc_data, icc_size), &profile)); + + const cmsUInt32Number rendering_intent32 = + cmsGetHeaderRenderingIntent(profile.get()); + if (rendering_intent32 > 3) { + return JXL_FAILURE("Invalid rendering intent %u\n", rendering_intent32); + } + // ICC and RenderingIntent have the same values (0..3). + c_enc.rendering_intent = static_cast<RenderingIntent>(rendering_intent32); + + static constexpr size_t kCICPSize = 12; + static constexpr auto kCICPSignature = + static_cast<cmsTagSignature>(0x63696370); + uint8_t cicp_buffer[kCICPSize]; + if (cmsReadRawTag(profile.get(), kCICPSignature, cicp_buffer, kCICPSize) == + kCICPSize && + ApplyCICP(cicp_buffer[8], cicp_buffer[9], cicp_buffer[10], + cicp_buffer[11], &c_enc)) { + *c = c_enc.ToExternal(); + return true; + } + + c_enc.color_space = ColorSpaceFromProfile(profile); + if (cmsGetColorSpace(profile.get()) == cmsSigCmykData) { + *cmyk = JXL_TRUE; + *c = c_enc.ToExternal(); + return true; + } + + const cmsCIEXYZ wp_unadapted = UnadaptedWhitePoint(context, profile, c_enc); + JXL_RETURN_IF_ERROR(c_enc.SetWhitePoint(CIExyFromXYZ(wp_unadapted))); + + // Relies on color_space. + JXL_RETURN_IF_ERROR( + IdentifyPrimaries(context, profile, wp_unadapted, &c_enc)); + + // Relies on color_space/white point/primaries being set already. + DetectTransferFunction(context, profile, &c_enc); + +#endif // JPEGXL_ENABLE_SKCMS + + *c = c_enc.ToExternal(); + return true; +} + +} // namespace + +namespace { + +void JxlCmsDestroy(void* cms_data) { + if (cms_data == nullptr) return; + JxlCms* t = reinterpret_cast<JxlCms*>(cms_data); +#if !JPEGXL_ENABLE_SKCMS + TransformDeleter()(t->lcms_transform); +#endif + delete t; +} + +void AllocateBuffer(size_t length, size_t num_threads, + std::vector<float>* storage, std::vector<float*>* view) { + constexpr size_t kAlign = 128 / sizeof(float); + size_t stride = RoundUpTo(length, kAlign); + storage->resize(stride * num_threads + kAlign); + intptr_t addr = reinterpret_cast<intptr_t>(storage->data()); + size_t offset = + (RoundUpTo(addr, kAlign * sizeof(float)) - addr) / sizeof(float); + view->clear(); + view->reserve(num_threads); + for (size_t i = 0; i < num_threads; ++i) { + view->emplace_back(storage->data() + offset + i * stride); + } +} + +void* JxlCmsInit(void* init_data, size_t num_threads, size_t xsize, + const JxlColorProfile* input, const JxlColorProfile* output, + float intensity_target) { + JXL_ASSERT(init_data != nullptr); + auto cms = static_cast<const JxlCmsInterface*>(init_data); + auto t = jxl::make_unique<JxlCms>(); + IccBytes icc_src, icc_dst; + if (input->icc.size == 0) { + JXL_NOTIFY_ERROR("JxlCmsInit: empty input ICC"); + return nullptr; + } + if (output->icc.size == 0) { + JXL_NOTIFY_ERROR("JxlCmsInit: empty OUTPUT ICC"); + return nullptr; + } + icc_src.assign(input->icc.data, input->icc.data + input->icc.size); + ColorEncoding c_src; + if (!c_src.SetFieldsFromICC(std::move(icc_src), *cms)) { + JXL_NOTIFY_ERROR("JxlCmsInit: failed to parse input ICC"); + return nullptr; + } + icc_dst.assign(output->icc.data, output->icc.data + output->icc.size); + ColorEncoding c_dst; + if (!c_dst.SetFieldsFromICC(std::move(icc_dst), *cms)) { + JXL_NOTIFY_ERROR("JxlCmsInit: failed to parse output ICC"); + return nullptr; + } +#if JXL_CMS_VERBOSE + printf("%s -> %s\n", Description(c_src).c_str(), Description(c_dst).c_str()); +#endif + +#if JPEGXL_ENABLE_SKCMS + if (!DecodeProfile(input->icc.data, input->icc.size, &t->profile_src)) { + JXL_NOTIFY_ERROR("JxlCmsInit: skcms failed to parse input ICC"); + return nullptr; + } + if (!DecodeProfile(output->icc.data, output->icc.size, &t->profile_dst)) { + JXL_NOTIFY_ERROR("JxlCmsInit: skcms failed to parse output ICC"); + return nullptr; + } +#else // JPEGXL_ENABLE_SKCMS + const cmsContext context = GetContext(); + Profile profile_src, profile_dst; + if (!DecodeProfile(context, Bytes(c_src.icc), &profile_src)) { + JXL_NOTIFY_ERROR("JxlCmsInit: lcms failed to parse input ICC"); + return nullptr; + } + if (!DecodeProfile(context, Bytes(c_dst.icc), &profile_dst)) { + JXL_NOTIFY_ERROR("JxlCmsInit: lcms failed to parse output ICC"); + return nullptr; + } +#endif // JPEGXL_ENABLE_SKCMS + + t->skip_lcms = false; + if (c_src.SameColorEncoding(c_dst)) { + t->skip_lcms = true; +#if JXL_CMS_VERBOSE + printf("Skip CMS\n"); +#endif + } + + t->apply_hlg_ootf = c_src.tf.IsHLG() != c_dst.tf.IsHLG(); + if (t->apply_hlg_ootf) { + const ColorEncoding* c_hlg = c_src.tf.IsHLG() ? &c_src : &c_dst; + t->hlg_ootf_num_channels = c_hlg->Channels(); + if (t->hlg_ootf_num_channels == 3 && + !GetPrimariesLuminances(*c_hlg, t->hlg_ootf_luminances.data())) { + JXL_NOTIFY_ERROR( + "JxlCmsInit: failed to compute the luminances of primaries"); + return nullptr; + } + } + + // Special-case SRGB <=> linear if the primaries / white point are the same, + // or any conversion where PQ or HLG is involved: + bool src_linear = c_src.tf.IsLinear(); + const bool dst_linear = c_dst.tf.IsLinear(); + + if (c_src.tf.IsPQ() || c_src.tf.IsHLG() || + (c_src.tf.IsSRGB() && dst_linear && c_src.SameColorSpace(c_dst))) { + // Construct new profile as if the data were already/still linear. + ColorEncoding c_linear_src = c_src; + c_linear_src.tf.SetTransferFunction(TransferFunction::kLinear); +#if JPEGXL_ENABLE_SKCMS + skcms_ICCProfile new_src; +#else // JPEGXL_ENABLE_SKCMS + Profile new_src; +#endif // JPEGXL_ENABLE_SKCMS + // Only enable ExtraTF if profile creation succeeded. + if (MaybeCreateProfile(c_linear_src.ToExternal(), &icc_src) && +#if JPEGXL_ENABLE_SKCMS + DecodeProfile(icc_src.data(), icc_src.size(), &new_src)) { +#else // JPEGXL_ENABLE_SKCMS + DecodeProfile(context, Bytes(icc_src), &new_src)) { +#endif // JPEGXL_ENABLE_SKCMS +#if JXL_CMS_VERBOSE + printf("Special HLG/PQ/sRGB -> linear\n"); +#endif +#if JPEGXL_ENABLE_SKCMS + t->icc_src = std::move(icc_src); + t->profile_src = new_src; +#else // JPEGXL_ENABLE_SKCMS + profile_src.swap(new_src); +#endif // JPEGXL_ENABLE_SKCMS + t->preprocess = c_src.tf.IsSRGB() + ? ExtraTF::kSRGB + : (c_src.tf.IsPQ() ? ExtraTF::kPQ : ExtraTF::kHLG); + c_src = c_linear_src; + src_linear = true; + } else { + if (t->apply_hlg_ootf) { + JXL_NOTIFY_ERROR( + "Failed to create extra linear source profile, and HLG OOTF " + "required"); + return nullptr; + } + JXL_WARNING("Failed to create extra linear destination profile"); + } + } + + if (c_dst.tf.IsPQ() || c_dst.tf.IsHLG() || + (c_dst.tf.IsSRGB() && src_linear && c_src.SameColorSpace(c_dst))) { + ColorEncoding c_linear_dst = c_dst; + c_linear_dst.tf.SetTransferFunction(TransferFunction::kLinear); +#if JPEGXL_ENABLE_SKCMS + skcms_ICCProfile new_dst; +#else // JPEGXL_ENABLE_SKCMS + Profile new_dst; +#endif // JPEGXL_ENABLE_SKCMS + // Only enable ExtraTF if profile creation succeeded. + if (MaybeCreateProfile(c_linear_dst.ToExternal(), &icc_dst) && +#if JPEGXL_ENABLE_SKCMS + DecodeProfile(icc_dst.data(), icc_dst.size(), &new_dst)) { +#else // JPEGXL_ENABLE_SKCMS + DecodeProfile(context, Bytes(icc_dst), &new_dst)) { +#endif // JPEGXL_ENABLE_SKCMS +#if JXL_CMS_VERBOSE + printf("Special linear -> HLG/PQ/sRGB\n"); +#endif +#if JPEGXL_ENABLE_SKCMS + t->icc_dst = std::move(icc_dst); + t->profile_dst = new_dst; +#else // JPEGXL_ENABLE_SKCMS + profile_dst.swap(new_dst); +#endif // JPEGXL_ENABLE_SKCMS + t->postprocess = c_dst.tf.IsSRGB() + ? ExtraTF::kSRGB + : (c_dst.tf.IsPQ() ? ExtraTF::kPQ : ExtraTF::kHLG); + c_dst = c_linear_dst; + } else { + if (t->apply_hlg_ootf) { + JXL_NOTIFY_ERROR( + "Failed to create extra linear destination profile, and inverse " + "HLG OOTF required"); + return nullptr; + } + JXL_WARNING("Failed to create extra linear destination profile"); + } + } + + if (c_src.SameColorEncoding(c_dst)) { +#if JXL_CMS_VERBOSE + printf("Same intermediary linear profiles, skipping CMS\n"); +#endif + t->skip_lcms = true; + } + +#if JPEGXL_ENABLE_SKCMS + if (!skcms_MakeUsableAsDestination(&t->profile_dst)) { + JXL_NOTIFY_ERROR( + "Failed to make %s usable as a color transform destination", + ColorEncodingDescription(c_dst.ToExternal()).c_str()); + return nullptr; + } +#endif // JPEGXL_ENABLE_SKCMS + + // Not including alpha channel (copied separately). + const size_t channels_src = (c_src.cmyk ? 4 : c_src.Channels()); + const size_t channels_dst = c_dst.Channels(); + JXL_CHECK(channels_src == channels_dst || + (channels_src == 4 && channels_dst == 3)); +#if JXL_CMS_VERBOSE + printf("Channels: %" PRIuS "; Threads: %" PRIuS "\n", channels_src, + num_threads); +#endif + +#if !JPEGXL_ENABLE_SKCMS + // Type includes color space (XYZ vs RGB), so can be different. + const uint32_t type_src = Type32(c_src, channels_src == 4); + const uint32_t type_dst = Type32(c_dst, false); + const uint32_t intent = static_cast<uint32_t>(c_dst.rendering_intent); + // Use cmsFLAGS_NOCACHE to disable the 1-pixel cache and make calling + // cmsDoTransform() thread-safe. + const uint32_t flags = cmsFLAGS_NOCACHE | cmsFLAGS_BLACKPOINTCOMPENSATION | + cmsFLAGS_HIGHRESPRECALC; + t->lcms_transform = + cmsCreateTransformTHR(context, profile_src.get(), type_src, + profile_dst.get(), type_dst, intent, flags); + if (t->lcms_transform == nullptr) { + JXL_NOTIFY_ERROR("Failed to create transform"); + return nullptr; + } +#endif // !JPEGXL_ENABLE_SKCMS + + // Ideally LCMS would convert directly from External to Image3. However, + // cmsDoTransformLineStride only accepts 32-bit BytesPerPlaneIn, whereas our + // planes can be more than 4 GiB apart. Hence, transform inputs/outputs must + // be interleaved. Calling cmsDoTransform for each pixel is expensive + // (indirect call). We therefore transform rows, which requires per-thread + // buffers. To avoid separate allocations, we use the rows of an image. + // Because LCMS apparently also cannot handle <= 16 bit inputs and 32-bit + // outputs (or vice versa), we use floating point input/output. + t->channels_src = channels_src; + t->channels_dst = channels_dst; + size_t actual_channels_src = channels_src; + size_t actual_channels_dst = channels_dst; +#if JPEGXL_ENABLE_SKCMS + // SkiaCMS doesn't support grayscale float buffers, so we create space for RGB + // float buffers anyway. + actual_channels_src = (channels_src == 4 ? 4 : 3); + actual_channels_dst = 3; +#endif + AllocateBuffer(xsize * actual_channels_src, num_threads, &t->src_storage, + &t->buf_src); + AllocateBuffer(xsize * actual_channels_dst, num_threads, &t->dst_storage, + &t->buf_dst); + t->intensity_target = intensity_target; + return t.release(); +} + +float* JxlCmsGetSrcBuf(void* cms_data, size_t thread) { + JxlCms* t = reinterpret_cast<JxlCms*>(cms_data); + return t->buf_src[thread]; +} + +float* JxlCmsGetDstBuf(void* cms_data, size_t thread) { + JxlCms* t = reinterpret_cast<JxlCms*>(cms_data); + return t->buf_dst[thread]; +} + +} // namespace + +extern "C" { + +JXL_CMS_EXPORT const JxlCmsInterface* JxlGetDefaultCms() { + static constexpr JxlCmsInterface kInterface = { + /*set_fields_data=*/nullptr, + /*set_fields_from_icc=*/&JxlCmsSetFieldsFromICC, + /*init_data=*/const_cast<void*>(static_cast<const void*>(&kInterface)), + /*init=*/&JxlCmsInit, + /*get_src_buf=*/&JxlCmsGetSrcBuf, + /*get_dst_buf=*/&JxlCmsGetDstBuf, + /*run=*/&DoColorSpaceTransform, + /*destroy=*/&JxlCmsDestroy}; + return &kInterface; +} + +} // extern "C" + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/cms/jxl_cms_internal.h b/third_party/jpeg-xl/lib/jxl/cms/jxl_cms_internal.h new file mode 100644 index 0000000000..c00fe82d8c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cms/jxl_cms_internal.h @@ -0,0 +1,1083 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_CMS_JXL_CMS_INTERNAL_H_ +#define LIB_JXL_CMS_JXL_CMS_INTERNAL_H_ + +// ICC profiles and color space conversions. + +#include <jxl/color_encoding.h> + +#include <algorithm> +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <string> +#include <vector> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/matrix_ops.h" +#include "lib/jxl/base/span.h" // Bytes +#include "lib/jxl/base/status.h" +#include "lib/jxl/cms/opsin_params.h" +#include "lib/jxl/cms/tone_mapping.h" +#include "lib/jxl/cms/transfer_functions.h" + +#ifndef JXL_ENABLE_3D_ICC_TONEMAPPING +#define JXL_ENABLE_3D_ICC_TONEMAPPING 1 +#endif + +namespace jxl { + +enum class ExtraTF { + kNone, + kPQ, + kHLG, + kSRGB, +}; + +static Status PrimariesToXYZ(float rx, float ry, float gx, float gy, float bx, + float by, float wx, float wy, float matrix[9]) { + bool ok = (wx >= 0) && (wx <= 1) && (wy > 0) && (wy <= 1); + if (!ok) { + return JXL_FAILURE("Invalid white point"); + } + // TODO(lode): also require rx, ry, gx, gy, bx, to be in range 0-1? ICC + // profiles in theory forbid negative XYZ values, but in practice the ACES P0 + // color space uses a negative y for the blue primary. + float primaries[9] = { + rx, gx, bx, ry, gy, by, 1.0f - rx - ry, 1.0f - gx - gy, 1.0f - bx - by}; + float primaries_inv[9]; + memcpy(primaries_inv, primaries, sizeof(float) * 9); + JXL_RETURN_IF_ERROR(Inv3x3Matrix(primaries_inv)); + + float w[3] = {wx / wy, 1.0f, (1.0f - wx - wy) / wy}; + // 1 / tiny float can still overflow + JXL_RETURN_IF_ERROR(std::isfinite(w[0]) && std::isfinite(w[2])); + float xyz[3]; + Mul3x3Vector(primaries_inv, w, xyz); + + float a[9] = { + xyz[0], 0, 0, 0, xyz[1], 0, 0, 0, xyz[2], + }; + + Mul3x3Matrix(primaries, a, matrix); + return true; +} + +/* Chromatic adaptation matrices*/ +constexpr float kBradford[9] = { + 0.8951f, 0.2664f, -0.1614f, -0.7502f, 1.7135f, + 0.0367f, 0.0389f, -0.0685f, 1.0296f, +}; +constexpr float kBradfordInv[9] = { + 0.9869929f, -0.1470543f, 0.1599627f, 0.4323053f, 0.5183603f, + 0.0492912f, -0.0085287f, 0.0400428f, 0.9684867f, +}; + +// Adapts whitepoint x, y to D50 +static Status AdaptToXYZD50(float wx, float wy, float matrix[9]) { + bool ok = (wx >= 0) && (wx <= 1) && (wy > 0) && (wy <= 1); + if (!ok) { + // Out of range values can cause division through zero + // further down with the bradford adaptation too. + return JXL_FAILURE("Invalid white point"); + } + float w[3] = {wx / wy, 1.0f, (1.0f - wx - wy) / wy}; + // 1 / tiny float can still overflow + JXL_RETURN_IF_ERROR(std::isfinite(w[0]) && std::isfinite(w[2])); + float w50[3] = {0.96422f, 1.0f, 0.82521f}; + + float lms[3]; + float lms50[3]; + + Mul3x3Vector(kBradford, w, lms); + Mul3x3Vector(kBradford, w50, lms50); + + if (lms[0] == 0 || lms[1] == 0 || lms[2] == 0) { + return JXL_FAILURE("Invalid white point"); + } + float a[9] = { + // /----> 0, 1, 2, 3, /----> 4, 5, 6, 7, /----> 8, + lms50[0] / lms[0], 0, 0, 0, lms50[1] / lms[1], 0, 0, 0, lms50[2] / lms[2], + }; + if (!std::isfinite(a[0]) || !std::isfinite(a[4]) || !std::isfinite(a[8])) { + return JXL_FAILURE("Invalid white point"); + } + + float b[9]; + Mul3x3Matrix(a, kBradford, b); + Mul3x3Matrix(kBradfordInv, b, matrix); + + return true; +} + +static Status PrimariesToXYZD50(float rx, float ry, float gx, float gy, + float bx, float by, float wx, float wy, + float matrix[9]) { + float toXYZ[9]; + JXL_RETURN_IF_ERROR(PrimariesToXYZ(rx, ry, gx, gy, bx, by, wx, wy, toXYZ)); + float d50[9]; + JXL_RETURN_IF_ERROR(AdaptToXYZD50(wx, wy, d50)); + + Mul3x3Matrix(d50, toXYZ, matrix); + return true; +} + +static Status ToneMapPixel(const JxlColorEncoding& c, const float in[3], + uint8_t pcslab_out[3]) { + float primaries_XYZ[9]; + JXL_RETURN_IF_ERROR(PrimariesToXYZ( + c.primaries_red_xy[0], c.primaries_red_xy[1], c.primaries_green_xy[0], + c.primaries_green_xy[1], c.primaries_blue_xy[0], c.primaries_blue_xy[1], + c.white_point_xy[0], c.white_point_xy[1], primaries_XYZ)); + const float luminances[3] = {primaries_XYZ[3], primaries_XYZ[4], + primaries_XYZ[5]}; + float linear[3]; + JxlTransferFunction tf = c.transfer_function; + if (tf == JXL_TRANSFER_FUNCTION_PQ) { + for (size_t i = 0; i < 3; ++i) { + linear[i] = TF_PQ_Base::DisplayFromEncoded( + /*display_intensity_target=*/10000.0, in[i]); + } + } else { + for (size_t i = 0; i < 3; ++i) { + linear[i] = TF_HLG_Base::DisplayFromEncoded(in[i]); + } + } + if (tf == JXL_TRANSFER_FUNCTION_PQ) { + Rec2408ToneMapperBase tone_mapper({0, 10000}, {0, 250}, luminances); + tone_mapper.ToneMap(&linear[0], &linear[1], &linear[2]); + } else { + HlgOOTF_Base ootf(/*source_luminance=*/300, /*target_luminance=*/80, + luminances); + ootf.Apply(&linear[0], &linear[1], &linear[2]); + } + GamutMapScalar(&linear[0], &linear[1], &linear[2], luminances, + /*preserve_saturation=*/0.3f); + + float chad[9]; + JXL_RETURN_IF_ERROR( + AdaptToXYZD50(c.white_point_xy[0], c.white_point_xy[1], chad)); + float to_xyzd50[9]; + Mul3x3Matrix(chad, primaries_XYZ, to_xyzd50); + + float xyz[3] = {0, 0, 0}; + for (size_t xyz_c = 0; xyz_c < 3; ++xyz_c) { + for (size_t rgb_c = 0; rgb_c < 3; ++rgb_c) { + xyz[xyz_c] += linear[rgb_c] * to_xyzd50[3 * xyz_c + rgb_c]; + } + } + + const auto lab_f = [](const float x) { + static constexpr float kDelta = 6. / 29; + return x <= kDelta * kDelta * kDelta + ? x * (1 / (3 * kDelta * kDelta)) + 4.f / 29 + : std::cbrt(x); + }; + static constexpr float kXn = 0.964212; + static constexpr float kYn = 1; + static constexpr float kZn = 0.825188; + + const float f_x = lab_f(xyz[0] / kXn); + const float f_y = lab_f(xyz[1] / kYn); + const float f_z = lab_f(xyz[2] / kZn); + + pcslab_out[0] = + static_cast<uint8_t>(.5f + 255.f * Clamp1(1.16f * f_y - .16f, 0.f, 1.f)); + pcslab_out[1] = static_cast<uint8_t>( + .5f + 128.f + Clamp1(500 * (f_x - f_y), -128.f, 127.f)); + pcslab_out[2] = static_cast<uint8_t>( + .5f + 128.f + Clamp1(200 * (f_y - f_z), -128.f, 127.f)); + + return true; +} + +static std::vector<uint16_t> CreateTableCurve(uint32_t N, const ExtraTF tf, + bool tone_map) { + // The generated PQ curve will make room for highlights up to this luminance. + // TODO(sboukortt): make this variable? + static constexpr float kPQIntensityTarget = 10000; + + JXL_ASSERT(N <= 4096); // ICC MFT2 only allows 4K entries + JXL_ASSERT(tf == ExtraTF::kPQ || tf == ExtraTF::kHLG); + + static constexpr float kLuminances[] = {1.f / 3, 1.f / 3, 1.f / 3}; + Rec2408ToneMapperBase tone_mapper({0, kPQIntensityTarget}, + {0, kDefaultIntensityTarget}, kLuminances); + // No point using float - LCMS converts to 16-bit for A2B/MFT. + std::vector<uint16_t> table(N); + for (uint32_t i = 0; i < N; ++i) { + const float x = static_cast<float>(i) / (N - 1); // 1.0 at index N - 1. + const double dx = static_cast<double>(x); + // LCMS requires EOTF (e.g. 2.4 exponent). + double y = (tf == ExtraTF::kHLG) + ? TF_HLG_Base::DisplayFromEncoded(dx) + : TF_PQ_Base::DisplayFromEncoded(kPQIntensityTarget, dx); + if (tone_map && tf == ExtraTF::kPQ && + kPQIntensityTarget > kDefaultIntensityTarget) { + float r = y * 10000 / kPQIntensityTarget, g = r, b = r; + tone_mapper.ToneMap(&r, &g, &b); + y = r; + } + JXL_ASSERT(y >= 0.0); + // Clamp to table range - necessary for HLG. + if (y > 1.0) y = 1.0; + // 1.0 corresponds to table value 0xFFFF. + table[i] = static_cast<uint16_t>(roundf(y * 65535.0)); + } + return table; +} + +static Status CIEXYZFromWhiteCIExy(double wx, double wy, float XYZ[3]) { + // Target Y = 1. + if (std::abs(wy) < 1e-12) return JXL_FAILURE("Y value is too small"); + const float factor = 1 / wy; + XYZ[0] = wx * factor; + XYZ[1] = 1; + XYZ[2] = (1 - wx - wy) * factor; + return true; +} + +namespace detail { + +constexpr bool kEnable3DToneMapping = JXL_ENABLE_3D_ICC_TONEMAPPING; + +static bool CanToneMap(const JxlColorEncoding& encoding) { + // If the color space cannot be represented by a CICP tag in the ICC profile + // then the rest of the profile must unambiguously identify it; we have less + // freedom to do use it for tone mapping. + JxlTransferFunction tf = encoding.transfer_function; + JxlPrimaries p = encoding.primaries; + JxlWhitePoint wp = encoding.white_point; + return encoding.color_space == JXL_COLOR_SPACE_RGB && + (tf == JXL_TRANSFER_FUNCTION_PQ || tf == JXL_TRANSFER_FUNCTION_HLG) && + ((p == JXL_PRIMARIES_P3 && + (wp == JXL_WHITE_POINT_D65 || wp == JXL_WHITE_POINT_DCI)) || + (p != JXL_PRIMARIES_CUSTOM && wp == JXL_WHITE_POINT_D65)); +} + +static void ICCComputeMD5(const std::vector<uint8_t>& data, uint8_t sum[16]) + JXL_NO_SANITIZE("unsigned-integer-overflow") { + std::vector<uint8_t> data64 = data; + data64.push_back(128); + // Add bytes such that ((size + 8) & 63) == 0. + size_t extra = ((64 - ((data64.size() + 8) & 63)) & 63); + data64.resize(data64.size() + extra, 0); + for (uint64_t i = 0; i < 64; i += 8) { + data64.push_back(static_cast<uint64_t>(data.size() << 3u) >> i); + } + + static const uint32_t sineparts[64] = { + 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a, + 0xa8304613, 0xfd469501, 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, + 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, 0xf61e2562, 0xc040b340, + 0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8, + 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8, + 0x676f02d9, 0x8d2a4c8a, 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c, + 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, 0x289b7ec6, 0xeaa127fa, + 0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665, + 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92, + 0xffeff47d, 0x85845dd1, 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, + 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391, + }; + static const uint32_t shift[64] = { + 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, + 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, + 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, + 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, + }; + + uint32_t a0 = 0x67452301, b0 = 0xefcdab89, c0 = 0x98badcfe, d0 = 0x10325476; + + for (size_t i = 0; i < data64.size(); i += 64) { + uint32_t a = a0, b = b0, c = c0, d = d0, f, g; + for (size_t j = 0; j < 64; j++) { + if (j < 16) { + f = (b & c) | ((~b) & d); + g = j; + } else if (j < 32) { + f = (d & b) | ((~d) & c); + g = (5 * j + 1) & 0xf; + } else if (j < 48) { + f = b ^ c ^ d; + g = (3 * j + 5) & 0xf; + } else { + f = c ^ (b | (~d)); + g = (7 * j) & 0xf; + } + uint32_t dg0 = data64[i + g * 4 + 0], dg1 = data64[i + g * 4 + 1], + dg2 = data64[i + g * 4 + 2], dg3 = data64[i + g * 4 + 3]; + uint32_t u = dg0 | (dg1 << 8u) | (dg2 << 16u) | (dg3 << 24u); + f += a + sineparts[j] + u; + a = d; + d = c; + c = b; + b += (f << shift[j]) | (f >> (32u - shift[j])); + } + a0 += a; + b0 += b; + c0 += c; + d0 += d; + } + sum[0] = a0; + sum[1] = a0 >> 8u; + sum[2] = a0 >> 16u; + sum[3] = a0 >> 24u; + sum[4] = b0; + sum[5] = b0 >> 8u; + sum[6] = b0 >> 16u; + sum[7] = b0 >> 24u; + sum[8] = c0; + sum[9] = c0 >> 8u; + sum[10] = c0 >> 16u; + sum[11] = c0 >> 24u; + sum[12] = d0; + sum[13] = d0 >> 8u; + sum[14] = d0 >> 16u; + sum[15] = d0 >> 24u; +} + +static Status CreateICCChadMatrix(double wx, double wy, float result[9]) { + float m[9]; + if (wy == 0) { // WhitePoint can not be pitch-black. + return JXL_FAILURE("Invalid WhitePoint"); + } + JXL_RETURN_IF_ERROR(AdaptToXYZD50(wx, wy, m)); + memcpy(result, m, sizeof(float) * 9); + return true; +} + +// Creates RGB to XYZ matrix given RGB primaries and whitepoint in xy. +static Status CreateICCRGBMatrix(double rx, double ry, double gx, double gy, + double bx, double by, double wx, double wy, + float result[9]) { + float m[9]; + JXL_RETURN_IF_ERROR(PrimariesToXYZD50(rx, ry, gx, gy, bx, by, wx, wy, m)); + memcpy(result, m, sizeof(float) * 9); + return true; +} + +static void WriteICCUint32(uint32_t value, size_t pos, + std::vector<uint8_t>* icc) { + if (icc->size() < pos + 4) icc->resize(pos + 4); + (*icc)[pos + 0] = (value >> 24u) & 255; + (*icc)[pos + 1] = (value >> 16u) & 255; + (*icc)[pos + 2] = (value >> 8u) & 255; + (*icc)[pos + 3] = value & 255; +} + +static void WriteICCUint16(uint16_t value, size_t pos, + std::vector<uint8_t>* icc) { + if (icc->size() < pos + 2) icc->resize(pos + 2); + (*icc)[pos + 0] = (value >> 8u) & 255; + (*icc)[pos + 1] = value & 255; +} + +static void WriteICCUint8(uint8_t value, size_t pos, + std::vector<uint8_t>* icc) { + if (icc->size() < pos + 1) icc->resize(pos + 1); + (*icc)[pos] = value; +} + +// Writes a 4-character tag +static void WriteICCTag(const char* value, size_t pos, + std::vector<uint8_t>* icc) { + if (icc->size() < pos + 4) icc->resize(pos + 4); + memcpy(icc->data() + pos, value, 4); +} + +static Status WriteICCS15Fixed16(float value, size_t pos, + std::vector<uint8_t>* icc) { + // "nextafterf" for 32768.0f towards zero are: + // 32767.998046875, 32767.99609375, 32767.994140625 + // Even the first value works well,... + bool ok = (-32767.995f <= value) && (value <= 32767.995f); + if (!ok) return JXL_FAILURE("ICC value is out of range / NaN"); + int32_t i = value * 65536.0f + 0.5f; + // Use two's complement + uint32_t u = static_cast<uint32_t>(i); + WriteICCUint32(u, pos, icc); + return true; +} + +static Status CreateICCHeader(const JxlColorEncoding& c, + std::vector<uint8_t>* header) { + // TODO(lode): choose color management engine name, e.g. "skia" if + // integrated in skia. + static const char* kCmm = "jxl "; + + header->resize(128, 0); + + WriteICCUint32(0, 0, header); // size, correct value filled in at end + WriteICCTag(kCmm, 4, header); + WriteICCUint32(0x04400000u, 8, header); + const char* profile_type = + c.color_space == JXL_COLOR_SPACE_XYB ? "scnr" : "mntr"; + WriteICCTag(profile_type, 12, header); + WriteICCTag(c.color_space == JXL_COLOR_SPACE_GRAY ? "GRAY" : "RGB ", 16, + header); + if (kEnable3DToneMapping && CanToneMap(c)) { + // We are going to use a 3D LUT for tone mapping, which will be more compact + // with an 8-bit LUT to CIELAB than with a 16-bit LUT to XYZ. 8-bit XYZ + // would not be viable due to XYZ being linear, whereas it is fine with + // CIELAB's ~cube root. + WriteICCTag("Lab ", 20, header); + } else { + WriteICCTag("XYZ ", 20, header); + } + + // Three uint32_t's date/time encoding. + // TODO(lode): encode actual date and time, this is a placeholder + uint32_t year = 2019, month = 12, day = 1; + uint32_t hour = 0, minute = 0, second = 0; + WriteICCUint16(year, 24, header); + WriteICCUint16(month, 26, header); + WriteICCUint16(day, 28, header); + WriteICCUint16(hour, 30, header); + WriteICCUint16(minute, 32, header); + WriteICCUint16(second, 34, header); + + WriteICCTag("acsp", 36, header); + WriteICCTag("APPL", 40, header); + WriteICCUint32(0, 44, header); // flags + WriteICCUint32(0, 48, header); // device manufacturer + WriteICCUint32(0, 52, header); // device model + WriteICCUint32(0, 56, header); // device attributes + WriteICCUint32(0, 60, header); // device attributes + WriteICCUint32(static_cast<uint32_t>(c.rendering_intent), 64, header); + + // Mandatory D50 white point of profile connection space + WriteICCUint32(0x0000f6d6, 68, header); + WriteICCUint32(0x00010000, 72, header); + WriteICCUint32(0x0000d32d, 76, header); + + WriteICCTag(kCmm, 80, header); + + return true; +} + +static void AddToICCTagTable(const char* tag, size_t offset, size_t size, + std::vector<uint8_t>* tagtable, + std::vector<size_t>* offsets) { + WriteICCTag(tag, tagtable->size(), tagtable); + // writing true offset deferred to later + WriteICCUint32(0, tagtable->size(), tagtable); + offsets->push_back(offset); + WriteICCUint32(size, tagtable->size(), tagtable); +} + +static void FinalizeICCTag(std::vector<uint8_t>* tags, size_t* offset, + size_t* size) { + while ((tags->size() & 3) != 0) { + tags->push_back(0); + } + *offset += *size; + *size = tags->size() - *offset; +} + +// The input text must be ASCII, writing other characters to UTF-16 is not +// implemented. +static void CreateICCMlucTag(const std::string& text, + std::vector<uint8_t>* tags) { + WriteICCTag("mluc", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + WriteICCUint32(1, tags->size(), tags); + WriteICCUint32(12, tags->size(), tags); + WriteICCTag("enUS", tags->size(), tags); + WriteICCUint32(text.size() * 2, tags->size(), tags); + WriteICCUint32(28, tags->size(), tags); + for (size_t i = 0; i < text.size(); i++) { + tags->push_back(0); // prepend 0 for UTF-16 + tags->push_back(text[i]); + } +} + +static Status CreateICCXYZTag(float xyz[3], std::vector<uint8_t>* tags) { + WriteICCTag("XYZ ", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + for (size_t i = 0; i < 3; ++i) { + JXL_RETURN_IF_ERROR(WriteICCS15Fixed16(xyz[i], tags->size(), tags)); + } + return true; +} + +static Status CreateICCChadTag(float chad[9], std::vector<uint8_t>* tags) { + WriteICCTag("sf32", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + for (size_t i = 0; i < 9; i++) { + JXL_RETURN_IF_ERROR(WriteICCS15Fixed16(chad[i], tags->size(), tags)); + } + return true; +} + +static void MaybeCreateICCCICPTag(const JxlColorEncoding& c, + std::vector<uint8_t>* tags, size_t* offset, + size_t* size, std::vector<uint8_t>* tagtable, + std::vector<size_t>* offsets) { + if (c.color_space != JXL_COLOR_SPACE_RGB) { + return; + } + uint8_t primaries = 0; + if (c.primaries == JXL_PRIMARIES_P3) { + if (c.white_point == JXL_WHITE_POINT_D65) { + primaries = 12; + } else if (c.white_point == JXL_WHITE_POINT_DCI) { + primaries = 11; + } else { + return; + } + } else if (c.primaries != JXL_PRIMARIES_CUSTOM && + c.white_point == JXL_WHITE_POINT_D65) { + primaries = static_cast<uint8_t>(c.primaries); + } else { + return; + } + JxlTransferFunction tf = c.transfer_function; + if (tf == JXL_TRANSFER_FUNCTION_UNKNOWN || + tf == JXL_TRANSFER_FUNCTION_GAMMA) { + return; + } + WriteICCTag("cicp", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + WriteICCUint8(primaries, tags->size(), tags); + WriteICCUint8(static_cast<uint8_t>(tf), tags->size(), tags); + // Matrix + WriteICCUint8(0, tags->size(), tags); + // Full range + WriteICCUint8(1, tags->size(), tags); + FinalizeICCTag(tags, offset, size); + AddToICCTagTable("cicp", *offset, *size, tagtable, offsets); +} + +static void CreateICCCurvCurvTag(const std::vector<uint16_t>& curve, + std::vector<uint8_t>* tags) { + size_t pos = tags->size(); + tags->resize(tags->size() + 12 + curve.size() * 2, 0); + WriteICCTag("curv", pos, tags); + WriteICCUint32(0, pos + 4, tags); + WriteICCUint32(curve.size(), pos + 8, tags); + for (size_t i = 0; i < curve.size(); i++) { + WriteICCUint16(curve[i], pos + 12 + i * 2, tags); + } +} + +// Writes 12 + 4*params.size() bytes +static Status CreateICCCurvParaTag(std::vector<float> params, size_t curve_type, + std::vector<uint8_t>* tags) { + WriteICCTag("para", tags->size(), tags); + WriteICCUint32(0, tags->size(), tags); + WriteICCUint16(curve_type, tags->size(), tags); + WriteICCUint16(0, tags->size(), tags); + for (size_t i = 0; i < params.size(); i++) { + JXL_RETURN_IF_ERROR(WriteICCS15Fixed16(params[i], tags->size(), tags)); + } + return true; +} + +static Status CreateICCLutAtoBTagForXYB(std::vector<uint8_t>* tags) { + WriteICCTag("mAB ", tags->size(), tags); + // 4 reserved bytes set to 0 + WriteICCUint32(0, tags->size(), tags); + // number of input channels + WriteICCUint8(3, tags->size(), tags); + // number of output channels + WriteICCUint8(3, tags->size(), tags); + // 2 reserved bytes for padding + WriteICCUint16(0, tags->size(), tags); + // offset to first B curve + WriteICCUint32(32, tags->size(), tags); + // offset to matrix + WriteICCUint32(244, tags->size(), tags); + // offset to first M curve + WriteICCUint32(148, tags->size(), tags); + // offset to CLUT + WriteICCUint32(80, tags->size(), tags); + // offset to first A curve + // (reuse linear B curves) + WriteICCUint32(32, tags->size(), tags); + + // offset = 32 + // no-op curves + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag({1.0f}, 0, tags)); + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag({1.0f}, 0, tags)); + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag({1.0f}, 0, tags)); + // offset = 80 + // number of grid points for each input channel + for (int i = 0; i < 16; ++i) { + WriteICCUint8(i < 3 ? 2 : 0, tags->size(), tags); + } + // precision = 2 + WriteICCUint8(2, tags->size(), tags); + // 3 bytes of padding + WriteICCUint8(0, tags->size(), tags); + WriteICCUint16(0, tags->size(), tags); + // 2*2*2*3 entries of 2 bytes each = 48 bytes + const jxl::cms::ColorCube3D& cube = jxl::cms::UnscaledA2BCube(); + for (size_t ix = 0; ix < 2; ++ix) { + for (size_t iy = 0; iy < 2; ++iy) { + for (size_t ib = 0; ib < 2; ++ib) { + const jxl::cms::ColorCube0D& out_f = cube[ix][iy][ib]; + for (int i = 0; i < 3; ++i) { + int32_t val = static_cast<int32_t>(0.5f + 65535 * out_f[i]); + JXL_DASSERT(val >= 0 && val <= 65535); + WriteICCUint16(val, tags->size(), tags); + } + } + } + } + // offset = 148 + // 3 curves with 5 parameters = 3 * (12 + 5 * 4) = 96 bytes + for (size_t i = 0; i < 3; ++i) { + const float b = -jxl::cms::kXYBOffset[i] - + std::cbrt(jxl::cms::kNegOpsinAbsorbanceBiasRGB[i]); + std::vector<float> params = { + 3, + 1.0f / jxl::cms::kXYBScale[i], + b, + 0, // unused + std::max(0.f, -b * jxl::cms::kXYBScale[i]), // make skcms happy + }; + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag(params, 3, tags)); + } + // offset = 244 + const double matrix[] = {1.5170095, -1.1065225, 0.071623, + -0.050022, 0.5683655, -0.018344, + -1.387676, 1.1145555, 0.6857255}; + // 12 * 4 = 48 bytes + for (size_t i = 0; i < 9; ++i) { + JXL_RETURN_IF_ERROR(WriteICCS15Fixed16(matrix[i], tags->size(), tags)); + } + for (size_t i = 0; i < 3; ++i) { + float intercept = 0; + for (size_t j = 0; j < 3; ++j) { + intercept += matrix[i * 3 + j] * jxl::cms::kNegOpsinAbsorbanceBiasRGB[j]; + } + JXL_RETURN_IF_ERROR(WriteICCS15Fixed16(intercept, tags->size(), tags)); + } + return true; +} + +static Status CreateICCLutAtoBTagForHDR(JxlColorEncoding c, + std::vector<uint8_t>* tags) { + static constexpr size_t k3DLutDim = 9; + WriteICCTag("mft1", tags->size(), tags); + // 4 reserved bytes set to 0 + WriteICCUint32(0, tags->size(), tags); + // number of input channels + WriteICCUint8(3, tags->size(), tags); + // number of output channels + WriteICCUint8(3, tags->size(), tags); + // number of CLUT grid points + WriteICCUint8(k3DLutDim, tags->size(), tags); + // 1 reserved bytes for padding + WriteICCUint8(0, tags->size(), tags); + + // Matrix (per specification, must be identity if input is not XYZ) + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < 3; ++j) { + JXL_RETURN_IF_ERROR( + WriteICCS15Fixed16(i == j ? 1.f : 0.f, tags->size(), tags)); + } + } + + // Input tables + for (size_t c = 0; c < 3; ++c) { + for (size_t i = 0; i < 256; ++i) { + WriteICCUint8(i, tags->size(), tags); + } + } + + for (size_t ix = 0; ix < k3DLutDim; ++ix) { + for (size_t iy = 0; iy < k3DLutDim; ++iy) { + for (size_t ib = 0; ib < k3DLutDim; ++ib) { + float f[3] = {ix * (1.0f / (k3DLutDim - 1)), + iy * (1.0f / (k3DLutDim - 1)), + ib * (1.0f / (k3DLutDim - 1))}; + uint8_t pcslab_out[3]; + JXL_RETURN_IF_ERROR(ToneMapPixel(c, f, pcslab_out)); + for (uint8_t val : pcslab_out) { + WriteICCUint8(val, tags->size(), tags); + } + } + } + } + + // Output tables + for (size_t c = 0; c < 3; ++c) { + for (size_t i = 0; i < 256; ++i) { + WriteICCUint8(i, tags->size(), tags); + } + } + + return true; +} + +// Some software (Apple Safari, Preview) requires this. +static Status CreateICCNoOpBToATag(std::vector<uint8_t>* tags) { + WriteICCTag("mBA ", tags->size(), tags); + // 4 reserved bytes set to 0 + WriteICCUint32(0, tags->size(), tags); + // number of input channels + WriteICCUint8(3, tags->size(), tags); + // number of output channels + WriteICCUint8(3, tags->size(), tags); + // 2 reserved bytes for padding + WriteICCUint16(0, tags->size(), tags); + // offset to first B curve + WriteICCUint32(32, tags->size(), tags); + // offset to matrix + WriteICCUint32(0, tags->size(), tags); + // offset to first M curve + WriteICCUint32(0, tags->size(), tags); + // offset to CLUT + WriteICCUint32(0, tags->size(), tags); + // offset to first A curve + WriteICCUint32(0, tags->size(), tags); + + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag({1.0f}, 0, tags)); + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag({1.0f}, 0, tags)); + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag({1.0f}, 0, tags)); + + return true; +} + +// These strings are baked into Description - do not change. + +static std::string ToString(JxlColorSpace color_space) { + switch (color_space) { + case JXL_COLOR_SPACE_RGB: + return "RGB"; + case JXL_COLOR_SPACE_GRAY: + return "Gra"; + case JXL_COLOR_SPACE_XYB: + return "XYB"; + case JXL_COLOR_SPACE_UNKNOWN: + return "CS?"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_UNREACHABLE("Invalid ColorSpace %u", static_cast<uint32_t>(color_space)); +} + +static std::string ToString(JxlWhitePoint white_point) { + switch (white_point) { + case JXL_WHITE_POINT_D65: + return "D65"; + case JXL_WHITE_POINT_CUSTOM: + return "Cst"; + case JXL_WHITE_POINT_E: + return "EER"; + case JXL_WHITE_POINT_DCI: + return "DCI"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_UNREACHABLE("Invalid WhitePoint %u", static_cast<uint32_t>(white_point)); +} + +static std::string ToString(JxlPrimaries primaries) { + switch (primaries) { + case JXL_PRIMARIES_SRGB: + return "SRG"; + case JXL_PRIMARIES_2100: + return "202"; + case JXL_PRIMARIES_P3: + return "DCI"; + case JXL_PRIMARIES_CUSTOM: + return "Cst"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_UNREACHABLE("Invalid Primaries %u", static_cast<uint32_t>(primaries)); +} + +static std::string ToString(JxlTransferFunction transfer_function) { + switch (transfer_function) { + case JXL_TRANSFER_FUNCTION_SRGB: + return "SRG"; + case JXL_TRANSFER_FUNCTION_LINEAR: + return "Lin"; + case JXL_TRANSFER_FUNCTION_709: + return "709"; + case JXL_TRANSFER_FUNCTION_PQ: + return "PeQ"; + case JXL_TRANSFER_FUNCTION_HLG: + return "HLG"; + case JXL_TRANSFER_FUNCTION_DCI: + return "DCI"; + case JXL_TRANSFER_FUNCTION_UNKNOWN: + return "TF?"; + case JXL_TRANSFER_FUNCTION_GAMMA: + JXL_UNREACHABLE("Invalid TransferFunction: gamma"); + } + // Should not happen - visitor fails if enum is invalid. + JXL_UNREACHABLE("Invalid TransferFunction %u", + static_cast<uint32_t>(transfer_function)); +} + +static std::string ToString(JxlRenderingIntent rendering_intent) { + switch (rendering_intent) { + case JXL_RENDERING_INTENT_PERCEPTUAL: + return "Per"; + case JXL_RENDERING_INTENT_RELATIVE: + return "Rel"; + case JXL_RENDERING_INTENT_SATURATION: + return "Sat"; + case JXL_RENDERING_INTENT_ABSOLUTE: + return "Abs"; + } + // Should not happen - visitor fails if enum is invalid. + JXL_UNREACHABLE("Invalid RenderingIntent %u", + static_cast<uint32_t>(rendering_intent)); +} + +static std::string ColorEncodingDescriptionImpl(const JxlColorEncoding& c) { + std::string d = ToString(c.color_space); + + bool explicit_wp_tf = (c.color_space != JXL_COLOR_SPACE_XYB); + if (explicit_wp_tf) { + d += '_'; + if (c.white_point == JXL_WHITE_POINT_CUSTOM) { + d += jxl::ToString(c.white_point_xy[0]) + ';'; + d += jxl::ToString(c.white_point_xy[1]); + } else { + d += ToString(c.white_point); + } + } + + if ((c.color_space != JXL_COLOR_SPACE_GRAY) && + (c.color_space != JXL_COLOR_SPACE_XYB)) { + d += '_'; + if (c.primaries == JXL_PRIMARIES_CUSTOM) { + d += jxl::ToString(c.primaries_red_xy[0]) + ';'; + d += jxl::ToString(c.primaries_red_xy[1]) + ';'; + d += jxl::ToString(c.primaries_green_xy[0]) + ';'; + d += jxl::ToString(c.primaries_green_xy[1]) + ';'; + d += jxl::ToString(c.primaries_blue_xy[0]) + ';'; + d += jxl::ToString(c.primaries_blue_xy[1]); + } else { + d += ToString(c.primaries); + } + } + + d += '_'; + d += ToString(c.rendering_intent); + + if (explicit_wp_tf) { + JxlTransferFunction tf = c.transfer_function; + d += '_'; + if (tf == JXL_TRANSFER_FUNCTION_GAMMA) { + d += 'g'; + d += jxl::ToString(c.gamma); + } else { + d += ToString(tf); + } + } + return d; +} + +static Status MaybeCreateProfileImpl(const JxlColorEncoding& c, + std::vector<uint8_t>* icc) { + std::vector<uint8_t> header, tagtable, tags; + JxlTransferFunction tf = c.transfer_function; + if (c.color_space == JXL_COLOR_SPACE_UNKNOWN || + tf == JXL_TRANSFER_FUNCTION_UNKNOWN) { + return false; // Not an error + } + + switch (c.color_space) { + case JXL_COLOR_SPACE_RGB: + case JXL_COLOR_SPACE_GRAY: + case JXL_COLOR_SPACE_XYB: + break; // OK + default: + return JXL_FAILURE("Invalid CS %u", + static_cast<unsigned int>(c.color_space)); + } + + if (c.color_space == JXL_COLOR_SPACE_XYB && + c.rendering_intent != JXL_RENDERING_INTENT_PERCEPTUAL) { + return JXL_FAILURE( + "Only perceptual rendering intent implemented for XYB " + "ICC profile."); + } + + JXL_RETURN_IF_ERROR(CreateICCHeader(c, &header)); + + std::vector<size_t> offsets; + // tag count, deferred to later + WriteICCUint32(0, tagtable.size(), &tagtable); + + size_t tag_offset = 0, tag_size = 0; + + CreateICCMlucTag(ColorEncodingDescriptionImpl(c), &tags); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("desc", tag_offset, tag_size, &tagtable, &offsets); + + const std::string copyright = "CC0"; + CreateICCMlucTag(copyright, &tags); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("cprt", tag_offset, tag_size, &tagtable, &offsets); + + // TODO(eustas): isn't it the other way round: gray image has d50 WhitePoint? + if (c.color_space == JXL_COLOR_SPACE_GRAY) { + float wtpt[3]; + JXL_RETURN_IF_ERROR( + CIEXYZFromWhiteCIExy(c.white_point_xy[0], c.white_point_xy[1], wtpt)); + JXL_RETURN_IF_ERROR(CreateICCXYZTag(wtpt, &tags)); + } else { + float d50[3] = {0.964203, 1.0, 0.824905}; + JXL_RETURN_IF_ERROR(CreateICCXYZTag(d50, &tags)); + } + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("wtpt", tag_offset, tag_size, &tagtable, &offsets); + + if (c.color_space != JXL_COLOR_SPACE_GRAY) { + // Chromatic adaptation matrix + float chad[9]; + JXL_RETURN_IF_ERROR( + CreateICCChadMatrix(c.white_point_xy[0], c.white_point_xy[1], chad)); + + JXL_RETURN_IF_ERROR(CreateICCChadTag(chad, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("chad", tag_offset, tag_size, &tagtable, &offsets); + } + + if (c.color_space == JXL_COLOR_SPACE_RGB) { + MaybeCreateICCCICPTag(c, &tags, &tag_offset, &tag_size, &tagtable, + &offsets); + + float m[9]; + JXL_RETURN_IF_ERROR(CreateICCRGBMatrix( + c.primaries_red_xy[0], c.primaries_red_xy[1], c.primaries_green_xy[0], + c.primaries_green_xy[1], c.primaries_blue_xy[0], c.primaries_blue_xy[1], + c.white_point_xy[0], c.white_point_xy[1], m)); + float r[3] = {m[0], m[3], m[6]}; + float g[3] = {m[1], m[4], m[7]}; + float b[3] = {m[2], m[5], m[8]}; + + JXL_RETURN_IF_ERROR(CreateICCXYZTag(r, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("rXYZ", tag_offset, tag_size, &tagtable, &offsets); + + JXL_RETURN_IF_ERROR(CreateICCXYZTag(g, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("gXYZ", tag_offset, tag_size, &tagtable, &offsets); + + JXL_RETURN_IF_ERROR(CreateICCXYZTag(b, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("bXYZ", tag_offset, tag_size, &tagtable, &offsets); + } + + if (c.color_space == JXL_COLOR_SPACE_XYB) { + JXL_RETURN_IF_ERROR(CreateICCLutAtoBTagForXYB(&tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("A2B0", tag_offset, tag_size, &tagtable, &offsets); + JXL_RETURN_IF_ERROR(CreateICCNoOpBToATag(&tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("B2A0", tag_offset, tag_size, &tagtable, &offsets); + } else if (kEnable3DToneMapping && CanToneMap(c)) { + JXL_RETURN_IF_ERROR(CreateICCLutAtoBTagForHDR(c, &tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("A2B0", tag_offset, tag_size, &tagtable, &offsets); + JXL_RETURN_IF_ERROR(CreateICCNoOpBToATag(&tags)); + FinalizeICCTag(&tags, &tag_offset, &tag_size); + AddToICCTagTable("B2A0", tag_offset, tag_size, &tagtable, &offsets); + } else { + if (tf == JXL_TRANSFER_FUNCTION_GAMMA) { + float gamma = 1.0 / c.gamma; + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag({gamma}, 0, &tags)); + } else if (c.color_space != JXL_COLOR_SPACE_XYB) { + switch (tf) { + case JXL_TRANSFER_FUNCTION_HLG: + CreateICCCurvCurvTag( + CreateTableCurve(64, ExtraTF::kHLG, CanToneMap(c)), &tags); + break; + case JXL_TRANSFER_FUNCTION_PQ: + CreateICCCurvCurvTag( + CreateTableCurve(64, ExtraTF::kPQ, CanToneMap(c)), &tags); + break; + case JXL_TRANSFER_FUNCTION_SRGB: + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag( + {2.4, 1.0 / 1.055, 0.055 / 1.055, 1.0 / 12.92, 0.04045}, 3, + &tags)); + break; + case JXL_TRANSFER_FUNCTION_709: + JXL_RETURN_IF_ERROR(CreateICCCurvParaTag( + {1.0 / 0.45, 1.0 / 1.099, 0.099 / 1.099, 1.0 / 4.5, 0.081}, 3, + &tags)); + break; + case JXL_TRANSFER_FUNCTION_LINEAR: + JXL_RETURN_IF_ERROR( + CreateICCCurvParaTag({1.0, 1.0, 0.0, 1.0, 0.0}, 3, &tags)); + break; + case JXL_TRANSFER_FUNCTION_DCI: + JXL_RETURN_IF_ERROR( + CreateICCCurvParaTag({2.6, 1.0, 0.0, 1.0, 0.0}, 3, &tags)); + break; + default: + JXL_UNREACHABLE("Unknown TF %u", static_cast<unsigned int>(tf)); + } + } + FinalizeICCTag(&tags, &tag_offset, &tag_size); + if (c.color_space == JXL_COLOR_SPACE_GRAY) { + AddToICCTagTable("kTRC", tag_offset, tag_size, &tagtable, &offsets); + } else { + AddToICCTagTable("rTRC", tag_offset, tag_size, &tagtable, &offsets); + AddToICCTagTable("gTRC", tag_offset, tag_size, &tagtable, &offsets); + AddToICCTagTable("bTRC", tag_offset, tag_size, &tagtable, &offsets); + } + } + + // Tag count + WriteICCUint32(offsets.size(), 0, &tagtable); + for (size_t i = 0; i < offsets.size(); i++) { + WriteICCUint32(offsets[i] + header.size() + tagtable.size(), 4 + 12 * i + 4, + &tagtable); + } + + // ICC profile size + WriteICCUint32(header.size() + tagtable.size() + tags.size(), 0, &header); + + *icc = header; + Bytes(tagtable).AppendTo(icc); + Bytes(tags).AppendTo(icc); + + // The MD5 checksum must be computed on the profile with profile flags, + // rendering intent, and region of the checksum itself, set to 0. + // TODO(lode): manually verify with a reliable tool that this creates correct + // signature (profile id) for ICC profiles. + std::vector<uint8_t> icc_sum = *icc; + if (icc_sum.size() >= 64 + 4) { + memset(icc_sum.data() + 44, 0, 4); + memset(icc_sum.data() + 64, 0, 4); + } + uint8_t checksum[16]; + detail::ICCComputeMD5(icc_sum, checksum); + + memcpy(icc->data() + 84, checksum, sizeof(checksum)); + + return true; +} + +} // namespace detail + +// Returns a representation of the ColorEncoding fields (not icc). +// Example description: "RGB_D65_SRG_Rel_Lin" +static JXL_MAYBE_UNUSED std::string ColorEncodingDescription( + const JxlColorEncoding& c) { + return detail::ColorEncodingDescriptionImpl(c); +} + +// NOTE: for XYB colorspace, the created profile can be used to transform a +// *scaled* XYB image (created by ScaleXYB()) to another colorspace. +static JXL_MAYBE_UNUSED Status MaybeCreateProfile(const JxlColorEncoding& c, + std::vector<uint8_t>* icc) { + return detail::MaybeCreateProfileImpl(c, icc); +} + +} // namespace jxl + +#endif // LIB_JXL_CMS_JXL_CMS_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/cms/opsin_params.h b/third_party/jpeg-xl/lib/jxl/cms/opsin_params.h new file mode 100644 index 0000000000..48e8e254f7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cms/opsin_params.h @@ -0,0 +1,160 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_CMS_OPSIN_PARAMS_H_ +#define LIB_JXL_CMS_OPSIN_PARAMS_H_ + +#include <array> + +// Constants that define the XYB color space. + +namespace jxl { +namespace cms { + +// Parameters for opsin absorbance. +constexpr float kM02 = 0.078f; +constexpr float kM00 = 0.30f; +constexpr float kM01 = 1.0f - kM02 - kM00; + +constexpr float kM12 = 0.078f; +constexpr float kM10 = 0.23f; +constexpr float kM11 = 1.0f - kM12 - kM10; + +constexpr float kM20 = 0.24342268924547819f; +constexpr float kM21 = 0.20476744424496821f; +constexpr float kM22 = 1.0f - kM20 - kM21; + +constexpr float kBScale = 1.0f; +constexpr float kYToBRatio = 1.0f; // works better with 0.50017729543783418 +constexpr float kBToYRatio = 1.0f / kYToBRatio; + +constexpr float kOpsinAbsorbanceBias0 = 0.0037930732552754493f; +constexpr float kOpsinAbsorbanceBias1 = kOpsinAbsorbanceBias0; +constexpr float kOpsinAbsorbanceBias2 = kOpsinAbsorbanceBias0; + +// Opsin absorbance matrix is now frozen. +constexpr std::array<float, 9> kOpsinAbsorbanceMatrix = { + kM00, kM01, kM02, kM10, kM11, kM12, kM20, kM21, kM22, +}; + +constexpr std::array<float, 9> kDefaultInverseOpsinAbsorbanceMatrix = { + 11.031566901960783f, -9.866943921568629f, -0.16462299647058826f, + -3.254147380392157f, 4.418770392156863f, -0.16462299647058826f, + -3.6588512862745097f, 2.7129230470588235f, 1.9459282392156863f}; + +// Must be the inverse matrix of kOpsinAbsorbanceMatrix and match the spec. +static inline const float* DefaultInverseOpsinAbsorbanceMatrix() { + return kDefaultInverseOpsinAbsorbanceMatrix.data(); +} + +constexpr std::array<float, 3> kOpsinAbsorbanceBias = { + kOpsinAbsorbanceBias0, + kOpsinAbsorbanceBias1, + kOpsinAbsorbanceBias2, +}; + +constexpr std::array<float, 4> kNegOpsinAbsorbanceBiasRGB = { + -kOpsinAbsorbanceBias0, -kOpsinAbsorbanceBias1, -kOpsinAbsorbanceBias2, + 1.0f}; + +constexpr float kScaledXYBOffset0 = 0.015386134f; +constexpr float kScaledXYBOffset1 = 0.0f; +constexpr float kScaledXYBOffset2 = 0.27770459f; + +constexpr std::array<float, 3> kScaledXYBOffset = { + kScaledXYBOffset0, kScaledXYBOffset1, kScaledXYBOffset2}; + +constexpr float kScaledXYBScale0 = 22.995788804f; +constexpr float kScaledXYBScale1 = 1.183000077f; +constexpr float kScaledXYBScale2 = 1.502141333f; + +constexpr std::array<float, 3> kScaledXYBScale = { + kScaledXYBScale0, + kScaledXYBScale1, + kScaledXYBScale2, +}; + +// NB(eustas): following function/variable names are just "namos". + +// More precise calculation of 1 / ((1 / r1) + (1 / r2)) +constexpr float ReciprocialSum(float r1, float r2) { + return (r1 * r2) / (r1 + r2); +} + +constexpr float kXYBOffset0 = kScaledXYBOffset0 + kScaledXYBOffset1; +constexpr float kXYBOffset1 = + kScaledXYBOffset1 - kScaledXYBOffset0 + (1.0f / kScaledXYBScale0); +constexpr float kXYBOffset2 = kScaledXYBOffset1 + kScaledXYBOffset2; + +constexpr std::array<float, 3> kXYBOffset = {kXYBOffset0, kXYBOffset1, + kXYBOffset2}; + +constexpr float kXYBScale0 = ReciprocialSum(kScaledXYBScale0, kScaledXYBScale1); +constexpr float kXYBScale1 = ReciprocialSum(kScaledXYBScale0, kScaledXYBScale1); +constexpr float kXYBScale2 = ReciprocialSum(kScaledXYBScale1, kScaledXYBScale2); + +constexpr std::array<float, 3> kXYBScale = {kXYBScale0, kXYBScale1, kXYBScale2}; + +template <size_t idx> +constexpr float ScaledXYBScale() { + return (idx == 0) ? kScaledXYBScale0 + : (idx == 1) ? kScaledXYBScale1 + : kScaledXYBScale2; +} + +template <size_t idx> +constexpr float ScaledXYBOffset() { + return (idx == 0) ? kScaledXYBOffset0 + : (idx == 1) ? kScaledXYBOffset1 + : kScaledXYBOffset2; +} + +template <size_t x, size_t y, size_t b, size_t idx> +constexpr float XYBCorner() { + return (((idx == 0) ? x + : (idx == 1) ? y + : b) / + ScaledXYBScale<idx>()) - + ScaledXYBOffset<idx>(); +} + +template <size_t x, size_t y, size_t b, size_t idx> +constexpr float ScaledA2BCorner() { + return (idx == 0) ? (XYBCorner<x, y, b, 1>() + XYBCorner<x, y, b, 0>()) + : (idx == 1) ? (XYBCorner<x, y, b, 1>() - XYBCorner<x, y, b, 0>()) + : (XYBCorner<x, y, b, 2>() + XYBCorner<x, y, b, 1>()); +} + +typedef std::array<float, 3> ColorCube0D; +template <size_t x, size_t y, size_t b> +constexpr ColorCube0D UnscaledA2BCorner() { + return {(ScaledA2BCorner<x, y, b, 0>() + kXYBOffset0) * kXYBScale0, + (ScaledA2BCorner<x, y, b, 1>() + kXYBOffset1) * kXYBScale1, + (ScaledA2BCorner<x, y, b, 2>() + kXYBOffset2) * kXYBScale2}; +} + +typedef std::array<ColorCube0D, 2> ColorCube1D; +template <size_t x, size_t y> +constexpr ColorCube1D UnscaledA2BCubeXY() { + return {UnscaledA2BCorner<x, y, 0>(), UnscaledA2BCorner<x, y, 1>()}; +} + +typedef std::array<ColorCube1D, 2> ColorCube2D; +template <size_t x> +constexpr ColorCube2D UnscaledA2BCubeX() { + return {UnscaledA2BCubeXY<x, 0>(), UnscaledA2BCubeXY<x, 1>()}; +} + +typedef std::array<ColorCube2D, 2> ColorCube3D; +constexpr ColorCube3D UnscaledA2BCube() { + return {UnscaledA2BCubeX<0>(), UnscaledA2BCubeX<1>()}; +} + +constexpr ColorCube3D kUnscaledA2BCube = UnscaledA2BCube(); + +} // namespace cms +} // namespace jxl + +#endif // LIB_JXL_CMS_OPSIN_PARAMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/cms/tone_mapping-inl.h b/third_party/jpeg-xl/lib/jxl/cms/tone_mapping-inl.h new file mode 100644 index 0000000000..3d94ccea12 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cms/tone_mapping-inl.h @@ -0,0 +1,191 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_CMS_TONE_MAPPING_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_CMS_TONE_MAPPING_INL_H_ +#undef LIB_JXL_CMS_TONE_MAPPING_INL_H_ +#else +#define LIB_JXL_CMS_TONE_MAPPING_INL_H_ +#endif + +#include <hwy/highway.h> + +#include "lib/jxl/cms/tone_mapping.h" +#include "lib/jxl/cms/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Clamp; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +template <typename D> +class Rec2408ToneMapper : Rec2408ToneMapperBase { + private: + using V = hwy::HWY_NAMESPACE::Vec<D>; + + public: + using Rec2408ToneMapperBase::Rec2408ToneMapperBase; + + void ToneMap(V* red, V* green, V* blue) const { + const V luminance = Mul(Set(df_, source_range_.second), + (MulAdd(Set(df_, red_Y_), *red, + MulAdd(Set(df_, green_Y_), *green, + Mul(Set(df_, blue_Y_), *blue))))); + const V pq_mastering_min = Set(df_, pq_mastering_min_); + const V inv_pq_mastering_range = Set(df_, inv_pq_mastering_range_); + const V normalized_pq = Min( + Set(df_, 1.f), + Mul(Sub(InvEOTF(luminance), pq_mastering_min), inv_pq_mastering_range)); + const V ks = Set(df_, ks_); + const V e2 = + IfThenElse(Lt(normalized_pq, ks), normalized_pq, P(normalized_pq)); + const V one_minus_e2 = Sub(Set(df_, 1), e2); + const V one_minus_e2_2 = Mul(one_minus_e2, one_minus_e2); + const V one_minus_e2_4 = Mul(one_minus_e2_2, one_minus_e2_2); + const V b = Set(df_, min_lum_); + const V e3 = MulAdd(b, one_minus_e2_4, e2); + const V pq_mastering_range = Set(df_, pq_mastering_range_); + const V e4 = MulAdd(e3, pq_mastering_range, pq_mastering_min); + const V new_luminance = + Min(Set(df_, target_range_.second), + ZeroIfNegative(tf_pq_.DisplayFromEncoded(df_, e4))); + const V min_luminance = Set(df_, 1e-6f); + const auto use_cap = Le(luminance, min_luminance); + const V ratio = Div(new_luminance, Max(luminance, min_luminance)); + const V cap = Mul(new_luminance, Set(df_, inv_target_peak_)); + const V normalizer = Set(df_, normalizer_); + const V multiplier = Mul(ratio, normalizer); + for (V* const val : {red, green, blue}) { + *val = IfThenElse(use_cap, cap, Mul(*val, multiplier)); + } + } + + private: + V InvEOTF(const V luminance) const { + return tf_pq_.EncodedFromDisplay(df_, luminance); + } + V T(const V a) const { + const V ks = Set(df_, ks_); + const V inv_one_minus_ks = Set(df_, inv_one_minus_ks_); + return Mul(Sub(a, ks), inv_one_minus_ks); + } + V P(const V b) const { + const V t_b = T(b); + const V t_b_2 = Mul(t_b, t_b); + const V t_b_3 = Mul(t_b_2, t_b); + const V ks = Set(df_, ks_); + const V max_lum = Set(df_, max_lum_); + return MulAdd( + MulAdd(Set(df_, 2), t_b_3, MulAdd(Set(df_, -3), t_b_2, Set(df_, 1))), + ks, + MulAdd(Add(t_b_3, MulAdd(Set(df_, -2), t_b_2, t_b)), + Sub(Set(df_, 1), ks), + Mul(MulAdd(Set(df_, -2), t_b_3, Mul(Set(df_, 3), t_b_2)), + max_lum))); + } + + D df_; + const TF_PQ tf_pq_ = TF_PQ(/*display_intensity_target=*/1.0); +}; + +class HlgOOTF : HlgOOTF_Base { + public: + using HlgOOTF_Base::HlgOOTF_Base; + + static HlgOOTF FromSceneLight(float display_luminance, + const float primaries_luminances[3]) { + return HlgOOTF(/*gamma=*/1.2f * + std::pow(1.111f, std::log2(display_luminance / 1000.f)), + primaries_luminances); + } + + static HlgOOTF ToSceneLight(float display_luminance, + const float primaries_luminances[3]) { + return HlgOOTF( + /*gamma=*/(1 / 1.2f) * + std::pow(1.111f, -std::log2(display_luminance / 1000.f)), + primaries_luminances); + } + + template <typename V> + void Apply(V* red, V* green, V* blue) const { + hwy::HWY_NAMESPACE::DFromV<V> df; + if (!apply_ootf_) return; + const V luminance = + MulAdd(Set(df, red_Y_), *red, + MulAdd(Set(df, green_Y_), *green, Mul(Set(df, blue_Y_), *blue))); + const V ratio = + Min(FastPowf(df, luminance, Set(df, exponent_)), Set(df, 1e9)); + *red = Mul(*red, ratio); + *green = Mul(*green, ratio); + *blue = Mul(*blue, ratio); + } + + bool WarrantsGamutMapping() const { return apply_ootf_ && exponent_ < 0; } +}; + +template <typename V> +void GamutMap(V* red, V* green, V* blue, const float primaries_luminances[3], + float preserve_saturation = 0.1f) { + hwy::HWY_NAMESPACE::DFromV<V> df; + const V luminance = + MulAdd(Set(df, primaries_luminances[0]), *red, + MulAdd(Set(df, primaries_luminances[1]), *green, + Mul(Set(df, primaries_luminances[2]), *blue))); + + // Desaturate out-of-gamut pixels. This is done by mixing each pixel + // with just enough gray of the target luminance to make all + // components non-negative. + // - For saturation preservation, if a component is still larger than + // 1 then the pixel is normalized to have a maximum component of 1. + // That will reduce its luminance. + // - For luminance preservation, getting all components below 1 is + // done by mixing in yet more gray. That will desaturate it further. + const V zero = Zero(df); + const V one = Set(df, 1); + V gray_mix_saturation = zero; + V gray_mix_luminance = zero; + for (const V* ch : {red, green, blue}) { + const V& val = *ch; + const V val_minus_gray = Sub(val, luminance); + const V inv_val_minus_gray = + Div(one, IfThenElse(Eq(val_minus_gray, zero), one, val_minus_gray)); + const V val_over_val_minus_gray = Mul(val, inv_val_minus_gray); + gray_mix_saturation = + IfThenElse(Ge(val_minus_gray, zero), gray_mix_saturation, + Max(gray_mix_saturation, val_over_val_minus_gray)); + gray_mix_luminance = + Max(gray_mix_luminance, + IfThenElse(Le(val_minus_gray, zero), gray_mix_saturation, + Sub(val_over_val_minus_gray, inv_val_minus_gray))); + } + const V gray_mix = Clamp( + MulAdd(Set(df, preserve_saturation), + Sub(gray_mix_saturation, gray_mix_luminance), gray_mix_luminance), + zero, one); + for (V* const ch : {red, green, blue}) { + V& val = *ch; + val = MulAdd(gray_mix, Sub(luminance, val), val); + } + const V max_clr = Max(Max(one, *red), Max(*green, *blue)); + const V normalizer = Div(one, max_clr); + for (V* const ch : {red, green, blue}) { + V& val = *ch; + val = Mul(val, normalizer); + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_CMS_TONE_MAPPING_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/cms/tone_mapping.h b/third_party/jpeg-xl/lib/jxl/cms/tone_mapping.h new file mode 100644 index 0000000000..a114109ea6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cms/tone_mapping.h @@ -0,0 +1,179 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_CMS_TONE_MAPPING_H_ +#define LIB_JXL_CMS_TONE_MAPPING_H_ + +#include <algorithm> +#include <cmath> +#include <utility> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/cms/transfer_functions.h" + +namespace jxl { + +class Rec2408ToneMapperBase { + public: + explicit Rec2408ToneMapperBase(std::pair<float, float> source_range, + std::pair<float, float> target_range, + const float primaries_luminances[3]) + : source_range_(source_range), + target_range_(target_range), + red_Y_(primaries_luminances[0]), + green_Y_(primaries_luminances[1]), + blue_Y_(primaries_luminances[2]) {} + + // TODO(eustas): test me + void ToneMap(float* red, float* green, float* blue) const { + const float luminance = + source_range_.second * + (red_Y_ * *red + green_Y_ * *green + blue_Y_ * *blue); + const float normalized_pq = + std::min(1.f, (InvEOTF(luminance) - pq_mastering_min_) * + inv_pq_mastering_range_); + const float e2 = (normalized_pq < ks_) ? normalized_pq : P(normalized_pq); + const float one_minus_e2 = 1 - e2; + const float one_minus_e2_2 = one_minus_e2 * one_minus_e2; + const float one_minus_e2_4 = one_minus_e2_2 * one_minus_e2_2; + const float e3 = min_lum_ * one_minus_e2_4 + e2; + const float e4 = e3 * pq_mastering_range_ + pq_mastering_min_; + const float d4 = + TF_PQ_Base::DisplayFromEncoded(/*display_intensity_target=*/1.0, e4); + const float new_luminance = Clamp1(d4, 0.f, target_range_.second); + const float min_luminance = 1e-6f; + const bool use_cap = (luminance <= min_luminance); + const float ratio = new_luminance / std::max(luminance, min_luminance); + const float cap = new_luminance * inv_target_peak_; + const float multiplier = ratio * normalizer_; + for (float* const val : {red, green, blue}) { + *val = use_cap ? cap : *val * multiplier; + } + } + + protected: + float InvEOTF(const float luminance) const { + return TF_PQ_Base::EncodedFromDisplay(/*display_intensity_target=*/1.0, + luminance); + } + float T(const float a) const { return (a - ks_) * inv_one_minus_ks_; } + float P(const float b) const { + const float t_b = T(b); + const float t_b_2 = t_b * t_b; + const float t_b_3 = t_b_2 * t_b; + return (2 * t_b_3 - 3 * t_b_2 + 1) * ks_ + + (t_b_3 - 2 * t_b_2 + t_b) * (1 - ks_) + + (-2 * t_b_3 + 3 * t_b_2) * max_lum_; + } + + const std::pair<float, float> source_range_; + const std::pair<float, float> target_range_; + const float red_Y_; + const float green_Y_; + const float blue_Y_; + + const float pq_mastering_min_ = InvEOTF(source_range_.first); + const float pq_mastering_max_ = InvEOTF(source_range_.second); + const float pq_mastering_range_ = pq_mastering_max_ - pq_mastering_min_; + const float inv_pq_mastering_range_ = 1.0f / pq_mastering_range_; + // TODO(eustas): divide instead of inverse-multiply? + const float min_lum_ = (InvEOTF(target_range_.first) - pq_mastering_min_) * + inv_pq_mastering_range_; + // TODO(eustas): divide instead of inverse-multiply? + const float max_lum_ = (InvEOTF(target_range_.second) - pq_mastering_min_) * + inv_pq_mastering_range_; + const float ks_ = 1.5f * max_lum_ - 0.5f; + + const float inv_one_minus_ks_ = 1.0f / std::max(1e-6f, 1.0f - ks_); + + const float normalizer_ = source_range_.second / target_range_.second; + const float inv_target_peak_ = 1.f / target_range_.second; +}; + +class HlgOOTF_Base { + public: + explicit HlgOOTF_Base(float source_luminance, float target_luminance, + const float primaries_luminances[3]) + : HlgOOTF_Base(/*gamma=*/std::pow(1.111f, std::log2(target_luminance / + source_luminance)), + primaries_luminances) {} + + // TODO(eustas): test me + void Apply(float* red, float* green, float* blue) const { + if (!apply_ootf_) return; + const float luminance = red_Y_ * *red + green_Y_ * *green + blue_Y_ * *blue; + const float ratio = std::min<float>(powf(luminance, exponent_), 1e9); + *red *= ratio; + *green *= ratio; + *blue *= ratio; + } + + protected: + explicit HlgOOTF_Base(float gamma, const float luminances[3]) + : exponent_(gamma - 1), + red_Y_(luminances[0]), + green_Y_(luminances[1]), + blue_Y_(luminances[2]) {} + const float exponent_; + const bool apply_ootf_ = exponent_ < -0.01f || 0.01f < exponent_; + const float red_Y_; + const float green_Y_; + const float blue_Y_; +}; + +static JXL_MAYBE_UNUSED void GamutMapScalar(float* red, float* green, + float* blue, + const float primaries_luminances[3], + float preserve_saturation = 0.1f) { + const float luminance = primaries_luminances[0] * *red + + primaries_luminances[1] * *green + + primaries_luminances[2] * *blue; + + // Desaturate out-of-gamut pixels. This is done by mixing each pixel + // with just enough gray of the target luminance to make all + // components non-negative. + // - For saturation preservation, if a component is still larger than + // 1 then the pixel is normalized to have a maximum component of 1. + // That will reduce its luminance. + // - For luminance preservation, getting all components below 1 is + // done by mixing in yet more gray. That will desaturate it further. + float gray_mix_saturation = 0.0f; + float gray_mix_luminance = 0.0f; + for (const float* ch : {red, green, blue}) { + const float& val = *ch; + const float val_minus_gray = val - luminance; + const float inv_val_minus_gray = + 1.0f / ((val_minus_gray == 0.0f) ? 1.0f : val_minus_gray); + const float val_over_val_minus_gray = val * inv_val_minus_gray; + gray_mix_saturation = + (val_minus_gray >= 0.0f) + ? gray_mix_saturation + : std::max(gray_mix_saturation, val_over_val_minus_gray); + gray_mix_luminance = + std::max(gray_mix_luminance, + (val_minus_gray <= 0.0f) + ? gray_mix_saturation + : (val_over_val_minus_gray - inv_val_minus_gray)); + } + const float gray_mix = + Clamp1((preserve_saturation * (gray_mix_saturation - gray_mix_luminance) + + gray_mix_luminance), + 0.0f, 1.0f); + for (float* const ch : {red, green, blue}) { + float& val = *ch; + val = gray_mix * (luminance - val) + val; + } + const float max_clr = std::max({1.0f, *red, *green, *blue}); + const float normalizer = 1.0f / max_clr; + for (float* const ch : {red, green, blue}) { + float& val = *ch; + val *= normalizer; + } +} + +} // namespace jxl + +#endif // LIB_JXL_CMS_TONE_MAPPING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/cms/tone_mapping_test.cc b/third_party/jpeg-xl/lib/jxl/cms/tone_mapping_test.cc new file mode 100644 index 0000000000..dda2bbb0aa --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cms/tone_mapping_test.cc @@ -0,0 +1,147 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/cms/tone_mapping_test.cc" +#include "lib/jxl/cms/tone_mapping.h" + +#include <cstdio> +#include <hwy/foreach_target.h> + +#include "lib/jxl/base/random.h" +#include "lib/jxl/cms/tone_mapping-inl.h" +#include "lib/jxl/testing.h" + +// Test utils +#include <hwy/highway.h> +#include <hwy/tests/hwy_gtest.h> +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +HWY_NOINLINE void TestRec2408ToneMap() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + float src = 11000.0 + rng.UniformF(-150.0f, 150.0f); + float tgt = 250 + rng.UniformF(-5.0f, 5.0f); + float luminances[3] = {rng.UniformF(0.2f, 0.4f), rng.UniformF(0.2f, 0.4f), + rng.UniformF(0.2f, 0.4f)}; + float rgb[3] = {rng.UniformF(0.0f, 1.0f), rng.UniformF(0.0f, 1.0f), + rng.UniformF(0.0f, 1.0f)}; + Rec2408ToneMapper<decltype(d)> tone_mapper({0, src}, {0, tgt}, luminances); + auto r = Set(d, rgb[0]); + auto g = Set(d, rgb[1]); + auto b = Set(d, rgb[2]); + tone_mapper.ToneMap(&r, &g, &b); + Rec2408ToneMapperBase tone_mapper_base({0, src}, {0, tgt}, luminances); + tone_mapper_base.ToneMap(&rgb[0], &rgb[1], &rgb[2]); + const float actual_r = GetLane(r); + const float expected_r = rgb[0]; + const float abs_err_r = std::abs(expected_r - actual_r); + EXPECT_LT(abs_err_r, 2.75e-5); + const float actual_g = GetLane(g); + const float expected_g = rgb[1]; + const float abs_err_g = std::abs(expected_g - actual_g); + EXPECT_LT(abs_err_g, 2.75e-5); + const float actual_b = GetLane(b); + const float expected_b = rgb[2]; + const float abs_err_b = std::abs(expected_b - actual_b); + EXPECT_LT(abs_err_b, 2.75e-5); + max_abs_err = std::max({max_abs_err, abs_err_r, abs_err_g, abs_err_b}); + } + printf("max abs err %e\n", static_cast<double>(max_abs_err)); +} + +HWY_NOINLINE void TestHlgOotfApply() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + float src = 300.0 + rng.UniformF(-50.0f, 50.0f); + float tgt = 80 + rng.UniformF(-5.0f, 5.0f); + float luminances[3] = {rng.UniformF(0.2f, 0.4f), rng.UniformF(0.2f, 0.4f), + rng.UniformF(0.2f, 0.4f)}; + float rgb[3] = {rng.UniformF(0.0f, 1.0f), rng.UniformF(0.0f, 1.0f), + rng.UniformF(0.0f, 1.0f)}; + HlgOOTF ootf(src, tgt, luminances); + auto r = Set(d, rgb[0]); + auto g = Set(d, rgb[1]); + auto b = Set(d, rgb[2]); + ootf.Apply(&r, &g, &b); + HlgOOTF_Base ootf_base(src, tgt, luminances); + ootf_base.Apply(&rgb[0], &rgb[1], &rgb[2]); + const float actual_r = GetLane(r); + const float expected_r = rgb[0]; + const float abs_err_r = std::abs(expected_r - actual_r); + EXPECT_LT(abs_err_r, 7.2e-7); + const float actual_g = GetLane(g); + const float expected_g = rgb[1]; + const float abs_err_g = std::abs(expected_g - actual_g); + EXPECT_LT(abs_err_g, 7.2e-7); + const float actual_b = GetLane(b); + const float expected_b = rgb[2]; + const float abs_err_b = std::abs(expected_b - actual_b); + EXPECT_LT(abs_err_b, 7.2e-7); + max_abs_err = std::max({max_abs_err, abs_err_r, abs_err_g, abs_err_b}); + } + printf("max abs err %e\n", static_cast<double>(max_abs_err)); +} + +HWY_NOINLINE void TestGamutMap() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + float preserve_saturation = rng.UniformF(0.2f, 0.4f); + float luminances[3] = {rng.UniformF(0.2f, 0.4f), rng.UniformF(0.2f, 0.4f), + rng.UniformF(0.2f, 0.4f)}; + float rgb[3] = {rng.UniformF(0.0f, 1.0f), rng.UniformF(0.0f, 1.0f), + rng.UniformF(0.0f, 1.0f)}; + auto r = Set(d, rgb[0]); + auto g = Set(d, rgb[1]); + auto b = Set(d, rgb[2]); + GamutMap(&r, &g, &b, luminances, preserve_saturation); + GamutMapScalar(&rgb[0], &rgb[1], &rgb[2], luminances, preserve_saturation); + const float actual_r = GetLane(r); + const float expected_r = rgb[0]; + const float abs_err_r = std::abs(expected_r - actual_r); + EXPECT_LT(abs_err_r, 1e-10); + const float actual_g = GetLane(g); + const float expected_g = rgb[1]; + const float abs_err_g = std::abs(expected_g - actual_g); + EXPECT_LT(abs_err_g, 1e-10); + const float actual_b = GetLane(b); + const float expected_b = rgb[2]; + const float abs_err_b = std::abs(expected_b - actual_b); + EXPECT_LT(abs_err_b, 1e-10); + max_abs_err = std::max({max_abs_err, abs_err_r, abs_err_g, abs_err_b}); + } + printf("max abs err %e\n", static_cast<double>(max_abs_err)); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class ToneMappingTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(ToneMappingTargetTest); + +HWY_EXPORT_AND_TEST_P(ToneMappingTargetTest, TestRec2408ToneMap); +HWY_EXPORT_AND_TEST_P(ToneMappingTargetTest, TestHlgOotfApply); +HWY_EXPORT_AND_TEST_P(ToneMappingTargetTest, TestGamutMap); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/cms/transfer_functions-inl.h b/third_party/jpeg-xl/lib/jxl/cms/transfer_functions-inl.h new file mode 100644 index 0000000000..84bcbb45ed --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cms/transfer_functions-inl.h @@ -0,0 +1,334 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Transfer functions for color encodings. + +#if defined(LIB_JXL_CMS_TRANSFER_FUNCTIONS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_CMS_TRANSFER_FUNCTIONS_INL_H_ +#undef LIB_JXL_CMS_TRANSFER_FUNCTIONS_INL_H_ +#else +#define LIB_JXL_CMS_TRANSFER_FUNCTIONS_INL_H_ +#endif + +#include <algorithm> +#include <cmath> +#include <hwy/highway.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/fast_math-inl.h" +#include "lib/jxl/base/rational_polynomial-inl.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/cms/transfer_functions.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::And; +using hwy::HWY_NAMESPACE::AndNot; +using hwy::HWY_NAMESPACE::Gt; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::Or; +using hwy::HWY_NAMESPACE::Sqrt; +using hwy::HWY_NAMESPACE::TableLookupBytes; + +// Definitions for BT.2100-2 transfer functions (used inside/outside SIMD): +// "display" is linear light (nits) normalized to [0, 1]. +// "encoded" is a nonlinear encoding (e.g. PQ) in [0, 1]. +// "scene" is a linear function of photon counts, normalized to [0, 1]. + +// Despite the stated ranges, we need unbounded transfer functions: see +// http://www.littlecms.com/CIC18_UnboundedCMM.pdf. Inputs can be negative or +// above 1 due to chromatic adaptation. To avoid severe round-trip errors caused +// by clamping, we mirror negative inputs via copysign (f(-x) = -f(x), see +// https://developer.apple.com/documentation/coregraphics/cgcolorspace/1644735-extendedsrgb) +// and extend the function domains above 1. + +// Hybrid Log-Gamma. +class TF_HLG : TF_HLG_Base { + public: + // Maximum error 5e-7. + template <class D, class V> + JXL_INLINE V EncodedFromDisplay(D d, V x) const { + const hwy::HWY_NAMESPACE::Rebind<uint32_t, D> du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + const V below_div12 = Sqrt(Mul(Set(d, 3.0f), x)); + const V e = + MulAdd(Set(d, kA * 0.693147181f), + FastLog2f(d, MulAdd(Set(d, 12), x, Set(d, -kB))), Set(d, kC)); + const V magnitude = IfThenElse(Le(x, Set(d, kDiv12)), below_div12, e); + return Or(AndNot(kSign, magnitude), original_sign); + } +}; + +class TF_709 { + public: + JXL_INLINE double EncodedFromDisplay(const double d) const { + if (d < kThresh) return kMulLow * d; + return kMulHi * std::pow(d, kPowHi) + kSub; + } + + // Maximum error 1e-6. + template <class D, class V> + JXL_INLINE V EncodedFromDisplay(D d, V x) const { + auto low = Mul(Set(d, kMulLow), x); + auto hi = + MulAdd(Set(d, kMulHi), FastPowf(d, x, Set(d, kPowHi)), Set(d, kSub)); + return IfThenElse(Le(x, Set(d, kThresh)), low, hi); + } + + template <class D, class V> + JXL_INLINE V DisplayFromEncoded(D d, V x) const { + auto low = Mul(Set(d, kInvMulLow), x); + auto hi = FastPowf(d, MulAdd(x, Set(d, kInvMulHi), Set(d, kInvAdd)), + Set(d, kInvPowHi)); + return IfThenElse(Lt(x, Set(d, kInvThresh)), low, hi); + } + + private: + static constexpr double kThresh = 0.018; + static constexpr double kMulLow = 4.5; + static constexpr double kMulHi = 1.099; + static constexpr double kPowHi = 0.45; + static constexpr double kSub = -0.099; + + static constexpr double kInvThresh = 0.081; + static constexpr double kInvMulLow = 1 / 4.5; + static constexpr double kInvMulHi = 1 / 1.099; + static constexpr double kInvPowHi = 1 / 0.45; + static constexpr double kInvAdd = 0.099 * kInvMulHi; +}; + +// Perceptual Quantization +class TF_PQ : TF_PQ_Base { + public: + explicit TF_PQ(float display_intensity_target = kDefaultIntensityTarget) + : display_scaling_factor_to_10000_nits_(display_intensity_target * + (1.0f / 10000.0f)), + display_scaling_factor_from_10000_nits_(10000.0f / + display_intensity_target) {} + + // Maximum error 3e-6 + template <class D, class V> + JXL_INLINE V DisplayFromEncoded(D d, V x) const { + const hwy::HWY_NAMESPACE::Rebind<uint32_t, D> du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + // 4-over-4-degree rational polynomial approximation on x+x*x. This improves + // the maximum error by about 5x over a rational polynomial for x. + auto xpxx = MulAdd(x, x, x); + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + HWY_REP4(2.62975656e-04f), HWY_REP4(-6.23553089e-03f), + HWY_REP4(7.38602301e-01f), HWY_REP4(2.64553172e+00f), + HWY_REP4(5.50034862e-01f), + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + HWY_REP4(4.21350107e+02f), HWY_REP4(-4.28736818e+02f), + HWY_REP4(1.74364667e+02f), HWY_REP4(-3.39078883e+01f), + HWY_REP4(2.67718770e+00f), + }; + auto magnitude = EvalRationalPolynomial(d, xpxx, p, q); + return Or( + AndNot(kSign, + Mul(magnitude, Set(d, display_scaling_factor_from_10000_nits_))), + original_sign); + } + + // Maximum error 7e-7. + template <class D, class V> + JXL_INLINE V EncodedFromDisplay(D d, V x) const { + const hwy::HWY_NAMESPACE::Rebind<uint32_t, D> du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + // 4-over-4-degree rational polynomial approximation on x**0.25, with two + // different polynomials above and below 1e-4. + auto xto025 = + Sqrt(Sqrt(Mul(x, Set(d, display_scaling_factor_to_10000_nits_)))); + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + HWY_REP4(1.351392e-02f), HWY_REP4(-1.095778e+00f), + HWY_REP4(5.522776e+01f), HWY_REP4(1.492516e+02f), + HWY_REP4(4.838434e+01f), + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + HWY_REP4(1.012416e+00f), HWY_REP4(2.016708e+01f), + HWY_REP4(9.263710e+01f), HWY_REP4(1.120607e+02f), + HWY_REP4(2.590418e+01f), + }; + + HWY_ALIGN constexpr float plo[(4 + 1) * 4] = { + HWY_REP4(9.863406e-06f), HWY_REP4(3.881234e-01f), + HWY_REP4(1.352821e+02f), HWY_REP4(6.889862e+04f), + HWY_REP4(-2.864824e+05f), + }; + HWY_ALIGN constexpr float qlo[(4 + 1) * 4] = { + HWY_REP4(3.371868e+01f), HWY_REP4(1.477719e+03f), + HWY_REP4(1.608477e+04f), HWY_REP4(-4.389884e+04f), + HWY_REP4(-2.072546e+05f), + }; + + auto magnitude = IfThenElse(Lt(x, Set(d, 1e-4f)), + EvalRationalPolynomial(d, xto025, plo, qlo), + EvalRationalPolynomial(d, xto025, p, q)); + return Or(AndNot(kSign, magnitude), original_sign); + } + + private: + const float display_scaling_factor_to_10000_nits_; + const float display_scaling_factor_from_10000_nits_; +}; + +// sRGB +class TF_SRGB { + public: + template <typename V> + JXL_INLINE V DisplayFromEncoded(V x) const { + const HWY_FULL(float) d; + const HWY_FULL(uint32_t) du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + + // TODO(janwas): range reduction + // Computed via af_cheb_rational (k=100); replicated 4x. + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + 2.200248328e-04f, 2.200248328e-04f, 2.200248328e-04f, 2.200248328e-04f, + 1.043637593e-02f, 1.043637593e-02f, 1.043637593e-02f, 1.043637593e-02f, + 1.624820318e-01f, 1.624820318e-01f, 1.624820318e-01f, 1.624820318e-01f, + 7.961564959e-01f, 7.961564959e-01f, 7.961564959e-01f, 7.961564959e-01f, + 8.210152774e-01f, 8.210152774e-01f, 8.210152774e-01f, 8.210152774e-01f, + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + 2.631846970e-01f, 2.631846970e-01f, 2.631846970e-01f, + 2.631846970e-01f, 1.076976492e+00f, 1.076976492e+00f, + 1.076976492e+00f, 1.076976492e+00f, 4.987528350e-01f, + 4.987528350e-01f, 4.987528350e-01f, 4.987528350e-01f, + -5.512498495e-02f, -5.512498495e-02f, -5.512498495e-02f, + -5.512498495e-02f, 6.521209011e-03f, 6.521209011e-03f, + 6.521209011e-03f, 6.521209011e-03f, + }; + const V linear = Mul(x, Set(d, kLowDivInv)); + const V poly = EvalRationalPolynomial(d, x, p, q); + const V magnitude = + IfThenElse(Gt(x, Set(d, kThreshSRGBToLinear)), poly, linear); + return Or(AndNot(kSign, magnitude), original_sign); + } + + // Error ~5e-07 + template <class D, class V> + JXL_INLINE V EncodedFromDisplay(D d, V x) const { + const hwy::HWY_NAMESPACE::Rebind<uint32_t, D> du; + const V kSign = BitCast(d, Set(du, 0x80000000u)); + const V original_sign = And(x, kSign); + x = AndNot(kSign, x); // abs + + // Computed via af_cheb_rational (k=100); replicated 4x. + HWY_ALIGN constexpr float p[(4 + 1) * 4] = { + -5.135152395e-04f, -5.135152395e-04f, -5.135152395e-04f, + -5.135152395e-04f, 5.287254571e-03f, 5.287254571e-03f, + 5.287254571e-03f, 5.287254571e-03f, 3.903842876e-01f, + 3.903842876e-01f, 3.903842876e-01f, 3.903842876e-01f, + 1.474205315e+00f, 1.474205315e+00f, 1.474205315e+00f, + 1.474205315e+00f, 7.352629620e-01f, 7.352629620e-01f, + 7.352629620e-01f, 7.352629620e-01f, + }; + HWY_ALIGN constexpr float q[(4 + 1) * 4] = { + 1.004519624e-02f, 1.004519624e-02f, 1.004519624e-02f, 1.004519624e-02f, + 3.036675394e-01f, 3.036675394e-01f, 3.036675394e-01f, 3.036675394e-01f, + 1.340816930e+00f, 1.340816930e+00f, 1.340816930e+00f, 1.340816930e+00f, + 9.258482155e-01f, 9.258482155e-01f, 9.258482155e-01f, 9.258482155e-01f, + 2.424867759e-02f, 2.424867759e-02f, 2.424867759e-02f, 2.424867759e-02f, + }; + const V linear = Mul(x, Set(d, kLowDiv)); + const V poly = EvalRationalPolynomial(d, Sqrt(x), p, q); + const V magnitude = + IfThenElse(Gt(x, Set(d, kThreshLinearToSRGB)), poly, linear); + return Or(AndNot(kSign, magnitude), original_sign); + } + + private: + static constexpr float kThreshSRGBToLinear = 0.04045f; + static constexpr float kThreshLinearToSRGB = 0.0031308f; + static constexpr float kLowDiv = 12.92f; + static constexpr float kLowDivInv = 1.0f / kLowDiv; +}; + +// Linear to sRGB conversion with error of at most 1.2e-4. +template <typename D, typename V> +V FastLinearToSRGB(D d, V v) { + const hwy::HWY_NAMESPACE::Rebind<uint32_t, D> du; + const hwy::HWY_NAMESPACE::Rebind<int32_t, D> di; + // Convert to 0.25 - 0.5 range. + auto v025_05 = BitCast( + d, And(Or(BitCast(du, v), Set(du, 0x3e800000)), Set(du, 0x3effffff))); + // third degree polynomial approximation between 0.25 and 0.5 + // of 1.055/2^(7/2.4) * x^(1/2.4) * 0.5. A degree 4 polynomial only improves + // accuracy by about 3x. + auto d1 = MulAdd(v025_05, Set(d, 0.059914046f), Set(d, -0.108894556f)); + auto d2 = MulAdd(d1, v025_05, Set(d, 0.107963754f)); + auto pow = MulAdd(d2, v025_05, Set(d, 0.018092343f)); + // Compute extra multiplier depending on exponent. Valid exponent range for + // [0.0031308f, 1.0) is 0...8 after subtracting 118. + // The next three constants contain a representation of the powers of + // 2**(1/2.4) = 2**(5/12) times two; in particular, bits from 26 to 31 are + // always the same and in k2to512powers_basebits, and the two arrays contain + // the next groups of 8 bits. This ends up being a 22-bit representation (with + // a mantissa of 13 bits). The choice of polynomial to approximate is such + // that the multiplication factor has the highest 5 bits constant, and that + // the factor for the lowest possible exponent is a power of two (thus making + // the additional bits 0, which is used to correctly merge back together the + // floats). + constexpr uint32_t k2to512powers_basebits = 0x40000000; + HWY_ALIGN constexpr uint8_t k2to512powers_25to18bits[16] = { + 0x0, 0xa, 0x19, 0x26, 0x32, 0x41, 0x4d, 0x5c, + 0x68, 0x75, 0x83, 0x8f, 0xa0, 0xaa, 0xb9, 0xc6, + }; + HWY_ALIGN constexpr uint8_t k2to512powers_17to10bits[16] = { + 0x0, 0xb7, 0x4, 0xd, 0xcb, 0xe7, 0x41, 0x68, + 0x51, 0xd1, 0xeb, 0xf2, 0x0, 0xb7, 0x4, 0xd, + }; + // Note that vld1q_s8_x2 on ARM seems to actually be slower. +#if HWY_TARGET != HWY_SCALAR + using hwy::HWY_NAMESPACE::ShiftLeft; + using hwy::HWY_NAMESPACE::ShiftRight; + // Every lane of exp is now (if cast to byte) {0, 0, 0, <index for lookup>}. + auto exp = Sub(ShiftRight<23>(BitCast(di, v)), Set(di, 118)); + auto pow25to18bits = TableLookupBytes( + LoadDup128(di, + reinterpret_cast<const int32_t*>(k2to512powers_25to18bits)), + exp); + auto pow17to10bits = TableLookupBytes( + LoadDup128(di, + reinterpret_cast<const int32_t*>(k2to512powers_17to10bits)), + exp); + // Now, pow* contain {0, 0, 0, <part of float repr of multiplier>}. Here + // we take advantage of the fact that each table has its position 0 equal to + // 0. + // We can now just reassemble the float. + auto mul = BitCast( + d, Or(Or(ShiftLeft<18>(pow25to18bits), ShiftLeft<10>(pow17to10bits)), + Set(di, k2to512powers_basebits))); +#else + // Fallback for scalar. + uint32_t exp = ((BitCast(di, v).raw >> 23) - 118) & 0xf; + auto mul = BitCast(d, Set(di, (k2to512powers_25to18bits[exp] << 18) | + (k2to512powers_17to10bits[exp] << 10) | + k2to512powers_basebits)); +#endif + return IfThenElse(Lt(v, Set(d, 0.0031308f)), Mul(v, Set(d, 12.92f)), + MulAdd(pow, mul, Set(d, -0.055))); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_CMS_TRANSFER_FUNCTIONS_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/cms/transfer_functions.h b/third_party/jpeg-xl/lib/jxl/cms/transfer_functions.h new file mode 100644 index 0000000000..4e5273d5d3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cms/transfer_functions.h @@ -0,0 +1,131 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Transfer functions for color encodings. + +#ifndef LIB_JXL_CMS_TRANSFER_FUNCTIONS_H_ +#define LIB_JXL_CMS_TRANSFER_FUNCTIONS_H_ + +#include <algorithm> +#include <cmath> + +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Definitions for BT.2100-2 transfer functions (used inside/outside SIMD): +// "display" is linear light (nits) normalized to [0, 1]. +// "encoded" is a nonlinear encoding (e.g. PQ) in [0, 1]. +// "scene" is a linear function of photon counts, normalized to [0, 1]. + +// Despite the stated ranges, we need unbounded transfer functions: see +// http://www.littlecms.com/CIC18_UnboundedCMM.pdf. Inputs can be negative or +// above 1 due to chromatic adaptation. To avoid severe round-trip errors caused +// by clamping, we mirror negative inputs via copysign (f(-x) = -f(x), see +// https://developer.apple.com/documentation/coregraphics/cgcolorspace/1644735-extendedsrgb) +// and extend the function domains above 1. + +// Hybrid Log-Gamma. +class TF_HLG_Base { + public: + // EOTF. e = encoded. + static double DisplayFromEncoded(const double e) { return OOTF(InvOETF(e)); } + + // Inverse EOTF. d = display. + static double EncodedFromDisplay(const double d) { return OETF(InvOOTF(d)); } + + private: + // OETF (defines the HLG approach). s = scene, returns encoded. + static double OETF(double s) { + if (s == 0.0) return 0.0; + const double original_sign = s; + s = std::abs(s); + + if (s <= kDiv12) return copysignf(std::sqrt(3.0 * s), original_sign); + + const double e = kA * std::log(12 * s - kB) + kC; + JXL_ASSERT(e > 0.0); + return copysignf(e, original_sign); + } + + // e = encoded, returns scene. + static double InvOETF(double e) { + if (e == 0.0) return 0.0; + const double original_sign = e; + e = std::abs(e); + + if (e <= 0.5) return copysignf(e * e * (1.0 / 3), original_sign); + + const double s = (std::exp((e - kC) * kRA) + kB) * kDiv12; + JXL_ASSERT(s >= 0); + return copysignf(s, original_sign); + } + + // s = scene, returns display. + static double OOTF(const double s) { + // The actual (red channel) OOTF is RD = alpha * YS^(gamma-1) * RS, where + // YS = 0.2627 * RS + 0.6780 * GS + 0.0593 * BS. Let alpha = 1 so we return + // "display" (normalized [0, 1]) instead of nits. Our transfer function + // interface does not allow a dependency on YS. Fortunately, the system + // gamma at 334 nits is 1.0, so this reduces to RD = RS. + return s; + } + + // d = display, returns scene. + static double InvOOTF(const double d) { + return d; // see OOTF(). + } + + protected: + static constexpr double kA = 0.17883277; + static constexpr double kRA = 1.0 / kA; + static constexpr double kB = 1 - 4 * kA; + static constexpr double kC = 0.5599107295; + static constexpr double kDiv12 = 1.0 / 12; +}; + +// Perceptual Quantization +class TF_PQ_Base { + public: + static double DisplayFromEncoded(float display_intensity_target, double e) { + if (e == 0.0) return 0.0; + const double original_sign = e; + e = std::abs(e); + + const double xp = std::pow(e, 1.0 / kM2); + const double num = std::max(xp - kC1, 0.0); + const double den = kC2 - kC3 * xp; + JXL_DASSERT(den != 0.0); + const double d = std::pow(num / den, 1.0 / kM1); + JXL_DASSERT(d >= 0.0); // Equal for e ~= 1E-9 + return copysignf(d * (10000.0f / display_intensity_target), original_sign); + } + + // Inverse EOTF. d = display. + static double EncodedFromDisplay(float display_intensity_target, double d) { + if (d == 0.0) return 0.0; + const double original_sign = d; + d = std::abs(d); + + const double xp = + std::pow(d * (display_intensity_target * (1.0f / 10000.0f)), kM1); + const double num = kC1 + xp * kC2; + const double den = 1.0 + xp * kC3; + const double e = std::pow(num / den, kM2); + JXL_DASSERT(e > 0.0); + return copysignf(e, original_sign); + } + + protected: + static constexpr double kM1 = 2610.0 / 16384; + static constexpr double kM2 = (2523.0 / 4096) * 128; + static constexpr double kC1 = 3424.0 / 4096; + static constexpr double kC2 = (2413.0 / 4096) * 32; + static constexpr double kC3 = (2392.0 / 4096) * 32; +}; + +} // namespace jxl + +#endif // LIB_JXL_CMS_TRANSFER_FUNCTIONS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/cms/transfer_functions_test.cc b/third_party/jpeg-xl/lib/jxl/cms/transfer_functions_test.cc new file mode 100644 index 0000000000..26de409a4e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/cms/transfer_functions_test.cc @@ -0,0 +1,94 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/cms/transfer_functions_test.cc" +#include "lib/jxl/cms/transfer_functions.h" + +#include <cstdio> +#include <hwy/foreach_target.h> + +#include "lib/jxl/base/random.h" +#include "lib/jxl/cms/transfer_functions-inl.h" +#include "lib/jxl/testing.h" + +// Test utils +#include <hwy/highway.h> +#include <hwy/tests/hwy_gtest.h> +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +HWY_NOINLINE void TestPqEncodedFromDisplay() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + double intensity = 11000.0 + rng.UniformF(-150.0f, 150.0f); + TF_PQ tf_pq(intensity); + const float f = rng.UniformF(0.0f, 1.0f); + const float actual = GetLane(tf_pq.EncodedFromDisplay(d, Set(d, f))); + const float expected = TF_PQ_Base::EncodedFromDisplay(intensity, f); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 5e-7) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast<double>(max_abs_err)); +} + +HWY_NOINLINE void TestHlgEncodedFromDisplay() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(0.0f, 1.0f); + const float actual = GetLane(TF_HLG().EncodedFromDisplay(d, Set(d, f))); + const float expected = TF_HLG_Base::EncodedFromDisplay(f); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 4e-7) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast<double>(max_abs_err)); +} + +HWY_NOINLINE void TestPqDisplayFromEncoded() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + double intensity = 11000.0 + rng.UniformF(-150.0f, 150.0f); + TF_PQ tf_pq(intensity); + const float f = rng.UniformF(0.0f, 1.0f); + const float actual = GetLane(tf_pq.DisplayFromEncoded(d, Set(d, f))); + const float expected = TF_PQ_Base::DisplayFromEncoded(intensity, f); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 3E-6) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast<double>(max_abs_err)); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class TransferFunctionsTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(TransferFunctionsTargetTest); + +HWY_EXPORT_AND_TEST_P(TransferFunctionsTargetTest, TestPqEncodedFromDisplay); +HWY_EXPORT_AND_TEST_P(TransferFunctionsTargetTest, TestHlgEncodedFromDisplay); +HWY_EXPORT_AND_TEST_P(TransferFunctionsTargetTest, TestPqDisplayFromEncoded); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/codec_in_out.h b/third_party/jpeg-xl/lib/jxl/codec_in_out.h new file mode 100644 index 0000000000..028f3ecaac --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/codec_in_out.h @@ -0,0 +1,115 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_CODEC_IN_OUT_H_ +#define LIB_JXL_CODEC_IN_OUT_H_ + +// Holds inputs/outputs for decoding/encoding images. + +#include <stddef.h> +#include <stdint.h> + +#include <type_traits> +#include <utility> +#include <vector> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/luminance.h" + +namespace jxl { + +// Optional text/EXIF metadata. +struct Blobs { + std::vector<uint8_t> exif; + std::vector<uint8_t> iptc; + std::vector<uint8_t> jumbf; + std::vector<uint8_t> xmp; +}; + +// Holds a preview, a main image or one or more frames, plus the inputs/outputs +// to/from decoding/encoding. +class CodecInOut { + public: + CodecInOut() : preview_frame(&metadata.m) { + frames.reserve(1); + frames.emplace_back(&metadata.m); + } + + // Move-only. + CodecInOut(CodecInOut&&) = default; + CodecInOut& operator=(CodecInOut&&) = default; + + size_t LastStillFrame() const { + JXL_DASSERT(!frames.empty()); + size_t last = 0; + for (size_t i = 0; i < frames.size(); i++) { + last = i; + if (frames[i].duration > 0) break; + } + return last; + } + + ImageBundle& Main() { return frames[LastStillFrame()]; } + const ImageBundle& Main() const { return frames[LastStillFrame()]; } + + // If c_current.IsGray(), all planes must be identical. + void SetFromImage(Image3F&& color, const ColorEncoding& c_current) { + Main().SetFromImage(std::move(color), c_current); + SetIntensityTarget(&this->metadata.m); + SetSize(Main().xsize(), Main().ysize()); + } + + void SetSize(size_t xsize, size_t ysize) { + JXL_CHECK(metadata.size.Set(xsize, ysize)); + } + + void CheckMetadata() const { + JXL_CHECK(metadata.m.bit_depth.bits_per_sample != 0); + JXL_CHECK(!metadata.m.color_encoding.ICC().empty()); + + if (preview_frame.xsize() != 0) preview_frame.VerifyMetadata(); + JXL_CHECK(preview_frame.metadata() == &metadata.m); + + for (const ImageBundle& ib : frames) { + ib.VerifyMetadata(); + JXL_CHECK(ib.metadata() == &metadata.m); + } + } + + size_t xsize() const { return metadata.size.xsize(); } + size_t ysize() const { return metadata.size.ysize(); } + void ShrinkTo(size_t xsize, size_t ysize) { + // preview is unaffected. + for (ImageBundle& ib : frames) { + ib.ShrinkTo(xsize, ysize); + } + SetSize(xsize, ysize); + } + + // -- DECODER OUTPUT, ENCODER INPUT: + + // Metadata stored into / retrieved from bitstreams. + + Blobs blobs; + + CodecMetadata metadata; // applies to preview and all frames + + // If metadata.have_preview: + ImageBundle preview_frame; + + std::vector<ImageBundle> frames; // size=1 if !metadata.have_animation + + // If the image should be written to a JPEG, use this quality for encoding. + size_t jpeg_quality; +}; + +} // namespace jxl + +#endif // LIB_JXL_CODEC_IN_OUT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/coeff_order.cc b/third_party/jpeg-xl/lib/jxl/coeff_order.cc new file mode 100644 index 0000000000..296a7cb2f0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/coeff_order.cc @@ -0,0 +1,150 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/coeff_order.h" + +#include <stdint.h> + +#include <algorithm> +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/lehmer_code.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +uint32_t CoeffOrderContext(uint32_t val) { + uint32_t token, nbits, bits; + HybridUintConfig(0, 0, 0).Encode(val, &token, &nbits, &bits); + return std::min(token, kPermutationContexts - 1); +} + +namespace { +Status ReadPermutation(size_t skip, size_t size, coeff_order_t* order, + BitReader* br, ANSSymbolReader* reader, + const std::vector<uint8_t>& context_map) { + std::vector<LehmerT> lehmer(size); + // temp space needs to be as large as the next power of 2, so doubling the + // allocated size is enough. + std::vector<uint32_t> temp(size * 2); + uint32_t end = + reader->ReadHybridUint(CoeffOrderContext(size), br, context_map) + skip; + if (end > size) { + return JXL_FAILURE("Invalid permutation size"); + } + uint32_t last = 0; + for (size_t i = skip; i < end; ++i) { + lehmer[i] = + reader->ReadHybridUint(CoeffOrderContext(last), br, context_map); + last = lehmer[i]; + if (lehmer[i] >= size - i) { + return JXL_FAILURE("Invalid lehmer code"); + } + } + if (order == nullptr) return true; + DecodeLehmerCode(lehmer.data(), temp.data(), size, order); + return true; +} + +} // namespace + +Status DecodePermutation(size_t skip, size_t size, coeff_order_t* order, + BitReader* br) { + std::vector<uint8_t> context_map; + ANSCode code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kPermutationContexts, &code, &context_map)); + ANSSymbolReader reader(&code, br); + JXL_RETURN_IF_ERROR( + ReadPermutation(skip, size, order, br, &reader, context_map)); + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("Invalid ANS stream"); + } + return true; +} + +namespace { + +Status DecodeCoeffOrder(AcStrategy acs, coeff_order_t* order, BitReader* br, + ANSSymbolReader* reader, + std::vector<coeff_order_t>& natural_order, + const std::vector<uint8_t>& context_map) { + const size_t llf = acs.covered_blocks_x() * acs.covered_blocks_y(); + const size_t size = kDCTBlockSize * llf; + + JXL_RETURN_IF_ERROR( + ReadPermutation(llf, size, order, br, reader, context_map)); + if (order == nullptr) return true; + for (size_t k = 0; k < size; ++k) { + order[k] = natural_order[order[k]]; + } + return true; +} + +} // namespace + +Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs, + coeff_order_t* order, BitReader* br) { + uint16_t computed = 0; + std::vector<uint8_t> context_map; + ANSCode code; + std::unique_ptr<ANSSymbolReader> reader; + std::vector<coeff_order_t> natural_order; + // Bitstream does not have histograms if no coefficient order is used. + if (used_orders != 0) { + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kPermutationContexts, &code, &context_map)); + reader = make_unique<ANSSymbolReader>(&code, br); + } + uint32_t acs_mask = 0; + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + if ((used_acs & (1 << o)) == 0) continue; + acs_mask |= 1 << kStrategyOrder[o]; + } + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + uint8_t ord = kStrategyOrder[o]; + if (computed & (1 << ord)) continue; + computed |= 1 << ord; + AcStrategy acs = AcStrategy::FromRawStrategy(o); + bool used = (acs_mask & (1 << ord)) != 0; + + const size_t llf = acs.covered_blocks_x() * acs.covered_blocks_y(); + const size_t size = kDCTBlockSize * llf; + + if (used || (used_orders & (1 << ord))) { + if (natural_order.size() < size) natural_order.resize(size); + acs.ComputeNaturalCoeffOrder(natural_order.data()); + } + + if ((used_orders & (1 << ord)) == 0) { + // No need to set the default order if no ACS uses this order. + if (used) { + for (size_t c = 0; c < 3; c++) { + memcpy(&order[CoeffOrderOffset(ord, c)], natural_order.data(), + size * sizeof(*order)); + } + } + } else { + for (size_t c = 0; c < 3; c++) { + coeff_order_t* dest = used ? &order[CoeffOrderOffset(ord, c)] : nullptr; + JXL_RETURN_IF_ERROR(DecodeCoeffOrder(acs, dest, br, reader.get(), + natural_order, context_map)); + } + } + } + if (used_orders && !reader->CheckANSFinalState()) { + return JXL_FAILURE("Invalid ANS stream"); + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/coeff_order.h b/third_party/jpeg-xl/lib/jxl/coeff_order.h new file mode 100644 index 0000000000..75f6f99e9f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/coeff_order.h @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_COEFF_ORDER_H_ +#define LIB_JXL_COEFF_ORDER_H_ + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/frame_dimensions.h" + +namespace jxl { + +class BitReader; + +// Those offsets get multiplied by kDCTBlockSize. +static constexpr size_t kCoeffOrderOffset[] = { + 0, 1, 2, 3, 4, 5, 6, 10, 14, 18, + 34, 50, 66, 68, 70, 72, 76, 80, 84, 92, + 100, 108, 172, 236, 300, 332, 364, 396, 652, 908, + 1164, 1292, 1420, 1548, 2572, 3596, 4620, 5132, 5644, 6156, +}; +static_assert(3 * kNumOrders + 1 == + sizeof(kCoeffOrderOffset) / sizeof(*kCoeffOrderOffset), + "Update this array when adding or removing order types."); + +static constexpr size_t CoeffOrderOffset(size_t order, size_t c) { + return kCoeffOrderOffset[3 * order + c] * kDCTBlockSize; +} + +static constexpr size_t kCoeffOrderMaxSize = + kCoeffOrderOffset[3 * kNumOrders] * kDCTBlockSize; + +// Mapping from AC strategy to order bucket. Strategies with different natural +// orders must have different buckets. +constexpr uint8_t kStrategyOrder[] = { + 0, 1, 1, 1, 2, 3, 4, 4, 5, 5, 6, 6, 1, 1, + 1, 1, 1, 1, 7, 8, 8, 9, 10, 10, 11, 12, 12, +}; + +static_assert(AcStrategy::kNumValidStrategies == + sizeof(kStrategyOrder) / sizeof(*kStrategyOrder), + "Update this array when adding or removing AC strategies."); + +constexpr uint32_t kPermutationContexts = 8; + +uint32_t CoeffOrderContext(uint32_t val); + +Status DecodeCoeffOrders(uint16_t used_orders, uint32_t used_acs, + coeff_order_t* order, BitReader* br); + +Status DecodePermutation(size_t skip, size_t size, coeff_order_t* order, + BitReader* br); + +} // namespace jxl + +#endif // LIB_JXL_COEFF_ORDER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/coeff_order_fwd.h b/third_party/jpeg-xl/lib/jxl/coeff_order_fwd.h new file mode 100644 index 0000000000..26306575c1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/coeff_order_fwd.h @@ -0,0 +1,47 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_COEFF_ORDER_FWD_H_ +#define LIB_JXL_COEFF_ORDER_FWD_H_ + +// Breaks circular dependency between ac_strategy and coeff_order. + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// Needs at least 16 bits. A 32-bit type speeds up DecodeAC by 2% at the cost of +// more memory. +using coeff_order_t = uint32_t; + +// Maximum number of orders to be used. Note that this needs to be multiplied by +// the number of channels. One per "size class" (plus one extra for DCT8), +// shared between transforms of size XxY and of size YxX. +constexpr uint8_t kNumOrders = 13; + +// DCT coefficients are laid out in such a way that the number of rows of +// coefficients is always the smaller coordinate. +JXL_INLINE constexpr size_t CoefficientRows(size_t rows, size_t columns) { + return rows < columns ? rows : columns; +} + +JXL_INLINE constexpr size_t CoefficientColumns(size_t rows, size_t columns) { + return rows < columns ? columns : rows; +} + +JXL_INLINE void CoefficientLayout(size_t* JXL_RESTRICT rows, + size_t* JXL_RESTRICT columns) { + size_t r = *rows; + size_t c = *columns; + *rows = CoefficientRows(r, c); + *columns = CoefficientColumns(r, c); +} + +} // namespace jxl + +#endif // LIB_JXL_COEFF_ORDER_FWD_H_ diff --git a/third_party/jpeg-xl/lib/jxl/coeff_order_test.cc b/third_party/jpeg-xl/lib/jxl/coeff_order_test.cc new file mode 100644 index 0000000000..a88dcfa274 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/coeff_order_test.cc @@ -0,0 +1,95 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/coeff_order.h" + +#include <algorithm> +#include <numeric> // iota +#include <utility> +#include <vector> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_coeff_order.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +void RoundtripPermutation(coeff_order_t* perm, coeff_order_t* out, size_t len, + size_t* size) { + BitWriter writer; + EncodePermutation(perm, 0, len, &writer, 0, nullptr); + writer.ZeroPadToByte(); + Status status = true; + { + BitReader reader(writer.GetSpan()); + BitReaderScopedCloser closer(&reader, &status); + ASSERT_TRUE(DecodePermutation(0, len, out, &reader)); + } + ASSERT_TRUE(status); + *size = writer.GetSpan().size(); +} + +enum Permutation { kIdentity, kFewSwaps, kFewSlides, kRandom }; + +constexpr size_t kSwaps = 32; + +void TestPermutation(Permutation kind, size_t len) { + std::vector<coeff_order_t> perm(len); + std::iota(perm.begin(), perm.end(), 0); + Rng rng(0); + if (kind == kFewSwaps) { + for (size_t i = 0; i < kSwaps; i++) { + size_t a = rng.UniformU(0, len - 1); + size_t b = rng.UniformU(0, len - 1); + std::swap(perm[a], perm[b]); + } + } + if (kind == kFewSlides) { + for (size_t i = 0; i < kSwaps; i++) { + size_t a = rng.UniformU(0, len - 1); + size_t b = rng.UniformU(0, len - 1); + size_t from = std::min(a, b); + size_t to = std::max(a, b); + size_t start = perm[from]; + for (size_t j = from; j < to; j++) { + perm[j] = perm[j + 1]; + } + perm[to] = start; + } + } + if (kind == kRandom) { + rng.Shuffle(perm.data(), perm.size()); + } + std::vector<coeff_order_t> out(len); + size_t size = 0; + RoundtripPermutation(perm.data(), out.data(), len, &size); + for (size_t idx = 0; idx < len; idx++) { + EXPECT_EQ(perm[idx], out[idx]); + } + printf("Encoded size: %" PRIuS "\n", size); +} + +TEST(CoeffOrderTest, IdentitySmall) { TestPermutation(kIdentity, 256); } +TEST(CoeffOrderTest, FewSlidesSmall) { TestPermutation(kFewSlides, 256); } +TEST(CoeffOrderTest, FewSwapsSmall) { TestPermutation(kFewSwaps, 256); } +TEST(CoeffOrderTest, RandomSmall) { TestPermutation(kRandom, 256); } + +TEST(CoeffOrderTest, IdentityMedium) { TestPermutation(kIdentity, 1 << 12); } +TEST(CoeffOrderTest, FewSlidesMedium) { TestPermutation(kFewSlides, 1 << 12); } +TEST(CoeffOrderTest, FewSwapsMedium) { TestPermutation(kFewSwaps, 1 << 12); } +TEST(CoeffOrderTest, RandomMedium) { TestPermutation(kRandom, 1 << 12); } + +TEST(CoeffOrderTest, IdentityBig) { TestPermutation(kIdentity, 1 << 16); } +TEST(CoeffOrderTest, FewSlidesBig) { TestPermutation(kFewSlides, 1 << 16); } +TEST(CoeffOrderTest, FewSwapsBig) { TestPermutation(kFewSwaps, 1 << 16); } +TEST(CoeffOrderTest, RandomBig) { TestPermutation(kRandom, 1 << 16); } + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/color_encoding_internal.cc b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.cc new file mode 100644 index 0000000000..19273dad3c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.cc @@ -0,0 +1,208 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/color_encoding_internal.h" + +#include <array> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/cms/color_encoding_cms.h" +#include "lib/jxl/cms/jxl_cms_internal.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/pack_signed.h" + +namespace jxl { + +bool CustomTransferFunction::SetImplicit() { + if (nonserialized_color_space == ColorSpace::kXYB) { + if (!storage_.SetGamma(1.0 / 3)) JXL_ASSERT(false); + return true; + } + return false; +} + +std::array<ColorEncoding, 2> ColorEncoding::CreateC2(Primaries pr, + TransferFunction tf) { + std::array<ColorEncoding, 2> c2; + + ColorEncoding* c_rgb = c2.data() + 0; + c_rgb->SetColorSpace(ColorSpace::kRGB); + c_rgb->storage_.white_point = WhitePoint::kD65; + c_rgb->storage_.primaries = pr; + c_rgb->storage_.tf.SetTransferFunction(tf); + JXL_CHECK(c_rgb->CreateICC()); + + ColorEncoding* c_gray = c2.data() + 1; + c_gray->SetColorSpace(ColorSpace::kGray); + c_gray->storage_.white_point = WhitePoint::kD65; + c_gray->storage_.primaries = pr; + c_gray->storage_.tf.SetTransferFunction(tf); + JXL_CHECK(c_gray->CreateICC()); + + return c2; +} + +const ColorEncoding& ColorEncoding::SRGB(bool is_gray) { + static std::array<ColorEncoding, 2> c2 = + CreateC2(Primaries::kSRGB, TransferFunction::kSRGB); + return c2[is_gray]; +} +const ColorEncoding& ColorEncoding::LinearSRGB(bool is_gray) { + static std::array<ColorEncoding, 2> c2 = + CreateC2(Primaries::kSRGB, TransferFunction::kLinear); + return c2[is_gray]; +} + +Status ColorEncoding::SetWhitePointType(const WhitePoint& wp) { + JXL_DASSERT(storage_.have_fields); + storage_.white_point = wp; + return true; +} + +Status ColorEncoding::SetPrimariesType(const Primaries& p) { + JXL_DASSERT(storage_.have_fields); + JXL_ASSERT(HasPrimaries()); + storage_.primaries = p; + return true; +} + +void ColorEncoding::DecideIfWantICC(const JxlCmsInterface& cms) { + if (storage_.icc.empty()) return; + + JxlColorEncoding c; + JXL_BOOL cmyk; + if (!cms.set_fields_from_icc(cms.set_fields_data, storage_.icc.data(), + storage_.icc.size(), &c, &cmyk)) { + return; + } + if (cmyk) return; + + std::vector<uint8_t> icc; + if (!MaybeCreateProfile(c, &icc)) return; + + want_icc_ = false; +} + +Customxy::Customxy() { Bundle::Init(this); } +Status Customxy::VisitFields(Visitor* JXL_RESTRICT visitor) { + uint32_t ux = PackSigned(storage_.x); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Bits(19), BitsOffset(19, 524288), + BitsOffset(20, 1048576), + BitsOffset(21, 2097152), 0, &ux)); + storage_.x = UnpackSigned(ux); + uint32_t uy = PackSigned(storage_.y); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Bits(19), BitsOffset(19, 524288), + BitsOffset(20, 1048576), + BitsOffset(21, 2097152), 0, &uy)); + storage_.y = UnpackSigned(uy); + return true; +} + +CustomTransferFunction::CustomTransferFunction() { Bundle::Init(this); } +Status CustomTransferFunction::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->Conditional(!SetImplicit())) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &storage_.have_gamma)); + + if (visitor->Conditional(storage_.have_gamma)) { + // Gamma is represented as a 24-bit int, the exponent used is + // gamma_ / 1e7. Valid values are (0, 1]. On the low end side, we also + // limit it to kMaxGamma/1e7. + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits( + 24, ::jxl::cms::CustomTransferFunction::kGammaMul, &storage_.gamma)); + if (storage_.gamma > ::jxl::cms::CustomTransferFunction::kGammaMul || + static_cast<uint64_t>(storage_.gamma) * + ::jxl::cms::CustomTransferFunction::kMaxGamma < + ::jxl::cms::CustomTransferFunction::kGammaMul) { + return JXL_FAILURE("Invalid gamma %u", storage_.gamma); + } + } + + if (visitor->Conditional(!storage_.have_gamma)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->Enum(TransferFunction::kSRGB, &storage_.transfer_function)); + } + } + + return true; +} + +ColorEncoding::ColorEncoding() { Bundle::Init(this); } +Status ColorEncoding::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &want_icc_)); + + // Always send even if want_icc_ because this affects decoding. + // We can skip the white point/primaries because they do not. + JXL_QUIET_RETURN_IF_ERROR( + visitor->Enum(ColorSpace::kRGB, &storage_.color_space)); + + if (visitor->Conditional(!WantICC())) { + // Serialize enums. NOTE: we set the defaults to the most common values so + // ImageMetadata.all_default is true in the common case. + + if (visitor->Conditional(!ImplicitWhitePoint())) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->Enum(WhitePoint::kD65, &storage_.white_point)); + if (visitor->Conditional(storage_.white_point == WhitePoint::kCustom)) { + white_.storage_ = storage_.white; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&white_)); + storage_.white = white_.storage_; + } + } + + if (visitor->Conditional(HasPrimaries())) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->Enum(Primaries::kSRGB, &storage_.primaries)); + if (visitor->Conditional(storage_.primaries == Primaries::kCustom)) { + red_.storage_ = storage_.red; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&red_)); + storage_.red = red_.storage_; + green_.storage_ = storage_.green; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&green_)); + storage_.green = green_.storage_; + blue_.storage_ = storage_.blue; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&blue_)); + storage_.blue = blue_.storage_; + } + } + + tf_.nonserialized_color_space = storage_.color_space; + tf_.storage_ = storage_.tf; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&tf_)); + storage_.tf = tf_.storage_; + + JXL_QUIET_RETURN_IF_ERROR( + visitor->Enum(RenderingIntent::kRelative, &storage_.rendering_intent)); + + // We didn't have ICC, so all fields should be known. + if (storage_.color_space == ColorSpace::kUnknown || + storage_.tf.IsUnknown()) { + return JXL_FAILURE( + "No ICC but cs %u and tf %u%s", + static_cast<unsigned int>(storage_.color_space), + storage_.tf.have_gamma + ? 0 + : static_cast<unsigned int>(storage_.tf.transfer_function), + storage_.tf.have_gamma ? "(gamma)" : ""); + } + + JXL_RETURN_IF_ERROR(CreateICC()); + } + + if (WantICC() && visitor->IsReading()) { + // Haven't called SetICC() yet, do nothing. + } else { + if (ICC().empty()) return JXL_FAILURE("Empty ICC"); + } + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/color_encoding_internal.h b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.h new file mode 100644 index 0000000000..0a104a12b2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/color_encoding_internal.h @@ -0,0 +1,361 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_COLOR_ENCODING_INTERNAL_H_ +#define LIB_JXL_COLOR_ENCODING_INTERNAL_H_ + +// Metadata for color space conversions. + +#include <jxl/cms_interface.h> +#include <jxl/color_encoding.h> +#include <stddef.h> +#include <stdint.h> + +#include <array> +#include <cstdlib> // free +#include <ostream> +#include <string> +#include <utility> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/cms/color_encoding_cms.h" +#include "lib/jxl/cms/jxl_cms_internal.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +using IccBytes = ::jxl::cms::IccBytes; +using ColorSpace = ::jxl::cms::ColorSpace; +using WhitePoint = ::jxl::cms::WhitePoint; +using Primaries = ::jxl::cms::Primaries; +using TransferFunction = ::jxl::cms::TransferFunction; +using RenderingIntent = ::jxl::cms::RenderingIntent; +using CIExy = ::jxl::cms::CIExy; +using PrimariesCIExy = ::jxl::cms::PrimariesCIExy; + +namespace cms { + +static inline const char* EnumName(ColorSpace /*unused*/) { + return "ColorSpace"; +} +static inline constexpr uint64_t EnumBits(ColorSpace /*unused*/) { + using CS = ColorSpace; + return MakeBit(CS::kRGB) | MakeBit(CS::kGray) | MakeBit(CS::kXYB) | + MakeBit(CS::kUnknown); +} + +static inline const char* EnumName(WhitePoint /*unused*/) { + return "WhitePoint"; +} +static inline constexpr uint64_t EnumBits(WhitePoint /*unused*/) { + return MakeBit(WhitePoint::kD65) | MakeBit(WhitePoint::kCustom) | + MakeBit(WhitePoint::kE) | MakeBit(WhitePoint::kDCI); +} + +static inline const char* EnumName(Primaries /*unused*/) { return "Primaries"; } +static inline constexpr uint64_t EnumBits(Primaries /*unused*/) { + using Pr = Primaries; + return MakeBit(Pr::kSRGB) | MakeBit(Pr::kCustom) | MakeBit(Pr::k2100) | + MakeBit(Pr::kP3); +} + +static inline const char* EnumName(TransferFunction /*unused*/) { + return "TransferFunction"; +} + +static inline constexpr uint64_t EnumBits(TransferFunction /*unused*/) { + using TF = TransferFunction; + return MakeBit(TF::k709) | MakeBit(TF::kLinear) | MakeBit(TF::kSRGB) | + MakeBit(TF::kPQ) | MakeBit(TF::kDCI) | MakeBit(TF::kHLG) | + MakeBit(TF::kUnknown); +} + +static inline const char* EnumName(RenderingIntent /*unused*/) { + return "RenderingIntent"; +} +static inline constexpr uint64_t EnumBits(RenderingIntent /*unused*/) { + using RI = RenderingIntent; + return MakeBit(RI::kPerceptual) | MakeBit(RI::kRelative) | + MakeBit(RI::kSaturation) | MakeBit(RI::kAbsolute); +} + +} // namespace cms + +struct ColorEncoding; + +// Serializable form of CIExy. +struct Customxy : public Fields { + Customxy(); + JXL_FIELDS_NAME(Customxy) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + private: + friend struct ColorEncoding; + ::jxl::cms::Customxy storage_; +}; + +struct CustomTransferFunction : public Fields { + CustomTransferFunction(); + JXL_FIELDS_NAME(CustomTransferFunction) + + // Sets fields and returns true if nonserialized_color_space has an implicit + // transfer function, otherwise leaves fields unchanged and returns false. + bool SetImplicit(); + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Must be set before calling VisitFields! + ColorSpace nonserialized_color_space = ColorSpace::kRGB; + + private: + friend struct ColorEncoding; + ::jxl::cms::CustomTransferFunction storage_; +}; + +// Compact encoding of data required to interpret and translate pixels to a +// known color space. Stored in Metadata. Thread-compatible. +struct ColorEncoding : public Fields { + ColorEncoding(); + JXL_FIELDS_NAME(ColorEncoding) + + // Returns ready-to-use color encodings (initialized on-demand). + static const ColorEncoding& SRGB(bool is_gray = false); + static const ColorEncoding& LinearSRGB(bool is_gray = false); + + // Returns true if an ICC profile was successfully created from fields. + // Must be called after modifying fields. Defined in color_management.cc. + Status CreateICC() { + storage_.icc.clear(); + const JxlColorEncoding external = ToExternal(); + if (!MaybeCreateProfile(external, &storage_.icc)) { + storage_.icc.clear(); + return JXL_FAILURE("Failed to create ICC profile"); + } + return true; + } + + // Returns non-empty and valid ICC profile, unless: + // - WantICC() == true and SetICC() was not yet called; + // - after a failed call to SetSRGB(), SetICC(), or CreateICC(). + const IccBytes& ICC() const { return storage_.icc; } + + // Returns true if `icc` is assigned and decoded successfully. If so, + // subsequent WantICC() will return true until DecideIfWantICC() changes it. + // Returning false indicates data has been lost. + Status SetICC(IccBytes&& icc, const JxlCmsInterface* cms) { + JXL_ASSERT(cms != nullptr); + JXL_ASSERT(!icc.empty()); + want_icc_ = storage_.SetFieldsFromICC(std::move(icc), *cms); + return want_icc_; + } + + // Sets the raw ICC profile bytes, without parsing the ICC, and without + // updating the direct fields such as whitepoint, primaries and color + // space. Functions to get and set fields, such as SetWhitePoint, cannot be + // used anymore after this and functions such as IsSRGB return false no matter + // what the contents of the icc profile. + void SetICCRaw(IccBytes&& icc) { + JXL_ASSERT(!icc.empty()); + storage_.icc = std::move(icc); + storage_.have_fields = false; + want_icc_ = true; + } + + // Returns whether to send the ICC profile in the codestream. + bool WantICC() const { return want_icc_; } + + // Return whether the direct fields are set, if false but ICC is set, only + // raw ICC bytes are known. + bool HaveFields() const { return storage_.have_fields; } + + // Causes WantICC() to return false if ICC() can be reconstructed from fields. + void DecideIfWantICC(const JxlCmsInterface& cms); + + bool IsGray() const { return storage_.color_space == ColorSpace::kGray; } + bool IsCMYK() const { return storage_.cmyk; } + size_t Channels() const { return storage_.Channels(); } + + // Returns false if the field is invalid and unusable. + bool HasPrimaries() const { return storage_.HasPrimaries(); } + + // Returns true after setting the field to a value defined by color_space, + // otherwise false and leaves the field unchanged. + bool ImplicitWhitePoint() { + // TODO(eustas): inline + if (storage_.color_space == ColorSpace::kXYB) { + storage_.white_point = WhitePoint::kD65; + return true; + } + return false; + } + + // Returns whether the color space is known to be sRGB. If a raw unparsed ICC + // profile is set without the fields being set, this returns false, even if + // the content of the ICC profile would match sRGB. + bool IsSRGB() const { + if (!storage_.have_fields) return false; + if (!IsGray() && storage_.color_space != ColorSpace::kRGB) return false; + if (storage_.white_point != WhitePoint::kD65) return false; + if (storage_.primaries != Primaries::kSRGB) return false; + if (!storage_.tf.IsSRGB()) return false; + return true; + } + + // Returns whether the color space is known to be linear sRGB. If a raw + // unparsed ICC profile is set without the fields being set, this returns + // false, even if the content of the ICC profile would match linear sRGB. + bool IsLinearSRGB() const { + if (!storage_.have_fields) return false; + if (!IsGray() && storage_.color_space != ColorSpace::kRGB) return false; + if (storage_.white_point != WhitePoint::kD65) return false; + if (storage_.primaries != Primaries::kSRGB) return false; + if (!storage_.tf.IsLinear()) return false; + return true; + } + + Status SetSRGB(const ColorSpace cs, + const RenderingIntent ri = RenderingIntent::kRelative) { + storage_.icc.clear(); + JXL_ASSERT(cs == ColorSpace::kGray || cs == ColorSpace::kRGB); + storage_.color_space = cs; + storage_.white_point = WhitePoint::kD65; + storage_.primaries = Primaries::kSRGB; + storage_.tf.transfer_function = TransferFunction::kSRGB; + storage_.rendering_intent = ri; + return CreateICC(); + } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Accessors ensure tf.nonserialized_color_space is updated at the same time. + ColorSpace GetColorSpace() const { return storage_.color_space; } + void SetColorSpace(const ColorSpace cs) { storage_.color_space = cs; } + CIExy GetWhitePoint() const { return storage_.GetWhitePoint(); } + + WhitePoint GetWhitePointType() const { return storage_.white_point; } + Status SetWhitePointType(const WhitePoint& wp); + PrimariesCIExy GetPrimaries() const { return storage_.GetPrimaries(); } + + Primaries GetPrimariesType() const { return storage_.primaries; } + Status SetPrimariesType(const Primaries& p); + + jxl::cms::CustomTransferFunction& Tf() { return storage_.tf; } + const jxl::cms::CustomTransferFunction& Tf() const { return storage_.tf; } + + RenderingIntent GetRenderingIntent() const { + return storage_.rendering_intent; + } + void SetRenderingIntent(const RenderingIntent& ri) { + storage_.rendering_intent = ri; + } + + bool SameColorEncoding(const ColorEncoding& other) const { + return storage_.SameColorEncoding(other.storage_); + } + + mutable bool all_default; + + JxlColorEncoding ToExternal() const { return storage_.ToExternal(); } + Status FromExternal(const JxlColorEncoding& external) { + JXL_RETURN_IF_ERROR(storage_.FromExternal(external)); + (void)CreateICC(); + return true; + } + const jxl::cms::ColorEncoding& View() const { return storage_; } + std::string Description() const; + + private: + static std::array<ColorEncoding, 2> CreateC2(Primaries pr, + TransferFunction tf); + + // If true, the codestream contains an ICC profile and we do not serialize + // fields. Otherwise, fields are serialized and we create an ICC profile. + bool want_icc_; + + ::jxl::cms::ColorEncoding storage_; + // Only used if white_point == kCustom. + Customxy white_; + + // Only valid if HaveFields() + CustomTransferFunction tf_; + + // Only used if primaries == kCustom. + Customxy red_; + Customxy green_; + Customxy blue_; +}; + +static inline std::string Description(const ColorEncoding& c) { + const JxlColorEncoding external = c.View().ToExternal(); + return ColorEncodingDescription(external); +} + +static inline std::ostream& operator<<(std::ostream& os, + const ColorEncoding& c) { + return os << Description(c); +} + +class ColorSpaceTransform { + public: + explicit ColorSpaceTransform(const JxlCmsInterface& cms) : cms_(cms) {} + ~ColorSpaceTransform() { + if (cms_data_ != nullptr) { + cms_.destroy(cms_data_); + } + } + + // Cannot copy. + ColorSpaceTransform(const ColorSpaceTransform&) = delete; + ColorSpaceTransform& operator=(const ColorSpaceTransform&) = delete; + + Status Init(const ColorEncoding& c_src, const ColorEncoding& c_dst, + float intensity_target, size_t xsize, size_t num_threads) { + xsize_ = xsize; + JxlColorProfile input_profile; + icc_src_ = c_src.ICC(); + input_profile.icc.data = icc_src_.data(); + input_profile.icc.size = icc_src_.size(); + input_profile.color_encoding = c_src.ToExternal(); + input_profile.num_channels = c_src.IsCMYK() ? 4 : c_src.Channels(); + JxlColorProfile output_profile; + icc_dst_ = c_dst.ICC(); + output_profile.icc.data = icc_dst_.data(); + output_profile.icc.size = icc_dst_.size(); + output_profile.color_encoding = c_dst.ToExternal(); + if (c_dst.IsCMYK()) + return JXL_FAILURE("Conversion to CMYK is not supported"); + output_profile.num_channels = c_dst.Channels(); + cms_data_ = cms_.init(cms_.init_data, num_threads, xsize, &input_profile, + &output_profile, intensity_target); + JXL_RETURN_IF_ERROR(cms_data_ != nullptr); + return true; + } + + float* BufSrc(const size_t thread) const { + return cms_.get_src_buf(cms_data_, thread); + } + + float* BufDst(const size_t thread) const { + return cms_.get_dst_buf(cms_data_, thread); + } + + Status Run(const size_t thread, const float* buf_src, float* buf_dst) { + return cms_.run(cms_data_, thread, buf_src, buf_dst, xsize_); + } + + private: + JxlCmsInterface cms_; + void* cms_data_ = nullptr; + // The interface may retain pointers into these. + IccBytes icc_src_; + IccBytes icc_dst_; + size_t xsize_; +}; + +} // namespace jxl + +#endif // LIB_JXL_COLOR_ENCODING_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/color_encoding_internal_test.cc b/third_party/jpeg-xl/lib/jxl/color_encoding_internal_test.cc new file mode 100644 index 0000000000..4d2d3e8119 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/color_encoding_internal_test.cc @@ -0,0 +1,155 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/color_encoding_internal.h" + +#include <jxl/color_encoding.h> + +#include <cstdlib> // rand + +#include "lib/jxl/cms/color_encoding_cms.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +using jxl::cms::ColorEncoding; + +TEST(ColorEncodingTest, RoundTripAll) { + for (const test::ColorEncodingDescriptor& cdesc : test::AllEncodings()) { + ColorEncoding c_original = test::ColorEncodingFromDescriptor(cdesc).View(); + // Verify Set(Get) yields the same white point/primaries/gamma. + { + ColorEncoding c; + EXPECT_TRUE(c.SetWhitePoint(c_original.GetWhitePoint())); + EXPECT_EQ(c_original.white_point, c.white_point); + } + { + ColorEncoding c; + EXPECT_TRUE(c.SetPrimaries(c_original.GetPrimaries())); + EXPECT_EQ(c_original.primaries, c.primaries); + } + if (c_original.tf.have_gamma) { + ColorEncoding c; + EXPECT_TRUE(c.tf.SetGamma(c_original.tf.GetGamma())); + EXPECT_TRUE(c_original.tf.IsSame(c.tf)); + } + } +} + +TEST(ColorEncodingTest, CustomWhitePoint) { + ColorEncoding c; + // Nonsensical values + CIExy xy_in; + xy_in.x = 0.8; + xy_in.y = 0.01; + EXPECT_TRUE(c.SetWhitePoint(xy_in)); + const CIExy xy = c.GetWhitePoint(); + + ColorEncoding c2; + EXPECT_TRUE(c2.SetWhitePoint(xy)); + EXPECT_TRUE(c.SameColorSpace(c2)); +} + +TEST(ColorEncodingTest, CustomPrimaries) { + ColorEncoding c; + PrimariesCIExy xy_in; + // Nonsensical values + xy_in.r.x = -0.01; + xy_in.r.y = 0.2; + xy_in.g.x = 0.4; + xy_in.g.y = 0.401; + xy_in.b.x = 1.1; + xy_in.b.y = -1.2; + EXPECT_TRUE(c.SetPrimaries(xy_in)); + const PrimariesCIExy xy = c.GetPrimaries(); + + ColorEncoding c2; + EXPECT_TRUE(c2.SetPrimaries(xy)); + EXPECT_TRUE(c.SameColorSpace(c2)); +} + +TEST(ColorEncodingTest, CustomGamma) { + ColorEncoding c; +#ifndef JXL_CRASH_ON_ERROR + EXPECT_FALSE(c.tf.SetGamma(0.0)); + EXPECT_FALSE(c.tf.SetGamma(-1E-6)); + EXPECT_FALSE(c.tf.SetGamma(1.001)); +#endif + EXPECT_TRUE(c.tf.SetGamma(1.0)); + EXPECT_FALSE(c.tf.have_gamma); + EXPECT_TRUE(c.tf.IsLinear()); + + EXPECT_TRUE(c.tf.SetGamma(0.123)); + EXPECT_TRUE(c.tf.have_gamma); + const double gamma = c.tf.GetGamma(); + + ColorEncoding c2; + EXPECT_TRUE(c2.tf.SetGamma(gamma)); + EXPECT_TRUE(c.SameColorEncoding(c2)); + EXPECT_TRUE(c2.tf.have_gamma); +} + +TEST(ColorEncodingTest, InternalExternalConversion) { + ColorEncoding source_internal; + ColorEncoding destination_internal; + + for (int i = 0; i < 100; i++) { + source_internal.color_space = static_cast<ColorSpace>(rand() % 4); + CIExy wp; + wp.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + wp.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + EXPECT_TRUE(source_internal.SetWhitePoint(wp)); + if (source_internal.HasPrimaries()) { + PrimariesCIExy primaries; + primaries.r.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.r.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.g.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.g.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.b.x = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + primaries.b.y = (float(rand()) / float((RAND_MAX)) * 0.5) + 0.25; + EXPECT_TRUE(source_internal.SetPrimaries(primaries)); + } + jxl::cms::CustomTransferFunction tf; + EXPECT_TRUE(tf.SetGamma((float(rand()) / float((RAND_MAX)) * 0.5) + 0.25)); + source_internal.tf = tf; + source_internal.rendering_intent = static_cast<RenderingIntent>(rand() % 4); + + JxlColorEncoding external = source_internal.ToExternal(); + EXPECT_TRUE(destination_internal.FromExternal(external)); + + EXPECT_EQ(source_internal.color_space, destination_internal.color_space); + EXPECT_EQ(source_internal.white_point, destination_internal.white_point); + CIExy src_wp = source_internal.GetWhitePoint(); + CIExy dst_wp = destination_internal.GetWhitePoint(); + EXPECT_EQ(src_wp.x, dst_wp.x); + EXPECT_EQ(src_wp.y, dst_wp.y); + if (source_internal.HasPrimaries()) { + PrimariesCIExy src_p = source_internal.GetPrimaries(); + PrimariesCIExy dst_p = destination_internal.GetPrimaries(); + EXPECT_EQ(src_p.r.x, dst_p.r.x); + EXPECT_EQ(src_p.r.y, dst_p.r.y); + EXPECT_EQ(src_p.g.x, dst_p.g.x); + EXPECT_EQ(src_p.g.y, dst_p.g.y); + EXPECT_EQ(src_p.b.x, dst_p.b.x); + EXPECT_EQ(src_p.b.y, dst_p.b.y); + } + EXPECT_EQ(source_internal.tf.have_gamma, + destination_internal.tf.have_gamma); + if (source_internal.tf.have_gamma) { + EXPECT_EQ(source_internal.tf.GetGamma(), + destination_internal.tf.GetGamma()); + } else { + EXPECT_EQ(source_internal.tf.GetTransferFunction(), + destination_internal.tf.GetTransferFunction()); + } + EXPECT_EQ(source_internal.rendering_intent, + destination_internal.rendering_intent); + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/color_management_test.cc b/third_party/jpeg-xl/lib/jxl/color_management_test.cc new file mode 100644 index 0000000000..ca50c9960e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/color_management_test.cc @@ -0,0 +1,469 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/cms.h> +#include <jxl/cms_interface.h> +#include <stdint.h> + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <cstdio> +#include <cstdlib> +#include <ostream> +#include <string> +#include <utility> +#include <vector> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/cms/color_encoding_cms.h" +#include "lib/jxl/cms/opsin_params.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { + +std::ostream& operator<<(std::ostream& os, const CIExy& xy) { + return os << "{x=" << xy.x << ", y=" << xy.y << "}"; +} + +std::ostream& operator<<(std::ostream& os, const PrimariesCIExy& primaries) { + return os << "{r=" << primaries.r << ", g=" << primaries.g + << ", b=" << primaries.b << "}"; +} + +namespace { + +using ::testing::ElementsAre; +using ::testing::FloatNear; + +// Small enough to be fast. If changed, must update Generate*. +static constexpr size_t kWidth = 16; + +static constexpr size_t kNumThreads = 1; // only have a single row. + +MATCHER_P(HasSameFieldsAs, expected, "") { + if (arg.GetRenderingIntent() != expected.GetRenderingIntent()) { + *result_listener << "which has a different rendering intent: " + << ToString(arg.GetRenderingIntent()) << " instead of " + << ToString(expected.GetRenderingIntent()); + return false; + } + if (arg.GetColorSpace() != expected.GetColorSpace()) { + *result_listener << "which has a different color space: " + << ToString(arg.GetColorSpace()) << " instead of " + << ToString(expected.GetColorSpace()); + return false; + } + if (arg.GetWhitePointType() != expected.GetWhitePointType()) { + *result_listener << "which has a different white point: " + << ToString(arg.GetWhitePointType()) << " instead of " + << ToString(expected.GetWhitePointType()); + return false; + } + if (arg.HasPrimaries() && + arg.GetPrimariesType() != expected.GetPrimariesType()) { + *result_listener << "which has different primaries: " + << ToString(arg.GetPrimariesType()) << " instead of " + << ToString(expected.GetPrimariesType()); + return false; + } + if (!arg.Tf().IsSame(expected.Tf())) { + static const auto tf_to_string = + [](const jxl::cms::CustomTransferFunction& tf) { + if (tf.have_gamma) { + return "g" + ToString(tf.GetGamma()); + } + return ToString(tf.transfer_function); + }; + *result_listener << "which has a different transfer function: " + << tf_to_string(arg.Tf()) << " instead of " + << tf_to_string(expected.Tf()); + return false; + } + return true; +} + +struct Globals { + // TODO(deymo): Make this a const. + static Globals* GetInstance() { + static Globals ret; + return &ret; + } + + private: + Globals() { + in_gray = GenerateGray(); + in_color = GenerateColor(); + out_gray = ImageF(kWidth, 1); + out_color = ImageF(kWidth * 3, 1); + + c_native = ColorEncoding::LinearSRGB(/*is_gray=*/false); + c_gray = ColorEncoding::LinearSRGB(/*is_gray=*/true); + } + + static ImageF GenerateGray() { + ImageF gray(kWidth, 1); + float* JXL_RESTRICT row = gray.Row(0); + // Increasing left to right + for (uint32_t x = 0; x < kWidth; ++x) { + row[x] = x * 1.0f / (kWidth - 1); // [0, 1] + } + return gray; + } + + static ImageF GenerateColor() { + ImageF image(kWidth * 3, 1); + float* JXL_RESTRICT interleaved = image.Row(0); + std::fill(interleaved, interleaved + kWidth * 3, 0.0f); + + // [0, 4): neutral + for (int32_t x = 0; x < 4; ++x) { + interleaved[3 * x + 0] = x * 1.0f / 3; // [0, 1] + interleaved[3 * x + 2] = interleaved[3 * x + 1] = interleaved[3 * x + 0]; + } + + // [4, 13): pure RGB with low/medium/high saturation + for (int32_t c = 0; c < 3; ++c) { + interleaved[3 * (4 + c) + c] = 0.08f + c * 0.01f; + interleaved[3 * (7 + c) + c] = 0.75f + c * 0.01f; + interleaved[3 * (10 + c) + c] = 1.0f; + } + + // [13, 16): impure, not quite saturated RGB + interleaved[3 * 13 + 0] = 0.86f; + interleaved[3 * 13 + 2] = interleaved[3 * 13 + 1] = 0.16f; + interleaved[3 * 14 + 1] = 0.87f; + interleaved[3 * 14 + 2] = interleaved[3 * 14 + 0] = 0.16f; + interleaved[3 * 15 + 2] = 0.88f; + interleaved[3 * 15 + 1] = interleaved[3 * 15 + 0] = 0.16f; + + return image; + } + + public: + // ImageF so we can use VerifyRelativeError; all are interleaved RGB. + ImageF in_gray; + ImageF in_color; + ImageF out_gray; + ImageF out_color; + ColorEncoding c_native; + ColorEncoding c_gray; +}; + +class ColorManagementTest + : public ::testing::TestWithParam<test::ColorEncodingDescriptor> { + public: + // "Same" pixels after converting g->c_native -> c -> g->c_native. + static void VerifyPixelRoundTrip(const ColorEncoding& c) { + Globals* g = Globals::GetInstance(); + const ColorEncoding& c_native = c.IsGray() ? g->c_gray : g->c_native; + const JxlCmsInterface& cms = *JxlGetDefaultCms(); + ColorSpaceTransform xform_fwd(cms); + ColorSpaceTransform xform_rev(cms); + const float intensity_target = + c.Tf().IsHLG() ? 1000 : kDefaultIntensityTarget; + ASSERT_TRUE( + xform_fwd.Init(c_native, c, intensity_target, kWidth, kNumThreads)); + ASSERT_TRUE( + xform_rev.Init(c, c_native, intensity_target, kWidth, kNumThreads)); + + const size_t thread = 0; + const ImageF& in = c.IsGray() ? g->in_gray : g->in_color; + ImageF* JXL_RESTRICT out = c.IsGray() ? &g->out_gray : &g->out_color; + ASSERT_TRUE(xform_fwd.Run(thread, in.Row(0), xform_fwd.BufDst(thread))); + ASSERT_TRUE(xform_rev.Run(thread, xform_fwd.BufDst(thread), out->Row(0))); + + // With lcms2, this value is lower: 5E-5 + double max_l1 = 7E-4; + // Most are lower; reached 3E-7 with D60 AP0. + double max_rel = 4E-7; + if (c.IsGray()) max_rel = 2E-5; + JXL_ASSERT_OK(VerifyRelativeError(in, *out, max_l1, max_rel, _)); + } +}; +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(ColorManagementTestInstantiation, + ColorManagementTest, + ::testing::ValuesIn(test::AllEncodings())); + +// Exercises the ColorManagement interface for ALL ColorEncoding synthesizable +// via enums. +TEST_P(ColorManagementTest, VerifyAllProfiles) { + ColorEncoding c = ColorEncodingFromDescriptor(GetParam()); + printf("%s\n", Description(c).c_str()); + + // Can create profile. + ASSERT_TRUE(c.CreateICC()); + + // Can set an equivalent ColorEncoding from the generated ICC profile. + ColorEncoding c3; + ASSERT_TRUE(c3.SetICC(IccBytes(c.ICC()), JxlGetDefaultCms())); + EXPECT_THAT(c3, HasSameFieldsAs(c)); + + VerifyPixelRoundTrip(c); +} + +testing::Matcher<CIExy> CIExyIs(const double x, const double y) { + static constexpr double kMaxError = 1e-4; + return testing::AllOf( + testing::Field(&CIExy::x, testing::DoubleNear(x, kMaxError)), + testing::Field(&CIExy::y, testing::DoubleNear(y, kMaxError))); +} + +testing::Matcher<PrimariesCIExy> PrimariesAre( + const testing::Matcher<CIExy>& r, const testing::Matcher<CIExy>& g, + const testing::Matcher<CIExy>& b) { + return testing::AllOf(testing::Field(&PrimariesCIExy::r, r), + testing::Field(&PrimariesCIExy::g, g), + testing::Field(&PrimariesCIExy::b, b)); +} + +TEST_F(ColorManagementTest, sRGBChromaticity) { + const ColorEncoding sRGB = ColorEncoding::SRGB(); + EXPECT_THAT(sRGB.GetWhitePoint(), CIExyIs(0.3127, 0.3290)); + EXPECT_THAT(sRGB.GetPrimaries(), + PrimariesAre(CIExyIs(0.64, 0.33), CIExyIs(0.30, 0.60), + CIExyIs(0.15, 0.06))); +} + +TEST_F(ColorManagementTest, D2700Chromaticity) { + std::vector<uint8_t> icc_data = + jxl::test::ReadTestData("jxl/color_management/sRGB-D2700.icc"); + IccBytes icc; + Bytes(icc_data).AppendTo(&icc); + ColorEncoding sRGB_D2700; + ASSERT_TRUE(sRGB_D2700.SetICC(std::move(icc), JxlGetDefaultCms())); + + EXPECT_THAT(sRGB_D2700.GetWhitePoint(), CIExyIs(0.45986, 0.41060)); + // The illuminant-relative chromaticities of this profile's primaries are the + // same as for sRGB. It is the PCS-relative chromaticities that would be + // different. + EXPECT_THAT(sRGB_D2700.GetPrimaries(), + PrimariesAre(CIExyIs(0.64, 0.33), CIExyIs(0.30, 0.60), + CIExyIs(0.15, 0.06))); +} + +TEST_F(ColorManagementTest, D2700ToSRGB) { + std::vector<uint8_t> icc_data = + jxl::test::ReadTestData("jxl/color_management/sRGB-D2700.icc"); + IccBytes icc; + Bytes(icc_data).AppendTo(&icc); + ColorEncoding sRGB_D2700; + ASSERT_TRUE(sRGB_D2700.SetICC(std::move(icc), JxlGetDefaultCms())); + + ColorSpaceTransform transform(*JxlGetDefaultCms()); + ASSERT_TRUE(transform.Init(sRGB_D2700, ColorEncoding::SRGB(), + kDefaultIntensityTarget, 1, 1)); + const float sRGB_D2700_values[3] = {0.863, 0.737, 0.490}; + float sRGB_values[3]; + ASSERT_TRUE(transform.Run(0, sRGB_D2700_values, sRGB_values)); + EXPECT_THAT(sRGB_values, + ElementsAre(FloatNear(0.914, 1e-3), FloatNear(0.745, 1e-3), + FloatNear(0.601, 1e-3))); +} + +TEST_F(ColorManagementTest, P3HlgTo2020Hlg) { + ColorEncoding p3_hlg; + p3_hlg.SetColorSpace(ColorSpace::kRGB); + ASSERT_TRUE(p3_hlg.SetWhitePointType(WhitePoint::kD65)); + ASSERT_TRUE(p3_hlg.SetPrimariesType(Primaries::kP3)); + p3_hlg.Tf().SetTransferFunction(TransferFunction::kHLG); + ASSERT_TRUE(p3_hlg.CreateICC()); + + ColorEncoding rec2020_hlg = p3_hlg; + ASSERT_TRUE(rec2020_hlg.SetPrimariesType(Primaries::k2100)); + ASSERT_TRUE(rec2020_hlg.CreateICC()); + + ColorSpaceTransform transform(*JxlGetDefaultCms()); + ASSERT_TRUE(transform.Init(p3_hlg, rec2020_hlg, 1000, 1, 1)); + const float p3_hlg_values[3] = {0., 0.75, 0.}; + float rec2020_hlg_values[3]; + ASSERT_TRUE(transform.Run(0, p3_hlg_values, rec2020_hlg_values)); + EXPECT_THAT(rec2020_hlg_values, + ElementsAre(FloatNear(0.3973, 1e-4), FloatNear(0.7382, 1e-4), + FloatNear(0.1183, 1e-4))); +} + +TEST_F(ColorManagementTest, HlgOotf) { + ColorEncoding p3_hlg; + p3_hlg.SetColorSpace(ColorSpace::kRGB); + ASSERT_TRUE(p3_hlg.SetWhitePointType(WhitePoint::kD65)); + ASSERT_TRUE(p3_hlg.SetPrimariesType(Primaries::kP3)); + p3_hlg.Tf().SetTransferFunction(TransferFunction::kHLG); + ASSERT_TRUE(p3_hlg.CreateICC()); + + ColorSpaceTransform transform_to_1000(*JxlGetDefaultCms()); + ASSERT_TRUE( + transform_to_1000.Init(p3_hlg, ColorEncoding::LinearSRGB(), 1000, 1, 1)); + // HDR reference white: https://www.itu.int/pub/R-REP-BT.2408-4-2021 + float p3_hlg_values[3] = {0.75, 0.75, 0.75}; + float linear_srgb_values[3]; + ASSERT_TRUE(transform_to_1000.Run(0, p3_hlg_values, linear_srgb_values)); + // On a 1000-nit display, HDR reference white should be 203 cd/m² which is + // 0.203 times the maximum. + EXPECT_THAT(linear_srgb_values, + ElementsAre(FloatNear(0.203, 1e-3), FloatNear(0.203, 1e-3), + FloatNear(0.203, 1e-3))); + + ColorSpaceTransform transform_to_400(*JxlGetDefaultCms()); + ASSERT_TRUE( + transform_to_400.Init(p3_hlg, ColorEncoding::LinearSRGB(), 400, 1, 1)); + ASSERT_TRUE(transform_to_400.Run(0, p3_hlg_values, linear_srgb_values)); + // On a 400-nit display, it should be 100 cd/m². + EXPECT_THAT(linear_srgb_values, + ElementsAre(FloatNear(0.250, 1e-3), FloatNear(0.250, 1e-3), + FloatNear(0.250, 1e-3))); + + p3_hlg_values[2] = 0.50; + ASSERT_TRUE(transform_to_1000.Run(0, p3_hlg_values, linear_srgb_values)); + EXPECT_THAT(linear_srgb_values, + ElementsAre(FloatNear(0.201, 1e-3), FloatNear(0.201, 1e-3), + FloatNear(0.050, 1e-3))); + + ColorSpaceTransform transform_from_400(*JxlGetDefaultCms()); + ASSERT_TRUE( + transform_from_400.Init(ColorEncoding::LinearSRGB(), p3_hlg, 400, 1, 1)); + linear_srgb_values[0] = linear_srgb_values[1] = linear_srgb_values[2] = 0.250; + ASSERT_TRUE(transform_from_400.Run(0, linear_srgb_values, p3_hlg_values)); + EXPECT_THAT(p3_hlg_values, + ElementsAre(FloatNear(0.75, 1e-3), FloatNear(0.75, 1e-3), + FloatNear(0.75, 1e-3))); + + ColorEncoding grayscale_hlg; + grayscale_hlg.SetColorSpace(ColorSpace::kGray); + ASSERT_TRUE(grayscale_hlg.SetWhitePointType(WhitePoint::kD65)); + grayscale_hlg.Tf().SetTransferFunction(TransferFunction::kHLG); + ASSERT_TRUE(grayscale_hlg.CreateICC()); + + ColorSpaceTransform grayscale_transform(*JxlGetDefaultCms()); + ASSERT_TRUE(grayscale_transform.Init( + grayscale_hlg, ColorEncoding::LinearSRGB(/*is_gray=*/true), 1000, 1, 1)); + const float grayscale_hlg_value = 0.75; + float linear_grayscale_value; + ASSERT_TRUE(grayscale_transform.Run(0, &grayscale_hlg_value, + &linear_grayscale_value)); + EXPECT_THAT(linear_grayscale_value, FloatNear(0.203, 1e-3)); +} + +TEST_F(ColorManagementTest, XYBProfile) { + ColorEncoding c_xyb; + c_xyb.SetColorSpace(ColorSpace::kXYB); + c_xyb.SetRenderingIntent(RenderingIntent::kPerceptual); + ASSERT_TRUE(c_xyb.CreateICC()); + ColorEncoding c_native = ColorEncoding::LinearSRGB(false); + + static const size_t kGridDim = 17; + static const size_t kNumColors = kGridDim * kGridDim * kGridDim; + const JxlCmsInterface& cms = *JxlGetDefaultCms(); + ColorSpaceTransform xform(cms); + ASSERT_TRUE( + xform.Init(c_xyb, c_native, kDefaultIntensityTarget, kNumColors, 1)); + + ImageMetadata metadata; + metadata.color_encoding = c_native; + ImageBundle ib(&metadata); + Image3F native(kNumColors, 1); + float mul = 1.0f / (kGridDim - 1); + for (size_t ir = 0, x = 0; ir < kGridDim; ++ir) { + for (size_t ig = 0; ig < kGridDim; ++ig) { + for (size_t ib = 0; ib < kGridDim; ++ib, ++x) { + native.PlaneRow(0, 0)[x] = ir * mul; + native.PlaneRow(1, 0)[x] = ig * mul; + native.PlaneRow(2, 0)[x] = ib * mul; + } + } + } + ib.SetFromImage(std::move(native), c_native); + const Image3F& in = *ib.color(); + Image3F opsin(kNumColors, 1); + ToXYB(ib, nullptr, &opsin, cms, nullptr); + + Image3F opsin2(kNumColors, 1); + CopyImageTo(opsin, &opsin2); + ScaleXYB(&opsin2); + + float* src = xform.BufSrc(0); + for (size_t i = 0; i < kNumColors; ++i) { + for (size_t c = 0; c < 3; ++c) { + src[3 * i + c] = opsin2.PlaneRow(c, 0)[i]; + } + } + + float* dst = xform.BufDst(0); + ASSERT_TRUE(xform.Run(0, src, dst)); + + Image3F out(kNumColors, 1); + for (size_t i = 0; i < kNumColors; ++i) { + for (size_t c = 0; c < 3; ++c) { + out.PlaneRow(c, 0)[i] = dst[3 * i + c]; + } + } + + auto debug_print_color = [&](size_t i) { + printf( + "(%f, %f, %f) -> (%9.6f, %f, %f) -> (%f, %f, %f) -> " + "(%9.6f, %9.6f, %9.6f)", + in.PlaneRow(0, 0)[i], in.PlaneRow(1, 0)[i], in.PlaneRow(2, 0)[i], + opsin.PlaneRow(0, 0)[i], opsin.PlaneRow(1, 0)[i], + opsin.PlaneRow(2, 0)[i], opsin2.PlaneRow(0, 0)[i], + opsin2.PlaneRow(1, 0)[i], opsin2.PlaneRow(2, 0)[i], + out.PlaneRow(0, 0)[i], out.PlaneRow(1, 0)[i], out.PlaneRow(2, 0)[i]); + }; + + float max_err[3] = {}; + size_t max_err_i[3] = {}; + for (size_t i = 0; i < kNumColors; ++i) { + for (size_t c = 0; c < 3; ++c) { + // debug_print_color(i); printf("\n"); + float err = std::abs(in.PlaneRow(c, 0)[i] - out.PlaneRow(c, 0)[i]); + if (err > max_err[c]) { + max_err[c] = err; + max_err_i[c] = i; + } + } + } + static float kMaxError[3] = {9e-4, 4e-4, 5e-4}; + printf("Maximum errors:\n"); + for (size_t c = 0; c < 3; ++c) { + debug_print_color(max_err_i[c]); + printf(" %f\n", max_err[c]); + EXPECT_LT(max_err[c], kMaxError[c]); + } +} + +TEST_F(ColorManagementTest, GoldenXYBCube) { + std::vector<int32_t> actual; + const jxl::cms::ColorCube3D& cube = jxl::cms::UnscaledA2BCube(); + for (size_t ix = 0; ix < 2; ++ix) { + for (size_t iy = 0; iy < 2; ++iy) { + for (size_t ib = 0; ib < 2; ++ib) { + const jxl::cms::ColorCube0D& out_f = cube[ix][iy][ib]; + for (int i = 0; i < 3; ++i) { + int32_t val = static_cast<int32_t>(0.5f + 65535 * out_f[i]); + ASSERT_TRUE(val >= 0 && val <= 65535); + actual.push_back(val); + } + } + } + } + + std::vector<int32_t> expected = {0, 3206, 0, 0, 3206, 28873, + 62329, 65535, 36662, 62329, 65535, 65535, + 3206, 0, 0, 3206, 0, 28873, + 65535, 62329, 36662, 65535, 62329, 65535}; + EXPECT_EQ(actual, expected); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/common.h b/third_party/jpeg-xl/lib/jxl/common.h new file mode 100644 index 0000000000..d619711c9f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/common.h @@ -0,0 +1,38 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_COMMON_H_ +#define LIB_JXL_COMMON_H_ + +// Shared constants. + +#include <cstddef> + +#ifndef JXL_HIGH_PRECISION +#define JXL_HIGH_PRECISION 1 +#endif + +// Macro that defines whether support for decoding JXL files to JPEG is enabled. +#ifndef JPEGXL_ENABLE_TRANSCODE_JPEG +#define JPEGXL_ENABLE_TRANSCODE_JPEG 1 +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +// Macro that defines whether support for decoding boxes is enabled. +#ifndef JPEGXL_ENABLE_BOXES +#define JPEGXL_ENABLE_BOXES 1 +#endif // JPEGXL_ENABLE_BOXES + +namespace jxl { +// Some enums and typedefs used by more than one header file. + +// Maximum number of passes in an image. +constexpr size_t kMaxNumPasses = 11; + +// Maximum number of reference frames. +constexpr size_t kMaxNumReferenceFrames = 4; + +} // namespace jxl + +#endif // LIB_JXL_COMMON_H_ diff --git a/third_party/jpeg-xl/lib/jxl/compressed_dc.cc b/third_party/jpeg-xl/lib/jxl/compressed_dc.cc new file mode 100644 index 0000000000..b21b1da18b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/compressed_dc.cc @@ -0,0 +1,313 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/compressed_dc.h" + +#include <stdint.h> +#include <stdlib.h> +#include <string.h> + +#include <algorithm> +#include <array> +#include <memory> +#include <utility> +#include <vector> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/compressed_dc.cc" +#include <hwy/aligned_allocator.h> +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/image.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +using D = HWY_FULL(float); +using DScalar = HWY_CAPPED(float, 1); + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Div; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::Vec; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +// TODO(veluca): optimize constants. +const float w1 = 0.20345139757231578f; +const float w2 = 0.0334829185968739f; +const float w0 = 1.0f - 4.0f * (w1 + w2); + +template <class V> +V MaxWorkaround(V a, V b) { +#if (HWY_TARGET == HWY_AVX3) && HWY_COMPILER_CLANG <= 800 + // Prevents "Do not know how to split the result of this operator" error + return IfThenElse(a > b, a, b); +#else + return Max(a, b); +#endif +} + +template <typename D> +JXL_INLINE void ComputePixelChannel(const D d, const float dc_factor, + const float* JXL_RESTRICT row_top, + const float* JXL_RESTRICT row, + const float* JXL_RESTRICT row_bottom, + Vec<D>* JXL_RESTRICT mc, + Vec<D>* JXL_RESTRICT sm, + Vec<D>* JXL_RESTRICT gap, size_t x) { + const auto tl = LoadU(d, row_top + x - 1); + const auto tc = Load(d, row_top + x); + const auto tr = LoadU(d, row_top + x + 1); + + const auto ml = LoadU(d, row + x - 1); + *mc = Load(d, row + x); + const auto mr = LoadU(d, row + x + 1); + + const auto bl = LoadU(d, row_bottom + x - 1); + const auto bc = Load(d, row_bottom + x); + const auto br = LoadU(d, row_bottom + x + 1); + + const auto w_center = Set(d, w0); + const auto w_side = Set(d, w1); + const auto w_corner = Set(d, w2); + + const auto corner = Add(Add(tl, tr), Add(bl, br)); + const auto side = Add(Add(ml, mr), Add(tc, bc)); + *sm = MulAdd(corner, w_corner, MulAdd(side, w_side, Mul(*mc, w_center))); + + const auto dc_quant = Set(d, dc_factor); + *gap = MaxWorkaround(*gap, Abs(Div(Sub(*mc, *sm), dc_quant))); +} + +template <typename D> +JXL_INLINE void ComputePixel( + const float* JXL_RESTRICT dc_factors, + const float* JXL_RESTRICT* JXL_RESTRICT rows_top, + const float* JXL_RESTRICT* JXL_RESTRICT rows, + const float* JXL_RESTRICT* JXL_RESTRICT rows_bottom, + float* JXL_RESTRICT* JXL_RESTRICT out_rows, size_t x) { + const D d; + auto mc_x = Undefined(d); + auto mc_y = Undefined(d); + auto mc_b = Undefined(d); + auto sm_x = Undefined(d); + auto sm_y = Undefined(d); + auto sm_b = Undefined(d); + auto gap = Set(d, 0.5f); + ComputePixelChannel(d, dc_factors[0], rows_top[0], rows[0], rows_bottom[0], + &mc_x, &sm_x, &gap, x); + ComputePixelChannel(d, dc_factors[1], rows_top[1], rows[1], rows_bottom[1], + &mc_y, &sm_y, &gap, x); + ComputePixelChannel(d, dc_factors[2], rows_top[2], rows[2], rows_bottom[2], + &mc_b, &sm_b, &gap, x); + auto factor = MulAdd(Set(d, -4.0f), gap, Set(d, 3.0f)); + factor = ZeroIfNegative(factor); + + auto out = MulAdd(Sub(sm_x, mc_x), factor, mc_x); + Store(out, d, out_rows[0] + x); + out = MulAdd(Sub(sm_y, mc_y), factor, mc_y); + Store(out, d, out_rows[1] + x); + out = MulAdd(Sub(sm_b, mc_b), factor, mc_b); + Store(out, d, out_rows[2] + x); +} + +void AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc, + ThreadPool* pool) { + const size_t xsize = dc->xsize(); + const size_t ysize = dc->ysize(); + if (ysize <= 2 || xsize <= 2) return; + + // TODO(veluca): use tile-based processing? + // TODO(veluca): decide if changes to the y channel should be propagated to + // the x and b channels through color correlation. + JXL_ASSERT(w1 + w2 < 0.25f); + + Image3F smoothed(xsize, ysize); + // Fill in borders that the loop below will not. First and last are unused. + for (size_t c = 0; c < 3; c++) { + for (size_t y : {size_t(0), ysize - 1}) { + memcpy(smoothed.PlaneRow(c, y), dc->PlaneRow(c, y), + xsize * sizeof(float)); + } + } + auto process_row = [&](const uint32_t y, size_t /*thread*/) { + const float* JXL_RESTRICT rows_top[3]{ + dc->ConstPlaneRow(0, y - 1), + dc->ConstPlaneRow(1, y - 1), + dc->ConstPlaneRow(2, y - 1), + }; + const float* JXL_RESTRICT rows[3] = { + dc->ConstPlaneRow(0, y), + dc->ConstPlaneRow(1, y), + dc->ConstPlaneRow(2, y), + }; + const float* JXL_RESTRICT rows_bottom[3] = { + dc->ConstPlaneRow(0, y + 1), + dc->ConstPlaneRow(1, y + 1), + dc->ConstPlaneRow(2, y + 1), + }; + float* JXL_RESTRICT rows_out[3] = { + smoothed.PlaneRow(0, y), + smoothed.PlaneRow(1, y), + smoothed.PlaneRow(2, y), + }; + for (size_t x : {size_t(0), xsize - 1}) { + for (size_t c = 0; c < 3; c++) { + rows_out[c][x] = rows[c][x]; + } + } + + size_t x = 1; + // First pixels + const size_t N = Lanes(D()); + for (; x < std::min(N, xsize - 1); x++) { + ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out, + x); + } + // Full vectors. + for (; x + N <= xsize - 1; x += N) { + ComputePixel<D>(dc_factors, rows_top, rows, rows_bottom, rows_out, x); + } + // Last pixels. + for (; x < xsize - 1; x++) { + ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out, + x); + } + }; + JXL_CHECK(RunOnPool(pool, 1, ysize - 1, ThreadPool::NoInit, process_row, + "DCSmoothingRow")); + dc->Swap(smoothed); +} + +// DC dequantization. +void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in, + const float* dc_factors, float mul, const float* cfl_factors, + YCbCrChromaSubsampling chroma_subsampling, + const BlockCtxMap& bctx) { + const HWY_FULL(float) df; + const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float + if (chroma_subsampling.Is444()) { + const auto fac_x = Set(df, dc_factors[0] * mul); + const auto fac_y = Set(df, dc_factors[1] * mul); + const auto fac_b = Set(df, dc_factors[2] * mul); + const auto cfl_fac_x = Set(df, cfl_factors[0]); + const auto cfl_fac_b = Set(df, cfl_factors[2]); + for (size_t y = 0; y < r.ysize(); y++) { + float* dec_row_x = r.PlaneRow(dc, 0, y); + float* dec_row_y = r.PlaneRow(dc, 1, y); + float* dec_row_b = r.PlaneRow(dc, 2, y); + const int32_t* quant_row_x = in.channel[1].plane.Row(y); + const int32_t* quant_row_y = in.channel[0].plane.Row(y); + const int32_t* quant_row_b = in.channel[2].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x += Lanes(di)) { + const auto in_q_x = Load(di, quant_row_x + x); + const auto in_q_y = Load(di, quant_row_y + x); + const auto in_q_b = Load(di, quant_row_b + x); + const auto in_x = Mul(ConvertTo(df, in_q_x), fac_x); + const auto in_y = Mul(ConvertTo(df, in_q_y), fac_y); + const auto in_b = Mul(ConvertTo(df, in_q_b), fac_b); + Store(in_y, df, dec_row_y + x); + Store(MulAdd(in_y, cfl_fac_x, in_x), df, dec_row_x + x); + Store(MulAdd(in_y, cfl_fac_b, in_b), df, dec_row_b + x); + } + } + } else { + for (size_t c : {1, 0, 2}) { + Rect rect(r.x0() >> chroma_subsampling.HShift(c), + r.y0() >> chroma_subsampling.VShift(c), + r.xsize() >> chroma_subsampling.HShift(c), + r.ysize() >> chroma_subsampling.VShift(c)); + const auto fac = Set(df, dc_factors[c] * mul); + const Channel& ch = in.channel[c < 2 ? c ^ 1 : c]; + for (size_t y = 0; y < rect.ysize(); y++) { + const int32_t* quant_row = ch.plane.Row(y); + float* row = rect.PlaneRow(dc, c, y); + for (size_t x = 0; x < rect.xsize(); x += Lanes(di)) { + const auto in_q = Load(di, quant_row + x); + const auto in = Mul(ConvertTo(df, in_q), fac); + Store(in, df, row + x); + } + } + } + } + if (bctx.num_dc_ctxs <= 1) { + for (size_t y = 0; y < r.ysize(); y++) { + uint8_t* qdc_row = r.Row(quant_dc, y); + memset(qdc_row, 0, sizeof(*qdc_row) * r.xsize()); + } + } else { + for (size_t y = 0; y < r.ysize(); y++) { + uint8_t* qdc_row_val = r.Row(quant_dc, y); + const int32_t* quant_row_x = + in.channel[1].plane.Row(y >> chroma_subsampling.VShift(0)); + const int32_t* quant_row_y = + in.channel[0].plane.Row(y >> chroma_subsampling.VShift(1)); + const int32_t* quant_row_b = + in.channel[2].plane.Row(y >> chroma_subsampling.VShift(2)); + for (size_t x = 0; x < r.xsize(); x++) { + int bucket_x = 0, bucket_y = 0, bucket_b = 0; + for (int t : bctx.dc_thresholds[0]) { + if (quant_row_x[x >> chroma_subsampling.HShift(0)] > t) bucket_x++; + } + for (int t : bctx.dc_thresholds[1]) { + if (quant_row_y[x >> chroma_subsampling.HShift(1)] > t) bucket_y++; + } + for (int t : bctx.dc_thresholds[2]) { + if (quant_row_b[x >> chroma_subsampling.HShift(2)] > t) bucket_b++; + } + int bucket = bucket_x; + bucket *= bctx.dc_thresholds[2].size() + 1; + bucket += bucket_b; + bucket *= bctx.dc_thresholds[1].size() + 1; + bucket += bucket_y; + qdc_row_val[x] = bucket; + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(DequantDC); +HWY_EXPORT(AdaptiveDCSmoothing); +void AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc, + ThreadPool* pool) { + return HWY_DYNAMIC_DISPATCH(AdaptiveDCSmoothing)(dc_factors, dc, pool); +} + +void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in, + const float* dc_factors, float mul, const float* cfl_factors, + YCbCrChromaSubsampling chroma_subsampling, + const BlockCtxMap& bctx) { + return HWY_DYNAMIC_DISPATCH(DequantDC)(r, dc, quant_dc, in, dc_factors, mul, + cfl_factors, chroma_subsampling, bctx); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/compressed_dc.h b/third_party/jpeg-xl/lib/jxl/compressed_dc.h new file mode 100644 index 0000000000..b06e5931f0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/compressed_dc.h @@ -0,0 +1,34 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_COMPRESSED_DC_H_ +#define LIB_JXL_COMPRESSED_DC_H_ + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/modular_image.h" + +// DC handling functions: encoding and decoding of DC to and from bitstream, and +// related function to initialize the per-group decoder cache. + +namespace jxl { + +// Smooth DC in already-smooth areas, to counteract banding. +void AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc, + ThreadPool* pool); + +void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in, + const float* dc_factors, float mul, const float* cfl_factors, + YCbCrChromaSubsampling chroma_subsampling, + const BlockCtxMap& bctx); + +} // namespace jxl + +#endif // LIB_JXL_COMPRESSED_DC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/convolve-inl.h b/third_party/jpeg-xl/lib/jxl/convolve-inl.h new file mode 100644 index 0000000000..cd79153a3a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/convolve-inl.h @@ -0,0 +1,295 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_CONVOLVE_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_CONVOLVE_INL_H_ +#undef LIB_JXL_CONVOLVE_INL_H_ +#else +#define LIB_JXL_CONVOLVE_INL_H_ +#endif + +#include <hwy/highway.h> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/image_ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Broadcast; +#if HWY_TARGET != HWY_SCALAR +using hwy::HWY_NAMESPACE::CombineShiftRightBytes; +#endif +using hwy::HWY_NAMESPACE::TableLookupLanes; +using hwy::HWY_NAMESPACE::Vec; + +// Synthesizes left/right neighbors from a vector of center pixels. +class Neighbors { + public: + using D = HWY_CAPPED(float, 16); + using V = Vec<D>; + + // Returns l[i] == c[Mirror(i - 1)]. + HWY_INLINE HWY_MAYBE_UNUSED static V FirstL1(const V c) { +#if HWY_CAP_GE256 + const D d; + HWY_ALIGN constexpr int32_t lanes[16] = {0, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14}; + const auto indices = SetTableIndices(d, lanes); + // c = PONM'LKJI + return TableLookupLanes(c, indices); // ONML'KJII +#elif HWY_TARGET == HWY_SCALAR + return c; // Same (the first mirrored value is the last valid one) +#else // 128 bit + // c = LKJI +#if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) + return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(2, 1, 0, 0))}; // KJII +#else + const D d; + // TODO(deymo): Figure out if this can be optimized using a single vsri + // instruction to convert LKJI to KJII. + HWY_ALIGN constexpr int lanes[4] = {0, 0, 1, 2}; // KJII + const auto indices = SetTableIndices(d, lanes); + return TableLookupLanes(c, indices); +#endif +#endif + } + + // Returns l[i] == c[Mirror(i - 2)]. + HWY_INLINE HWY_MAYBE_UNUSED static V FirstL2(const V c) { +#if HWY_CAP_GE256 + const D d; + HWY_ALIGN constexpr int32_t lanes[16] = {1, 0, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13}; + const auto indices = SetTableIndices(d, lanes); + // c = PONM'LKJI + return TableLookupLanes(c, indices); // NMLK'JIIJ +#elif HWY_TARGET == HWY_SCALAR + const D d; + JXL_ASSERT(false); // unsupported, avoid calling this. + return Zero(d); +#else // 128 bit + // c = LKJI +#if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) + return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(1, 0, 0, 1))}; // JIIJ +#else + const D d; + HWY_ALIGN constexpr int lanes[4] = {1, 0, 0, 1}; // JIIJ + const auto indices = SetTableIndices(d, lanes); + return TableLookupLanes(c, indices); +#endif +#endif + } + + // Returns l[i] == c[Mirror(i - 3)]. + HWY_INLINE HWY_MAYBE_UNUSED static V FirstL3(const V c) { +#if HWY_CAP_GE256 + const D d; + HWY_ALIGN constexpr int32_t lanes[16] = {2, 1, 0, 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12}; + const auto indices = SetTableIndices(d, lanes); + // c = PONM'LKJI + return TableLookupLanes(c, indices); // MLKJ'IIJK +#elif HWY_TARGET == HWY_SCALAR + const D d; + JXL_ASSERT(false); // unsupported, avoid calling this. + return Zero(d); +#else // 128 bit + // c = LKJI +#if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) + return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(0, 0, 1, 2))}; // IIJK +#else + const D d; + HWY_ALIGN constexpr int lanes[4] = {2, 1, 0, 0}; // IIJK + const auto indices = SetTableIndices(d, lanes); + return TableLookupLanes(c, indices); +#endif +#endif + } +}; + +#if HWY_TARGET != HWY_SCALAR + +// Returns indices for SetTableIndices such that TableLookupLanes on the +// rightmost unaligned vector (rightmost sample in its most-significant lane) +// returns the mirrored values, with the mirror outside the last valid sample. +static inline const int32_t* MirrorLanes(const size_t mod) { + const HWY_CAPPED(float, 16) d; + constexpr size_t kN = MaxLanes(d); + + // For mod = `image width mod 16` 0..15: + // last full vec mirrored (mem order) loadedVec mirrorVec idxVec + // 0123456789abcdef| fedcba9876543210 fed..210 012..def 012..def + // 0123456789abcdef|0 0fedcba98765432 0fe..321 234..f00 123..eff + // 0123456789abcdef|01 10fedcba987654 10f..432 456..110 234..ffe + // 0123456789abcdef|012 210fedcba9876 210..543 67..2210 34..ffed + // 0123456789abcdef|0123 3210fedcba98 321..654 8..33210 4..ffedc + // 0123456789abcdef|01234 43210fedcba + // 0123456789abcdef|012345 543210fedc + // 0123456789abcdef|0123456 6543210fe + // 0123456789abcdef|01234567 76543210 + // 0123456789abcdef|012345678 8765432 + // 0123456789abcdef|0123456789 987654 + // 0123456789abcdef|0123456789A A9876 + // 0123456789abcdef|0123456789AB BA98 + // 0123456789abcdef|0123456789ABC CBA + // 0123456789abcdef|0123456789ABCD DC + // 0123456789abcdef|0123456789ABCDE E EDC..10f EED..210 ffe..321 +#if HWY_CAP_GE512 + HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, // + 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; +#elif HWY_CAP_GE256 + HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = { + 1, 2, 3, 4, 5, 6, 7, 7, // + 6, 5, 4, 3, 2, 1, 0}; +#else // 128-bit + HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = {1, 2, 3, 3, // + 2, 1, 0}; +#endif + return idx_lanes + kN - 1 - mod; +} + +#endif // HWY_TARGET != HWY_SCALAR + +// Single entry point for convolution. +// "Strategy" (Direct*/Separable*) decides kernel size and how to evaluate it. +template <class Strategy> +class ConvolveT { + static constexpr int64_t kRadius = Strategy::kRadius; + using Simd = HWY_CAPPED(float, 16); + + public: + static size_t MinWidth() { +#if HWY_TARGET == HWY_SCALAR + // First/Last use mirrored loads of up to +/- kRadius. + return 2 * kRadius; +#else + return Lanes(Simd()) + kRadius; +#endif + } + + // "Image" is ImageF or Image3F. + template <class Image, class Weights> + static void Run(const Image& in, const Rect& rect, const Weights& weights, + ThreadPool* pool, Image* out) { + JXL_CHECK(SameSize(rect, *out)); + JXL_CHECK(rect.xsize() >= MinWidth()); + + static_assert(int64_t(kRadius) <= 3, + "Must handle [0, kRadius) and >= kRadius"); + switch (rect.xsize() % Lanes(Simd())) { + case 0: + return RunRows<0>(in, rect, weights, pool, out); + case 1: + return RunRows<1>(in, rect, weights, pool, out); + case 2: + return RunRows<2>(in, rect, weights, pool, out); + default: + return RunRows<3>(in, rect, weights, pool, out); + } + } + + private: + template <size_t kSizeModN, class WrapRow, class Weights> + static JXL_INLINE void RunRow(const float* JXL_RESTRICT in, + const size_t xsize, const int64_t stride, + const WrapRow& wrap_row, const Weights& weights, + float* JXL_RESTRICT out) { + Strategy::template ConvolveRow<kSizeModN>(in, xsize, stride, wrap_row, + weights, out); + } + + template <size_t kSizeModN, class Weights> + static JXL_INLINE void RunBorderRows(const ImageF& in, const Rect& rect, + const int64_t ybegin, const int64_t yend, + const Weights& weights, ImageF* out) { + const int64_t stride = in.PixelsPerRow(); + const WrapRowMirror wrap_row(in, rect.ysize()); + for (int64_t y = ybegin; y < yend; ++y) { + RunRow<kSizeModN>(rect.ConstRow(in, y), rect.xsize(), stride, wrap_row, + weights, out->Row(y)); + } + } + + // Image3F. + template <size_t kSizeModN, class Weights> + static JXL_INLINE void RunBorderRows(const Image3F& in, const Rect& rect, + const int64_t ybegin, const int64_t yend, + const Weights& weights, Image3F* out) { + const int64_t stride = in.PixelsPerRow(); + for (int64_t y = ybegin; y < yend; ++y) { + for (size_t c = 0; c < 3; ++c) { + const WrapRowMirror wrap_row(in.Plane(c), rect.ysize()); + RunRow<kSizeModN>(rect.ConstPlaneRow(in, c, y), rect.xsize(), stride, + wrap_row, weights, out->PlaneRow(c, y)); + } + } + } + + template <size_t kSizeModN, class Weights> + static JXL_INLINE void RunInteriorRows(const ImageF& in, const Rect& rect, + const int64_t ybegin, + const int64_t yend, + const Weights& weights, + ThreadPool* pool, ImageF* out) { + const int64_t stride = in.PixelsPerRow(); + JXL_CHECK(RunOnPool( + pool, ybegin, yend, ThreadPool::NoInit, + [&](const uint32_t y, size_t /*thread*/) HWY_ATTR { + RunRow<kSizeModN>(rect.ConstRow(in, y), rect.xsize(), stride, + WrapRowUnchanged(), weights, out->Row(y)); + }, + "Convolve")); + } + + // Image3F. + template <size_t kSizeModN, class Weights> + static JXL_INLINE void RunInteriorRows(const Image3F& in, const Rect& rect, + const int64_t ybegin, + const int64_t yend, + const Weights& weights, + ThreadPool* pool, Image3F* out) { + const int64_t stride = in.PixelsPerRow(); + JXL_CHECK(RunOnPool( + pool, ybegin, yend, ThreadPool::NoInit, + [&](const uint32_t y, size_t /*thread*/) HWY_ATTR { + for (size_t c = 0; c < 3; ++c) { + RunRow<kSizeModN>(rect.ConstPlaneRow(in, c, y), rect.xsize(), + stride, WrapRowUnchanged(), weights, + out->PlaneRow(c, y)); + } + }, + "Convolve3")); + } + + template <size_t kSizeModN, class Image, class Weights> + static JXL_INLINE void RunRows(const Image& in, const Rect& rect, + const Weights& weights, ThreadPool* pool, + Image* out) { + const int64_t ysize = rect.ysize(); + RunBorderRows<kSizeModN>(in, rect, 0, std::min(int64_t(kRadius), ysize), + weights, out); + if (ysize > 2 * int64_t(kRadius)) { + RunInteriorRows<kSizeModN>(in, rect, int64_t(kRadius), + ysize - int64_t(kRadius), weights, pool, out); + } + if (ysize > int64_t(kRadius)) { + RunBorderRows<kSizeModN>(in, rect, ysize - int64_t(kRadius), ysize, + weights, out); + } + } +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_CONVOLVE_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/convolve.h b/third_party/jpeg-xl/lib/jxl/convolve.h new file mode 100644 index 0000000000..5231ae2640 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/convolve.h @@ -0,0 +1,88 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_CONVOLVE_H_ +#define LIB_JXL_CONVOLVE_H_ + +// 2D convolution. + +#include <stddef.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// No valid values outside [0, xsize), but the strategy may still safely load +// the preceding vector, and/or round xsize up to the vector lane count. This +// avoids needing PadImage. +// Requires xsize >= kConvolveLanes + kConvolveMaxRadius. +static constexpr size_t kConvolveMaxRadius = 3; + +// Weights must already be normalized. + +struct WeightsSymmetric3 { + // d r d (each replicated 4x) + // r c r + // d r d + float c[4]; + float r[4]; + float d[4]; +}; + +struct WeightsSymmetric5 { + // The lower-right quadrant is: c r R (each replicated 4x) + // r d L + // R L D + float c[4]; + float r[4]; + float R[4]; + float d[4]; + float D[4]; + float L[4]; +}; + +// Weights for separable 5x5 filters (typically but not necessarily the same +// values for horizontal and vertical directions). The kernel must already be +// normalized, but note that values for negative offsets are omitted, so the +// given values do not sum to 1. +struct WeightsSeparable5 { + // Horizontal 1D, distances 0..2 (each replicated 4x) + float horz[3 * 4]; + float vert[3 * 4]; +}; + +const WeightsSymmetric3& WeightsSymmetric3Lowpass(); +const WeightsSeparable5& WeightsSeparable5Lowpass(); +const WeightsSymmetric5& WeightsSymmetric5Lowpass(); + +void SlowSymmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out); + +void SlowSeparable5(const ImageF& in, const Rect& in_rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out, const Rect& out_rect); + +void Symmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* out); + +void Symmetric5(const ImageF& in, const Rect& in_rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out, const Rect& out_rect); + +void Symmetric5(const ImageF& in, const Rect& rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out); + +void Separable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out); + +} // namespace jxl + +#endif // LIB_JXL_CONVOLVE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/convolve_separable5.cc b/third_party/jpeg-xl/lib/jxl/convolve_separable5.cc new file mode 100644 index 0000000000..db533606a1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/convolve_separable5.cc @@ -0,0 +1,261 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/convolve.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/convolve_separable5.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/convolve-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Vec; + +// 5x5 convolution by separable kernel with a single scan through the input. +// This is more cache-efficient than separate horizontal/vertical passes, and +// possibly faster (given enough registers) than tiling and/or transposing. +// +// Overview: imagine a 5x5 window around a central pixel. First convolve the +// rows by multiplying the pixels with the corresponding weights from +// WeightsSeparable5.horz[abs(x_offset) * 4]. Then multiply each of these +// intermediate results by the corresponding vertical weight, i.e. +// vert[abs(y_offset) * 4]. Finally, store the sum of these values as the +// convolution result at the position of the central pixel in the output. +// +// Each of these operations uses SIMD vectors. The central pixel and most +// importantly the output are aligned, so neighnoring pixels (e.g. x_offset=1) +// require unaligned loads. Because weights are supplied in identical groups of +// 4, we can use LoadDup128 to load them (slightly faster). +// +// Uses mirrored boundary handling. Until x >= kRadius, the horizontal +// convolution uses Neighbors class to shuffle vectors as if each of its lanes +// had been loaded from the mirrored offset. Similarly, the last full vector to +// write uses mirroring. In the case of scalar vectors, Neighbors is not usable +// and the value is loaded directly. Otherwise, the number of valid pixels +// modulo the vector size enables a small optimization: for smaller offsets, +// a non-mirrored load is sufficient. +class Separable5Strategy { + using D = HWY_CAPPED(float, 16); + using V = Vec<D>; + + public: + static constexpr int64_t kRadius = 2; + + template <size_t kSizeModN, class WrapRow> + static JXL_MAYBE_INLINE void ConvolveRow( + const float* const JXL_RESTRICT row_m, const size_t xsize, + const int64_t stride, const WrapRow& wrap_row, + const WeightsSeparable5& weights, float* const JXL_RESTRICT row_out) { + const D d; + const int64_t neg_stride = -stride; // allows LEA addressing. + const float* const JXL_RESTRICT row_t2 = + wrap_row(row_m + 2 * neg_stride, stride); + const float* const JXL_RESTRICT row_t1 = + wrap_row(row_m + 1 * neg_stride, stride); + const float* const JXL_RESTRICT row_b1 = + wrap_row(row_m + 1 * stride, stride); + const float* const JXL_RESTRICT row_b2 = + wrap_row(row_m + 2 * stride, stride); + + const V wh0 = LoadDup128(d, weights.horz + 0 * 4); + const V wh1 = LoadDup128(d, weights.horz + 1 * 4); + const V wh2 = LoadDup128(d, weights.horz + 2 * 4); + const V wv0 = LoadDup128(d, weights.vert + 0 * 4); + const V wv1 = LoadDup128(d, weights.vert + 1 * 4); + const V wv2 = LoadDup128(d, weights.vert + 2 * 4); + + size_t x = 0; + + // More than one iteration for scalars. + for (; x < kRadius; x += Lanes(d)) { + const V conv0 = + Mul(HorzConvolveFirst(row_m, x, xsize, wh0, wh1, wh2), wv0); + + const V conv1t = HorzConvolveFirst(row_t1, x, xsize, wh0, wh1, wh2); + const V conv1b = HorzConvolveFirst(row_b1, x, xsize, wh0, wh1, wh2); + const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0); + + const V conv2t = HorzConvolveFirst(row_t2, x, xsize, wh0, wh1, wh2); + const V conv2b = HorzConvolveFirst(row_b2, x, xsize, wh0, wh1, wh2); + const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1); + Store(conv2, d, row_out + x); + } + + // Main loop: load inputs without padding + for (; x + Lanes(d) + kRadius <= xsize; x += Lanes(d)) { + const V conv0 = Mul(HorzConvolve(row_m + x, wh0, wh1, wh2), wv0); + + const V conv1t = HorzConvolve(row_t1 + x, wh0, wh1, wh2); + const V conv1b = HorzConvolve(row_b1 + x, wh0, wh1, wh2); + const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0); + + const V conv2t = HorzConvolve(row_t2 + x, wh0, wh1, wh2); + const V conv2b = HorzConvolve(row_b2 + x, wh0, wh1, wh2); + const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1); + Store(conv2, d, row_out + x); + } + + // Last full vector to write (the above loop handled mod >= kRadius) +#if HWY_TARGET == HWY_SCALAR + while (x < xsize) { +#else + if (kSizeModN < kRadius) { +#endif + const V conv0 = + Mul(HorzConvolveLast<kSizeModN>(row_m, x, xsize, wh0, wh1, wh2), wv0); + + const V conv1t = + HorzConvolveLast<kSizeModN>(row_t1, x, xsize, wh0, wh1, wh2); + const V conv1b = + HorzConvolveLast<kSizeModN>(row_b1, x, xsize, wh0, wh1, wh2); + const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0); + + const V conv2t = + HorzConvolveLast<kSizeModN>(row_t2, x, xsize, wh0, wh1, wh2); + const V conv2b = + HorzConvolveLast<kSizeModN>(row_b2, x, xsize, wh0, wh1, wh2); + const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1); + Store(conv2, d, row_out + x); + x += Lanes(d); + } + + // If mod = 0, the above vector was the last. + if (kSizeModN != 0) { + for (; x < xsize; ++x) { + float mul = 0.0f; + for (int64_t dy = -kRadius; dy <= kRadius; ++dy) { + const float wy = weights.vert[std::abs(dy) * 4]; + const float* clamped_row = wrap_row(row_m + dy * stride, stride); + for (int64_t dx = -kRadius; dx <= kRadius; ++dx) { + const float wx = weights.horz[std::abs(dx) * 4]; + const int64_t clamped_x = Mirror(x + dx, xsize); + mul += clamped_row[clamped_x] * wx * wy; + } + } + row_out[x] = mul; + } + } + } + + private: + // Same as HorzConvolve for the first/last vector in a row. + static JXL_MAYBE_INLINE V HorzConvolveFirst( + const float* const JXL_RESTRICT row, const int64_t x, const int64_t xsize, + const V wh0, const V wh1, const V wh2) { + const D d; + const V c = LoadU(d, row + x); + const V mul0 = Mul(c, wh0); + +#if HWY_TARGET == HWY_SCALAR + const V l1 = LoadU(d, row + Mirror(x - 1, xsize)); + const V l2 = LoadU(d, row + Mirror(x - 2, xsize)); +#else + (void)xsize; + const V l1 = Neighbors::FirstL1(c); + const V l2 = Neighbors::FirstL2(c); +#endif + + const V r1 = LoadU(d, row + x + 1); + const V r2 = LoadU(d, row + x + 2); + + const V mul1 = MulAdd(Add(l1, r1), wh1, mul0); + const V mul2 = MulAdd(Add(l2, r2), wh2, mul1); + return mul2; + } + + template <size_t kSizeModN> + static JXL_MAYBE_INLINE V + HorzConvolveLast(const float* const JXL_RESTRICT row, const int64_t x, + const int64_t xsize, const V wh0, const V wh1, const V wh2) { + const D d; + const V c = LoadU(d, row + x); + const V mul0 = Mul(c, wh0); + + const V l1 = LoadU(d, row + x - 1); + const V l2 = LoadU(d, row + x - 2); + + V r1, r2; +#if HWY_TARGET == HWY_SCALAR + r1 = LoadU(d, row + Mirror(x + 1, xsize)); + r2 = LoadU(d, row + Mirror(x + 2, xsize)); +#else + const size_t N = Lanes(d); + if (kSizeModN == 0) { + r2 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 2))); + r1 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 1))); + } else { // == 1 + const auto last = LoadU(d, row + xsize - N); + r2 = TableLookupLanes(last, SetTableIndices(d, MirrorLanes(N - 1))); + r1 = last; + } +#endif + + // Sum of pixels with Manhattan distance i, multiplied by weights[i]. + const V sum1 = Add(l1, r1); + const V mul1 = MulAdd(sum1, wh1, mul0); + const V sum2 = Add(l2, r2); + const V mul2 = MulAdd(sum2, wh2, mul1); + return mul2; + } + + // Requires kRadius valid pixels before/after pos. + static JXL_MAYBE_INLINE V HorzConvolve(const float* const JXL_RESTRICT pos, + const V wh0, const V wh1, + const V wh2) { + const D d; + const V c = LoadU(d, pos); + const V mul0 = Mul(c, wh0); + + // Loading anew is faster than combining vectors. + const V l1 = LoadU(d, pos - 1); + const V r1 = LoadU(d, pos + 1); + const V l2 = LoadU(d, pos - 2); + const V r2 = LoadU(d, pos + 2); + // Sum of pixels with Manhattan distance i, multiplied by weights[i]. + const V sum1 = Add(l1, r1); + const V mul1 = MulAdd(sum1, wh1, mul0); + const V sum2 = Add(l2, r2); + const V mul2 = MulAdd(sum2, wh2, mul1); + return mul2; + } +}; + +void Separable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out) { + using Conv = ConvolveT<Separable5Strategy>; + if (rect.xsize() >= Conv::MinWidth()) { + return Conv::Run(in, rect, weights, pool, out); + } + + return SlowSeparable5(in, rect, weights, pool, out, Rect(*out)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(Separable5); +void Separable5(const ImageF& in, const Rect& rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out) { + return HWY_DYNAMIC_DISPATCH(Separable5)(in, rect, weights, pool, out); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/convolve_slow.cc b/third_party/jpeg-xl/lib/jxl/convolve_slow.cc new file mode 100644 index 0000000000..655e040885 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/convolve_slow.cc @@ -0,0 +1,198 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/convolve.h" + +#include "lib/jxl/convolve-inl.h" + +namespace jxl { + +//------------------------------------------------------------------------------ +// Kernels + +// 4 instances of a given literal value, useful as input to LoadDup128. +#define JXL_REP4(literal) literal, literal, literal, literal + +// Concentrates energy in low-frequency components (e.g. for antialiasing). +const WeightsSymmetric3& WeightsSymmetric3Lowpass() { + // Computed by research/convolve_weights.py's cubic spline approximations of + // prolate spheroidal wave functions. + constexpr float w0 = 0.36208932f; + constexpr float w1 = 0.12820096f; + constexpr float w2 = 0.03127668f; + static constexpr WeightsSymmetric3 weights = { + {JXL_REP4(w0)}, {JXL_REP4(w1)}, {JXL_REP4(w2)}}; + return weights; +} + +const WeightsSeparable5& WeightsSeparable5Lowpass() { + constexpr float w0 = 0.41714928f; + constexpr float w1 = 0.25539268f; + constexpr float w2 = 0.03603267f; + static constexpr WeightsSeparable5 weights = { + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}, + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}}; + return weights; +} + +const WeightsSymmetric5& WeightsSymmetric5Lowpass() { + static constexpr WeightsSymmetric5 weights = { + {JXL_REP4(0.1740135f)}, {JXL_REP4(0.1065369f)}, {JXL_REP4(0.0150310f)}, + {JXL_REP4(0.0652254f)}, {JXL_REP4(0.0012984f)}, {JXL_REP4(0.0092025f)}}; + return weights; +} + +const WeightsSeparable5& WeightsSeparable5Gaussian1() { + constexpr float w0 = 0.38774f; + constexpr float w1 = 0.24477f; + constexpr float w2 = 0.06136f; + static constexpr WeightsSeparable5 weights = { + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}, + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}}; + return weights; +} + +const WeightsSeparable5& WeightsSeparable5Gaussian2() { + constexpr float w0 = 0.250301f; + constexpr float w1 = 0.221461f; + constexpr float w2 = 0.153388f; + static constexpr WeightsSeparable5 weights = { + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}, + {JXL_REP4(w0), JXL_REP4(w1), JXL_REP4(w2)}}; + return weights; +} + +#undef JXL_REP4 + +//------------------------------------------------------------------------------ +// Slow + +namespace { + +template <class WrapX, class WrapY> +float SlowSymmetric3Pixel(const ImageF& in, const int64_t ix, const int64_t iy, + const int64_t xsize, const int64_t ysize, + const WeightsSymmetric3& weights) { + float sum = 0.0f; + + // ix: image; kx: kernel + for (int64_t ky = -1; ky <= 1; ky++) { + const int64_t y = WrapY()(iy + ky, ysize); + const float* JXL_RESTRICT row_in = in.ConstRow(static_cast<size_t>(y)); + + const float wc = ky == 0 ? weights.c[0] : weights.r[0]; + const float wlr = ky == 0 ? weights.r[0] : weights.d[0]; + + const int64_t xm1 = WrapX()(ix - 1, xsize); + const int64_t xp1 = WrapX()(ix + 1, xsize); + sum += row_in[ix] * wc + (row_in[xm1] + row_in[xp1]) * wlr; + } + return sum; +} + +template <class WrapY> +void SlowSymmetric3Row(const ImageF& in, const int64_t iy, const int64_t xsize, + const int64_t ysize, const WeightsSymmetric3& weights, + float* JXL_RESTRICT row_out) { + row_out[0] = + SlowSymmetric3Pixel<WrapMirror, WrapY>(in, 0, iy, xsize, ysize, weights); + for (int64_t ix = 1; ix < xsize - 1; ix++) { + row_out[ix] = SlowSymmetric3Pixel<WrapUnchanged, WrapY>(in, ix, iy, xsize, + ysize, weights); + } + { + const int64_t ix = xsize - 1; + row_out[ix] = SlowSymmetric3Pixel<WrapMirror, WrapY>(in, ix, iy, xsize, + ysize, weights); + } +} + +} // namespace + +void SlowSymmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out) { + const int64_t xsize = static_cast<int64_t>(rect.xsize()); + const int64_t ysize = static_cast<int64_t>(rect.ysize()); + const int64_t kRadius = 1; + + JXL_CHECK(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t iy = task; + float* JXL_RESTRICT out_row = out->Row(static_cast<size_t>(iy)); + + if (iy < kRadius || iy >= ysize - kRadius) { + SlowSymmetric3Row<WrapMirror>(in, iy, xsize, ysize, weights, out_row); + } else { + SlowSymmetric3Row<WrapUnchanged>(in, iy, xsize, ysize, weights, + out_row); + } + }, + "SlowSymmetric3")); +} + +namespace { + +// Separable kernels, any radius. +float SlowSeparablePixel(const ImageF& in, const Rect& rect, const int64_t x, + const int64_t y, const int64_t radius, + const float* JXL_RESTRICT horz_weights, + const float* JXL_RESTRICT vert_weights) { + const size_t xsize = in.xsize(); + const size_t ysize = in.ysize(); + const WrapMirror wrap; + + float mul = 0.0f; + for (int dy = -radius; dy <= radius; ++dy) { + const float wy = vert_weights[std::abs(dy) * 4]; + const size_t sy = wrap(rect.y0() + y + dy, ysize); + JXL_CHECK(sy < ysize); + const float* const JXL_RESTRICT row = in.ConstRow(sy); + for (int dx = -radius; dx <= radius; ++dx) { + const float wx = horz_weights[std::abs(dx) * 4]; + const size_t sx = wrap(rect.x0() + x + dx, xsize); + JXL_CHECK(sx < xsize); + mul += row[sx] * wx * wy; + } + } + return mul; +} + +template <int R, typename Weights> +void SlowSeparable(const ImageF& in, const Rect& in_rect, + const Weights& weights, ThreadPool* pool, ImageF* out, + const Rect& out_rect) { + JXL_ASSERT(in_rect.xsize() == out_rect.xsize()); + JXL_ASSERT(in_rect.ysize() == out_rect.ysize()); + JXL_ASSERT(in_rect.IsInside(Rect(in))); + JXL_ASSERT(out_rect.IsInside(Rect(*out))); + const float* horz_weights = &weights.horz[0]; + const float* vert_weights = &weights.vert[0]; + + const size_t ysize = in_rect.ysize(); + JXL_CHECK(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + + float* const JXL_RESTRICT row_out = out_rect.Row(out, y); + for (size_t x = 0; x < in_rect.xsize(); ++x) { + row_out[x] = SlowSeparablePixel(in, in_rect, x, y, /*radius=*/R, + horz_weights, vert_weights); + } + }, + "SlowSeparable")); +} + +} // namespace + +void SlowSeparable5(const ImageF& in, const Rect& in_rect, + const WeightsSeparable5& weights, ThreadPool* pool, + ImageF* out, const Rect& out_rect) { + SlowSeparable<2>(in, in_rect, weights, pool, out, out_rect); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/convolve_symmetric3.cc b/third_party/jpeg-xl/lib/jxl/convolve_symmetric3.cc new file mode 100644 index 0000000000..06b59dfb60 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/convolve_symmetric3.cc @@ -0,0 +1,194 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/convolve.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/convolve_symmetric3.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/convolve-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Vec; + +template <class WrapY, class V> +static V WeightedSum(const ImageF& in, const WrapY wrap_y, const size_t ix, + const int64_t iy, const size_t ysize, const V wx0, + const V wx1, const V wx2) { + const HWY_FULL(float) d; + const float* JXL_RESTRICT center = in.ConstRow(wrap_y(iy, ysize)) + ix; + const auto in_m2 = LoadU(d, center - 2); + const auto in_p2 = LoadU(d, center + 2); + const auto in_m1 = LoadU(d, center - 1); + const auto in_p1 = LoadU(d, center + 1); + const auto in_00 = Load(d, center); + const auto sum_2 = Mul(wx2, Add(in_m2, in_p2)); + const auto sum_1 = Mul(wx1, Add(in_m1, in_p1)); + const auto sum_0 = Mul(wx0, in_00); + return Add(sum_2, Add(sum_1, sum_0)); +} + +// 3x3 convolution by symmetric kernel with a single scan through the input. +class Symmetric3Strategy { + using D = HWY_CAPPED(float, 16); + using V = Vec<D>; + + public: + static constexpr int64_t kRadius = 1; + + // Only accesses pixels in [0, xsize). + template <size_t kSizeModN, class WrapRow> + static JXL_MAYBE_INLINE void ConvolveRow( + const float* const JXL_RESTRICT row_m, const size_t xsize, + const int64_t stride, const WrapRow& wrap_row, + const WeightsSymmetric3& weights, float* const JXL_RESTRICT row_out) { + const D d; + // t, m, b = top, middle, bottom row; + const float* const JXL_RESTRICT row_t = wrap_row(row_m - stride, stride); + const float* const JXL_RESTRICT row_b = wrap_row(row_m + stride, stride); + + // Must load in advance - compiler doesn't understand LoadDup128 and + // schedules them too late. + const V w0 = LoadDup128(d, weights.c); + const V w1 = LoadDup128(d, weights.r); + const V w2 = LoadDup128(d, weights.d); + + // l, c, r = left, center, right. Leftmost vector: need FirstL1. + { + const V tc = LoadU(d, row_t + 0); + const V mc = LoadU(d, row_m + 0); + const V bc = LoadU(d, row_b + 0); + const V tl = Neighbors::FirstL1(tc); + const V tr = LoadU(d, row_t + 0 + 1); + const V ml = Neighbors::FirstL1(mc); + const V mr = LoadU(d, row_m + 0 + 1); + const V bl = Neighbors::FirstL1(bc); + const V br = LoadU(d, row_b + 0 + 1); + const V conv = + WeightedSum(tl, tc, tr, ml, mc, mr, bl, bc, br, w0, w1, w2); + Store(conv, d, row_out + 0); + } + + // Loop as long as we can load enough new values: + const size_t N = Lanes(d); + size_t x = N; + for (; x + N + kRadius <= xsize; x += N) { + const auto conv = ConvolveValid(row_t, row_m, row_b, x, w0, w1, w2); + Store(conv, d, row_out + x); + } + + // For final (partial) vector: + const V tc = LoadU(d, row_t + x); + const V mc = LoadU(d, row_m + x); + const V bc = LoadU(d, row_b + x); + + V tr, mr, br; +#if HWY_TARGET == HWY_SCALAR + tr = tc; // Single-lane => mirrored right neighbor = center value. + mr = mc; + br = bc; +#else + if (kSizeModN == 0) { + // The above loop didn't handle the last vector because it needs an + // additional right neighbor (generated via mirroring). + auto mirror = SetTableIndices(d, MirrorLanes(N - 1)); + tr = TableLookupLanes(tc, mirror); + mr = TableLookupLanes(mc, mirror); + br = TableLookupLanes(bc, mirror); + } else { + auto mirror = SetTableIndices(d, MirrorLanes((xsize % N) - 1)); + // Loads last valid value into uppermost lane and mirrors. + tr = TableLookupLanes(LoadU(d, row_t + xsize - N), mirror); + mr = TableLookupLanes(LoadU(d, row_m + xsize - N), mirror); + br = TableLookupLanes(LoadU(d, row_b + xsize - N), mirror); + } +#endif + + const V tl = LoadU(d, row_t + x - 1); + const V ml = LoadU(d, row_m + x - 1); + const V bl = LoadU(d, row_b + x - 1); + const V conv = WeightedSum(tl, tc, tr, ml, mc, mr, bl, bc, br, w0, w1, w2); + Store(conv, d, row_out + x); + } + + private: + // Returns sum{x_i * w_i}. + template <class V> + static JXL_MAYBE_INLINE V WeightedSum(const V tl, const V tc, const V tr, + const V ml, const V mc, const V mr, + const V bl, const V bc, const V br, + const V w0, const V w1, const V w2) { + const V sum_tb = Add(tc, bc); + + // Faster than 5 mul + 4 FMA. + const V mul0 = Mul(mc, w0); + const V sum_lr = Add(ml, mr); + + const V x1 = Add(sum_tb, sum_lr); + const V mul1 = MulAdd(x1, w1, mul0); + + const V sum_t2 = Add(tl, tr); + const V sum_b2 = Add(bl, br); + const V x2 = Add(sum_t2, sum_b2); + const V mul2 = MulAdd(x2, w2, mul1); + return mul2; + } + + static JXL_MAYBE_INLINE V ConvolveValid(const float* JXL_RESTRICT row_t, + const float* JXL_RESTRICT row_m, + const float* JXL_RESTRICT row_b, + const int64_t x, const V w0, + const V w1, const V w2) { + const D d; + const V tc = LoadU(d, row_t + x); + const V mc = LoadU(d, row_m + x); + const V bc = LoadU(d, row_b + x); + const V tl = LoadU(d, row_t + x - 1); + const V tr = LoadU(d, row_t + x + 1); + const V ml = LoadU(d, row_m + x - 1); + const V mr = LoadU(d, row_m + x + 1); + const V bl = LoadU(d, row_b + x - 1); + const V br = LoadU(d, row_b + x + 1); + return WeightedSum(tl, tc, tr, ml, mc, mr, bl, bc, br, w0, w1, w2); + } +}; + +void Symmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* out) { + using Conv = ConvolveT<Symmetric3Strategy>; + if (rect.xsize() >= Conv::MinWidth()) { + return Conv::Run(in, rect, weights, pool, out); + } + + return SlowSymmetric3(in, rect, weights, pool, out); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(Symmetric3); +void Symmetric3(const ImageF& in, const Rect& rect, + const WeightsSymmetric3& weights, ThreadPool* pool, + ImageF* out) { + return HWY_DYNAMIC_DISPATCH(Symmetric3)(in, rect, weights, pool, out); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/convolve_symmetric5.cc b/third_party/jpeg-xl/lib/jxl/convolve_symmetric5.cc new file mode 100644 index 0000000000..2e203fd08f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/convolve_symmetric5.cc @@ -0,0 +1,189 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/convolve.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/convolve_symmetric5.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/convolve-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::Vec; + +// Weighted sum of 1x5 pixels around ix, iy with [wx2 wx1 wx0 wx1 wx2]. +template <class WrapY> +static float WeightedSumBorder(const ImageF& in, const WrapY wrap_y, + const int64_t ix, const int64_t iy, + const size_t xsize, const size_t ysize, + const float wx0, const float wx1, + const float wx2) { + const WrapMirror wrap_x; + const float* JXL_RESTRICT row = in.ConstRow(wrap_y(iy, ysize)); + const float in_m2 = row[wrap_x(ix - 2, xsize)]; + const float in_p2 = row[wrap_x(ix + 2, xsize)]; + const float in_m1 = row[wrap_x(ix - 1, xsize)]; + const float in_p1 = row[wrap_x(ix + 1, xsize)]; + const float in_00 = row[ix]; + const float sum_2 = wx2 * (in_m2 + in_p2); + const float sum_1 = wx1 * (in_m1 + in_p1); + const float sum_0 = wx0 * in_00; + return sum_2 + (sum_1 + sum_0); +} + +template <class WrapY, class V> +static V WeightedSum(const ImageF& in, const WrapY wrap_y, const size_t ix, + const int64_t iy, const size_t ysize, const V wx0, + const V wx1, const V wx2) { + const HWY_FULL(float) d; + const float* JXL_RESTRICT center = in.ConstRow(wrap_y(iy, ysize)) + ix; + const auto in_m2 = LoadU(d, center - 2); + const auto in_p2 = LoadU(d, center + 2); + const auto in_m1 = LoadU(d, center - 1); + const auto in_p1 = LoadU(d, center + 1); + const auto in_00 = LoadU(d, center); + const auto sum_2 = Mul(wx2, Add(in_m2, in_p2)); + const auto sum_1 = Mul(wx1, Add(in_m1, in_p1)); + const auto sum_0 = Mul(wx0, in_00); + return Add(sum_2, Add(sum_1, sum_0)); +} + +// Produces result for one pixel +template <class WrapY> +float Symmetric5Border(const ImageF& in, const int64_t ix, const int64_t iy, + const WeightsSymmetric5& weights) { + const float w0 = weights.c[0]; + const float w1 = weights.r[0]; + const float w2 = weights.R[0]; + const float w4 = weights.d[0]; + const float w5 = weights.L[0]; + const float w8 = weights.D[0]; + + const size_t xsize = in.xsize(); + const size_t ysize = in.ysize(); + const WrapY wrap_y; + // Unrolled loop over all 5 rows of the kernel. + float sum0 = WeightedSumBorder(in, wrap_y, ix, iy, xsize, ysize, w0, w1, w2); + + sum0 += WeightedSumBorder(in, wrap_y, ix, iy - 2, xsize, ysize, w2, w5, w8); + float sum1 = + WeightedSumBorder(in, wrap_y, ix, iy + 2, xsize, ysize, w2, w5, w8); + + sum0 += WeightedSumBorder(in, wrap_y, ix, iy - 1, xsize, ysize, w1, w4, w5); + sum1 += WeightedSumBorder(in, wrap_y, ix, iy + 1, xsize, ysize, w1, w4, w5); + + return sum0 + sum1; +} + +// Produces result for one vector's worth of pixels +template <class WrapY> +static void Symmetric5Interior(const ImageF& in, const int64_t ix, + const int64_t rix, const int64_t iy, + const WeightsSymmetric5& weights, + float* JXL_RESTRICT row_out) { + const HWY_FULL(float) d; + + const auto w0 = LoadDup128(d, weights.c); + const auto w1 = LoadDup128(d, weights.r); + const auto w2 = LoadDup128(d, weights.R); + const auto w4 = LoadDup128(d, weights.d); + const auto w5 = LoadDup128(d, weights.L); + const auto w8 = LoadDup128(d, weights.D); + + const size_t ysize = in.ysize(); + const WrapY wrap_y; + // Unrolled loop over all 5 rows of the kernel. + auto sum0 = WeightedSum(in, wrap_y, ix, iy, ysize, w0, w1, w2); + + sum0 = Add(sum0, WeightedSum(in, wrap_y, ix, iy - 2, ysize, w2, w5, w8)); + auto sum1 = WeightedSum(in, wrap_y, ix, iy + 2, ysize, w2, w5, w8); + + sum0 = Add(sum0, WeightedSum(in, wrap_y, ix, iy - 1, ysize, w1, w4, w5)); + sum1 = Add(sum1, WeightedSum(in, wrap_y, ix, iy + 1, ysize, w1, w4, w5)); + + StoreU(Add(sum0, sum1), d, row_out + rix); +} + +template <class WrapY> +static void Symmetric5Row(const ImageF& in, const Rect& rect, const int64_t iy, + const WeightsSymmetric5& weights, + float* JXL_RESTRICT row_out) { + const int64_t kRadius = 2; + const size_t xend = rect.x1(); + + size_t rix = 0; + size_t ix = rect.x0(); + const HWY_FULL(float) d; + const size_t N = Lanes(d); + const size_t aligned_x = RoundUpTo(kRadius, N); + for (; ix < std::min(aligned_x, xend); ++ix, ++rix) { + row_out[rix] = Symmetric5Border<WrapY>(in, ix, iy, weights); + } + for (; ix + N + kRadius <= xend; ix += N, rix += N) { + Symmetric5Interior<WrapY>(in, ix, rix, iy, weights, row_out); + } + for (; ix < xend; ++ix, ++rix) { + row_out[rix] = Symmetric5Border<WrapY>(in, ix, iy, weights); + } +} + +// Semi-vectorized (interior pixels Fonly); called directly like slow::, unlike +// the fully vectorized strategies below. +void Symmetric5(const ImageF& in, const Rect& in_rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out, const Rect& out_rect) { + JXL_ASSERT(in_rect.xsize() == out_rect.xsize()); + JXL_ASSERT(in_rect.ysize() == out_rect.ysize()); + const size_t ysize = in_rect.ysize(); + JXL_CHECK(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t riy = task; + const int64_t iy = in_rect.y0() + riy; + + if (iy < 2 || iy >= static_cast<ssize_t>(in.ysize()) - 2) { + Symmetric5Row<WrapMirror>(in, in_rect, iy, weights, + out_rect.Row(out, riy)); + } else { + Symmetric5Row<WrapUnchanged>(in, in_rect, iy, weights, + out_rect.Row(out, riy)); + } + }, + "Symmetric5x5Convolution")); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(Symmetric5); +void Symmetric5(const ImageF& in, const Rect& in_rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out, const Rect& out_rect) { + return HWY_DYNAMIC_DISPATCH(Symmetric5)(in, in_rect, weights, pool, out, + out_rect); +} + +void Symmetric5(const ImageF& in, const Rect& rect, + const WeightsSymmetric5& weights, ThreadPool* pool, + ImageF* JXL_RESTRICT out) { + return Symmetric5(in, rect, weights, pool, out, Rect(*out)); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/convolve_test.cc b/third_party/jpeg-xl/lib/jxl/convolve_test.cc new file mode 100644 index 0000000000..6a8dc9c400 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/convolve_test.cc @@ -0,0 +1,265 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/convolve.h" + +#include <time.h> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/convolve_test.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> +#include <hwy/nanobenchmark.h> +#include <hwy/tests/hwy_gtest.h> +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +#ifndef JXL_DEBUG_CONVOLVE +#define JXL_DEBUG_CONVOLVE 0 +#endif + +#include "lib/jxl/convolve-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +void TestNeighbors() { + const Neighbors::D d; + const Neighbors::V v = Iota(d, 0); + constexpr size_t kMaxVectorSize = 64; + constexpr size_t M = kMaxVectorSize / sizeof(float); + HWY_ALIGN float actual[M] = {0}; + + HWY_ALIGN float first_l1[M] = {0, 0, 1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12, 13, 14}; + Store(Neighbors::FirstL1(v), d, actual); + const size_t N = Lanes(d); + ASSERT_LE(N, M); + EXPECT_EQ(std::vector<float>(first_l1, first_l1 + N), + std::vector<float>(actual, actual + N)); + +#if HWY_TARGET != HWY_SCALAR + HWY_ALIGN float first_l2[M] = {1, 0, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, 12, 13}; + Store(Neighbors::FirstL2(v), d, actual); + EXPECT_EQ(std::vector<float>(first_l2, first_l2 + N), + std::vector<float>(actual, actual + N)); + + HWY_ALIGN float first_l3[] = {2, 1, 0, 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12}; + Store(Neighbors::FirstL3(v), d, actual); + EXPECT_EQ(std::vector<float>(first_l3, first_l3 + N), + std::vector<float>(actual, actual + N)); +#endif // HWY_TARGET != HWY_SCALAR +} + +void VerifySymmetric3(const size_t xsize, const size_t ysize, ThreadPool* pool, + Rng* rng) { + const Rect rect(0, 0, xsize, ysize); + + ImageF in(xsize, ysize); + GenerateImage(*rng, &in, 0.0f, 1.0f); + + ImageF out_expected(xsize, ysize); + ImageF out_actual(xsize, ysize); + + const WeightsSymmetric3& weights = WeightsSymmetric3Lowpass(); + Symmetric3(in, rect, weights, pool, &out_expected); + SlowSymmetric3(in, rect, weights, pool, &out_actual); + + JXL_ASSERT_OK(VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f, _)); +} + +std::vector<Rect> GenerateTestRectangles(size_t xsize, size_t ysize) { + std::vector<Rect> out; + for (size_t tl : {0, 1, 13}) { + for (size_t br : {0, 1, 13}) { + if (xsize > tl + br && ysize > tl + br) { + out.push_back(Rect(tl, tl, xsize - tl - br, ysize - tl - br)); + } + } + } + return out; +} + +// Ensures Symmetric and Separable give the same result. +void VerifySymmetric5(const size_t xsize, const size_t ysize, ThreadPool* pool, + Rng* rng) { + ImageF in(xsize, ysize); + GenerateImage(*rng, &in, 0.0f, 1.0f); + + for (const Rect& in_rect : GenerateTestRectangles(xsize, ysize)) { + JXL_DEBUG(JXL_DEBUG_CONVOLVE, + "in_rect: %" PRIuS "x%" PRIuS "+%" PRIuS ",%" PRIuS "", + in_rect.xsize(), in_rect.ysize(), in_rect.x0(), in_rect.y0()); + { + Rect out_rect = in_rect; + ImageF out_expected(xsize, ysize); + ImageF out_actual(xsize, ysize); + FillImage(-1.0f, &out_expected); + FillImage(-1.0f, &out_actual); + + SlowSeparable5(in, in_rect, WeightsSeparable5Lowpass(), pool, + &out_expected, out_rect); + Symmetric5(in, in_rect, WeightsSymmetric5Lowpass(), pool, &out_actual, + out_rect); + + JXL_ASSERT_OK( + VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f, _)); + } + { + Rect out_rect(0, 0, in_rect.xsize(), in_rect.ysize()); + ImageF out_expected(out_rect.xsize(), out_rect.ysize()); + ImageF out_actual(out_rect.xsize(), out_rect.ysize()); + + SlowSeparable5(in, in_rect, WeightsSeparable5Lowpass(), pool, + &out_expected, out_rect); + Symmetric5(in, in_rect, WeightsSymmetric5Lowpass(), pool, &out_actual, + out_rect); + + JXL_ASSERT_OK( + VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f, _)); + } + } +} + +void VerifySeparable5(const size_t xsize, const size_t ysize, ThreadPool* pool, + Rng* rng) { + const Rect rect(0, 0, xsize, ysize); + + ImageF in(xsize, ysize); + GenerateImage(*rng, &in, 0.0f, 1.0f); + + ImageF out_expected(xsize, ysize); + ImageF out_actual(xsize, ysize); + + const WeightsSeparable5& weights = WeightsSeparable5Lowpass(); + SlowSeparable5(in, rect, weights, pool, &out_expected, rect); + Separable5(in, rect, weights, pool, &out_actual); + + JXL_ASSERT_OK(VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f, _)); +} + +// For all xsize/ysize and kernels: +void TestConvolve() { + TestNeighbors(); + + test::ThreadPoolForTests pool(4); + EXPECT_EQ(true, + RunOnPool( + &pool, kConvolveMaxRadius, 40, ThreadPool::NoInit, + [](const uint32_t task, size_t /*thread*/) { + const size_t xsize = task; + Rng rng(129 + 13 * xsize); + + ThreadPool* null_pool = nullptr; + test::ThreadPoolForTests pool3(3); + for (size_t ysize = kConvolveMaxRadius; ysize < 16; ++ysize) { + JXL_DEBUG(JXL_DEBUG_CONVOLVE, + "%" PRIuS " x %" PRIuS " (target %" PRIx64 + ")===============================", + xsize, ysize, static_cast<int64_t>(HWY_TARGET)); + + JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sym3------------------"); + VerifySymmetric3(xsize, ysize, null_pool, &rng); + VerifySymmetric3(xsize, ysize, &pool3, &rng); + + JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sym5------------------"); + VerifySymmetric5(xsize, ysize, null_pool, &rng); + VerifySymmetric5(xsize, ysize, &pool3, &rng); + + JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sep5------------------"); + VerifySeparable5(xsize, ysize, null_pool, &rng); + VerifySeparable5(xsize, ysize, &pool3, &rng); + } + }, + "TestConvolve")); +} + +// Measures durations, verifies results, prints timings. `unpredictable1` +// must have value 1 (unknown to the compiler to prevent elision). +template <class Conv> +void BenchmarkConv(const char* caption, const Conv& conv, + const hwy::FuncInput unpredictable1) { + const size_t kNumInputs = 1; + const hwy::FuncInput inputs[kNumInputs] = {unpredictable1}; + hwy::Result results[kNumInputs]; + + const size_t kDim = 160; // in+out fit in L2 + ImageF in(kDim, kDim); + ZeroFillImage(&in); + in.Row(kDim / 2)[kDim / 2] = unpredictable1; + ImageF out(kDim, kDim); + + hwy::Params p; + p.verbose = false; + p.max_evals = 7; + p.target_rel_mad = 0.002; + const size_t num_results = MeasureClosure( + [&in, &conv, &out](const hwy::FuncInput input) { + conv(in, &out); + return out.Row(input)[0]; + }, + inputs, kNumInputs, results, p); + if (num_results != kNumInputs) { + fprintf(stderr, "MeasureClosure failed.\n"); + } + for (size_t i = 0; i < num_results; ++i) { + const double seconds = static_cast<double>(results[i].ticks) / + hwy::platform::InvariantTicksPerSecond(); + printf("%12s: %7.2f MP/s (MAD=%4.2f%%)\n", caption, + kDim * kDim * 1E-6 / seconds, + static_cast<double>(results[i].variability) * 100.0); + } +} + +struct ConvSymmetric3 { + void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) const { + ThreadPool* null_pool = nullptr; + Symmetric3(in, Rect(in), WeightsSymmetric3Lowpass(), null_pool, out); + } +}; + +struct ConvSeparable5 { + void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) const { + ThreadPool* null_pool = nullptr; + Separable5(in, Rect(in), WeightsSeparable5Lowpass(), null_pool, out); + } +}; + +void BenchmarkAll() { +#if 0 // disabled to avoid test timeouts, run manually on demand + const hwy::FuncInput unpredictable1 = time(nullptr) != 1234; + BenchmarkConv("Symmetric3", ConvSymmetric3(), unpredictable1); + BenchmarkConv("Separable5", ConvSeparable5(), unpredictable1); +#endif +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class ConvolveTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(ConvolveTest); + +HWY_EXPORT_AND_TEST_P(ConvolveTest, TestConvolve); + +HWY_EXPORT_AND_TEST_P(ConvolveTest, BenchmarkAll); + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/data_parallel_test.cc b/third_party/jpeg-xl/lib/jxl/data_parallel_test.cc new file mode 100644 index 0000000000..ee2a97f93a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/data_parallel_test.cc @@ -0,0 +1,87 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/data_parallel.h" + +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +class DataParallelTest : public ::testing::Test { + protected: + // A fake class to verify that DataParallel is properly calling the + // client-provided runner functions. + static int FakeRunner(void* runner_opaque, void* jpegxl_opaque, + JxlParallelRunInit init, JxlParallelRunFunction func, + uint32_t start_range, uint32_t end_range) { + DataParallelTest* self = static_cast<DataParallelTest*>(runner_opaque); + self->runner_called_++; + self->jpegxl_opaque_ = jpegxl_opaque; + self->init_ = init; + self->func_ = func; + self->start_range_ = start_range; + self->end_range_ = end_range; + return self->runner_return_; + } + + ThreadPool pool_{&DataParallelTest::FakeRunner, this}; + + // Number of times FakeRunner() was called. + int runner_called_ = 0; + + // Parameters passed to FakeRunner. + void* jpegxl_opaque_ = nullptr; + JxlParallelRunInit init_ = nullptr; + JxlParallelRunFunction func_ = nullptr; + uint32_t start_range_ = -1; + uint32_t end_range_ = -1; + + // Return value that FakeRunner will return. + int runner_return_ = 0; +}; + +// JxlParallelRunInit interface. +typedef int (*JxlParallelRunInit)(); + +} // namespace + +TEST_F(DataParallelTest, RunnerCalledParameters) { + EXPECT_TRUE(pool_.Run( + 1234, 5678, [](size_t /* num_threads */) { return true; }, + [](uint32_t /* task */, size_t /* thread */) { return; })); + EXPECT_EQ(1, runner_called_); + EXPECT_NE(nullptr, init_); + EXPECT_NE(nullptr, func_); + EXPECT_NE(nullptr, jpegxl_opaque_); + EXPECT_EQ(1234u, start_range_); + EXPECT_EQ(5678u, end_range_); +} + +TEST_F(DataParallelTest, RunnerFailurePropagates) { + runner_return_ = -1; // FakeRunner return value. + EXPECT_FALSE(pool_.Run( + 1234, 5678, [](size_t /* num_threads */) { return false; }, + [](uint32_t /* task */, size_t /* thread */) { return; })); + EXPECT_FALSE(RunOnPool( + nullptr, 1234, 5678, [](size_t /* num_threads */) { return false; }, + [](uint32_t /* task */, size_t /* thread */) { return; }, "Test")); +} + +TEST_F(DataParallelTest, RunnerNotCalledOnEmptyRange) { + runner_return_ = -1; // FakeRunner return value. + EXPECT_TRUE(pool_.Run( + 123, 123, [](size_t /* num_threads */) { return false; }, + [](uint32_t /* task */, size_t /* thread */) { return; })); + EXPECT_TRUE(RunOnPool( + nullptr, 123, 123, [](size_t /* num_threads */) { return false; }, + [](uint32_t /* task */, size_t /* thread */) { return; }, "Test")); + // We don't call the external runner when the range is empty. We don't even + // need to call the init function. + EXPECT_EQ(0, runner_called_); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dct-inl.h b/third_party/jpeg-xl/lib/jxl/dct-inl.h new file mode 100644 index 0000000000..cb6c54bc46 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct-inl.h @@ -0,0 +1,339 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Fast SIMD floating-point (I)DCT, any power of two. + +#if defined(LIB_JXL_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DCT_INL_H_ +#undef LIB_JXL_DCT_INL_H_ +#else +#define LIB_JXL_DCT_INL_H_ +#endif + +#include <stddef.h> + +#include <hwy/highway.h> + +#include "lib/jxl/dct_block-inl.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/transpose-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::NegMulAdd; +using hwy::HWY_NAMESPACE::Sub; + +template <size_t SZ> +struct FVImpl { + using type = HWY_CAPPED(float, SZ); +}; + +template <> +struct FVImpl<0> { + using type = HWY_FULL(float); +}; + +template <size_t SZ> +using FV = typename FVImpl<SZ>::type; + +// Implementation of Lowest Complexity Self Recursive Radix-2 DCT II/III +// Algorithms, by Siriani M. Perera and Jianhua Liu. + +template <size_t N, size_t SZ> +struct CoeffBundle { + static void AddReverse(const float* JXL_RESTRICT ain1, + const float* JXL_RESTRICT ain2, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N; i++) { + auto in1 = Load(FV<SZ>(), ain1 + i * SZ); + auto in2 = Load(FV<SZ>(), ain2 + (N - i - 1) * SZ); + Store(Add(in1, in2), FV<SZ>(), aout + i * SZ); + } + } + static void SubReverse(const float* JXL_RESTRICT ain1, + const float* JXL_RESTRICT ain2, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N; i++) { + auto in1 = Load(FV<SZ>(), ain1 + i * SZ); + auto in2 = Load(FV<SZ>(), ain2 + (N - i - 1) * SZ); + Store(Sub(in1, in2), FV<SZ>(), aout + i * SZ); + } + } + static void B(float* JXL_RESTRICT coeff) { + auto sqrt2 = Set(FV<SZ>(), kSqrt2); + auto in1 = Load(FV<SZ>(), coeff); + auto in2 = Load(FV<SZ>(), coeff + SZ); + Store(MulAdd(in1, sqrt2, in2), FV<SZ>(), coeff); + for (size_t i = 1; i + 1 < N; i++) { + auto in1 = Load(FV<SZ>(), coeff + i * SZ); + auto in2 = Load(FV<SZ>(), coeff + (i + 1) * SZ); + Store(Add(in1, in2), FV<SZ>(), coeff + i * SZ); + } + } + static void BTranspose(float* JXL_RESTRICT coeff) { + for (size_t i = N - 1; i > 0; i--) { + auto in1 = Load(FV<SZ>(), coeff + i * SZ); + auto in2 = Load(FV<SZ>(), coeff + (i - 1) * SZ); + Store(Add(in1, in2), FV<SZ>(), coeff + i * SZ); + } + auto sqrt2 = Set(FV<SZ>(), kSqrt2); + auto in1 = Load(FV<SZ>(), coeff); + Store(Mul(in1, sqrt2), FV<SZ>(), coeff); + } + // Ideally optimized away by compiler (except the multiply). + static void InverseEvenOdd(const float* JXL_RESTRICT ain, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N / 2; i++) { + auto in1 = Load(FV<SZ>(), ain + i * SZ); + Store(in1, FV<SZ>(), aout + 2 * i * SZ); + } + for (size_t i = N / 2; i < N; i++) { + auto in1 = Load(FV<SZ>(), ain + i * SZ); + Store(in1, FV<SZ>(), aout + (2 * (i - N / 2) + 1) * SZ); + } + } + // Ideally optimized away by compiler. + static void ForwardEvenOdd(const float* JXL_RESTRICT ain, size_t ain_stride, + float* JXL_RESTRICT aout) { + for (size_t i = 0; i < N / 2; i++) { + auto in1 = LoadU(FV<SZ>(), ain + 2 * i * ain_stride); + Store(in1, FV<SZ>(), aout + i * SZ); + } + for (size_t i = N / 2; i < N; i++) { + auto in1 = LoadU(FV<SZ>(), ain + (2 * (i - N / 2) + 1) * ain_stride); + Store(in1, FV<SZ>(), aout + i * SZ); + } + } + // Invoked on full vector. + static void Multiply(float* JXL_RESTRICT coeff) { + for (size_t i = 0; i < N / 2; i++) { + auto in1 = Load(FV<SZ>(), coeff + (N / 2 + i) * SZ); + auto mul = Set(FV<SZ>(), WcMultipliers<N>::kMultipliers[i]); + Store(Mul(in1, mul), FV<SZ>(), coeff + (N / 2 + i) * SZ); + } + } + static void MultiplyAndAdd(const float* JXL_RESTRICT coeff, + float* JXL_RESTRICT out, size_t out_stride) { + for (size_t i = 0; i < N / 2; i++) { + auto mul = Set(FV<SZ>(), WcMultipliers<N>::kMultipliers[i]); + auto in1 = Load(FV<SZ>(), coeff + i * SZ); + auto in2 = Load(FV<SZ>(), coeff + (N / 2 + i) * SZ); + auto out1 = MulAdd(mul, in2, in1); + auto out2 = NegMulAdd(mul, in2, in1); + StoreU(out1, FV<SZ>(), out + i * out_stride); + StoreU(out2, FV<SZ>(), out + (N - i - 1) * out_stride); + } + } + template <typename Block> + static void LoadFromBlock(const Block& in, size_t off, + float* JXL_RESTRICT coeff) { + for (size_t i = 0; i < N; i++) { + Store(in.LoadPart(FV<SZ>(), i, off), FV<SZ>(), coeff + i * SZ); + } + } + template <typename Block> + static void StoreToBlockAndScale(const float* JXL_RESTRICT coeff, + const Block& out, size_t off) { + auto mul = Set(FV<SZ>(), 1.0f / N); + for (size_t i = 0; i < N; i++) { + out.StorePart(FV<SZ>(), Mul(mul, Load(FV<SZ>(), coeff + i * SZ)), i, off); + } + } +}; + +template <size_t N, size_t SZ> +struct DCT1DImpl; + +template <size_t SZ> +struct DCT1DImpl<1, SZ> { + JXL_INLINE void operator()(float* JXL_RESTRICT mem, float*) {} +}; + +template <size_t SZ> +struct DCT1DImpl<2, SZ> { + JXL_INLINE void operator()(float* JXL_RESTRICT mem, float*) { + auto in1 = Load(FV<SZ>(), mem); + auto in2 = Load(FV<SZ>(), mem + SZ); + Store(Add(in1, in2), FV<SZ>(), mem); + Store(Sub(in1, in2), FV<SZ>(), mem + SZ); + } +}; + +template <size_t N, size_t SZ> +struct DCT1DImpl { + void operator()(float* JXL_RESTRICT mem, float* JXL_RESTRICT tmp) { + CoeffBundle<N / 2, SZ>::AddReverse(mem, mem + N / 2 * SZ, tmp); + DCT1DImpl<N / 2, SZ>()(tmp, tmp + N * SZ); + CoeffBundle<N / 2, SZ>::SubReverse(mem, mem + N / 2 * SZ, tmp + N / 2 * SZ); + CoeffBundle<N, SZ>::Multiply(tmp); + DCT1DImpl<N / 2, SZ>()(tmp + N / 2 * SZ, tmp + N * SZ); + CoeffBundle<N / 2, SZ>::B(tmp + N / 2 * SZ); + CoeffBundle<N, SZ>::InverseEvenOdd(tmp, mem); + } +}; + +template <size_t N, size_t SZ> +struct IDCT1DImpl; + +template <size_t SZ> +struct IDCT1DImpl<1, SZ> { + JXL_INLINE void operator()(const float* from, size_t from_stride, float* to, + size_t to_stride, float* JXL_RESTRICT) { + StoreU(LoadU(FV<SZ>(), from), FV<SZ>(), to); + } +}; + +template <size_t SZ> +struct IDCT1DImpl<2, SZ> { + JXL_INLINE void operator()(const float* from, size_t from_stride, float* to, + size_t to_stride, float* JXL_RESTRICT) { + JXL_DASSERT(from_stride >= SZ); + JXL_DASSERT(to_stride >= SZ); + auto in1 = LoadU(FV<SZ>(), from); + auto in2 = LoadU(FV<SZ>(), from + from_stride); + StoreU(Add(in1, in2), FV<SZ>(), to); + StoreU(Sub(in1, in2), FV<SZ>(), to + to_stride); + } +}; + +template <size_t N, size_t SZ> +struct IDCT1DImpl { + void operator()(const float* from, size_t from_stride, float* to, + size_t to_stride, float* JXL_RESTRICT tmp) { + JXL_DASSERT(from_stride >= SZ); + JXL_DASSERT(to_stride >= SZ); + CoeffBundle<N, SZ>::ForwardEvenOdd(from, from_stride, tmp); + IDCT1DImpl<N / 2, SZ>()(tmp, SZ, tmp, SZ, tmp + N * SZ); + CoeffBundle<N / 2, SZ>::BTranspose(tmp + N / 2 * SZ); + IDCT1DImpl<N / 2, SZ>()(tmp + N / 2 * SZ, SZ, tmp + N / 2 * SZ, SZ, + tmp + N * SZ); + CoeffBundle<N, SZ>::MultiplyAndAdd(tmp, to, to_stride); + } +}; + +template <size_t N, size_t M_or_0, typename FromBlock, typename ToBlock> +void DCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp, + float* JXL_RESTRICT tmp) { + size_t M = M_or_0 != 0 ? M_or_0 : Mp; + constexpr size_t SZ = MaxLanes(FV<M_or_0>()); + for (size_t i = 0; i < M; i += Lanes(FV<M_or_0>())) { + // TODO(veluca): consider removing the temporary memory here (as is done in + // IDCT), if it turns out that some compilers don't optimize away the loads + // and this is performance-critical. + CoeffBundle<N, SZ>::LoadFromBlock(from, i, tmp); + DCT1DImpl<N, SZ>()(tmp, tmp + N * SZ); + CoeffBundle<N, SZ>::StoreToBlockAndScale(tmp, to, i); + } +} + +template <size_t N, size_t M_or_0, typename FromBlock, typename ToBlock> +void IDCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp, + float* JXL_RESTRICT tmp) { + size_t M = M_or_0 != 0 ? M_or_0 : Mp; + constexpr size_t SZ = MaxLanes(FV<M_or_0>()); + for (size_t i = 0; i < M; i += Lanes(FV<M_or_0>())) { + IDCT1DImpl<N, SZ>()(from.Address(0, i), from.Stride(), to.Address(0, i), + to.Stride(), tmp); + } +} + +template <size_t N, size_t M, typename = void> +struct DCT1D { + template <typename FromBlock, typename ToBlock> + void operator()(const FromBlock& from, const ToBlock& to, + float* JXL_RESTRICT tmp) { + return DCT1DWrapper<N, M>(from, to, M, tmp); + } +}; + +template <size_t N, size_t M> +struct DCT1D<N, M, typename std::enable_if<(M > MaxLanes(FV<0>()))>::type> { + template <typename FromBlock, typename ToBlock> + void operator()(const FromBlock& from, const ToBlock& to, + float* JXL_RESTRICT tmp) { + return NoInlineWrapper(DCT1DWrapper<N, 0, FromBlock, ToBlock>, from, to, M, + tmp); + } +}; + +template <size_t N, size_t M, typename = void> +struct IDCT1D { + template <typename FromBlock, typename ToBlock> + void operator()(const FromBlock& from, const ToBlock& to, + float* JXL_RESTRICT tmp) { + return IDCT1DWrapper<N, M>(from, to, M, tmp); + } +}; + +template <size_t N, size_t M> +struct IDCT1D<N, M, typename std::enable_if<(M > MaxLanes(FV<0>()))>::type> { + template <typename FromBlock, typename ToBlock> + void operator()(const FromBlock& from, const ToBlock& to, + float* JXL_RESTRICT tmp) { + return NoInlineWrapper(IDCT1DWrapper<N, 0, FromBlock, ToBlock>, from, to, M, + tmp); + } +}; + +// Computes the maybe-transposed, scaled DCT of a block, that needs to be +// HWY_ALIGN'ed. +template <size_t ROWS, size_t COLS> +struct ComputeScaledDCT { + // scratch_space must be aligned, and should have space for ROWS*COLS + // floats. + template <class From> + HWY_MAYBE_UNUSED void operator()(const From& from, float* to, + float* JXL_RESTRICT scratch_space) { + float* JXL_RESTRICT block = scratch_space; + float* JXL_RESTRICT tmp = scratch_space + ROWS * COLS; + if (ROWS < COLS) { + DCT1D<ROWS, COLS>()(from, DCTTo(block, COLS), tmp); + Transpose<ROWS, COLS>::Run(DCTFrom(block, COLS), DCTTo(to, ROWS)); + DCT1D<COLS, ROWS>()(DCTFrom(to, ROWS), DCTTo(block, ROWS), tmp); + Transpose<COLS, ROWS>::Run(DCTFrom(block, ROWS), DCTTo(to, COLS)); + } else { + DCT1D<ROWS, COLS>()(from, DCTTo(to, COLS), tmp); + Transpose<ROWS, COLS>::Run(DCTFrom(to, COLS), DCTTo(block, ROWS)); + DCT1D<COLS, ROWS>()(DCTFrom(block, ROWS), DCTTo(to, ROWS), tmp); + } + } +}; +// Computes the maybe-transposed, scaled IDCT of a block, that needs to be +// HWY_ALIGN'ed. +template <size_t ROWS, size_t COLS> +struct ComputeScaledIDCT { + // scratch_space must be aligned, and should have space for ROWS*COLS + // floats. + template <class To> + HWY_MAYBE_UNUSED void operator()(float* JXL_RESTRICT from, const To& to, + float* JXL_RESTRICT scratch_space) { + float* JXL_RESTRICT block = scratch_space; + float* JXL_RESTRICT tmp = scratch_space + ROWS * COLS; + // Reverse the steps done in ComputeScaledDCT. + if (ROWS < COLS) { + Transpose<ROWS, COLS>::Run(DCTFrom(from, COLS), DCTTo(block, ROWS)); + IDCT1D<COLS, ROWS>()(DCTFrom(block, ROWS), DCTTo(from, ROWS), tmp); + Transpose<COLS, ROWS>::Run(DCTFrom(from, ROWS), DCTTo(block, COLS)); + IDCT1D<ROWS, COLS>()(DCTFrom(block, COLS), to, tmp); + } else { + IDCT1D<COLS, ROWS>()(DCTFrom(from, ROWS), DCTTo(block, ROWS), tmp); + Transpose<COLS, ROWS>::Run(DCTFrom(block, ROWS), DCTTo(from, COLS)); + IDCT1D<ROWS, COLS>()(DCTFrom(from, COLS), to, tmp); + } + } +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); +#endif // LIB_JXL_DCT_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dct_block-inl.h b/third_party/jpeg-xl/lib/jxl/dct_block-inl.h new file mode 100644 index 0000000000..50646a737f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_block-inl.h @@ -0,0 +1,108 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Adapters for DCT input/output: from/to contiguous blocks or image rows. + +#if defined(LIB_JXL_DCT_BLOCK_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DCT_BLOCK_INL_H_ +#undef LIB_JXL_DCT_BLOCK_INL_H_ +#else +#define LIB_JXL_DCT_BLOCK_INL_H_ +#endif + +#include <stddef.h> + +#include <hwy/highway.h> + +#include "lib/jxl/base/status.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Vec; + +// Block: (x, y) <-> (N * y + x) +// Lines: (x, y) <-> (stride * y + x) +// +// I.e. Block is a specialization of Lines with fixed stride. +// +// FromXXX should implement Read and Load (Read vector). +// ToXXX should implement Write and Store (Write vector). + +template <size_t N> +using BlockDesc = HWY_CAPPED(float, N); + +// Here and in the following, the SZ template parameter specifies the number of +// values to load/store. Needed because we want to handle 4x4 sub-blocks of +// 16x16 blocks. +class DCTFrom { + public: + DCTFrom(const float* data, size_t stride) : stride_(stride), data_(data) {} + + template <typename D> + HWY_INLINE Vec<D> LoadPart(D, const size_t row, size_t i) const { + JXL_DASSERT(Lanes(D()) <= stride_); + // Since these functions are used also for DC, no alignment at all is + // guaranteed in the case of floating blocks. + // TODO(veluca): consider using a different class for DC-to-LF and + // DC-from-LF, or copying DC values to/from a temporary aligned location. + return LoadU(D(), Address(row, i)); + } + + HWY_INLINE float Read(const size_t row, const size_t i) const { + return *Address(row, i); + } + + constexpr HWY_INLINE const float* Address(const size_t row, + const size_t i) const { + return data_ + row * stride_ + i; + } + + size_t Stride() const { return stride_; } + + private: + size_t stride_; + const float* JXL_RESTRICT data_; +}; + +class DCTTo { + public: + DCTTo(float* data, size_t stride) : stride_(stride), data_(data) {} + + template <typename D> + HWY_INLINE void StorePart(D, const Vec<D>& v, const size_t row, + size_t i) const { + JXL_DASSERT(Lanes(D()) <= stride_); + // Since these functions are used also for DC, no alignment at all is + // guaranteed in the case of floating blocks. + // TODO(veluca): consider using a different class for DC-to-LF and + // DC-from-LF, or copying DC values to/from a temporary aligned location. + StoreU(v, D(), Address(row, i)); + } + + HWY_INLINE void Write(float v, const size_t row, const size_t i) const { + *Address(row, i) = v; + } + + constexpr HWY_INLINE float* Address(const size_t row, const size_t i) const { + return data_ + row * stride_ + i; + } + + size_t Stride() const { return stride_; } + + private: + size_t stride_; + float* JXL_RESTRICT data_; +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_DCT_BLOCK_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dct_for_test.h b/third_party/jpeg-xl/lib/jxl/dct_for_test.h new file mode 100644 index 0000000000..58dd75e20e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_for_test.h @@ -0,0 +1,99 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DCT_FOR_TEST_H_ +#define LIB_JXL_DCT_FOR_TEST_H_ + +// Unoptimized DCT only for use in tests. + +#include <string.h> // memcpy + +#include <cmath> +#include <vector> + +#include "lib/jxl/base/common.h" + +namespace jxl { + +namespace test { +static inline double alpha(int u) { return u == 0 ? 0.7071067811865475 : 1.0; } + +// N-DCT on M columns, divided by sqrt(N). Matches the definition in the spec. +template <size_t N, size_t M> +void DCT1D(double block[N * M], double out[N * M]) { + std::vector<double> matrix(N * N); + const double scale = std::sqrt(2.0) / N; + for (size_t y = 0; y < N; y++) { + for (size_t u = 0; u < N; u++) { + matrix[N * u + y] = alpha(u) * cos((y + 0.5) * u * Pi(1.0 / N)) * scale; + } + } + for (size_t x = 0; x < M; x++) { + for (size_t u = 0; u < N; u++) { + out[M * u + x] = 0; + for (size_t y = 0; y < N; y++) { + out[M * u + x] += matrix[N * u + y] * block[M * y + x]; + } + } + } +} + +// N-IDCT on M columns, multiplied by sqrt(N). Matches the definition in the +// spec. +template <size_t N, size_t M> +void IDCT1D(double block[N * M], double out[N * M]) { + std::vector<double> matrix(N * N); + const double scale = std::sqrt(2.0); + for (size_t y = 0; y < N; y++) { + for (size_t u = 0; u < N; u++) { + // Transpose of DCT matrix. + matrix[N * y + u] = alpha(u) * cos((y + 0.5) * u * Pi(1.0 / N)) * scale; + } + } + for (size_t x = 0; x < M; x++) { + for (size_t u = 0; u < N; u++) { + out[M * u + x] = 0; + for (size_t y = 0; y < N; y++) { + out[M * u + x] += matrix[N * u + y] * block[M * y + x]; + } + } + } +} + +template <size_t N, size_t M> +void TransposeBlock(double in[N * M], double out[M * N]) { + for (size_t x = 0; x < N; x++) { + for (size_t y = 0; y < M; y++) { + out[y * N + x] = in[x * M + y]; + } + } +} +} // namespace test + +// Untransposed DCT. +template <size_t N> +void DCTSlow(double block[N * N]) { + constexpr size_t kBlockSize = N * N; + std::vector<double> g(kBlockSize); + test::DCT1D<N, N>(block, g.data()); + test::TransposeBlock<N, N>(g.data(), block); + test::DCT1D<N, N>(block, g.data()); + test::TransposeBlock<N, N>(g.data(), block); +} + +// Untransposed IDCT. +template <size_t N> +void IDCTSlow(double block[N * N]) { + constexpr size_t kBlockSize = N * N; + std::vector<double> g(kBlockSize); + test::IDCT1D<N, N>(block, g.data()); + test::TransposeBlock<N, N>(g.data(), block); + test::IDCT1D<N, N>(block, g.data()); + test::TransposeBlock<N, N>(g.data(), block); +} + +} // namespace jxl + +#endif // LIB_JXL_DCT_FOR_TEST_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dct_scales.cc b/third_party/jpeg-xl/lib/jxl/dct_scales.cc new file mode 100644 index 0000000000..f9e89a6014 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_scales.cc @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dct_scales.h" + +namespace jxl { + +// Definition of constexpr arrays. +constexpr float DCTResampleScales<1, 8>::kScales[]; +constexpr float DCTResampleScales<2, 16>::kScales[]; +constexpr float DCTResampleScales<4, 32>::kScales[]; +constexpr float DCTResampleScales<8, 64>::kScales[]; +constexpr float DCTResampleScales<16, 128>::kScales[]; +constexpr float DCTResampleScales<32, 256>::kScales[]; +constexpr float DCTResampleScales<8, 1>::kScales[]; +constexpr float DCTResampleScales<16, 2>::kScales[]; +constexpr float DCTResampleScales<32, 4>::kScales[]; +constexpr float DCTResampleScales<64, 8>::kScales[]; +constexpr float DCTResampleScales<128, 16>::kScales[]; +constexpr float DCTResampleScales<256, 32>::kScales[]; +constexpr float WcMultipliers<4>::kMultipliers[]; +constexpr float WcMultipliers<8>::kMultipliers[]; +constexpr float WcMultipliers<16>::kMultipliers[]; +constexpr float WcMultipliers<32>::kMultipliers[]; +constexpr float WcMultipliers<64>::kMultipliers[]; +constexpr float WcMultipliers<128>::kMultipliers[]; +constexpr float WcMultipliers<256>::kMultipliers[]; + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dct_scales.h b/third_party/jpeg-xl/lib/jxl/dct_scales.h new file mode 100644 index 0000000000..23af03d60f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_scales.h @@ -0,0 +1,379 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DCT_SCALES_H_ +#define LIB_JXL_DCT_SCALES_H_ + +// Scaling factors. + +#include <stddef.h> + +namespace jxl { + +static constexpr float kSqrt2 = 1.41421356237f; +static constexpr float kSqrt0_5 = 0.70710678118f; + +// For n != 0, the n-th basis function of a N-DCT, evaluated in pixel k, has a +// value of cos((k+1/2) n/(2N) pi). When downsampling by 2x, we average +// the values for pixel k and k+1 to get the value for pixel (k/2), thus we get +// +// [cos((k+1/2) n/N pi) + cos((k+3/2) n/N pi)]/2 = +// cos(n/(2N) pi) cos((k+1) n/N pi) = +// cos(n/(2N) pi) cos(((k/2)+1/2) n/(N/2) pi) +// +// which is exactly the same as the value of pixel k/2 of a N/2-sized DCT, +// except for the cos(n/(2N) pi) scaling factor (which does *not* +// depend on the pixel). Thus, when using the lower-frequency coefficients of a +// DCT-N to compute a DCT-(N/2), they should be scaled by this constant. Scaling +// factors for a DCT-(N/4) etc can then be obtained by successive +// multiplications. The structs below contain the above-mentioned scaling +// factors. +// +// Python code for the tables below: +// +// for i in range(N // 8): +// v = math.cos(i / (2 * N) * math.pi) +// v *= math.cos(i / (N) * math.pi) +// v *= math.cos(i / (N / 2) * math.pi) +// print(v, end=", ") + +template <size_t FROM, size_t TO> +struct DCTResampleScales; + +template <> +struct DCTResampleScales<8, 1> { + static constexpr float kScales[] = { + 1.000000000000000000, + }; +}; + +template <> +struct DCTResampleScales<16, 2> { + static constexpr float kScales[] = { + 1.000000000000000000, + 0.901764195028874394, + }; +}; + +template <> +struct DCTResampleScales<32, 4> { + static constexpr float kScales[] = { + 1.000000000000000000, + 0.974886821136879522, + 0.901764195028874394, + 0.787054918159101335, + }; +}; + +template <> +struct DCTResampleScales<64, 8> { + static constexpr float kScales[] = { + 1.0000000000000000, 0.9936866130906366, 0.9748868211368796, + 0.9440180941651672, 0.9017641950288744, 0.8490574973847023, + 0.7870549181591013, 0.7171081282466044, + }; +}; + +template <> +struct DCTResampleScales<128, 16> { + static constexpr float kScales[] = { + 1.0, + 0.9984194528776054, + 0.9936866130906366, + 0.9858278282666936, + 0.9748868211368796, + 0.9609244059440204, + 0.9440180941651672, + 0.9242615922757944, + 0.9017641950288744, + 0.8766500784429904, + 0.8490574973847023, + 0.8191378932865928, + 0.7870549181591013, + 0.7529833816270532, + 0.7171081282466044, + 0.6796228528314651, + }; +}; + +template <> +struct DCTResampleScales<256, 32> { + static constexpr float kScales[] = { + 1.0, + 0.9996047255830407, + 0.9984194528776054, + 0.9964458326264695, + 0.9936866130906366, + 0.9901456355893141, + 0.9858278282666936, + 0.9807391980963174, + 0.9748868211368796, + 0.9682788310563117, + 0.9609244059440204, + 0.9528337534340876, + 0.9440180941651672, + 0.9344896436056892, + 0.9242615922757944, + 0.913348084400198, + 0.9017641950288744, + 0.8895259056651056, + 0.8766500784429904, + 0.8631544288990163, + 0.8490574973847023, + 0.8343786191696513, + 0.8191378932865928, + 0.8033561501721485, + 0.7870549181591013, + 0.7702563888779096, + 0.7529833816270532, + 0.7352593067735488, + 0.7171081282466044, + 0.6985543251889097, + 0.6796228528314651, + 0.6603391026591464, + }; +}; + +// Inverses of the above. +template <> +struct DCTResampleScales<1, 8> { + static constexpr float kScales[] = { + 1.000000000000000000, + }; +}; + +template <> +struct DCTResampleScales<2, 16> { + static constexpr float kScales[] = { + 1.000000000000000000, + 1.108937353592731823, + }; +}; + +template <> +struct DCTResampleScales<4, 32> { + static constexpr float kScales[] = { + 1.000000000000000000, + 1.025760096781116015, + 1.108937353592731823, + 1.270559368765487251, + }; +}; + +template <> +struct DCTResampleScales<8, 64> { + static constexpr float kScales[] = { + 1.0000000000000000, 1.0063534990068217, 1.0257600967811158, + 1.0593017296817173, 1.1089373535927318, 1.1777765381970435, + 1.2705593687654873, 1.3944898413647777, + }; +}; + +template <> +struct DCTResampleScales<16, 128> { + static constexpr float kScales[] = { + 1.0, + 1.0015830492062623, + 1.0063534990068217, + 1.0143759095928793, + 1.0257600967811158, + 1.0406645869480142, + 1.0593017296817173, + 1.0819447744633812, + 1.1089373535927318, + 1.1407059950032632, + 1.1777765381970435, + 1.2207956782315876, + 1.2705593687654873, + 1.3280505578213306, + 1.3944898413647777, + 1.4714043176061107, + }; +}; + +template <> +struct DCTResampleScales<32, 256> { + static constexpr float kScales[] = { + 1.0, + 1.0003954307206069, + 1.0015830492062623, + 1.0035668445360069, + 1.0063534990068217, + 1.009952439375063, + 1.0143759095928793, + 1.0196390660647288, + 1.0257600967811158, + 1.0327603660498115, + 1.0406645869480142, + 1.049501024072585, + 1.0593017296817173, + 1.0701028169146336, + 1.0819447744633812, + 1.0948728278734026, + 1.1089373535927318, + 1.124194353004584, + 1.1407059950032632, + 1.158541237256391, + 1.1777765381970435, + 1.1984966740820495, + 1.2207956782315876, + 1.244777922949508, + 1.2705593687654873, + 1.2982690107339132, + 1.3280505578213306, + 1.3600643892400104, + 1.3944898413647777, + 1.4315278911623237, + 1.4714043176061107, + 1.5143734423314616, + }; +}; + +// Constants for DCT implementation. Generated by the following snippet: +// for i in range(N // 2): +// print(1.0 / (2 * math.cos((i + 0.5) * math.pi / N)), end=", ") +template <size_t N> +struct WcMultipliers; + +template <> +struct WcMultipliers<4> { + static constexpr float kMultipliers[] = { + 0.541196100146197, + 1.3065629648763764, + }; +}; + +template <> +struct WcMultipliers<8> { + static constexpr float kMultipliers[] = { + 0.5097955791041592, + 0.6013448869350453, + 0.8999762231364156, + 2.5629154477415055, + }; +}; + +template <> +struct WcMultipliers<16> { + static constexpr float kMultipliers[] = { + 0.5024192861881557, 0.5224986149396889, 0.5669440348163577, + 0.6468217833599901, 0.7881546234512502, 1.060677685990347, + 1.7224470982383342, 5.101148618689155, + }; +}; + +template <> +struct WcMultipliers<32> { + static constexpr float kMultipliers[] = { + 0.5006029982351963, 0.5054709598975436, 0.5154473099226246, + 0.5310425910897841, 0.5531038960344445, 0.5829349682061339, + 0.6225041230356648, 0.6748083414550057, 0.7445362710022986, + 0.8393496454155268, 0.9725682378619608, 1.1694399334328847, + 1.4841646163141662, 2.057781009953411, 3.407608418468719, + 10.190008123548033, + }; +}; +template <> +struct WcMultipliers<64> { + static constexpr float kMultipliers[] = { + 0.500150636020651, 0.5013584524464084, 0.5037887256810443, + 0.5074711720725553, 0.5124514794082247, 0.5187927131053328, + 0.52657731515427, 0.535909816907992, 0.5469204379855088, + 0.5597698129470802, 0.57465518403266, 0.5918185358574165, + 0.6115573478825099, 0.6342389366884031, 0.6603198078137061, + 0.6903721282002123, 0.7251205223771985, 0.7654941649730891, + 0.8127020908144905, 0.8683447152233481, 0.9345835970364075, + 1.0144082649970547, 1.1120716205797176, 1.233832737976571, + 1.3892939586328277, 1.5939722833856311, 1.8746759800084078, + 2.282050068005162, 2.924628428158216, 4.084611078129248, + 6.796750711673633, 20.373878167231453, + }; +}; +template <> +struct WcMultipliers<128> { + static constexpr float kMultipliers[] = { + 0.5000376519155477, 0.5003390374428216, 0.5009427176380873, + 0.5018505174842379, 0.5030651913013697, 0.5045904432216454, + 0.5064309549285542, 0.5085924210498143, 0.5110815927066812, + 0.5139063298475396, 0.5170756631334912, 0.5205998663018917, + 0.524490540114724, 0.5287607092074876, 0.5334249333971333, + 0.538499435291984, 0.5440022463817783, 0.549953374183236, + 0.5563749934898856, 0.5632916653417023, 0.5707305880121454, + 0.5787218851348208, 0.5872989370937893, 0.5964987630244563, + 0.606362462272146, 0.6169357260050706, 0.6282694319707711, + 0.6404203382416639, 0.6534518953751283, 0.6674352009263413, + 0.6824501259764195, 0.6985866506472291, 0.7159464549705746, + 0.7346448236478627, 0.7548129391165311, 0.776600658233963, + 0.8001798956216941, 0.8257487738627852, 0.8535367510066064, + 0.8838110045596234, 0.9168844461846523, 0.9531258743921193, + 0.9929729612675466, 1.036949040910389, 1.0856850642580145, + 1.1399486751015042, 1.2006832557294167, 1.2690611716991191, + 1.346557628206286, 1.4350550884414341, 1.5369941008524954, + 1.6555965242641195, 1.7952052190778898, 1.961817848571166, + 2.163957818751979, 2.4141600002500763, 2.7316450287739396, + 3.147462191781909, 3.7152427383269746, 4.5362909369693565, + 5.827688377844654, 8.153848602466814, 13.58429025728446, + 40.744688103351834, + }; +}; + +template <> +struct WcMultipliers<256> { + static constexpr float kMultipliers[128] = { + 0.5000094125358878, 0.500084723455784, 0.5002354020255269, + 0.5004615618093246, 0.5007633734146156, 0.5011410648064231, + 0.5015949217281668, 0.502125288230386, 0.5027325673091954, + 0.5034172216566842, 0.5041797745258774, 0.5050208107132756, + 0.5059409776624396, 0.5069409866925212, 0.5080216143561264, + 0.509183703931388, 0.5104281670536573, 0.5117559854927805, + 0.5131682130825206, 0.5146659778093218, 0.516250484068288, + 0.5179230150949777, 0.5196849355823947, 0.5215376944933958, + 0.5234828280796439, 0.52552196311921, 0.5276568203859896, + 0.5298892183652453, 0.5322210772308335, 0.5346544231010253, + 0.537191392591309, 0.5398342376841637, 0.5425853309375497, + 0.545447171055775, 0.5484223888484947, 0.551513753605893, + 0.554724179920619, 0.5580567349898085, 0.5615146464335654, + 0.5651013106696203, 0.5688203018875696, 0.5726753816701664, + 0.5766705093136241, 0.5808098529038624, 0.5850978012111273, + 0.58953897647151, 0.5941382481306648, 0.5989007476325463, + 0.6038318843443582, 0.6089373627182432, 0.614223200800649, + 0.6196957502119484, 0.6253617177319102, 0.6312281886412079, + 0.6373026519855411, 0.6435930279473415, 0.6501076975307724, + 0.6568555347890955, 0.6638459418498757, 0.6710888870233562, + 0.6785949463131795, 0.6863753486870501, 0.6944420255086364, + 0.7028076645818034, 0.7114857693151208, 0.7204907235796304, + 0.7298378629074134, 0.7395435527641373, 0.749625274727372, + 0.7601017215162176, 0.7709929019493761, 0.7823202570613161, + 0.7941067887834509, 0.8063772028037925, 0.8191580674598145, + 0.83247799080191, 0.8463678182968619, 0.860860854031955, + 0.8759931087426972, 0.8918035785352535, 0.9083345588266809, + 0.9256319988042384, 0.9437459026371479, 0.962730784794803, + 0.9826461881778968, 1.0035572754078206, 1.0255355056139732, + 1.048659411496106, 1.0730154944316674, 1.0986992590905857, + 1.1258164135986009, 1.1544842669978943, 1.184833362908442, + 1.217009397314603, 1.2511754798461228, 1.287514812536712, + 1.326233878832723, 1.3675662599582539, 1.411777227500661, + 1.459169302866857, 1.5100890297227016, 1.5649352798258847, + 1.6241695131835794, 1.6883285509131505, 1.7580406092704062, + 1.8340456094306077, 1.9172211551275689, 2.0086161135167564, + 2.1094945286246385, 2.22139377701127, 2.346202662531156, + 2.486267909203593, 2.644541877144861, 2.824791402350551, + 3.0318994541759925, 3.2723115884254845, 3.5547153325075804, + 3.891107790700307, 4.298537526449054, 4.802076008665048, + 5.440166215091329, 6.274908408039339, 7.413566756422303, + 9.058751453879703, 11.644627325175037, 16.300023088031555, + 27.163977662448232, 81.48784219222516, + }; +}; + +// Apply the DCT algorithm-intrinsic constants to DCTResampleScale. +template <size_t FROM, size_t TO> +constexpr float DCTTotalResampleScale(size_t x) { + return DCTResampleScales<FROM, TO>::kScales[x]; +} + +} // namespace jxl + +#endif // LIB_JXL_DCT_SCALES_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dct_test.cc b/third_party/jpeg-xl/lib/jxl/dct_test.cc new file mode 100644 index 0000000000..e4982e2f45 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_test.cc @@ -0,0 +1,390 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <string.h> + +#include <cmath> +#include <numeric> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dct_test.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> +#include <hwy/tests/hwy_gtest.h> + +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/dct_for_test.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/image.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// Computes the in-place NxN DCT of block. +// Requires that block is HWY_ALIGN'ed. +// +// Performs ComputeTransposedScaledDCT and then transposes and scales it to +// obtain "vanilla" DCT. +template <size_t N> +void ComputeDCT(float block[N * N]) { + HWY_ALIGN float tmp_block[N * N]; + HWY_ALIGN float scratch_space[4 * N * N]; + ComputeScaledDCT<N, N>()(DCTFrom(block, N), tmp_block, scratch_space); + + // Untranspose. + Transpose<N, N>::Run(DCTFrom(tmp_block, N), DCTTo(block, N)); +} + +// Computes the in-place 8x8 iDCT of block. +// Requires that block is HWY_ALIGN'ed. +template <int N> +void ComputeIDCT(float block[N * N]) { + HWY_ALIGN float tmp_block[N * N]; + HWY_ALIGN float scratch_space[4 * N * N]; + // Untranspose. + Transpose<N, N>::Run(DCTFrom(block, N), DCTTo(tmp_block, N)); + + ComputeScaledIDCT<N, N>()(tmp_block, DCTTo(block, N), scratch_space); +} + +template <size_t N> +void TransposeTestT(float accuracy) { + constexpr size_t kBlockSize = N * N; + HWY_ALIGN float src[kBlockSize]; + DCTTo to_src(src, N); + for (size_t y = 0; y < N; ++y) { + for (size_t x = 0; x < N; ++x) { + to_src.Write(y * N + x, y, x); + } + } + HWY_ALIGN float dst[kBlockSize]; + Transpose<N, N>::Run(DCTFrom(src, N), DCTTo(dst, N)); + DCTFrom from_dst(dst, N); + for (size_t y = 0; y < N; ++y) { + for (size_t x = 0; x < N; ++x) { + float expected = x * N + y; + float actual = from_dst.Read(y, x); + EXPECT_NEAR(expected, actual, accuracy) << "x = " << x << ", y = " << y; + } + } +} + +void TransposeTest() { + TransposeTestT<8>(1e-7f); + TransposeTestT<16>(1e-7f); + TransposeTestT<32>(1e-7f); +} + +template <size_t N> +void ColumnDctRoundtripT(float accuracy) { + constexpr size_t kBlockSize = N * N; + // Though we are only interested in single column result, dct.h has built-in + // limit on minimal number of columns processed. So, to be safe, we do + // regular 8x8 block transformation. On the bright side - we could check all + // 8 basis vectors at once. + HWY_ALIGN float block[kBlockSize]; + HWY_ALIGN float scratch[3 * kBlockSize]; + DCTTo to(block, N); + DCTFrom from(block, N); + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < N; ++j) { + to.Write((i == j) ? 1.0f : 0.0f, i, j); + } + } + + // Running (I)DCT on the same memory block seems to trigger a compiler bug on + // ARMv7 with clang6. + HWY_ALIGN float tmp[kBlockSize]; + DCTTo to_tmp(tmp, N); + DCTFrom from_tmp(tmp, N); + + DCT1D<N, N>()(from, to_tmp, scratch); + IDCT1D<N, N>()(from_tmp, to, scratch); + + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < N; ++j) { + float expected = (i == j) ? 1.0f : 0.0f; + float actual = from.Read(i, j); + EXPECT_NEAR(expected, actual, accuracy) << " i=" << i << ", j=" << j; + } + } +} + +void ColumnDctRoundtrip() { + ColumnDctRoundtripT<8>(1e-6f); + ColumnDctRoundtripT<16>(1e-6f); + ColumnDctRoundtripT<32>(1e-6f); +} + +template <size_t N> +void TestDctAccuracy(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + HWY_ALIGN float fast[kBlockSize] = {0.0f}; + double slow[kBlockSize] = {0.0}; + fast[i] = 1.0; + slow[i] = 1.0; + DCTSlow<N>(slow); + ComputeDCT<N>(fast); + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(fast[k], slow[k], accuracy / N) + << "i = " << i << ", k = " << k << ", N = " << N; + } + } +} + +template <size_t N> +void TestIdctAccuracy(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + HWY_ALIGN float fast[kBlockSize] = {0.0f}; + double slow[kBlockSize] = {0.0}; + fast[i] = 1.0; + slow[i] = 1.0; + IDCTSlow<N>(slow); + ComputeIDCT<N>(fast); + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(fast[k], slow[k], accuracy * N) + << "i = " << i << ", k = " << k << ", N = " << N; + } + } +} + +template <size_t N> +void TestInverseT(float accuracy) { + test::ThreadPoolForTests pool(N < 32 ? 0 : 8); + enum { kBlockSize = N * N }; + EXPECT_TRUE(RunOnPool( + &pool, 0, kBlockSize, ThreadPool::NoInit, + [accuracy](const uint32_t task, size_t /*thread*/) { + const size_t i = static_cast<size_t>(task); + HWY_ALIGN float x[kBlockSize] = {0.0f}; + x[i] = 1.0; + + ComputeIDCT<N>(x); + ComputeDCT<N>(x); + + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(x[k], (k == i) ? 1.0f : 0.0f, accuracy) + << "i = " << i << ", k = " << k; + } + }, + "TestInverse")); +} + +void InverseTest() { + TestInverseT<8>(1e-6f); + TestInverseT<16>(1e-6f); + TestInverseT<32>(3e-6f); +} + +template <size_t N> +void TestDctTranspose(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + for (size_t j = 0; j < kBlockSize; ++j) { + // We check that <e_i, Me_j> = <M^\dagger{}e_i, e_j>. + // That means (Me_j)_i = (M^\dagger{}e_i)_j + + // x := Me_j + HWY_ALIGN float x[kBlockSize] = {0.0f}; + x[j] = 1.0; + ComputeIDCT<N>(x); + // y := M^\dagger{}e_i + HWY_ALIGN float y[kBlockSize] = {0.0f}; + y[i] = 1.0; + ComputeDCT<N>(y); + + EXPECT_NEAR(x[i] / N, y[j] * N, accuracy) << "i = " << i << ", j = " << j; + } + } +} + +template <size_t N> +void TestSlowInverse(float accuracy, size_t start = 0, size_t end = N * N) { + constexpr size_t kBlockSize = N * N; + for (size_t i = start; i < end; i++) { + double x[kBlockSize] = {0.0f}; + x[i] = 1.0; + + DCTSlow<N>(x); + IDCTSlow<N>(x); + + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(x[k], (k == i) ? 1.0f : 0.0f, accuracy) + << "i = " << i << ", k = " << k; + } + } +} + +template <size_t ROWS, size_t COLS> +void TestRectInverseT(float accuracy) { + constexpr size_t kBlockSize = ROWS * COLS; + for (size_t i = 0; i < kBlockSize; ++i) { + HWY_ALIGN float x[kBlockSize] = {0.0f}; + HWY_ALIGN float out[kBlockSize] = {0.0f}; + x[i] = 1.0; + HWY_ALIGN float coeffs[kBlockSize] = {0.0f}; + HWY_ALIGN float scratch_space[kBlockSize * 5]; + + ComputeScaledDCT<ROWS, COLS>()(DCTFrom(x, COLS), coeffs, scratch_space); + ComputeScaledIDCT<ROWS, COLS>()(coeffs, DCTTo(out, COLS), scratch_space); + + for (size_t k = 0; k < kBlockSize; ++k) { + EXPECT_NEAR(out[k], (k == i) ? 1.0f : 0.0f, accuracy) + << "i = " << i << ", k = " << k << " ROWS = " << ROWS + << " COLS = " << COLS; + } + } +} + +void TestRectInverse() { + TestRectInverseT<16, 32>(1e-6f); + TestRectInverseT<8, 32>(1e-6f); + TestRectInverseT<8, 16>(1e-6f); + TestRectInverseT<4, 8>(1e-6f); + TestRectInverseT<2, 4>(1e-6f); + TestRectInverseT<1, 4>(1e-6f); + TestRectInverseT<1, 2>(1e-6f); + + TestRectInverseT<32, 16>(1e-6f); + TestRectInverseT<32, 8>(1e-6f); + TestRectInverseT<16, 8>(1e-6f); + TestRectInverseT<8, 4>(1e-6f); + TestRectInverseT<4, 2>(1e-6f); + TestRectInverseT<4, 1>(1e-6f); + TestRectInverseT<2, 1>(1e-6f); +} + +template <size_t ROWS, size_t COLS> +void TestRectTransposeT(float accuracy) { + constexpr size_t kBlockSize = ROWS * COLS; + HWY_ALIGN float scratch_space[kBlockSize * 5]; + for (size_t px = 0; px < COLS; ++px) { + for (size_t py = 0; py < ROWS; ++py) { + HWY_ALIGN float x1[kBlockSize] = {0.0f}; + HWY_ALIGN float x2[kBlockSize] = {0.0f}; + HWY_ALIGN float coeffs1[kBlockSize] = {0.0f}; + HWY_ALIGN float coeffs2[kBlockSize] = {0.0f}; + x1[py * COLS + px] = 1; + x2[px * ROWS + py] = 1; + + constexpr size_t OUT_ROWS = ROWS < COLS ? ROWS : COLS; + constexpr size_t OUT_COLS = ROWS < COLS ? COLS : ROWS; + + ComputeScaledDCT<ROWS, COLS>()(DCTFrom(x1, COLS), coeffs1, scratch_space); + ComputeScaledDCT<COLS, ROWS>()(DCTFrom(x2, ROWS), coeffs2, scratch_space); + + for (size_t x = 0; x < OUT_COLS; ++x) { + for (size_t y = 0; y < OUT_ROWS; ++y) { + EXPECT_NEAR(coeffs1[y * OUT_COLS + x], coeffs2[y * OUT_COLS + x], + accuracy) + << " px = " << px << ", py = " << py << ", x = " << x + << ", y = " << y; + } + } + } + } +} + +void TestRectTranspose() { + TestRectTransposeT<16, 32>(1e-6f); + TestRectTransposeT<8, 32>(1e-6f); + TestRectTransposeT<8, 16>(1e-6f); + TestRectTransposeT<4, 8>(1e-6f); + TestRectTransposeT<2, 4>(1e-6f); + TestRectTransposeT<1, 4>(1e-6f); + TestRectTransposeT<1, 2>(1e-6f); + + // Identical to 8, 16 + // TestRectTranspose<16, 8>(1e-6f); +} + +void TestDctAccuracyShard(size_t shard) { + if (shard == 0) { + TestDctAccuracy<1>(1.1E-7f); + TestDctAccuracy<2>(1.1E-7f); + TestDctAccuracy<4>(1.1E-7f); + TestDctAccuracy<8>(1.1E-7f); + TestDctAccuracy<16>(1.3E-7f); + } + TestDctAccuracy<32>(1.1E-7f, 32 * shard, 32 * (shard + 1)); +} + +void TestIdctAccuracyShard(size_t shard) { + if (shard == 0) { + TestIdctAccuracy<1>(1E-7f); + TestIdctAccuracy<2>(1E-7f); + TestIdctAccuracy<4>(1E-7f); + TestIdctAccuracy<8>(1E-7f); + TestIdctAccuracy<16>(1E-7f); + } + TestIdctAccuracy<32>(1E-7f, 32 * shard, 32 * (shard + 1)); +} + +void TestDctTransposeShard(size_t shard) { + if (shard == 0) { + TestDctTranspose<8>(1E-6f); + TestDctTranspose<16>(1E-6f); + } + TestDctTranspose<32>(3E-6f, 32 * shard, 32 * (shard + 1)); +} + +void TestSlowInverseShard(size_t shard) { + if (shard == 0) { + TestSlowInverse<1>(1E-5f); + TestSlowInverse<2>(1E-5f); + TestSlowInverse<4>(1E-5f); + TestSlowInverse<8>(1E-5f); + TestSlowInverse<16>(1E-5f); + } + TestSlowInverse<32>(1E-5f, 32 * shard, 32 * (shard + 1)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class TransposeTest : public hwy::TestWithParamTarget {}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(TransposeTest); + +HWY_EXPORT_AND_TEST_P(TransposeTest, TransposeTest); +HWY_EXPORT_AND_TEST_P(TransposeTest, InverseTest); +HWY_EXPORT_AND_TEST_P(TransposeTest, ColumnDctRoundtrip); +HWY_EXPORT_AND_TEST_P(TransposeTest, TestRectInverse); +HWY_EXPORT_AND_TEST_P(TransposeTest, TestRectTranspose); + +// Tests in the DctShardedTest class are sharded for N=32. +class DctShardedTest : public ::hwy::TestWithParamTargetAndT<uint32_t> {}; + +std::vector<uint32_t> ShardRange(uint32_t n) { +#ifdef JXL_DISABLE_SLOW_TESTS + JXL_ASSERT(n > 6); + std::vector<uint32_t> ret = {0, 1, 3, 5, n - 1}; +#else + std::vector<uint32_t> ret(n); + std::iota(ret.begin(), ret.end(), 0); +#endif // JXL_DISABLE_SLOW_TESTS + return ret; +} + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P_T(DctShardedTest, + ::testing::ValuesIn(ShardRange(32))); + +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestDctAccuracyShard); +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestIdctAccuracyShard); +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestDctTransposeShard); +HWY_EXPORT_AND_TEST_P_T(DctShardedTest, TestSlowInverseShard); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dct_util.h b/third_party/jpeg-xl/lib/jxl/dct_util.h new file mode 100644 index 0000000000..2f29449677 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dct_util.h @@ -0,0 +1,85 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DCT_UTIL_H_ +#define LIB_JXL_DCT_UTIL_H_ + +#include <stddef.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +union ACPtr { + int32_t* ptr32; + int16_t* ptr16; + ACPtr() = default; + explicit ACPtr(int16_t* p) : ptr16(p) {} + explicit ACPtr(int32_t* p) : ptr32(p) {} +}; + +union ConstACPtr { + const int32_t* ptr32; + const int16_t* ptr16; + ConstACPtr() = default; + explicit ConstACPtr(const int16_t* p) : ptr16(p) {} + explicit ConstACPtr(const int32_t* p) : ptr32(p) {} +}; + +enum class ACType { k16 = 0, k32 = 1 }; + +class ACImage { + public: + virtual ~ACImage() = default; + virtual ACType Type() const = 0; + virtual ACPtr PlaneRow(size_t c, size_t y, size_t xbase) = 0; + virtual ConstACPtr PlaneRow(size_t c, size_t y, size_t xbase) const = 0; + virtual size_t PixelsPerRow() const = 0; + virtual void ZeroFill() = 0; + virtual void ZeroFillPlane(size_t c) = 0; + virtual bool IsEmpty() const = 0; +}; + +template <typename T> +class ACImageT final : public ACImage { + public: + ACImageT() = default; + ACImageT(size_t xsize, size_t ysize) { + static_assert( + std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value, + "ACImage must be either 32- or 16- bit"); + img_ = Image3<T>(xsize, ysize); + } + ACType Type() const override { + return sizeof(T) == 2 ? ACType::k16 : ACType::k32; + } + ACPtr PlaneRow(size_t c, size_t y, size_t xbase) override { + return ACPtr(img_.PlaneRow(c, y) + xbase); + } + ConstACPtr PlaneRow(size_t c, size_t y, size_t xbase) const override { + return ConstACPtr(img_.PlaneRow(c, y) + xbase); + } + + size_t PixelsPerRow() const override { return img_.PixelsPerRow(); } + + void ZeroFill() override { ZeroFillImage(&img_); } + + void ZeroFillPlane(size_t c) override { ZeroFillImage(&img_.Plane(c)); } + + bool IsEmpty() const override { + return img_.xsize() == 0 || img_.ysize() == 0; + } + + private: + Image3<T> img_; +}; + +} // namespace jxl + +#endif // LIB_JXL_DCT_UTIL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_ans.cc b/third_party/jpeg-xl/lib/jxl/dec_ans.cc new file mode 100644 index 0000000000..29d41c8062 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_ans.cc @@ -0,0 +1,363 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_ans.h" + +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_context_map.h" +#include "lib/jxl/fields.h" + +namespace jxl { +namespace { + +// Decodes a number in the range [0..255], by reading 1 - 11 bits. +inline int DecodeVarLenUint8(BitReader* input) { + if (input->ReadFixedBits<1>()) { + int nbits = static_cast<int>(input->ReadFixedBits<3>()); + if (nbits == 0) { + return 1; + } else { + return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits); + } + } + return 0; +} + +// Decodes a number in the range [0..65535], by reading 1 - 21 bits. +inline int DecodeVarLenUint16(BitReader* input) { + if (input->ReadFixedBits<1>()) { + int nbits = static_cast<int>(input->ReadFixedBits<4>()); + if (nbits == 0) { + return 1; + } else { + return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits); + } + } + return 0; +} + +Status ReadHistogram(int precision_bits, std::vector<int32_t>* counts, + BitReader* input) { + int simple_code = input->ReadBits(1); + if (simple_code == 1) { + int i; + int symbols[2] = {0}; + int max_symbol = 0; + const int num_symbols = input->ReadBits(1) + 1; + for (i = 0; i < num_symbols; ++i) { + symbols[i] = DecodeVarLenUint8(input); + if (symbols[i] > max_symbol) max_symbol = symbols[i]; + } + counts->resize(max_symbol + 1); + if (num_symbols == 1) { + (*counts)[symbols[0]] = 1 << precision_bits; + } else { + if (symbols[0] == symbols[1]) { // corrupt data + return false; + } + (*counts)[symbols[0]] = input->ReadBits(precision_bits); + (*counts)[symbols[1]] = (1 << precision_bits) - (*counts)[symbols[0]]; + } + } else { + int is_flat = input->ReadBits(1); + if (is_flat == 1) { + int alphabet_size = DecodeVarLenUint8(input) + 1; + *counts = CreateFlatHistogram(alphabet_size, 1 << precision_bits); + return true; + } + + uint32_t shift; + { + // TODO(veluca): speed up reading with table lookups. + int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1); + int log = 0; + for (; log < upper_bound_log; log++) { + if (input->ReadFixedBits<1>() == 0) break; + } + shift = (input->ReadBits(log) | (1 << log)) - 1; + if (shift > ANS_LOG_TAB_SIZE + 1) { + return JXL_FAILURE("Invalid shift value"); + } + } + + int length = DecodeVarLenUint8(input) + 3; + counts->resize(length); + int total_count = 0; + + static const uint8_t huff[128][2] = { + {3, 10}, {7, 12}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {7, 13}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, + {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, + }; + + std::vector<int> logcounts(counts->size()); + int omit_log = -1; + int omit_pos = -1; + // This array remembers which symbols have an RLE length. + std::vector<int> same(counts->size(), 0); + for (size_t i = 0; i < logcounts.size(); ++i) { + input->Refill(); // for PeekFixedBits + Advance + int idx = input->PeekFixedBits<7>(); + input->Consume(huff[idx][0]); + logcounts[i] = huff[idx][1]; + // The RLE symbol. + if (logcounts[i] == ANS_LOG_TAB_SIZE + 1) { + int rle_length = DecodeVarLenUint8(input); + same[i] = rle_length + 5; + i += rle_length + 3; + continue; + } + if (logcounts[i] > omit_log) { + omit_log = logcounts[i]; + omit_pos = i; + } + } + // Invalid input, e.g. due to invalid usage of RLE. + if (omit_pos < 0) return JXL_FAILURE("Invalid histogram."); + if (static_cast<size_t>(omit_pos) + 1 < logcounts.size() && + logcounts[omit_pos + 1] == ANS_TAB_SIZE + 1) { + return JXL_FAILURE("Invalid histogram."); + } + int prev = 0; + int numsame = 0; + for (size_t i = 0; i < logcounts.size(); ++i) { + if (same[i]) { + // RLE sequence, let this loop output the same count for the next + // iterations. + numsame = same[i] - 1; + prev = i > 0 ? (*counts)[i - 1] : 0; + } + if (numsame > 0) { + (*counts)[i] = prev; + numsame--; + } else { + int code = logcounts[i]; + // omit_pos may not be negative at this point (checked before). + if (i == static_cast<size_t>(omit_pos)) { + continue; + } else if (code == 0) { + continue; + } else if (code == 1) { + (*counts)[i] = 1; + } else { + int bitcount = GetPopulationCountPrecision(code - 1, shift); + (*counts)[i] = (1 << (code - 1)) + + (input->ReadBits(bitcount) << (code - 1 - bitcount)); + } + } + total_count += (*counts)[i]; + } + (*counts)[omit_pos] = (1 << precision_bits) - total_count; + if ((*counts)[omit_pos] <= 0) { + // The histogram we've read sums to more than total_count (including at + // least 1 for the omitted value). + return JXL_FAILURE("Invalid histogram count."); + } + } + return true; +} + +} // namespace + +Status DecodeANSCodes(const size_t num_histograms, + const size_t max_alphabet_size, BitReader* in, + ANSCode* result) { + result->degenerate_symbols.resize(num_histograms, -1); + if (result->use_prefix_code) { + JXL_ASSERT(max_alphabet_size <= 1 << PREFIX_MAX_BITS); + result->huffman_data.resize(num_histograms); + std::vector<uint16_t> alphabet_sizes(num_histograms); + for (size_t c = 0; c < num_histograms; c++) { + alphabet_sizes[c] = DecodeVarLenUint16(in) + 1; + if (alphabet_sizes[c] > max_alphabet_size) { + return JXL_FAILURE("Alphabet size is too long: %u", alphabet_sizes[c]); + } + } + for (size_t c = 0; c < num_histograms; c++) { + if (alphabet_sizes[c] > 1) { + if (!result->huffman_data[c].ReadFromBitStream(alphabet_sizes[c], in)) { + if (!in->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for huffman code"); + } + return JXL_FAILURE("Invalid huffman tree number %" PRIuS + ", alphabet size %u", + c, alphabet_sizes[c]); + } + } else { + // 0-bit codes does not require extension tables. + result->huffman_data[c].table_.clear(); + result->huffman_data[c].table_.resize(1u << kHuffmanTableBits); + } + for (const auto& h : result->huffman_data[c].table_) { + if (h.bits <= kHuffmanTableBits) { + result->UpdateMaxNumBits(c, h.value); + } + } + } + } else { + JXL_ASSERT(max_alphabet_size <= ANS_MAX_ALPHABET_SIZE); + result->alias_tables = + AllocateArray(num_histograms * (1 << result->log_alpha_size) * + sizeof(AliasTable::Entry)); + AliasTable::Entry* alias_tables = + reinterpret_cast<AliasTable::Entry*>(result->alias_tables.get()); + for (size_t c = 0; c < num_histograms; ++c) { + std::vector<int32_t> counts; + if (!ReadHistogram(ANS_LOG_TAB_SIZE, &counts, in)) { + return JXL_FAILURE("Invalid histogram bitstream."); + } + if (counts.size() > max_alphabet_size) { + return JXL_FAILURE("Alphabet size is too long: %" PRIuS, counts.size()); + } + while (!counts.empty() && counts.back() == 0) { + counts.pop_back(); + } + for (size_t s = 0; s < counts.size(); s++) { + if (counts[s] != 0) { + result->UpdateMaxNumBits(c, s); + } + } + // InitAliasTable "fixes" empty counts to contain degenerate "0" symbol. + int degenerate_symbol = counts.empty() ? 0 : (counts.size() - 1); + for (int s = 0; s < degenerate_symbol; ++s) { + if (counts[s] != 0) { + degenerate_symbol = -1; + break; + } + } + result->degenerate_symbols[c] = degenerate_symbol; + InitAliasTable(counts, ANS_TAB_SIZE, result->log_alpha_size, + alias_tables + c * (1 << result->log_alpha_size)); + } + } + return true; +} +Status DecodeUintConfig(size_t log_alpha_size, HybridUintConfig* uint_config, + BitReader* br) { + br->Refill(); + size_t split_exponent = br->ReadBits(CeilLog2Nonzero(log_alpha_size + 1)); + size_t msb_in_token = 0, lsb_in_token = 0; + if (split_exponent != log_alpha_size) { + // otherwise, msb/lsb don't matter. + size_t nbits = CeilLog2Nonzero(split_exponent + 1); + msb_in_token = br->ReadBits(nbits); + if (msb_in_token > split_exponent) { + // This could be invalid here already and we need to check this before + // we use its value to read more bits. + return JXL_FAILURE("Invalid HybridUintConfig"); + } + nbits = CeilLog2Nonzero(split_exponent - msb_in_token + 1); + lsb_in_token = br->ReadBits(nbits); + } + if (lsb_in_token + msb_in_token > split_exponent) { + return JXL_FAILURE("Invalid HybridUintConfig"); + } + *uint_config = HybridUintConfig(split_exponent, msb_in_token, lsb_in_token); + return true; +} + +Status DecodeUintConfigs(size_t log_alpha_size, + std::vector<HybridUintConfig>* uint_config, + BitReader* br) { + // TODO(veluca): RLE? + for (size_t i = 0; i < uint_config->size(); i++) { + JXL_RETURN_IF_ERROR( + DecodeUintConfig(log_alpha_size, &(*uint_config)[i], br)); + } + return true; +} + +LZ77Params::LZ77Params() { Bundle::Init(this); } +Status LZ77Params::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &enabled)); + if (!visitor->Conditional(enabled)) return true; + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(224), Val(512), Val(4096), + BitsOffset(15, 8), 224, &min_symbol)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(3), Val(4), BitsOffset(2, 5), + BitsOffset(8, 9), 3, &min_length)); + return true; +} + +void ANSCode::UpdateMaxNumBits(size_t ctx, size_t symbol) { + HybridUintConfig* cfg = &uint_config[ctx]; + // LZ77 symbols use a different uint config. + if (lz77.enabled && lz77.nonserialized_distance_context != ctx && + symbol >= lz77.min_symbol) { + symbol -= lz77.min_symbol; + cfg = &lz77.length_uint_config; + } + size_t split_token = cfg->split_token; + size_t msb_in_token = cfg->msb_in_token; + size_t lsb_in_token = cfg->lsb_in_token; + size_t split_exponent = cfg->split_exponent; + if (symbol < split_token) { + max_num_bits = std::max(max_num_bits, split_exponent); + return; + } + uint32_t n_extra_bits = + split_exponent - (msb_in_token + lsb_in_token) + + ((symbol - split_token) >> (msb_in_token + lsb_in_token)); + size_t total_bits = msb_in_token + lsb_in_token + n_extra_bits + 1; + max_num_bits = std::max(max_num_bits, total_bits); +} + +Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, + std::vector<uint8_t>* context_map, bool disallow_lz77) { + JXL_RETURN_IF_ERROR(Bundle::Read(br, &code->lz77)); + if (code->lz77.enabled) { + num_contexts++; + JXL_RETURN_IF_ERROR(DecodeUintConfig(/*log_alpha_size=*/8, + &code->lz77.length_uint_config, br)); + } + if (code->lz77.enabled && disallow_lz77) { + return JXL_FAILURE("Using LZ77 when explicitly disallowed"); + } + size_t num_histograms = 1; + context_map->resize(num_contexts); + if (num_contexts > 1) { + JXL_RETURN_IF_ERROR(DecodeContextMap(context_map, &num_histograms, br)); + } + JXL_DEBUG_V( + 4, "Decoded context map of size %" PRIuS " and %" PRIuS " histograms", + num_contexts, num_histograms); + code->lz77.nonserialized_distance_context = context_map->back(); + code->use_prefix_code = br->ReadFixedBits<1>(); + if (code->use_prefix_code) { + code->log_alpha_size = PREFIX_MAX_BITS; + } else { + code->log_alpha_size = br->ReadFixedBits<2>() + 5; + } + code->uint_config.resize(num_histograms); + JXL_RETURN_IF_ERROR( + DecodeUintConfigs(code->log_alpha_size, &code->uint_config, br)); + const size_t max_alphabet_size = 1 << code->log_alpha_size; + JXL_RETURN_IF_ERROR( + DecodeANSCodes(num_histograms, max_alphabet_size, br, code)); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_ans.h b/third_party/jpeg-xl/lib/jxl/dec_ans.h new file mode 100644 index 0000000000..57faad25a7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_ans.h @@ -0,0 +1,505 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_ANS_H_ +#define LIB_JXL_DEC_ANS_H_ + +// Library to decode the ANS population counts from the bit-stream and build a +// decoding table from them. + +#include <stddef.h> +#include <stdint.h> + +#include <cstring> +#include <vector> + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/cache_aligned.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_huffman.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +class ANSSymbolReader; + +// Experiments show that best performance is typically achieved for a +// split-exponent of 3 or 4. Trend seems to be that '4' is better +// for large-ish pictures, and '3' better for rather small-ish pictures. +// This is plausible - the more special symbols we have, the better +// statistics we need to get a benefit out of them. + +// Our hybrid-encoding scheme has dedicated tokens for the smallest +// (1 << split_exponents) numbers, and for the rest +// encodes (number of bits) + (msb_in_token sub-leading binary digits) + +// (lsb_in_token lowest binary digits) in the token, with the remaining bits +// then being encoded as data. +// +// Example with split_exponent = 4, msb_in_token = 2, lsb_in_token = 0. +// +// Numbers N in [0 .. 15]: +// These get represented as (token=N, bits=''). +// Numbers N >= 16: +// If n is such that 2**n <= N < 2**(n+1), +// and m = N - 2**n is the 'mantissa', +// these get represented as: +// (token=split_token + +// ((n - split_exponent) * 4) + +// (m >> (n - msb_in_token)), +// bits=m & (1 << (n - msb_in_token)) - 1) +// Specifically, we would get: +// N = 0 - 15: (token=N, nbits=0, bits='') +// N = 16 (10000): (token=16, nbits=2, bits='00') +// N = 17 (10001): (token=16, nbits=2, bits='01') +// N = 20 (10100): (token=17, nbits=2, bits='00') +// N = 24 (11000): (token=18, nbits=2, bits='00') +// N = 28 (11100): (token=19, nbits=2, bits='00') +// N = 32 (100000): (token=20, nbits=3, bits='000') +// N = 65535: (token=63, nbits=13, bits='1111111111111') +struct HybridUintConfig { + uint32_t split_exponent; + uint32_t split_token; + uint32_t msb_in_token; + uint32_t lsb_in_token; + JXL_INLINE void Encode(uint32_t value, uint32_t* JXL_RESTRICT token, + uint32_t* JXL_RESTRICT nbits, + uint32_t* JXL_RESTRICT bits) const { + if (value < split_token) { + *token = value; + *nbits = 0; + *bits = 0; + } else { + uint32_t n = FloorLog2Nonzero(value); + uint32_t m = value - (1 << n); + *token = split_token + + ((n - split_exponent) << (msb_in_token + lsb_in_token)) + + ((m >> (n - msb_in_token)) << lsb_in_token) + + (m & ((1 << lsb_in_token) - 1)); + *nbits = n - msb_in_token - lsb_in_token; + *bits = (value >> lsb_in_token) & ((1UL << *nbits) - 1); + } + } + + explicit HybridUintConfig(uint32_t split_exponent = 4, + uint32_t msb_in_token = 2, + uint32_t lsb_in_token = 0) + : split_exponent(split_exponent), + split_token(1 << split_exponent), + msb_in_token(msb_in_token), + lsb_in_token(lsb_in_token) { + JXL_DASSERT(split_exponent >= msb_in_token + lsb_in_token); + } +}; + +struct LZ77Params : public Fields { + LZ77Params(); + JXL_FIELDS_NAME(LZ77Params) + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + bool enabled; + + // Symbols above min_symbol use a special hybrid uint encoding and + // represent a length, to be added to min_length. + uint32_t min_symbol; + uint32_t min_length; + + // Not serialized by VisitFields. + HybridUintConfig length_uint_config{0, 0, 0}; + + size_t nonserialized_distance_context; +}; + +static constexpr size_t kWindowSize = 1 << 20; +static constexpr size_t kNumSpecialDistances = 120; +// Table of special distance codes from WebP lossless. +static constexpr int8_t kSpecialDistances[kNumSpecialDistances][2] = { + {0, 1}, {1, 0}, {1, 1}, {-1, 1}, {0, 2}, {2, 0}, {1, 2}, {-1, 2}, + {2, 1}, {-2, 1}, {2, 2}, {-2, 2}, {0, 3}, {3, 0}, {1, 3}, {-1, 3}, + {3, 1}, {-3, 1}, {2, 3}, {-2, 3}, {3, 2}, {-3, 2}, {0, 4}, {4, 0}, + {1, 4}, {-1, 4}, {4, 1}, {-4, 1}, {3, 3}, {-3, 3}, {2, 4}, {-2, 4}, + {4, 2}, {-4, 2}, {0, 5}, {3, 4}, {-3, 4}, {4, 3}, {-4, 3}, {5, 0}, + {1, 5}, {-1, 5}, {5, 1}, {-5, 1}, {2, 5}, {-2, 5}, {5, 2}, {-5, 2}, + {4, 4}, {-4, 4}, {3, 5}, {-3, 5}, {5, 3}, {-5, 3}, {0, 6}, {6, 0}, + {1, 6}, {-1, 6}, {6, 1}, {-6, 1}, {2, 6}, {-2, 6}, {6, 2}, {-6, 2}, + {4, 5}, {-4, 5}, {5, 4}, {-5, 4}, {3, 6}, {-3, 6}, {6, 3}, {-6, 3}, + {0, 7}, {7, 0}, {1, 7}, {-1, 7}, {5, 5}, {-5, 5}, {7, 1}, {-7, 1}, + {4, 6}, {-4, 6}, {6, 4}, {-6, 4}, {2, 7}, {-2, 7}, {7, 2}, {-7, 2}, + {3, 7}, {-3, 7}, {7, 3}, {-7, 3}, {5, 6}, {-5, 6}, {6, 5}, {-6, 5}, + {8, 0}, {4, 7}, {-4, 7}, {7, 4}, {-7, 4}, {8, 1}, {8, 2}, {6, 6}, + {-6, 6}, {8, 3}, {5, 7}, {-5, 7}, {7, 5}, {-7, 5}, {8, 4}, {6, 7}, + {-6, 7}, {7, 6}, {-7, 6}, {8, 5}, {7, 7}, {-7, 7}, {8, 6}, {8, 7}}; + +struct ANSCode { + CacheAlignedUniquePtr alias_tables; + std::vector<HuffmanDecodingData> huffman_data; + std::vector<HybridUintConfig> uint_config; + std::vector<int> degenerate_symbols; + bool use_prefix_code; + uint8_t log_alpha_size; // for ANS. + LZ77Params lz77; + // Maximum number of bits necessary to represent the result of a + // ReadHybridUint call done with this ANSCode. + size_t max_num_bits = 0; + void UpdateMaxNumBits(size_t ctx, size_t symbol); +}; + +class ANSSymbolReader { + public: + // Invalid symbol reader, to be overwritten. + ANSSymbolReader() = default; + ANSSymbolReader(const ANSCode* code, BitReader* JXL_RESTRICT br, + size_t distance_multiplier = 0) + : alias_tables_( + reinterpret_cast<AliasTable::Entry*>(code->alias_tables.get())), + huffman_data_(code->huffman_data.data()), + use_prefix_code_(code->use_prefix_code), + configs(code->uint_config.data()) { + if (!use_prefix_code_) { + state_ = static_cast<uint32_t>(br->ReadFixedBits<32>()); + log_alpha_size_ = code->log_alpha_size; + log_entry_size_ = ANS_LOG_TAB_SIZE - code->log_alpha_size; + entry_size_minus_1_ = (1 << log_entry_size_) - 1; + } else { + state_ = (ANS_SIGNATURE << 16u); + } + if (!code->lz77.enabled) return; + // a std::vector incurs unacceptable decoding speed loss because of + // initialization. + lz77_window_storage_ = AllocateArray(kWindowSize * sizeof(uint32_t)); + lz77_window_ = reinterpret_cast<uint32_t*>(lz77_window_storage_.get()); + lz77_ctx_ = code->lz77.nonserialized_distance_context; + lz77_length_uint_ = code->lz77.length_uint_config; + lz77_threshold_ = code->lz77.min_symbol; + lz77_min_length_ = code->lz77.min_length; + num_special_distances_ = + distance_multiplier == 0 ? 0 : kNumSpecialDistances; + for (size_t i = 0; i < num_special_distances_; i++) { + int dist = kSpecialDistances[i][0]; + dist += static_cast<int>(distance_multiplier) * kSpecialDistances[i][1]; + if (dist < 1) dist = 1; + special_distances_[i] = dist; + } + } + + JXL_INLINE size_t ReadSymbolANSWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); + + const AliasTable::Entry* table = + &alias_tables_[histo_idx << log_alpha_size_]; + const AliasTable::Symbol symbol = + AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); + state_ = symbol.freq * (state_ >> ANS_LOG_TAB_SIZE) + symbol.offset; + +#if 1 + // Branchless version is about equally fast on SKX. + const uint32_t new_state = + (state_ << 16u) | static_cast<uint32_t>(br->PeekFixedBits<16>()); + const bool normalize = state_ < (1u << 16u); + state_ = normalize ? new_state : state_; + br->Consume(normalize ? 16 : 0); +#else + if (JXL_UNLIKELY(state_ < (1u << 16u))) { + state_ = (state_ << 16u) | br->PeekFixedBits<16>(); + br->Consume(16); + } +#endif + const uint32_t next_res = state_ & (ANS_TAB_SIZE - 1u); + AliasTable::Prefetch(table, next_res, log_entry_size_); + + return symbol.value; + } + + JXL_INLINE size_t ReadSymbolHuffWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + return huffman_data_[histo_idx].ReadSymbol(br); + } + + JXL_INLINE size_t ReadSymbolWithoutRefill(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + // TODO(veluca): hoist if in hotter loops. + if (JXL_UNLIKELY(use_prefix_code_)) { + return ReadSymbolHuffWithoutRefill(histo_idx, br); + } + return ReadSymbolANSWithoutRefill(histo_idx, br); + } + + JXL_INLINE size_t ReadSymbol(const size_t histo_idx, + BitReader* JXL_RESTRICT br) { + br->Refill(); + return ReadSymbolWithoutRefill(histo_idx, br); + } + +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + bool CheckANSFinalState() const { return true; } +#else + bool CheckANSFinalState() const { return state_ == (ANS_SIGNATURE << 16u); } +#endif + + template <typename BitReader> + static JXL_INLINE uint32_t ReadHybridUintConfig( + const HybridUintConfig& config, size_t token, BitReader* br) { + size_t split_token = config.split_token; + size_t msb_in_token = config.msb_in_token; + size_t lsb_in_token = config.lsb_in_token; + size_t split_exponent = config.split_exponent; + // Fast-track version of hybrid integer decoding. + if (token < split_token) return token; + uint32_t nbits = split_exponent - (msb_in_token + lsb_in_token) + + ((token - split_token) >> (msb_in_token + lsb_in_token)); + // Max amount of bits for ReadBits is 32 and max valid left shift is 29 + // bits. However, for speed no error is propagated here, instead limit the + // nbits size. If nbits > 29, the code stream is invalid, but no error is + // returned. + // Note that in most cases we will emit an error if the histogram allows + // representing numbers that would cause invalid shifts, but we need to + // keep this check as when LZ77 is enabled it might make sense to have an + // histogram that could in principle cause invalid shifts. + nbits &= 31u; + uint32_t low = token & ((1 << lsb_in_token) - 1); + token >>= lsb_in_token; + const size_t bits = br->PeekBits(nbits); + br->Consume(nbits); + size_t ret = (((((1 << msb_in_token) | (token & ((1 << msb_in_token) - 1))) + << nbits) | + bits) + << lsb_in_token) | + low; + // TODO(eustas): mark BitReader as unhealthy if nbits > 29 or ret does not + // fit uint32_t + return static_cast<uint32_t>(ret); + } + + // Takes a *clustered* idx. Can only use if HuffRleOnly() is true. + JXL_INLINE void ReadHybridUintClusteredHuffRleOnly(size_t ctx, + BitReader* JXL_RESTRICT br, + uint32_t* value, + uint32_t* run) { + JXL_DASSERT(HuffRleOnly()); + br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits + size_t token = ReadSymbolHuffWithoutRefill(ctx, br); + if (JXL_UNLIKELY(token >= lz77_threshold_)) { + *run = + ReadHybridUintConfig(lz77_length_uint_, token - lz77_threshold_, br) + + lz77_min_length_ - 1; + return; + } + *value = ReadHybridUintConfig(configs[ctx], token, br); + } + bool HuffRleOnly() { + if (lz77_window_ == nullptr) return false; + if (!use_prefix_code_) return false; + for (size_t i = 0; i < kHuffmanTableBits; i++) { + if (huffman_data_[lz77_ctx_].table_[i].bits) return false; + if (huffman_data_[lz77_ctx_].table_[i].value != 1) return false; + } + if (configs[lz77_ctx_].split_token > 1) return false; + return true; + } + bool UsesLZ77() { return lz77_window_ != nullptr; } + + // Takes a *clustered* idx. Inlined, for use in hot paths. + template <bool uses_lz77> + JXL_INLINE size_t ReadHybridUintClusteredInlined(size_t ctx, + BitReader* JXL_RESTRICT br) { + if (uses_lz77) { + if (JXL_UNLIKELY(num_to_copy_ > 0)) { + size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; + num_to_copy_--; + lz77_window_[(num_decoded_++) & kWindowMask] = ret; + return ret; + } + } + + br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits + size_t token = ReadSymbolWithoutRefill(ctx, br); + if (uses_lz77) { + if (JXL_UNLIKELY(token >= lz77_threshold_)) { + num_to_copy_ = ReadHybridUintConfig(lz77_length_uint_, + token - lz77_threshold_, br) + + lz77_min_length_; + br->Refill(); // covers ReadSymbolWithoutRefill + PeekBits + // Distance code. + size_t token = ReadSymbolWithoutRefill(lz77_ctx_, br); + size_t distance = ReadHybridUintConfig(configs[lz77_ctx_], token, br); + if (JXL_LIKELY(distance < num_special_distances_)) { + distance = special_distances_[distance]; + } else { + distance = distance + 1 - num_special_distances_; + } + if (JXL_UNLIKELY(distance > num_decoded_)) { + distance = num_decoded_; + } + if (JXL_UNLIKELY(distance > kWindowSize)) { + distance = kWindowSize; + } + copy_pos_ = num_decoded_ - distance; + if (JXL_UNLIKELY(distance == 0)) { + JXL_DASSERT(lz77_window_ != nullptr); + // distance 0 -> num_decoded_ == copy_pos_ == 0 + size_t to_fill = std::min<size_t>(num_to_copy_, kWindowSize); + memset(lz77_window_, 0, to_fill * sizeof(lz77_window_[0])); + } + // TODO(eustas): overflow; mark BitReader as unhealthy + if (num_to_copy_ < lz77_min_length_) return 0; + // the code below is the same as doing this: + // return ReadHybridUintClustered<uses_lz77>(ctx, br); + // but gcc doesn't like recursive inlining + + size_t ret = lz77_window_[(copy_pos_++) & kWindowMask]; + num_to_copy_--; + lz77_window_[(num_decoded_++) & kWindowMask] = ret; + return ret; + } + } + size_t ret = ReadHybridUintConfig(configs[ctx], token, br); + if (uses_lz77 && lz77_window_) + lz77_window_[(num_decoded_++) & kWindowMask] = ret; + return ret; + } + + // same but not inlined + template <bool uses_lz77> + size_t ReadHybridUintClustered(size_t ctx, BitReader* JXL_RESTRICT br) { + return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); + } + + // inlined only in the no-lz77 case + template <bool uses_lz77> + JXL_INLINE size_t + ReadHybridUintClusteredMaybeInlined(size_t ctx, BitReader* JXL_RESTRICT br) { + if (uses_lz77) { + return ReadHybridUintClustered<uses_lz77>(ctx, br); + } else { + return ReadHybridUintClusteredInlined<uses_lz77>(ctx, br); + } + } + + // inlined, for use in hot paths + template <bool uses_lz77> + JXL_INLINE size_t + ReadHybridUintInlined(size_t ctx, BitReader* JXL_RESTRICT br, + const std::vector<uint8_t>& context_map) { + return ReadHybridUintClustered<uses_lz77>(context_map[ctx], br); + } + + // not inlined, for use in non-hot paths + size_t ReadHybridUint(size_t ctx, BitReader* JXL_RESTRICT br, + const std::vector<uint8_t>& context_map) { + return ReadHybridUintClustered</*uses_lz77=*/true>(context_map[ctx], br); + } + + // ctx is a *clustered* context! + // This function will modify the ANS state as if `count` symbols have been + // decoded. + bool IsSingleValueAndAdvance(size_t ctx, uint32_t* value, size_t count) { + // TODO(veluca): No optimization for Huffman mode yet. + if (use_prefix_code_) return false; + // TODO(eustas): propagate "degenerate_symbol" to simplify this method. + const uint32_t res = state_ & (ANS_TAB_SIZE - 1u); + const AliasTable::Entry* table = &alias_tables_[ctx << log_alpha_size_]; + AliasTable::Symbol symbol = + AliasTable::Lookup(table, res, log_entry_size_, entry_size_minus_1_); + if (symbol.freq != ANS_TAB_SIZE) return false; + if (configs[ctx].split_token <= symbol.value) return false; + if (symbol.value >= lz77_threshold_) return false; + *value = symbol.value; + if (lz77_window_) { + for (size_t i = 0; i < count; i++) { + lz77_window_[(num_decoded_++) & kWindowMask] = symbol.value; + } + } + return true; + } + + static constexpr size_t kMaxCheckpointInterval = 512; + struct Checkpoint { + uint32_t state; + uint32_t num_to_copy; + uint32_t copy_pos; + uint32_t num_decoded; + uint32_t lz77_window[kMaxCheckpointInterval]; + }; + void Save(Checkpoint* checkpoint) { + checkpoint->state = state_; + checkpoint->num_decoded = num_decoded_; + checkpoint->num_to_copy = num_to_copy_; + checkpoint->copy_pos = copy_pos_; + if (lz77_window_) { + size_t win_start = num_decoded_ & kWindowMask; + size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; + if (win_end > win_start) { + memcpy(checkpoint->lz77_window, lz77_window_ + win_start, + (win_end - win_start) * sizeof(*lz77_window_)); + } else { + memcpy(checkpoint->lz77_window, lz77_window_ + win_start, + (kWindowSize - win_start) * sizeof(*lz77_window_)); + memcpy(checkpoint->lz77_window + (kWindowSize - win_start), + lz77_window_, win_end * sizeof(*lz77_window_)); + } + } + } + void Restore(const Checkpoint& checkpoint) { + state_ = checkpoint.state; + JXL_DASSERT(num_decoded_ <= + checkpoint.num_decoded + kMaxCheckpointInterval); + num_decoded_ = checkpoint.num_decoded; + num_to_copy_ = checkpoint.num_to_copy; + copy_pos_ = checkpoint.copy_pos; + if (lz77_window_) { + size_t win_start = num_decoded_ & kWindowMask; + size_t win_end = (num_decoded_ + kMaxCheckpointInterval) & kWindowMask; + if (win_end > win_start) { + memcpy(lz77_window_ + win_start, checkpoint.lz77_window, + (win_end - win_start) * sizeof(*lz77_window_)); + } else { + memcpy(lz77_window_ + win_start, checkpoint.lz77_window, + (kWindowSize - win_start) * sizeof(*lz77_window_)); + memcpy(lz77_window_, checkpoint.lz77_window + (kWindowSize - win_start), + win_end * sizeof(*lz77_window_)); + } + } + } + + private: + const AliasTable::Entry* JXL_RESTRICT alias_tables_; // not owned + const HuffmanDecodingData* huffman_data_; + bool use_prefix_code_; + uint32_t state_ = ANS_SIGNATURE << 16u; + const HybridUintConfig* JXL_RESTRICT configs; + uint32_t log_alpha_size_{}; + uint32_t log_entry_size_{}; + uint32_t entry_size_minus_1_{}; + + // LZ77 structures and constants. + static constexpr size_t kWindowMask = kWindowSize - 1; + CacheAlignedUniquePtr lz77_window_storage_; + uint32_t* lz77_window_ = nullptr; + uint32_t num_decoded_ = 0; + uint32_t num_to_copy_ = 0; + uint32_t copy_pos_ = 0; + uint32_t lz77_ctx_ = 0; + uint32_t lz77_min_length_ = 0; + uint32_t lz77_threshold_ = 1 << 20; // bigger than any symbol. + HybridUintConfig lz77_length_uint_; + uint32_t special_distances_[kNumSpecialDistances]{}; + uint32_t num_special_distances_{}; +}; + +Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, + std::vector<uint8_t>* context_map, + bool disallow_lz77 = false); + +// Exposed for tests. +Status DecodeUintConfigs(size_t log_alpha_size, + std::vector<HybridUintConfig>* uint_config, + BitReader* br); + +} // namespace jxl + +#endif // LIB_JXL_DEC_ANS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_bit_reader.h b/third_party/jpeg-xl/lib/jxl/dec_bit_reader.h new file mode 100644 index 0000000000..ea71d759fa --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_bit_reader.h @@ -0,0 +1,352 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_BIT_READER_H_ +#define LIB_JXL_DEC_BIT_READER_H_ + +// Bounds-checked bit reader; 64-bit buffer with support for deferred refills +// and switching to reading byte-aligned words. + +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memcpy + +#ifdef __BMI2__ +#include <immintrin.h> +#endif + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Reads bits previously written to memory by BitWriter. Uses unaligned 8-byte +// little-endian loads. +class BitReader { + public: + static constexpr size_t kMaxBitsPerCall = 56; + + // Constructs an invalid BitReader, to be overwritten before usage. + BitReader() + : buf_(0), + bits_in_buf_(0), + next_byte_{nullptr}, + end_minus_8_{nullptr}, + first_byte_(nullptr) {} + BitReader(const BitReader&) = delete; + + // bytes need not be aligned nor padded! + template <class ArrayLike> + explicit BitReader(const ArrayLike& bytes) + : buf_(0), + bits_in_buf_(0), + next_byte_(bytes.data()), + // Assumes first_byte_ >= 8. + end_minus_8_(bytes.data() - 8 + bytes.size()), + first_byte_(bytes.data()) { + Refill(); + } + ~BitReader() { + // Close() must be called before destroying an initialized bit reader. + // Invalid bit readers will have a nullptr in first_byte_. + JXL_ASSERT(close_called_ || !first_byte_); + } + + // Move operator needs to invalidate the other BitReader such that it is + // irrelevant if we call Close() on it or not. + BitReader& operator=(BitReader&& other) noexcept { + // Ensure the current instance was already closed, before we overwrite it + // with other. + JXL_ASSERT(close_called_ || !first_byte_); + + JXL_DASSERT(!other.close_called_); + buf_ = other.buf_; + bits_in_buf_ = other.bits_in_buf_; + next_byte_ = other.next_byte_; + end_minus_8_ = other.end_minus_8_; + first_byte_ = other.first_byte_; + overread_bytes_ = other.overread_bytes_; + close_called_ = other.close_called_; + + other.first_byte_ = nullptr; + other.next_byte_ = nullptr; + return *this; + } + BitReader& operator=(const BitReader& other) = delete; + + // For time-critical reads, refills can be shared by multiple reads. + // Based on variant 4 (plus bounds-checking), see + // fgiesen.wordpress.com/2018/02/20/reading-bits-in-far-too-many-ways-part-2/ + JXL_INLINE void Refill() { + if (JXL_UNLIKELY(next_byte_ > end_minus_8_)) { + BoundsCheckedRefill(); + } else { + // It's safe to load 64 bits; insert valid (possibly nonzero) bits above + // bits_in_buf_. The shift requires bits_in_buf_ < 64. + buf_ |= LoadLE64(next_byte_) << bits_in_buf_; + + // Advance by bytes fully absorbed into the buffer. + next_byte_ += (63 - bits_in_buf_) >> 3; + + // We absorbed a multiple of 8 bits, so the lower 3 bits of bits_in_buf_ + // must remain unchanged, otherwise the next refill's shifted bits will + // not align with buf_. Set the three upper bits so the result >= 56. + bits_in_buf_ |= 56; + JXL_DASSERT(56 <= bits_in_buf_ && bits_in_buf_ < 64); + } + } + + // Returns the bits that would be returned by Read without calling Advance(). + // It is legal to PEEK at more bits than present in the bitstream (required + // by Huffman), and those bits will be zero. + template <size_t N> + JXL_INLINE uint64_t PeekFixedBits() const { + static_assert(N <= kMaxBitsPerCall, "Reading too many bits in one call."); + JXL_DASSERT(!close_called_); + return buf_ & ((1ULL << N) - 1); + } + + JXL_INLINE uint64_t PeekBits(size_t nbits) const { + JXL_DASSERT(nbits <= kMaxBitsPerCall); + JXL_DASSERT(!close_called_); + + // Slightly faster but requires BMI2. It is infeasible to make the many + // callers reside between begin/end_target, especially because only the + // callers in dec_ans are time-critical. Therefore only enabled if the + // entire binary is compiled for (and thus requires) BMI2. +#if defined(__BMI2__) && defined(__x86_64__) + return _bzhi_u64(buf_, nbits); +#else + const uint64_t mask = (1ULL << nbits) - 1; + return buf_ & mask; +#endif + } + + // Removes bits from the buffer. Need not match the previous Peek size, but + // the buffer must contain at least num_bits (this prevents consuming more + // than the total number of bits). + JXL_INLINE void Consume(size_t num_bits) { + JXL_DASSERT(!close_called_); + JXL_DASSERT(bits_in_buf_ >= num_bits); +#ifdef JXL_CRASH_ON_ERROR + // When JXL_CRASH_ON_ERROR is defined, it is a fatal error to read more bits + // than available in the stream. A non-zero overread_bytes_ implies that + // next_byte_ is already at the end of the stream, so we don't need to + // check that. + JXL_ASSERT(bits_in_buf_ >= num_bits + overread_bytes_ * kBitsPerByte); +#endif + bits_in_buf_ -= num_bits; + buf_ >>= num_bits; + } + + JXL_INLINE uint64_t ReadBits(size_t nbits) { + JXL_DASSERT(!close_called_); + Refill(); + const uint64_t bits = PeekBits(nbits); + Consume(nbits); + return bits; + } + + template <size_t N> + JXL_INLINE uint64_t ReadFixedBits() { + JXL_DASSERT(!close_called_); + Refill(); + const uint64_t bits = PeekFixedBits<N>(); + Consume(N); + return bits; + } + + // Equivalent to calling ReadFixedBits(1) `skip` times, but much faster. + // `skip` is typically large. + void SkipBits(size_t skip) { + JXL_DASSERT(!close_called_); + // Buffer is large enough - don't zero buf_ below. + if (JXL_UNLIKELY(skip <= bits_in_buf_)) { + Consume(skip); + return; + } + + // First deduct what we can satisfy from the buffer + skip -= bits_in_buf_; + bits_in_buf_ = 0; + // Not enough to call Advance - that may leave some bits in the buffer + // which were previously ABOVE bits_in_buf. + buf_ = 0; + + // Skip whole bytes + const size_t whole_bytes = skip / kBitsPerByte; + skip %= kBitsPerByte; + if (JXL_UNLIKELY(whole_bytes > + static_cast<size_t>(end_minus_8_ + 8 - next_byte_))) { + // This is already an overflow condition (skipping past the end of the bit + // stream). However if we increase next_byte_ too much we risk overflowing + // that value and potentially making it valid again (next_byte_ < end). + // This will set next_byte_ to the end of the stream and still consume + // some bits in overread_bytes_, however the TotalBitsConsumed() will be + // incorrect (still larger than the TotalBytes()). + next_byte_ = end_minus_8_ + 8; + skip += kBitsPerByte; + } else { + next_byte_ += whole_bytes; + } + + Refill(); + Consume(skip); + } + + size_t TotalBitsConsumed() const { + const size_t bytes_read = static_cast<size_t>(next_byte_ - first_byte_); + return (bytes_read + overread_bytes_) * kBitsPerByte - bits_in_buf_; + } + + Status JumpToByteBoundary() { + const size_t remainder = TotalBitsConsumed() % kBitsPerByte; + if (remainder == 0) return true; + if (JXL_UNLIKELY(ReadBits(kBitsPerByte - remainder) != 0)) { + return JXL_FAILURE("Non-zero padding bits"); + } + return true; + } + + // For interoperability with other bitreaders (for resuming at + // non-byte-aligned positions). + const uint8_t* FirstByte() const { return first_byte_; } + size_t TotalBytes() const { + return static_cast<size_t>(end_minus_8_ + 8 - first_byte_); + } + + // Returns span of the remaining (unconsumed) bytes, e.g. for passing to + // external decoders such as Brotli. + Span<const uint8_t> GetSpan() const { + JXL_DASSERT(first_byte_ != nullptr); + JXL_ASSERT(TotalBitsConsumed() % kBitsPerByte == 0); + const size_t offset = TotalBitsConsumed() / kBitsPerByte; // no remainder + JXL_ASSERT(offset <= TotalBytes()); + return Bytes(first_byte_ + offset, TotalBytes() - offset); + } + + // Returns whether all the bits read so far have been within the input bounds. + // When reading past the EOF, the Read*() and Consume() functions return zeros + // but flag a failure when calling Close() without checking this function. + Status AllReadsWithinBounds() { + // Mark up to which point the user checked the out of bounds condition. If + // the user handles the condition at higher level (e.g. fetch more bytes + // from network, return a custom JXL_FAILURE, ...), Close() should not + // output a debug error (which would break tests with JXL_CRASH_ON_ERROR + // even when legitimately handling the situation at higher level). This is + // used by Bundle::CanRead. + checked_out_of_bounds_bits_ = TotalBitsConsumed(); + if (TotalBitsConsumed() > TotalBytes() * kBitsPerByte) { + return false; + } + return true; + } + + // Close the bit reader and return whether all the previous reads were + // successful. Close must be called once. + Status Close() { + JXL_DASSERT(!close_called_); + close_called_ = true; + if (!first_byte_) return true; + if (TotalBitsConsumed() > checked_out_of_bounds_bits_ && + TotalBitsConsumed() > TotalBytes() * kBitsPerByte) { + return JXL_FAILURE("Read more bits than available in the bit_reader"); + } + return true; + } + + private: + // Separate function avoids inlining this relatively cold code into callers. + JXL_NOINLINE void BoundsCheckedRefill() { + const uint8_t* end = end_minus_8_ + 8; + + // Read whole bytes until we have [56, 64) bits (same as LoadLE64) + for (; bits_in_buf_ < 64 - kBitsPerByte; bits_in_buf_ += kBitsPerByte) { + if (next_byte_ >= end) break; + buf_ |= static_cast<uint64_t>(*next_byte_++) << bits_in_buf_; + } + JXL_DASSERT(bits_in_buf_ < 64); + + // Add extra bytes as 0 at the end of the stream in the bit_buffer_. If + // these bits are read, Close() will return a failure. + size_t extra_bytes = (63 - bits_in_buf_) / kBitsPerByte; + overread_bytes_ += extra_bytes; + bits_in_buf_ += extra_bytes * kBitsPerByte; + + JXL_DASSERT(bits_in_buf_ < 64); + JXL_DASSERT(bits_in_buf_ >= 56); + } + + JXL_NOINLINE uint32_t BoundsCheckedReadByteAlignedWord() { + if (next_byte_ + 1 < end_minus_8_ + 8) { + uint32_t ret = LoadLE16(next_byte_); + next_byte_ += 2; + return ret; + } + overread_bytes_ += 2; + return 0; + } + + uint64_t buf_; + size_t bits_in_buf_; // [0, 64) + const uint8_t* JXL_RESTRICT next_byte_; + const uint8_t* end_minus_8_; // for refill bounds check + const uint8_t* first_byte_; // for GetSpan + + // Number of bytes past the end that were loaded into the buf_. These bytes + // are not read from memory, but instead assumed 0. It is an error (likely due + // to an invalid stream) to Consume() more bits than specified in the range + // passed to the constructor. + uint64_t overread_bytes_{0}; + bool close_called_{false}; + + uint64_t checked_out_of_bounds_bits_{0}; +}; + +// Closes a BitReader when the BitReaderScopedCloser goes out of scope. When +// closing the bit reader, if the status result was failure it sets this failure +// to the passed variable pointer. Typical usage. +// +// Status ret = true; +// { +// BitReader reader(...); +// BitReaderScopedCloser reader_closer(&reader, &ret); +// +// // ... code that can return errors here ... +// } +// // ... more code that doesn't use the BitReader. +// return ret; + +class BitReaderScopedCloser { + public: + BitReaderScopedCloser(BitReader* reader, Status* status) + : reader_(reader), status_(status) { + JXL_DASSERT(reader_ != nullptr); + JXL_DASSERT(status_ != nullptr); + } + ~BitReaderScopedCloser() { + if (reader_ != nullptr) { + Status close_ret = reader_->Close(); + if (!close_ret) *status_ = close_ret; + } + } + void CloseAndSuppressError() { + JXL_ASSERT(reader_ != nullptr); + (void)reader_->Close(); + reader_ = nullptr; + } + BitReaderScopedCloser(const BitReaderScopedCloser&) = delete; + + private: + BitReader* reader_; + Status* status_; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_BIT_READER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_cache.cc b/third_party/jpeg-xl/lib/jxl/dec_cache.cc new file mode 100644 index 0000000000..8d12bce02e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_cache.cc @@ -0,0 +1,264 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_cache.h" + +#include "lib/jxl/blending.h" +#include "lib/jxl/common.h" // JXL_HIGH_PRECISION +#include "lib/jxl/render_pipeline/stage_blending.h" +#include "lib/jxl/render_pipeline/stage_chroma_upsampling.h" +#include "lib/jxl/render_pipeline/stage_cms.h" +#include "lib/jxl/render_pipeline/stage_epf.h" +#include "lib/jxl/render_pipeline/stage_from_linear.h" +#include "lib/jxl/render_pipeline/stage_gaborish.h" +#include "lib/jxl/render_pipeline/stage_noise.h" +#include "lib/jxl/render_pipeline/stage_patches.h" +#include "lib/jxl/render_pipeline/stage_splines.h" +#include "lib/jxl/render_pipeline/stage_spot.h" +#include "lib/jxl/render_pipeline/stage_to_linear.h" +#include "lib/jxl/render_pipeline/stage_tone_mapping.h" +#include "lib/jxl/render_pipeline/stage_upsampling.h" +#include "lib/jxl/render_pipeline/stage_write.h" +#include "lib/jxl/render_pipeline/stage_xyb.h" +#include "lib/jxl/render_pipeline/stage_ycbcr.h" + +namespace jxl { + +Status PassesDecoderState::PreparePipeline(const FrameHeader& frame_header, + ImageBundle* decoded, + PipelineOptions options) { + size_t num_c = 3 + frame_header.nonserialized_metadata->m.num_extra_channels; + if (options.render_noise && (frame_header.flags & FrameHeader::kNoise) != 0) { + num_c += 3; + } + + if (frame_header.CanBeReferenced()) { + // Necessary so that SetInputSizes() can allocate output buffers as needed. + frame_storage_for_referencing = ImageBundle(decoded->metadata()); + } + + RenderPipeline::Builder builder(num_c); + + if (options.use_slow_render_pipeline) { + builder.UseSimpleImplementation(); + } + + if (!frame_header.chroma_subsampling.Is444()) { + for (size_t c = 0; c < 3; c++) { + if (frame_header.chroma_subsampling.HShift(c) != 0) { + builder.AddStage(GetChromaUpsamplingStage(c, /*horizontal=*/true)); + } + if (frame_header.chroma_subsampling.VShift(c) != 0) { + builder.AddStage(GetChromaUpsamplingStage(c, /*horizontal=*/false)); + } + } + } + + if (frame_header.loop_filter.gab) { + builder.AddStage(GetGaborishStage(frame_header.loop_filter)); + } + + { + const LoopFilter& lf = frame_header.loop_filter; + if (lf.epf_iters >= 3) { + builder.AddStage(GetEPFStage(lf, sigma, 0)); + } + if (lf.epf_iters >= 1) { + builder.AddStage(GetEPFStage(lf, sigma, 1)); + } + if (lf.epf_iters >= 2) { + builder.AddStage(GetEPFStage(lf, sigma, 2)); + } + } + + bool late_ec_upsample = frame_header.upsampling != 1; + for (auto ecups : frame_header.extra_channel_upsampling) { + if (ecups != frame_header.upsampling) { + // If patches are applied, either frame_header.upsampling == 1 or + // late_ec_upsample is true. + late_ec_upsample = false; + } + } + + if (!late_ec_upsample) { + for (size_t ec = 0; ec < frame_header.extra_channel_upsampling.size(); + ec++) { + if (frame_header.extra_channel_upsampling[ec] != 1) { + builder.AddStage(GetUpsamplingStage( + frame_header.nonserialized_metadata->transform_data, 3 + ec, + CeilLog2Nonzero(frame_header.extra_channel_upsampling[ec]))); + } + } + } + + if ((frame_header.flags & FrameHeader::kPatches) != 0) { + builder.AddStage( + GetPatchesStage(&shared->image_features.patches, + 3 + shared->metadata->m.num_extra_channels)); + } + if ((frame_header.flags & FrameHeader::kSplines) != 0) { + builder.AddStage(GetSplineStage(&shared->image_features.splines)); + } + + if (frame_header.upsampling != 1) { + size_t nb_channels = + 3 + + (late_ec_upsample ? frame_header.extra_channel_upsampling.size() : 0); + for (size_t c = 0; c < nb_channels; c++) { + builder.AddStage(GetUpsamplingStage( + frame_header.nonserialized_metadata->transform_data, c, + CeilLog2Nonzero(frame_header.upsampling))); + } + } + if (options.render_noise && (frame_header.flags & FrameHeader::kNoise) != 0) { + builder.AddStage(GetConvolveNoiseStage(num_c - 3)); + builder.AddStage(GetAddNoiseStage(shared->image_features.noise_params, + shared->cmap, num_c - 3)); + } + if (frame_header.dc_level != 0) { + builder.AddStage(GetWriteToImage3FStage( + &shared_storage.dc_frames[frame_header.dc_level - 1])); + } + + if (frame_header.CanBeReferenced() && + frame_header.save_before_color_transform) { + builder.AddStage(GetWriteToImageBundleStage( + &frame_storage_for_referencing, output_encoding_info.color_encoding)); + } + + bool has_alpha = false; + size_t alpha_c = 0; + for (size_t i = 0; i < decoded->metadata()->extra_channel_info.size(); i++) { + if (decoded->metadata()->extra_channel_info[i].type == + ExtraChannel::kAlpha) { + has_alpha = true; + alpha_c = 3 + i; + break; + } + } + + if (fast_xyb_srgb8_conversion) { +#if !JXL_HIGH_PRECISION + JXL_ASSERT(!NeedsBlending(frame_header)); + JXL_ASSERT(!frame_header.CanBeReferenced() || + frame_header.save_before_color_transform); + JXL_ASSERT(!options.render_spotcolors || + !decoded->metadata()->Find(ExtraChannel::kSpotColor)); + bool is_rgba = (main_output.format.num_channels == 4); + uint8_t* rgb_output = reinterpret_cast<uint8_t*>(main_output.buffer); + builder.AddStage(GetFastXYBTosRGB8Stage(rgb_output, main_output.stride, + width, height, is_rgba, has_alpha, + alpha_c)); +#endif + } else { + bool linear = false; + if (frame_header.color_transform == ColorTransform::kYCbCr) { + builder.AddStage(GetYCbCrStage()); + } else if (frame_header.color_transform == ColorTransform::kXYB) { + builder.AddStage(GetXYBStage(output_encoding_info)); + if (output_encoding_info.color_encoding.GetColorSpace() != + ColorSpace::kXYB) { + linear = true; + } + } // Nothing to do for kNone. + + if (options.coalescing && NeedsBlending(frame_header)) { + if (linear) { + builder.AddStage(GetFromLinearStage(output_encoding_info)); + linear = false; + } + builder.AddStage(GetBlendingStage(frame_header, this, + output_encoding_info.color_encoding)); + } + + if (options.coalescing && frame_header.CanBeReferenced() && + !frame_header.save_before_color_transform) { + if (linear) { + builder.AddStage(GetFromLinearStage(output_encoding_info)); + linear = false; + } + builder.AddStage(GetWriteToImageBundleStage( + &frame_storage_for_referencing, output_encoding_info.color_encoding)); + } + + if (options.render_spotcolors && + frame_header.nonserialized_metadata->m.Find(ExtraChannel::kSpotColor)) { + for (size_t i = 0; i < decoded->metadata()->extra_channel_info.size(); + i++) { + // Don't use Find() because there may be multiple spot color channels. + const ExtraChannelInfo& eci = + decoded->metadata()->extra_channel_info[i]; + if (eci.type == ExtraChannel::kSpotColor) { + builder.AddStage(GetSpotColorStage(3 + i, eci.spot_color)); + } + } + } + + auto tone_mapping_stage = GetToneMappingStage(output_encoding_info); + if (tone_mapping_stage) { + if (!linear) { + auto to_linear_stage = GetToLinearStage(output_encoding_info); + if (!to_linear_stage) { + if (!output_encoding_info.cms_set) { + return JXL_FAILURE("Cannot tonemap this colorspace without a CMS"); + } + auto cms_stage = GetCmsStage(output_encoding_info); + if (cms_stage) { + builder.AddStage(std::move(cms_stage)); + } + } else { + builder.AddStage(std::move(to_linear_stage)); + } + linear = true; + } + builder.AddStage(std::move(tone_mapping_stage)); + } + + if (linear) { + const size_t channels_src = + (output_encoding_info.orig_color_encoding.IsCMYK() + ? 4 + : output_encoding_info.orig_color_encoding.Channels()); + const size_t channels_dst = + output_encoding_info.color_encoding.Channels(); + bool mixing_color_and_grey = (channels_dst != channels_src); + if ((output_encoding_info.color_encoding_is_original) || + (!output_encoding_info.cms_set) || mixing_color_and_grey) { + // in those cases we only need a linear stage in other cases we attempt + // to obtain an cms stage: the cases are + // - output_encoding_info.color_encoding_is_original: no cms stage + // needed because it would be a no-op + // - !output_encoding_info.cms_set: can't use the cms, so no point in + // trying to add a cms stage + // - mixing_color_and_grey: cms stage can't handle that + // TODO(firsching): remove "mixing_color_and_grey" condition after + // adding support for greyscale to cms stage. + builder.AddStage(GetFromLinearStage(output_encoding_info)); + } else { + if (!output_encoding_info.linear_color_encoding.CreateICC()) { + return JXL_FAILURE("Failed to create ICC"); + } + auto cms_stage = GetCmsStage(output_encoding_info); + if (cms_stage) { + builder.AddStage(std::move(cms_stage)); + } + } + linear = false; + } + + if (main_output.callback.IsPresent() || main_output.buffer) { + builder.AddStage(GetWriteToOutputStage(main_output, width, height, + has_alpha, unpremul_alpha, alpha_c, + undo_orientation, extra_output)); + } else { + builder.AddStage(GetWriteToImageBundleStage( + decoded, output_encoding_info.color_encoding)); + } + } + render_pipeline = std::move(builder).Finalize(shared->frame_dim); + return render_pipeline->IsInitialized(); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_cache.h b/third_party/jpeg-xl/lib/jxl/dec_cache.h new file mode 100644 index 0000000000..d4cc7a1957 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_cache.h @@ -0,0 +1,265 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_CACHE_H_ +#define LIB_JXL_DEC_CACHE_H_ + +#include <jxl/decode.h> +#include <jxl/types.h> +#include <stdint.h> + +#include <algorithm> +#include <atomic> +#include <cmath> +#include <hwy/base.h> // HWY_ALIGN_MAX +#include <memory> +#include <vector> + +#include "hwy/aligned_allocator.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/common.h" // kMaxNumPasses +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/render_pipeline/render_pipeline.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" +#include "lib/jxl/render_pipeline/stage_upsampling.h" + +namespace jxl { + +constexpr size_t kSigmaBorder = 1; +constexpr size_t kSigmaPadding = 2; + +struct PixelCallback { + PixelCallback() = default; + PixelCallback(JxlImageOutInitCallback init, JxlImageOutRunCallback run, + JxlImageOutDestroyCallback destroy, void* init_opaque) + : init(init), run(run), destroy(destroy), init_opaque(init_opaque) { +#if JXL_ENABLE_ASSERT + const bool has_init = init != nullptr; + const bool has_run = run != nullptr; + const bool has_destroy = destroy != nullptr; + JXL_ASSERT(has_init == has_run && has_run == has_destroy); +#endif + } + + bool IsPresent() const { return run != nullptr; } + + void* Init(size_t num_threads, size_t num_pixels) const { + return init(init_opaque, num_threads, num_pixels); + } + + JxlImageOutInitCallback init = nullptr; + JxlImageOutRunCallback run = nullptr; + JxlImageOutDestroyCallback destroy = nullptr; + void* init_opaque = nullptr; +}; + +struct ImageOutput { + // Pixel format of the output pixels, used for buffer and callback output. + JxlPixelFormat format; + // Output bit depth for unsigned data types, used for float to int conversion. + size_t bits_per_sample; + // Callback for line-by-line output. + PixelCallback callback; + // Pixel buffer for image output. + void* buffer; + size_t buffer_size; + // Length of a row of image_buffer in bytes (based on oriented width). + size_t stride; +}; + +// Per-frame decoder state. All the images here should be accessed through a +// group rect (either with block units or pixel units). +struct PassesDecoderState { + PassesSharedState shared_storage; + // Allows avoiding copies for encoder loop. + const PassesSharedState* JXL_RESTRICT shared = &shared_storage; + + // 8x upsampling stage for DC. + std::unique_ptr<RenderPipelineStage> upsampler8x; + + // For ANS decoding. + std::vector<ANSCode> code; + std::vector<std::vector<uint8_t>> context_map; + + // Multiplier to be applied to the quant matrices of the x channel. + float x_dm_multiplier; + float b_dm_multiplier; + + // Sigma values for EPF. + ImageF sigma; + + // Image dimensions before applying undo_orientation. + size_t width; + size_t height; + ImageOutput main_output; + std::vector<ImageOutput> extra_output; + + // Whether to use int16 float-XYB-to-uint8-srgb conversion. + bool fast_xyb_srgb8_conversion; + + // If true, the RGBA output will be unpremultiplied before writing to the + // output. + bool unpremul_alpha; + + // The render pipeline will apply this orientation to bring the image to the + // intended display orientation. + Orientation undo_orientation; + + // Used for seeding noise. + size_t visible_frame_index = 0; + size_t nonvisible_frame_index = 0; + + // Keep track of the transform types used. + std::atomic<uint32_t> used_acs{0}; + + // Storage for coefficients if in "accumulate" mode. + std::unique_ptr<ACImage> coefficients = make_unique<ACImageT<int32_t>>(0, 0); + + // Rendering pipeline. + std::unique_ptr<RenderPipeline> render_pipeline; + + // Storage for the current frame if it can be referenced by future frames. + ImageBundle frame_storage_for_referencing; + + struct PipelineOptions { + bool use_slow_render_pipeline; + bool coalescing; + bool render_spotcolors; + bool render_noise; + }; + + Status PreparePipeline(const FrameHeader& frame_header, ImageBundle* decoded, + PipelineOptions options); + + // Information for colour conversions. + OutputEncodingInfo output_encoding_info; + + // Initializes decoder-specific structures using information from *shared. + Status Init(const FrameHeader& frame_header) { + x_dm_multiplier = std::pow(1 / (1.25f), frame_header.x_qm_scale - 2.0f); + b_dm_multiplier = std::pow(1 / (1.25f), frame_header.b_qm_scale - 2.0f); + + main_output.callback = PixelCallback(); + main_output.buffer = nullptr; + extra_output.clear(); + + fast_xyb_srgb8_conversion = false; + unpremul_alpha = false; + undo_orientation = Orientation::kIdentity; + + used_acs = 0; + + upsampler8x = GetUpsamplingStage(shared->metadata->transform_data, 0, 3); + if (frame_header.loop_filter.epf_iters > 0) { + sigma = ImageF(shared->frame_dim.xsize_blocks + 2 * kSigmaPadding, + shared->frame_dim.ysize_blocks + 2 * kSigmaPadding); + } + return true; + } + + // Initialize the decoder state after all of DC is decoded. + Status InitForAC(size_t num_passes, ThreadPool* pool) { + shared_storage.coeff_order_size = 0; + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + if (((1 << o) & used_acs) == 0) continue; + uint8_t ord = kStrategyOrder[o]; + shared_storage.coeff_order_size = + std::max(kCoeffOrderOffset[3 * (ord + 1)] * kDCTBlockSize, + shared_storage.coeff_order_size); + } + size_t sz = num_passes * shared_storage.coeff_order_size; + if (sz > shared_storage.coeff_orders.size()) { + shared_storage.coeff_orders.resize(sz); + } + return true; + } +}; + +// Temp images required for decoding a single group. Reduces memory allocations +// for large images because we only initialize min(#threads, #groups) instances. +struct GroupDecCache { + void InitOnce(size_t num_passes, size_t used_acs) { + for (size_t i = 0; i < num_passes; i++) { + if (num_nzeroes[i].xsize() == 0) { + // Allocate enough for a whole group - partial groups on the + // right/bottom border just use a subset. The valid size is passed via + // Rect. + + num_nzeroes[i] = Image3I(kGroupDimInBlocks, kGroupDimInBlocks); + } + } + size_t max_block_area = 0; + + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + AcStrategy acs = AcStrategy::FromRawStrategy(o); + if ((used_acs & (1 << o)) == 0) continue; + size_t area = + acs.covered_blocks_x() * acs.covered_blocks_y() * kDCTBlockSize; + max_block_area = std::max(area, max_block_area); + } + + if (max_block_area > max_block_area_) { + max_block_area_ = max_block_area; + // We need 3x float blocks for dequantized coefficients and 1x for scratch + // space for transforms. + float_memory_ = hwy::AllocateAligned<float>(max_block_area_ * 7); + // We need 3x int32 or int16 blocks for quantized coefficients. + int32_memory_ = hwy::AllocateAligned<int32_t>(max_block_area_ * 3); + int16_memory_ = hwy::AllocateAligned<int16_t>(max_block_area_ * 3); + } + + dec_group_block = float_memory_.get(); + scratch_space = dec_group_block + max_block_area_ * 3; + dec_group_qblock = int32_memory_.get(); + dec_group_qblock16 = int16_memory_.get(); + } + + void InitDCBufferOnce() { + if (dc_buffer.xsize() == 0) { + dc_buffer = ImageF(kGroupDimInBlocks + kRenderPipelineXOffset * 2, + kGroupDimInBlocks + 4); + } + } + + // Scratch space used by DecGroupImpl(). + float* dec_group_block; + int32_t* dec_group_qblock; + int16_t* dec_group_qblock16; + + // For TransformToPixels. + float* scratch_space; + // Note that scratch_space is never used at the same time as dec_group_qblock. + // Moreover, only one of dec_group_qblock16 is ever used. + // TODO(veluca): figure out if we can save allocations. + + // AC decoding + Image3I num_nzeroes[kMaxNumPasses]; + + // Buffer for DC upsampling. + ImageF dc_buffer; + + private: + hwy::AlignedFreeUniquePtr<float[]> float_memory_; + hwy::AlignedFreeUniquePtr<int32_t[]> int32_memory_; + hwy::AlignedFreeUniquePtr<int16_t[]> int16_memory_; + size_t max_block_area_ = 0; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_CACHE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_context_map.cc b/third_party/jpeg-xl/lib/jxl/dec_context_map.cc new file mode 100644 index 0000000000..2c936722da --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_context_map.cc @@ -0,0 +1,89 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_context_map.h" + +#include <algorithm> +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/inverse_mtf-inl.h" + +namespace jxl { + +namespace { + +Status VerifyContextMap(const std::vector<uint8_t>& context_map, + const size_t num_htrees) { + std::vector<bool> have_htree(num_htrees); + size_t num_found = 0; + for (const uint8_t htree : context_map) { + if (htree >= num_htrees) { + return JXL_FAILURE("Invalid histogram index in context map."); + } + if (!have_htree[htree]) { + have_htree[htree] = true; + ++num_found; + } + } + if (num_found != num_htrees) { + return JXL_FAILURE("Incomplete context map."); + } + return true; +} + +} // namespace + +Status DecodeContextMap(std::vector<uint8_t>* context_map, size_t* num_htrees, + BitReader* input) { + bool is_simple = input->ReadFixedBits<1>(); + if (is_simple) { + int bits_per_entry = input->ReadFixedBits<2>(); + if (bits_per_entry != 0) { + for (size_t i = 0; i < context_map->size(); i++) { + (*context_map)[i] = input->ReadBits(bits_per_entry); + } + } else { + std::fill(context_map->begin(), context_map->end(), 0); + } + } else { + bool use_mtf = input->ReadFixedBits<1>(); + ANSCode code; + std::vector<uint8_t> sink_ctx_map; + // Usage of LZ77 is disallowed if decoding only two symbols. This doesn't + // make sense in non-malicious bitstreams, and could cause a stack overflow + // in malicious bitstreams by making every context map require its own + // context map. + JXL_RETURN_IF_ERROR( + DecodeHistograms(input, 1, &code, &sink_ctx_map, + /*disallow_lz77=*/context_map->size() <= 2)); + ANSSymbolReader reader(&code, input); + size_t i = 0; + uint32_t maxsym = 0; + while (i < context_map->size()) { + uint32_t sym = reader.ReadHybridUintInlined</*uses_lz77=*/true>( + 0, input, sink_ctx_map); + maxsym = sym > maxsym ? sym : maxsym; + (*context_map)[i] = sym; + i++; + } + if (maxsym >= kMaxClusters) { + return JXL_FAILURE("Invalid cluster ID"); + } + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("Invalid context map"); + } + if (use_mtf) { + InverseMoveToFrontTransform(context_map->data(), context_map->size()); + } + } + *num_htrees = *std::max_element(context_map->begin(), context_map->end()) + 1; + return VerifyContextMap(*context_map, *num_htrees); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_context_map.h b/third_party/jpeg-xl/lib/jxl/dec_context_map.h new file mode 100644 index 0000000000..95b8a0ca92 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_context_map.h @@ -0,0 +1,30 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_CONTEXT_MAP_H_ +#define LIB_JXL_DEC_CONTEXT_MAP_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/dec_bit_reader.h" + +namespace jxl { + +// Context map uses uint8_t. +constexpr size_t kMaxClusters = 256; + +// Reads the context map from the bit stream. On calling this function, +// context_map->size() must be the number of possible context ids. +// Sets *num_htrees to the number of different histogram ids in +// *context_map. +Status DecodeContextMap(std::vector<uint8_t>* context_map, size_t* num_htrees, + BitReader* input); + +} // namespace jxl + +#endif // LIB_JXL_DEC_CONTEXT_MAP_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_external_image.cc b/third_party/jpeg-xl/lib/jxl/dec_external_image.cc new file mode 100644 index 0000000000..06cd573378 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_external_image.cc @@ -0,0 +1,482 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_external_image.h" + +#include <jxl/types.h> +#include <string.h> + +#include <algorithm> +#include <array> +#include <functional> +#include <utility> +#include <vector> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_external_image.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/sanitizers.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Clamp; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::NearestInt; + +// TODO(jon): check if this can be replaced by a FloatToU16 function +void FloatToU32(const float* in, uint32_t* out, size_t num, float mul, + size_t bits_per_sample) { + const HWY_FULL(float) d; + const hwy::HWY_NAMESPACE::Rebind<uint32_t, decltype(d)> du; + + // Unpoison accessing partially-uninitialized vectors with memory sanitizer. + // This is because we run NearestInt() on the vector, which triggers MSAN even + // it is safe to do so since the values are not mixed between lanes. + const size_t num_round_up = RoundUpTo(num, Lanes(d)); + msan::UnpoisonMemory(in + num, sizeof(in[0]) * (num_round_up - num)); + + const auto one = Set(d, 1.0f); + const auto scale = Set(d, mul); + for (size_t x = 0; x < num; x += Lanes(d)) { + auto v = Load(d, in + x); + // Clamp turns NaN to 'min'. + v = Clamp(v, Zero(d), one); + auto i = NearestInt(Mul(v, scale)); + Store(BitCast(du, i), du, out + x); + } + + // Poison back the output. + msan::PoisonMemory(out + num, sizeof(out[0]) * (num_round_up - num)); +} + +void FloatToF16(const float* in, hwy::float16_t* out, size_t num) { + const HWY_FULL(float) d; + const hwy::HWY_NAMESPACE::Rebind<hwy::float16_t, decltype(d)> du; + + // Unpoison accessing partially-uninitialized vectors with memory sanitizer. + // This is because we run DemoteTo() on the vector which triggers msan. + const size_t num_round_up = RoundUpTo(num, Lanes(d)); + msan::UnpoisonMemory(in + num, sizeof(in[0]) * (num_round_up - num)); + + for (size_t x = 0; x < num; x += Lanes(d)) { + auto v = Load(d, in + x); + auto v16 = DemoteTo(du, v); + Store(v16, du, out + x); + } + + // Poison back the output. + msan::PoisonMemory(out + num, sizeof(out[0]) * (num_round_up - num)); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { +namespace { + +// Stores a float in big endian +void StoreBEFloat(float value, uint8_t* p) { + uint32_t u; + memcpy(&u, &value, 4); + StoreBE32(u, p); +} + +// Stores a float in little endian +void StoreLEFloat(float value, uint8_t* p) { + uint32_t u; + memcpy(&u, &value, 4); + StoreLE32(u, p); +} + +// The orientation may not be identity. +// TODO(lode): SIMDify where possible +template <typename T> +Status UndoOrientation(jxl::Orientation undo_orientation, const Plane<T>& image, + Plane<T>& out, jxl::ThreadPool* pool) { + const size_t xsize = image.xsize(); + const size_t ysize = image.ysize(); + + if (undo_orientation == Orientation::kFlipHorizontal) { + out = Plane<T>(xsize, ysize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + T* JXL_RESTRICT row_out = out.Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[xsize - x - 1] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kRotate180) { + out = Plane<T>(xsize, ysize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + T* JXL_RESTRICT row_out = out.Row(ysize - y - 1); + for (size_t x = 0; x < xsize; ++x) { + row_out[xsize - x - 1] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kFlipVertical) { + out = Plane<T>(xsize, ysize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + T* JXL_RESTRICT row_out = out.Row(ysize - y - 1); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kTranspose) { + out = Plane<T>(ysize, xsize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(x)[y] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kRotate90) { + out = Plane<T>(ysize, xsize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(x)[ysize - y - 1] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kAntiTranspose) { + out = Plane<T>(ysize, xsize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(xsize - x - 1)[ysize - y - 1] = row_in[x]; + } + }, + "UndoOrientation")); + } else if (undo_orientation == Orientation::kRotate270) { + out = Plane<T>(ysize, xsize); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const int64_t y = task; + const T* JXL_RESTRICT row_in = image.Row(y); + for (size_t x = 0; x < xsize; ++x) { + out.Row(xsize - x - 1)[y] = row_in[x]; + } + }, + "UndoOrientation")); + } + return true; +} +} // namespace + +HWY_EXPORT(FloatToU32); +HWY_EXPORT(FloatToF16); + +namespace { + +using StoreFuncType = void(uint32_t value, uint8_t* dest); +template <StoreFuncType StoreFunc> +void StoreUintRow(uint32_t* JXL_RESTRICT* rows_u32, size_t num_channels, + size_t xsize, size_t bytes_per_sample, + uint8_t* JXL_RESTRICT out) { + for (size_t x = 0; x < xsize; ++x) { + for (size_t c = 0; c < num_channels; c++) { + StoreFunc(rows_u32[c][x], + out + (num_channels * x + c) * bytes_per_sample); + } + } +} + +template <void(StoreFunc)(float, uint8_t*)> +void StoreFloatRow(const float* JXL_RESTRICT* rows_in, size_t num_channels, + size_t xsize, uint8_t* JXL_RESTRICT out) { + for (size_t x = 0; x < xsize; ++x) { + for (size_t c = 0; c < num_channels; c++) { + StoreFunc(rows_in[c][x], out + (num_channels * x + c) * sizeof(float)); + } + } +} + +void JXL_INLINE Store8(uint32_t value, uint8_t* dest) { *dest = value & 0xff; } + +} // namespace + +Status ConvertChannelsToExternal(const ImageF* in_channels[], + size_t num_channels, size_t bits_per_sample, + bool float_out, JxlEndianness endianness, + size_t stride, jxl::ThreadPool* pool, + void* out_image, size_t out_size, + const PixelCallback& out_callback, + jxl::Orientation undo_orientation) { + JXL_DASSERT(num_channels != 0 && num_channels <= kConvertMaxChannels); + JXL_DASSERT(in_channels[0] != nullptr); + JXL_CHECK(float_out ? bits_per_sample == 16 || bits_per_sample == 32 + : bits_per_sample > 0 && bits_per_sample <= 16); + if (!!out_image == out_callback.IsPresent()) { + return JXL_FAILURE( + "Must provide either an out_image or an out_callback, but not both."); + } + std::vector<const ImageF*> channels; + channels.assign(in_channels, in_channels + num_channels); + + const size_t bytes_per_channel = DivCeil(bits_per_sample, jxl::kBitsPerByte); + const size_t bytes_per_pixel = num_channels * bytes_per_channel; + + std::vector<std::vector<uint8_t>> row_out_callback; + const auto FreeCallbackOpaque = [&out_callback](void* p) { + out_callback.destroy(p); + }; + std::unique_ptr<void, decltype(FreeCallbackOpaque)> out_run_opaque( + nullptr, FreeCallbackOpaque); + auto InitOutCallback = [&](size_t num_threads) -> Status { + if (out_callback.IsPresent()) { + out_run_opaque.reset(out_callback.Init(num_threads, stride)); + JXL_RETURN_IF_ERROR(out_run_opaque != nullptr); + row_out_callback.resize(num_threads); + for (size_t i = 0; i < num_threads; ++i) { + row_out_callback[i].resize(stride); + } + } + return true; + }; + + // Channels used to store the transformed original channels if needed. + ImageF temp_channels[kConvertMaxChannels]; + if (undo_orientation != Orientation::kIdentity) { + for (size_t c = 0; c < num_channels; ++c) { + if (channels[c]) { + JXL_RETURN_IF_ERROR(UndoOrientation(undo_orientation, *channels[c], + temp_channels[c], pool)); + channels[c] = &(temp_channels[c]); + } + } + } + + // First channel may not be nullptr. + size_t xsize = channels[0]->xsize(); + size_t ysize = channels[0]->ysize(); + if (stride < bytes_per_pixel * xsize) { + return JXL_FAILURE("stride is smaller than scanline width in bytes: %" PRIuS + " vs %" PRIuS, + stride, bytes_per_pixel * xsize); + } + if (!out_callback.IsPresent() && + out_size < (ysize - 1) * stride + bytes_per_pixel * xsize) { + return JXL_FAILURE("out_size is too small to store image"); + } + + const bool little_endian = + endianness == JXL_LITTLE_ENDIAN || + (endianness == JXL_NATIVE_ENDIAN && IsLittleEndian()); + + // Handle the case where a channel is nullptr by creating a single row with + // ones to use instead. + ImageF ones; + for (size_t c = 0; c < num_channels; ++c) { + if (!channels[c]) { + ones = ImageF(xsize, 1); + FillImage(1.0f, &ones); + break; + } + } + + if (float_out) { + if (bits_per_sample == 16) { + bool swap_endianness = little_endian != IsLittleEndian(); + Plane<hwy::float16_t> f16_cache; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), + [&](size_t num_threads) { + f16_cache = + Plane<hwy::float16_t>(xsize, num_channels * num_threads); + return InitOutCallback(num_threads); + }, + [&](const uint32_t task, const size_t thread) { + const int64_t y = task; + const float* JXL_RESTRICT row_in[kConvertMaxChannels]; + for (size_t c = 0; c < num_channels; c++) { + row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0); + } + hwy::float16_t* JXL_RESTRICT row_f16[kConvertMaxChannels]; + for (size_t c = 0; c < num_channels; c++) { + row_f16[c] = f16_cache.Row(c + thread * num_channels); + HWY_DYNAMIC_DISPATCH(FloatToF16) + (row_in[c], row_f16[c], xsize); + } + uint8_t* row_out = + out_callback.IsPresent() + ? row_out_callback[thread].data() + : &(reinterpret_cast<uint8_t*>(out_image))[stride * y]; + // interleave the one scanline + hwy::float16_t* row_f16_out = + reinterpret_cast<hwy::float16_t*>(row_out); + for (size_t x = 0; x < xsize; x++) { + for (size_t c = 0; c < num_channels; c++) { + row_f16_out[x * num_channels + c] = row_f16[c][x]; + } + } + if (swap_endianness) { + size_t size = xsize * num_channels * 2; + for (size_t i = 0; i < size; i += 2) { + std::swap(row_out[i + 0], row_out[i + 1]); + } + } + if (out_callback.IsPresent()) { + out_callback.run(out_run_opaque.get(), thread, 0, y, xsize, + row_out); + } + }, + "ConvertF16")); + } else if (bits_per_sample == 32) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), + [&](size_t num_threads) { return InitOutCallback(num_threads); }, + [&](const uint32_t task, const size_t thread) { + const int64_t y = task; + uint8_t* row_out = + out_callback.IsPresent() + ? row_out_callback[thread].data() + : &(reinterpret_cast<uint8_t*>(out_image))[stride * y]; + const float* JXL_RESTRICT row_in[kConvertMaxChannels]; + for (size_t c = 0; c < num_channels; c++) { + row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0); + } + if (little_endian) { + StoreFloatRow<StoreLEFloat>(row_in, num_channels, xsize, row_out); + } else { + StoreFloatRow<StoreBEFloat>(row_in, num_channels, xsize, row_out); + } + if (out_callback.IsPresent()) { + out_callback.run(out_run_opaque.get(), thread, 0, y, xsize, + row_out); + } + }, + "ConvertFloat")); + } else { + return JXL_FAILURE("float other than 16-bit and 32-bit not supported"); + } + } else { + // Multiplier to convert from floating point 0-1 range to the integer + // range. + float mul = (1ull << bits_per_sample) - 1; + Plane<uint32_t> u32_cache; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, static_cast<uint32_t>(ysize), + [&](size_t num_threads) { + u32_cache = Plane<uint32_t>(xsize, num_channels * num_threads); + return InitOutCallback(num_threads); + }, + [&](const uint32_t task, const size_t thread) { + const int64_t y = task; + uint8_t* row_out = + out_callback.IsPresent() + ? row_out_callback[thread].data() + : &(reinterpret_cast<uint8_t*>(out_image))[stride * y]; + const float* JXL_RESTRICT row_in[kConvertMaxChannels]; + for (size_t c = 0; c < num_channels; c++) { + row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0); + } + uint32_t* JXL_RESTRICT row_u32[kConvertMaxChannels]; + for (size_t c = 0; c < num_channels; c++) { + row_u32[c] = u32_cache.Row(c + thread * num_channels); + // row_u32[] is a per-thread temporary row storage, this isn't + // intended to be initialized on a previous run. + msan::PoisonMemory(row_u32[c], xsize * sizeof(row_u32[c][0])); + HWY_DYNAMIC_DISPATCH(FloatToU32) + (row_in[c], row_u32[c], xsize, mul, bits_per_sample); + } + if (bits_per_sample <= 8) { + StoreUintRow<Store8>(row_u32, num_channels, xsize, 1, row_out); + } else { + if (little_endian) { + StoreUintRow<StoreLE16>(row_u32, num_channels, xsize, 2, row_out); + } else { + StoreUintRow<StoreBE16>(row_u32, num_channels, xsize, 2, row_out); + } + } + if (out_callback.IsPresent()) { + out_callback.run(out_run_opaque.get(), thread, 0, y, xsize, + row_out); + } + }, + "ConvertUint")); + } + return true; +} + +Status ConvertToExternal(const jxl::ImageBundle& ib, size_t bits_per_sample, + bool float_out, size_t num_channels, + JxlEndianness endianness, size_t stride, + jxl::ThreadPool* pool, void* out_image, + size_t out_size, const PixelCallback& out_callback, + jxl::Orientation undo_orientation, + bool unpremul_alpha) { + bool want_alpha = num_channels == 2 || num_channels == 4; + size_t color_channels = num_channels <= 2 ? 1 : 3; + + const Image3F* color = &ib.color(); + // Undo premultiplied alpha. + Image3F unpremul; + if (ib.AlphaIsPremultiplied() && ib.HasAlpha() && unpremul_alpha) { + unpremul = Image3F(color->xsize(), color->ysize()); + CopyImageTo(*color, &unpremul); + for (size_t y = 0; y < unpremul.ysize(); y++) { + UnpremultiplyAlpha(unpremul.PlaneRow(0, y), unpremul.PlaneRow(1, y), + unpremul.PlaneRow(2, y), ib.alpha().Row(y), + unpremul.xsize()); + } + color = &unpremul; + } + + const ImageF* channels[kConvertMaxChannels]; + size_t c = 0; + for (; c < color_channels; c++) { + channels[c] = &color->Plane(c); + } + if (want_alpha) { + channels[c++] = ib.HasAlpha() ? &ib.alpha() : nullptr; + } + JXL_ASSERT(num_channels == c); + + return ConvertChannelsToExternal( + channels, num_channels, bits_per_sample, float_out, endianness, stride, + pool, out_image, out_size, out_callback, undo_orientation); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dec_external_image.h b/third_party/jpeg-xl/lib/jxl/dec_external_image.h new file mode 100644 index 0000000000..7a29981694 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_external_image.h @@ -0,0 +1,65 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_EXTERNAL_IMAGE_H_ +#define LIB_JXL_DEC_EXTERNAL_IMAGE_H_ + +// Interleaved image for color transforms and Codec. + +#include <jxl/types.h> +#include <stddef.h> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" + +namespace jxl { + +// Maximum number of channels for the ConvertChannelsToExternal function. +const size_t kConvertMaxChannels = 4; + +// Converts a list of channels to an interleaved image, applying transformations +// when needed. +// The input channels are given as a (non-const!) array of channel pointers and +// interleaved in that order. +// +// Note: if a pointer in channels[] is nullptr, a 1.0 value will be used +// instead. This is useful for handling when a user requests an alpha channel +// from an image that doesn't have one. The first channel in the list may not +// be nullptr, since it is used to determine the image size. +Status ConvertChannelsToExternal(const ImageF* in_channels[], + size_t num_channels, size_t bits_per_sample, + bool float_out, JxlEndianness endianness, + size_t stride, jxl::ThreadPool* pool, + void* out_image, size_t out_size, + const PixelCallback& out_callback, + jxl::Orientation undo_orientation); + +// Converts ib to interleaved void* pixel buffer with the given format. +// bits_per_sample: must be 16 or 32 if float_out is true, and at most 16 +// if it is false. No bit packing is done. +// num_channels: must be 1, 2, 3 or 4 for gray, gray+alpha, RGB, RGB+alpha. +// This supports the features needed for the C API and does not perform +// color space conversion. +// TODO(lode): support rectangle crop. +// stride_out is output scanline size in bytes, must be >= +// output_xsize * output_bytes_per_pixel. +// undo_orientation is an EXIF orientation to undo. Depending on the +// orientation, the output xsize and ysize are swapped compared to input +// xsize and ysize. +Status ConvertToExternal(const jxl::ImageBundle& ib, size_t bits_per_sample, + bool float_out, size_t num_channels, + JxlEndianness endianness, size_t stride_out, + jxl::ThreadPool* thread_pool, void* out_image, + size_t out_size, const PixelCallback& out_callback, + jxl::Orientation undo_orientation, + bool unpremul_alpha = false); + +} // namespace jxl + +#endif // LIB_JXL_DEC_EXTERNAL_IMAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_external_image_gbench.cc b/third_party/jpeg-xl/lib/jxl/dec_external_image_gbench.cc new file mode 100644 index 0000000000..c87a4d5f36 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_external_image_gbench.cc @@ -0,0 +1,56 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "benchmark/benchmark.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { +namespace { + +// Decoder case, interleaves an internal float image. +void BM_DecExternalImage_ConvertImageRGBA(benchmark::State& state) { + const size_t kNumIter = 5; + size_t xsize = state.range(); + size_t ysize = state.range(); + size_t num_channels = 4; + + ImageMetadata im; + im.SetAlphaBits(8); + ImageBundle ib(&im); + Image3F color(xsize, ysize); + ZeroFillImage(&color); + ib.SetFromImage(std::move(color), ColorEncoding::SRGB()); + ImageF alpha(xsize, ysize); + ZeroFillImage(&alpha); + ib.SetAlpha(std::move(alpha)); + + const size_t bytes_per_row = xsize * num_channels; + std::vector<uint8_t> interleaved(bytes_per_row * ysize); + + for (auto _ : state) { + for (size_t i = 0; i < kNumIter; ++i) { + JXL_CHECK(ConvertToExternal( + ib, + /*bits_per_sample=*/8, + /*float_out=*/false, num_channels, JXL_NATIVE_ENDIAN, + /*stride*/ bytes_per_row, + /*thread_pool=*/nullptr, interleaved.data(), interleaved.size(), + /*out_callback=*/{}, + /*undo_orientation=*/jxl::Orientation::kIdentity)); + } + } + + // Pixels per second. + state.SetItemsProcessed(kNumIter * state.iterations() * xsize * ysize); + state.SetBytesProcessed(kNumIter * state.iterations() * interleaved.size()); +} + +BENCHMARK(BM_DecExternalImage_ConvertImageRGBA) + ->RangeMultiplier(2) + ->Range(256, 2048); + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_frame.cc b/third_party/jpeg-xl/lib/jxl/dec_frame.cc new file mode 100644 index 0000000000..918dbe7c37 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_frame.cc @@ -0,0 +1,890 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_frame.h" + +#include <jxl/decode.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <atomic> +#include <cstdlib> +#include <memory> +#include <utility> +#include <vector> + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" // kMaxNumPasses +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_group.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/dec_noise.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/jpeg/jpeg_data.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/render_pipeline/render_pipeline.h" +#include "lib/jxl/splines.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +namespace { +Status DecodeGlobalDCInfo(BitReader* reader, bool is_jpeg, + PassesDecoderState* state, ThreadPool* pool) { + JXL_RETURN_IF_ERROR(state->shared_storage.quantizer.Decode(reader)); + + JXL_RETURN_IF_ERROR( + DecodeBlockCtxMap(reader, &state->shared_storage.block_ctx_map)); + + JXL_RETURN_IF_ERROR(state->shared_storage.cmap.DecodeDC(reader)); + + // Pre-compute info for decoding a group. + if (is_jpeg) { + state->shared_storage.quantizer.ClearDCMul(); // Don't dequant DC + } + + state->shared_storage.ac_strategy.FillInvalid(); + return true; +} +} // namespace + +Status DecodeFrame(PassesDecoderState* dec_state, ThreadPool* JXL_RESTRICT pool, + const uint8_t* next_in, size_t avail_in, + FrameHeader* frame_header, ImageBundle* decoded, + const CodecMetadata& metadata, + bool use_slow_rendering_pipeline) { + FrameDecoder frame_decoder(dec_state, metadata, pool, + use_slow_rendering_pipeline); + + BitReader reader(Bytes(next_in, avail_in)); + JXL_RETURN_IF_ERROR(frame_decoder.InitFrame(&reader, decoded, + /*is_preview=*/false)); + JXL_RETURN_IF_ERROR(frame_decoder.InitFrameOutput()); + if (frame_header) { + *frame_header = frame_decoder.GetFrameHeader(); + } + + JXL_RETURN_IF_ERROR(reader.AllReadsWithinBounds()); + size_t header_bytes = reader.TotalBitsConsumed() / kBitsPerByte; + JXL_RETURN_IF_ERROR(reader.Close()); + + size_t processed_bytes = header_bytes; + Status close_ok = true; + std::vector<std::unique_ptr<BitReader>> section_readers; + { + std::vector<std::unique_ptr<BitReaderScopedCloser>> section_closers; + std::vector<FrameDecoder::SectionInfo> section_info; + std::vector<FrameDecoder::SectionStatus> section_status; + size_t pos = header_bytes; + size_t index = 0; + for (auto toc_entry : frame_decoder.Toc()) { + JXL_RETURN_IF_ERROR(pos + toc_entry.size <= avail_in); + auto br = make_unique<BitReader>(Bytes(next_in + pos, toc_entry.size)); + section_info.emplace_back( + FrameDecoder::SectionInfo{br.get(), toc_entry.id, index++}); + section_closers.emplace_back( + make_unique<BitReaderScopedCloser>(br.get(), &close_ok)); + section_readers.emplace_back(std::move(br)); + pos += toc_entry.size; + } + section_status.resize(section_info.size()); + JXL_RETURN_IF_ERROR(frame_decoder.ProcessSections( + section_info.data(), section_info.size(), section_status.data())); + for (size_t i = 0; i < section_status.size(); i++) { + JXL_RETURN_IF_ERROR(section_status[i] == FrameDecoder::kDone); + processed_bytes += frame_decoder.Toc()[i].size; + } + } + JXL_RETURN_IF_ERROR(close_ok); + JXL_RETURN_IF_ERROR(frame_decoder.FinalizeFrame()); + decoded->SetDecodedBytes(processed_bytes); + return true; +} + +Status FrameDecoder::InitFrame(BitReader* JXL_RESTRICT br, ImageBundle* decoded, + bool is_preview) { + decoded_ = decoded; + JXL_ASSERT(is_finalized_); + + // Reset the dequantization matrices to their default values. + dec_state_->shared_storage.matrices = DequantMatrices(); + + frame_header_.nonserialized_is_preview = is_preview; + JXL_ASSERT(frame_header_.nonserialized_metadata != nullptr); + JXL_RETURN_IF_ERROR(ReadFrameHeader(br, &frame_header_)); + frame_dim_ = frame_header_.ToFrameDimensions(); + JXL_DEBUG_V(2, "FrameHeader: %s", frame_header_.DebugString().c_str()); + + const size_t num_passes = frame_header_.passes.num_passes; + const size_t num_groups = frame_dim_.num_groups; + + // If the previous frame was not a kRegularFrame, `decoded` may have different + // dimensions; must reset to avoid errors. + decoded->RemoveColor(); + decoded->ClearExtraChannels(); + + decoded->duration = frame_header_.animation_frame.duration; + + if (!frame_header_.nonserialized_is_preview && + (frame_header_.is_last || frame_header_.animation_frame.duration > 0) && + (frame_header_.frame_type == kRegularFrame || + frame_header_.frame_type == kSkipProgressive)) { + ++dec_state_->visible_frame_index; + dec_state_->nonvisible_frame_index = 0; + } else { + ++dec_state_->nonvisible_frame_index; + } + + // Read TOC. + const size_t toc_entries = + NumTocEntries(num_groups, frame_dim_.num_dc_groups, num_passes); + std::vector<uint32_t> sizes; + std::vector<coeff_order_t> permutation; + JXL_RETURN_IF_ERROR(ReadToc(toc_entries, br, &sizes, &permutation)); + bool have_permutation = !permutation.empty(); + toc_.resize(toc_entries); + section_sizes_sum_ = 0; + for (size_t i = 0; i < toc_entries; ++i) { + toc_[i].size = sizes[i]; + size_t index = have_permutation ? permutation[i] : i; + toc_[index].id = i; + if (section_sizes_sum_ + toc_[i].size < section_sizes_sum_) { + return JXL_FAILURE("group offset overflow"); + } + section_sizes_sum_ += toc_[i].size; + } + + if (JXL_DEBUG_V_LEVEL >= 3) { + for (size_t i = 0; i < toc_entries; ++i) { + JXL_DEBUG_V(3, "TOC entry %" PRIuS " size %" PRIuS " id %" PRIuS "", i, + toc_[i].size, toc_[i].id); + } + } + + JXL_DASSERT((br->TotalBitsConsumed() % kBitsPerByte) == 0); + const size_t group_codes_begin = br->TotalBitsConsumed() / kBitsPerByte; + JXL_DASSERT(!toc_.empty()); + + // Overflow check. + if (group_codes_begin + section_sizes_sum_ < group_codes_begin) { + return JXL_FAILURE("Invalid group codes"); + } + + if (!frame_header_.chroma_subsampling.Is444() && + !(frame_header_.flags & FrameHeader::kSkipAdaptiveDCSmoothing) && + frame_header_.encoding == FrameEncoding::kVarDCT) { + return JXL_FAILURE( + "Non-444 chroma subsampling is not allowed when adaptive DC " + "smoothing is enabled"); + } + return true; +} + +Status FrameDecoder::InitFrameOutput() { + JXL_RETURN_IF_ERROR( + InitializePassesSharedState(frame_header_, &dec_state_->shared_storage)); + JXL_RETURN_IF_ERROR(dec_state_->Init(frame_header_)); + modular_frame_decoder_.Init(frame_dim_); + + if (decoded_->IsJPEG()) { + if (frame_header_.encoding == FrameEncoding::kModular) { + return JXL_FAILURE("Cannot output JPEG from Modular"); + } + jpeg::JPEGData* jpeg_data = decoded_->jpeg_data.get(); + size_t num_components = jpeg_data->components.size(); + if (num_components != 1 && num_components != 3) { + return JXL_FAILURE("Invalid number of components"); + } + if (frame_header_.nonserialized_metadata->m.xyb_encoded) { + return JXL_FAILURE("Cannot decode to JPEG an XYB image"); + } + auto jpeg_c_map = JpegOrder(ColorTransform::kYCbCr, num_components == 1); + decoded_->jpeg_data->width = frame_dim_.xsize; + decoded_->jpeg_data->height = frame_dim_.ysize; + for (size_t c = 0; c < num_components; c++) { + auto& component = jpeg_data->components[jpeg_c_map[c]]; + component.width_in_blocks = + frame_dim_.xsize_blocks >> frame_header_.chroma_subsampling.HShift(c); + component.height_in_blocks = + frame_dim_.ysize_blocks >> frame_header_.chroma_subsampling.VShift(c); + component.h_samp_factor = + 1 << frame_header_.chroma_subsampling.RawHShift(c); + component.v_samp_factor = + 1 << frame_header_.chroma_subsampling.RawVShift(c); + component.coeffs.resize(component.width_in_blocks * + component.height_in_blocks * jxl::kDCTBlockSize); + } + } + + // Clear the state. + decoded_dc_global_ = false; + decoded_ac_global_ = false; + is_finalized_ = false; + finalized_dc_ = false; + num_sections_done_ = 0; + decoded_dc_groups_.clear(); + decoded_dc_groups_.resize(frame_dim_.num_dc_groups); + decoded_passes_per_ac_group_.clear(); + decoded_passes_per_ac_group_.resize(frame_dim_.num_groups, 0); + processed_section_.clear(); + processed_section_.resize(toc_.size()); + allocated_ = false; + return true; +} + +Status FrameDecoder::ProcessDCGlobal(BitReader* br) { + PassesSharedState& shared = dec_state_->shared_storage; + if (frame_header_.flags & FrameHeader::kPatches) { + bool uses_extra_channels = false; + JXL_RETURN_IF_ERROR(shared.image_features.patches.Decode( + br, frame_dim_.xsize_padded, frame_dim_.ysize_padded, + &uses_extra_channels)); + if (uses_extra_channels && frame_header_.upsampling != 1) { + for (size_t ecups : frame_header_.extra_channel_upsampling) { + if (ecups != frame_header_.upsampling) { + return JXL_FAILURE( + "Cannot use extra channels in patches if color channels are " + "subsampled differently from extra channels"); + } + } + } + } else { + shared.image_features.patches.Clear(); + } + shared.image_features.splines.Clear(); + if (frame_header_.flags & FrameHeader::kSplines) { + JXL_RETURN_IF_ERROR(shared.image_features.splines.Decode( + br, frame_dim_.xsize * frame_dim_.ysize)); + } + if (frame_header_.flags & FrameHeader::kNoise) { + JXL_RETURN_IF_ERROR(DecodeNoise(br, &shared.image_features.noise_params)); + } + JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.DecodeDC(br)); + + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR( + jxl::DecodeGlobalDCInfo(br, decoded_->IsJPEG(), dec_state_, pool_)); + } + // Splines' draw cache uses the color correlation map. + if (frame_header_.flags & FrameHeader::kSplines) { + JXL_RETURN_IF_ERROR(shared.image_features.splines.InitializeDrawCache( + frame_dim_.xsize_upsampled, frame_dim_.ysize_upsampled, + dec_state_->shared->cmap)); + } + Status dec_status = modular_frame_decoder_.DecodeGlobalInfo( + br, frame_header_, /*allow_truncated_group=*/false); + if (dec_status.IsFatalError()) return dec_status; + if (dec_status) { + decoded_dc_global_ = true; + } + return dec_status; +} + +Status FrameDecoder::ProcessDCGroup(size_t dc_group_id, BitReader* br) { + const size_t gx = dc_group_id % frame_dim_.xsize_dc_groups; + const size_t gy = dc_group_id / frame_dim_.xsize_dc_groups; + const LoopFilter& lf = frame_header_.loop_filter; + if (frame_header_.encoding == FrameEncoding::kVarDCT && + !(frame_header_.flags & FrameHeader::kUseDcFrame)) { + JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeVarDCTDC( + frame_header_, dc_group_id, br, dec_state_)); + } + const Rect mrect(gx * frame_dim_.dc_group_dim, gy * frame_dim_.dc_group_dim, + frame_dim_.dc_group_dim, frame_dim_.dc_group_dim); + JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup( + frame_header_, mrect, br, 3, 1000, + ModularStreamId::ModularDC(dc_group_id), + /*zerofill=*/false, nullptr, nullptr, + /*allow_truncated=*/false)); + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeAcMetadata( + frame_header_, dc_group_id, br, dec_state_)); + } else if (lf.epf_iters > 0) { + FillImage(kInvSigmaNum / lf.epf_sigma_for_modular, &dec_state_->sigma); + } + decoded_dc_groups_[dc_group_id] = uint8_t{true}; + return true; +} + +void FrameDecoder::FinalizeDC() { + // Do Adaptive DC smoothing if enabled. This *must* happen between all the + // ProcessDCGroup and ProcessACGroup. + if (frame_header_.encoding == FrameEncoding::kVarDCT && + !(frame_header_.flags & FrameHeader::kSkipAdaptiveDCSmoothing) && + !(frame_header_.flags & FrameHeader::kUseDcFrame)) { + AdaptiveDCSmoothing(dec_state_->shared->quantizer.MulDC(), + &dec_state_->shared_storage.dc_storage, pool_); + } + + finalized_dc_ = true; +} + +Status FrameDecoder::AllocateOutput() { + if (allocated_) return true; + modular_frame_decoder_.MaybeDropFullImage(); + decoded_->origin = frame_header_.frame_origin; + JXL_RETURN_IF_ERROR( + dec_state_->InitForAC(frame_header_.passes.num_passes, nullptr)); + allocated_ = true; + return true; +} + +Status FrameDecoder::ProcessACGlobal(BitReader* br) { + JXL_CHECK(finalized_dc_); + + // Decode AC group. + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.Decode( + br, &modular_frame_decoder_)); + JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.EnsureComputed( + dec_state_->used_acs)); + + size_t num_histo_bits = + CeilLog2Nonzero(dec_state_->shared->frame_dim.num_groups); + dec_state_->shared_storage.num_histograms = + 1 + br->ReadBits(num_histo_bits); + + JXL_DEBUG_V(3, + "Processing AC global with %d passes and %" PRIuS + " sets of histograms", + frame_header_.passes.num_passes, + dec_state_->shared_storage.num_histograms); + + dec_state_->code.resize(kMaxNumPasses); + dec_state_->context_map.resize(kMaxNumPasses); + // Read coefficient orders and histograms. + size_t max_num_bits_ac = 0; + for (size_t i = 0; i < frame_header_.passes.num_passes; i++) { + uint16_t used_orders = U32Coder::Read(kOrderEnc, br); + JXL_RETURN_IF_ERROR(DecodeCoeffOrders( + used_orders, dec_state_->used_acs, + &dec_state_->shared_storage + .coeff_orders[i * dec_state_->shared_storage.coeff_order_size], + br)); + size_t num_contexts = + dec_state_->shared->num_histograms * + dec_state_->shared_storage.block_ctx_map.NumACContexts(); + JXL_RETURN_IF_ERROR(DecodeHistograms( + br, num_contexts, &dec_state_->code[i], &dec_state_->context_map[i])); + // Add extra values to enable the cheat in hot loop of DecodeACVarBlock. + dec_state_->context_map[i].resize( + num_contexts + kZeroDensityContextLimit - kZeroDensityContextCount); + max_num_bits_ac = + std::max(max_num_bits_ac, dec_state_->code[i].max_num_bits); + } + max_num_bits_ac += CeilLog2Nonzero(frame_header_.passes.num_passes); + // 16-bit buffer for decoding to JPEG are not implemented. + // TODO(veluca): figure out the exact limit - 16 should still work with + // 16-bit buffers, but we are excluding it for safety. + bool use_16_bit = max_num_bits_ac < 16 && !decoded_->IsJPEG(); + bool store = frame_header_.passes.num_passes > 1; + size_t xs = store ? kGroupDim * kGroupDim : 0; + size_t ys = store ? frame_dim_.num_groups : 0; + if (use_16_bit) { + dec_state_->coefficients = make_unique<ACImageT<int16_t>>(xs, ys); + } else { + dec_state_->coefficients = make_unique<ACImageT<int32_t>>(xs, ys); + } + if (store) { + dec_state_->coefficients->ZeroFill(); + } + } + + // Set JPEG decoding data. + if (decoded_->IsJPEG()) { + decoded_->color_transform = frame_header_.color_transform; + decoded_->chroma_subsampling = frame_header_.chroma_subsampling; + const std::vector<QuantEncoding>& qe = + dec_state_->shared_storage.matrices.encodings(); + if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW || + std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) { + return JXL_FAILURE( + "Quantization table is not a JPEG quantization table."); + } + jpeg::JPEGData* jpeg_data = decoded_->jpeg_data.get(); + size_t num_components = jpeg_data->components.size(); + bool is_gray = (num_components == 1); + auto jpeg_c_map = JpegOrder(frame_header_.color_transform, is_gray); + size_t qt_set = 0; + for (size_t c = 0; c < num_components; c++) { + // TODO(eustas): why 1-st quant table for gray? + size_t quant_c = is_gray ? 1 : c; + size_t qpos = jpeg_data->components[jpeg_c_map[c]].quant_idx; + JXL_CHECK(qpos != jpeg_data->quant.size()); + qt_set |= 1 << qpos; + for (size_t x = 0; x < 8; x++) { + for (size_t y = 0; y < 8; y++) { + jpeg_data->quant[qpos].values[x * 8 + y] = + (*qe[0].qraw.qtable)[quant_c * 64 + y * 8 + x]; + } + } + } + for (size_t i = 0; i < jpeg_data->quant.size(); i++) { + if (qt_set & (1 << i)) continue; + if (i == 0) return JXL_FAILURE("First quant table unused."); + // Unused quant table is set to copy of previous quant table + for (size_t j = 0; j < 64; j++) { + jpeg_data->quant[i].values[j] = jpeg_data->quant[i - 1].values[j]; + } + } + } + decoded_ac_global_ = true; + return true; +} + +Status FrameDecoder::ProcessACGroup(size_t ac_group_id, + BitReader* JXL_RESTRICT* br, + size_t num_passes, size_t thread, + bool force_draw, bool dc_only) { + size_t group_dim = frame_dim_.group_dim; + const size_t gx = ac_group_id % frame_dim_.xsize_groups; + const size_t gy = ac_group_id / frame_dim_.xsize_groups; + const size_t x = gx * group_dim; + const size_t y = gy * group_dim; + JXL_DEBUG_V(3, + "Processing AC group %" PRIuS "(%" PRIuS ",%" PRIuS + ") group_dim: %" PRIuS " decoded passes: %u new passes: %" PRIuS, + ac_group_id, gx, gy, group_dim, + decoded_passes_per_ac_group_[ac_group_id], num_passes); + + RenderPipelineInput render_pipeline_input = + dec_state_->render_pipeline->GetInputBuffers(ac_group_id, thread); + + bool should_run_pipeline = true; + + if (frame_header_.encoding == FrameEncoding::kVarDCT) { + group_dec_caches_[thread].InitOnce(frame_header_.passes.num_passes, + dec_state_->used_acs); + JXL_RETURN_IF_ERROR(DecodeGroup(frame_header_, br, num_passes, ac_group_id, + dec_state_, &group_dec_caches_[thread], + thread, render_pipeline_input, decoded_, + decoded_passes_per_ac_group_[ac_group_id], + force_draw, dc_only, &should_run_pipeline)); + } + + // don't limit to image dimensions here (is done in DecodeGroup) + const Rect mrect(x, y, group_dim, group_dim); + bool modular_ready = false; + size_t pass0 = decoded_passes_per_ac_group_[ac_group_id]; + size_t pass1 = + force_draw ? frame_header_.passes.num_passes : pass0 + num_passes; + for (size_t i = pass0; i < pass1; ++i) { + int minShift, maxShift; + frame_header_.passes.GetDownsamplingBracket(i, minShift, maxShift); + bool modular_pass_ready = true; + if (i < pass0 + num_passes) { + JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup( + frame_header_, mrect, br[i - pass0], minShift, maxShift, + ModularStreamId::ModularAC(ac_group_id, i), + /*zerofill=*/false, dec_state_, &render_pipeline_input, + /*allow_truncated=*/false, &modular_pass_ready)); + } else { + JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup( + frame_header_, mrect, nullptr, minShift, maxShift, + ModularStreamId::ModularAC(ac_group_id, i), /*zerofill=*/true, + dec_state_, &render_pipeline_input, + /*allow_truncated=*/false, &modular_pass_ready)); + } + if (modular_pass_ready) modular_ready = true; + } + decoded_passes_per_ac_group_[ac_group_id] += num_passes; + + if ((frame_header_.flags & FrameHeader::kNoise) != 0) { + size_t noise_c_start = + 3 + frame_header_.nonserialized_metadata->m.num_extra_channels; + // When the color channels are downsampled, we need to generate more noise + // input for the current group than just the group dimensions. + std::pair<ImageF*, Rect> rects[3]; + for (size_t iy = 0; iy < frame_header_.upsampling; iy++) { + for (size_t ix = 0; ix < frame_header_.upsampling; ix++) { + for (size_t c = 0; c < 3; c++) { + auto r = render_pipeline_input.GetBuffer(noise_c_start + c); + rects[c].first = r.first; + size_t x1 = r.second.x0() + r.second.xsize(); + size_t y1 = r.second.y0() + r.second.ysize(); + rects[c].second = Rect(r.second.x0() + ix * group_dim, + r.second.y0() + iy * group_dim, group_dim, + group_dim, x1, y1); + } + Random3Planes(dec_state_->visible_frame_index, + dec_state_->nonvisible_frame_index, + (gx * frame_header_.upsampling + ix) * group_dim, + (gy * frame_header_.upsampling + iy) * group_dim, + rects[0], rects[1], rects[2]); + } + } + } + + if (!modular_frame_decoder_.UsesFullImage() && !decoded_->IsJPEG()) { + if (should_run_pipeline && modular_ready) { + render_pipeline_input.Done(); + } else if (force_draw) { + return JXL_FAILURE("Modular group decoding failed."); + } + } + return true; +} + +void FrameDecoder::MarkSections(const SectionInfo* sections, size_t num, + SectionStatus* section_status) { + num_sections_done_ += num; + for (size_t i = 0; i < num; i++) { + if (section_status[i] != SectionStatus::kDone) { + processed_section_[sections[i].id] = false; + num_sections_done_--; + } + } +} + +Status FrameDecoder::ProcessSections(const SectionInfo* sections, size_t num, + SectionStatus* section_status) { + if (num == 0) return true; // Nothing to process + std::fill(section_status, section_status + num, SectionStatus::kSkipped); + size_t dc_global_sec = num; + size_t ac_global_sec = num; + std::vector<size_t> dc_group_sec(frame_dim_.num_dc_groups, num); + std::vector<std::vector<size_t>> ac_group_sec( + frame_dim_.num_groups, + std::vector<size_t>(frame_header_.passes.num_passes, num)); + // This keeps track of the number of ac passes we want to process during this + // call of ProcessSections. + std::vector<size_t> desired_num_ac_passes(frame_dim_.num_groups); + bool single_section = + frame_dim_.num_groups == 1 && frame_header_.passes.num_passes == 1; + if (single_section) { + JXL_ASSERT(num == 1); + JXL_ASSERT(sections[0].id == 0); + if (processed_section_[0] == false) { + processed_section_[0] = true; + ac_group_sec[0].resize(1); + dc_global_sec = ac_global_sec = dc_group_sec[0] = ac_group_sec[0][0] = 0; + desired_num_ac_passes[0] = 1; + } else { + section_status[0] = SectionStatus::kDuplicate; + } + } else { + size_t ac_global_index = frame_dim_.num_dc_groups + 1; + for (size_t i = 0; i < num; i++) { + JXL_ASSERT(sections[i].id < processed_section_.size()); + if (processed_section_[sections[i].id]) { + section_status[i] = SectionStatus::kDuplicate; + continue; + } + if (sections[i].id == 0) { + dc_global_sec = i; + } else if (sections[i].id < ac_global_index) { + dc_group_sec[sections[i].id - 1] = i; + } else if (sections[i].id == ac_global_index) { + ac_global_sec = i; + } else { + size_t ac_idx = sections[i].id - ac_global_index - 1; + size_t acg = ac_idx % frame_dim_.num_groups; + size_t acp = ac_idx / frame_dim_.num_groups; + if (acp >= frame_header_.passes.num_passes) { + return JXL_FAILURE("Invalid section ID"); + } + ac_group_sec[acg][acp] = i; + } + processed_section_[sections[i].id] = true; + } + // Count number of new passes per group. + for (size_t g = 0; g < ac_group_sec.size(); g++) { + size_t j = 0; + for (; j + decoded_passes_per_ac_group_[g] < + frame_header_.passes.num_passes; + j++) { + if (ac_group_sec[g][j + decoded_passes_per_ac_group_[g]] == num) { + break; + } + } + desired_num_ac_passes[g] = j; + } + } + if (dc_global_sec != num) { + Status dc_global_status = ProcessDCGlobal(sections[dc_global_sec].br); + if (dc_global_status.IsFatalError()) return dc_global_status; + if (dc_global_status) { + section_status[dc_global_sec] = SectionStatus::kDone; + } else { + section_status[dc_global_sec] = SectionStatus::kPartial; + } + } + + std::atomic<bool> has_error{false}; + if (decoded_dc_global_) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool_, 0, dc_group_sec.size(), ThreadPool::NoInit, + [this, &dc_group_sec, &num, §ions, §ion_status, &has_error]( + size_t i, size_t thread) { + if (dc_group_sec[i] != num) { + if (!ProcessDCGroup(i, sections[dc_group_sec[i]].br)) { + has_error = true; + } else { + section_status[dc_group_sec[i]] = SectionStatus::kDone; + } + } + }, + "DecodeDCGroup")); + } + if (has_error) return JXL_FAILURE("Error in DC group"); + + if (*std::min_element(decoded_dc_groups_.begin(), decoded_dc_groups_.end()) && + !finalized_dc_) { + PassesDecoderState::PipelineOptions pipeline_options; + pipeline_options.use_slow_render_pipeline = use_slow_rendering_pipeline_; + pipeline_options.coalescing = coalescing_; + pipeline_options.render_spotcolors = render_spotcolors_; + pipeline_options.render_noise = true; + JXL_RETURN_IF_ERROR( + dec_state_->PreparePipeline(frame_header_, decoded_, pipeline_options)); + FinalizeDC(); + JXL_RETURN_IF_ERROR(AllocateOutput()); + if (progressive_detail_ >= JxlProgressiveDetail::kDC) { + MarkSections(sections, num, section_status); + return true; + } + } + + if (finalized_dc_ && ac_global_sec != num && !decoded_ac_global_) { + JXL_RETURN_IF_ERROR(ProcessACGlobal(sections[ac_global_sec].br)); + section_status[ac_global_sec] = SectionStatus::kDone; + } + + if (progressive_detail_ >= JxlProgressiveDetail::kLastPasses) { + // Mark that we only want the next progression pass. + size_t target_complete_passes = NextNumPassesToPause(); + for (size_t i = 0; i < ac_group_sec.size(); i++) { + desired_num_ac_passes[i] = + std::min(desired_num_ac_passes[i], + target_complete_passes - decoded_passes_per_ac_group_[i]); + } + } + + if (decoded_ac_global_) { + // Mark all the AC groups that we received as not complete yet. + for (size_t i = 0; i < ac_group_sec.size(); i++) { + if (desired_num_ac_passes[i] != 0) { + dec_state_->render_pipeline->ClearDone(i); + } + } + + JXL_RETURN_IF_ERROR(RunOnPool( + pool_, 0, ac_group_sec.size(), + [this](size_t num_threads) { + return PrepareStorage(num_threads, + decoded_passes_per_ac_group_.size()); + }, + [this, &ac_group_sec, &desired_num_ac_passes, &num, §ions, + §ion_status, &has_error](size_t g, size_t thread) { + if (desired_num_ac_passes[g] == 0) { + // no new AC pass, nothing to do + return; + } + (void)num; + size_t first_pass = decoded_passes_per_ac_group_[g]; + BitReader* JXL_RESTRICT readers[kMaxNumPasses]; + for (size_t i = 0; i < desired_num_ac_passes[g]; i++) { + JXL_ASSERT(ac_group_sec[g][first_pass + i] != num); + readers[i] = sections[ac_group_sec[g][first_pass + i]].br; + } + if (!ProcessACGroup(g, readers, desired_num_ac_passes[g], + GetStorageLocation(thread, g), + /*force_draw=*/false, /*dc_only=*/false)) { + has_error = true; + } else { + for (size_t i = 0; i < desired_num_ac_passes[g]; i++) { + section_status[ac_group_sec[g][first_pass + i]] = + SectionStatus::kDone; + } + } + }, + "DecodeGroup")); + } + if (has_error) return JXL_FAILURE("Error in AC group"); + + MarkSections(sections, num, section_status); + return true; +} + +Status FrameDecoder::Flush() { + bool has_blending = frame_header_.blending_info.mode != BlendMode::kReplace || + frame_header_.custom_size_or_origin; + for (const auto& blending_info_ec : + frame_header_.extra_channel_blending_info) { + if (blending_info_ec.mode != BlendMode::kReplace) has_blending = true; + } + // No early Flush() if blending is enabled. + if (has_blending && !is_finalized_) { + return false; + } + // No early Flush() - nothing to do - if the frame is a kSkipProgressive + // frame. + if (frame_header_.frame_type == FrameType::kSkipProgressive && + !is_finalized_) { + return true; + } + if (decoded_->IsJPEG()) { + // Nothing to do. + return true; + } + JXL_RETURN_IF_ERROR(AllocateOutput()); + + uint32_t completely_decoded_ac_pass = *std::min_element( + decoded_passes_per_ac_group_.begin(), decoded_passes_per_ac_group_.end()); + if (completely_decoded_ac_pass < frame_header_.passes.num_passes) { + // We don't have all AC yet: force a draw of all the missing areas. + // Mark all sections as not complete. + for (size_t i = 0; i < decoded_passes_per_ac_group_.size(); i++) { + if (decoded_passes_per_ac_group_[i] < frame_header_.passes.num_passes) { + dec_state_->render_pipeline->ClearDone(i); + } + } + std::atomic<bool> has_error{false}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool_, 0, decoded_passes_per_ac_group_.size(), + [this](const size_t num_threads) { + return PrepareStorage(num_threads, + decoded_passes_per_ac_group_.size()); + }, + [this, &has_error](const uint32_t g, size_t thread) { + if (decoded_passes_per_ac_group_[g] == + frame_header_.passes.num_passes) { + // This group was drawn already, nothing to do. + return; + } + BitReader* JXL_RESTRICT readers[kMaxNumPasses] = {}; + bool ok = ProcessACGroup( + g, readers, /*num_passes=*/0, GetStorageLocation(thread, g), + /*force_draw=*/true, /*dc_only=*/!decoded_ac_global_); + if (!ok) has_error = true; + }, + "ForceDrawGroup")); + if (has_error) { + return JXL_FAILURE("Drawing groups failed"); + } + } + + // undo global modular transforms and copy int pixel buffers to float ones + JXL_RETURN_IF_ERROR(modular_frame_decoder_.FinalizeDecoding( + frame_header_, dec_state_, pool_, is_finalized_)); + + return true; +} + +int FrameDecoder::SavedAs(const FrameHeader& header) { + if (header.frame_type == FrameType::kDCFrame) { + // bits 16, 32, 64, 128 for DC level + return 16 << (header.dc_level - 1); + } else if (header.CanBeReferenced()) { + // bits 1, 2, 4 and 8 for the references + return 1 << header.save_as_reference; + } + + return 0; +} + +bool FrameDecoder::HasEverything() const { + if (!decoded_dc_global_) return false; + if (!decoded_ac_global_) return false; + for (auto& have_dc_group : decoded_dc_groups_) { + if (!have_dc_group) return false; + } + for (auto& nb_passes : decoded_passes_per_ac_group_) { + if (nb_passes < frame_header_.passes.num_passes) return false; + } + return true; +} + +int FrameDecoder::References() const { + if (is_finalized_) { + return 0; + } + if (!HasEverything()) return 0; + + int result = 0; + + // Blending + if (frame_header_.frame_type == FrameType::kRegularFrame || + frame_header_.frame_type == FrameType::kSkipProgressive) { + bool cropped = frame_header_.custom_size_or_origin; + if (cropped || frame_header_.blending_info.mode != BlendMode::kReplace) { + result |= (1 << frame_header_.blending_info.source); + } + const auto& extra = frame_header_.extra_channel_blending_info; + for (size_t i = 0; i < extra.size(); ++i) { + if (cropped || extra[i].mode != BlendMode::kReplace) { + result |= (1 << extra[i].source); + } + } + } + + // Patches + if (frame_header_.flags & FrameHeader::kPatches) { + result |= dec_state_->shared->image_features.patches.GetReferences(); + } + + // DC Level + if (frame_header_.flags & FrameHeader::kUseDcFrame) { + // Reads from the next dc level + int dc_level = frame_header_.dc_level + 1; + // bits 16, 32, 64, 128 for DC level + result |= (16 << (dc_level - 1)); + } + + return result; +} + +Status FrameDecoder::FinalizeFrame() { + if (is_finalized_) { + return JXL_FAILURE("FinalizeFrame called multiple times"); + } + is_finalized_ = true; + if (decoded_->IsJPEG()) { + // Nothing to do. + return true; + } + + // undo global modular transforms and copy int pixel buffers to float ones + JXL_RETURN_IF_ERROR( + modular_frame_decoder_.FinalizeDecoding(frame_header_, dec_state_, pool_, + /*inplace=*/true)); + + if (frame_header_.CanBeReferenced()) { + auto& info = dec_state_->shared_storage + .reference_frames[frame_header_.save_as_reference]; + info.frame = std::move(dec_state_->frame_storage_for_referencing); + info.ib_is_in_xyb = frame_header_.save_before_color_transform; + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_frame.h b/third_party/jpeg-xl/lib/jxl/dec_frame.h new file mode 100644 index 0000000000..09bdbc9675 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_frame.h @@ -0,0 +1,335 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_FRAME_H_ +#define LIB_JXL_DEC_FRAME_H_ + +#include <jxl/decode.h> +#include <jxl/types.h> +#include <stdint.h> + +#include <algorithm> +#include <cstddef> +#include <limits> +#include <utility> +#include <vector> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" // JXL_HIGH_PRECISION +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" + +namespace jxl { + +// Decodes a frame. Groups may be processed in parallel by `pool`. +// `metadata` is the metadata that applies to all frames of the codestream +// `decoded->metadata` must already be set and must match metadata.m. +// Used in the encoder to model decoder behaviour, and in tests. +Status DecodeFrame(PassesDecoderState* dec_state, ThreadPool* JXL_RESTRICT pool, + const uint8_t* next_in, size_t avail_in, + FrameHeader* frame_header, ImageBundle* decoded, + const CodecMetadata& metadata, + bool use_slow_rendering_pipeline = false); + +// TODO(veluca): implement "forced drawing". +class FrameDecoder { + public: + // All parameters must outlive the FrameDecoder. + FrameDecoder(PassesDecoderState* dec_state, const CodecMetadata& metadata, + ThreadPool* pool, bool use_slow_rendering_pipeline) + : dec_state_(dec_state), + pool_(pool), + frame_header_(&metadata), + use_slow_rendering_pipeline_(use_slow_rendering_pipeline) {} + + void SetRenderSpotcolors(bool rsc) { render_spotcolors_ = rsc; } + void SetCoalescing(bool c) { coalescing_ = c; } + + // Read FrameHeader and table of contents from the given BitReader. + Status InitFrame(BitReader* JXL_RESTRICT br, ImageBundle* decoded, + bool is_preview); + + // Checks frame dimensions for their limits, and sets the output + // image buffer. + Status InitFrameOutput(); + + struct SectionInfo { + BitReader* JXL_RESTRICT br; + // Logical index of the section, regardless of any permutation that may be + // applied in the table of contents or of the physical position in the file. + size_t id; + // Index of the section in the order of the bytes inside the frame. + size_t index; + }; + + struct TocEntry { + size_t size; + size_t id; + }; + + enum SectionStatus { + // Processed correctly. + kDone = 0, + // Skipped because other required sections were not yet processed. + kSkipped = 1, + // Skipped because the section was already processed. + kDuplicate = 2, + // Only partially decoded: the section will need to be processed again. + kPartial = 3, + }; + + // Processes `num` sections; each SectionInfo contains the index + // of the section and a BitReader that only contains the data of the section. + // `section_status` should point to `num` elements, and will be filled with + // information about whether each section was processed or not. + // A section is a part of the encoded file that is indexed by the TOC. + Status ProcessSections(const SectionInfo* sections, size_t num, + SectionStatus* section_status); + + // Flushes all the data decoded so far to pixels. + Status Flush(); + + // Runs final operations once a frame data is decoded. + // Must be called exactly once per frame, after all calls to ProcessSections. + Status FinalizeFrame(); + + // Returns dependencies of this frame on reference ids as a bit mask: bits 0-3 + // indicate reference frame 0-3 for patches and blending, bits 4-7 indicate DC + // frames this frame depends on. Only returns a valid result after all calls + // to ProcessSections are finished and before FinalizeFrame. + int References() const; + + // Returns reference id of storage location where this frame is stored as a + // bit flag, or 0 if not stored. + // Matches the bit mask used for GetReferences: bits 0-3 indicate it is stored + // for patching or blending, bits 4-7 indicate DC frame. + // Unlike References, can be ran at any time as + // soon as the frame header is known. + static int SavedAs(const FrameHeader& header); + + uint64_t SumSectionSizes() const { return section_sizes_sum_; } + const std::vector<TocEntry>& Toc() const { return toc_; } + + const FrameHeader& GetFrameHeader() const { return frame_header_; } + + // Returns whether a DC image has been decoded, accessible at low resolution + // at passes.shared_storage.dc_storage + bool HasDecodedDC() const { return finalized_dc_; } + bool HasDecodedAll() const { return toc_.size() == num_sections_done_; } + + size_t NumCompletePasses() const { + return *std::min_element(decoded_passes_per_ac_group_.begin(), + decoded_passes_per_ac_group_.end()); + } + + // If enabled, ProcessSections will stop and return true when the DC + // sections have been processed, instead of starting the AC sections. This + // will only occur if supported (that is, flushing will produce a valid + // 1/8th*1/8th resolution image). The return value of true then does not mean + // all sections have been processed, use HasDecodedDC and HasDecodedAll + // to check the true finished state. + // Returns the progressive detail that will be effective for the frame. + JxlProgressiveDetail SetPauseAtProgressive(JxlProgressiveDetail prog_detail) { + bool single_section = + frame_dim_.num_groups == 1 && frame_header_.passes.num_passes == 1; + if (frame_header_.frame_type != kSkipProgressive && + // If there's only one group and one pass, there is no separate section + // for DC and the entire full resolution image is available at once. + !single_section && + // If extra channels are encoded with modular without squeeze, they + // don't support DC. If the are encoded with squeeze, DC works in theory + // but the implementation may not yet correctly support this for Flush. + // Therefore, can't correctly pause for a progressive step if there is + // an extra channel (including alpha channel) + // TODO(firsching): Check if this is still the case. + decoded_->metadata()->extra_channel_info.empty() && + // DC is not guaranteed to be available in modular mode and may be a + // black image. If squeeze is used, it may be available depending on the + // current implementation. + // TODO(lode): do return DC if it's known that flushing at this point + // will produce a valid 1/8th downscaled image with modular encoding. + frame_header_.encoding == FrameEncoding::kVarDCT) { + progressive_detail_ = prog_detail; + } else { + progressive_detail_ = JxlProgressiveDetail::kFrames; + } + if (progressive_detail_ >= JxlProgressiveDetail::kPasses) { + for (size_t i = 1; i < frame_header_.passes.num_passes; ++i) { + passes_to_pause_.push_back(i); + } + } else if (progressive_detail_ >= JxlProgressiveDetail::kLastPasses) { + for (size_t i = 0; i < frame_header_.passes.num_downsample; ++i) { + passes_to_pause_.push_back(frame_header_.passes.last_pass[i] + 1); + } + // The format does not guarantee that these values are sorted. + std::sort(passes_to_pause_.begin(), passes_to_pause_.end()); + } + return progressive_detail_; + } + + size_t NextNumPassesToPause() const { + auto it = std::upper_bound(passes_to_pause_.begin(), passes_to_pause_.end(), + NumCompletePasses()); + return (it != passes_to_pause_.end() ? *it + : std::numeric_limits<size_t>::max()); + } + + // Sets the pixel callback or image buffer where the pixels will be decoded. + // + // @param undo_orientation: if true, indicates the frame decoder should apply + // the exif orientation to bring the image to the intended display + // orientation. + void SetImageOutput(const PixelCallback& pixel_callback, void* image_buffer, + size_t image_buffer_size, size_t xsize, size_t ysize, + JxlPixelFormat format, size_t bits_per_sample, + bool unpremul_alpha, bool undo_orientation) const { + dec_state_->width = xsize; + dec_state_->height = ysize; + dec_state_->main_output.format = format; + dec_state_->main_output.bits_per_sample = bits_per_sample; + dec_state_->main_output.callback = pixel_callback; + dec_state_->main_output.buffer = image_buffer; + dec_state_->main_output.buffer_size = image_buffer_size; + dec_state_->main_output.stride = GetStride(xsize, format); + const jxl::ExtraChannelInfo* alpha = + decoded_->metadata()->Find(jxl::ExtraChannel::kAlpha); + if (alpha && alpha->alpha_associated && unpremul_alpha) { + dec_state_->unpremul_alpha = true; + } + if (undo_orientation) { + dec_state_->undo_orientation = decoded_->metadata()->GetOrientation(); + if (static_cast<int>(dec_state_->undo_orientation) > 4) { + std::swap(dec_state_->width, dec_state_->height); + } + } + dec_state_->extra_output.clear(); +#if !JXL_HIGH_PRECISION + if (dec_state_->main_output.buffer && + (format.data_type == JXL_TYPE_UINT8) && (format.num_channels >= 3) && + !dec_state_->unpremul_alpha && + (dec_state_->undo_orientation == Orientation::kIdentity) && + decoded_->metadata()->xyb_encoded && + dec_state_->output_encoding_info.color_encoding.IsSRGB() && + dec_state_->output_encoding_info.all_default_opsin && + (dec_state_->output_encoding_info.desired_intensity_target == + dec_state_->output_encoding_info.orig_intensity_target) && + HasFastXYBTosRGB8() && frame_header_.needs_color_transform()) { + dec_state_->fast_xyb_srgb8_conversion = true; + } +#endif + } + + void AddExtraChannelOutput(void* buffer, size_t buffer_size, size_t xsize, + JxlPixelFormat format, size_t bits_per_sample) { + ImageOutput out; + out.format = format; + out.bits_per_sample = bits_per_sample; + out.buffer = buffer; + out.buffer_size = buffer_size; + out.stride = GetStride(xsize, format); + dec_state_->extra_output.push_back(out); + } + + private: + Status ProcessDCGlobal(BitReader* br); + Status ProcessDCGroup(size_t dc_group_id, BitReader* br); + void FinalizeDC(); + Status AllocateOutput(); + Status ProcessACGlobal(BitReader* br); + Status ProcessACGroup(size_t ac_group_id, BitReader* JXL_RESTRICT* br, + size_t num_passes, size_t thread, bool force_draw, + bool dc_only); + void MarkSections(const SectionInfo* sections, size_t num, + SectionStatus* section_status); + + // Allocates storage for parallel decoding using up to `num_threads` threads + // of up to `num_tasks` tasks. The value of `thread` passed to + // `GetStorageLocation` must be smaller than the `num_threads` value passed + // here. The value of `task` passed to `GetStorageLocation` must be smaller + // than the value of `num_tasks` passed here. + Status PrepareStorage(size_t num_threads, size_t num_tasks) { + size_t storage_size = std::min(num_threads, num_tasks); + if (storage_size > group_dec_caches_.size()) { + group_dec_caches_.resize(storage_size); + } + use_task_id_ = num_threads > num_tasks; + bool use_group_ids = (modular_frame_decoder_.UsesFullImage() && + (frame_header_.encoding == FrameEncoding::kVarDCT || + (frame_header_.flags & FrameHeader::kNoise))); + if (dec_state_->render_pipeline) { + JXL_RETURN_IF_ERROR(dec_state_->render_pipeline->PrepareForThreads( + storage_size, use_group_ids)); + } + return true; + } + + size_t GetStorageLocation(size_t thread, size_t task) { + if (use_task_id_) return task; + return thread; + } + + static size_t BytesPerChannel(JxlDataType data_type) { + return (data_type == JXL_TYPE_UINT8 ? 1u + : data_type == JXL_TYPE_FLOAT ? 4u + : 2u); + } + + static size_t GetStride(const size_t xsize, JxlPixelFormat format) { + size_t stride = + (xsize * BytesPerChannel(format.data_type) * format.num_channels); + if (format.align > 1) { + stride = (jxl::DivCeil(stride, format.align) * format.align); + } + return stride; + } + + PassesDecoderState* dec_state_; + ThreadPool* pool_; + std::vector<TocEntry> toc_; + uint64_t section_sizes_sum_; + // TODO(veluca): figure out the duplication between these and dec_state_. + FrameHeader frame_header_; + FrameDimensions frame_dim_; + ImageBundle* decoded_; + ModularFrameDecoder modular_frame_decoder_; + bool render_spotcolors_ = true; + bool coalescing_ = true; + + std::vector<uint8_t> processed_section_; + std::vector<uint8_t> decoded_passes_per_ac_group_; + std::vector<uint8_t> decoded_dc_groups_; + bool decoded_dc_global_; + bool decoded_ac_global_; + bool HasEverything() const; + bool finalized_dc_ = true; + size_t num_sections_done_ = 0; + bool is_finalized_ = true; + bool allocated_ = false; + + std::vector<GroupDecCache> group_dec_caches_; + + // Whether or not the task id should be used for storage indexing, instead of + // the thread id. + bool use_task_id_ = false; + + // Testing setting: whether or not to use the slow rendering pipeline. + bool use_slow_rendering_pipeline_; + + JxlProgressiveDetail progressive_detail_ = kFrames; + // Number of completed passes where section decoding should pause. + // Used for progressive details at least kLastPasses. + std::vector<int> passes_to_pause_; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_FRAME_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_group.cc b/third_party/jpeg-xl/lib/jxl/dec_group.cc new file mode 100644 index 0000000000..186318e63d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_group.cc @@ -0,0 +1,793 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_group.h" + +#include <stdint.h> +#include <string.h> + +#include <algorithm> +#include <memory> +#include <utility> + +#include "lib/jxl/frame_header.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_group.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/common.h" // kMaxNumPasses +#include "lib/jxl/convolve.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_transforms-inl.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer-inl.h" +#include "lib/jxl/quantizer.h" + +#ifndef LIB_JXL_DEC_GROUP_CC +#define LIB_JXL_DEC_GROUP_CC +namespace jxl { + +struct AuxOut; + +// Interface for reading groups for DecodeGroupImpl. +class GetBlock { + public: + virtual void StartRow(size_t by) = 0; + virtual Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, + size_t size, size_t log2_covered_blocks, + ACPtr block[3], ACType ac_type) = 0; + virtual ~GetBlock() {} +}; + +// Controls whether DecodeGroupImpl renders to pixels or not. +enum DrawMode { + // Render to pixels. + kDraw = 0, + // Don't render to pixels. + kDontDraw = 1, +}; + +} // namespace jxl +#endif // LIB_JXL_DEC_GROUP_CC + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::ShiftRight; + +using D = HWY_FULL(float); +using DU = HWY_FULL(uint32_t); +using DI = HWY_FULL(int32_t); +using DI16 = Rebind<int16_t, DI>; +constexpr D d; +constexpr DI di; +constexpr DI16 di16; + +// TODO(veluca): consider SIMDfying. +void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) { + for (size_t x = 0; x < 8; x++) { + for (size_t y = x + 1; y < 8; y++) { + std::swap(block[y * 8 + x], block[x * 8 + y]); + } + } +} + +template <ACType ac_type> +void DequantLane(Vec<D> scaled_dequant_x, Vec<D> scaled_dequant_y, + Vec<D> scaled_dequant_b, + const float* JXL_RESTRICT dequant_matrices, size_t size, + size_t k, Vec<D> x_cc_mul, Vec<D> b_cc_mul, + const float* JXL_RESTRICT biases, ACPtr qblock[3], + float* JXL_RESTRICT block) { + const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x); + const auto y_mul = + Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y); + const auto b_mul = + Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b); + + Vec<DI> quantized_x_int; + Vec<DI> quantized_y_int; + Vec<DI> quantized_b_int; + if (ac_type == ACType::k16) { + Rebind<int16_t, DI> di16; + quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k)); + quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k)); + quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k)); + } else { + quantized_x_int = Load(di, qblock[0].ptr32 + k); + quantized_y_int = Load(di, qblock[1].ptr32 + k); + quantized_b_int = Load(di, qblock[2].ptr32 + k); + } + + const auto dequant_x_cc = + Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul); + const auto dequant_y = + Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul); + const auto dequant_b_cc = + Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul); + + const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc); + const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc); + Store(dequant_x, d, block + k); + Store(dequant_y, d, block + size + k); + Store(dequant_b, d, block + 2 * size + k); +} + +template <ACType ac_type> +void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant, + float x_dm_multiplier, float b_dm_multiplier, Vec<D> x_cc_mul, + Vec<D> b_cc_mul, size_t kind, size_t size, + const Quantizer& quantizer, size_t covered_blocks, + const size_t* sbx, + const float* JXL_RESTRICT* JXL_RESTRICT dc_row, + size_t dc_stride, const float* JXL_RESTRICT biases, + ACPtr qblock[3], float* JXL_RESTRICT block, + float* JXL_RESTRICT scratch) { + const auto scaled_dequant_s = inv_global_scale / quant; + + const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier); + const auto scaled_dequant_y = Set(d, scaled_dequant_s); + const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier); + + const float* dequant_matrices = quantizer.DequantMatrix(kind, 0); + + for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) { + DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b, + dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases, + qblock, block); + } + for (size_t c = 0; c < 3; c++) { + LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride, + block + c * size, scratch); + } +} + +Status DecodeGroupImpl(const FrameHeader& frame_header, + GetBlock* JXL_RESTRICT get_block, + GroupDecCache* JXL_RESTRICT group_dec_cache, + PassesDecoderState* JXL_RESTRICT dec_state, + size_t thread, size_t group_idx, + RenderPipelineInput& render_pipeline_input, + ImageBundle* decoded, DrawMode draw) { + // TODO(veluca): investigate cache usage in this function. + const Rect block_rect = + dec_state->shared->frame_dim.BlockGroupRect(group_idx); + const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy; + + const size_t xsize_blocks = block_rect.xsize(); + const size_t ysize_blocks = block_rect.ysize(); + + const size_t dc_stride = dec_state->shared->dc->PixelsPerRow(); + + const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale(); + + const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling; + + size_t idct_stride[3]; + for (size_t c = 0; c < 3; c++) { + idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow(); + } + + HWY_ALIGN int32_t scaled_qtable[64 * 3]; + + ACType ac_type = dec_state->coefficients->Type(); + auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16> + : DequantBlock<ACType::k32>; + // Whether or not coefficients should be stored for future usage, and/or read + // from past usage. + bool accumulate = !dec_state->coefficients->IsEmpty(); + // Offset of the current block in the group. + size_t offset = 0; + + std::array<int, 3> jpeg_c_map; + bool jpeg_is_gray = false; + std::array<int, 3> dcoff = {}; + + // TODO(veluca): all of this should be done only once per image. + if (decoded->IsJPEG()) { + if (!dec_state->shared->cmap.IsJPEGCompatible()) { + return JXL_FAILURE("The CfL map is not JPEG-compatible"); + } + jpeg_is_gray = (decoded->jpeg_data->components.size() == 1); + jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray); + const std::vector<QuantEncoding>& qe = + dec_state->shared->matrices.encodings(); + if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW || + std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) { + return JXL_FAILURE( + "Quantization table is not a JPEG quantization table."); + } + for (size_t c = 0; c < 3; c++) { + if (frame_header.color_transform == ColorTransform::kNone) { + dcoff[c] = 1024 / (*qe[0].qraw.qtable)[64 * c]; + } + for (size_t i = 0; i < 64; i++) { + // Transpose the matrix, as it will be used on the transposed block. + int n = qe[0].qraw.qtable->at(64 + i); + int d = qe[0].qraw.qtable->at(64 * c + i); + if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) { + return JXL_FAILURE("Invalid JPEG quantization table"); + } + scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] = + (1 << kCFLFixedPointPrecision) * n / d; + } + } + } + + size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)}; + size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)}; + Rect r[3]; + for (size_t i = 0; i < 3; i++) { + r[i] = + Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i], + block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]); + if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(), + dec_state->shared->dc->Plane(i).ysize()})) { + return JXL_FAILURE("Frame dimensions are too big for the image."); + } + } + + for (size_t by = 0; by < ysize_blocks; ++by) { + get_block->StartRow(by); + size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]}; + + const int32_t* JXL_RESTRICT row_quant = + block_rect.ConstRow(dec_state->shared->raw_quant_field, by); + + const float* JXL_RESTRICT dc_rows[3] = { + r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]), + r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]), + r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]), + }; + + const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks; + AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by); + + const int8_t* JXL_RESTRICT row_cmap[3] = { + dec_state->shared->cmap.ytox_map.ConstRow(ty), + nullptr, + dec_state->shared->cmap.ytob_map.ConstRow(ty), + }; + + float* JXL_RESTRICT idct_row[3]; + int16_t* JXL_RESTRICT jpeg_row[3]; + for (size_t c = 0; c < 3; c++) { + idct_row[c] = render_pipeline_input.GetBuffer(c).second.Row( + render_pipeline_input.GetBuffer(c).first, sby[c] * kBlockDim); + if (decoded->IsJPEG()) { + auto& component = decoded->jpeg_data->components[jpeg_c_map[c]]; + jpeg_row[c] = + component.coeffs.data() + + (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) * + kDCTBlockSize; + } + } + + size_t bx = 0; + for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks); + tx++) { + size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks; + auto x_cc_mul = + Set(d, dec_state->shared->cmap.YtoXRatio(row_cmap[0][abs_tx])); + auto b_cc_mul = + Set(d, dec_state->shared->cmap.YtoBRatio(row_cmap[2][abs_tx])); + // Increment bx by llf_x because those iterations would otherwise + // immediately continue (!IsFirstBlock). Reduces mispredictions. + for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) { + size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]}; + AcStrategy acs = acs_row[bx]; + const size_t llf_x = acs.covered_blocks_x(); + + // Can only happen in the second or lower rows of a varblock. + if (JXL_UNLIKELY(!acs.IsFirstBlock())) { + bx += llf_x; + continue; + } + const size_t log2_covered_blocks = acs.log2_covered_blocks(); + + const size_t covered_blocks = 1 << log2_covered_blocks; + const size_t size = covered_blocks * kDCTBlockSize; + + ACPtr qblock[3]; + if (accumulate) { + for (size_t c = 0; c < 3; c++) { + qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset); + } + } else { + // No point in reading from bitstream without accumulating and not + // drawing. + JXL_ASSERT(draw == kDraw); + if (ac_type == ACType::k16) { + memset(group_dec_cache->dec_group_qblock16, 0, + size * 3 * sizeof(int16_t)); + for (size_t c = 0; c < 3; c++) { + qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size; + } + } else { + memset(group_dec_cache->dec_group_qblock, 0, + size * 3 * sizeof(int32_t)); + for (size_t c = 0; c < 3; c++) { + qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size; + } + } + } + JXL_RETURN_IF_ERROR(get_block->LoadBlock( + bx, by, acs, size, log2_covered_blocks, qblock, ac_type)); + offset += size; + if (draw == kDontDraw) { + bx += llf_x; + continue; + } + + if (JXL_UNLIKELY(decoded->IsJPEG())) { + if (acs.Strategy() != AcStrategy::Type::DCT) { + return JXL_FAILURE( + "Can only decode to JPEG if only DCT-8 is used."); + } + + HWY_ALIGN int32_t transposed_dct_y[64]; + for (size_t c : {1, 0, 2}) { + // Propagate only Y for grayscale. + if (jpeg_is_gray && c != 1) { + continue; + } + if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { + continue; + } + int16_t* JXL_RESTRICT jpeg_pos = + jpeg_row[c] + sbx[c] * kDCTBlockSize; + // JPEG XL is transposed, JPEG is not. + auto transposed_dct = qblock[c].ptr32; + Transpose8x8InPlace(transposed_dct); + // No CfL - no need to store the y block converted to integers. + if (!cs.Is444() || + (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) { + for (size_t i = 0; i < 64; i += Lanes(d)) { + const auto ini = Load(di, transposed_dct + i); + const auto ini16 = DemoteTo(di16, ini); + StoreU(ini16, di16, jpeg_pos + i); + } + } else if (c == 1) { + // Y channel: save for restoring X/B, but nothing else to do. + for (size_t i = 0; i < 64; i += Lanes(d)) { + const auto ini = Load(di, transposed_dct + i); + Store(ini, di, transposed_dct_y + i); + const auto ini16 = DemoteTo(di16, ini); + StoreU(ini16, di16, jpeg_pos + i); + } + } else { + // transposed_dct_y contains the y channel block, transposed. + const auto scale = Set( + di, dec_state->shared->cmap.RatioJPEG(row_cmap[c][abs_tx])); + const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1)); + for (int i = 0; i < 64; i += Lanes(d)) { + auto in = Load(di, transposed_dct + i); + auto in_y = Load(di, transposed_dct_y + i); + auto qt = Load(di, scaled_qtable + c * size + i); + auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>( + Add(Mul(qt, scale), round)); + auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>( + Add(Mul(in_y, coeff_scale), round)); + StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i); + } + } + jpeg_pos[0] = + Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047); + } + } else { + HWY_ALIGN float* const block = group_dec_cache->dec_group_block; + // Dequantize and add predictions. + dequant_block( + acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier, + dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(), + size, dec_state->shared->quantizer, + acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows, + dc_stride, + dec_state->output_encoding_info.opsin_params.quant_biases, qblock, + block, group_dec_cache->scratch_space); + + for (size_t c : {1, 0, 2}) { + if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { + continue; + } + // IDCT + float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim; + TransformToPixels(acs.Strategy(), block + c * size, idct_pos, + idct_stride[c], group_dec_cache->scratch_space); + } + } + bx += llf_x; + } + } + } + return true; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { +// Decode quantized AC coefficients of DCT blocks. +// LLF components in the output block will not be modified. +template <ACType ac_type, bool uses_lz77> +Status DecodeACVarBlock(size_t ctx_offset, size_t log2_covered_blocks, + int32_t* JXL_RESTRICT row_nzeros, + const int32_t* JXL_RESTRICT row_nzeros_top, + size_t nzeros_stride, size_t c, size_t bx, size_t by, + size_t lbx, AcStrategy acs, + const coeff_order_t* JXL_RESTRICT coeff_order, + BitReader* JXL_RESTRICT br, + ANSSymbolReader* JXL_RESTRICT decoder, + const std::vector<uint8_t>& context_map, + const uint8_t* qdc_row, const int32_t* qf_row, + const BlockCtxMap& block_ctx_map, ACPtr block, + size_t shift = 0) { + // Equal to number of LLF coefficients. + const size_t covered_blocks = 1 << log2_covered_blocks; + const size_t size = covered_blocks * kDCTBlockSize; + int32_t predicted_nzeros = + PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32); + + size_t ord = kStrategyOrder[acs.RawStrategy()]; + const coeff_order_t* JXL_RESTRICT order = + &coeff_order[CoeffOrderOffset(ord, c)]; + + size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c); + const int32_t nzero_ctx = + block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset; + + size_t nzeros = + decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map); + if (nzeros > size - covered_blocks) { + return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS + " 8x8 blocks", + nzeros, covered_blocks); + } + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + row_nzeros[bx + x + y * nzeros_stride] = + (nzeros + covered_blocks - 1) >> log2_covered_blocks; + } + } + + const size_t histo_offset = + ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx); + + size_t prev = (nzeros > size / 16 ? 0 : 1); + for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) { + const size_t ctx = + histo_offset + ZeroDensityContext(nzeros, k, covered_blocks, + log2_covered_blocks, prev); + const size_t u_coeff = + decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map); + // Hand-rolled version of UnpackSigned, shifting before the conversion to + // signed integer to avoid undefined behavior of shifting negative numbers. + const size_t magnitude = u_coeff >> 1; + const size_t neg_sign = (~u_coeff) & 1; + const intptr_t coeff = + static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift); + if (ac_type == ACType::k16) { + block.ptr16[order[k]] += coeff; + } else { + block.ptr32[order[k]] += coeff; + } + prev = static_cast<size_t>(u_coeff != 0); + nzeros -= prev; + } + if (JXL_UNLIKELY(nzeros != 0)) { + return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS + ", should be 0. Block (%" PRIuS ", %" PRIuS + "), channel %" PRIuS, + nzeros, bx, by, c); + } + + return true; +} + +// Structs used by DecodeGroupImpl to get a quantized block. +// GetBlockFromBitstream uses ANS decoding (and thus keeps track of row +// pointers in row_nzeros), GetBlockFromEncoder simply reads the coefficient +// image provided by the encoder. + +struct GetBlockFromBitstream : public GetBlock { + void StartRow(size_t by) override { + qf_row = rect.ConstRow(*qf, by); + for (size_t c = 0; c < 3; c++) { + size_t sby = by >> vshift[c]; + quant_dc_row = quant_dc->ConstRow(rect.y0() + by) + rect.x0(); + for (size_t i = 0; i < num_passes; i++) { + row_nzeros[i][c] = group_dec_cache->num_nzeroes[i].PlaneRow(c, sby); + row_nzeros_top[i][c] = + sby == 0 + ? nullptr + : group_dec_cache->num_nzeroes[i].ConstPlaneRow(c, sby - 1); + } + } + } + + Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, + size_t log2_covered_blocks, ACPtr block[3], + ACType ac_type) override { + ; + for (size_t c : {1, 0, 2}) { + size_t sbx = bx >> hshift[c]; + size_t sby = by >> vshift[c]; + if (JXL_UNLIKELY((sbx << hshift[c] != bx) || (sby << vshift[c] != by))) { + continue; + } + + for (size_t pass = 0; JXL_UNLIKELY(pass < num_passes); pass++) { + auto decode_ac_varblock = + decoders[pass].UsesLZ77() + ? (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 1> + : DecodeACVarBlock<ACType::k32, 1>) + : (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 0> + : DecodeACVarBlock<ACType::k32, 0>); + JXL_RETURN_IF_ERROR(decode_ac_varblock( + ctx_offset[pass], log2_covered_blocks, row_nzeros[pass][c], + row_nzeros_top[pass][c], nzeros_stride, c, sbx, sby, bx, acs, + &coeff_orders[pass * coeff_order_size], readers[pass], + &decoders[pass], context_map[pass], quant_dc_row, qf_row, + *block_ctx_map, block[c], shift_for_pass[pass])); + } + } + return true; + } + + Status Init(const FrameHeader& frame_header, + BitReader* JXL_RESTRICT* JXL_RESTRICT readers, size_t num_passes, + size_t group_idx, size_t histo_selector_bits, const Rect& rect, + GroupDecCache* JXL_RESTRICT group_dec_cache, + PassesDecoderState* dec_state, size_t first_pass) { + for (size_t i = 0; i < 3; i++) { + hshift[i] = frame_header.chroma_subsampling.HShift(i); + vshift[i] = frame_header.chroma_subsampling.VShift(i); + } + this->coeff_order_size = dec_state->shared->coeff_order_size; + this->coeff_orders = + dec_state->shared->coeff_orders.data() + first_pass * coeff_order_size; + this->context_map = dec_state->context_map.data() + first_pass; + this->readers = readers; + this->num_passes = num_passes; + this->shift_for_pass = frame_header.passes.shift + first_pass; + this->group_dec_cache = group_dec_cache; + this->rect = rect; + block_ctx_map = &dec_state->shared->block_ctx_map; + qf = &dec_state->shared->raw_quant_field; + quant_dc = &dec_state->shared->quant_dc; + + for (size_t pass = 0; pass < num_passes; pass++) { + // Select which histogram set to use among those of the current pass. + size_t cur_histogram = 0; + if (histo_selector_bits != 0) { + cur_histogram = readers[pass]->ReadBits(histo_selector_bits); + } + if (cur_histogram >= dec_state->shared->num_histograms) { + return JXL_FAILURE("Invalid histogram selector"); + } + ctx_offset[pass] = cur_histogram * block_ctx_map->NumACContexts(); + + decoders[pass] = + ANSSymbolReader(&dec_state->code[pass + first_pass], readers[pass]); + } + nzeros_stride = group_dec_cache->num_nzeroes[0].PixelsPerRow(); + for (size_t i = 0; i < num_passes; i++) { + JXL_ASSERT( + nzeros_stride == + static_cast<size_t>(group_dec_cache->num_nzeroes[i].PixelsPerRow())); + } + return true; + } + + const uint32_t* shift_for_pass = nullptr; // not owned + const coeff_order_t* JXL_RESTRICT coeff_orders; + size_t coeff_order_size; + const std::vector<uint8_t>* JXL_RESTRICT context_map; + ANSSymbolReader decoders[kMaxNumPasses]; + BitReader* JXL_RESTRICT* JXL_RESTRICT readers; + size_t num_passes; + size_t ctx_offset[kMaxNumPasses]; + size_t nzeros_stride; + int32_t* JXL_RESTRICT row_nzeros[kMaxNumPasses][3]; + const int32_t* JXL_RESTRICT row_nzeros_top[kMaxNumPasses][3]; + GroupDecCache* JXL_RESTRICT group_dec_cache; + const BlockCtxMap* block_ctx_map; + const ImageI* qf; + const ImageB* quant_dc; + const int32_t* qf_row; + const uint8_t* quant_dc_row; + Rect rect; + size_t hshift[3], vshift[3]; +}; + +struct GetBlockFromEncoder : public GetBlock { + void StartRow(size_t by) override {} + + Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, + size_t log2_covered_blocks, ACPtr block[3], + ACType ac_type) override { + JXL_DASSERT(ac_type == ACType::k32); + for (size_t c = 0; c < 3; c++) { + // for each pass + for (size_t i = 0; i < quantized_ac->size(); i++) { + for (size_t k = 0; k < size; k++) { + // TODO(veluca): SIMD. + block[c].ptr32[k] += + rows[i][c][offset + k] * (1 << shift_for_pass[i]); + } + } + } + offset += size; + return true; + } + + GetBlockFromEncoder(const std::vector<std::unique_ptr<ACImage>>& ac, + size_t group_idx, const uint32_t* shift_for_pass) + : quantized_ac(&ac), shift_for_pass(shift_for_pass) { + // TODO(veluca): not supported with chroma subsampling. + for (size_t i = 0; i < quantized_ac->size(); i++) { + JXL_CHECK((*quantized_ac)[i]->Type() == ACType::k32); + for (size_t c = 0; c < 3; c++) { + rows[i][c] = (*quantized_ac)[i]->PlaneRow(c, group_idx, 0).ptr32; + } + } + } + + const std::vector<std::unique_ptr<ACImage>>* JXL_RESTRICT quantized_ac; + size_t offset = 0; + const int32_t* JXL_RESTRICT rows[kMaxNumPasses][3]; + const uint32_t* shift_for_pass = nullptr; // not owned +}; + +HWY_EXPORT(DecodeGroupImpl); + +} // namespace + +Status DecodeGroup(const FrameHeader& frame_header, + BitReader* JXL_RESTRICT* JXL_RESTRICT readers, + size_t num_passes, size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread, + RenderPipelineInput& render_pipeline_input, + ImageBundle* JXL_RESTRICT decoded, size_t first_pass, + bool force_draw, bool dc_only, bool* should_run_pipeline) { + DrawMode draw = + (num_passes + first_pass == frame_header.passes.num_passes) || force_draw + ? kDraw + : kDontDraw; + + if (should_run_pipeline) { + *should_run_pipeline = draw != kDontDraw; + } + + if (draw == kDraw && num_passes == 0 && first_pass == 0) { + group_dec_cache->InitDCBufferOnce(); + const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling; + for (size_t c : {0, 1, 2}) { + size_t hs = cs.HShift(c); + size_t vs = cs.VShift(c); + // We reuse filter_input_storage here as it is not currently in use. + const Rect src_rect_precs = + dec_state->shared->frame_dim.BlockGroupRect(group_idx); + const Rect src_rect = + Rect(src_rect_precs.x0() >> hs, src_rect_precs.y0() >> vs, + src_rect_precs.xsize() >> hs, src_rect_precs.ysize() >> vs); + const Rect copy_rect(kRenderPipelineXOffset, 2, src_rect.xsize(), + src_rect.ysize()); + CopyImageToWithPadding(src_rect, dec_state->shared->dc->Plane(c), 2, + copy_rect, &group_dec_cache->dc_buffer); + // Mirrorpad. Interleaving left and right padding ensures that padding + // works out correctly even for images with DC size of 1. + for (size_t y = 0; y < src_rect.ysize() + 4; y++) { + size_t xend = kRenderPipelineXOffset + + (dec_state->shared->dc->Plane(c).xsize() >> hs) - + src_rect.x0(); + for (size_t ix = 0; ix < 2; ix++) { + if (src_rect.x0() == 0) { + group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset - ix - 1] = + group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset + ix]; + } + if (src_rect.x0() + src_rect.xsize() + 2 >= + (dec_state->shared->dc->xsize() >> hs)) { + group_dec_cache->dc_buffer.Row(y)[xend + ix] = + group_dec_cache->dc_buffer.Row(y)[xend - ix - 1]; + } + } + } + Rect dst_rect = render_pipeline_input.GetBuffer(c).second; + ImageF* upsampling_dst = render_pipeline_input.GetBuffer(c).first; + JXL_ASSERT(dst_rect.IsInside(*upsampling_dst)); + + RenderPipelineStage::RowInfo input_rows(1, std::vector<float*>(5)); + RenderPipelineStage::RowInfo output_rows(1, std::vector<float*>(8)); + for (size_t y = src_rect.y0(); y < src_rect.y0() + src_rect.ysize(); + y++) { + for (ssize_t iy = 0; iy < 5; iy++) { + input_rows[0][iy] = group_dec_cache->dc_buffer.Row( + Mirror(ssize_t(y) + iy - 2, + dec_state->shared->dc->Plane(c).ysize() >> vs) + + 2 - src_rect.y0()); + } + for (size_t iy = 0; iy < 8; iy++) { + output_rows[0][iy] = + dst_rect.Row(upsampling_dst, ((y - src_rect.y0()) << 3) + iy) - + kRenderPipelineXOffset; + } + // Arguments set to 0/nullptr are not used. + dec_state->upsampler8x->ProcessRow(input_rows, output_rows, + /*xextra=*/0, src_rect.xsize(), 0, 0, + thread); + } + } + return true; + } + + size_t histo_selector_bits = 0; + if (dc_only) { + JXL_ASSERT(num_passes == 0); + } else { + JXL_ASSERT(dec_state->shared->num_histograms > 0); + histo_selector_bits = CeilLog2Nonzero(dec_state->shared->num_histograms); + } + + auto get_block = jxl::make_unique<GetBlockFromBitstream>(); + JXL_RETURN_IF_ERROR(get_block->Init( + frame_header, readers, num_passes, group_idx, histo_selector_bits, + dec_state->shared->frame_dim.BlockGroupRect(group_idx), group_dec_cache, + dec_state, first_pass)); + + JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)( + frame_header, get_block.get(), group_dec_cache, dec_state, thread, + group_idx, render_pipeline_input, decoded, draw)); + + for (size_t pass = 0; pass < num_passes; pass++) { + if (!get_block->decoders[pass].CheckANSFinalState()) { + return JXL_FAILURE("ANS checksum failure."); + } + } + return true; +} + +Status DecodeGroupForRoundtrip(const FrameHeader& frame_header, + const std::vector<std::unique_ptr<ACImage>>& ac, + size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, + size_t thread, + RenderPipelineInput& render_pipeline_input, + ImageBundle* JXL_RESTRICT decoded, + AuxOut* aux_out) { + GetBlockFromEncoder get_block(ac, group_idx, frame_header.passes.shift); + group_dec_cache->InitOnce( + /*num_passes=*/0, + /*used_acs=*/(1u << AcStrategy::kNumValidStrategies) - 1); + + return HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)( + frame_header, &get_block, group_dec_cache, dec_state, thread, group_idx, + render_pipeline_input, decoded, kDraw); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dec_group.h b/third_party/jpeg-xl/lib/jxl/dec_group.h new file mode 100644 index 0000000000..7f3ba5e868 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_group.h @@ -0,0 +1,47 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_GROUP_H_ +#define LIB_JXL_DEC_GROUP_H_ + +#include <cstddef> +#include <memory> +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/render_pipeline/render_pipeline.h" + +namespace jxl { + +struct AuxOut; + +Status DecodeGroup(const FrameHeader& frame_header, + BitReader* JXL_RESTRICT* JXL_RESTRICT readers, + size_t num_passes, size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread, + RenderPipelineInput& render_pipeline_input, + ImageBundle* JXL_RESTRICT decoded, size_t first_pass, + bool force_draw, bool dc_only, bool* should_run_pipeline); + +Status DecodeGroupForRoundtrip(const FrameHeader& frame_header, + const std::vector<std::unique_ptr<ACImage>>& ac, + size_t group_idx, + PassesDecoderState* JXL_RESTRICT dec_state, + GroupDecCache* JXL_RESTRICT group_dec_cache, + size_t thread, + RenderPipelineInput& render_pipeline_input, + ImageBundle* JXL_RESTRICT decoded, + AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_DEC_GROUP_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_group_border.cc b/third_party/jpeg-xl/lib/jxl/dec_group_border.cc new file mode 100644 index 0000000000..4bee3ae6ef --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_group_border.cc @@ -0,0 +1,184 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_group_border.h" + +#include <atomic> + +namespace jxl { + +void GroupBorderAssigner::Init(const FrameDimensions& frame_dim) { + frame_dim_ = frame_dim; + size_t num_corners = + (frame_dim_.xsize_groups + 1) * (frame_dim_.ysize_groups + 1); + counters_.reset(new std::atomic<uint8_t>[num_corners]); + // Initialize counters. + for (size_t y = 0; y < frame_dim_.ysize_groups + 1; y++) { + for (size_t x = 0; x < frame_dim_.xsize_groups + 1; x++) { + // Counters at image borders don't have anything on the other side, we + // pre-fill their value to have more uniform handling afterwards. + uint8_t init_value = 0; + if (x == 0) { + init_value |= kTopLeft | kBottomLeft; + } + if (x == frame_dim_.xsize_groups) { + init_value |= kTopRight | kBottomRight; + } + if (y == 0) { + init_value |= kTopLeft | kTopRight; + } + if (y == frame_dim_.ysize_groups) { + init_value |= kBottomLeft | kBottomRight; + } + counters_[y * (frame_dim_.xsize_groups + 1) + x] = init_value; + } + } +} + +void GroupBorderAssigner::ClearDone(size_t group_id) { + size_t x = group_id % frame_dim_.xsize_groups; + size_t y = group_id / frame_dim_.xsize_groups; + size_t top_left_idx = y * (frame_dim_.xsize_groups + 1) + x; + size_t top_right_idx = y * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_right_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_left_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x; + counters_[top_left_idx].fetch_and(~kBottomRight); + counters_[top_right_idx].fetch_and(~kBottomLeft); + counters_[bottom_left_idx].fetch_and(~kTopRight); + counters_[bottom_right_idx].fetch_and(~kTopLeft); +} + +// Looking at each corner between groups, we can guarantee that the four +// involved groups will agree between each other regarding the order in which +// each of the four groups terminated. Thus, the last of the four groups +// gets the responsibility of handling the corner. For borders, every border +// is assigned to its top corner (for vertical borders) or to its left corner +// (for horizontal borders): the order as seen on those corners will decide who +// handles that border. + +void GroupBorderAssigner::GroupDone(size_t group_id, size_t padx, size_t pady, + Rect* rects_to_finalize, + size_t* num_to_finalize) { + size_t x = group_id % frame_dim_.xsize_groups; + size_t y = group_id / frame_dim_.xsize_groups; + Rect block_rect(x * frame_dim_.group_dim / kBlockDim, + y * frame_dim_.group_dim / kBlockDim, + frame_dim_.group_dim / kBlockDim, + frame_dim_.group_dim / kBlockDim, frame_dim_.xsize_blocks, + frame_dim_.ysize_blocks); + + size_t top_left_idx = y * (frame_dim_.xsize_groups + 1) + x; + size_t top_right_idx = y * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_right_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x + 1; + size_t bottom_left_idx = (y + 1) * (frame_dim_.xsize_groups + 1) + x; + + auto fetch_status = [this](size_t idx, uint8_t bit) { + // Note that the acq-rel semantics of this fetch are actually needed to + // ensure that the pixel data of the group is already written to memory. + size_t status = counters_[idx].fetch_or(bit); + JXL_DASSERT((bit & status) == 0); + return bit | status; + }; + + size_t top_left_status = fetch_status(top_left_idx, kBottomRight); + size_t top_right_status = fetch_status(top_right_idx, kBottomLeft); + size_t bottom_right_status = fetch_status(bottom_right_idx, kTopLeft); + size_t bottom_left_status = fetch_status(bottom_left_idx, kTopRight); + + size_t x1 = block_rect.x0() + block_rect.xsize(); + size_t y1 = block_rect.y0() + block_rect.ysize(); + + bool is_last_group_x = frame_dim_.xsize_groups == x + 1; + bool is_last_group_y = frame_dim_.ysize_groups == y + 1; + + // Start of border of neighbouring group, end of border of this group, start + // of border of this group (on the other side), end of border of next group. + size_t xpos[4] = { + block_rect.x0() == 0 ? 0 : block_rect.x0() * kBlockDim - padx, + block_rect.x0() == 0 + ? 0 + : std::min(frame_dim_.xsize, block_rect.x0() * kBlockDim + padx), + is_last_group_x ? frame_dim_.xsize : x1 * kBlockDim - padx, + std::min(frame_dim_.xsize, x1 * kBlockDim + padx)}; + size_t ypos[4] = { + block_rect.y0() == 0 ? 0 : block_rect.y0() * kBlockDim - pady, + block_rect.y0() == 0 + ? 0 + : std::min(frame_dim_.ysize, block_rect.y0() * kBlockDim + pady), + is_last_group_y ? frame_dim_.ysize : y1 * kBlockDim - pady, + std::min(frame_dim_.ysize, y1 * kBlockDim + pady)}; + + *num_to_finalize = 0; + auto append_rect = [&](size_t x0, size_t x1, size_t y0, size_t y1) { + Rect rect(xpos[x0], ypos[y0], xpos[x1] - xpos[x0], ypos[y1] - ypos[y0]); + if (rect.xsize() == 0 || rect.ysize() == 0) return; + JXL_DASSERT(*num_to_finalize < kMaxToFinalize); + rects_to_finalize[(*num_to_finalize)++] = rect; + }; + + // Because of how group borders are assigned, it is impossible that we need to + // process the left and right side of some area but not the center area. Thus, + // we compute the first/last part to process in every horizontal strip and + // merge them together. We first collect a mask of what parts should be + // processed. + // We do this horizontally rather than vertically because horizontal borders + // are larger. + bool available_parts_mask[3][3] = {}; // [x][y] + // Center + available_parts_mask[1][1] = true; + // Corners + if (top_left_status == 0xF) available_parts_mask[0][0] = true; + if (top_right_status == 0xF) available_parts_mask[2][0] = true; + if (bottom_right_status == 0xF) available_parts_mask[2][2] = true; + if (bottom_left_status == 0xF) available_parts_mask[0][2] = true; + // Other borders + if (top_left_status & kTopRight) available_parts_mask[1][0] = true; + if (top_left_status & kBottomLeft) available_parts_mask[0][1] = true; + if (top_right_status & kBottomRight) available_parts_mask[2][1] = true; + if (bottom_left_status & kBottomRight) available_parts_mask[1][2] = true; + + // Collect horizontal ranges. + constexpr size_t kNoSegment = 3; + std::pair<size_t, size_t> horizontal_segments[3] = {{kNoSegment, kNoSegment}, + {kNoSegment, kNoSegment}, + {kNoSegment, kNoSegment}}; + for (size_t y = 0; y < 3; y++) { + for (size_t x = 0; x < 3; x++) { + if (!available_parts_mask[x][y]) continue; + JXL_DASSERT(horizontal_segments[y].second == kNoSegment || + horizontal_segments[y].second == x); + JXL_DASSERT((horizontal_segments[y].first == kNoSegment) == + (horizontal_segments[y].second == kNoSegment)); + if (horizontal_segments[y].first == kNoSegment) { + horizontal_segments[y].first = x; + } + horizontal_segments[y].second = x + 1; + } + } + if (horizontal_segments[0] == horizontal_segments[1] && + horizontal_segments[0] == horizontal_segments[2]) { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 3); + } else if (horizontal_segments[0] == horizontal_segments[1]) { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 2); + append_rect(horizontal_segments[2].first, horizontal_segments[2].second, 2, + 3); + } else if (horizontal_segments[1] == horizontal_segments[2]) { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 1); + append_rect(horizontal_segments[1].first, horizontal_segments[1].second, 1, + 3); + } else { + append_rect(horizontal_segments[0].first, horizontal_segments[0].second, 0, + 1); + append_rect(horizontal_segments[1].first, horizontal_segments[1].second, 1, + 2); + append_rect(horizontal_segments[2].first, horizontal_segments[2].second, 2, + 3); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_group_border.h b/third_party/jpeg-xl/lib/jxl/dec_group_border.h new file mode 100644 index 0000000000..cb3ecbefae --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_group_border.h @@ -0,0 +1,47 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_GROUP_BORDER_H_ +#define LIB_JXL_DEC_GROUP_BORDER_H_ + +#include <stddef.h> + +#include <atomic> + +#include "lib/jxl/base/arch_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/image.h" + +namespace jxl { + +class GroupBorderAssigner { + public: + // Prepare the GroupBorderAssigner to handle a given frame. + void Init(const FrameDimensions& frame_dim); + // Marks a group as done, and returns the (at most 3) rects to run + // FinalizeImageRect on. `block_rect` must be the rect corresponding + // to the given `group_id`, measured in blocks. + void GroupDone(size_t group_id, size_t padx, size_t pady, + Rect* rects_to_finalize, size_t* num_to_finalize); + // Marks a group as not-done, for running re-paints. + void ClearDone(size_t group_id); + + static constexpr size_t kMaxToFinalize = 3; + + private: + FrameDimensions frame_dim_; + std::unique_ptr<std::atomic<uint8_t>[]> counters_; + + // Constants to identify group positions relative to the corners. + static constexpr uint8_t kTopLeft = 0x01; + static constexpr uint8_t kTopRight = 0x02; + static constexpr uint8_t kBottomRight = 0x04; + static constexpr uint8_t kBottomLeft = 0x08; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_GROUP_BORDER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_huffman.cc b/third_party/jpeg-xl/lib/jxl/dec_huffman.cc new file mode 100644 index 0000000000..05b275773a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_huffman.cc @@ -0,0 +1,255 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_huffman.h" + +#include <string.h> /* for memset */ + +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/huffman_table.h" + +namespace jxl { + +static const int kCodeLengthCodes = 18; +static const uint8_t kCodeLengthCodeOrder[kCodeLengthCodes] = { + 1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15, +}; +static const uint8_t kDefaultCodeLength = 8; +static const uint8_t kCodeLengthRepeatCode = 16; + +int ReadHuffmanCodeLengths(const uint8_t* code_length_code_lengths, + int num_symbols, uint8_t* code_lengths, + BitReader* br) { + int symbol = 0; + uint8_t prev_code_len = kDefaultCodeLength; + int repeat = 0; + uint8_t repeat_code_len = 0; + int space = 32768; + HuffmanCode table[32]; + + uint16_t counts[16] = {0}; + for (int i = 0; i < kCodeLengthCodes; ++i) { + ++counts[code_length_code_lengths[i]]; + } + if (!BuildHuffmanTable(table, 5, code_length_code_lengths, kCodeLengthCodes, + &counts[0])) { + return 0; + } + + while (symbol < num_symbols && space > 0) { + const HuffmanCode* p = table; + uint8_t code_len; + br->Refill(); + p += br->PeekFixedBits<5>(); + br->Consume(p->bits); + code_len = (uint8_t)p->value; + if (code_len < kCodeLengthRepeatCode) { + repeat = 0; + code_lengths[symbol++] = code_len; + if (code_len != 0) { + prev_code_len = code_len; + space -= 32768u >> code_len; + } + } else { + const int extra_bits = code_len - 14; + int old_repeat; + int repeat_delta; + uint8_t new_len = 0; + if (code_len == kCodeLengthRepeatCode) { + new_len = prev_code_len; + } + if (repeat_code_len != new_len) { + repeat = 0; + repeat_code_len = new_len; + } + old_repeat = repeat; + if (repeat > 0) { + repeat -= 2; + repeat <<= extra_bits; + } + repeat += (int)br->ReadBits(extra_bits) + 3; + repeat_delta = repeat - old_repeat; + if (symbol + repeat_delta > num_symbols) { + return 0; + } + memset(&code_lengths[symbol], repeat_code_len, (size_t)repeat_delta); + symbol += repeat_delta; + if (repeat_code_len != 0) { + space -= repeat_delta << (15 - repeat_code_len); + } + } + } + if (space != 0) { + return 0; + } + memset(&code_lengths[symbol], 0, (size_t)(num_symbols - symbol)); + return true; +} + +static JXL_INLINE bool ReadSimpleCode(size_t alphabet_size, BitReader* br, + HuffmanCode* table) { + size_t max_bits = + (alphabet_size > 1u) ? FloorLog2Nonzero(alphabet_size - 1u) + 1 : 0; + + size_t num_symbols = br->ReadFixedBits<2>() + 1; + + uint16_t symbols[4] = {0}; + for (size_t i = 0; i < num_symbols; ++i) { + uint16_t symbol = br->ReadBits(max_bits); + if (symbol >= alphabet_size) { + return false; + } + symbols[i] = symbol; + } + + for (size_t i = 0; i < num_symbols - 1; ++i) { + for (size_t j = i + 1; j < num_symbols; ++j) { + if (symbols[i] == symbols[j]) return false; + } + } + + // 4 symbols have to option to encode. + if (num_symbols == 4) num_symbols += br->ReadFixedBits<1>(); + + const auto swap_symbols = [&symbols](size_t i, size_t j) { + uint16_t t = symbols[j]; + symbols[j] = symbols[i]; + symbols[i] = t; + }; + + size_t table_size = 1; + switch (num_symbols) { + case 1: + table[0] = {0, symbols[0]}; + break; + case 2: + if (symbols[0] > symbols[1]) swap_symbols(0, 1); + table[0] = {1, symbols[0]}; + table[1] = {1, symbols[1]}; + table_size = 2; + break; + case 3: + if (symbols[1] > symbols[2]) swap_symbols(1, 2); + table[0] = {1, symbols[0]}; + table[2] = {1, symbols[0]}; + table[1] = {2, symbols[1]}; + table[3] = {2, symbols[2]}; + table_size = 4; + break; + case 4: { + for (size_t i = 0; i < 3; ++i) { + for (size_t j = i + 1; j < 4; ++j) { + if (symbols[i] > symbols[j]) swap_symbols(i, j); + } + } + table[0] = {2, symbols[0]}; + table[2] = {2, symbols[1]}; + table[1] = {2, symbols[2]}; + table[3] = {2, symbols[3]}; + table_size = 4; + break; + } + case 5: { + if (symbols[2] > symbols[3]) swap_symbols(2, 3); + table[0] = {1, symbols[0]}; + table[1] = {2, symbols[1]}; + table[2] = {1, symbols[0]}; + table[3] = {3, symbols[2]}; + table[4] = {1, symbols[0]}; + table[5] = {2, symbols[1]}; + table[6] = {1, symbols[0]}; + table[7] = {3, symbols[3]}; + table_size = 8; + break; + } + default: { + // Unreachable. + return false; + } + } + + const uint32_t goal_size = 1u << kHuffmanTableBits; + while (table_size != goal_size) { + memcpy(&table[table_size], &table[0], + (size_t)table_size * sizeof(table[0])); + table_size <<= 1; + } + + return true; +} + +bool HuffmanDecodingData::ReadFromBitStream(size_t alphabet_size, + BitReader* br) { + if (alphabet_size > (1 << PREFIX_MAX_BITS)) return false; + + /* simple_code_or_skip is used as follows: + 1 for simple code; + 0 for no skipping, 2 skips 2 code lengths, 3 skips 3 code lengths */ + uint32_t simple_code_or_skip = br->ReadFixedBits<2>(); + if (simple_code_or_skip == 1u) { + table_.resize(1u << kHuffmanTableBits); + return ReadSimpleCode(alphabet_size, br, table_.data()); + } + + std::vector<uint8_t> code_lengths(alphabet_size, 0); + uint8_t code_length_code_lengths[kCodeLengthCodes] = {0}; + int space = 32; + int num_codes = 0; + /* Static Huffman code for the code length code lengths */ + static const HuffmanCode huff[16] = { + {2, 0}, {2, 4}, {2, 3}, {3, 2}, {2, 0}, {2, 4}, {2, 3}, {4, 1}, + {2, 0}, {2, 4}, {2, 3}, {3, 2}, {2, 0}, {2, 4}, {2, 3}, {4, 5}, + }; + for (size_t i = simple_code_or_skip; i < kCodeLengthCodes && space > 0; ++i) { + const int code_len_idx = kCodeLengthCodeOrder[i]; + const HuffmanCode* p = huff; + uint8_t v; + br->Refill(); + p += br->PeekFixedBits<4>(); + br->Consume(p->bits); + v = (uint8_t)p->value; + code_length_code_lengths[code_len_idx] = v; + if (v != 0) { + space -= (32u >> v); + ++num_codes; + } + } + bool ok = (num_codes == 1 || space == 0) && + ReadHuffmanCodeLengths(code_length_code_lengths, alphabet_size, + &code_lengths[0], br); + + if (!ok) return false; + uint16_t counts[16] = {0}; + for (size_t i = 0; i < alphabet_size; ++i) { + ++counts[code_lengths[i]]; + } + table_.resize(alphabet_size + 376); + uint32_t table_size = + BuildHuffmanTable(table_.data(), kHuffmanTableBits, &code_lengths[0], + alphabet_size, &counts[0]); + table_.resize(table_size); + return (table_size > 0); +} + +// Decodes the next Huffman coded symbol from the bit-stream. +uint16_t HuffmanDecodingData::ReadSymbol(BitReader* br) const { + size_t n_bits; + const HuffmanCode* table = table_.data(); + table += br->PeekBits(kHuffmanTableBits); + n_bits = table->bits; + if (n_bits > kHuffmanTableBits) { + br->Consume(kHuffmanTableBits); + n_bits -= kHuffmanTableBits; + table += table->value; + table += br->PeekBits(n_bits); + } + br->Consume(table->bits); + return table->value; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_huffman.h b/third_party/jpeg-xl/lib/jxl/dec_huffman.h new file mode 100644 index 0000000000..162c3e309c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_huffman.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_HUFFMAN_H_ +#define LIB_JXL_DEC_HUFFMAN_H_ + +#include <memory> +#include <vector> + +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/huffman_table.h" + +namespace jxl { + +static constexpr size_t kHuffmanTableBits = 8u; + +struct HuffmanDecodingData { + // Decodes the Huffman code lengths from the bit-stream and fills in the + // pre-allocated table with the corresponding 2-level Huffman decoding table. + // Returns false if the Huffman code lengths can not de decoded. + bool ReadFromBitStream(size_t alphabet_size, BitReader* br); + + uint16_t ReadSymbol(BitReader* br) const; + + std::vector<HuffmanCode> table_; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_HUFFMAN_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_modular.cc b/third_party/jpeg-xl/lib/jxl/dec_modular.cc new file mode 100644 index 0000000000..4fcba489e2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_modular.cc @@ -0,0 +1,779 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_modular.h" + +#include <stdint.h> + +#include <atomic> +#include <sstream> +#include <vector> + +#include "lib/jxl/frame_header.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_modular.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::Rebind; + +void MultiplySum(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const pixel_type* const JXL_RESTRICT row_in_Y, + const float factor, float* const JXL_RESTRICT row_out) { + const HWY_FULL(float) df; + const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float + const auto factor_v = Set(df, factor); + for (size_t x = 0; x < xsize; x += Lanes(di)) { + const auto in = Add(Load(di, row_in + x), Load(di, row_in_Y + x)); + const auto out = Mul(ConvertTo(df, in), factor_v); + Store(out, df, row_out + x); + } +} + +void RgbFromSingle(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const float factor, float* out_r, float* out_g, + float* out_b) { + const HWY_FULL(float) df; + const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float + + const auto factor_v = Set(df, factor); + for (size_t x = 0; x < xsize; x += Lanes(di)) { + const auto in = Load(di, row_in + x); + const auto out = Mul(ConvertTo(df, in), factor_v); + Store(out, df, out_r + x); + Store(out, df, out_g + x); + Store(out, df, out_b + x); + } +} + +void SingleFromSingle(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const float factor, float* row_out) { + const HWY_FULL(float) df; + const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float + + const auto factor_v = Set(df, factor); + for (size_t x = 0; x < xsize; x += Lanes(di)) { + const auto in = Load(di, row_in + x); + const auto out = Mul(ConvertTo(df, in), factor_v); + Store(out, df, row_out + x); + } +} +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(MultiplySum); // Local function +HWY_EXPORT(RgbFromSingle); // Local function +HWY_EXPORT(SingleFromSingle); // Local function + +// Slow conversion using double precision multiplication, only +// needed when the bit depth is too high for single precision +void SingleFromSingleAccurate(const size_t xsize, + const pixel_type* const JXL_RESTRICT row_in, + const double factor, float* row_out) { + for (size_t x = 0; x < xsize; x++) { + row_out[x] = row_in[x] * factor; + } +} + +// convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int +// back to binary32 float +void int_to_float(const pixel_type* const JXL_RESTRICT row_in, + float* const JXL_RESTRICT row_out, const size_t xsize, + const int bits, const int exp_bits) { + if (bits == 32) { + JXL_ASSERT(sizeof(pixel_type) == sizeof(float)); + JXL_ASSERT(exp_bits == 8); + memcpy(row_out, row_in, xsize * sizeof(float)); + return; + } + int exp_bias = (1 << (exp_bits - 1)) - 1; + int sign_shift = bits - 1; + int mant_bits = bits - exp_bits - 1; + int mant_shift = 23 - mant_bits; + for (size_t x = 0; x < xsize; ++x) { + uint32_t f; + memcpy(&f, &row_in[x], 4); + int signbit = (f >> sign_shift); + f &= (1 << sign_shift) - 1; + if (f == 0) { + row_out[x] = (signbit ? -0.f : 0.f); + continue; + } + int exp = (f >> mant_bits); + int mantissa = (f & ((1 << mant_bits) - 1)); + mantissa <<= mant_shift; + // Try to normalize only if there is space for maneuver. + if (exp == 0 && exp_bits < 8) { + // subnormal number + while ((mantissa & 0x800000) == 0) { + mantissa <<= 1; + exp--; + } + exp++; + // remove leading 1 because it is implicit now + mantissa &= 0x7fffff; + } + exp -= exp_bias; + // broke up the arbitrary float into its parts, now reassemble into + // binary32 + exp += 127; + JXL_ASSERT(exp >= 0); + f = (signbit ? 0x80000000 : 0); + f |= (exp << 23); + f |= mantissa; + memcpy(&row_out[x], &f, 4); + } +} + +#if JXL_DEBUG_V_LEVEL >= 1 +std::string ModularStreamId::DebugString() const { + std::ostringstream os; + os << (kind == kGlobalData ? "ModularGlobal" + : kind == kVarDCTDC ? "VarDCTDC" + : kind == kModularDC ? "ModularDC" + : kind == kACMetadata ? "ACMeta" + : kind == kQuantTable ? "QuantTable" + : kind == kModularAC ? "ModularAC" + : ""); + if (kind == kVarDCTDC || kind == kModularDC || kind == kACMetadata || + kind == kModularAC) { + os << " group " << group_id; + } + if (kind == kModularAC) { + os << " pass " << pass_id; + } + if (kind == kQuantTable) { + os << " " << quant_table_id; + } + return os.str(); +} +#endif + +Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader, + const FrameHeader& frame_header, + bool allow_truncated_group) { + bool decode_color = frame_header.encoding == FrameEncoding::kModular; + const auto& metadata = frame_header.nonserialized_metadata->m; + bool is_gray = metadata.color_encoding.IsGray(); + size_t nb_chans = 3; + if (is_gray && frame_header.color_transform == ColorTransform::kNone) { + nb_chans = 1; + } + do_color = decode_color; + size_t nb_extra = metadata.extra_channel_info.size(); + bool has_tree = reader->ReadBits(1); + if (!allow_truncated_group || + reader->TotalBitsConsumed() < reader->TotalBytes() * kBitsPerByte) { + if (has_tree) { + size_t tree_size_limit = + std::min(static_cast<size_t>(1 << 22), + 1024 + frame_dim.xsize * frame_dim.ysize * + (nb_chans + nb_extra) / 16); + JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit)); + JXL_RETURN_IF_ERROR( + DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map)); + } + } + if (!do_color) nb_chans = 0; + + bool fp = metadata.bit_depth.floating_point_sample; + + // bits_per_sample is just metadata for XYB images. + if (metadata.bit_depth.bits_per_sample >= 32 && do_color && + frame_header.color_transform != ColorTransform::kXYB) { + if (metadata.bit_depth.bits_per_sample == 32 && fp == false) { + return JXL_FAILURE("uint32_t not supported in dec_modular"); + } else if (metadata.bit_depth.bits_per_sample > 32) { + return JXL_FAILURE("bits_per_sample > 32 not supported"); + } + } + + Image gi(frame_dim.xsize, frame_dim.ysize, metadata.bit_depth.bits_per_sample, + nb_chans + nb_extra); + + all_same_shift = true; + if (frame_header.color_transform == ColorTransform::kYCbCr) { + for (size_t c = 0; c < nb_chans; c++) { + gi.channel[c].hshift = frame_header.chroma_subsampling.HShift(c); + gi.channel[c].vshift = frame_header.chroma_subsampling.VShift(c); + size_t xsize_shifted = + DivCeil(frame_dim.xsize, 1 << gi.channel[c].hshift); + size_t ysize_shifted = + DivCeil(frame_dim.ysize, 1 << gi.channel[c].vshift); + gi.channel[c].shrink(xsize_shifted, ysize_shifted); + if (gi.channel[c].hshift != gi.channel[0].hshift || + gi.channel[c].vshift != gi.channel[0].vshift) + all_same_shift = false; + } + } + + for (size_t ec = 0, c = nb_chans; ec < nb_extra; ec++, c++) { + size_t ecups = frame_header.extra_channel_upsampling[ec]; + gi.channel[c].shrink(DivCeil(frame_dim.xsize_upsampled, ecups), + DivCeil(frame_dim.ysize_upsampled, ecups)); + gi.channel[c].hshift = gi.channel[c].vshift = + CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling); + if (gi.channel[c].hshift != gi.channel[0].hshift || + gi.channel[c].vshift != gi.channel[0].vshift) + all_same_shift = false; + } + + JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (w/o transforms) %s", + gi.DebugString().c_str()); + ModularOptions options; + options.max_chan_size = frame_dim.group_dim; + options.group_dim = frame_dim.group_dim; + Status dec_status = ModularGenericDecompress( + reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim), + &options, + /*undo_transforms=*/false, &tree, &code, &context_map, + allow_truncated_group); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status); + if (dec_status.IsFatalError()) { + return JXL_FAILURE("Failed to decode global modular info"); + } + + // TODO(eustas): are we sure this can be done after partial decode? + have_something = false; + for (size_t c = 0; c < gi.channel.size(); c++) { + Channel& gic = gi.channel[c]; + if (c >= gi.nb_meta_channels && gic.w <= frame_dim.group_dim && + gic.h <= frame_dim.group_dim) + have_something = true; + } + // move global transforms to groups if possible + if (!have_something && all_same_shift) { + if (gi.transform.size() == 1 && gi.transform[0].id == TransformId::kRCT) { + global_transform = gi.transform; + gi.transform.clear(); + // TODO(jon): also move no-delta-palette out (trickier though) + } + } + full_image = std::move(gi); + JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (with transforms) %s", + full_image.DebugString().c_str()); + return dec_status; +} + +void ModularFrameDecoder::MaybeDropFullImage() { + if (full_image.transform.empty() && !have_something && all_same_shift) { + use_full_image = false; + JXL_DEBUG_V(6, "Dropping full image"); + for (auto& ch : full_image.channel) { + // keep metadata on channels around, but dealloc their planes + ch.plane = Plane<pixel_type>(); + } + } +} + +Status ModularFrameDecoder::DecodeGroup( + const FrameHeader& frame_header, const Rect& rect, BitReader* reader, + int minShift, int maxShift, const ModularStreamId& stream, bool zerofill, + PassesDecoderState* dec_state, RenderPipelineInput* render_pipeline_input, + bool allow_truncated, bool* should_run_pipeline) { + JXL_DEBUG_V(6, "Decoding %s with rect %s and shift bracket %d..%d %s", + stream.DebugString().c_str(), Description(rect).c_str(), minShift, + maxShift, zerofill ? "using zerofill" : ""); + JXL_DASSERT(stream.kind == ModularStreamId::kModularDC || + stream.kind == ModularStreamId::kModularAC); + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + Image gi(xsize, ysize, full_image.bitdepth, 0); + // start at the first bigger-than-groupsize non-metachannel + size_t c = full_image.nb_meta_channels; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break; + } + size_t beginc = c; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + int shift = std::min(fc.hshift, fc.vshift); + if (shift > maxShift) continue; + if (shift < minShift) continue; + Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, + rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); + if (r.xsize() == 0 || r.ysize() == 0) continue; + if (zerofill && use_full_image) { + for (size_t y = 0; y < r.ysize(); ++y) { + pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y); + memset(row_out, 0, r.xsize() * sizeof(*row_out)); + } + } else { + Channel gc(r.xsize(), r.ysize()); + if (zerofill) ZeroFillImage(&gc.plane); + gc.hshift = fc.hshift; + gc.vshift = fc.vshift; + gi.channel.emplace_back(std::move(gc)); + } + } + if (zerofill && use_full_image) return true; + // Return early if there's nothing to decode. Otherwise there might be + // problems later (in ModularImageToDecodedRect). + if (gi.channel.empty()) { + if (dec_state && should_run_pipeline) { + const auto* metadata = frame_header.nonserialized_metadata; + if (do_color || metadata->m.num_extra_channels > 0) { + // Signal to FrameDecoder that we do not have some of the required input + // for the render pipeline. + *should_run_pipeline = false; + } + } + JXL_DEBUG_V(6, "Nothing to decode, returning early."); + return true; + } + ModularOptions options; + if (!zerofill) { + auto status = ModularGenericDecompress( + reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options, + /*undo_transforms=*/true, &tree, &code, &context_map, allow_truncated); + if (!allow_truncated) JXL_RETURN_IF_ERROR(status); + if (status.IsFatalError()) return status; + } + // Undo global transforms that have been pushed to the group level + if (!use_full_image) { + JXL_ASSERT(render_pipeline_input); + for (auto t : global_transform) { + JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header)); + } + JXL_RETURN_IF_ERROR(ModularImageToDecodedRect( + frame_header, gi, dec_state, nullptr, *render_pipeline_input, + Rect(0, 0, gi.w, gi.h))); + return true; + } + int gic = 0; + for (c = beginc; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + int shift = std::min(fc.hshift, fc.vshift); + if (shift > maxShift) continue; + if (shift < minShift) continue; + Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, + rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); + if (r.xsize() == 0 || r.ysize() == 0) continue; + JXL_ASSERT(use_full_image); + CopyImageTo(/*rect_from=*/Rect(0, 0, r.xsize(), r.ysize()), + /*from=*/gi.channel[gic].plane, + /*rect_to=*/r, /*to=*/&fc.plane); + gic++; + } + return true; +} + +Status ModularFrameDecoder::DecodeVarDCTDC(const FrameHeader& frame_header, + size_t group_id, BitReader* reader, + PassesDecoderState* dec_state) { + const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id); + JXL_DEBUG_V(6, "Decoding VarDCT DC with rect %s", Description(r).c_str()); + // TODO(eustas): investigate if we could reduce the impact of + // EvalRationalPolynomial; generally speaking, the limit is + // 2**(128/(3*magic)), where 128 comes from IEEE 754 exponent, + // 3 comes from XybToRgb that cubes the values, and "magic" is + // the sum of all other contributions. 2**18 is known to lead + // to NaN on input found by fuzzing (see commit message). + Image image(r.xsize(), r.ysize(), full_image.bitdepth, 3); + size_t stream_id = ModularStreamId::VarDCTDC(group_id).ID(frame_dim); + reader->Refill(); + size_t extra_precision = reader->ReadFixedBits<2>(); + float mul = 1.0f / (1 << extra_precision); + ModularOptions options; + for (size_t c = 0; c < 3; c++) { + Channel& ch = image.channel[c < 2 ? c ^ 1 : c]; + ch.w >>= frame_header.chroma_subsampling.HShift(c); + ch.h >>= frame_header.chroma_subsampling.VShift(c); + ch.shrink(); + } + if (!ModularGenericDecompress( + reader, image, /*header=*/nullptr, stream_id, &options, + /*undo_transforms=*/true, &tree, &code, &context_map)) { + return JXL_FAILURE("Failed to decode VarDCT DC group"); + } + DequantDC(r, &dec_state->shared_storage.dc_storage, + &dec_state->shared_storage.quant_dc, image, + dec_state->shared->quantizer.MulDC(), mul, + dec_state->shared->cmap.DCFactors(), + frame_header.chroma_subsampling, dec_state->shared->block_ctx_map); + return true; +} + +Status ModularFrameDecoder::DecodeAcMetadata(const FrameHeader& frame_header, + size_t group_id, BitReader* reader, + PassesDecoderState* dec_state) { + const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id); + JXL_DEBUG_V(6, "Decoding AcMetadata with rect %s", Description(r).c_str()); + size_t upper_bound = r.xsize() * r.ysize(); + reader->Refill(); + size_t count = reader->ReadBits(CeilLog2Nonzero(upper_bound)) + 1; + size_t stream_id = ModularStreamId::ACMetadata(group_id).ID(frame_dim); + // YToX, YToB, ACS + QF, EPF + Image image(r.xsize(), r.ysize(), full_image.bitdepth, 4); + static_assert(kColorTileDimInBlocks == 8, "Color tile size changed"); + Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3); + image.channel[0] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[1] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[2] = Channel(count, 2, 0, 0); + ModularOptions options; + if (!ModularGenericDecompress( + reader, image, /*header=*/nullptr, stream_id, &options, + /*undo_transforms=*/true, &tree, &code, &context_map)) { + return JXL_FAILURE("Failed to decode AC metadata"); + } + ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane, cr, + &dec_state->shared_storage.cmap.ytox_map); + ConvertPlaneAndClamp(Rect(image.channel[1].plane), image.channel[1].plane, cr, + &dec_state->shared_storage.cmap.ytob_map); + size_t num = 0; + bool is444 = frame_header.chroma_subsampling.Is444(); + auto& ac_strategy = dec_state->shared_storage.ac_strategy; + size_t xlim = std::min(ac_strategy.xsize(), r.x0() + r.xsize()); + size_t ylim = std::min(ac_strategy.ysize(), r.y0() + r.ysize()); + uint32_t local_used_acs = 0; + for (size_t iy = 0; iy < r.ysize(); iy++) { + size_t y = r.y0() + iy; + int32_t* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy); + uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy); + int32_t* row_in_1 = image.channel[2].plane.Row(0); + int32_t* row_in_2 = image.channel[2].plane.Row(1); + int32_t* row_in_3 = image.channel[3].plane.Row(iy); + for (size_t ix = 0; ix < r.xsize(); ix++) { + size_t x = r.x0() + ix; + int sharpness = row_in_3[ix]; + if (sharpness < 0 || sharpness >= LoopFilter::kEpfSharpEntries) { + return JXL_FAILURE("Corrupted sharpness field"); + } + row_epf[ix] = sharpness; + if (ac_strategy.IsValid(x, y)) { + continue; + } + + if (num >= count) return JXL_FAILURE("Corrupted stream"); + + if (!AcStrategy::IsRawStrategyValid(row_in_1[num])) { + return JXL_FAILURE("Invalid AC strategy"); + } + local_used_acs |= 1u << row_in_1[num]; + AcStrategy acs = AcStrategy::FromRawStrategy(row_in_1[num]); + if ((acs.covered_blocks_x() > 1 || acs.covered_blocks_y() > 1) && + !is444) { + return JXL_FAILURE( + "AC strategy not compatible with chroma subsampling"); + } + // Ensure that blocks do not overflow *AC* groups. + size_t next_x_ac_block = (x / kGroupDimInBlocks + 1) * kGroupDimInBlocks; + size_t next_y_ac_block = (y / kGroupDimInBlocks + 1) * kGroupDimInBlocks; + size_t next_x_dct_block = x + acs.covered_blocks_x(); + size_t next_y_dct_block = y + acs.covered_blocks_y(); + if (next_x_dct_block > next_x_ac_block || next_x_dct_block > xlim) { + return JXL_FAILURE("Invalid AC strategy, x overflow"); + } + if (next_y_dct_block > next_y_ac_block || next_y_dct_block > ylim) { + return JXL_FAILURE("Invalid AC strategy, y overflow"); + } + JXL_RETURN_IF_ERROR( + ac_strategy.SetNoBoundsCheck(x, y, AcStrategy::Type(row_in_1[num]))); + row_qf[ix] = 1 + std::max<int32_t>(0, std::min(Quantizer::kQuantMax - 1, + row_in_2[num])); + num++; + } + } + dec_state->used_acs |= local_used_acs; + if (frame_header.loop_filter.epf_iters > 0) { + ComputeSigma(frame_header.loop_filter, r, dec_state); + } + return true; +} + +Status ModularFrameDecoder::ModularImageToDecodedRect( + const FrameHeader& frame_header, Image& gi, PassesDecoderState* dec_state, + jxl::ThreadPool* pool, RenderPipelineInput& render_pipeline_input, + Rect modular_rect) { + const auto* metadata = frame_header.nonserialized_metadata; + JXL_CHECK(gi.transform.empty()); + + auto get_row = [&](size_t c, size_t y) { + const auto& buffer = render_pipeline_input.GetBuffer(c); + return buffer.second.Row(buffer.first, y); + }; + + size_t c = 0; + if (do_color) { + const bool rgb_from_gray = + metadata->m.color_encoding.IsGray() && + frame_header.color_transform == ColorTransform::kNone; + const bool fp = metadata->m.bit_depth.floating_point_sample && + frame_header.color_transform != ColorTransform::kXYB; + for (; c < 3; c++) { + double factor = full_image.bitdepth < 32 + ? 1.0 / ((1u << full_image.bitdepth) - 1) + : 0; + size_t c_in = c; + if (frame_header.color_transform == ColorTransform::kXYB) { + factor = dec_state->shared->matrices.DCQuants()[c]; + // XYB is encoded as YX(B-Y) + if (c < 2) c_in = 1 - c; + } else if (rgb_from_gray) { + c_in = 0; + } + JXL_ASSERT(c_in < gi.channel.size()); + Channel& ch_in = gi.channel[c_in]; + // TODO(eustas): could we detect it on earlier stage? + if (ch_in.w == 0 || ch_in.h == 0) { + return JXL_FAILURE("Empty image"); + } + JXL_CHECK(ch_in.hshift <= 3 && ch_in.vshift <= 3); + Rect r = render_pipeline_input.GetBuffer(c).second; + Rect mr(modular_rect.x0() >> ch_in.hshift, + modular_rect.y0() >> ch_in.vshift, + DivCeil(modular_rect.xsize(), 1 << ch_in.hshift), + DivCeil(modular_rect.ysize(), 1 << ch_in.vshift)); + mr = mr.Crop(ch_in.plane); + size_t xsize_shifted = r.xsize(); + size_t ysize_shifted = r.ysize(); + if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) { + return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS + "x%" PRIuS + " modular channel into " + "a %" PRIuS "x%" PRIuS " rect", + mr.xsize(), mr.ysize(), r.xsize(), r.ysize()); + } + if (frame_header.color_transform == ColorTransform::kXYB && c == 2) { + JXL_ASSERT(!fp); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ysize_shifted, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + const pixel_type* const JXL_RESTRICT row_in = + mr.Row(&ch_in.plane, y); + const pixel_type* const JXL_RESTRICT row_in_Y = + mr.Row(&gi.channel[0].plane, y); + float* const JXL_RESTRICT row_out = get_row(c, y); + HWY_DYNAMIC_DISPATCH(MultiplySum) + (xsize_shifted, row_in, row_in_Y, factor, row_out); + }, + "ModularIntToFloat")); + } else if (fp) { + int bits = metadata->m.bit_depth.bits_per_sample; + int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ysize_shifted, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + const pixel_type* const JXL_RESTRICT row_in = + mr.Row(&ch_in.plane, y); + if (rgb_from_gray) { + for (size_t cc = 0; cc < 3; cc++) { + float* const JXL_RESTRICT row_out = get_row(cc, y); + int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits); + } + } else { + float* const JXL_RESTRICT row_out = get_row(c, y); + int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits); + } + }, + "ModularIntToFloat_losslessfloat")); + } else { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ysize_shifted, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + const pixel_type* const JXL_RESTRICT row_in = + mr.Row(&ch_in.plane, y); + if (rgb_from_gray) { + if (full_image.bitdepth < 23) { + HWY_DYNAMIC_DISPATCH(RgbFromSingle) + (xsize_shifted, row_in, factor, get_row(0, y), get_row(1, y), + get_row(2, y)); + } else { + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + get_row(0, y)); + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + get_row(1, y)); + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + get_row(2, y)); + } + } else { + float* const JXL_RESTRICT row_out = get_row(c, y); + if (full_image.bitdepth < 23) { + HWY_DYNAMIC_DISPATCH(SingleFromSingle) + (xsize_shifted, row_in, factor, row_out); + } else { + SingleFromSingleAccurate(xsize_shifted, row_in, factor, + row_out); + } + } + }, + "ModularIntToFloat")); + } + if (rgb_from_gray) { + break; + } + } + if (rgb_from_gray) { + c = 1; + } + } + size_t num_extra_channels = metadata->m.num_extra_channels; + for (size_t ec = 0; ec < num_extra_channels; ec++, c++) { + const ExtraChannelInfo& eci = metadata->m.extra_channel_info[ec]; + int bits = eci.bit_depth.bits_per_sample; + int exp_bits = eci.bit_depth.exponent_bits_per_sample; + bool fp = eci.bit_depth.floating_point_sample; + JXL_ASSERT(fp || bits < 32); + const double factor = fp ? 0 : (1.0 / ((1u << bits) - 1)); + JXL_ASSERT(c < gi.channel.size()); + Channel& ch_in = gi.channel[c]; + Rect r = render_pipeline_input.GetBuffer(3 + ec).second; + Rect mr(modular_rect.x0() >> ch_in.hshift, + modular_rect.y0() >> ch_in.vshift, + DivCeil(modular_rect.xsize(), 1 << ch_in.hshift), + DivCeil(modular_rect.ysize(), 1 << ch_in.vshift)); + mr = mr.Crop(ch_in.plane); + if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) { + return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS + "x%" PRIuS + " modular channel into " + "a %" PRIuS "x%" PRIuS " rect", + mr.xsize(), mr.ysize(), r.xsize(), r.ysize()); + } + for (size_t y = 0; y < r.ysize(); ++y) { + float* const JXL_RESTRICT row_out = + r.Row(render_pipeline_input.GetBuffer(3 + ec).first, y); + const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y); + if (fp) { + int_to_float(row_in, row_out, r.xsize(), bits, exp_bits); + } else { + if (full_image.bitdepth < 23) { + HWY_DYNAMIC_DISPATCH(SingleFromSingle) + (r.xsize(), row_in, factor, row_out); + } else { + SingleFromSingleAccurate(r.xsize(), row_in, factor, row_out); + } + } + } + } + return true; +} + +Status ModularFrameDecoder::FinalizeDecoding(const FrameHeader& frame_header, + PassesDecoderState* dec_state, + jxl::ThreadPool* pool, + bool inplace) { + if (!use_full_image) return true; + Image gi = (inplace ? std::move(full_image) : full_image.clone()); + size_t xsize = gi.w; + size_t ysize = gi.h; + + JXL_DEBUG_V(3, "Finalizing decoding for modular image: %s", + gi.DebugString().c_str()); + + // Don't use threads if total image size is smaller than a group + if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr; + + // Undo the global transforms + gi.undo_transforms(global_header.wp_header, pool); + JXL_DASSERT(global_transform.empty()); + if (gi.error) return JXL_FAILURE("Undoing transforms failed"); + + for (size_t i = 0; i < dec_state->shared->frame_dim.num_groups; i++) { + dec_state->render_pipeline->ClearDone(i); + } + std::atomic<bool> has_error{false}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, dec_state->shared->frame_dim.num_groups, + [&](size_t num_threads) { + bool use_group_ids = (frame_header.encoding == FrameEncoding::kVarDCT || + (frame_header.flags & FrameHeader::kNoise)); + return dec_state->render_pipeline->PrepareForThreads(num_threads, + use_group_ids); + }, + [&](const uint32_t group, size_t thread_id) { + RenderPipelineInput input = + dec_state->render_pipeline->GetInputBuffers(group, thread_id); + if (!ModularImageToDecodedRect( + frame_header, gi, dec_state, nullptr, input, + dec_state->shared->frame_dim.GroupRect(group))) { + has_error = true; + return; + } + input.Done(); + }, + "ModularToRect")); + if (has_error) { + return JXL_FAILURE("Error producing input to render pipeline"); + } + return true; +} + +static constexpr const float kAlmostZero = 1e-8f; + +Status ModularFrameDecoder::DecodeQuantTable( + size_t required_size_x, size_t required_size_y, BitReader* br, + QuantEncoding* encoding, size_t idx, + ModularFrameDecoder* modular_frame_decoder) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->qraw.qtable_den)); + if (encoding->qraw.qtable_den < kAlmostZero) { + // qtable[] values are already checked for <= 0 so the denominator may not + // be negative. + return JXL_FAILURE("Invalid qtable_den: value too small"); + } + Image image(required_size_x, required_size_y, 8, 3); + ModularOptions options; + if (modular_frame_decoder) { + JXL_RETURN_IF_ERROR(ModularGenericDecompress( + br, image, /*header=*/nullptr, + ModularStreamId::QuantTable(idx).ID(modular_frame_decoder->frame_dim), + &options, /*undo_transforms=*/true, &modular_frame_decoder->tree, + &modular_frame_decoder->code, &modular_frame_decoder->context_map)); + } else { + JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr, + 0, &options, + /*undo_transforms=*/true)); + } + if (!encoding->qraw.qtable) { + encoding->qraw.qtable = new std::vector<int>(); + } + encoding->qraw.qtable->resize(required_size_x * required_size_y * 3); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < required_size_y; y++) { + int32_t* JXL_RESTRICT row = image.channel[c].Row(y); + for (size_t x = 0; x < required_size_x; x++) { + (*encoding->qraw.qtable)[c * required_size_x * required_size_y + + y * required_size_x + x] = row[x]; + if (row[x] <= 0) { + return JXL_FAILURE("Invalid raw quantization table"); + } + } + } + } + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dec_modular.h b/third_party/jpeg-xl/lib/jxl/dec_modular.h new file mode 100644 index 0000000000..58a6562740 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_modular.h @@ -0,0 +1,143 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_MODULAR_H_ +#define LIB_JXL_DEC_MODULAR_H_ + +#include <stddef.h> + +#include <string> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +struct ModularStreamId { + enum Kind { + kGlobalData, + kVarDCTDC, + kModularDC, + kACMetadata, + kQuantTable, + kModularAC + }; + Kind kind; + size_t quant_table_id; + size_t group_id; // DC or AC group id. + size_t pass_id; // Only for kModularAC. + size_t ID(const FrameDimensions& frame_dim) const { + size_t id = 0; + switch (kind) { + case kGlobalData: + id = 0; + break; + case kVarDCTDC: + id = 1 + group_id; + break; + case kModularDC: + id = 1 + frame_dim.num_dc_groups + group_id; + break; + case kACMetadata: + id = 1 + 2 * frame_dim.num_dc_groups + group_id; + break; + case kQuantTable: + id = 1 + 3 * frame_dim.num_dc_groups + quant_table_id; + break; + case kModularAC: + id = 1 + 3 * frame_dim.num_dc_groups + DequantMatrices::kNum + + frame_dim.num_groups * pass_id + group_id; + break; + }; + return id; + } + static ModularStreamId Global() { + return ModularStreamId{kGlobalData, 0, 0, 0}; + } + static ModularStreamId VarDCTDC(size_t group_id) { + return ModularStreamId{kVarDCTDC, 0, group_id, 0}; + } + static ModularStreamId ModularDC(size_t group_id) { + return ModularStreamId{kModularDC, 0, group_id, 0}; + } + static ModularStreamId ACMetadata(size_t group_id) { + return ModularStreamId{kACMetadata, 0, group_id, 0}; + } + static ModularStreamId QuantTable(size_t quant_table_id) { + JXL_ASSERT(quant_table_id < DequantMatrices::kNum); + return ModularStreamId{kQuantTable, quant_table_id, 0, 0}; + } + static ModularStreamId ModularAC(size_t group_id, size_t pass_id) { + return ModularStreamId{kModularAC, 0, group_id, pass_id}; + } + static size_t Num(const FrameDimensions& frame_dim, size_t passes) { + return ModularAC(0, passes).ID(frame_dim); + } + std::string DebugString() const; +}; + +class ModularFrameDecoder { + public: + void Init(const FrameDimensions& frame_dim) { this->frame_dim = frame_dim; } + Status DecodeGlobalInfo(BitReader* reader, const FrameHeader& frame_header, + bool allow_truncated_group); + Status DecodeGroup(const FrameHeader& frame_header, const Rect& rect, + BitReader* reader, int minShift, int maxShift, + const ModularStreamId& stream, bool zerofill, + PassesDecoderState* dec_state, + RenderPipelineInput* render_pipeline_input, + bool allow_truncated, bool* should_run_pipeline = nullptr); + // Decodes a VarDCT DC group (`group_id`) from the given `reader`. + Status DecodeVarDCTDC(const FrameHeader& frame_header, size_t group_id, + BitReader* reader, PassesDecoderState* dec_state); + // Decodes a VarDCT AC Metadata group (`group_id`) from the given `reader`. + Status DecodeAcMetadata(const FrameHeader& frame_header, size_t group_id, + BitReader* reader, PassesDecoderState* dec_state); + // Decodes a RAW quant table from `br` into the given `encoding`, of size + // `required_size_x x required_size_y`. If `modular_frame_decoder` is passed, + // its global tree is used, otherwise no global tree is used. + static Status DecodeQuantTable(size_t required_size_x, size_t required_size_y, + BitReader* br, QuantEncoding* encoding, + size_t idx, + ModularFrameDecoder* modular_frame_decoder); + // if inplace is true, this can only be called once + // if it is false, it can be called multiple times (e.g. for progressive + // steps) + Status FinalizeDecoding(const FrameHeader& frame_header, + PassesDecoderState* dec_state, jxl::ThreadPool* pool, + bool inplace); + bool have_dc() const { return have_something; } + void MaybeDropFullImage(); + bool UsesFullImage() const { return use_full_image; } + + private: + Status ModularImageToDecodedRect(const FrameHeader& frame_header, Image& gi, + PassesDecoderState* dec_state, + jxl::ThreadPool* pool, + RenderPipelineInput& render_pipeline_input, + Rect modular_rect); + + Image full_image; + std::vector<Transform> global_transform; + FrameDimensions frame_dim; + bool do_color; + bool have_something; + bool use_full_image = true; + bool all_same_shift; + Tree tree; + ANSCode code; + std::vector<uint8_t> context_map; + GroupHeader global_header; +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_MODULAR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_noise.cc b/third_party/jpeg-xl/lib/jxl/dec_noise.cc new file mode 100644 index 0000000000..ae46b1062f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_noise.cc @@ -0,0 +1,129 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_noise.h" + +#include <stdint.h> +#include <stdlib.h> + +#include <algorithm> +#include <numeric> +#include <utility> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_noise.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/xorshift128plus-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Or; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Vec; + +using D = HWY_CAPPED(float, kBlockDim); +using DI = hwy::HWY_NAMESPACE::Rebind<int, D>; +using DI8 = hwy::HWY_NAMESPACE::Repartition<uint8_t, D>; + +// Converts one vector's worth of random bits to floats in [1, 2). +// NOTE: as the convolution kernel sums to 0, it doesn't matter if inputs are in +// [0, 1) or in [1, 2). +void BitsToFloat(const uint32_t* JXL_RESTRICT random_bits, + float* JXL_RESTRICT floats) { + const HWY_FULL(float) df; + const HWY_FULL(uint32_t) du; + + const auto bits = Load(du, random_bits); + // 1.0 + 23 random mantissa bits = [1, 2) + const auto rand12 = BitCast(df, Or(ShiftRight<9>(bits), Set(du, 0x3F800000))); + Store(rand12, df, floats); +} + +void RandomImage(Xorshift128Plus* rng, const Rect& rect, + ImageF* JXL_RESTRICT noise) { + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + + // May exceed the vector size, hence we have two loops over x below. + constexpr size_t kFloatsPerBatch = + Xorshift128Plus::N * sizeof(uint64_t) / sizeof(float); + HWY_ALIGN uint64_t batch[Xorshift128Plus::N] = {}; + + const HWY_FULL(float) df; + const size_t N = Lanes(df); + + for (size_t y = 0; y < ysize; ++y) { + float* JXL_RESTRICT row = rect.Row(noise, y); + + size_t x = 0; + // Only entire batches (avoids exceeding the image padding). + for (; x + kFloatsPerBatch < xsize; x += kFloatsPerBatch) { + rng->Fill(batch); + for (size_t i = 0; i < kFloatsPerBatch; i += Lanes(df)) { + BitsToFloat(reinterpret_cast<const uint32_t*>(batch) + i, row + x + i); + } + } + + // Any remaining pixels, rounded up to vectors (safe due to padding). + rng->Fill(batch); + size_t batch_pos = 0; // < kFloatsPerBatch + for (; x < xsize; x += N) { + BitsToFloat(reinterpret_cast<const uint32_t*>(batch) + batch_pos, + row + x); + batch_pos += N; + } + } +} +void Random3Planes(size_t visible_frame_index, size_t nonvisible_frame_index, + size_t x0, size_t y0, const std::pair<ImageF*, Rect>& plane0, + const std::pair<ImageF*, Rect>& plane1, + const std::pair<ImageF*, Rect>& plane2) { + HWY_ALIGN Xorshift128Plus rng(visible_frame_index, nonvisible_frame_index, x0, + y0); + RandomImage(&rng, plane0.second, plane0.first); + RandomImage(&rng, plane1.second, plane1.first); + RandomImage(&rng, plane2.second, plane2.first); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(Random3Planes); +void Random3Planes(size_t visible_frame_index, size_t nonvisible_frame_index, + size_t x0, size_t y0, const std::pair<ImageF*, Rect>& plane0, + const std::pair<ImageF*, Rect>& plane1, + const std::pair<ImageF*, Rect>& plane2) { + return HWY_DYNAMIC_DISPATCH(Random3Planes)(visible_frame_index, + nonvisible_frame_index, x0, y0, + plane0, plane1, plane2); +} + +void DecodeFloatParam(float precision, float* val, BitReader* br) { + const int absval_quant = br->ReadFixedBits<10>(); + *val = absval_quant / precision; +} + +Status DecodeNoise(BitReader* br, NoiseParams* noise_params) { + for (float& i : noise_params->lut) { + DecodeFloatParam(kNoisePrecision, &i, br); + } + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dec_noise.h b/third_party/jpeg-xl/lib/jxl/dec_noise.h new file mode 100644 index 0000000000..ac05866470 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_noise.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_NOISE_H_ +#define LIB_JXL_DEC_NOISE_H_ + +// Noise synthesis. Currently disabled. + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image.h" +#include "lib/jxl/noise.h" + +namespace jxl { + +void Random3Planes(size_t visible_frame_index, size_t nonvisible_frame_index, + size_t x0, size_t y0, const std::pair<ImageF*, Rect>& plane0, + const std::pair<ImageF*, Rect>& plane1, + const std::pair<ImageF*, Rect>& plane2); + +// Must only call if FrameHeader.flags.kNoise. +Status DecodeNoise(BitReader* br, NoiseParams* noise_params); + +} // namespace jxl + +#endif // LIB_JXL_DEC_NOISE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc b/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc new file mode 100644 index 0000000000..0ae2223252 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.cc @@ -0,0 +1,361 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_patch_dictionary.h" + +#include <stdint.h> +#include <stdlib.h> +#include <sys/types.h> + +#include <algorithm> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/blending.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" // kMaxNumReferenceFrames +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/pack_signed.h" +#include "lib/jxl/patch_dictionary_internal.h" + +namespace jxl { + +Status PatchDictionary::Decode(BitReader* br, size_t xsize, size_t ysize, + bool* uses_extra_channels) { + positions_.clear(); + std::vector<uint8_t> context_map; + ANSCode code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kNumPatchDictionaryContexts, &code, &context_map)); + ANSSymbolReader decoder(&code, br); + + auto read_num = [&](size_t context) { + size_t r = decoder.ReadHybridUint(context, br, context_map); + return r; + }; + + size_t num_ref_patch = read_num(kNumRefPatchContext); + // Limit max memory usage of patches to about 66 bytes per pixel (assuming 8 + // bytes per size_t) + const size_t num_pixels = xsize * ysize; + const size_t max_ref_patches = 1024 + num_pixels / 4; + const size_t max_patches = max_ref_patches * 4; + const size_t max_blending_infos = max_patches * 4; + if (num_ref_patch > max_ref_patches) { + return JXL_FAILURE("Too many patches in dictionary"); + } + size_t num_ec = shared_->metadata->m.num_extra_channels; + + size_t total_patches = 0; + size_t next_size = 1; + + for (size_t id = 0; id < num_ref_patch; id++) { + PatchReferencePosition ref_pos; + ref_pos.ref = read_num(kReferenceFrameContext); + if (ref_pos.ref >= kMaxNumReferenceFrames || + shared_->reference_frames[ref_pos.ref].frame.xsize() == 0) { + return JXL_FAILURE("Invalid reference frame ID"); + } + if (!shared_->reference_frames[ref_pos.ref].ib_is_in_xyb) { + return JXL_FAILURE( + "Patches cannot use frames saved post color transforms"); + } + const ImageBundle& ib = shared_->reference_frames[ref_pos.ref].frame; + ref_pos.x0 = read_num(kPatchReferencePositionContext); + ref_pos.y0 = read_num(kPatchReferencePositionContext); + ref_pos.xsize = read_num(kPatchSizeContext) + 1; + ref_pos.ysize = read_num(kPatchSizeContext) + 1; + if (ref_pos.x0 + ref_pos.xsize > ib.xsize()) { + return JXL_FAILURE("Invalid position specified in reference frame"); + } + if (ref_pos.y0 + ref_pos.ysize > ib.ysize()) { + return JXL_FAILURE("Invalid position specified in reference frame"); + } + size_t id_count = read_num(kPatchCountContext); + if (id_count > max_patches) { + return JXL_FAILURE("Too many patches in dictionary"); + } + id_count++; + total_patches += id_count; + if (total_patches > max_patches) { + return JXL_FAILURE("Too many patches in dictionary"); + } + if (next_size < total_patches) { + next_size *= 2; + next_size = std::min<size_t>(next_size, max_patches); + } + if (next_size * (num_ec + 1) > max_blending_infos) { + return JXL_FAILURE("Too many patches in dictionary"); + } + positions_.reserve(next_size); + blendings_.reserve(next_size * (num_ec + 1)); + for (size_t i = 0; i < id_count; i++) { + PatchPosition pos; + pos.ref_pos_idx = ref_positions_.size(); + if (i == 0) { + pos.x = read_num(kPatchPositionContext); + pos.y = read_num(kPatchPositionContext); + } else { + ssize_t deltax = UnpackSigned(read_num(kPatchOffsetContext)); + if (deltax < 0 && static_cast<size_t>(-deltax) > positions_.back().x) { + return JXL_FAILURE("Invalid patch: negative x coordinate (%" PRIuS + " base x %" PRIdS " delta x)", + positions_.back().x, deltax); + } + pos.x = positions_.back().x + deltax; + ssize_t deltay = UnpackSigned(read_num(kPatchOffsetContext)); + if (deltay < 0 && static_cast<size_t>(-deltay) > positions_.back().y) { + return JXL_FAILURE("Invalid patch: negative y coordinate (%" PRIuS + " base y %" PRIdS " delta y)", + positions_.back().y, deltay); + } + pos.y = positions_.back().y + deltay; + } + if (pos.x + ref_pos.xsize > xsize) { + return JXL_FAILURE("Invalid patch x: at %" PRIuS " + %" PRIuS + " > %" PRIuS, + pos.x, ref_pos.xsize, xsize); + } + if (pos.y + ref_pos.ysize > ysize) { + return JXL_FAILURE("Invalid patch y: at %" PRIuS " + %" PRIuS + " > %" PRIuS, + pos.y, ref_pos.ysize, ysize); + } + for (size_t j = 0; j < num_ec + 1; j++) { + uint32_t blend_mode = read_num(kPatchBlendModeContext); + if (blend_mode >= uint32_t(PatchBlendMode::kNumBlendModes)) { + return JXL_FAILURE("Invalid patch blend mode: %u", blend_mode); + } + PatchBlending info; + info.mode = static_cast<PatchBlendMode>(blend_mode); + if (UsesAlpha(info.mode)) { + *uses_extra_channels = true; + } + if (info.mode != PatchBlendMode::kNone && j > 0) { + *uses_extra_channels = true; + } + if (UsesAlpha(info.mode) && + shared_->metadata->m.extra_channel_info.size() > 1) { + info.alpha_channel = read_num(kPatchAlphaChannelContext); + if (info.alpha_channel >= + shared_->metadata->m.extra_channel_info.size()) { + return JXL_FAILURE( + "Invalid alpha channel for blending: %u out of %u\n", + info.alpha_channel, + (uint32_t)shared_->metadata->m.extra_channel_info.size()); + } + } else { + info.alpha_channel = 0; + } + if (UsesClamp(info.mode)) { + info.clamp = read_num(kPatchClampContext); + } else { + info.clamp = false; + } + blendings_.push_back(info); + } + positions_.push_back(std::move(pos)); + } + ref_positions_.emplace_back(std::move(ref_pos)); + } + positions_.shrink_to_fit(); + + if (!decoder.CheckANSFinalState()) { + return JXL_FAILURE("ANS checksum failure."); + } + + ComputePatchTree(); + return true; +} + +int PatchDictionary::GetReferences() const { + int result = 0; + for (size_t i = 0; i < ref_positions_.size(); ++i) { + result |= (1 << static_cast<int>(ref_positions_[i].ref)); + } + return result; +} + +namespace { +struct PatchInterval { + size_t idx; + size_t y0, y1; +}; +} // namespace + +void PatchDictionary::ComputePatchTree() { + patch_tree_.clear(); + num_patches_.clear(); + sorted_patches_y0_.clear(); + sorted_patches_y1_.clear(); + if (positions_.empty()) { + return; + } + // Create a y-interval for each patch. + std::vector<PatchInterval> intervals(positions_.size()); + for (size_t i = 0; i < positions_.size(); ++i) { + const auto& pos = positions_[i]; + intervals[i].idx = i; + intervals[i].y0 = pos.y; + intervals[i].y1 = pos.y + ref_positions_[pos.ref_pos_idx].ysize; + } + auto sort_by_y0 = [&intervals](size_t start, size_t end) { + std::sort(intervals.data() + start, intervals.data() + end, + [](const PatchInterval& i0, const PatchInterval& i1) { + return i0.y0 < i1.y0; + }); + }; + auto sort_by_y1 = [&intervals](size_t start, size_t end) { + std::sort(intervals.data() + start, intervals.data() + end, + [](const PatchInterval& i0, const PatchInterval& i1) { + return i0.y1 < i1.y1; + }); + }; + // Count the number of patches for each row. + sort_by_y1(0, intervals.size()); + num_patches_.resize(intervals.back().y1); + for (auto iv : intervals) { + for (size_t y = iv.y0; y < iv.y1; ++y) num_patches_[y]++; + } + PatchTreeNode root; + root.start = 0; + root.num = intervals.size(); + patch_tree_.push_back(root); + size_t next = 0; + while (next < patch_tree_.size()) { + auto& node = patch_tree_[next]; + size_t start = node.start; + size_t end = node.start + node.num; + // Choose the y_center for this node to be the median of interval starts. + sort_by_y0(start, end); + size_t middle_idx = start + node.num / 2; + node.y_center = intervals[middle_idx].y0; + // Divide the intervals in [start, end) into three groups: + // * those completely to the right of y_center: [right_start, end) + // * those overlapping y_center: [left_end, right_start) + // * those completely to the left of y_center: [start, left_end) + size_t right_start = middle_idx; + while (right_start < end && intervals[right_start].y0 == node.y_center) { + ++right_start; + } + sort_by_y1(start, right_start); + size_t left_end = right_start; + while (left_end > start && intervals[left_end - 1].y1 > node.y_center) { + --left_end; + } + // Fill in sorted_patches_y0_ and sorted_patches_y1_ for the current node. + node.num = right_start - left_end; + node.start = sorted_patches_y0_.size(); + for (ssize_t i = static_cast<ssize_t>(right_start) - 1; + i >= static_cast<ssize_t>(left_end); --i) { + sorted_patches_y1_.push_back({intervals[i].y1, intervals[i].idx}); + } + sort_by_y0(left_end, right_start); + for (size_t i = left_end; i < right_start; ++i) { + sorted_patches_y0_.push_back({intervals[i].y0, intervals[i].idx}); + } + // Create the left and right nodes (if not empty). + node.left_child = node.right_child = -1; + if (left_end > start) { + PatchTreeNode left; + left.start = start; + left.num = left_end - left.start; + patch_tree_[next].left_child = patch_tree_.size(); + patch_tree_.push_back(left); + } + if (right_start < end) { + PatchTreeNode right; + right.start = right_start; + right.num = end - right.start; + patch_tree_[next].right_child = patch_tree_.size(); + patch_tree_.push_back(right); + } + ++next; + } +} + +std::vector<size_t> PatchDictionary::GetPatchesForRow(size_t y) const { + std::vector<size_t> result; + if (y < num_patches_.size() && num_patches_[y] > 0) { + result.reserve(num_patches_[y]); + for (ssize_t tree_idx = 0; tree_idx != -1;) { + JXL_DASSERT(tree_idx < (ssize_t)patch_tree_.size()); + const auto& node = patch_tree_[tree_idx]; + if (y <= node.y_center) { + for (size_t i = 0; i < node.num; ++i) { + const auto& p = sorted_patches_y0_[node.start + i]; + if (y < p.first) break; + result.push_back(p.second); + } + tree_idx = y < node.y_center ? node.left_child : -1; + } else { + for (size_t i = 0; i < node.num; ++i) { + const auto& p = sorted_patches_y1_[node.start + i]; + if (y >= p.first) break; + result.push_back(p.second); + } + tree_idx = node.right_child; + } + } + // Ensure that he relative order of patches that affect the same pixels is + // preserved. This is important for patches that have a blend mode + // different from kAdd. + std::sort(result.begin(), result.end()); + } + return result; +} + +// Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed +// to be located at position (x0, y) in the frame. +void PatchDictionary::AddOneRow(float* const* inout, size_t y, size_t x0, + size_t xsize) const { + size_t num_ec = shared_->metadata->m.num_extra_channels; + std::vector<const float*> fg_ptrs(3 + num_ec); + for (size_t pos_idx : GetPatchesForRow(y)) { + const size_t blending_idx = pos_idx * (num_ec + 1); + const PatchPosition& pos = positions_[pos_idx]; + const PatchReferencePosition& ref_pos = ref_positions_[pos.ref_pos_idx]; + size_t by = pos.y; + size_t bx = pos.x; + size_t patch_xsize = ref_pos.xsize; + JXL_DASSERT(y >= by); + JXL_DASSERT(y < by + ref_pos.ysize); + size_t iy = y - by; + size_t ref = ref_pos.ref; + if (bx >= x0 + xsize) continue; + if (bx + patch_xsize < x0) continue; + size_t patch_x0 = std::max(bx, x0); + size_t patch_x1 = std::min(bx + patch_xsize, x0 + xsize); + for (size_t c = 0; c < 3; c++) { + fg_ptrs[c] = shared_->reference_frames[ref].frame.color().ConstPlaneRow( + c, ref_pos.y0 + iy) + + ref_pos.x0 + x0 - bx; + } + for (size_t i = 0; i < num_ec; i++) { + fg_ptrs[3 + i] = + shared_->reference_frames[ref].frame.extra_channels()[i].ConstRow( + ref_pos.y0 + iy) + + ref_pos.x0 + x0 - bx; + } + PerformBlending(inout, fg_ptrs.data(), inout, patch_x0 - x0, + patch_x1 - patch_x0, blendings_[blending_idx], + blendings_.data() + blending_idx + 1, + shared_->metadata->m.extra_channel_info); + } +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.h b/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.h new file mode 100644 index 0000000000..aac6111ae6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_patch_dictionary.h @@ -0,0 +1,149 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_PATCH_DICTIONARY_H_ +#define LIB_JXL_DEC_PATCH_DICTIONARY_H_ + +// Chooses reference patches, and avoids encoding them once per occurrence. + +#include <stddef.h> +#include <string.h> +#include <sys/types.h> + +#include <tuple> +#include <vector> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image.h" + +namespace jxl { + +enum class PatchBlendMode : uint8_t { + // The new values are the old ones. Useful to skip some channels. + kNone = 0, + // The new values (in the crop) replace the old ones: sample = new + kReplace = 1, + // The new values (in the crop) get added to the old ones: sample = old + new + kAdd = 2, + // The new values (in the crop) get multiplied by the old ones: + // sample = old * new + // This blend mode is only supported if BlendColorSpace is kEncoded. The + // range of the new value matters for multiplication purposes, and its + // nominal range of 0..1 is computed the same way as this is done for the + // alpha values in kBlend and kAlphaWeightedAdd. + kMul = 3, + // The new values (in the crop) replace the old ones if alpha>0: + // For first alpha channel: + // alpha = old + new * (1 - old) + // For other channels if !alpha_associated: + // sample = ((1 - new_alpha) * old * old_alpha + new_alpha * new) / alpha + // For other channels if alpha_associated: + // sample = (1 - new_alpha) * old + new + // The alpha formula applies to the alpha used for the division in the other + // channels formula, and applies to the alpha channel itself if its + // blend_channel value matches itself. + // If using kBlendAbove, new is the patch and old is the original image; if + // using kBlendBelow, the meaning is inverted. + kBlendAbove = 4, + kBlendBelow = 5, + // The new values (in the crop) are added to the old ones if alpha>0: + // For first alpha channel: sample = sample = old + new * (1 - old) + // For other channels: sample = old + alpha * new + kAlphaWeightedAddAbove = 6, + kAlphaWeightedAddBelow = 7, + kNumBlendModes, +}; + +inline bool UsesAlpha(PatchBlendMode mode) { + return mode == PatchBlendMode::kBlendAbove || + mode == PatchBlendMode::kBlendBelow || + mode == PatchBlendMode::kAlphaWeightedAddAbove || + mode == PatchBlendMode::kAlphaWeightedAddBelow; +} +inline bool UsesClamp(PatchBlendMode mode) { + return UsesAlpha(mode) || mode == PatchBlendMode::kMul; +} + +struct PatchBlending { + PatchBlendMode mode; + uint32_t alpha_channel; + bool clamp; +}; + +// Position and size of the patch in the reference frame. +struct PatchReferencePosition { + size_t ref, x0, y0, xsize, ysize; +}; + +struct PatchPosition { + // Position of top-left corner of the patch in the image. + size_t x, y; + size_t ref_pos_idx; +}; + +struct PassesSharedState; + +// Encoder-side helper class to encode the PatchesDictionary. +class PatchDictionaryEncoder; + +class PatchDictionary { + public: + PatchDictionary() = default; + + void SetPassesSharedState(const PassesSharedState* shared) { + shared_ = shared; + } + + bool HasAny() const { return !positions_.empty(); } + + Status Decode(BitReader* br, size_t xsize, size_t ysize, + bool* uses_extra_channels); + + void Clear() { + positions_.clear(); + ComputePatchTree(); + } + + // Adds patches to a segment of `xsize` pixels, starting at `inout`, assumed + // to be located at position (x0, y) in the frame. + void AddOneRow(float* const* inout, size_t y, size_t x0, size_t xsize) const; + + // Returns dependencies of this patch dictionary on reference frame ids as a + // bit mask: bits 0-3 indicate reference frame 0-3. + int GetReferences() const; + + std::vector<size_t> GetPatchesForRow(size_t y) const; + + private: + friend class PatchDictionaryEncoder; + + const PassesSharedState* shared_; + std::vector<PatchPosition> positions_; + std::vector<PatchReferencePosition> ref_positions_; + std::vector<PatchBlending> blendings_; + + // Interval tree on the y coordinates of the patches. + struct PatchTreeNode { + ssize_t left_child; + ssize_t right_child; + size_t y_center; + // Range of patches in sorted_patches_y0_ and sorted_patches_y1_ that + // contain the row y_center. + size_t start; + size_t num; + }; + std::vector<PatchTreeNode> patch_tree_; + // Number of patches for each row. + std::vector<size_t> num_patches_; + std::vector<std::pair<size_t, size_t>> sorted_patches_y0_; + std::vector<std::pair<size_t, size_t>> sorted_patches_y1_; + + void ComputePatchTree(); +}; + +} // namespace jxl + +#endif // LIB_JXL_DEC_PATCH_DICTIONARY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_transforms-inl.h b/third_party/jpeg-xl/lib/jxl/dec_transforms-inl.h new file mode 100644 index 0000000000..9c90550625 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_transforms-inl.h @@ -0,0 +1,827 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_DEC_TRANSFORMS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DEC_TRANSFORMS_INL_H_ +#undef LIB_JXL_DEC_TRANSFORMS_INL_H_ +#else +#define LIB_JXL_DEC_TRANSFORMS_INL_H_ +#endif + +#include <stddef.h> + +#include <hwy/highway.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/dct_scales.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::MulAdd; + +// Computes the lowest-frequency LF_ROWSxLF_COLS-sized square in output, which +// is a DCT_ROWS*DCT_COLS-sized DCT block, by doing a ROWS*COLS DCT on the +// input block. +template <size_t DCT_ROWS, size_t DCT_COLS, size_t LF_ROWS, size_t LF_COLS, + size_t ROWS, size_t COLS> +JXL_INLINE void ReinterpretingDCT(const float* input, const size_t input_stride, + float* output, const size_t output_stride, + float* JXL_RESTRICT block, + float* JXL_RESTRICT scratch_space) { + static_assert(LF_ROWS == ROWS, + "ReinterpretingDCT should only be called with LF == N"); + static_assert(LF_COLS == COLS, + "ReinterpretingDCT should only be called with LF == N"); + ComputeScaledDCT<ROWS, COLS>()(DCTFrom(input, input_stride), block, + scratch_space); + if (ROWS < COLS) { + for (size_t y = 0; y < LF_ROWS; y++) { + for (size_t x = 0; x < LF_COLS; x++) { + output[y * output_stride + x] = + block[y * COLS + x] * DCTTotalResampleScale<ROWS, DCT_ROWS>(y) * + DCTTotalResampleScale<COLS, DCT_COLS>(x); + } + } + } else { + for (size_t y = 0; y < LF_COLS; y++) { + for (size_t x = 0; x < LF_ROWS; x++) { + output[y * output_stride + x] = + block[y * ROWS + x] * DCTTotalResampleScale<COLS, DCT_COLS>(y) * + DCTTotalResampleScale<ROWS, DCT_ROWS>(x); + } + } + } +} + +template <size_t S> +void IDCT2TopBlock(const float* block, size_t stride_out, float* out) { + static_assert(kBlockDim % S == 0, "S should be a divisor of kBlockDim"); + static_assert(S % 2 == 0, "S should be even"); + float temp[kDCTBlockSize]; + constexpr size_t num_2x2 = S / 2; + for (size_t y = 0; y < num_2x2; y++) { + for (size_t x = 0; x < num_2x2; x++) { + float c00 = block[y * kBlockDim + x]; + float c01 = block[y * kBlockDim + num_2x2 + x]; + float c10 = block[(y + num_2x2) * kBlockDim + x]; + float c11 = block[(y + num_2x2) * kBlockDim + num_2x2 + x]; + float r00 = c00 + c01 + c10 + c11; + float r01 = c00 + c01 - c10 - c11; + float r10 = c00 - c01 + c10 - c11; + float r11 = c00 - c01 - c10 + c11; + temp[y * 2 * kBlockDim + x * 2] = r00; + temp[y * 2 * kBlockDim + x * 2 + 1] = r01; + temp[(y * 2 + 1) * kBlockDim + x * 2] = r10; + temp[(y * 2 + 1) * kBlockDim + x * 2 + 1] = r11; + } + } + for (size_t y = 0; y < S; y++) { + for (size_t x = 0; x < S; x++) { + out[y * stride_out + x] = temp[y * kBlockDim + x]; + } + } +} + +void AFVIDCT4x4(const float* JXL_RESTRICT coeffs, float* JXL_RESTRICT pixels) { + HWY_ALIGN static constexpr float k4x4AFVBasis[16][16] = { + { + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.25, + }, + { + 0.876902929799142f, + 0.2206518106944235f, + -0.10140050393753763f, + -0.1014005039375375f, + 0.2206518106944236f, + -0.10140050393753777f, + -0.10140050393753772f, + -0.10140050393753763f, + -0.10140050393753758f, + -0.10140050393753769f, + -0.1014005039375375f, + -0.10140050393753768f, + -0.10140050393753768f, + -0.10140050393753759f, + -0.10140050393753763f, + -0.10140050393753741f, + }, + { + 0.0, + 0.0, + 0.40670075830260755f, + 0.44444816619734445f, + 0.0, + 0.0, + 0.19574399372042936f, + 0.2929100136981264f, + -0.40670075830260716f, + -0.19574399372042872f, + 0.0, + 0.11379074460448091f, + -0.44444816619734384f, + -0.29291001369812636f, + -0.1137907446044814f, + 0.0, + }, + { + 0.0, + 0.0, + -0.21255748058288748f, + 0.3085497062849767f, + 0.0, + 0.4706702258572536f, + -0.1621205195722993f, + 0.0, + -0.21255748058287047f, + -0.16212051957228327f, + -0.47067022585725277f, + -0.1464291867126764f, + 0.3085497062849487f, + 0.0, + -0.14642918671266536f, + 0.4251149611657548f, + }, + { + 0.0, + -0.7071067811865474f, + 0.0, + 0.0, + 0.7071067811865476f, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + }, + { + -0.4105377591765233f, + 0.6235485373547691f, + -0.06435071657946274f, + -0.06435071657946266f, + 0.6235485373547694f, + -0.06435071657946284f, + -0.0643507165794628f, + -0.06435071657946274f, + -0.06435071657946272f, + -0.06435071657946279f, + -0.06435071657946266f, + -0.06435071657946277f, + -0.06435071657946277f, + -0.06435071657946273f, + -0.06435071657946274f, + -0.0643507165794626f, + }, + { + 0.0, + 0.0, + -0.4517556589999482f, + 0.15854503551840063f, + 0.0, + -0.04038515160822202f, + 0.0074182263792423875f, + 0.39351034269210167f, + -0.45175565899994635f, + 0.007418226379244351f, + 0.1107416575309343f, + 0.08298163094882051f, + 0.15854503551839705f, + 0.3935103426921022f, + 0.0829816309488214f, + -0.45175565899994796f, + }, + { + 0.0, + 0.0, + -0.304684750724869f, + 0.5112616136591823f, + 0.0, + 0.0, + -0.290480129728998f, + -0.06578701549142804f, + 0.304684750724884f, + 0.2904801297290076f, + 0.0, + -0.23889773523344604f, + -0.5112616136592012f, + 0.06578701549142545f, + 0.23889773523345467f, + 0.0, + }, + { + 0.0, + 0.0, + 0.3017929516615495f, + 0.25792362796341184f, + 0.0, + 0.16272340142866204f, + 0.09520022653475037f, + 0.0, + 0.3017929516615503f, + 0.09520022653475055f, + -0.16272340142866173f, + -0.35312385449816297f, + 0.25792362796341295f, + 0.0, + -0.3531238544981624f, + -0.6035859033230976f, + }, + { + 0.0, + 0.0, + 0.40824829046386274f, + 0.0, + 0.0, + 0.0, + 0.0, + -0.4082482904638628f, + -0.4082482904638635f, + 0.0, + 0.0, + -0.40824829046386296f, + 0.0, + 0.4082482904638634f, + 0.408248290463863f, + 0.0, + }, + { + 0.0, + 0.0, + 0.1747866975480809f, + 0.0812611176717539f, + 0.0, + 0.0, + -0.3675398009862027f, + -0.307882213957909f, + -0.17478669754808135f, + 0.3675398009862011f, + 0.0, + 0.4826689115059883f, + -0.08126111767175039f, + 0.30788221395790305f, + -0.48266891150598584f, + 0.0, + }, + { + 0.0, + 0.0, + -0.21105601049335784f, + 0.18567180916109802f, + 0.0, + 0.0, + 0.49215859013738733f, + -0.38525013709251915f, + 0.21105601049335806f, + -0.49215859013738905f, + 0.0, + 0.17419412659916217f, + -0.18567180916109904f, + 0.3852501370925211f, + -0.1741941265991621f, + 0.0, + }, + { + 0.0, + 0.0, + -0.14266084808807264f, + -0.3416446842253372f, + 0.0, + 0.7367497537172237f, + 0.24627107722075148f, + -0.08574019035519306f, + -0.14266084808807344f, + 0.24627107722075137f, + 0.14883399227113567f, + -0.04768680350229251f, + -0.3416446842253373f, + -0.08574019035519267f, + -0.047686803502292804f, + -0.14266084808807242f, + }, + { + 0.0, + 0.0, + -0.13813540350758585f, + 0.3302282550303788f, + 0.0, + 0.08755115000587084f, + -0.07946706605909573f, + -0.4613374887461511f, + -0.13813540350758294f, + -0.07946706605910261f, + 0.49724647109535086f, + 0.12538059448563663f, + 0.3302282550303805f, + -0.4613374887461554f, + 0.12538059448564315f, + -0.13813540350758452f, + }, + { + 0.0, + 0.0, + -0.17437602599651067f, + 0.0702790691196284f, + 0.0, + -0.2921026642334881f, + 0.3623817333531167f, + 0.0, + -0.1743760259965108f, + 0.36238173335311646f, + 0.29210266423348785f, + -0.4326608024727445f, + 0.07027906911962818f, + 0.0, + -0.4326608024727457f, + 0.34875205199302267f, + }, + { + 0.0, + 0.0, + 0.11354987314994337f, + -0.07417504595810355f, + 0.0, + 0.19402893032594343f, + -0.435190496523228f, + 0.21918684838857466f, + 0.11354987314994257f, + -0.4351904965232251f, + 0.5550443808910661f, + -0.25468277124066463f, + -0.07417504595810233f, + 0.2191868483885728f, + -0.25468277124066413f, + 0.1135498731499429f, + }, + }; + + const HWY_CAPPED(float, 16) d; + for (size_t i = 0; i < 16; i += Lanes(d)) { + auto pixel = Zero(d); + for (size_t j = 0; j < 16; j++) { + auto cf = Set(d, coeffs[j]); + auto basis = Load(d, k4x4AFVBasis[j] + i); + pixel = MulAdd(cf, basis, pixel); + } + Store(pixel, d, pixels + i); + } +} + +template <size_t afv_kind> +void AFVTransformToPixels(const float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, size_t pixels_stride) { + HWY_ALIGN float scratch_space[4 * 8 * 4]; + size_t afv_x = afv_kind & 1; + size_t afv_y = afv_kind / 2; + float dcs[3] = {}; + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + dcs[0] = (block00 + block10 + block01) * 4.0f; + dcs[1] = (block00 + block10 - block01); + dcs[2] = block00 - block10; + // IAFV: (even, even) positions. + HWY_ALIGN float coeff[4 * 4]; + coeff[0] = dcs[0]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + coeff[iy * 4 + ix] = coefficients[iy * 2 * 8 + ix * 2]; + } + } + HWY_ALIGN float block[4 * 8]; + AFVIDCT4x4(coeff, block); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + pixels[(iy + afv_y * 4) * pixels_stride + afv_x * 4 + ix] = + block[(afv_y == 1 ? 3 - iy : iy) * 4 + (afv_x == 1 ? 3 - ix : ix)]; + } + } + // IDCT4x4 in (odd, even) positions. + block[0] = dcs[1]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 4 + ix] = coefficients[iy * 2 * 8 + ix * 2 + 1]; + } + } + ComputeScaledIDCT<4, 4>()( + block, + DCTTo(pixels + afv_y * 4 * pixels_stride + (afv_x == 1 ? 0 : 4), + pixels_stride), + scratch_space); + // IDCT4x8. + block[0] = dcs[2]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 8 + ix] = coefficients[(1 + iy * 2) * 8 + ix]; + } + } + ComputeScaledIDCT<4, 8>()( + block, + DCTTo(pixels + (afv_y == 1 ? 0 : 4) * pixels_stride, pixels_stride), + scratch_space); +} + +HWY_MAYBE_UNUSED void TransformToPixels(const AcStrategy::Type strategy, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, + size_t pixels_stride, + float* scratch_space) { + using Type = AcStrategy::Type; + switch (strategy) { + case Type::IDENTITY: { + float dcs[4] = {}; + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + dcs[0] = block00 + block01 + block10 + block11; + dcs[1] = block00 + block01 - block10 - block11; + dcs[2] = block00 - block01 + block10 - block11; + dcs[3] = block00 - block01 - block10 + block11; + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + float block_dc = dcs[y * 2 + x]; + float residual_sum = 0; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + residual_sum += coefficients[(y + iy * 2) * 8 + x + ix * 2]; + } + } + pixels[(4 * y + 1) * pixels_stride + 4 * x + 1] = + block_dc - residual_sum * (1.0f / 16); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 1 && iy == 1) continue; + pixels[(y * 4 + iy) * pixels_stride + x * 4 + ix] = + coefficients[(y + iy * 2) * 8 + x + ix * 2] + + pixels[(4 * y + 1) * pixels_stride + 4 * x + 1]; + } + } + pixels[y * 4 * pixels_stride + x * 4] = + coefficients[(y + 2) * 8 + x + 2] + + pixels[(4 * y + 1) * pixels_stride + 4 * x + 1]; + } + } + break; + } + case Type::DCT8X4: { + float dcs[2] = {}; + float block0 = coefficients[0]; + float block1 = coefficients[8]; + dcs[0] = block0 + block1; + dcs[1] = block0 - block1; + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 8]; + block[0] = dcs[x]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 8 + ix] = coefficients[(x + iy * 2) * 8 + ix]; + } + } + ComputeScaledIDCT<8, 4>()(block, DCTTo(pixels + x * 4, pixels_stride), + scratch_space); + } + break; + } + case Type::DCT4X8: { + float dcs[2] = {}; + float block0 = coefficients[0]; + float block1 = coefficients[8]; + dcs[0] = block0 + block1; + dcs[1] = block0 - block1; + for (size_t y = 0; y < 2; y++) { + HWY_ALIGN float block[4 * 8]; + block[0] = dcs[y]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 8 + ix] = coefficients[(y + iy * 2) * 8 + ix]; + } + } + ComputeScaledIDCT<4, 8>()( + block, DCTTo(pixels + y * 4 * pixels_stride, pixels_stride), + scratch_space); + } + break; + } + case Type::DCT4X4: { + float dcs[4] = {}; + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + dcs[0] = block00 + block01 + block10 + block11; + dcs[1] = block00 + block01 - block10 - block11; + dcs[2] = block00 - block01 + block10 - block11; + dcs[3] = block00 - block01 - block10 + block11; + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 4]; + block[0] = dcs[y * 2 + x]; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 0 && iy == 0) continue; + block[iy * 4 + ix] = coefficients[(y + iy * 2) * 8 + x + ix * 2]; + } + } + ComputeScaledIDCT<4, 4>()( + block, + DCTTo(pixels + y * 4 * pixels_stride + x * 4, pixels_stride), + scratch_space); + } + } + break; + } + case Type::DCT2X2: { + HWY_ALIGN float coeffs[kDCTBlockSize]; + memcpy(coeffs, coefficients, sizeof(float) * kDCTBlockSize); + IDCT2TopBlock<2>(coeffs, kBlockDim, coeffs); + IDCT2TopBlock<4>(coeffs, kBlockDim, coeffs); + IDCT2TopBlock<8>(coeffs, kBlockDim, coeffs); + for (size_t y = 0; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + pixels[y * pixels_stride + x] = coeffs[y * kBlockDim + x]; + } + } + break; + } + case Type::DCT16X16: { + ComputeScaledIDCT<16, 16>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT16X8: { + ComputeScaledIDCT<16, 8>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT8X16: { + ComputeScaledIDCT<8, 16>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X8: { + ComputeScaledIDCT<32, 8>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT8X32: { + ComputeScaledIDCT<8, 32>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X16: { + ComputeScaledIDCT<32, 16>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT16X32: { + ComputeScaledIDCT<16, 32>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X32: { + ComputeScaledIDCT<32, 32>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT: { + ComputeScaledIDCT<8, 8>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::AFV0: { + AFVTransformToPixels<0>(coefficients, pixels, pixels_stride); + break; + } + case Type::AFV1: { + AFVTransformToPixels<1>(coefficients, pixels, pixels_stride); + break; + } + case Type::AFV2: { + AFVTransformToPixels<2>(coefficients, pixels, pixels_stride); + break; + } + case Type::AFV3: { + AFVTransformToPixels<3>(coefficients, pixels, pixels_stride); + break; + } + case Type::DCT64X32: { + ComputeScaledIDCT<64, 32>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT32X64: { + ComputeScaledIDCT<32, 64>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT64X64: { + ComputeScaledIDCT<64, 64>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT128X64: { + ComputeScaledIDCT<128, 64>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT64X128: { + ComputeScaledIDCT<64, 128>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT128X128: { + ComputeScaledIDCT<128, 128>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT256X128: { + ComputeScaledIDCT<256, 128>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT128X256: { + ComputeScaledIDCT<128, 256>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::DCT256X256: { + ComputeScaledIDCT<256, 256>()(coefficients, DCTTo(pixels, pixels_stride), + scratch_space); + break; + } + case Type::kNumValidStrategies: + JXL_UNREACHABLE("Invalid strategy"); + } +} + +HWY_MAYBE_UNUSED void LowestFrequenciesFromDC(const AcStrategy::Type strategy, + const float* dc, size_t dc_stride, + float* llf, + float* JXL_RESTRICT scratch) { + using Type = AcStrategy::Type; + HWY_ALIGN float warm_block[4 * 4]; + HWY_ALIGN float warm_scratch_space[4 * 4 * 4]; + switch (strategy) { + case Type::DCT16X8: { + ReinterpretingDCT</*DCT_ROWS=*/2 * kBlockDim, /*DCT_COLS=*/kBlockDim, + /*LF_ROWS=*/2, /*LF_COLS=*/1, /*ROWS=*/2, /*COLS=*/1>( + dc, dc_stride, llf, 2 * kBlockDim, warm_block, warm_scratch_space); + break; + } + case Type::DCT8X16: { + ReinterpretingDCT</*DCT_ROWS=*/kBlockDim, /*DCT_COLS=*/2 * kBlockDim, + /*LF_ROWS=*/1, /*LF_COLS=*/2, /*ROWS=*/1, /*COLS=*/2>( + dc, dc_stride, llf, 2 * kBlockDim, warm_block, warm_scratch_space); + break; + } + case Type::DCT16X16: { + ReinterpretingDCT</*DCT_ROWS=*/2 * kBlockDim, /*DCT_COLS=*/2 * kBlockDim, + /*LF_ROWS=*/2, /*LF_COLS=*/2, /*ROWS=*/2, /*COLS=*/2>( + dc, dc_stride, llf, 2 * kBlockDim, warm_block, warm_scratch_space); + break; + } + case Type::DCT32X8: { + ReinterpretingDCT</*DCT_ROWS=*/4 * kBlockDim, /*DCT_COLS=*/kBlockDim, + /*LF_ROWS=*/4, /*LF_COLS=*/1, /*ROWS=*/4, /*COLS=*/1>( + dc, dc_stride, llf, 4 * kBlockDim, warm_block, warm_scratch_space); + break; + } + case Type::DCT8X32: { + ReinterpretingDCT</*DCT_ROWS=*/kBlockDim, /*DCT_COLS=*/4 * kBlockDim, + /*LF_ROWS=*/1, /*LF_COLS=*/4, /*ROWS=*/1, /*COLS=*/4>( + dc, dc_stride, llf, 4 * kBlockDim, warm_block, warm_scratch_space); + break; + } + case Type::DCT32X16: { + ReinterpretingDCT</*DCT_ROWS=*/4 * kBlockDim, /*DCT_COLS=*/2 * kBlockDim, + /*LF_ROWS=*/4, /*LF_COLS=*/2, /*ROWS=*/4, /*COLS=*/2>( + dc, dc_stride, llf, 4 * kBlockDim, warm_block, warm_scratch_space); + break; + } + case Type::DCT16X32: { + ReinterpretingDCT</*DCT_ROWS=*/2 * kBlockDim, /*DCT_COLS=*/4 * kBlockDim, + /*LF_ROWS=*/2, /*LF_COLS=*/4, /*ROWS=*/2, /*COLS=*/4>( + dc, dc_stride, llf, 4 * kBlockDim, warm_block, warm_scratch_space); + break; + } + case Type::DCT32X32: { + ReinterpretingDCT</*DCT_ROWS=*/4 * kBlockDim, /*DCT_COLS=*/4 * kBlockDim, + /*LF_ROWS=*/4, /*LF_COLS=*/4, /*ROWS=*/4, /*COLS=*/4>( + dc, dc_stride, llf, 4 * kBlockDim, warm_block, warm_scratch_space); + break; + } + case Type::DCT64X32: { + ReinterpretingDCT</*DCT_ROWS=*/8 * kBlockDim, /*DCT_COLS=*/4 * kBlockDim, + /*LF_ROWS=*/8, /*LF_COLS=*/4, /*ROWS=*/8, /*COLS=*/4>( + dc, dc_stride, llf, 8 * kBlockDim, scratch, scratch + 8 * 4); + break; + } + case Type::DCT32X64: { + ReinterpretingDCT</*DCT_ROWS=*/4 * kBlockDim, /*DCT_COLS=*/8 * kBlockDim, + /*LF_ROWS=*/4, /*LF_COLS=*/8, /*ROWS=*/4, /*COLS=*/8>( + dc, dc_stride, llf, 8 * kBlockDim, scratch, scratch + 4 * 8); + break; + } + case Type::DCT64X64: { + ReinterpretingDCT</*DCT_ROWS=*/8 * kBlockDim, /*DCT_COLS=*/8 * kBlockDim, + /*LF_ROWS=*/8, /*LF_COLS=*/8, /*ROWS=*/8, /*COLS=*/8>( + dc, dc_stride, llf, 8 * kBlockDim, scratch, scratch + 8 * 8); + break; + } + case Type::DCT128X64: { + ReinterpretingDCT</*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/8 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/8, /*ROWS=*/16, /*COLS=*/8>( + dc, dc_stride, llf, 16 * kBlockDim, scratch, scratch + 16 * 8); + break; + } + case Type::DCT64X128: { + ReinterpretingDCT</*DCT_ROWS=*/8 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/8, /*LF_COLS=*/16, /*ROWS=*/8, /*COLS=*/16>( + dc, dc_stride, llf, 16 * kBlockDim, scratch, scratch + 8 * 16); + break; + } + case Type::DCT128X128: { + ReinterpretingDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/16, /*ROWS=*/16, /*COLS=*/16>( + dc, dc_stride, llf, 16 * kBlockDim, scratch, scratch + 16 * 16); + break; + } + case Type::DCT256X128: { + ReinterpretingDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/16, /*ROWS=*/32, /*COLS=*/16>( + dc, dc_stride, llf, 32 * kBlockDim, scratch, scratch + 32 * 16); + break; + } + case Type::DCT128X256: { + ReinterpretingDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/32, /*ROWS=*/16, /*COLS=*/32>( + dc, dc_stride, llf, 32 * kBlockDim, scratch, scratch + 16 * 32); + break; + } + case Type::DCT256X256: { + ReinterpretingDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/32, /*ROWS=*/32, /*COLS=*/32>( + dc, dc_stride, llf, 32 * kBlockDim, scratch, scratch + 32 * 32); + break; + } + case Type::DCT: + case Type::DCT2X2: + case Type::DCT4X4: + case Type::DCT4X8: + case Type::DCT8X4: + case Type::AFV0: + case Type::AFV1: + case Type::AFV2: + case Type::AFV3: + case Type::IDENTITY: + llf[0] = dc[0]; + break; + case Type::kNumValidStrategies: + JXL_UNREACHABLE("Invalid strategy"); + }; +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_DEC_TRANSFORMS_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.cc b/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.cc new file mode 100644 index 0000000000..2d40740262 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.cc @@ -0,0 +1,42 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_transforms_testonly.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_transforms_testonly.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_transforms-inl.h" + +namespace jxl { + +#if HWY_ONCE +HWY_EXPORT(TransformToPixels); +void TransformToPixels(AcStrategy::Type strategy, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, size_t pixels_stride, + float* scratch_space) { + return HWY_DYNAMIC_DISPATCH(TransformToPixels)(strategy, coefficients, pixels, + pixels_stride, scratch_space); +} + +HWY_EXPORT(LowestFrequenciesFromDC); +void LowestFrequenciesFromDC(const jxl::AcStrategy::Type strategy, + const float* dc, size_t dc_stride, float* llf, + float* JXL_RESTRICT scratch) { + return HWY_DYNAMIC_DISPATCH(LowestFrequenciesFromDC)(strategy, dc, dc_stride, + llf, scratch); +} + +HWY_EXPORT(AFVIDCT4x4); +void AFVIDCT4x4(const float* JXL_RESTRICT coeffs, float* JXL_RESTRICT pixels) { + return HWY_DYNAMIC_DISPATCH(AFVIDCT4x4)(coeffs, pixels); +} +#endif // HWY_ONCE + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.h b/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.h new file mode 100644 index 0000000000..f68481fda9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_transforms_testonly.h @@ -0,0 +1,33 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_TRANSFORMS_TESTONLY_H_ +#define LIB_JXL_DEC_TRANSFORMS_TESTONLY_H_ + +// Facade for (non-inlined) inverse integral transforms. + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +void TransformToPixels(AcStrategy::Type strategy, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT pixels, size_t pixels_stride, + float* JXL_RESTRICT scratch_space); + +// Equivalent of the above for DC image. +void LowestFrequenciesFromDC(const jxl::AcStrategy::Type strategy, + const float* dc, size_t dc_stride, float* llf, + float* JXL_RESTRICT scratch); + +void AFVIDCT4x4(const float* JXL_RESTRICT coeffs, float* JXL_RESTRICT pixels); + +} // namespace jxl + +#endif // LIB_JXL_DEC_TRANSFORMS_TESTONLY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_xyb-inl.h b/third_party/jpeg-xl/lib/jxl/dec_xyb-inl.h new file mode 100644 index 0000000000..495693b257 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_xyb-inl.h @@ -0,0 +1,346 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// XYB -> linear sRGB helper function. + +#if defined(LIB_JXL_DEC_XYB_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_DEC_XYB_INL_H_ +#undef LIB_JXL_DEC_XYB_INL_H_ +#else +#define LIB_JXL_DEC_XYB_INL_H_ +#endif + +#include <hwy/highway.h> + +#include "lib/jxl/dec_xyb.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Broadcast; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Sub; + +// Inverts the pixel-wise RGB->XYB conversion in OpsinDynamicsImage() (including +// the gamma mixing and simple gamma). Avoids clamping to [0, 1] - out of (sRGB) +// gamut values may be in-gamut after transforming to a wider space. +// "inverse_matrix" points to 9 broadcasted vectors, which are the 3x3 entries +// of the (row-major) opsin absorbance matrix inverse. Pre-multiplying its +// entries by c is equivalent to multiplying linear_* by c afterwards. +template <class D, class V> +HWY_INLINE HWY_MAYBE_UNUSED void XybToRgb(D d, const V opsin_x, const V opsin_y, + const V opsin_b, + const OpsinParams& opsin_params, + V* const HWY_RESTRICT linear_r, + V* const HWY_RESTRICT linear_g, + V* const HWY_RESTRICT linear_b) { +#if HWY_TARGET == HWY_SCALAR + const auto neg_bias_r = Set(d, opsin_params.opsin_biases[0]); + const auto neg_bias_g = Set(d, opsin_params.opsin_biases[1]); + const auto neg_bias_b = Set(d, opsin_params.opsin_biases[2]); +#else + const auto neg_bias_rgb = LoadDup128(d, opsin_params.opsin_biases); + const auto neg_bias_r = Broadcast<0>(neg_bias_rgb); + const auto neg_bias_g = Broadcast<1>(neg_bias_rgb); + const auto neg_bias_b = Broadcast<2>(neg_bias_rgb); +#endif + + // Color space: XYB -> RGB + auto gamma_r = Add(opsin_y, opsin_x); + auto gamma_g = Sub(opsin_y, opsin_x); + auto gamma_b = opsin_b; + + gamma_r = Sub(gamma_r, Set(d, opsin_params.opsin_biases_cbrt[0])); + gamma_g = Sub(gamma_g, Set(d, opsin_params.opsin_biases_cbrt[1])); + gamma_b = Sub(gamma_b, Set(d, opsin_params.opsin_biases_cbrt[2])); + + // Undo gamma compression: linear = gamma^3 for efficiency. + const auto gamma_r2 = Mul(gamma_r, gamma_r); + const auto gamma_g2 = Mul(gamma_g, gamma_g); + const auto gamma_b2 = Mul(gamma_b, gamma_b); + const auto mixed_r = MulAdd(gamma_r2, gamma_r, neg_bias_r); + const auto mixed_g = MulAdd(gamma_g2, gamma_g, neg_bias_g); + const auto mixed_b = MulAdd(gamma_b2, gamma_b, neg_bias_b); + + const float* HWY_RESTRICT inverse_matrix = opsin_params.inverse_opsin_matrix; + + // Unmix (multiply by 3x3 inverse_matrix) + // TODO(eustas): ref would be more readable than pointer + *linear_r = Mul(LoadDup128(d, &inverse_matrix[0 * 4]), mixed_r); + *linear_g = Mul(LoadDup128(d, &inverse_matrix[3 * 4]), mixed_r); + *linear_b = Mul(LoadDup128(d, &inverse_matrix[6 * 4]), mixed_r); + *linear_r = MulAdd(LoadDup128(d, &inverse_matrix[1 * 4]), mixed_g, *linear_r); + *linear_g = MulAdd(LoadDup128(d, &inverse_matrix[4 * 4]), mixed_g, *linear_g); + *linear_b = MulAdd(LoadDup128(d, &inverse_matrix[7 * 4]), mixed_g, *linear_b); + *linear_r = MulAdd(LoadDup128(d, &inverse_matrix[2 * 4]), mixed_b, *linear_r); + *linear_g = MulAdd(LoadDup128(d, &inverse_matrix[5 * 4]), mixed_b, *linear_g); + *linear_b = MulAdd(LoadDup128(d, &inverse_matrix[8 * 4]), mixed_b, *linear_b); +} + +static inline HWY_MAYBE_UNUSED bool HasFastXYBTosRGB8() { +#if HWY_TARGET == HWY_NEON + return true; +#else + return false; +#endif +} + +static inline HWY_MAYBE_UNUSED void FastXYBTosRGB8(const float* input[4], + uint8_t* output, + bool is_rgba, size_t xsize) { + // This function is very NEON-specific. As such, it uses intrinsics directly. +#if HWY_TARGET == HWY_NEON + // WARNING: doing fixed point arithmetic correctly is very complicated. + // Changes to this function should be thoroughly tested. + + // Note that the input is assumed to have 13 bits of mantissa, and the output + // will have 14 bits. + auto srgb_tf = [&](int16x8_t v16) { + int16x8_t clz = vclzq_s16(v16); + // Convert to [0.25, 0.5) range. + int16x8_t v025_05_16 = vqshlq_s16(v16, vqsubq_s16(clz, vdupq_n_s16(2))); + + // third degree polynomial approximation between 0.25 and 0.5 + // of 1.055/2^(7/2.4) * x^(1/2.4) / 32. + // poly ~ ((0.95x-1.75)*x+1.72)*x+0.29 + // We actually compute ~ ((0.47x-0.87)*x+0.86)*(2x)+0.29 as 1.75 and 1.72 + // overflow our fixed point representation. + + int16x8_t twov = vqaddq_s16(v025_05_16, v025_05_16); + + // 0.47 * x + int16x8_t step1 = vqrdmulhq_n_s16(v025_05_16, 15706); + // - 0.87 + int16x8_t step2 = vsubq_s16(step1, vdupq_n_s16(28546)); + // * x + int16x8_t step3 = vqrdmulhq_s16(step2, v025_05_16); + // + 0.86 + int16x8_t step4 = vaddq_s16(step3, vdupq_n_s16(28302)); + // * 2x + int16x8_t step5 = vqrdmulhq_s16(step4, twov); + // + 0.29 + int16x8_t mul16 = vaddq_s16(step5, vdupq_n_s16(9485)); + + int16x8_t exp16 = vsubq_s16(vdupq_n_s16(11), clz); + // Compute 2**(1/2.4*exp16)/32. Values of exp16 that would overflow are + // capped to 1. + // Generated with the following Python script: + // a = [] + // b = [] + // + // for i in range(0, 16): + // v = 2**(5/12.*i) + // v /= 16 + // v *= 256 * 128 + // v = int(v) + // a.append(v // 256) + // b.append(v % 256) + // + // print(", ".join("0x%02x" % x for x in a)) + // + // print(", ".join("0x%02x" % x for x in b)) + + HWY_ALIGN constexpr uint8_t k2to512powersm1div32_high[16] = { + 0x08, 0x0a, 0x0e, 0x13, 0x19, 0x21, 0x2d, 0x3c, + 0x50, 0x6b, 0x8f, 0x8f, 0x8f, 0x8f, 0x8f, 0x8f, + }; + HWY_ALIGN constexpr uint8_t k2to512powersm1div32_low[16] = { + 0x00, 0xad, 0x41, 0x06, 0x65, 0xe7, 0x41, 0x68, + 0xa2, 0xa2, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + }; + // Using the highway implementation here since vqtbl1q is aarch64-only. + using hwy::HWY_NAMESPACE::Vec128; + uint8x16_t pow_low = + TableLookupBytes( + Vec128<uint8_t, 16>(vld1q_u8(k2to512powersm1div32_low)), + Vec128<uint8_t, 16>(vreinterpretq_u8_s16(exp16))) + .raw; + uint8x16_t pow_high = + TableLookupBytes( + Vec128<uint8_t, 16>(vld1q_u8(k2to512powersm1div32_high)), + Vec128<uint8_t, 16>(vreinterpretq_u8_s16(exp16))) + .raw; + int16x8_t pow16 = vreinterpretq_s16_u16(vsliq_n_u16( + vreinterpretq_u16_u8(pow_low), vreinterpretq_u16_u8(pow_high), 8)); + + // approximation of v * 12.92, divided by 2 + // Note that our input is using 13 mantissa bits instead of 15. + int16x8_t v16_linear = vrshrq_n_s16(vmulq_n_s16(v16, 826), 5); + // 1.055*pow(v, 1/2.4) - 0.055, divided by 2 + auto v16_pow = vsubq_s16(vqrdmulhq_s16(mul16, pow16), vdupq_n_s16(901)); + // > 0.0031308f (note that v16 has 13 mantissa bits) + return vbslq_s16(vcgeq_s16(v16, vdupq_n_s16(26)), v16_pow, v16_linear); + }; + + const float* JXL_RESTRICT row_in_x = input[0]; + const float* JXL_RESTRICT row_in_y = input[1]; + const float* JXL_RESTRICT row_in_b = input[2]; + const float* JXL_RESTRICT row_in_a = input[3]; + for (size_t x = 0; x < xsize; x += 8) { + // Normal ranges for xyb for in-gamut sRGB colors: + // x: -0.015386 0.028100 + // y: 0.000000 0.845308 + // b: 0.000000 0.845308 + + // We actually want x * 8 to have some extra precision. + // TODO(veluca): consider different approaches here, like vld1q_f32_x2. + float32x4_t opsin_x_left = vld1q_f32(row_in_x + x); + int16x4_t opsin_x16_times8_left = + vqmovn_s32(vcvtq_n_s32_f32(opsin_x_left, 18)); + float32x4_t opsin_x_right = + vld1q_f32(row_in_x + x + (x + 4 < xsize ? 4 : 0)); + int16x4_t opsin_x16_times8_right = + vqmovn_s32(vcvtq_n_s32_f32(opsin_x_right, 18)); + int16x8_t opsin_x16_times8 = + vcombine_s16(opsin_x16_times8_left, opsin_x16_times8_right); + + float32x4_t opsin_y_left = vld1q_f32(row_in_y + x); + int16x4_t opsin_y16_left = vqmovn_s32(vcvtq_n_s32_f32(opsin_y_left, 15)); + float32x4_t opsin_y_right = + vld1q_f32(row_in_y + x + (x + 4 < xsize ? 4 : 0)); + int16x4_t opsin_y16_right = vqmovn_s32(vcvtq_n_s32_f32(opsin_y_right, 15)); + int16x8_t opsin_y16 = vcombine_s16(opsin_y16_left, opsin_y16_right); + + float32x4_t opsin_b_left = vld1q_f32(row_in_b + x); + int16x4_t opsin_b16_left = vqmovn_s32(vcvtq_n_s32_f32(opsin_b_left, 15)); + float32x4_t opsin_b_right = + vld1q_f32(row_in_b + x + (x + 4 < xsize ? 4 : 0)); + int16x4_t opsin_b16_right = vqmovn_s32(vcvtq_n_s32_f32(opsin_b_right, 15)); + int16x8_t opsin_b16 = vcombine_s16(opsin_b16_left, opsin_b16_right); + + int16x8_t neg_bias16 = vdupq_n_s16(-124); // -0.0037930732552754493 + int16x8_t neg_bias_cbrt16 = vdupq_n_s16(-5110); // -0.155954201 + int16x8_t neg_bias_half16 = vdupq_n_s16(-62); + + // Color space: XYB -> RGB + // Compute ((y+x-bias_cbrt)^3-(y-x-bias_cbrt)^3)/2, + // ((y+x-bias_cbrt)^3+(y-x-bias_cbrt)^3)/2+bias, (b-bias_cbrt)^3+bias. + // Note that ignoring x2 in the formulas below (as x << y) results in + // errors of at least 3 in the final sRGB values. + int16x8_t opsin_yp16 = vqsubq_s16(opsin_y16, neg_bias_cbrt16); + int16x8_t ysq16 = vqrdmulhq_s16(opsin_yp16, opsin_yp16); + int16x8_t twentyfourx16 = vmulq_n_s16(opsin_x16_times8, 3); + int16x8_t twentyfourxy16 = vqrdmulhq_s16(opsin_yp16, twentyfourx16); + int16x8_t threexsq16 = + vrshrq_n_s16(vqrdmulhq_s16(opsin_x16_times8, twentyfourx16), 6); + + // We can ignore x^3 here. Note that this is multiplied by 8. + int16x8_t mixed_rmg16 = vqrdmulhq_s16(twentyfourxy16, opsin_yp16); + + int16x8_t mixed_rpg_sos_half = vhaddq_s16(ysq16, threexsq16); + int16x8_t mixed_rpg16 = vhaddq_s16( + vqrdmulhq_s16(opsin_yp16, mixed_rpg_sos_half), neg_bias_half16); + + int16x8_t gamma_b16 = vqsubq_s16(opsin_b16, neg_bias_cbrt16); + int16x8_t gamma_bsq16 = vqrdmulhq_s16(gamma_b16, gamma_b16); + int16x8_t gamma_bcb16 = vqrdmulhq_s16(gamma_bsq16, gamma_b16); + int16x8_t mixed_b16 = vqaddq_s16(gamma_bcb16, neg_bias16); + // mixed_rpg and mixed_b are in 0-1 range. + // mixed_rmg has a smaller range (-0.035 to 0.035 for valid sRGB). Note + // that at this point it is already multiplied by 8. + + // We multiply all the mixed values by 1/4 (i.e. shift them to 13-bit + // fixed point) to ensure intermediate quantities are in range. Note that + // r-g is not shifted, and was x8 before here; this corresponds to a x32 + // overall multiplicative factor and ensures that all the matrix constants + // are in 0-1 range. + // Similarly, mixed_rpg16 is already multiplied by 1/4 because of the two + // vhadd + using neg_bias_half. + mixed_b16 = vshrq_n_s16(mixed_b16, 2); + + // Unmix (multiply by 3x3 inverse_matrix) + // For increased precision, we use a matrix for converting from + // ((mixed_r - mixed_g)/2, (mixed_r + mixed_g)/2, mixed_b) to rgb. This + // avoids cancellation effects when computing (y+x)^3-(y-x)^3. + // We compute mixed_rpg - mixed_b because the (1+c)*mixed_rpg - c * + // mixed_b pattern is repeated frequently in the code below. This allows + // us to save a multiply per channel, and removes the presence of + // some constants above 1. Moreover, mixed_rmg - mixed_b is in (-1, 1) + // range, so the subtraction is safe. + // All the magic-looking constants here are derived by computing the + // inverse opsin matrix for the transformation modified as described + // above. + + // Precomputation common to multiple color values. + int16x8_t mixed_rpgmb16 = vqsubq_s16(mixed_rpg16, mixed_b16); + int16x8_t mixed_rpgmb_times_016 = vqrdmulhq_n_s16(mixed_rpgmb16, 5394); + int16x8_t mixed_rg16 = vqaddq_s16(mixed_rpgmb_times_016, mixed_rpg16); + + // R + int16x8_t linear_r16 = + vqaddq_s16(mixed_rg16, vqrdmulhq_n_s16(mixed_rmg16, 21400)); + + // G + int16x8_t linear_g16 = + vqaddq_s16(mixed_rg16, vqrdmulhq_n_s16(mixed_rmg16, -7857)); + + // B + int16x8_t linear_b16 = vqrdmulhq_n_s16(mixed_rpgmb16, -30996); + linear_b16 = vqaddq_s16(linear_b16, mixed_b16); + linear_b16 = vqaddq_s16(linear_b16, vqrdmulhq_n_s16(mixed_rmg16, -6525)); + + // Apply SRGB transfer function. + int16x8_t r = srgb_tf(linear_r16); + int16x8_t g = srgb_tf(linear_g16); + int16x8_t b = srgb_tf(linear_b16); + + uint8x8_t r8 = + vqmovun_s16(vrshrq_n_s16(vsubq_s16(r, vshrq_n_s16(r, 8)), 6)); + uint8x8_t g8 = + vqmovun_s16(vrshrq_n_s16(vsubq_s16(g, vshrq_n_s16(g, 8)), 6)); + uint8x8_t b8 = + vqmovun_s16(vrshrq_n_s16(vsubq_s16(b, vshrq_n_s16(b, 8)), 6)); + + size_t n = xsize - x; + if (is_rgba) { + float32x4_t a_f32_left = + row_in_a ? vld1q_f32(row_in_a + x) : vdupq_n_f32(1.0f); + float32x4_t a_f32_right = + row_in_a ? vld1q_f32(row_in_a + x + (x + 4 < xsize ? 4 : 0)) + : vdupq_n_f32(1.0f); + int16x4_t a16_left = vqmovn_s32(vcvtq_n_s32_f32(a_f32_left, 8)); + int16x4_t a16_right = vqmovn_s32(vcvtq_n_s32_f32(a_f32_right, 8)); + uint8x8_t a8 = vqmovun_s16(vcombine_s16(a16_left, a16_right)); + uint8_t* buf = output + 4 * x; + uint8x8x4_t data = {r8, g8, b8, a8}; + if (n >= 8) { + vst4_u8(buf, data); + } else { + uint8_t tmp[8 * 4]; + vst4_u8(tmp, data); + memcpy(buf, tmp, n * 4); + } + } else { + uint8_t* buf = output + 3 * x; + uint8x8x3_t data = {r8, g8, b8}; + if (n >= 8) { + vst3_u8(buf, data); + } else { + uint8_t tmp[8 * 3]; + vst3_u8(tmp, data); + memcpy(buf, tmp, n * 3); + } + } + } +#else + (void)input; + (void)output; + (void)is_rgba; + (void)xsize; + JXL_UNREACHABLE("Unreachable"); +#endif +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_DEC_XYB_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/dec_xyb.cc b/third_party/jpeg-xl/lib/jxl/dec_xyb.cc new file mode 100644 index 0000000000..7010f0d813 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_xyb.cc @@ -0,0 +1,328 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/dec_xyb.h" + +#include <string.h> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/dec_xyb.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/matrix_ops.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/cms/jxl_cms_internal.h" +#include "lib/jxl/cms/opsin_params.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_group_border.h" +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/sanitizers.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::MulAdd; + +void OpsinToLinearInplace(Image3F* JXL_RESTRICT inout, ThreadPool* pool, + const OpsinParams& opsin_params) { + JXL_CHECK_IMAGE_INITIALIZED(*inout, Rect(*inout)); + + const size_t xsize = inout->xsize(); // not padded + JXL_CHECK(RunOnPool( + pool, 0, inout->ysize(), ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + + // Faster than adding via ByteOffset at end of loop. + float* JXL_RESTRICT row0 = inout->PlaneRow(0, y); + float* JXL_RESTRICT row1 = inout->PlaneRow(1, y); + float* JXL_RESTRICT row2 = inout->PlaneRow(2, y); + + const HWY_FULL(float) d; + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_opsin_x = Load(d, row0 + x); + const auto in_opsin_y = Load(d, row1 + x); + const auto in_opsin_b = Load(d, row2 + x); + auto linear_r = Undefined(d); + auto linear_g = Undefined(d); + auto linear_b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, opsin_params, + &linear_r, &linear_g, &linear_b); + + Store(linear_r, d, row0 + x); + Store(linear_g, d, row1 + x); + Store(linear_b, d, row2 + x); + } + }, + "OpsinToLinear")); +} + +// Same, but not in-place. +void OpsinToLinear(const Image3F& opsin, const Rect& rect, ThreadPool* pool, + Image3F* JXL_RESTRICT linear, + const OpsinParams& opsin_params) { + JXL_ASSERT(SameSize(rect, *linear)); + JXL_CHECK_IMAGE_INITIALIZED(opsin, rect); + + JXL_CHECK(RunOnPool( + pool, 0, static_cast<int>(rect.ysize()), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = static_cast<size_t>(task); + + // Faster than adding via ByteOffset at end of loop. + const float* JXL_RESTRICT row_opsin_0 = rect.ConstPlaneRow(opsin, 0, y); + const float* JXL_RESTRICT row_opsin_1 = rect.ConstPlaneRow(opsin, 1, y); + const float* JXL_RESTRICT row_opsin_2 = rect.ConstPlaneRow(opsin, 2, y); + float* JXL_RESTRICT row_linear_0 = linear->PlaneRow(0, y); + float* JXL_RESTRICT row_linear_1 = linear->PlaneRow(1, y); + float* JXL_RESTRICT row_linear_2 = linear->PlaneRow(2, y); + + const HWY_FULL(float) d; + + for (size_t x = 0; x < rect.xsize(); x += Lanes(d)) { + const auto in_opsin_x = Load(d, row_opsin_0 + x); + const auto in_opsin_y = Load(d, row_opsin_1 + x); + const auto in_opsin_b = Load(d, row_opsin_2 + x); + auto linear_r = Undefined(d); + auto linear_g = Undefined(d); + auto linear_b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, opsin_params, + &linear_r, &linear_g, &linear_b); + + Store(linear_r, d, row_linear_0 + x); + Store(linear_g, d, row_linear_1 + x); + Store(linear_b, d, row_linear_2 + x); + } + }, + "OpsinToLinear(Rect)")); + JXL_CHECK_IMAGE_INITIALIZED(*linear, rect); +} + +// Transform YCbCr to RGB. +// Could be performed in-place (i.e. Y, Cb and Cr could alias R, B and B). +void YcbcrToRgb(const Image3F& ycbcr, Image3F* rgb, const Rect& rect) { + JXL_CHECK_IMAGE_INITIALIZED(ycbcr, rect); + const HWY_CAPPED(float, kBlockDim) df; + const size_t S = Lanes(df); // Step. + + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + if ((xsize == 0) || (ysize == 0)) return; + + // Full-range BT.601 as defined by JFIF Clause 7: + // https://www.itu.int/rec/T-REC-T.871-201105-I/en + const auto c128 = Set(df, 128.0f / 255); + const auto crcr = Set(df, 1.402f); + const auto cgcb = Set(df, -0.114f * 1.772f / 0.587f); + const auto cgcr = Set(df, -0.299f * 1.402f / 0.587f); + const auto cbcb = Set(df, 1.772f); + + for (size_t y = 0; y < ysize; y++) { + const float* y_row = rect.ConstPlaneRow(ycbcr, 1, y); + const float* cb_row = rect.ConstPlaneRow(ycbcr, 0, y); + const float* cr_row = rect.ConstPlaneRow(ycbcr, 2, y); + float* r_row = rect.PlaneRow(rgb, 0, y); + float* g_row = rect.PlaneRow(rgb, 1, y); + float* b_row = rect.PlaneRow(rgb, 2, y); + for (size_t x = 0; x < xsize; x += S) { + const auto y_vec = Add(Load(df, y_row + x), c128); + const auto cb_vec = Load(df, cb_row + x); + const auto cr_vec = Load(df, cr_row + x); + const auto r_vec = MulAdd(crcr, cr_vec, y_vec); + const auto g_vec = MulAdd(cgcr, cr_vec, MulAdd(cgcb, cb_vec, y_vec)); + const auto b_vec = MulAdd(cbcb, cb_vec, y_vec); + Store(r_vec, df, r_row + x); + Store(g_vec, df, g_row + x); + Store(b_vec, df, b_row + x); + } + } + JXL_CHECK_IMAGE_INITIALIZED(*rgb, rect); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(OpsinToLinearInplace); +void OpsinToLinearInplace(Image3F* JXL_RESTRICT inout, ThreadPool* pool, + const OpsinParams& opsin_params) { + return HWY_DYNAMIC_DISPATCH(OpsinToLinearInplace)(inout, pool, opsin_params); +} + +HWY_EXPORT(OpsinToLinear); +void OpsinToLinear(const Image3F& opsin, const Rect& rect, ThreadPool* pool, + Image3F* JXL_RESTRICT linear, + const OpsinParams& opsin_params) { + return HWY_DYNAMIC_DISPATCH(OpsinToLinear)(opsin, rect, pool, linear, + opsin_params); +} + +HWY_EXPORT(YcbcrToRgb); +void YcbcrToRgb(const Image3F& ycbcr, Image3F* rgb, const Rect& rect) { + return HWY_DYNAMIC_DISPATCH(YcbcrToRgb)(ycbcr, rgb, rect); +} + +HWY_EXPORT(HasFastXYBTosRGB8); +bool HasFastXYBTosRGB8() { return HWY_DYNAMIC_DISPATCH(HasFastXYBTosRGB8)(); } + +HWY_EXPORT(FastXYBTosRGB8); +void FastXYBTosRGB8(const float* input[4], uint8_t* output, bool is_rgba, + size_t xsize) { + return HWY_DYNAMIC_DISPATCH(FastXYBTosRGB8)(input, output, is_rgba, xsize); +} + +void OpsinParams::Init(float intensity_target) { + InitSIMDInverseMatrix(GetOpsinAbsorbanceInverseMatrix(), inverse_opsin_matrix, + intensity_target); + memcpy(opsin_biases, jxl::cms::kNegOpsinAbsorbanceBiasRGB.data(), + sizeof(jxl::cms::kNegOpsinAbsorbanceBiasRGB)); + memcpy(quant_biases, kDefaultQuantBias, sizeof(kDefaultQuantBias)); + for (size_t c = 0; c < 4; c++) { + opsin_biases_cbrt[c] = cbrtf(opsin_biases[c]); + } +} + +bool CanOutputToColorEncoding(const ColorEncoding& c_desired) { + if (!c_desired.HaveFields()) { + return false; + } + // TODO(veluca): keep in sync with dec_reconstruct.cc + const auto& tf = c_desired.Tf(); + if (!tf.IsPQ() && !tf.IsSRGB() && !tf.have_gamma && !tf.IsLinear() && + !tf.IsHLG() && !tf.IsDCI() && !tf.Is709()) { + return false; + } + if (c_desired.IsGray() && c_desired.GetWhitePointType() != WhitePoint::kD65) { + // TODO(veluca): figure out what should happen here. + return false; + } + return true; +} + +Status OutputEncodingInfo::SetFromMetadata(const CodecMetadata& metadata) { + orig_color_encoding = metadata.m.color_encoding; + orig_intensity_target = metadata.m.IntensityTarget(); + desired_intensity_target = orig_intensity_target; + const auto& im = metadata.transform_data.opsin_inverse_matrix; + memcpy(orig_inverse_matrix, im.inverse_matrix, sizeof(orig_inverse_matrix)); + default_transform = im.all_default; + xyb_encoded = metadata.m.xyb_encoded; + std::copy(std::begin(im.opsin_biases), std::end(im.opsin_biases), + opsin_params.opsin_biases); + for (int i = 0; i < 3; ++i) { + opsin_params.opsin_biases_cbrt[i] = cbrtf(opsin_params.opsin_biases[i]); + } + opsin_params.opsin_biases_cbrt[3] = opsin_params.opsin_biases[3] = 1; + std::copy(std::begin(im.quant_biases), std::end(im.quant_biases), + opsin_params.quant_biases); + bool orig_ok = CanOutputToColorEncoding(orig_color_encoding); + bool orig_grey = orig_color_encoding.IsGray(); + return SetColorEncoding(!xyb_encoded || orig_ok + ? orig_color_encoding + : ColorEncoding::LinearSRGB(orig_grey)); +} + +Status OutputEncodingInfo::MaybeSetColorEncoding( + const ColorEncoding& c_desired) { + if (c_desired.GetColorSpace() == ColorSpace::kXYB && + ((color_encoding.GetColorSpace() == ColorSpace::kRGB && + color_encoding.GetPrimariesType() != Primaries::kSRGB) || + color_encoding.Tf().IsPQ())) { + return false; + } + if (!xyb_encoded && !CanOutputToColorEncoding(c_desired)) { + return false; + } + return SetColorEncoding(c_desired); +} + +Status OutputEncodingInfo::SetColorEncoding(const ColorEncoding& c_desired) { + color_encoding = c_desired; + linear_color_encoding = color_encoding; + linear_color_encoding.Tf().SetTransferFunction(TransferFunction::kLinear); + color_encoding_is_original = orig_color_encoding.SameColorEncoding(c_desired); + + // Compute the opsin inverse matrix and luminances based on primaries and + // white point. + float inverse_matrix[9]; + bool inverse_matrix_is_default = default_transform; + memcpy(inverse_matrix, orig_inverse_matrix, sizeof(inverse_matrix)); + constexpr float kSRGBLuminances[3] = {0.2126, 0.7152, 0.0722}; + memcpy(luminances, kSRGBLuminances, sizeof(luminances)); + if ((c_desired.GetPrimariesType() != Primaries::kSRGB || + c_desired.GetWhitePointType() != WhitePoint::kD65) && + !c_desired.IsGray()) { + float srgb_to_xyzd50[9]; + const auto& srgb = ColorEncoding::SRGB(/*is_gray=*/false); + PrimariesCIExy p = srgb.GetPrimaries(); + CIExy w = srgb.GetWhitePoint(); + JXL_CHECK(PrimariesToXYZD50(p.r.x, p.r.y, p.g.x, p.g.y, p.b.x, p.b.y, w.x, + w.y, srgb_to_xyzd50)); + float original_to_xyz[3][3]; + p = c_desired.GetPrimaries(); + w = c_desired.GetWhitePoint(); + if (!PrimariesToXYZ(p.r.x, p.r.y, p.g.x, p.g.y, p.b.x, p.b.y, w.x, w.y, + &original_to_xyz[0][0])) { + return JXL_FAILURE("PrimariesToXYZ failed"); + } + memcpy(luminances, original_to_xyz[1], sizeof luminances); + if (xyb_encoded) { + float adapt_to_d50[9]; + if (!AdaptToXYZD50(c_desired.GetWhitePoint().x, + c_desired.GetWhitePoint().y, adapt_to_d50)) { + return JXL_FAILURE("AdaptToXYZD50 failed"); + } + float xyzd50_to_original[9]; + Mul3x3Matrix(adapt_to_d50, &original_to_xyz[0][0], xyzd50_to_original); + JXL_RETURN_IF_ERROR(Inv3x3Matrix(xyzd50_to_original)); + float srgb_to_original[9]; + Mul3x3Matrix(xyzd50_to_original, srgb_to_xyzd50, srgb_to_original); + Mul3x3Matrix(srgb_to_original, orig_inverse_matrix, inverse_matrix); + inverse_matrix_is_default = false; + } + } + + if (c_desired.IsGray()) { + float tmp_inv_matrix[9]; + memcpy(tmp_inv_matrix, inverse_matrix, sizeof(inverse_matrix)); + float srgb_to_luma[9]; + memcpy(&srgb_to_luma[0], luminances, sizeof(luminances)); + memcpy(&srgb_to_luma[3], luminances, sizeof(luminances)); + memcpy(&srgb_to_luma[6], luminances, sizeof(luminances)); + Mul3x3Matrix(srgb_to_luma, tmp_inv_matrix, inverse_matrix); + } + + // The internal XYB color space uses absolute luminance, so we scale back the + // opsin inverse matrix to relative luminance where 1.0 corresponds to the + // original intensity target. + if (xyb_encoded) { + InitSIMDInverseMatrix(inverse_matrix, opsin_params.inverse_opsin_matrix, + orig_intensity_target); + all_default_opsin = (std::abs(orig_intensity_target - 255.0) <= 0.1f && + inverse_matrix_is_default); + } + + // Set the inverse gamma based on color space transfer function. + const auto& tf = c_desired.Tf(); + inverse_gamma = (tf.have_gamma ? tf.GetGamma() + : tf.IsDCI() ? 1.0f / 2.6f + : 1.0); + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/dec_xyb.h b/third_party/jpeg-xl/lib/jxl/dec_xyb.h new file mode 100644 index 0000000000..ddfd555632 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/dec_xyb.h @@ -0,0 +1,98 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DEC_XYB_H_ +#define LIB_JXL_DEC_XYB_H_ + +// XYB -> linear sRGB. + +#include <jxl/cms_interface.h> + +#include <cstddef> +#include <cstdint> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_metadata.h" + +namespace jxl { + +// Parameters for XYB->sRGB conversion. +struct OpsinParams { + float inverse_opsin_matrix[9 * 4]; + float opsin_biases[4]; + float opsin_biases_cbrt[4]; + float quant_biases[4]; + void Init(float intensity_target); +}; + +struct OutputEncodingInfo { + // + // Fields depending only on image metadata + // + ColorEncoding orig_color_encoding; + // Used for the HLG OOTF and PQ tone mapping. + float orig_intensity_target; + // Opsin inverse matrix taken from the metadata. + float orig_inverse_matrix[9]; + bool default_transform; + bool xyb_encoded; + // + // Fields depending on output color encoding + // + // The requested color encoding. + ColorEncoding color_encoding; + // This is expected as the output of the conversion from XYB. + // It is equal to `color_encoding`, but with a linear tone response curve. + ColorEncoding linear_color_encoding; + bool color_encoding_is_original; + // Contains an opsin matrix that converts to the primaries of the output + // encoding. + OpsinParams opsin_params; + bool all_default_opsin; + // Used for Gamma and DCI transfer functions. + float inverse_gamma; + // Luminances of color_encoding's primaries, used for the HLG inverse OOTF and + // for PQ tone mapping. + // Default to sRGB's. + float luminances[3]; + // Used for the HLG inverse OOTF and PQ tone mapping. + float desired_intensity_target; + bool cms_set = false; + JxlCmsInterface color_management_system; + + Status SetFromMetadata(const CodecMetadata& metadata); + Status MaybeSetColorEncoding(const ColorEncoding& c_desired); + + private: + Status SetColorEncoding(const ColorEncoding& c_desired); +}; + +// Converts `inout` (not padded) from opsin to linear sRGB in-place. Called from +// per-pass postprocessing, hence parallelized. +void OpsinToLinearInplace(Image3F* JXL_RESTRICT inout, ThreadPool* pool, + const OpsinParams& opsin_params); + +// Converts `opsin:rect` (opsin may be padded, rect.x0 must be vector-aligned) +// to linear sRGB. Called from whole-frame encoder, hence parallelized. +void OpsinToLinear(const Image3F& opsin, const Rect& rect, ThreadPool* pool, + Image3F* JXL_RESTRICT linear, + const OpsinParams& opsin_params); + +// Bt.601 to match JPEG/JFIF. Inputs are _signed_ YCbCr values suitable for DCT, +// see F.1.1.3 of T.81 (because our data type is float, there is no need to add +// a bias to make the values unsigned). +void YcbcrToRgb(const Image3F& ycbcr, Image3F* rgb, const Rect& rect); + +bool HasFastXYBTosRGB8(); +void FastXYBTosRGB8(const float* input[4], uint8_t* output, bool is_rgba, + size_t xsize); + +} // namespace jxl + +#endif // LIB_JXL_DEC_XYB_H_ diff --git a/third_party/jpeg-xl/lib/jxl/decode.cc b/third_party/jpeg-xl/lib/jxl/decode.cc new file mode 100644 index 0000000000..b674d1ba88 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/decode.cc @@ -0,0 +1,2841 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/decode.h> +#include <jxl/types.h> +#include <jxl/version.h> + +#include <algorithm> +#include <array> +#include <functional> +#include <memory> +#include <utility> +#include <vector> + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/padded_bytes.h" + +// JPEGXL_ENABLE_BOXES, JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/common.h" + +#if JPEGXL_ENABLE_BOXES || JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/box_content_decoder.h" +#endif +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/dec_modular.h" +#if JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/decode_to_jpeg.h" +#endif +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/icc_codec.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/memory_manager_internal.h" +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/toc.h" + +namespace { + +// Checks if a + b > size, taking possible integer overflow into account. +bool OutOfBounds(size_t a, size_t b, size_t size) { + size_t pos = a + b; + if (pos > size) return true; + if (pos < a) return true; // overflow happened + return false; +} + +JXL_INLINE size_t InitialBasicInfoSizeHint() { + // Amount of bytes before the start of the codestream in the container format, + // assuming that the codestream is the first box after the signature and + // filetype boxes. 12 bytes signature box + 20 bytes filetype box + 16 bytes + // codestream box length + name + optional XLBox length. + const size_t container_header_size = 48; + + // Worst-case amount of bytes for basic info of the JPEG XL codestream header, + // that is all information up to and including extra_channel_bits. Up to + // around 2 bytes signature + 8 bytes SizeHeader + 31 bytes ColorEncoding + 4 + // bytes rest of ImageMetadata + 5 bytes part of ImageMetadata2. + // TODO(lode): recompute and update this value when alpha_bits is moved to + // extra channels info. + const size_t max_codestream_basic_info_size = 50; + + return container_header_size + max_codestream_basic_info_size; +} + +// Debug-printing failure macro similar to JXL_FAILURE, but for the status code +// JXL_DEC_ERROR +#ifdef JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(format, ...) \ + (::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__), \ + ::jxl::Abort(), JXL_DEC_ERROR) +#else // JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(format, ...) \ + (((JXL_DEBUG_ON_ERROR) && \ + ::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__)), \ + JXL_DEC_ERROR) +#endif // JXL_CRASH_ON_ERROR + +// Error caused by bad input (invalid file) rather than incorrect API usage. +// For now there is no way to distinguish these two types of errors yet. +#define JXL_INPUT_ERROR(format, ...) JXL_API_ERROR(format, ##__VA_ARGS__) + +JxlDecoderStatus ConvertStatus(JxlDecoderStatus status) { return status; } + +JxlDecoderStatus ConvertStatus(jxl::Status status) { + return status ? JXL_DEC_SUCCESS : JXL_DEC_ERROR; +} + +#define JXL_API_RETURN_IF_ERROR(expr) \ + { \ + JxlDecoderStatus status_ = ConvertStatus(expr); \ + if (status_ != JXL_DEC_SUCCESS) return status_; \ + } + +JxlSignature ReadSignature(const uint8_t* buf, size_t len, size_t* pos) { + if (*pos >= len) return JXL_SIG_NOT_ENOUGH_BYTES; + + buf += *pos; + len -= *pos; + + // JPEG XL codestream: 0xff 0x0a + if (len >= 1 && buf[0] == 0xff) { + if (len < 2) { + return JXL_SIG_NOT_ENOUGH_BYTES; + } else if (buf[1] == jxl::kCodestreamMarker) { + *pos += 2; + return JXL_SIG_CODESTREAM; + } else { + return JXL_SIG_INVALID; + } + } + + // JPEG XL container + if (len >= 1 && buf[0] == 0) { + if (len < 12) { + return JXL_SIG_NOT_ENOUGH_BYTES; + } else if (buf[1] == 0 && buf[2] == 0 && buf[3] == 0xC && buf[4] == 'J' && + buf[5] == 'X' && buf[6] == 'L' && buf[7] == ' ' && + buf[8] == 0xD && buf[9] == 0xA && buf[10] == 0x87 && + buf[11] == 0xA) { + *pos += 12; + return JXL_SIG_CONTAINER; + } else { + return JXL_SIG_INVALID; + } + } + + return JXL_SIG_INVALID; +} + +} // namespace + +uint32_t JxlDecoderVersion(void) { + return JPEGXL_MAJOR_VERSION * 1000000 + JPEGXL_MINOR_VERSION * 1000 + + JPEGXL_PATCH_VERSION; +} + +JxlSignature JxlSignatureCheck(const uint8_t* buf, size_t len) { + size_t pos = 0; + return ReadSignature(buf, len, &pos); +} + +namespace { + +size_t BitsPerChannel(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_UINT8: + return 8; + case JXL_TYPE_UINT16: + return 16; + case JXL_TYPE_FLOAT: + return 32; + case JXL_TYPE_FLOAT16: + return 16; + default: + return 0; // signals unhandled JxlDataType + } +} + +template <typename T> +uint32_t GetBitDepth(JxlBitDepth bit_depth, const T& metadata, + JxlPixelFormat format) { + if (bit_depth.type == JXL_BIT_DEPTH_FROM_PIXEL_FORMAT) { + return BitsPerChannel(format.data_type); + } else if (bit_depth.type == JXL_BIT_DEPTH_FROM_CODESTREAM) { + return metadata.bit_depth.bits_per_sample; + } else if (bit_depth.type == JXL_BIT_DEPTH_CUSTOM) { + return bit_depth.bits_per_sample; + } + return 0; +} + +enum class DecoderStage : uint32_t { + kInited, // Decoder created, no JxlDecoderProcessInput called yet + kStarted, // Running JxlDecoderProcessInput calls + kCodestreamFinished, // Codestream done, but other boxes could still occur. + // This stage can also occur before having seen the + // entire codestream if the user didn't subscribe to any + // codestream events at all, e.g. only to box events, + // or, the user only subscribed to basic info, and only + // the header of the codestream was parsed. + kError, // Error occurred, decoder object no longer usable +}; + +enum class FrameStage : uint32_t { + kHeader, // Must parse frame header. + kTOC, // Must parse TOC + kFull, // Must parse full pixels +}; + +enum class BoxStage : uint32_t { + kHeader, // Parsing box header of the next box, or start of non-container + // stream + kFtyp, // The ftyp box + kSkip, // Box whose contents are skipped + kCodestream, // Handling codestream box contents, or non-container stream + kPartialCodestream, // Handling the extra header of partial codestream box + kJpegRecon, // Handling jpeg reconstruction box +}; + +enum class JpegReconStage : uint32_t { + kNone, // Not outputting + kSettingMetadata, // Ready to output, must set metadata to the jpeg_data + kOutputting, // Currently outputting the JPEG bytes +}; + +/* +Given list of frame references to storage slots, and storage slots in which this +frame is saved, computes which frames are required to decode the frame at the +given index and any frames after it. The frames on which this depends are +returned as a vector of their indices, in no particular order. The given index +must be smaller than saved_as.size(), and references.size() must equal +saved_as.size(). Any frames beyond saved_as and references are considered +unknown future frames and must be treated as if something depends on them. +*/ +std::vector<size_t> GetFrameDependencies(size_t index, + const std::vector<int>& saved_as, + const std::vector<int>& references) { + JXL_ASSERT(references.size() == saved_as.size()); + JXL_ASSERT(index < references.size()); + + std::vector<size_t> result; + + constexpr size_t kNumStorage = 8; + + // value which indicates nothing is stored in this storage slot + const size_t invalid = references.size(); + // for each of the 8 storage slots, a vector that translates frame index to + // frame stored in this storage slot at this point, that is, the last + // frame that was stored in this slot before or at this index. + std::array<std::vector<size_t>, kNumStorage> storage; + for (size_t s = 0; s < kNumStorage; ++s) { + storage[s].resize(saved_as.size()); + int mask = 1 << s; + size_t id = invalid; + for (size_t i = 0; i < saved_as.size(); ++i) { + if (saved_as[i] & mask) { + id = i; + } + storage[s][i] = id; + } + } + + std::vector<char> seen(index + 1, 0); + std::vector<size_t> stack; + stack.push_back(index); + seen[index] = 1; + + // For frames after index, assume they can depend on any of the 8 storage + // slots, so push the frame for each stored reference to the stack and result. + // All frames after index are treated as having unknown references and with + // the possibility that there are more frames after the last known. + // TODO(lode): take values of saved_as and references after index, and a + // input flag indicating if they are all frames of the image, to further + // optimize this. + for (size_t s = 0; s < kNumStorage; ++s) { + size_t frame_ref = storage[s][index]; + if (frame_ref == invalid) continue; + if (seen[frame_ref]) continue; + stack.push_back(frame_ref); + seen[frame_ref] = 1; + result.push_back(frame_ref); + } + + while (!stack.empty()) { + size_t frame_index = stack.back(); + stack.pop_back(); + if (frame_index == 0) continue; // first frame cannot have references + for (size_t s = 0; s < kNumStorage; ++s) { + int mask = 1 << s; + if (!(references[frame_index] & mask)) continue; + size_t frame_ref = storage[s][frame_index - 1]; + if (frame_ref == invalid) continue; + if (seen[frame_ref]) continue; + stack.push_back(frame_ref); + seen[frame_ref] = 1; + result.push_back(frame_ref); + } + } + + return result; +} + +// Parameters for user-requested extra channel output. +struct ExtraChannelOutput { + JxlPixelFormat format; + void* buffer; + size_t buffer_size; +}; + +} // namespace + +namespace jxl { + +typedef struct JxlDecoderFrameIndexBoxEntryStruct { + // OFFi: offset of start byte of this frame compared to start + // byte of previous frame from this index in the JPEG XL codestream. For the + // first frame, this is the offset from the first byte of the JPEG XL + // codestream. + uint64_t OFFi; + // Ti: duration in ticks between the start of this frame and + // the start of the next frame in the index. If this is the last frame in the + // index, this is the duration in ticks between the start of this frame and + // the end of the stream. A tick lasts TNUM / TDEN seconds. + uint32_t Ti; + // Fi: amount of frames the next frame in the index occurs + // after this frame. If this is the last frame in the index, this is the + // amount of frames after this frame in the remainder of the stream. Only + // frames that are presented by the decoder are counted for this purpose, this + // excludes frames that are not intended for display but for compositing with + // other frames, such as frames that aren't the last frame with a duration of + // 0 ticks. + uint32_t Fi; +} JxlDecoderFrameIndexBoxEntry; + +typedef struct JxlDecoderFrameIndexBoxStruct { + int64_t NF() const { return entries.size(); } + int32_t TNUM = 1; + int32_t TDEN = 1000; + + std::vector<JxlDecoderFrameIndexBoxEntry> entries; + + // That way we can ensure that every index box will have the first frame. + // If the API user decides to mark it as an indexed frame, we call + // the AddFrame again, this time with requested. + void AddFrame(uint64_t OFFi, uint32_t Ti, uint32_t Fi) { + JxlDecoderFrameIndexBoxEntry e; + e.OFFi = OFFi; + e.Ti = Ti; + e.Fi = Fi; + entries.push_back(e); + } +} JxlDecoderFrameIndexBox; + +} // namespace jxl + +// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) +struct JxlDecoderStruct { + JxlDecoderStruct() = default; + + JxlMemoryManager memory_manager; + std::unique_ptr<jxl::ThreadPool> thread_pool; + + DecoderStage stage; + + // Status of progression, internal. + bool got_signature; + // Indicates we know that we've seen the last codestream box: either this + // was a jxlc box, or a jxlp box that has its index indicated as last by + // having its most significant bit set, or no boxes are used at all. This + // does not indicate the full codestream has already been seen, only the + // last box of it has been initiated. + bool last_codestream_seen; + bool got_codestream_signature; + bool got_basic_info; + bool got_transform_data; // To skip everything before ICC. + bool got_all_headers; // Codestream metadata headers. + bool post_headers; // Already decoding pixels. + jxl::ICCReader icc_reader; + jxl::JxlDecoderFrameIndexBox frame_index_box; + // This means either we actually got the preview image, or determined we + // cannot get it or there is none. + bool got_preview_image; + bool preview_frame; + + // Position of next_in in the original file including box format if present + // (as opposed to position in the codestream) + size_t file_pos; + + size_t box_contents_begin; + size_t box_contents_end; + size_t box_contents_size; + size_t box_size; + size_t header_size; + // Either a final box that runs until EOF, or the case of no container format + // at all. + bool box_contents_unbounded; + + JxlBoxType box_type; + JxlBoxType box_decoded_type; // Underlying type for brob boxes + // Set to true right after a JXL_DEC_BOX event only. + bool box_event; + bool decompress_boxes; + + bool box_out_buffer_set; + // Whether the out buffer is set for the current box, if the user did not yet + // release the buffer while the next box is encountered, this will be set to + // false. If this is false, no JXL_DEC_NEED_MORE_INPUT is emitted + // (irrespective of the value of box_out_buffer_set), because not setting + // output indicates the user does not wish the data of this box. + bool box_out_buffer_set_current_box; + uint8_t* box_out_buffer; + size_t box_out_buffer_size; + // which byte of the full box content the start of the out buffer points to + size_t box_out_buffer_begin; + // which byte of box_out_buffer to write to next + size_t box_out_buffer_pos; + + // Settings + bool keep_orientation; + bool unpremul_alpha; + bool render_spotcolors; + bool coalescing; + float desired_intensity_target; + + // Bitfield, for which informative events (JXL_DEC_BASIC_INFO, etc...) the + // decoder returns a status. By default, do not return for any of the events, + // only return when the decoder cannot continue because it needs more input or + // output data. + int events_wanted; + int orig_events_wanted; + + // Fields for reading the basic info from the header. + size_t basic_info_size_hint; + bool have_container; + size_t box_count; + + // The level of progressive detail in frame decoding. + JxlProgressiveDetail prog_detail = kDC; + // The progressive detail of the current frame. + JxlProgressiveDetail frame_prog_detail; + // The intended downsampling ratio for the current progression step. + size_t downsampling_target; + + // Set to true if either an image out buffer or an image out callback was set. + bool image_out_buffer_set; + + // Owned by the caller, buffer for preview or full resolution image. + void* image_out_buffer; + JxlImageOutInitCallback image_out_init_callback; + JxlImageOutRunCallback image_out_run_callback; + JxlImageOutDestroyCallback image_out_destroy_callback; + void* image_out_init_opaque; + struct SimpleImageOutCallback { + JxlImageOutCallback callback; + void* opaque; + }; + SimpleImageOutCallback simple_image_out_callback; + + size_t image_out_size; + + JxlPixelFormat image_out_format; + JxlBitDepth image_out_bit_depth; + + // For extra channels. Empty if no extra channels are requested, and they are + // reset each frame + std::vector<ExtraChannelOutput> extra_channel_output; + + jxl::CodecMetadata metadata; + // Same as metadata.m, except for the color_encoding, which is set to the + // output encoding. + jxl::ImageMetadata image_metadata; + std::unique_ptr<jxl::ImageBundle> ib; + + std::unique_ptr<jxl::PassesDecoderState> passes_state; + std::unique_ptr<jxl::FrameDecoder> frame_dec; + size_t next_section; + std::vector<char> section_processed; + + // headers and TOC for the current frame. When got_toc is true, this is + // always the frame header of the last frame of the current still series, + // that is, the displayed frame. + std::unique_ptr<jxl::FrameHeader> frame_header; + + size_t remaining_frame_size; + FrameStage frame_stage; + bool dc_frame_progression_done; + // The currently processed frame is the last of the current composite still, + // and so must be returned as pixels + bool is_last_of_still; + // The currently processed frame is the last of the codestream + bool is_last_total; + // How many frames to skip. + size_t skip_frames; + // Skipping the current frame. May be false if skip_frames was just set to + // a positive value while already processing a current frame, then + // skipping_frame will be enabled only for the next frame. + bool skipping_frame; + + // Amount of internal frames and external frames started. External frames are + // user-visible frames, internal frames includes all external frames and + // also invisible frames such as patches, blending-only and dc_level frames. + size_t internal_frames; + size_t external_frames; + + // For each internal frame, which storage locations it references, and which + // storage locations it is stored in, using the bit mask as defined in + // FrameDecoder::References and FrameDecoder::SaveAs. + std::vector<int> frame_references; + std::vector<int> frame_saved_as; + + // Translates external frame index to internal frame index. The external + // index is the index of user-visible frames. The internal index can be larger + // since non-visible frames (such as frames with patches, ...) are included. + std::vector<size_t> frame_external_to_internal; + + // Whether the frame with internal index is required to decode the frame + // being skipped to or any frames after that. If no skipping is active, + // this vector is ignored. If the current internal frame index is beyond this + // vector, it must be treated as a required frame. + std::vector<char> frame_required; + + // Codestream input data is copied here temporarily when the decoder needs + // more input bytes to process the next part of the stream. We copy the input + // data in order to be able to release it all through the API it when + // returning JXL_DEC_NEED_MORE_INPUT. + std::vector<uint8_t> codestream_copy; + // Number of bytes at the end of codestream_copy that were not yet consumed + // by calling AdvanceInput(). + size_t codestream_unconsumed; + // Position in the codestream_copy vector that the decoder already finished + // processing. It can be greater than the current size of codestream_copy in + // case where the decoder skips some parts of the frame that were not yet + // provided. + size_t codestream_pos; + // Number of bits after codestream_pos that were already processed. + size_t codestream_bits_ahead; + + BoxStage box_stage; + +#if JPEGXL_ENABLE_BOXES + jxl::JxlBoxContentDecoder box_content_decoder; +#endif +#if JPEGXL_ENABLE_TRANSCODE_JPEG + jxl::JxlToJpegDecoder jpeg_decoder; + // Decodes Exif or XMP metadata for JPEG reconstruction + jxl::JxlBoxContentDecoder metadata_decoder; + std::vector<uint8_t> exif_metadata; + std::vector<uint8_t> xmp_metadata; + // must store JPEG reconstruction metadata from the current box + // 0 = not stored, 1 = currently storing, 2 = finished + int store_exif; + int store_xmp; + size_t recon_out_buffer_pos; + size_t recon_exif_size; // Expected exif size as read from the jbrd box + size_t recon_xmp_size; // Expected exif size as read from the jbrd box + JpegReconStage recon_output_jpeg; + + bool JbrdNeedMoreBoxes() const { + // jbrd box wants exif but exif box not yet seen + if (store_exif < 2 && recon_exif_size > 0) return true; + // jbrd box wants xmp but xmp box not yet seen + if (store_xmp < 2 && recon_xmp_size > 0) return true; + return false; + } +#endif + + const uint8_t* next_in; + size_t avail_in; + bool input_closed; + + void AdvanceInput(size_t size) { + JXL_DASSERT(avail_in >= size); + next_in += size; + avail_in -= size; + file_pos += size; + } + + size_t AvailableCodestream() const { + size_t avail_codestream = avail_in; + if (!box_contents_unbounded) { + avail_codestream = + std::min<size_t>(avail_codestream, box_contents_end - file_pos); + } + return avail_codestream; + } + + void AdvanceCodestream(size_t size) { + size_t avail_codestream = AvailableCodestream(); + if (codestream_copy.empty()) { + if (size <= avail_codestream) { + AdvanceInput(size); + } else { + codestream_pos = size - avail_codestream; + AdvanceInput(avail_codestream); + } + } else { + codestream_pos += size; + if (codestream_pos + codestream_unconsumed >= codestream_copy.size()) { + size_t advance = std::min( + codestream_unconsumed, + codestream_unconsumed + codestream_pos - codestream_copy.size()); + AdvanceInput(advance); + codestream_pos -= std::min(codestream_pos, codestream_copy.size()); + codestream_unconsumed = 0; + codestream_copy.clear(); + } + } + } + + JxlDecoderStatus RequestMoreInput() { + if (codestream_copy.empty()) { + size_t avail_codestream = AvailableCodestream(); + codestream_copy.insert(codestream_copy.end(), next_in, + next_in + avail_codestream); + AdvanceInput(avail_codestream); + } else { + AdvanceInput(codestream_unconsumed); + codestream_unconsumed = 0; + } + return JXL_DEC_NEED_MORE_INPUT; + } + + JxlDecoderStatus GetCodestreamInput(jxl::Span<const uint8_t>* span) { + if (codestream_copy.empty() && codestream_pos > 0) { + size_t avail_codestream = AvailableCodestream(); + size_t skip = std::min<size_t>(codestream_pos, avail_codestream); + AdvanceInput(skip); + codestream_pos -= skip; + if (codestream_pos > 0) { + return RequestMoreInput(); + } + } + JXL_ASSERT(codestream_pos <= codestream_copy.size()); + JXL_ASSERT(codestream_unconsumed <= codestream_copy.size()); + size_t avail_codestream = AvailableCodestream(); + if (codestream_copy.empty()) { + if (avail_codestream == 0) { + return RequestMoreInput(); + } + *span = jxl::Bytes(next_in, avail_codestream); + return JXL_DEC_SUCCESS; + } else { + codestream_copy.insert(codestream_copy.end(), + next_in + codestream_unconsumed, + next_in + avail_codestream); + codestream_unconsumed = avail_codestream; + *span = jxl::Bytes(codestream_copy.data() + codestream_pos, + codestream_copy.size() - codestream_pos); + return JXL_DEC_SUCCESS; + } + } + + // Whether the decoder can use more codestream input for a purpose it needs. + // This returns false if the user didn't subscribe to any events that + // require the codestream (e.g. only subscribed to metadata boxes), or all + // parts of the codestream that are subscribed to (e.g. only basic info) have + // already occurred. + bool CanUseMoreCodestreamInput() const { + // The decoder can set this to finished early if all relevant events were + // processed, so this check works. + return stage != DecoderStage::kCodestreamFinished; + } + + // If set then some operations will fail, if those would require + // allocating large objects. Actual memory usage might be two orders of + // magnitude bigger. + // TODO(eustas): remove once there is working API for memory / CPU limit. + size_t memory_limit_base = 0; + size_t cpu_limit_base = 0; + size_t used_cpu_base = 0; +}; + +namespace { + +bool CheckSizeLimit(JxlDecoder* dec, size_t xsize, size_t ysize) { + if (!dec->memory_limit_base) return true; + if (xsize == 0 || ysize == 0) return true; + if (xsize >= dec->memory_limit_base || ysize >= dec->memory_limit_base) { + return false; + } + // Rough estimate of real row length. + xsize = jxl::DivCeil(xsize, 32) * 32; + size_t num_pixels = xsize * ysize; + if (num_pixels / xsize != ysize) return false; // overflow + if (num_pixels > dec->memory_limit_base) return false; + return true; +} + +} // namespace + +// Resets the state that must be reset for both Rewind and Reset +void JxlDecoderRewindDecodingState(JxlDecoder* dec) { + dec->stage = DecoderStage::kInited; + dec->got_signature = false; + dec->last_codestream_seen = false; + dec->got_codestream_signature = false; + dec->got_basic_info = false; + dec->got_transform_data = false; + dec->got_all_headers = false; + dec->post_headers = false; + dec->icc_reader.Reset(); + dec->got_preview_image = false; + dec->preview_frame = false; + dec->file_pos = 0; + dec->box_contents_begin = 0; + dec->box_contents_end = 0; + dec->box_contents_size = 0; + dec->box_size = 0; + dec->header_size = 0; + dec->box_contents_unbounded = false; + memset(dec->box_type, 0, sizeof(dec->box_type)); + memset(dec->box_decoded_type, 0, sizeof(dec->box_decoded_type)); + dec->box_event = false; + dec->box_stage = BoxStage::kHeader; + dec->box_out_buffer_set = false; + dec->box_out_buffer_set_current_box = false; + dec->box_out_buffer = nullptr; + dec->box_out_buffer_size = 0; + dec->box_out_buffer_begin = 0; + dec->box_out_buffer_pos = 0; + +#if JPEGXL_ENABLE_TRANSCODE_JPEG + dec->exif_metadata.clear(); + dec->xmp_metadata.clear(); + dec->store_exif = 0; + dec->store_xmp = 0; + dec->recon_out_buffer_pos = 0; + dec->recon_exif_size = 0; + dec->recon_xmp_size = 0; + dec->recon_output_jpeg = JpegReconStage::kNone; +#endif + + dec->events_wanted = dec->orig_events_wanted; + dec->basic_info_size_hint = InitialBasicInfoSizeHint(); + dec->have_container = 0; + dec->box_count = 0; + dec->downsampling_target = 8; + dec->image_out_buffer_set = false; + dec->image_out_buffer = nullptr; + dec->image_out_init_callback = nullptr; + dec->image_out_run_callback = nullptr; + dec->image_out_destroy_callback = nullptr; + dec->image_out_init_opaque = nullptr; + dec->image_out_size = 0; + dec->image_out_bit_depth.type = JXL_BIT_DEPTH_FROM_PIXEL_FORMAT; + dec->extra_channel_output.clear(); + dec->next_in = 0; + dec->avail_in = 0; + dec->input_closed = false; + + dec->passes_state.reset(nullptr); + dec->frame_dec.reset(nullptr); + dec->next_section = 0; + dec->section_processed.clear(); + + dec->ib.reset(); + dec->metadata = jxl::CodecMetadata(); + dec->image_metadata = dec->metadata.m; + dec->frame_header.reset(new jxl::FrameHeader(&dec->metadata)); + + dec->codestream_copy.clear(); + dec->codestream_unconsumed = 0; + dec->codestream_pos = 0; + dec->codestream_bits_ahead = 0; + + dec->frame_stage = FrameStage::kHeader; + dec->remaining_frame_size = 0; + dec->is_last_of_still = false; + dec->is_last_total = false; + dec->skip_frames = 0; + dec->skipping_frame = false; + dec->internal_frames = 0; + dec->external_frames = 0; +} + +void JxlDecoderReset(JxlDecoder* dec) { + JxlDecoderRewindDecodingState(dec); + + dec->thread_pool.reset(); + dec->keep_orientation = false; + dec->unpremul_alpha = false; + dec->render_spotcolors = true; + dec->coalescing = true; + dec->desired_intensity_target = 0; + dec->orig_events_wanted = 0; + dec->events_wanted = 0; + dec->frame_references.clear(); + dec->frame_saved_as.clear(); + dec->frame_external_to_internal.clear(); + dec->frame_required.clear(); + dec->decompress_boxes = false; +} + +JxlDecoder* JxlDecoderCreate(const JxlMemoryManager* memory_manager) { + JxlMemoryManager local_memory_manager; + if (!jxl::MemoryManagerInit(&local_memory_manager, memory_manager)) + return nullptr; + + void* alloc = + jxl::MemoryManagerAlloc(&local_memory_manager, sizeof(JxlDecoder)); + if (!alloc) return nullptr; + // Placement new constructor on allocated memory + JxlDecoder* dec = new (alloc) JxlDecoder(); + dec->memory_manager = local_memory_manager; + +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + if (!memory_manager) { + dec->memory_limit_base = 53 << 16; + // Allow 5 x max_image_size processing units; every frame is accounted + // as W x H CPU processing units, so there could be numerous small frames + // or few larger ones. + dec->cpu_limit_base = 5 * dec->memory_limit_base; + } +#endif + + JxlDecoderReset(dec); + + return dec; +} + +void JxlDecoderDestroy(JxlDecoder* dec) { + if (dec) { + JxlMemoryManager local_memory_manager = dec->memory_manager; + // Call destructor directly since custom free function is used. + dec->~JxlDecoder(); + jxl::MemoryManagerFree(&local_memory_manager, dec); + } +} + +void JxlDecoderRewind(JxlDecoder* dec) { JxlDecoderRewindDecodingState(dec); } + +void JxlDecoderSkipFrames(JxlDecoder* dec, size_t amount) { + // Increment amount, rather than set it: making the amount smaller is + // impossible because the decoder may already have skipped frames required to + // decode earlier frames, and making the amount larger compared to an existing + // amount is impossible because if JxlDecoderSkipFrames is called in the + // middle of already skipping frames, the user cannot know how many frames + // have already been skipped internally so far so an absolute value cannot + // be defined. + dec->skip_frames += amount; + + dec->frame_required.clear(); + size_t next_frame = dec->external_frames + dec->skip_frames; + + // A frame that has been seen before a rewind + if (next_frame < dec->frame_external_to_internal.size()) { + size_t internal_index = dec->frame_external_to_internal[next_frame]; + if (internal_index < dec->frame_saved_as.size()) { + std::vector<size_t> deps = GetFrameDependencies( + internal_index, dec->frame_saved_as, dec->frame_references); + + dec->frame_required.resize(internal_index + 1, 0); + for (size_t i = 0; i < deps.size(); i++) { + JXL_ASSERT(deps[i] < dec->frame_required.size()); + dec->frame_required[deps[i]] = 1; + } + } + } +} + +JxlDecoderStatus JxlDecoderSkipCurrentFrame(JxlDecoder* dec) { + if (dec->frame_stage != FrameStage::kFull) { + return JXL_API_ERROR("JxlDecoderSkipCurrentFrame called at the wrong time"); + } + JXL_DASSERT(dec->frame_dec); + dec->frame_stage = FrameStage::kHeader; + dec->AdvanceCodestream(dec->remaining_frame_size); + if (dec->is_last_of_still) { + dec->image_out_buffer_set = false; + } + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus +JxlDecoderSetParallelRunner(JxlDecoder* dec, JxlParallelRunner parallel_runner, + void* parallel_runner_opaque) { + if (dec->stage != DecoderStage::kInited) { + return JXL_API_ERROR( + "JxlDecoderSetParallelRunner must be called before starting"); + } + dec->thread_pool.reset( + new jxl::ThreadPool(parallel_runner, parallel_runner_opaque)); + return JXL_DEC_SUCCESS; +} + +size_t JxlDecoderSizeHintBasicInfo(const JxlDecoder* dec) { + if (dec->got_basic_info) return 0; + return dec->basic_info_size_hint; +} + +JxlDecoderStatus JxlDecoderSubscribeEvents(JxlDecoder* dec, int events_wanted) { + if (dec->stage != DecoderStage::kInited) { + return JXL_DEC_ERROR; // Cannot subscribe to events after having started. + } + if (events_wanted & 63) { + return JXL_DEC_ERROR; // Can only subscribe to informative events. + } + dec->events_wanted = events_wanted; + dec->orig_events_wanted = events_wanted; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetKeepOrientation(JxlDecoder* dec, + JXL_BOOL skip_reorientation) { + if (dec->stage != DecoderStage::kInited) { + return JXL_API_ERROR("Must set keep_orientation option before starting"); + } + dec->keep_orientation = !!skip_reorientation; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetUnpremultiplyAlpha(JxlDecoder* dec, + JXL_BOOL unpremul_alpha) { + if (dec->stage != DecoderStage::kInited) { + return JXL_API_ERROR("Must set unpremul_alpha option before starting"); + } + dec->unpremul_alpha = !!unpremul_alpha; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetRenderSpotcolors(JxlDecoder* dec, + JXL_BOOL render_spotcolors) { + if (dec->stage != DecoderStage::kInited) { + return JXL_API_ERROR("Must set render_spotcolors option before starting"); + } + dec->render_spotcolors = !!render_spotcolors; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetCoalescing(JxlDecoder* dec, JXL_BOOL coalescing) { + if (dec->stage != DecoderStage::kInited) { + return JXL_API_ERROR("Must set coalescing option before starting"); + } + dec->coalescing = !!coalescing; + return JXL_DEC_SUCCESS; +} + +namespace { +// helper function to get the dimensions of the current image buffer +void GetCurrentDimensions(const JxlDecoder* dec, size_t& xsize, size_t& ysize) { + if (dec->frame_header->nonserialized_is_preview) { + xsize = dec->metadata.oriented_preview_xsize(dec->keep_orientation); + ysize = dec->metadata.oriented_preview_ysize(dec->keep_orientation); + return; + } + xsize = dec->metadata.oriented_xsize(dec->keep_orientation); + ysize = dec->metadata.oriented_ysize(dec->keep_orientation); + if (!dec->coalescing) { + const auto frame_dim = dec->frame_header->ToFrameDimensions(); + xsize = frame_dim.xsize_upsampled; + ysize = frame_dim.ysize_upsampled; + if (!dec->keep_orientation && + static_cast<int>(dec->metadata.m.GetOrientation()) > 4) { + std::swap(xsize, ysize); + } + } +} +} // namespace + +namespace jxl { +namespace { + +template <class T> +bool CanRead(Span<const uint8_t> data, BitReader* reader, T* JXL_RESTRICT t) { + // Use a copy of the bit reader because CanRead advances bits. + BitReader reader2(data); + reader2.SkipBits(reader->TotalBitsConsumed()); + bool result = Bundle::CanRead(&reader2, t); + JXL_ASSERT(reader2.Close()); + return result; +} + +// Returns JXL_DEC_SUCCESS if the full bundle was successfully read, status +// indicating either error or need more input otherwise. +template <class T> +JxlDecoderStatus ReadBundle(JxlDecoder* dec, Span<const uint8_t> data, + BitReader* reader, T* JXL_RESTRICT t) { + if (!CanRead(data, reader, t)) { + return dec->RequestMoreInput(); + } + if (!Bundle::Read(reader, t)) { + return JXL_DEC_ERROR; + } + return JXL_DEC_SUCCESS; +} + +std::unique_ptr<BitReader, std::function<void(BitReader*)>> GetBitReader( + Span<const uint8_t> span) { + BitReader* reader = new BitReader(span); + return std::unique_ptr<BitReader, std::function<void(BitReader*)>>( + reader, [](BitReader* reader) { + // We can't allow Close to abort the program if the reader is out of + // bounds, or all return paths in the code, even those that already + // return failure, would have to manually call AllReadsWithinBounds(). + // Invalid JXL codestream should not cause program to quit. + (void)reader->AllReadsWithinBounds(); + (void)reader->Close(); + delete reader; + }); +} + +JxlDecoderStatus JxlDecoderReadBasicInfo(JxlDecoder* dec) { + if (!dec->got_codestream_signature) { + // Check and skip the codestream signature + Span<const uint8_t> span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + if (span.size() < 2) { + return dec->RequestMoreInput(); + } + if (span.data()[0] != 0xff || span.data()[1] != jxl::kCodestreamMarker) { + return JXL_INPUT_ERROR("invalid signature"); + } + dec->got_codestream_signature = true; + dec->AdvanceCodestream(2); + } + + Span<const uint8_t> span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + auto reader = GetBitReader(span); + JXL_API_RETURN_IF_ERROR( + ReadBundle(dec, span, reader.get(), &dec->metadata.size)); + JXL_API_RETURN_IF_ERROR( + ReadBundle(dec, span, reader.get(), &dec->metadata.m)); + size_t total_bits = reader->TotalBitsConsumed(); + dec->AdvanceCodestream(total_bits / jxl::kBitsPerByte); + dec->codestream_bits_ahead = total_bits % jxl::kBitsPerByte; + dec->got_basic_info = true; + dec->basic_info_size_hint = 0; + dec->image_metadata = dec->metadata.m; + JXL_DEBUG_V(2, "Decoded BasicInfo: %s", dec->metadata.DebugString().c_str()); + + if (!CheckSizeLimit(dec, dec->metadata.size.xsize(), + dec->metadata.size.ysize())) { + return JXL_INPUT_ERROR("image is too large"); + } + + return JXL_DEC_SUCCESS; +} + +// Reads all codestream headers (but not frame headers) +JxlDecoderStatus JxlDecoderReadAllHeaders(JxlDecoder* dec) { + if (!dec->got_transform_data) { + Span<const uint8_t> span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + auto reader = GetBitReader(span); + reader->SkipBits(dec->codestream_bits_ahead); + dec->metadata.transform_data.nonserialized_xyb_encoded = + dec->metadata.m.xyb_encoded; + JXL_API_RETURN_IF_ERROR( + ReadBundle(dec, span, reader.get(), &dec->metadata.transform_data)); + size_t total_bits = reader->TotalBitsConsumed(); + dec->AdvanceCodestream(total_bits / jxl::kBitsPerByte); + dec->codestream_bits_ahead = total_bits % jxl::kBitsPerByte; + dec->got_transform_data = true; + } + + Span<const uint8_t> span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + auto reader = GetBitReader(span); + reader->SkipBits(dec->codestream_bits_ahead); + + if (dec->metadata.m.color_encoding.WantICC()) { + jxl::Status status = + dec->icc_reader.Init(reader.get(), dec->memory_limit_base); + // Always check AllReadsWithinBounds, not all the C++ decoder implementation + // handles reader out of bounds correctly yet (e.g. context map). Not + // checking AllReadsWithinBounds can cause reader->Close() to trigger an + // assert, but we don't want library to quit program for invalid codestream. + if (!reader->AllReadsWithinBounds() || + status.code() == StatusCode::kNotEnoughBytes) { + return dec->RequestMoreInput(); + } + if (!status) { + // Other non-successful status is an error + return JXL_DEC_ERROR; + } + PaddedBytes decoded_icc; + status = dec->icc_reader.Process(reader.get(), &decoded_icc); + if (status.code() == StatusCode::kNotEnoughBytes) { + return dec->RequestMoreInput(); + } + if (!status) { + // Other non-successful status is an error + return JXL_DEC_ERROR; + } + if (decoded_icc.empty()) { + return JXL_DEC_ERROR; + } + IccBytes icc; + Bytes(decoded_icc).AppendTo(&icc); + dec->metadata.m.color_encoding.SetICCRaw(std::move(icc)); + } + + dec->got_all_headers = true; + JXL_API_RETURN_IF_ERROR(reader->JumpToByteBoundary()); + + dec->AdvanceCodestream(reader->TotalBitsConsumed() / jxl::kBitsPerByte); + dec->codestream_bits_ahead = 0; + + if (!dec->passes_state) { + dec->passes_state.reset(new jxl::PassesDecoderState()); + } + + JXL_API_RETURN_IF_ERROR( + dec->passes_state->output_encoding_info.SetFromMetadata(dec->metadata)); + if (dec->desired_intensity_target > 0) { + dec->passes_state->output_encoding_info.desired_intensity_target = + dec->desired_intensity_target; + } + dec->image_metadata = dec->metadata.m; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderProcessSections(JxlDecoder* dec) { + Span<const uint8_t> span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + const auto& toc = dec->frame_dec->Toc(); + size_t pos = 0; + std::vector<jxl::FrameDecoder::SectionInfo> section_info; + std::vector<jxl::FrameDecoder::SectionStatus> section_status; + for (size_t i = dec->next_section; i < toc.size(); ++i) { + if (dec->section_processed[i]) { + pos += toc[i].size; + continue; + } + size_t id = toc[i].id; + size_t size = toc[i].size; + if (OutOfBounds(pos, size, span.size())) { + break; + } + auto br = new jxl::BitReader(jxl::Bytes(span.data() + pos, size)); + section_info.emplace_back(jxl::FrameDecoder::SectionInfo{br, id, i}); + section_status.emplace_back(); + pos += size; + } + jxl::Status status = dec->frame_dec->ProcessSections( + section_info.data(), section_info.size(), section_status.data()); + bool out_of_bounds = false; + for (const auto& info : section_info) { + if (!info.br->AllReadsWithinBounds()) { + // Mark out of bounds section, but keep closing and deleting the next + // ones as well. + out_of_bounds = true; + } + JXL_ASSERT(info.br->Close()); + delete info.br; + } + if (out_of_bounds) { + // If any bit reader indicates out of bounds, it's an error, not just + // needing more input, since we ensure only bit readers containing + // a complete section are provided to the FrameDecoder. + return JXL_INPUT_ERROR("frame out of bounds"); + } + if (!status) { + return JXL_INPUT_ERROR("frame processing failed"); + } + for (size_t i = 0; i < section_status.size(); ++i) { + auto status = section_status[i]; + if (status == jxl::FrameDecoder::kDone) { + dec->section_processed[section_info[i].index] = 1; + } else if (status != jxl::FrameDecoder::kSkipped) { + return JXL_INPUT_ERROR("unexpected section status"); + } + } + size_t completed_prefix_bytes = 0; + while (dec->next_section < dec->section_processed.size() && + dec->section_processed[dec->next_section] == 1) { + completed_prefix_bytes += toc[dec->next_section].size; + ++dec->next_section; + } + dec->remaining_frame_size -= completed_prefix_bytes; + dec->AdvanceCodestream(completed_prefix_bytes); + return JXL_DEC_SUCCESS; +} + +// TODO(eustas): no CodecInOut -> no image size reinforcement -> possible OOM. +JxlDecoderStatus JxlDecoderProcessCodestream(JxlDecoder* dec) { + // If no parallel runner is set, use the default + // TODO(lode): move this initialization to an appropriate location once the + // runner is used to decode pixels. + if (!dec->thread_pool) { + dec->thread_pool.reset(new jxl::ThreadPool(nullptr, nullptr)); + } + + // No matter what events are wanted, the basic info is always required. + if (!dec->got_basic_info) { + JxlDecoderStatus status = JxlDecoderReadBasicInfo(dec); + if (status != JXL_DEC_SUCCESS) return status; + } + + if (dec->events_wanted & JXL_DEC_BASIC_INFO) { + dec->events_wanted &= ~JXL_DEC_BASIC_INFO; + return JXL_DEC_BASIC_INFO; + } + + if (!dec->events_wanted) { + dec->stage = DecoderStage::kCodestreamFinished; + return JXL_DEC_SUCCESS; + } + + if (!dec->got_all_headers) { + JxlDecoderStatus status = JxlDecoderReadAllHeaders(dec); + if (status != JXL_DEC_SUCCESS) return status; + } + + if (dec->events_wanted & JXL_DEC_COLOR_ENCODING) { + dec->events_wanted &= ~JXL_DEC_COLOR_ENCODING; + return JXL_DEC_COLOR_ENCODING; + } + + if (!dec->events_wanted) { + dec->stage = DecoderStage::kCodestreamFinished; + return JXL_DEC_SUCCESS; + } + + dec->post_headers = true; + + if (!dec->got_preview_image && dec->metadata.m.have_preview) { + dec->preview_frame = true; + } + + // Handle frames + for (;;) { + bool parse_frames = + (dec->events_wanted & + (JXL_DEC_PREVIEW_IMAGE | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + if (!parse_frames) { + break; + } + if (dec->frame_stage == FrameStage::kHeader && dec->is_last_total) { + break; + } + if (dec->frame_stage == FrameStage::kHeader) { +#if JPEGXL_ENABLE_TRANSCODE_JPEG + if (dec->recon_output_jpeg == JpegReconStage::kSettingMetadata || + dec->recon_output_jpeg == JpegReconStage::kOutputting) { + // The image bundle contains the JPEG reconstruction frame, but the + // decoder is still waiting to decode an EXIF or XMP box. It's not + // implemented to decode additional frames during this, and a JPEG + // reconstruction image should have only one frame. + return JXL_API_ERROR( + "cannot decode a next frame after JPEG reconstruction frame"); + } +#endif + if (!dec->ib) { + dec->ib.reset(new jxl::ImageBundle(&dec->image_metadata)); + } +#if JPEGXL_ENABLE_TRANSCODE_JPEG + // If JPEG reconstruction is wanted and possible, set the jpeg_data of + // the ImageBundle. + if (!dec->jpeg_decoder.SetImageBundleJpegData(dec->ib.get())) + return JXL_DEC_ERROR; +#endif + dec->frame_dec.reset(new FrameDecoder( + dec->passes_state.get(), dec->metadata, dec->thread_pool.get(), + /*use_slow_rendering_pipeline=*/false)); + dec->frame_header.reset(new FrameHeader(&dec->metadata)); + Span<const uint8_t> span; + JXL_API_RETURN_IF_ERROR(dec->GetCodestreamInput(&span)); + auto reader = GetBitReader(span); + jxl::Status status = dec->frame_dec->InitFrame( + reader.get(), dec->ib.get(), dec->preview_frame); + if (!reader->AllReadsWithinBounds() || + status.code() == StatusCode::kNotEnoughBytes) { + return dec->RequestMoreInput(); + } else if (!status) { + return JXL_INPUT_ERROR("invalid frame header"); + } + dec->AdvanceCodestream(reader->TotalBitsConsumed() / kBitsPerByte); + *dec->frame_header = dec->frame_dec->GetFrameHeader(); + jxl::FrameDimensions frame_dim = dec->frame_header->ToFrameDimensions(); + if (!CheckSizeLimit(dec, frame_dim.xsize_upsampled_padded, + frame_dim.ysize_upsampled_padded)) { + return JXL_INPUT_ERROR("frame is too large"); + } + bool output_needed = + (dec->preview_frame ? (dec->events_wanted & JXL_DEC_PREVIEW_IMAGE) + : (dec->events_wanted & JXL_DEC_FULL_IMAGE)); + if (output_needed) { + JXL_API_RETURN_IF_ERROR(dec->frame_dec->InitFrameOutput()); + } + if (dec->cpu_limit_base != 0) { + // No overflow, checked in CheckSizeLimit. + size_t num_pixels = frame_dim.xsize * frame_dim.ysize; + if (dec->used_cpu_base + num_pixels < dec->used_cpu_base) { + return JXL_INPUT_ERROR("image too large"); + } + dec->used_cpu_base += num_pixels; + if (dec->used_cpu_base > dec->cpu_limit_base) { + return JXL_INPUT_ERROR("image too large"); + } + } + dec->remaining_frame_size = dec->frame_dec->SumSectionSizes(); + + dec->frame_stage = FrameStage::kTOC; + if (dec->preview_frame) { + if (!(dec->events_wanted & JXL_DEC_PREVIEW_IMAGE)) { + dec->frame_stage = FrameStage::kHeader; + dec->AdvanceCodestream(dec->remaining_frame_size); + dec->got_preview_image = true; + dec->preview_frame = false; + } + continue; + } + + int saved_as = FrameDecoder::SavedAs(*dec->frame_header); + // is last in entire codestream + dec->is_last_total = dec->frame_header->is_last; + // is last of current still + dec->is_last_of_still = + dec->is_last_total || dec->frame_header->animation_frame.duration > 0; + // is kRegularFrame and coalescing is disabled + dec->is_last_of_still |= + (!dec->coalescing && + dec->frame_header->frame_type == FrameType::kRegularFrame); + const size_t internal_frame_index = dec->internal_frames; + const size_t external_frame_index = dec->external_frames; + if (dec->is_last_of_still) dec->external_frames++; + dec->internal_frames++; + + if (dec->skip_frames > 0) { + dec->skipping_frame = true; + if (dec->is_last_of_still) { + dec->skip_frames--; + } + } else { + dec->skipping_frame = false; + } + + if (external_frame_index >= dec->frame_external_to_internal.size()) { + dec->frame_external_to_internal.push_back(internal_frame_index); + JXL_ASSERT(dec->frame_external_to_internal.size() == + external_frame_index + 1); + } + + if (internal_frame_index >= dec->frame_saved_as.size()) { + dec->frame_saved_as.push_back(saved_as); + JXL_ASSERT(dec->frame_saved_as.size() == internal_frame_index + 1); + + // add the value 0xff (which means all references) to new slots: we only + // know the references of the frame at FinalizeFrame, and fill in the + // correct values there. As long as this information is not known, the + // worst case where the frame depends on all storage slots is assumed. + dec->frame_references.push_back(0xff); + JXL_ASSERT(dec->frame_references.size() == internal_frame_index + 1); + } + + if (dec->skipping_frame) { + // Whether this frame could be referenced by any future frame: either + // because it's a frame saved for blending or patches, or because it's + // a DC frame. + bool referenceable = + dec->frame_header->CanBeReferenced() || + dec->frame_header->frame_type == FrameType::kDCFrame; + if (internal_frame_index < dec->frame_required.size() && + !dec->frame_required[internal_frame_index]) { + referenceable = false; + } + if (!referenceable) { + // Skip all decoding for this frame, since the user is skipping this + // frame and no future frames can reference it. + dec->frame_stage = FrameStage::kHeader; + dec->AdvanceCodestream(dec->remaining_frame_size); + continue; + } + } + + if ((dec->events_wanted & JXL_DEC_FRAME) && dec->is_last_of_still) { + // Only return this for the last of a series of stills: patches frames + // etc... before this one do not contain the correct information such + // as animation timing, ... + if (!dec->skipping_frame) { + return JXL_DEC_FRAME; + } + } + } + + if (dec->frame_stage == FrameStage::kTOC) { + dec->frame_dec->SetRenderSpotcolors(dec->render_spotcolors); + dec->frame_dec->SetCoalescing(dec->coalescing); + + if (!dec->preview_frame && + (dec->events_wanted & JXL_DEC_FRAME_PROGRESSION)) { + dec->frame_prog_detail = + dec->frame_dec->SetPauseAtProgressive(dec->prog_detail); + } else { + dec->frame_prog_detail = JxlProgressiveDetail::kFrames; + } + dec->dc_frame_progression_done = 0; + + dec->next_section = 0; + dec->section_processed.clear(); + dec->section_processed.resize(dec->frame_dec->Toc().size(), 0); + + // If we don't need pixels, we can skip actually decoding the frames. + if (dec->preview_frame || (dec->events_wanted & JXL_DEC_FULL_IMAGE)) { + dec->frame_stage = FrameStage::kFull; + } else if (!dec->is_last_total) { + dec->frame_stage = FrameStage::kHeader; + dec->AdvanceCodestream(dec->remaining_frame_size); + continue; + } else { + break; + } + } + + if (dec->frame_stage == FrameStage::kFull) { + if (!dec->image_out_buffer_set) { + if (dec->preview_frame) { + return JXL_DEC_NEED_PREVIEW_OUT_BUFFER; + } + if ( +#if JPEGXL_ENABLE_TRANSCODE_JPEG + (!dec->jpeg_decoder.IsOutputSet() || + dec->ib->jpeg_data == nullptr) && +#endif + dec->is_last_of_still && !dec->skipping_frame) { + // TODO(lode): remove the dec->is_last_of_still condition if the + // frame decoder needs the image buffer as working space for decoding + // non-visible or blending frames too + return JXL_DEC_NEED_IMAGE_OUT_BUFFER; + } + } + + if (dec->image_out_buffer_set) { + size_t xsize, ysize; + GetCurrentDimensions(dec, xsize, ysize); + size_t bits_per_sample = GetBitDepth( + dec->image_out_bit_depth, dec->metadata.m, dec->image_out_format); + dec->frame_dec->SetImageOutput( + PixelCallback{ + dec->image_out_init_callback, dec->image_out_run_callback, + dec->image_out_destroy_callback, dec->image_out_init_opaque}, + reinterpret_cast<uint8_t*>(dec->image_out_buffer), + dec->image_out_size, xsize, ysize, dec->image_out_format, + bits_per_sample, dec->unpremul_alpha, !dec->keep_orientation); + for (size_t i = 0; i < dec->extra_channel_output.size(); ++i) { + const auto& extra = dec->extra_channel_output[i]; + size_t ec_bits_per_sample = + GetBitDepth(dec->image_out_bit_depth, + dec->metadata.m.extra_channel_info[i], extra.format); + dec->frame_dec->AddExtraChannelOutput(extra.buffer, extra.buffer_size, + xsize, extra.format, + ec_bits_per_sample); + } + } + + size_t next_num_passes_to_pause = dec->frame_dec->NextNumPassesToPause(); + + JXL_API_RETURN_IF_ERROR(JxlDecoderProcessSections(dec)); + + bool all_sections_done = dec->frame_dec->HasDecodedAll(); + bool got_dc_only = !all_sections_done && dec->frame_dec->HasDecodedDC(); + + if (dec->frame_prog_detail >= JxlProgressiveDetail::kDC && + !dec->dc_frame_progression_done && got_dc_only) { + dec->dc_frame_progression_done = true; + dec->downsampling_target = 8; + return JXL_DEC_FRAME_PROGRESSION; + } + + bool new_progression_step_done = + dec->frame_dec->NumCompletePasses() >= next_num_passes_to_pause; + + if (!all_sections_done && + dec->frame_prog_detail >= JxlProgressiveDetail::kLastPasses && + new_progression_step_done) { + dec->downsampling_target = + dec->frame_header->passes.GetDownsamplingTargetForCompletedPasses( + dec->frame_dec->NumCompletePasses()); + return JXL_DEC_FRAME_PROGRESSION; + } + + if (!all_sections_done) { + // Not all sections have been processed yet + return dec->RequestMoreInput(); + } + + if (!dec->preview_frame) { + size_t internal_index = dec->internal_frames - 1; + JXL_ASSERT(dec->frame_references.size() > internal_index); + // Always fill this in, even if it was already written, it could be that + // this frame was skipped before and set to 255, while only now we know + // the true value. + dec->frame_references[internal_index] = dec->frame_dec->References(); + } + + if (!dec->frame_dec->FinalizeFrame()) { + return JXL_INPUT_ERROR("decoding frame failed"); + } +#if JPEGXL_ENABLE_TRANSCODE_JPEG + // If jpeg output was requested, we merely return the JXL_DEC_FULL_IMAGE + // status without outputting pixels. + if (dec->jpeg_decoder.IsOutputSet() && dec->ib->jpeg_data != nullptr) { + dec->frame_stage = FrameStage::kHeader; + dec->recon_output_jpeg = JpegReconStage::kSettingMetadata; + return JXL_DEC_FULL_IMAGE; + } +#endif + if (dec->preview_frame || dec->is_last_of_still) { + dec->image_out_buffer_set = false; + dec->extra_channel_output.clear(); + } + } + + dec->frame_stage = FrameStage::kHeader; + + // The pixels have been output or are not needed, do not keep them in + // memory here. + dec->ib.reset(); + if (dec->preview_frame) { + dec->got_preview_image = true; + dec->preview_frame = false; + dec->events_wanted &= ~JXL_DEC_PREVIEW_IMAGE; + return JXL_DEC_PREVIEW_IMAGE; + } else if (dec->is_last_of_still && + (dec->events_wanted & JXL_DEC_FULL_IMAGE) && + !dec->skipping_frame) { + return JXL_DEC_FULL_IMAGE; + } + } + + dec->stage = DecoderStage::kCodestreamFinished; + // Return success, this means there is nothing more to do. + return JXL_DEC_SUCCESS; +} + +} // namespace +} // namespace jxl + +JxlDecoderStatus JxlDecoderSetInput(JxlDecoder* dec, const uint8_t* data, + size_t size) { + if (dec->next_in) { + return JXL_API_ERROR("already set input, use JxlDecoderReleaseInput first"); + } + if (dec->input_closed) { + return JXL_API_ERROR("input already closed"); + } + + dec->next_in = data; + dec->avail_in = size; + return JXL_DEC_SUCCESS; +} + +size_t JxlDecoderReleaseInput(JxlDecoder* dec) { + size_t result = dec->avail_in; + dec->next_in = nullptr; + dec->avail_in = 0; + return result; +} + +void JxlDecoderCloseInput(JxlDecoder* dec) { dec->input_closed = true; } + +JxlDecoderStatus JxlDecoderSetJPEGBuffer(JxlDecoder* dec, uint8_t* data, + size_t size) { +#if JPEGXL_ENABLE_TRANSCODE_JPEG + // JPEG reconstruction buffer can only set and updated before or during the + // first frame, the reconstruction box refers to the first frame and in + // theory multi-frame images should not be used with a jbrd box. + if (dec->internal_frames > 1) { + return JXL_API_ERROR("JPEG reconstruction only works for the first frame"); + } + if (dec->jpeg_decoder.IsOutputSet()) { + return JXL_API_ERROR("Already set JPEG buffer"); + } + return dec->jpeg_decoder.SetOutputBuffer(data, size); +#else + return JXL_API_ERROR("JPEG reconstruction is not supported."); +#endif +} + +size_t JxlDecoderReleaseJPEGBuffer(JxlDecoder* dec) { +#if JPEGXL_ENABLE_TRANSCODE_JPEG + return dec->jpeg_decoder.ReleaseOutputBuffer(); +#else + return JXL_API_ERROR("JPEG reconstruction is not supported."); +#endif +} + +// Parses the header of the box, outputting the 4-character type and the box +// size, including header size, as stored in the box header. +// @param in current input bytes. +// @param size available input size. +// @param pos position in the input, must begin at the header of the box. +// @param file_pos position of pos since the start of the JXL file, rather than +// the current input, used for integer overflow checking. +// @param type the output box type. +// @param box_size output the total box size, including header, in bytes, or 0 +// if it's a final unbounded box. +// @param header_size output size of the box header. +// @return JXL_DEC_SUCCESS if the box header was fully parsed. In that case the +// parsing position must be incremented by header_size bytes. +// JXL_DEC_NEED_MORE_INPUT if not enough input bytes available, in that case +// header_size indicates a lower bound for the known size the header has to be +// at least. JXL_DEC_ERROR if the box header is invalid. +static JxlDecoderStatus ParseBoxHeader(const uint8_t* in, size_t size, + size_t pos, size_t file_pos, + JxlBoxType type, uint64_t* box_size, + uint64_t* header_size) { + if (OutOfBounds(pos, 8, size)) { + *header_size = 8; + return JXL_DEC_NEED_MORE_INPUT; + } + size_t box_start = pos; + // Box size, including this header itself. + *box_size = LoadBE32(in + pos); + pos += 4; + memcpy(type, in + pos, 4); + pos += 4; + if (*box_size == 1) { + *header_size = 16; + if (OutOfBounds(pos, 8, size)) return JXL_DEC_NEED_MORE_INPUT; + *box_size = LoadBE64(in + pos); + pos += 8; + } + *header_size = pos - box_start; + if (*box_size > 0 && *box_size < *header_size) { + return JXL_INPUT_ERROR("invalid box size"); + } + if (file_pos + *box_size < file_pos) { + return JXL_INPUT_ERROR("Box size overflow"); + } + return JXL_DEC_SUCCESS; +} + +// This includes handling the codestream if it is not a box-based jxl file. +static JxlDecoderStatus HandleBoxes(JxlDecoder* dec) { + // Box handling loop + for (;;) { + if (dec->box_stage != BoxStage::kHeader) { + dec->AdvanceInput(dec->header_size); + dec->header_size = 0; +#if JPEGXL_ENABLE_BOXES + if ((dec->events_wanted & JXL_DEC_BOX) && + dec->box_out_buffer_set_current_box) { + uint8_t* next_out = dec->box_out_buffer + dec->box_out_buffer_pos; + size_t avail_out = dec->box_out_buffer_size - dec->box_out_buffer_pos; + + JxlDecoderStatus box_result = dec->box_content_decoder.Process( + dec->next_in, dec->avail_in, + dec->file_pos - dec->box_contents_begin, &next_out, &avail_out); + size_t produced = + next_out - (dec->box_out_buffer + dec->box_out_buffer_pos); + dec->box_out_buffer_pos += produced; + + // Don't return JXL_DEC_NEED_MORE_INPUT: the box stages below, instead, + // handle the input progression, and the above only outputs the part of + // the box seen so far. + if (box_result != JXL_DEC_SUCCESS && + box_result != JXL_DEC_NEED_MORE_INPUT) { + return box_result; + } + } +#endif +#if JPEGXL_ENABLE_TRANSCODE_JPEG + if (dec->store_exif == 1 || dec->store_xmp == 1) { + std::vector<uint8_t>& metadata = + (dec->store_exif == 1) ? dec->exif_metadata : dec->xmp_metadata; + for (;;) { + if (metadata.empty()) metadata.resize(64); + uint8_t* orig_next_out = metadata.data() + dec->recon_out_buffer_pos; + uint8_t* next_out = orig_next_out; + size_t avail_out = metadata.size() - dec->recon_out_buffer_pos; + JxlDecoderStatus box_result = dec->metadata_decoder.Process( + dec->next_in, dec->avail_in, + dec->file_pos - dec->box_contents_begin, &next_out, &avail_out); + size_t produced = next_out - orig_next_out; + dec->recon_out_buffer_pos += produced; + if (box_result == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + metadata.resize(metadata.size() * 2); + } else if (box_result == JXL_DEC_NEED_MORE_INPUT) { + break; // box stage handling below will handle this instead + } else if (box_result == JXL_DEC_SUCCESS) { + size_t needed_size = (dec->store_exif == 1) ? dec->recon_exif_size + : dec->recon_xmp_size; + if (dec->box_contents_unbounded && + dec->recon_out_buffer_pos < needed_size) { + // Unbounded box, but we know the expected size due to the jbrd + // box's data. Treat this as the JXL_DEC_NEED_MORE_INPUT case. + break; + } else { + metadata.resize(dec->recon_out_buffer_pos); + if (dec->store_exif == 1) dec->store_exif = 2; + if (dec->store_xmp == 1) dec->store_xmp = 2; + break; + } + } else { + // error + return box_result; + } + } + } +#endif + } +#if JPEGXL_ENABLE_TRANSCODE_JPEG + if (dec->recon_output_jpeg == JpegReconStage::kSettingMetadata && + !dec->JbrdNeedMoreBoxes()) { + jxl::jpeg::JPEGData* jpeg_data = dec->ib->jpeg_data.get(); + if (dec->recon_exif_size) { + JxlDecoderStatus status = jxl::JxlToJpegDecoder::SetExif( + dec->exif_metadata.data(), dec->exif_metadata.size(), jpeg_data); + if (status != JXL_DEC_SUCCESS) return status; + } + if (dec->recon_xmp_size) { + JxlDecoderStatus status = jxl::JxlToJpegDecoder::SetXmp( + dec->xmp_metadata.data(), dec->xmp_metadata.size(), jpeg_data); + if (status != JXL_DEC_SUCCESS) return status; + } + dec->recon_output_jpeg = JpegReconStage::kOutputting; + } + + if (dec->recon_output_jpeg == JpegReconStage::kOutputting && + !dec->JbrdNeedMoreBoxes()) { + JxlDecoderStatus status = + dec->jpeg_decoder.WriteOutput(*dec->ib->jpeg_data); + if (status != JXL_DEC_SUCCESS) return status; + dec->recon_output_jpeg = JpegReconStage::kNone; + dec->ib.reset(); + if (dec->events_wanted & JXL_DEC_FULL_IMAGE) { + // Return the full image event here now, this may be delayed if this + // could only be done after decoding an exif or xmp box after the + // codestream. + return JXL_DEC_FULL_IMAGE; + } + } +#endif + + if (dec->box_stage == BoxStage::kHeader) { + if (!dec->have_container) { + if (dec->stage == DecoderStage::kCodestreamFinished) + return JXL_DEC_SUCCESS; + dec->box_stage = BoxStage::kCodestream; + dec->box_contents_unbounded = true; + continue; + } + if (dec->avail_in == 0) { + if (dec->stage != DecoderStage::kCodestreamFinished) { + // Not yet seen (all) codestream boxes. + return JXL_DEC_NEED_MORE_INPUT; + } +#if JPEGXL_ENABLE_TRANSCODE_JPEG + if (dec->JbrdNeedMoreBoxes()) { + return JXL_DEC_NEED_MORE_INPUT; + } +#endif + if (dec->input_closed) { + return JXL_DEC_SUCCESS; + } + if (!(dec->events_wanted & JXL_DEC_BOX)) { + // All codestream and jbrd metadata boxes finished, and no individual + // boxes requested by user, so no need to request any more input. + // This returns success for backwards compatibility, when + // JxlDecoderCloseInput and JXL_DEC_BOX did not exist, as well + // as for efficiency. + return JXL_DEC_SUCCESS; + } + // Even though we are exactly at a box end, there still may be more + // boxes. The user may call JxlDecoderCloseInput to indicate the input + // is finished and get success instead. + return JXL_DEC_NEED_MORE_INPUT; + } + + bool boxed_codestream_done = + ((dec->events_wanted & JXL_DEC_BOX) && + dec->stage == DecoderStage::kCodestreamFinished && +#if JPEGXL_ENABLE_TRANSCODE_JPEG + !dec->JbrdNeedMoreBoxes() && +#endif + dec->last_codestream_seen); + if (boxed_codestream_done && dec->avail_in >= 2 && + dec->next_in[0] == 0xff && + dec->next_in[1] == jxl::kCodestreamMarker) { + // We detected the start of the next naked codestream, so we can return + // success here. + return JXL_DEC_SUCCESS; + } + + uint64_t box_size, header_size; + JxlDecoderStatus status = + ParseBoxHeader(dec->next_in, dec->avail_in, 0, dec->file_pos, + dec->box_type, &box_size, &header_size); + if (status != JXL_DEC_SUCCESS) { + if (status == JXL_DEC_NEED_MORE_INPUT) { + dec->basic_info_size_hint = + InitialBasicInfoSizeHint() + header_size - dec->file_pos; + } + return status; + } + if (memcmp(dec->box_type, "brob", 4) == 0) { + if (dec->avail_in < header_size + 4) { + return JXL_DEC_NEED_MORE_INPUT; + } + memcpy(dec->box_decoded_type, dec->next_in + header_size, + sizeof(dec->box_decoded_type)); + } else { + memcpy(dec->box_decoded_type, dec->box_type, + sizeof(dec->box_decoded_type)); + } + + // Box order validity checks + // The signature box at box_count == 1 is not checked here since that's + // already done at the beginning. + dec->box_count++; + if (boxed_codestream_done && memcmp(dec->box_type, "JXL ", 4) == 0) { + // We detected the start of the next boxed stream, so we can return + // success here. + return JXL_DEC_SUCCESS; + } + if (dec->box_count == 2 && memcmp(dec->box_type, "ftyp", 4) != 0) { + return JXL_INPUT_ERROR("the second box must be the ftyp box"); + } + if (memcmp(dec->box_type, "ftyp", 4) == 0 && dec->box_count != 2) { + return JXL_INPUT_ERROR("the ftyp box must come second"); + } + + dec->box_contents_unbounded = (box_size == 0); + dec->box_contents_begin = dec->file_pos + header_size; + dec->box_contents_end = + dec->box_contents_unbounded ? 0 : (dec->file_pos + box_size); + dec->box_contents_size = + dec->box_contents_unbounded ? 0 : (box_size - header_size); + dec->box_size = box_size; + dec->header_size = header_size; +#if JPEGXL_ENABLE_TRANSCODE_JPEG + if (dec->orig_events_wanted & JXL_DEC_JPEG_RECONSTRUCTION) { + // Initiate storing of Exif or XMP data for JPEG reconstruction + if (dec->store_exif == 0 && + memcmp(dec->box_decoded_type, "Exif", 4) == 0) { + dec->store_exif = 1; + dec->recon_out_buffer_pos = 0; + } + if (dec->store_xmp == 0 && + memcmp(dec->box_decoded_type, "xml ", 4) == 0) { + dec->store_xmp = 1; + dec->recon_out_buffer_pos = 0; + } + } +#endif +#if JPEGXL_ENABLE_BOXES + if (dec->events_wanted & JXL_DEC_BOX) { + bool decompress = + dec->decompress_boxes && memcmp(dec->box_type, "brob", 4) == 0; + dec->box_content_decoder.StartBox( + decompress, dec->box_contents_unbounded, dec->box_contents_size); + } +#endif +#if JPEGXL_ENABLE_TRANSCODE_JPEG + if (dec->store_exif == 1 || dec->store_xmp == 1) { + bool brob = memcmp(dec->box_type, "brob", 4) == 0; + dec->metadata_decoder.StartBox(brob, dec->box_contents_unbounded, + dec->box_contents_size); + } +#endif + if (memcmp(dec->box_type, "ftyp", 4) == 0) { + dec->box_stage = BoxStage::kFtyp; + } else if (memcmp(dec->box_type, "jxlc", 4) == 0) { + if (dec->last_codestream_seen) { + return JXL_INPUT_ERROR("there can only be one jxlc box"); + } + dec->last_codestream_seen = true; + dec->box_stage = BoxStage::kCodestream; + } else if (memcmp(dec->box_type, "jxlp", 4) == 0) { + dec->box_stage = BoxStage::kPartialCodestream; +#if JPEGXL_ENABLE_TRANSCODE_JPEG + } else if ((dec->orig_events_wanted & JXL_DEC_JPEG_RECONSTRUCTION) && + memcmp(dec->box_type, "jbrd", 4) == 0) { + if (!(dec->events_wanted & JXL_DEC_JPEG_RECONSTRUCTION)) { + return JXL_INPUT_ERROR( + "multiple JPEG reconstruction boxes not supported"); + } + dec->box_stage = BoxStage::kJpegRecon; +#endif + } else { + dec->box_stage = BoxStage::kSkip; + } + + if (dec->events_wanted & JXL_DEC_BOX) { + dec->box_event = true; + dec->box_out_buffer_set_current_box = false; + return JXL_DEC_BOX; + } + } else if (dec->box_stage == BoxStage::kFtyp) { + if (dec->box_contents_size < 12) { + return JXL_INPUT_ERROR("file type box too small"); + } + if (dec->avail_in < 4) return JXL_DEC_NEED_MORE_INPUT; + if (memcmp(dec->next_in, "jxl ", 4) != 0) { + return JXL_INPUT_ERROR("file type box major brand must be \"jxl \""); + } + dec->AdvanceInput(4); + dec->box_stage = BoxStage::kSkip; + } else if (dec->box_stage == BoxStage::kPartialCodestream) { + if (dec->last_codestream_seen) { + return JXL_INPUT_ERROR("cannot have jxlp box after last jxlp box"); + } + // TODO(lode): error if box is unbounded but last bit not set + if (dec->avail_in < 4) return JXL_DEC_NEED_MORE_INPUT; + if (!dec->box_contents_unbounded && dec->box_contents_size < 4) { + return JXL_INPUT_ERROR("jxlp box too small to contain index"); + } + size_t jxlp_index = LoadBE32(dec->next_in); + // The high bit of jxlp_index indicates whether this is the last + // jxlp box. + if (jxlp_index & 0x80000000) { + dec->last_codestream_seen = true; + } + dec->AdvanceInput(4); + dec->box_stage = BoxStage::kCodestream; + } else if (dec->box_stage == BoxStage::kCodestream) { + JxlDecoderStatus status = jxl::JxlDecoderProcessCodestream(dec); +#if JPEGXL_ENABLE_TRANSCODE_JPEG + if (status == JXL_DEC_FULL_IMAGE) { + if (dec->recon_output_jpeg != JpegReconStage::kNone) { + continue; + } + } +#endif + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (dec->file_pos == dec->box_contents_end && + !dec->box_contents_unbounded) { + dec->box_stage = BoxStage::kHeader; + continue; + } + } + + if (status == JXL_DEC_SUCCESS) { +#if JPEGXL_ENABLE_TRANSCODE_JPEG + if (dec->JbrdNeedMoreBoxes()) { + dec->box_stage = BoxStage::kSkip; + continue; + } +#endif + if (dec->box_contents_unbounded) { + // Last box reached and codestream done, nothing more to do. + break; + } + if (dec->events_wanted & JXL_DEC_BOX) { + // Codestream done, but there may be more other boxes. + dec->box_stage = BoxStage::kSkip; + continue; + } + } + return status; +#if JPEGXL_ENABLE_TRANSCODE_JPEG + } else if (dec->box_stage == BoxStage::kJpegRecon) { + if (!dec->jpeg_decoder.IsParsingBox()) { + // This is a new JPEG reconstruction metadata box. + dec->jpeg_decoder.StartBox(dec->box_contents_unbounded, + dec->box_contents_size); + } + const uint8_t* next_in = dec->next_in; + size_t avail_in = dec->avail_in; + JxlDecoderStatus recon_result = + dec->jpeg_decoder.Process(&next_in, &avail_in); + size_t consumed = next_in - dec->next_in; + dec->AdvanceInput(consumed); + if (recon_result == JXL_DEC_JPEG_RECONSTRUCTION) { + jxl::jpeg::JPEGData* jpeg_data = dec->jpeg_decoder.GetJpegData(); + size_t num_exif = jxl::JxlToJpegDecoder::NumExifMarkers(*jpeg_data); + size_t num_xmp = jxl::JxlToJpegDecoder::NumXmpMarkers(*jpeg_data); + if (num_exif) { + if (num_exif > 1) { + return JXL_INPUT_ERROR( + "multiple exif markers for JPEG reconstruction not supported"); + } + if (JXL_DEC_SUCCESS != jxl::JxlToJpegDecoder::ExifBoxContentSize( + *jpeg_data, &dec->recon_exif_size)) { + return JXL_INPUT_ERROR("invalid jbrd exif size"); + } + } + if (num_xmp) { + if (num_xmp > 1) { + return JXL_INPUT_ERROR( + "multiple XMP markers for JPEG reconstruction not supported"); + } + if (JXL_DEC_SUCCESS != jxl::JxlToJpegDecoder::XmlBoxContentSize( + *jpeg_data, &dec->recon_xmp_size)) { + return JXL_INPUT_ERROR("invalid jbrd XMP size"); + } + } + + dec->box_stage = BoxStage::kHeader; + // If successful JPEG reconstruction, return the success if the user + // cares about it, otherwise continue. + if (dec->events_wanted & JXL_DEC_JPEG_RECONSTRUCTION) { + dec->events_wanted &= ~JXL_DEC_JPEG_RECONSTRUCTION; + return JXL_DEC_JPEG_RECONSTRUCTION; + } + } else { + // If anything else, return the result. + return recon_result; + } +#endif + } else if (dec->box_stage == BoxStage::kSkip) { + if (dec->box_contents_unbounded) { + if (dec->input_closed) { + return JXL_DEC_SUCCESS; + } + if (!(dec->box_out_buffer_set)) { + // An unbounded box is always the last box. Not requesting box data, + // so return success even if JxlDecoderCloseInput was not called for + // backwards compatibility as well as efficiency since this box is + // being skipped. + return JXL_DEC_SUCCESS; + } + // Arbitrarily more bytes may follow, only JxlDecoderCloseInput can + // mark the end. + dec->AdvanceInput(dec->avail_in); + return JXL_DEC_NEED_MORE_INPUT; + } + // Amount of remaining bytes in the box that is being skipped. + size_t remaining = dec->box_contents_end - dec->file_pos; + if (dec->avail_in < remaining) { + // Indicate how many more bytes needed starting from next_in. + dec->basic_info_size_hint = + InitialBasicInfoSizeHint() + dec->box_contents_end - dec->file_pos; + // Don't have the full box yet, skip all we have so far + dec->AdvanceInput(dec->avail_in); + return JXL_DEC_NEED_MORE_INPUT; + } else { + // Full box available, skip all its remaining bytes + dec->AdvanceInput(remaining); + dec->box_stage = BoxStage::kHeader; + } + } else { + JXL_DASSERT(false); // unknown box stage + } + } + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderProcessInput(JxlDecoder* dec) { + if (dec->stage == DecoderStage::kInited) { + dec->stage = DecoderStage::kStarted; + } + if (dec->stage == DecoderStage::kError) { + return JXL_API_ERROR( + "Cannot keep using decoder after it encountered an error, use " + "JxlDecoderReset to reset it"); + } + + if (!dec->got_signature) { + JxlSignature sig = JxlSignatureCheck(dec->next_in, dec->avail_in); + if (sig == JXL_SIG_INVALID) return JXL_INPUT_ERROR("invalid signature"); + if (sig == JXL_SIG_NOT_ENOUGH_BYTES) { + if (dec->input_closed) { + return JXL_INPUT_ERROR("file too small for signature"); + } + return JXL_DEC_NEED_MORE_INPUT; + } + + dec->got_signature = true; + + if (sig == JXL_SIG_CONTAINER) { + dec->have_container = 1; + } else { + dec->last_codestream_seen = true; + } + } + + JxlDecoderStatus status = HandleBoxes(dec); + + if (status == JXL_DEC_NEED_MORE_INPUT && dec->input_closed) { + return JXL_INPUT_ERROR("premature end of input"); + } + + // Even if the box handling returns success, certain types of + // data may be missing. + if (status == JXL_DEC_SUCCESS) { + if (dec->CanUseMoreCodestreamInput()) { + return JXL_INPUT_ERROR("codestream never finished"); + } +#if JPEGXL_ENABLE_TRANSCODE_JPEG + if (dec->JbrdNeedMoreBoxes()) { + return JXL_INPUT_ERROR("missing metadata boxes for jpeg reconstruction"); + } +#endif + } + + return status; +} + +// To ensure ABI forward-compatibility, this struct has a constant size. +static_assert(sizeof(JxlBasicInfo) == 204, + "JxlBasicInfo struct size should remain constant"); + +JxlDecoderStatus JxlDecoderGetBasicInfo(const JxlDecoder* dec, + JxlBasicInfo* info) { + if (!dec->got_basic_info) return JXL_DEC_NEED_MORE_INPUT; + + if (info) { + memset(info, 0, sizeof(*info)); + + const jxl::ImageMetadata& meta = dec->metadata.m; + + info->have_container = dec->have_container; + info->xsize = dec->metadata.size.xsize(); + info->ysize = dec->metadata.size.ysize(); + info->uses_original_profile = !meta.xyb_encoded; + + info->bits_per_sample = meta.bit_depth.bits_per_sample; + info->exponent_bits_per_sample = meta.bit_depth.exponent_bits_per_sample; + + info->have_preview = meta.have_preview; + info->have_animation = meta.have_animation; + info->orientation = static_cast<JxlOrientation>(meta.orientation); + + if (!dec->keep_orientation) { + if (info->orientation >= JXL_ORIENT_TRANSPOSE) { + std::swap(info->xsize, info->ysize); + } + info->orientation = JXL_ORIENT_IDENTITY; + } + + info->intensity_target = meta.IntensityTarget(); + if (dec->desired_intensity_target > 0) { + info->intensity_target = dec->desired_intensity_target; + } + info->min_nits = meta.tone_mapping.min_nits; + info->relative_to_max_display = meta.tone_mapping.relative_to_max_display; + info->linear_below = meta.tone_mapping.linear_below; + + const jxl::ExtraChannelInfo* alpha = meta.Find(jxl::ExtraChannel::kAlpha); + if (alpha != nullptr) { + info->alpha_bits = alpha->bit_depth.bits_per_sample; + info->alpha_exponent_bits = alpha->bit_depth.exponent_bits_per_sample; + info->alpha_premultiplied = alpha->alpha_associated; + } else { + info->alpha_bits = 0; + info->alpha_exponent_bits = 0; + info->alpha_premultiplied = 0; + } + + info->num_color_channels = + meta.color_encoding.GetColorSpace() == jxl::ColorSpace::kGray ? 1 : 3; + + info->num_extra_channels = meta.num_extra_channels; + + if (info->have_preview) { + info->preview.xsize = dec->metadata.m.preview_size.xsize(); + info->preview.ysize = dec->metadata.m.preview_size.ysize(); + } + + if (info->have_animation) { + info->animation.tps_numerator = dec->metadata.m.animation.tps_numerator; + info->animation.tps_denominator = + dec->metadata.m.animation.tps_denominator; + info->animation.num_loops = dec->metadata.m.animation.num_loops; + info->animation.have_timecodes = dec->metadata.m.animation.have_timecodes; + } + + if (meta.have_intrinsic_size) { + info->intrinsic_xsize = dec->metadata.m.intrinsic_size.xsize(); + info->intrinsic_ysize = dec->metadata.m.intrinsic_size.ysize(); + } else { + info->intrinsic_xsize = info->xsize; + info->intrinsic_ysize = info->ysize; + } + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetExtraChannelInfo(const JxlDecoder* dec, + size_t index, + JxlExtraChannelInfo* info) { + if (!dec->got_basic_info) return JXL_DEC_NEED_MORE_INPUT; + + const std::vector<jxl::ExtraChannelInfo>& channels = + dec->metadata.m.extra_channel_info; + + if (index >= channels.size()) return JXL_DEC_ERROR; // out of bounds + const jxl::ExtraChannelInfo& channel = channels[index]; + + info->type = static_cast<JxlExtraChannelType>(channel.type); + info->bits_per_sample = channel.bit_depth.bits_per_sample; + info->exponent_bits_per_sample = + channel.bit_depth.floating_point_sample + ? channel.bit_depth.exponent_bits_per_sample + : 0; + info->dim_shift = channel.dim_shift; + info->name_length = channel.name.size(); + info->alpha_premultiplied = channel.alpha_associated; + info->spot_color[0] = channel.spot_color[0]; + info->spot_color[1] = channel.spot_color[1]; + info->spot_color[2] = channel.spot_color[2]; + info->spot_color[3] = channel.spot_color[3]; + info->cfa_channel = channel.cfa_channel; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetExtraChannelName(const JxlDecoder* dec, + size_t index, char* name, + size_t size) { + if (!dec->got_basic_info) return JXL_DEC_NEED_MORE_INPUT; + + const std::vector<jxl::ExtraChannelInfo>& channels = + dec->metadata.m.extra_channel_info; + + if (index >= channels.size()) return JXL_DEC_ERROR; // out of bounds + const jxl::ExtraChannelInfo& channel = channels[index]; + + // Also need null-termination character + if (channel.name.size() + 1 > size) return JXL_DEC_ERROR; + + memcpy(name, channel.name.c_str(), channel.name.size() + 1); + + return JXL_DEC_SUCCESS; +} + +namespace { + +// Gets the jxl::ColorEncoding for the desired target, and checks errors. +// Returns the object regardless of whether the actual color space is in ICC, +// but ensures that if the color encoding is not the encoding from the +// codestream header metadata, it cannot require ICC profile. +JxlDecoderStatus GetColorEncodingForTarget( + const JxlDecoder* dec, JxlColorProfileTarget target, + const jxl::ColorEncoding** encoding) { + if (!dec->got_all_headers) return JXL_DEC_NEED_MORE_INPUT; + *encoding = nullptr; + if (target == JXL_COLOR_PROFILE_TARGET_DATA && dec->metadata.m.xyb_encoded) { + *encoding = &dec->passes_state->output_encoding_info.color_encoding; + } else { + *encoding = &dec->metadata.m.color_encoding; + } + return JXL_DEC_SUCCESS; +} +} // namespace + +JxlDecoderStatus JxlDecoderGetColorAsEncodedProfile( + const JxlDecoder* dec, JxlColorProfileTarget target, + JxlColorEncoding* color_encoding) { + const jxl::ColorEncoding* jxl_color_encoding = nullptr; + JxlDecoderStatus status = + GetColorEncodingForTarget(dec, target, &jxl_color_encoding); + if (status) return status; + + if (jxl_color_encoding->WantICC()) + return JXL_DEC_ERROR; // Indicate no encoded profile available. + + if (color_encoding) { + *color_encoding = jxl_color_encoding->ToExternal(); + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetICCProfileSize(const JxlDecoder* dec, + JxlColorProfileTarget target, + size_t* size) { + const jxl::ColorEncoding* jxl_color_encoding = nullptr; + JxlDecoderStatus status = + GetColorEncodingForTarget(dec, target, &jxl_color_encoding); + if (status != JXL_DEC_SUCCESS) return status; + + if (jxl_color_encoding->WantICC()) { + jxl::ColorSpace color_space = + dec->metadata.m.color_encoding.GetColorSpace(); + if (color_space == jxl::ColorSpace::kUnknown || + color_space == jxl::ColorSpace::kXYB) { + // This indicates there's no ICC profile available + // TODO(lode): for the XYB case, do we want to craft an ICC profile that + // represents XYB as an RGB profile? It may be possible, but not with + // only 1D transfer functions. + return JXL_DEC_ERROR; + } + } + + if (size) { + *size = jxl_color_encoding->ICC().size(); + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetColorAsICCProfile(const JxlDecoder* dec, + JxlColorProfileTarget target, + uint8_t* icc_profile, + size_t size) { + size_t wanted_size; + // This also checks the NEED_MORE_INPUT and the unknown/xyb cases + JxlDecoderStatus status = + JxlDecoderGetICCProfileSize(dec, target, &wanted_size); + if (status != JXL_DEC_SUCCESS) return status; + if (size < wanted_size) return JXL_API_ERROR("ICC profile output too small"); + + const jxl::ColorEncoding* jxl_color_encoding = nullptr; + status = GetColorEncodingForTarget(dec, target, &jxl_color_encoding); + if (status != JXL_DEC_SUCCESS) return status; + + memcpy(icc_profile, jxl_color_encoding->ICC().data(), + jxl_color_encoding->ICC().size()); + + return JXL_DEC_SUCCESS; +} + +namespace { + +// Returns the amount of bits needed for getting memory buffer size, and does +// all error checking required for size checking and format validity. +JxlDecoderStatus PrepareSizeCheck(const JxlDecoder* dec, + const JxlPixelFormat* format, size_t* bits) { + if (!dec->got_basic_info) { + // Don't know image dimensions yet, cannot check for valid size. + return JXL_DEC_NEED_MORE_INPUT; + } + if (!dec->coalescing && + (!dec->frame_header || dec->frame_stage == FrameStage::kHeader)) { + return JXL_API_ERROR("Don't know frame dimensions yet"); + } + if (format->num_channels > 4) { + return JXL_API_ERROR("More than 4 channels not supported"); + } + + *bits = BitsPerChannel(format->data_type); + + if (*bits == 0) { + return JXL_API_ERROR("Invalid/unsupported data type"); + } + + return JXL_DEC_SUCCESS; +} + +} // namespace + +size_t JxlDecoderGetIntendedDownsamplingRatio(JxlDecoder* dec) { + return dec->downsampling_target; +} + +JxlDecoderStatus JxlDecoderFlushImage(JxlDecoder* dec) { + if (!dec->image_out_buffer_set) return JXL_DEC_ERROR; + if (dec->frame_stage != FrameStage::kFull) { + return JXL_DEC_ERROR; + } + JXL_DASSERT(dec->frame_dec); + if (!dec->frame_dec->HasDecodedDC()) { + // FrameDecoder::Flush currently requires DC to have been decoded already + // to work correctly. + return JXL_DEC_ERROR; + } + + if (!dec->frame_dec->Flush()) { + return JXL_DEC_ERROR; + } + + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderSetCms(JxlDecoder* dec, + const JxlCmsInterface cms) { + if (!dec->passes_state) { + dec->passes_state.reset(new jxl::PassesDecoderState()); + } + dec->passes_state->output_encoding_info.color_management_system = cms; + dec->passes_state->output_encoding_info.cms_set = true; + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderPreviewOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size) { + size_t bits; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits); + if (status != JXL_DEC_SUCCESS) return status; + if (format->num_channels < 3 && + !dec->image_metadata.color_encoding.IsGray()) { + return JXL_API_ERROR("Number of channels is too low for color output"); + } + + size_t xsize = dec->metadata.oriented_preview_xsize(dec->keep_orientation); + size_t ysize = dec->metadata.oriented_preview_ysize(dec->keep_orientation); + + size_t row_size = + jxl::DivCeil(xsize * format->num_channels * bits, jxl::kBitsPerByte); + size_t last_row_size = row_size; + if (format->align > 1) { + row_size = jxl::DivCeil(row_size, format->align) * format->align; + } + *size = row_size * (ysize - 1) + last_row_size; + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderSetPreviewOutBuffer( + JxlDecoder* dec, const JxlPixelFormat* format, void* buffer, size_t size) { + if (!dec->got_basic_info || !dec->metadata.m.have_preview || + !(dec->orig_events_wanted & JXL_DEC_PREVIEW_IMAGE)) { + return JXL_API_ERROR("No preview out buffer needed at this time"); + } + if (format->num_channels < 3 && + !dec->image_metadata.color_encoding.IsGray()) { + return JXL_API_ERROR("Number of channels is too low for color output"); + } + + size_t min_size; + // This also checks whether the format is valid and supported and basic info + // is available. + JxlDecoderStatus status = + JxlDecoderPreviewOutBufferSize(dec, format, &min_size); + if (status != JXL_DEC_SUCCESS) return status; + + if (size < min_size) return JXL_DEC_ERROR; + + dec->image_out_buffer_set = true; + dec->image_out_buffer = buffer; + dec->image_out_size = size; + dec->image_out_format = *format; + + return JXL_DEC_SUCCESS; +} + +JXL_EXPORT JxlDecoderStatus JxlDecoderImageOutBufferSize( + const JxlDecoder* dec, const JxlPixelFormat* format, size_t* size) { + size_t bits; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits); + if (status != JXL_DEC_SUCCESS) return status; + if (format->num_channels < 3 && + !dec->image_metadata.color_encoding.IsGray()) { + return JXL_API_ERROR("Number of channels is too low for color output"); + } + size_t xsize, ysize; + GetCurrentDimensions(dec, xsize, ysize); + size_t row_size = + jxl::DivCeil(xsize * format->num_channels * bits, jxl::kBitsPerByte); + if (format->align > 1) { + row_size = jxl::DivCeil(row_size, format->align) * format->align; + } + *size = row_size * ysize; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetImageOutBuffer(JxlDecoder* dec, + const JxlPixelFormat* format, + void* buffer, size_t size) { + if (!dec->got_basic_info || !(dec->orig_events_wanted & JXL_DEC_FULL_IMAGE)) { + return JXL_API_ERROR("No image out buffer needed at this time"); + } + if (dec->image_out_buffer_set && !!dec->image_out_run_callback) { + return JXL_API_ERROR( + "Cannot change from image out callback to image out buffer"); + } + if (format->num_channels < 3 && + !dec->image_metadata.color_encoding.IsGray()) { + return JXL_API_ERROR("Number of channels is too low for color output"); + } + size_t min_size; + // This also checks whether the format is valid and supported and basic info + // is available. + JxlDecoderStatus status = + JxlDecoderImageOutBufferSize(dec, format, &min_size); + if (status != JXL_DEC_SUCCESS) return status; + + if (size < min_size) return JXL_DEC_ERROR; + + dec->image_out_buffer_set = true; + dec->image_out_buffer = buffer; + dec->image_out_size = size; + dec->image_out_format = *format; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderExtraChannelBufferSize(const JxlDecoder* dec, + const JxlPixelFormat* format, + size_t* size, + uint32_t index) { + if (!dec->got_basic_info || !(dec->orig_events_wanted & JXL_DEC_FULL_IMAGE)) { + return JXL_API_ERROR("No extra channel buffer needed at this time"); + } + + if (index >= dec->metadata.m.num_extra_channels) { + return JXL_API_ERROR("Invalid extra channel index"); + } + + size_t num_channels = 1; // Do not use format's num_channels + + size_t bits; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits); + if (status != JXL_DEC_SUCCESS) return status; + + size_t xsize, ysize; + GetCurrentDimensions(dec, xsize, ysize); + size_t row_size = + jxl::DivCeil(xsize * num_channels * bits, jxl::kBitsPerByte); + if (format->align > 1) { + row_size = jxl::DivCeil(row_size, format->align) * format->align; + } + *size = row_size * ysize; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetExtraChannelBuffer(JxlDecoder* dec, + const JxlPixelFormat* format, + void* buffer, size_t size, + uint32_t index) { + size_t min_size; + // This also checks whether the format and index are valid and supported and + // basic info is available. + JxlDecoderStatus status = + JxlDecoderExtraChannelBufferSize(dec, format, &min_size, index); + if (status != JXL_DEC_SUCCESS) return status; + + if (size < min_size) return JXL_DEC_ERROR; + + if (dec->extra_channel_output.size() <= index) { + dec->extra_channel_output.resize(dec->metadata.m.num_extra_channels, + {{}, nullptr, 0}); + } + // Guaranteed correct thanks to check in JxlDecoderExtraChannelBufferSize. + JXL_ASSERT(index < dec->extra_channel_output.size()); + + dec->extra_channel_output[index].format = *format; + dec->extra_channel_output[index].format.num_channels = 1; + dec->extra_channel_output[index].buffer = buffer; + dec->extra_channel_output[index].buffer_size = size; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetImageOutCallback(JxlDecoder* dec, + const JxlPixelFormat* format, + JxlImageOutCallback callback, + void* opaque) { + dec->simple_image_out_callback.callback = callback; + dec->simple_image_out_callback.opaque = opaque; + const auto init_callback = + +[](void* init_opaque, size_t num_threads, size_t num_pixels_per_thread) { + // No initialization to do, just reuse init_opaque as run_opaque. + return init_opaque; + }; + const auto run_callback = + +[](void* run_opaque, size_t thread_id, size_t x, size_t y, + size_t num_pixels, const void* pixels) { + const auto* const simple_callback = + static_cast<const JxlDecoder::SimpleImageOutCallback*>(run_opaque); + simple_callback->callback(simple_callback->opaque, x, y, num_pixels, + pixels); + }; + const auto destroy_callback = +[](void* run_opaque) {}; + return JxlDecoderSetMultithreadedImageOutCallback( + dec, format, init_callback, run_callback, + /*destroy_callback=*/destroy_callback, &dec->simple_image_out_callback); +} + +JxlDecoderStatus JxlDecoderSetMultithreadedImageOutCallback( + JxlDecoder* dec, const JxlPixelFormat* format, + JxlImageOutInitCallback init_callback, JxlImageOutRunCallback run_callback, + JxlImageOutDestroyCallback destroy_callback, void* init_opaque) { + if (dec->image_out_buffer_set && !!dec->image_out_buffer) { + return JXL_API_ERROR( + "Cannot change from image out buffer to image out callback"); + } + + if (init_callback == nullptr || run_callback == nullptr || + destroy_callback == nullptr) { + return JXL_API_ERROR("All callbacks are required"); + } + + // Perform error checking for invalid format. + size_t bits_sink; + JxlDecoderStatus status = PrepareSizeCheck(dec, format, &bits_sink); + if (status != JXL_DEC_SUCCESS) return status; + + dec->image_out_buffer_set = true; + dec->image_out_init_callback = init_callback; + dec->image_out_run_callback = run_callback; + dec->image_out_destroy_callback = destroy_callback; + dec->image_out_init_opaque = init_opaque; + dec->image_out_format = *format; + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetFrameHeader(const JxlDecoder* dec, + JxlFrameHeader* header) { + if (!dec->frame_header || dec->frame_stage == FrameStage::kHeader) { + return JXL_API_ERROR("no frame header available"); + } + const auto& metadata = dec->metadata.m; + memset(header, 0, sizeof(*header)); + if (metadata.have_animation) { + header->duration = dec->frame_header->animation_frame.duration; + if (metadata.animation.have_timecodes) { + header->timecode = dec->frame_header->animation_frame.timecode; + } + } + header->name_length = dec->frame_header->name.size(); + header->is_last = dec->frame_header->is_last; + size_t xsize, ysize; + GetCurrentDimensions(dec, xsize, ysize); + header->layer_info.xsize = xsize; + header->layer_info.ysize = ysize; + if (!dec->coalescing && dec->frame_header->custom_size_or_origin) { + header->layer_info.crop_x0 = dec->frame_header->frame_origin.x0; + header->layer_info.crop_y0 = dec->frame_header->frame_origin.y0; + header->layer_info.have_crop = JXL_TRUE; + } else { + header->layer_info.crop_x0 = 0; + header->layer_info.crop_y0 = 0; + header->layer_info.have_crop = JXL_FALSE; + } + if (!dec->keep_orientation && !dec->coalescing) { + // orient the crop offset + size_t W = dec->metadata.oriented_xsize(false); + size_t H = dec->metadata.oriented_ysize(false); + if (metadata.orientation > 4) { + std::swap(header->layer_info.crop_x0, header->layer_info.crop_y0); + } + size_t o = (metadata.orientation - 1) & 3; + if (o > 0 && o < 3) { + header->layer_info.crop_x0 = W - xsize - header->layer_info.crop_x0; + } + if (o > 1) { + header->layer_info.crop_y0 = H - ysize - header->layer_info.crop_y0; + } + } + if (dec->coalescing) { + header->layer_info.blend_info.blendmode = JXL_BLEND_REPLACE; + header->layer_info.blend_info.source = 0; + header->layer_info.blend_info.alpha = 0; + header->layer_info.blend_info.clamp = JXL_FALSE; + header->layer_info.save_as_reference = 0; + } else { + header->layer_info.blend_info.blendmode = + static_cast<JxlBlendMode>(dec->frame_header->blending_info.mode); + header->layer_info.blend_info.source = + dec->frame_header->blending_info.source; + header->layer_info.blend_info.alpha = + dec->frame_header->blending_info.alpha_channel; + header->layer_info.blend_info.clamp = + dec->frame_header->blending_info.clamp; + header->layer_info.save_as_reference = dec->frame_header->save_as_reference; + } + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetExtraChannelBlendInfo(const JxlDecoder* dec, + size_t index, + JxlBlendInfo* blend_info) { + if (!dec->frame_header || dec->frame_stage == FrameStage::kHeader) { + return JXL_API_ERROR("no frame header available"); + } + const auto& metadata = dec->metadata.m; + if (index >= metadata.num_extra_channels) { + return JXL_API_ERROR("Invalid extra channel index"); + } + blend_info->blendmode = static_cast<JxlBlendMode>( + dec->frame_header->extra_channel_blending_info[index].mode); + blend_info->source = + dec->frame_header->extra_channel_blending_info[index].source; + blend_info->alpha = + dec->frame_header->extra_channel_blending_info[index].alpha_channel; + blend_info->clamp = + dec->frame_header->extra_channel_blending_info[index].clamp; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetFrameName(const JxlDecoder* dec, char* name, + size_t size) { + if (!dec->frame_header || dec->frame_stage == FrameStage::kHeader) { + return JXL_API_ERROR("no frame header available"); + } + if (size < dec->frame_header->name.size() + 1) { + return JXL_API_ERROR("too small frame name output buffer"); + } + memcpy(name, dec->frame_header->name.c_str(), + dec->frame_header->name.size() + 1); + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetPreferredColorProfile( + JxlDecoder* dec, const JxlColorEncoding* color_encoding) { + return JxlDecoderSetOutputColorProfile(dec, color_encoding, + /*icc_data=*/nullptr, /*icc_size=*/0); +} + +JxlDecoderStatus JxlDecoderSetOutputColorProfile( + JxlDecoder* dec, const JxlColorEncoding* color_encoding, + const uint8_t* icc_data, size_t icc_size) { + if ((color_encoding != nullptr) && (icc_data != nullptr)) { + return JXL_API_ERROR("cannot set both color_encoding and icc_data"); + } + if ((color_encoding == nullptr) && (icc_data == nullptr)) { + return JXL_API_ERROR("one of color_encoding and icc_data must be set"); + } + if (!dec->got_all_headers) { + return JXL_API_ERROR("color info not yet available"); + } + if (dec->post_headers) { + return JXL_API_ERROR("too late to set the color encoding"); + } + if ((!dec->passes_state->output_encoding_info.cms_set) && + (icc_data != nullptr)) { + return JXL_API_ERROR( + "must set color management system via JxlDecoderSetCms"); + } + auto& output_encoding = dec->passes_state->output_encoding_info; + if (color_encoding) { + if (dec->image_metadata.color_encoding.IsGray() && + color_encoding->color_space != JXL_COLOR_SPACE_GRAY && + dec->image_out_buffer_set && dec->image_out_format.num_channels < 3) { + return JXL_API_ERROR("Number of channels is too low for color output"); + } + if (color_encoding->color_space == JXL_COLOR_SPACE_UNKNOWN) { + return JXL_API_ERROR("Unknown output colorspace"); + } + jxl::ColorEncoding c_out; + JXL_API_RETURN_IF_ERROR(c_out.FromExternal(*color_encoding)); + JXL_API_RETURN_IF_ERROR(!c_out.ICC().empty()); + if (!c_out.SameColorEncoding(output_encoding.color_encoding)) { + JXL_API_RETURN_IF_ERROR(output_encoding.MaybeSetColorEncoding(c_out)); + dec->image_metadata.color_encoding = output_encoding.color_encoding; + } + return JXL_DEC_SUCCESS; + } + // icc_data != nullptr + // TODO(firsching): implement setting output color profile from icc_data. + jxl::ColorEncoding c_dst; + std::vector<uint8_t> padded_icc; + padded_icc.assign(icc_data, icc_data + icc_size); + if (!c_dst.SetICC(std::move(padded_icc), + &output_encoding.color_management_system)) { + return JXL_API_ERROR( + "setting output color profile from icc_data not yet implemented."); + } + JXL_API_RETURN_IF_ERROR( + (int)output_encoding.MaybeSetColorEncoding(std::move(c_dst))); + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetDesiredIntensityTarget( + JxlDecoder* dec, float desired_intensity_target) { + if (desired_intensity_target < 0) { + return JXL_API_ERROR("negative intensity target requested"); + } + dec->desired_intensity_target = desired_intensity_target; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetBoxBuffer(JxlDecoder* dec, uint8_t* data, + size_t size) { + if (dec->box_out_buffer_set) { + return JXL_API_ERROR("must release box buffer before setting it again"); + } + if (!dec->box_event) { + return JXL_API_ERROR("can only set box buffer after box event"); + } + + dec->box_out_buffer_set = true; + dec->box_out_buffer_set_current_box = true; + dec->box_out_buffer = data; + dec->box_out_buffer_size = size; + dec->box_out_buffer_pos = 0; + return JXL_DEC_SUCCESS; +} + +size_t JxlDecoderReleaseBoxBuffer(JxlDecoder* dec) { + if (!dec->box_out_buffer_set) { + return 0; + } + size_t result = dec->box_out_buffer_size - dec->box_out_buffer_pos; + dec->box_out_buffer_set = false; + dec->box_out_buffer = nullptr; + dec->box_out_buffer_size = 0; + if (!dec->box_out_buffer_set_current_box) { + dec->box_out_buffer_begin = 0; + } else { + dec->box_out_buffer_begin += dec->box_out_buffer_pos; + } + dec->box_out_buffer_set_current_box = false; + return result; +} + +JxlDecoderStatus JxlDecoderSetDecompressBoxes(JxlDecoder* dec, + JXL_BOOL decompress) { + // TODO(lode): return error if libbrotli is not compiled in the jxl decoding + // library + dec->decompress_boxes = decompress; + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetBoxType(JxlDecoder* dec, JxlBoxType type, + JXL_BOOL decompressed) { + if (!dec->box_event) { + return JXL_API_ERROR("can only get box info after JXL_DEC_BOX event"); + } + if (decompressed) { + memcpy(type, dec->box_decoded_type, sizeof(dec->box_decoded_type)); + } else { + memcpy(type, dec->box_type, sizeof(dec->box_type)); + } + + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderGetBoxSizeRaw(const JxlDecoder* dec, + uint64_t* size) { + if (!dec->box_event) { + return JXL_API_ERROR("can only get box info after JXL_DEC_BOX event"); + } + if (size) { + *size = dec->box_size; + } + return JXL_DEC_SUCCESS; +} + +JxlDecoderStatus JxlDecoderSetProgressiveDetail(JxlDecoder* dec, + JxlProgressiveDetail detail) { + if (detail != kDC && detail != kLastPasses && detail != kPasses) { + return JXL_API_ERROR( + "Values other than kDC (%d), kLastPasses (%d) and kPasses (%d), " + "like %d are not implemented.", + kDC, kLastPasses, kPasses, detail); + } + dec->prog_detail = detail; + return JXL_DEC_SUCCESS; +} + +namespace { + +template <typename T> +JxlDecoderStatus VerifyOutputBitDepth(JxlBitDepth bit_depth, const T& metadata, + JxlPixelFormat format) { + uint32_t bits_per_sample = GetBitDepth(bit_depth, metadata, format); + if (bits_per_sample == 0) return JXL_API_ERROR("Invalid output bit depth"); + if (format.data_type == JXL_TYPE_UINT8 && bits_per_sample > 8) { + return JXL_API_ERROR("Invalid bit depth %u for uint8 output", + bits_per_sample); + } else if (format.data_type == JXL_TYPE_UINT16 && bits_per_sample > 16) { + return JXL_API_ERROR("Invalid bit depth %u for uint16 output", + bits_per_sample); + } + return JXL_DEC_SUCCESS; +} + +} // namespace + +JxlDecoderStatus JxlDecoderSetImageOutBitDepth(JxlDecoder* dec, + const JxlBitDepth* bit_depth) { + if (!dec->image_out_buffer_set) { + return JXL_API_ERROR("No image out buffer was set."); + } + JXL_API_RETURN_IF_ERROR( + VerifyOutputBitDepth(*bit_depth, dec->metadata.m, dec->image_out_format)); + dec->image_out_bit_depth = *bit_depth; + return JXL_DEC_SUCCESS; +} diff --git a/third_party/jpeg-xl/lib/jxl/decode_test.cc b/third_party/jpeg-xl/lib/jxl/decode_test.cc new file mode 100644 index 0000000000..caee6dbc56 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/decode_test.cc @@ -0,0 +1,5615 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/decode.h" + +#include <jxl/cms.h> +#include <jxl/codestream_header.h> +#include <jxl/color_encoding.h> +#include <jxl/decode.h> +#include <jxl/decode_cxx.h> +#include <jxl/memory_manager.h> +#include <jxl/parallel_runner.h> +#include <jxl/resizable_parallel_runner.h> +#include <jxl/resizable_parallel_runner_cxx.h> +#include <jxl/thread_parallel_runner.h> +#include <jxl/thread_parallel_runner_cxx.h> +#include <jxl/types.h> + +#include <algorithm> +#include <cstdint> +#include <cstdio> +#include <cstdlib> +#include <cstring> +#include <ostream> +#include <set> +#include <sstream> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "lib/extras/dec/color_description.h" +#include "lib/extras/enc/encode.h" +#include "lib/extras/enc/jpg.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/cms/color_encoding_cms.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/enc_icc_codec.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_progressive_split.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/padded_bytes.h" +#include "lib/jxl/test_image.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" +#include "lib/jxl/toc.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace { +void AppendU32BE(uint32_t u32, std::vector<uint8_t>* bytes) { + bytes->push_back(u32 >> 24); + bytes->push_back(u32 >> 16); + bytes->push_back(u32 >> 8); + bytes->push_back(u32 >> 0); +} + +// What type of codestream format in the boxes to use for testing +enum CodeStreamBoxFormat { + // Do not use box format at all, only pure codestream + kCSBF_None, + // Have a single codestream box, with its actual size given in the box + kCSBF_Single, + // Have a single codestream box, with box size 0 (final box running to end) + kCSBF_Single_Zero_Terminated, + // Single codestream box, with another unknown box behind it + kCSBF_Single_Other, + // Have multiple partial codestream boxes + kCSBF_Multi, + // Have multiple partial codestream boxes, with final box size 0 (running + // to end) + kCSBF_Multi_Zero_Terminated, + // Have multiple partial codestream boxes, terminated by non-codestream box + kCSBF_Multi_Other_Terminated, + // Have multiple partial codestream boxes, terminated by non-codestream box + // that has its size set to 0 (running to end) + kCSBF_Multi_Other_Zero_Terminated, + // Have multiple partial codestream boxes, and the first one has a content + // of zero length + kCSBF_Multi_First_Empty, + // Have multiple partial codestream boxes, and the last one has a content + // of zero length and there is an unknown empty box at the end + kCSBF_Multi_Last_Empty_Other, + // Have a compressed exif box before a regular codestream box + kCSBF_Brob_Exif, + // Not a value but used for counting amount of enum entries + kCSBF_NUM_ENTRIES, +}; + +// Unknown boxes for testing +static const char* unk1_box_type = "unk1"; +static const char* unk1_box_contents = "abcdefghijklmnopqrstuvwxyz"; +static const size_t unk1_box_size = strlen(unk1_box_contents); +static const char* unk2_box_type = "unk2"; +static const char* unk2_box_contents = "0123456789"; +static const size_t unk2_box_size = strlen(unk2_box_contents); +static const char* unk3_box_type = "unk3"; +static const char* unk3_box_contents = "ABCDEF123456"; +static const size_t unk3_box_size = strlen(unk3_box_contents); +// Box with brob-compressed exif, including header +static const uint8_t* box_brob_exif = reinterpret_cast<const uint8_t*>( + "\0\0\0@brobExif\241\350\2\300\177\244v\2525\304\360\27=?\267{" + "\33\37\314\332\214QX17PT\"\256\0\0\202s\214\313t\333\310\320k\20\276\30" + "\204\277l$\326c#\1\b"); +size_t box_brob_exif_size = 64; +// The uncompressed Exif data from the brob box +static const uint8_t* exif_uncompressed = reinterpret_cast<const uint8_t*>( + "\0\0\0\0MM\0*" + "\0\0\0\b\0\5\1\22\0\3\0\0\0\1\0\5\0\0\1\32\0\5\0\0\0\1\0\0\0J\1\33\0\5\0\0" + "\0\1\0\0\0R\1(" + "\0\3\0\0\0\1\0\1\0\0\2\23\0\3\0\0\0\1\0\1\0\0\0\0\0\0\0\0\0\1\0\0\0\1\0\0" + "\0\1\0\0\0\1"); +size_t exif_uncompressed_size = 94; + +// Returns an ICC profile output by the JPEG XL decoder for RGB_D65_SRG_Rel_Lin, +// but with, on purpose, rXYZ, bXYZ and gXYZ (the RGB primaries) switched to a +// different order to ensure the profile does not match any known profile, so +// the encoder cannot encode it in a compact struct instead. +jxl::IccBytes GetIccTestProfile() { + const uint8_t* profile = reinterpret_cast<const uint8_t*>( + "\0\0\3\200lcms\0040\0\0mntrRGB XYZ " + "\a\344\0\a\0\27\0\21\0$" + "\0\37acspAPPL\0\0\0\1\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\1\0\0\366" + "\326\0\1\0\0\0\0\323-lcms\372c\207\36\227\200{" + "\2\232s\255\327\340\0\n\26\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0\rdesc\0\0\1 " + "\0\0\0Bcprt\0\0\1d\0\0\1\0wtpt\0\0\2d\0\0\0\24chad\0\0\2x\0\0\0," + "bXYZ\0\0\2\244\0\0\0\24gXYZ\0\0\2\270\0\0\0\24rXYZ\0\0\2\314\0\0\0\24rTR" + "C\0\0\2\340\0\0\0 gTRC\0\0\2\340\0\0\0 bTRC\0\0\2\340\0\0\0 " + "chrm\0\0\3\0\0\0\0$dmnd\0\0\3$\0\0\0(" + "dmdd\0\0\3L\0\0\0002mluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0&" + "\0\0\0\34\0R\0G\0B\0_\0D\0006\0005\0_\0S\0R\0G\0_\0R\0e\0l\0_" + "\0L\0i\0n\0\0mluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0\344\0\0\0\34\0C\0o\0" + "p\0y\0r\0i\0g\0h\0t\0 \0002\0000\0001\08\0 \0G\0o\0o\0g\0l\0e\0 " + "\0L\0L\0C\0,\0 \0C\0C\0-\0B\0Y\0-\0S\0A\0 \0003\0.\0000\0 " + "\0U\0n\0p\0o\0r\0t\0e\0d\0 " + "\0l\0i\0c\0e\0n\0s\0e\0(\0h\0t\0t\0p\0s\0:\0/\0/" + "\0c\0r\0e\0a\0t\0i\0v\0e\0c\0o\0m\0m\0o\0n\0s\0.\0o\0r\0g\0/" + "\0l\0i\0c\0e\0n\0s\0e\0s\0/\0b\0y\0-\0s\0a\0/\0003\0.\0000\0/" + "\0l\0e\0g\0a\0l\0c\0o\0d\0e\0)XYZ " + "\0\0\0\0\0\0\366\326\0\1\0\0\0\0\323-" + "sf32\0\0\0\0\0\1\fB\0\0\5\336\377\377\363%" + "\0\0\a\223\0\0\375\220\377\377\373\241\377\377\375\242\0\0\3\334\0\0\300" + "nXYZ \0\0\0\0\0\0o\240\0\08\365\0\0\3\220XYZ " + "\0\0\0\0\0\0$\237\0\0\17\204\0\0\266\304XYZ " + "\0\0\0\0\0\0b\227\0\0\267\207\0\0\30\331para\0\0\0\0\0\3\0\0\0\1\0\0\0\1" + "\0\0\0\0\0\0\0\1\0\0\0\0\0\0chrm\0\0\0\0\0\3\0\0\0\0\243\327\0\0T|" + "\0\0L\315\0\0\231\232\0\0&" + "g\0\0\17\\mluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0\f\0\0\0\34\0G\0o\0o\0g" + "\0l\0emluc\0\0\0\0\0\0\0\1\0\0\0\fenUS\0\0\0\26\0\0\0\34\0I\0m\0a\0g\0e" + "\0 \0c\0o\0d\0e\0c\0\0"); + size_t profile_size = 896; + jxl::IccBytes icc_profile; + icc_profile.assign(profile, profile + profile_size); + return icc_profile; +} + +} // namespace + +namespace jxl { +namespace { + +void AppendTestBox(const char* type, const char* contents, size_t contents_size, + bool unbounded, std::vector<uint8_t>* bytes) { + AppendU32BE(contents_size + 8, bytes); + bytes->push_back(type[0]); + bytes->push_back(type[1]); + bytes->push_back(type[2]); + bytes->push_back(type[3]); + const uint8_t* contents_u = reinterpret_cast<const uint8_t*>(contents); + Bytes(contents_u, contents_size).AppendTo(bytes); +} + +enum PreviewMode { + kNoPreview, + kSmallPreview, + kBigPreview, + kNumPreviewModes, +}; + +void GeneratePreview(PreviewMode preview_mode, ImageBundle* ib) { + if (preview_mode == kSmallPreview) { + ib->ShrinkTo(ib->xsize() / 7, ib->ysize() / 7); + } else if (preview_mode == kBigPreview) { + auto upsample7 = [&](const ImageF& in, ImageF* out) { + for (size_t y = 0; y < out->ysize(); ++y) { + for (size_t x = 0; x < out->xsize(); ++x) { + out->Row(y)[x] = in.ConstRow(y / 7)[x / 7]; + } + } + }; + Image3F preview(ib->xsize() * 7, ib->ysize() * 7); + for (size_t c = 0; c < 3; ++c) { + upsample7(ib->color()->Plane(c), &preview.Plane(c)); + } + std::vector<ImageF> extra_channels; + for (size_t i = 0; i < ib->extra_channels().size(); ++i) { + ImageF ec(ib->xsize() * 7, ib->ysize() * 7); + upsample7(ib->extra_channels()[i], &ec); + extra_channels.emplace_back(std::move(ec)); + } + ib->RemoveColor(); + ib->ClearExtraChannels(); + ib->SetFromImage(std::move(preview), ib->c_current()); + ib->SetExtraChannels(std::move(extra_channels)); + } +} + +struct TestCodestreamParams { + CompressParams cparams; + CodeStreamBoxFormat box_format = kCSBF_None; + JxlOrientation orientation = JXL_ORIENT_IDENTITY; + PreviewMode preview_mode = kNoPreview; + bool add_intrinsic_size = false; + bool add_icc_profile = false; + float intensity_target = 0.0; + std::string color_space; + std::vector<uint8_t>* jpeg_codestream = nullptr; +}; + +// Input pixels always given as 16-bit RGBA, 8 bytes per pixel. +// include_alpha determines if the encoded image should contain the alpha +// channel. +// add_icc_profile: if false, encodes the image as sRGB using the JXL fields, +// for grayscale or RGB images. If true, encodes the image using the ICC profile +// returned by GetIccTestProfile, without the JXL fields, this requires the +// image is RGB, not grayscale. +// Providing jpeg_codestream will populate the jpeg_codestream with compressed +// JPEG bytes, and make it possible to reconstruct those exact JPEG bytes using +// the return value _if_ add_container indicates a box format. +std::vector<uint8_t> CreateTestJXLCodestream( + Span<const uint8_t> pixels, size_t xsize, size_t ysize, size_t num_channels, + const TestCodestreamParams& params) { + // Compress the pixels with JPEG XL. + bool grayscale = (num_channels <= 2); + bool include_alpha = !(num_channels & 1) && params.jpeg_codestream == nullptr; + size_t bitdepth = params.jpeg_codestream == nullptr ? 16 : 8; + CodecInOut io; + io.SetSize(xsize, ysize); + ColorEncoding color_encoding; + if (params.add_icc_profile) { + // the hardcoded ICC profile we attach requires RGB. + EXPECT_EQ(false, grayscale); + EXPECT_TRUE(params.color_space.empty()); + EXPECT_TRUE(color_encoding.SetICC(GetIccTestProfile(), JxlGetDefaultCms())); + } else if (!params.color_space.empty()) { + JxlColorEncoding c; + EXPECT_TRUE(jxl::ParseDescription(params.color_space, &c)); + EXPECT_TRUE(color_encoding.FromExternal(c)); + EXPECT_EQ(color_encoding.IsGray(), grayscale); + } else { + color_encoding = jxl::ColorEncoding::SRGB(/*is_gray=*/grayscale); + } + io.metadata.m.SetUintSamples(bitdepth); + if (include_alpha) { + io.metadata.m.SetAlphaBits(bitdepth); + } + if (params.intensity_target != 0) { + io.metadata.m.SetIntensityTarget(params.intensity_target); + } + JxlPixelFormat format = {static_cast<uint32_t>(num_channels), JXL_TYPE_UINT16, + JXL_BIG_ENDIAN, 0}; + // Make the grayscale-ness of the io metadata color_encoding and the packed + // image match. + io.metadata.m.color_encoding = color_encoding; + EXPECT_TRUE(ConvertFromExternal(pixels, xsize, ysize, color_encoding, + /*bits_per_sample=*/16, format, + /* pool */ nullptr, &io.Main())); + std::vector<uint8_t> jpeg_data; + if (params.jpeg_codestream != nullptr) { + if (jxl::extras::CanDecode(jxl::extras::Codec::kJPG)) { + std::vector<uint8_t> jpeg_bytes; + extras::PackedPixelFile ppf; + extras::PackedFrame frame(xsize, ysize, format); + JXL_ASSERT(frame.color.pixels_size == pixels.size()); + memcpy(frame.color.pixels(0, 0, 0), pixels.data(), pixels.size()); + ppf.frames.emplace_back(std::move(frame)); + ppf.info.xsize = xsize; + ppf.info.ysize = ysize; + ppf.info.num_color_channels = grayscale ? 1 : 3; + ppf.info.bits_per_sample = 16; + auto encoder = extras::GetJPEGEncoder(); + encoder->SetOption("quality", "70"); + extras::EncodedImage encoded; + EXPECT_TRUE(encoder->Encode(ppf, &encoded)); + jpeg_bytes = encoded.bitstreams[0]; + Bytes(jpeg_bytes).AppendTo(params.jpeg_codestream); + EXPECT_TRUE(jxl::jpeg::DecodeImageJPG( + jxl::Bytes(jpeg_bytes.data(), jpeg_bytes.size()), &io)); + EXPECT_TRUE( + EncodeJPEGData(*io.Main().jpeg_data, &jpeg_data, params.cparams)); + io.metadata.m.xyb_encoded = false; + } else { + JXL_ABORT( + "unable to create reconstructible JPEG without JPEG support enabled"); + } + } + if (params.preview_mode) { + io.preview_frame = io.Main().Copy(); + GeneratePreview(params.preview_mode, &io.preview_frame); + io.metadata.m.have_preview = true; + EXPECT_TRUE(io.metadata.m.preview_size.Set(io.preview_frame.xsize(), + io.preview_frame.ysize())); + } + if (params.add_intrinsic_size) { + EXPECT_TRUE(io.metadata.m.intrinsic_size.Set(xsize / 3, ysize / 3)); + } + io.metadata.m.orientation = params.orientation; + std::vector<uint8_t> compressed; + EXPECT_TRUE(test::EncodeFile(params.cparams, &io, &compressed)); + CodeStreamBoxFormat add_container = params.box_format; + if (add_container != kCSBF_None) { + // Header with signature box and ftyp box. + const uint8_t header[] = {0, 0, 0, 0xc, 0x4a, 0x58, 0x4c, 0x20, + 0xd, 0xa, 0x87, 0xa, 0, 0, 0, 0x14, + 0x66, 0x74, 0x79, 0x70, 0x6a, 0x78, 0x6c, 0x20, + 0, 0, 0, 0, 0x6a, 0x78, 0x6c, 0x20}; + + bool is_multi = add_container == kCSBF_Multi || + add_container == kCSBF_Multi_Zero_Terminated || + add_container == kCSBF_Multi_Other_Terminated || + add_container == kCSBF_Multi_Other_Zero_Terminated || + add_container == kCSBF_Multi_First_Empty || + add_container == kCSBF_Multi_Last_Empty_Other; + + if (is_multi) { + size_t third = compressed.size() / 3; + std::vector<uint8_t> compressed0(compressed.data(), + compressed.data() + third); + std::vector<uint8_t> compressed1(compressed.data() + third, + compressed.data() + 2 * third); + std::vector<uint8_t> compressed2(compressed.data() + 2 * third, + compressed.data() + compressed.size()); + + std::vector<uint8_t> c; + Bytes(header).AppendTo(&c); + if (params.jpeg_codestream != nullptr) { + jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_data.size(), false, + &c); + Bytes(jpeg_data).AppendTo(&c); + } + uint32_t jxlp_index = 0; + if (add_container == kCSBF_Multi_First_Empty) { + // Empty placeholder codestream part + AppendU32BE(12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + } + // First codestream part + AppendU32BE(compressed0.size() + 12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + Bytes(compressed0).AppendTo(&c); + // A few non-codestream boxes in between + AppendTestBox(unk1_box_type, unk1_box_contents, unk1_box_size, false, &c); + AppendTestBox(unk2_box_type, unk2_box_contents, unk2_box_size, false, &c); + // Empty placeholder codestream part + AppendU32BE(12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + // Second codestream part + AppendU32BE(compressed1.size() + 12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++, &c); + Bytes(compressed1).AppendTo(&c); + // Third (last) codestream part + AppendU32BE(add_container == kCSBF_Multi_Zero_Terminated + ? 0 + : (compressed2.size() + 12), + &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + if (add_container != kCSBF_Multi_Last_Empty_Other) { + AppendU32BE(jxlp_index++ | 0x80000000, &c); + } else { + AppendU32BE(jxlp_index++, &c); + } + Bytes(compressed2).AppendTo(&c); + if (add_container == kCSBF_Multi_Last_Empty_Other) { + // Empty placeholder codestream part + AppendU32BE(12, &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('p'); + AppendU32BE(jxlp_index++ | 0x80000000, &c); + AppendTestBox(unk3_box_type, unk3_box_contents, unk3_box_size, false, + &c); + } + if (add_container == kCSBF_Multi_Other_Terminated) { + AppendTestBox(unk3_box_type, unk3_box_contents, unk3_box_size, false, + &c); + } + if (add_container == kCSBF_Multi_Other_Zero_Terminated) { + AppendTestBox(unk3_box_type, unk3_box_contents, unk3_box_size, true, + &c); + } + compressed.swap(c); + } else { + std::vector<uint8_t> c; + Bytes(header).AppendTo(&c); + if (params.jpeg_codestream != nullptr) { + jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_data.size(), false, + &c); + Bytes(jpeg_data).AppendTo(&c); + } + if (add_container == kCSBF_Brob_Exif) { + Bytes(box_brob_exif, box_brob_exif_size).AppendTo(&c); + } + AppendU32BE(add_container == kCSBF_Single_Zero_Terminated + ? 0 + : (compressed.size() + 8), + &c); + c.push_back('j'); + c.push_back('x'); + c.push_back('l'); + c.push_back('c'); + Bytes(compressed).AppendTo(&c); + if (add_container == kCSBF_Single_Other) { + AppendTestBox(unk1_box_type, unk1_box_contents, unk1_box_size, false, + &c); + } + compressed.swap(c); + } + } + + return compressed; +} + +JxlDecoderStatus ProcessInputIgnoreBoxes(JxlDecoder* dec) { + JxlDecoderStatus status = JXL_DEC_BOX; + while (status == JXL_DEC_BOX) { + status = JxlDecoderProcessInput(dec); + } + return status; +} + +// Decodes one-shot with the API for non-streaming decoding tests. +std::vector<uint8_t> DecodeWithAPI(JxlDecoder* dec, + Span<const uint8_t> compressed, + const JxlPixelFormat& format, + bool use_callback, bool set_buffer_early, + bool use_resizable_runner, + bool require_boxes, bool expect_success, + std::vector<uint8_t>* icc = nullptr) { + JxlThreadParallelRunnerPtr runner_fixed; + JxlResizableParallelRunnerPtr runner_resizable; + JxlParallelRunner runner_fn; + void* runner; + + if (use_resizable_runner) { + runner_resizable = JxlResizableParallelRunnerMake(nullptr); + runner = runner_resizable.get(); + runner_fn = JxlResizableParallelRunner; + } else { + size_t hw_threads = JxlThreadParallelRunnerDefaultNumWorkerThreads(); + runner_fixed = + JxlThreadParallelRunnerMake(nullptr, std::min<size_t>(hw_threads, 16)); + runner = runner_fixed.get(); + runner_fn = JxlThreadParallelRunner; + } + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, runner_fn, runner)); + + auto process_input = + require_boxes ? ProcessInputIgnoreBoxes : JxlDecoderProcessInput; + + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | (set_buffer_early ? JXL_DEC_FRAME : 0) | + JXL_DEC_PREVIEW_IMAGE | JXL_DEC_FULL_IMAGE | + (require_boxes ? JXL_DEC_BOX : 0) | + (icc != nullptr ? JXL_DEC_COLOR_ENCODING : 0))); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, process_input(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + if (use_resizable_runner) { + JxlResizableParallelRunnerSetThreads( + runner, + JxlResizableParallelRunnerSuggestThreads(info.xsize, info.ysize)); + } + + std::vector<uint8_t> pixels(buffer_size); + size_t bytes_per_pixel = format.num_channels * + test::GetDataBits(format.data_type) / + jxl::kBitsPerByte; + size_t stride = bytes_per_pixel * info.xsize; + if (format.align > 1) { + stride = jxl::DivCeil(stride, format.align) * format.align; + } + auto callback = [&](size_t x, size_t y, size_t num_pixels, + const void* pixels_row) { + memcpy(pixels.data() + stride * y + bytes_per_pixel * x, pixels_row, + num_pixels * bytes_per_pixel); + }; + + JxlDecoderStatus status = process_input(dec); + + if (status == JXL_DEC_COLOR_ENCODING) { + size_t icc_size = 0; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_DATA, + &icc_size)); + icc->resize(icc_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsICCProfile(dec, JXL_COLOR_PROFILE_TARGET_DATA, + icc->data(), icc_size)); + + status = process_input(dec); + } + + std::vector<uint8_t> preview; + if (status == JXL_DEC_NEED_PREVIEW_OUT_BUFFER) { + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)); + preview.resize(buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreviewOutBuffer(dec, &format, preview.data(), + preview.size())); + EXPECT_EQ(JXL_DEC_PREVIEW_IMAGE, process_input(dec)); + + status = process_input(dec); + } + + if (set_buffer_early) { + EXPECT_EQ(JXL_DEC_FRAME, status); + } else { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, status); + } + + if (use_callback) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutCallback( + dec, &format, + [](void* opaque, size_t x, size_t y, size_t xsize, + const void* pixels_row) { + auto cb = static_cast<decltype(&callback)>(opaque); + (*cb)(x, y, xsize, pixels_row); + }, + /*opaque=*/&callback)); + } else { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + } + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, process_input(dec)); + + // After the full image was output, JxlDecoderProcessInput should return + // success to indicate all is done, unless we requested boxes and the last + // box was not a terminal unbounded box, in which case it should ask for + // more input. + JxlDecoderStatus expected_status = + expect_success ? JXL_DEC_SUCCESS : JXL_DEC_NEED_MORE_INPUT; + EXPECT_EQ(expected_status, process_input(dec)); + + return pixels; +} + +// Decodes one-shot with the API for non-streaming decoding tests. +std::vector<uint8_t> DecodeWithAPI(Span<const uint8_t> compressed, + const JxlPixelFormat& format, + bool use_callback, bool set_buffer_early, + bool use_resizable_runner, + bool require_boxes, bool expect_success) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + std::vector<uint8_t> pixels = + DecodeWithAPI(dec, compressed, format, use_callback, set_buffer_early, + use_resizable_runner, require_boxes, expect_success); + JxlDecoderDestroy(dec); + return pixels; +} + +} // namespace +} // namespace jxl + +//////////////////////////////////////////////////////////////////////////////// + +TEST(DecodeTest, JxlSignatureCheckTest) { + std::vector<std::pair<int, std::vector<uint8_t>>> tests = { + // No JPEGXL header starts with 'a'. + {JXL_SIG_INVALID, {'a'}}, + {JXL_SIG_INVALID, {'a', 'b', 'c', 'd', 'e', 'f'}}, + + // Empty file is not enough bytes. + {JXL_SIG_NOT_ENOUGH_BYTES, {}}, + + // JPEGXL headers. + {JXL_SIG_NOT_ENOUGH_BYTES, {0xff}}, // Part of a signature. + {JXL_SIG_INVALID, {0xff, 0xD8}}, // JPEG-1 + {JXL_SIG_CODESTREAM, {0xff, 0x0a}}, + + // JPEGXL container file. + {JXL_SIG_CONTAINER, + {0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xD, 0xA, 0x87, 0xA}}, + // Ending with invalid byte. + {JXL_SIG_INVALID, {0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xD, 0xA, 0x87, 0}}, + // Part of signature. + {JXL_SIG_NOT_ENOUGH_BYTES, + {0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xD, 0xA, 0x87}}, + {JXL_SIG_NOT_ENOUGH_BYTES, {0}}, + }; + for (const auto& test : tests) { + EXPECT_EQ(test.first, + JxlSignatureCheck(test.second.data(), test.second.size())) + << "Where test data is " << ::testing::PrintToString(test.second); + } +} + +TEST(DecodeTest, DefaultAllocTest) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, CustomAllocTest) { + struct CalledCounters { + int allocs = 0; + int frees = 0; + } counters; + + JxlMemoryManager mm; + mm.opaque = &counters; + mm.alloc = [](void* opaque, size_t size) { + reinterpret_cast<CalledCounters*>(opaque)->allocs++; + return malloc(size); + }; + mm.free = [](void* opaque, void* address) { + reinterpret_cast<CalledCounters*>(opaque)->frees++; + free(address); + }; + + JxlDecoder* dec = JxlDecoderCreate(&mm); + EXPECT_NE(nullptr, dec); + EXPECT_LE(1, counters.allocs); + EXPECT_EQ(0, counters.frees); + JxlDecoderDestroy(dec); + EXPECT_LE(1, counters.frees); +} + +// TODO(lode): add multi-threaded test when multithreaded pixel decoding from +// API is implemented. +TEST(DecodeTest, DefaultParallelRunnerTest) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, nullptr, nullptr)); + JxlDecoderDestroy(dec); +} + +// Creates the header of a JPEG XL file with various custom parameters for +// testing. +// xsize, ysize: image dimensions to store in the SizeHeader, max 512. +// bits_per_sample, orientation: a selection of header parameters to test with. +// orientation: image orientation to set in the metadata +// alpha_bits: if non-0, alpha extra channel bits to set in the metadata. Also +// gives the alpha channel the name "alpha_test" +// have_container: add box container format around the codestream. +// metadata_default: if true, ImageMetadata is set to default and +// bits_per_sample, orientation and alpha_bits are ignored. +// insert_box: insert an extra box before the codestream box, making the header +// farther away from the front than is ideal. Only used if have_container. +std::vector<uint8_t> GetTestHeader(size_t xsize, size_t ysize, + size_t bits_per_sample, size_t orientation, + size_t alpha_bits, bool xyb_encoded, + bool have_container, bool metadata_default, + bool insert_extra_box, + const jxl::IccBytes& icc_profile) { + jxl::BitWriter writer; + jxl::BitWriter::Allotment allotment(&writer, 65536); // Large enough + + if (have_container) { + const std::vector<uint8_t> signature_box = {0, 0, 0, 0xc, 'J', 'X', + 'L', ' ', 0xd, 0xa, 0x87, 0xa}; + const std::vector<uint8_t> filetype_box = { + 0, 0, 0, 0x14, 'f', 't', 'y', 'p', 'j', 'x', + 'l', ' ', 0, 0, 0, 0, 'j', 'x', 'l', ' '}; + const std::vector<uint8_t> extra_box_header = {0, 0, 0, 0xff, + 't', 'e', 's', 't'}; + // Beginning of codestream box, with an arbitrary size certainly large + // enough to contain the header + const std::vector<uint8_t> codestream_box_header = {0, 0, 0, 0xff, + 'j', 'x', 'l', 'c'}; + + for (size_t i = 0; i < signature_box.size(); i++) { + writer.Write(8, signature_box[i]); + } + for (size_t i = 0; i < filetype_box.size(); i++) { + writer.Write(8, filetype_box[i]); + } + if (insert_extra_box) { + for (size_t i = 0; i < extra_box_header.size(); i++) { + writer.Write(8, extra_box_header[i]); + } + for (size_t i = 0; i < 255 - 8; i++) { + writer.Write(8, 0); + } + } + for (size_t i = 0; i < codestream_box_header.size(); i++) { + writer.Write(8, codestream_box_header[i]); + } + } + + // JXL signature + writer.Write(8, 0xff); + writer.Write(8, 0x0a); + + // SizeHeader + jxl::CodecMetadata metadata; + EXPECT_TRUE(metadata.size.Set(xsize, ysize)); + EXPECT_TRUE(WriteSizeHeader(metadata.size, &writer, 0, nullptr)); + + if (!metadata_default) { + metadata.m.SetUintSamples(bits_per_sample); + metadata.m.orientation = orientation; + metadata.m.SetAlphaBits(alpha_bits); + metadata.m.xyb_encoded = xyb_encoded; + if (alpha_bits != 0) { + metadata.m.extra_channel_info[0].name = "alpha_test"; + } + } + + if (!icc_profile.empty()) { + jxl::IccBytes copy = icc_profile; + EXPECT_TRUE( + metadata.m.color_encoding.SetICC(std::move(copy), JxlGetDefaultCms())); + } + + EXPECT_TRUE(jxl::Bundle::Write(metadata.m, &writer, 0, nullptr)); + metadata.transform_data.nonserialized_xyb_encoded = metadata.m.xyb_encoded; + EXPECT_TRUE(jxl::Bundle::Write(metadata.transform_data, &writer, 0, nullptr)); + + if (!icc_profile.empty()) { + EXPECT_TRUE(metadata.m.color_encoding.WantICC()); + EXPECT_TRUE(jxl::WriteICC(icc_profile, &writer, 0, nullptr)); + } + + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, 0, nullptr); + return std::vector<uint8_t>( + writer.GetSpan().data(), + writer.GetSpan().data() + writer.GetSpan().size()); +} + +TEST(DecodeTest, BasicInfoTest) { + size_t xsize[2] = {50, 33}; + size_t ysize[2] = {50, 77}; + size_t bits_per_sample[2] = {8, 23}; + size_t orientation[2] = {3, 5}; + size_t alpha_bits[2] = {0, 8}; + JXL_BOOL have_container[2] = {0, 1}; + bool xyb_encoded = false; + + std::vector<std::vector<uint8_t>> test_samples; + // Test with direct codestream + test_samples.push_back(GetTestHeader( + xsize[0], ysize[0], bits_per_sample[0], orientation[0], alpha_bits[0], + xyb_encoded, have_container[0], /*metadata_default=*/false, + /*insert_extra_box=*/false, {})); + // Test with container and different parameters + test_samples.push_back(GetTestHeader( + xsize[1], ysize[1], bits_per_sample[1], orientation[1], alpha_bits[1], + xyb_encoded, have_container[1], /*metadata_default=*/false, + /*insert_extra_box=*/false, {})); + + for (size_t i = 0; i < test_samples.size(); ++i) { + const std::vector<uint8_t>& data = test_samples[i]; + // Test decoding too small header first, until we reach the final byte. + for (size_t size = 0; size <= data.size(); ++size) { + // Test with a new decoder for each tested byte size. + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + const uint8_t* next_in = data.data(); + size_t avail_in = size; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + + JxlBasicInfo info; + bool have_basic_info = !JxlDecoderGetBasicInfo(dec, &info); + + if (size == data.size()) { + EXPECT_EQ(JXL_DEC_BASIC_INFO, status); + + // All header bytes given so the decoder must have the basic info. + EXPECT_EQ(true, have_basic_info); + EXPECT_EQ(have_container[i], info.have_container); + EXPECT_EQ(alpha_bits[i], info.alpha_bits); + // Orientations 5..8 swap the dimensions + if (orientation[i] >= 5) { + EXPECT_EQ(xsize[i], info.ysize); + EXPECT_EQ(ysize[i], info.xsize); + } else { + EXPECT_EQ(xsize[i], info.xsize); + EXPECT_EQ(ysize[i], info.ysize); + } + // The API should set the orientation to identity by default since it + // already applies the transformation internally by default. + EXPECT_EQ(1u, info.orientation); + + EXPECT_EQ(3u, info.num_color_channels); + + if (alpha_bits[i] != 0) { + // Expect an extra channel + EXPECT_EQ(1u, info.num_extra_channels); + JxlExtraChannelInfo extra; + EXPECT_EQ(0, JxlDecoderGetExtraChannelInfo(dec, 0, &extra)); + EXPECT_EQ(alpha_bits[i], extra.bits_per_sample); + EXPECT_EQ(JXL_CHANNEL_ALPHA, extra.type); + EXPECT_EQ(0, extra.alpha_premultiplied); + // Verify the name "alpha_test" given to the alpha channel + EXPECT_EQ(10u, extra.name_length); + char name[11]; + EXPECT_EQ(0, + JxlDecoderGetExtraChannelName(dec, 0, name, sizeof(name))); + EXPECT_EQ(std::string("alpha_test"), std::string(name)); + } else { + EXPECT_EQ(0u, info.num_extra_channels); + } + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + } else { + // If we did not give the full header, the basic info should not be + // available. Allow a few bytes of slack due to some bits for default + // opsinmatrix/extension bits. + if (size + 2 < data.size()) { + EXPECT_EQ(false, have_basic_info); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, status); + } + } + + // Test that decoder doesn't allow setting a setting required at beginning + // unless it's reset + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + JxlDecoderReset(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + + JxlDecoderDestroy(dec); + } + } +} + +TEST(DecodeTest, BufferSizeTest) { + size_t xsize = 33; + size_t ysize = 77; + size_t bits_per_sample = 8; + size_t orientation = 1; + size_t alpha_bits = 8; + bool have_container = false; + bool xyb_encoded = false; + + std::vector<uint8_t> header = + GetTestHeader(xsize, ysize, bits_per_sample, orientation, alpha_bits, + xyb_encoded, have_container, /*metadata_default=*/false, + /*insert_extra_box=*/false, {}); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + const uint8_t* next_in = header.data(); + size_t avail_in = header.size(); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + EXPECT_EQ(JXL_DEC_BASIC_INFO, status); + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + + JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + size_t image_out_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &image_out_size)); + EXPECT_EQ(xsize * ysize * 4, image_out_size); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, BasicInfoSizeHintTest) { + // Test on a file where the size hint is too small initially due to inserting + // a box before the codestream (something that is normally not recommended) + size_t xsize = 50; + size_t ysize = 50; + size_t bits_per_sample = 16; + size_t orientation = 1; + size_t alpha_bits = 0; + bool xyb_encoded = false; + std::vector<uint8_t> data = GetTestHeader( + xsize, ysize, bits_per_sample, orientation, alpha_bits, xyb_encoded, + /*have_container=*/true, /*metadata_default=*/false, + /*insert_extra_box=*/true, {}); + + JxlDecoderStatus status; + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + + size_t hint0 = JxlDecoderSizeHintBasicInfo(dec); + // Test that the test works as intended: we construct a file on purpose to + // be larger than the first hint by having that extra box. + EXPECT_LT(hint0, data.size()); + const uint8_t* next_in = data.data(); + // Do as if we have only as many bytes as indicated by the hint available + size_t avail_in = std::min(hint0, data.size()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + status = JxlDecoderProcessInput(dec); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, status); + // Basic info cannot be available yet due to the extra inserted box. + EXPECT_EQ(false, !JxlDecoderGetBasicInfo(dec, nullptr)); + + size_t num_read = avail_in - JxlDecoderReleaseInput(dec); + EXPECT_LT(num_read, data.size()); + + size_t hint1 = JxlDecoderSizeHintBasicInfo(dec); + // The hint must be larger than the previous hint (taking already processed + // bytes into account, the hint is a hint for the next avail_in) since the + // decoder now knows there is a box in between. + EXPECT_GT(hint1 + num_read, hint0); + avail_in = std::min<size_t>(hint1, data.size() - num_read); + next_in += num_read; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + status = JxlDecoderProcessInput(dec); + EXPECT_EQ(JXL_DEC_BASIC_INFO, status); + JxlBasicInfo info; + // We should have the basic info now, since we only added one box in-between, + // and the decoder should have known its size, its implementation can return + // a correct hint. + EXPECT_EQ(true, !JxlDecoderGetBasicInfo(dec, &info)); + + // Also test if the basic info is correct. + EXPECT_EQ(1, info.have_container); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + EXPECT_EQ(orientation, info.orientation); + EXPECT_EQ(bits_per_sample, info.bits_per_sample); + + JxlDecoderDestroy(dec); +} + +std::vector<uint8_t> GetIccTestHeader(const jxl::IccBytes& icc_profile, + bool xyb_encoded) { + size_t xsize = 50; + size_t ysize = 50; + size_t bits_per_sample = 16; + size_t orientation = 1; + size_t alpha_bits = 0; + return GetTestHeader(xsize, ysize, bits_per_sample, orientation, alpha_bits, + xyb_encoded, + /*have_container=*/false, /*metadata_default=*/false, + /*insert_extra_box=*/false, icc_profile); +} + +// Tests the case where pixels and metadata ICC profile are the same +TEST(DecodeTest, IccProfileTestOriginal) { + jxl::IccBytes icc_profile = GetIccTestProfile(); + bool xyb_encoded = false; + std::vector<uint8_t> data = GetIccTestHeader(icc_profile, xyb_encoded); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), data.size())); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + + // Expect the opposite of xyb_encoded for uses_original_profile + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(JXL_TRUE, info.uses_original_profile); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + // the encoded color profile expected to be not available, since the image + // has an ICC profile instead + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderGetColorAsEncodedProfile( + dec, JXL_COLOR_PROFILE_TARGET_ORIGINAL, nullptr)); + + size_t dec_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + &dec_profile_size)); + + // Check that can get return status with NULL size + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + nullptr)); + + // The profiles must be equal. This requires they have equal size, and if + // they do, we can get the profile and compare the contents. + EXPECT_EQ(icc_profile.size(), dec_profile_size); + if (icc_profile.size() == dec_profile_size) { + jxl::IccBytes icc_profile2(icc_profile.size()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + icc_profile2.data(), icc_profile2.size())); + EXPECT_EQ(icc_profile, icc_profile2); + } + + // the data is not xyb_encoded, so same result expected for the pixel data + // color profile + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderGetColorAsEncodedProfile( + dec, JXL_COLOR_PROFILE_TARGET_DATA, nullptr)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_DATA, + &dec_profile_size)); + EXPECT_EQ(icc_profile.size(), dec_profile_size); + + JxlDecoderDestroy(dec); +} + +// Tests the case where pixels and metadata ICC profile are different +TEST(DecodeTest, IccProfileTestXybEncoded) { + jxl::IccBytes icc_profile = GetIccTestProfile(); + bool xyb_encoded = true; + std::vector<uint8_t> data = GetIccTestHeader(icc_profile, xyb_encoded); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), data.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + + // Expect the opposite of xyb_encoded for uses_original_profile + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(JXL_FALSE, info.uses_original_profile); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + // the encoded color profile expected to be not available, since the image + // has an ICC profile instead + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderGetColorAsEncodedProfile( + dec, JXL_COLOR_PROFILE_TARGET_ORIGINAL, nullptr)); + + // Check that can get return status with NULL size + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + nullptr)); + + size_t dec_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + &dec_profile_size)); + + // The profiles must be equal. This requires they have equal size, and if + // they do, we can get the profile and compare the contents. + EXPECT_EQ(icc_profile.size(), dec_profile_size); + if (icc_profile.size() == dec_profile_size) { + jxl::IccBytes icc_profile2(icc_profile.size()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + icc_profile2.data(), icc_profile2.size())); + EXPECT_EQ(icc_profile, icc_profile2); + } + + // Data is xyb_encoded, so the data profile is a different profile, encoded + // as structured profile. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsEncodedProfile( + dec, JXL_COLOR_PROFILE_TARGET_DATA, nullptr)); + JxlColorEncoding pixel_encoding; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, JXL_COLOR_PROFILE_TARGET_DATA, &pixel_encoding)); + EXPECT_EQ(JXL_PRIMARIES_SRGB, pixel_encoding.primaries); + // The API returns LINEAR by default when the colorspace cannot be represented + // by enum values. + EXPECT_EQ(JXL_TRANSFER_FUNCTION_LINEAR, pixel_encoding.transfer_function); + + // Test the same but with integer format. + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, JXL_COLOR_PROFILE_TARGET_DATA, &pixel_encoding)); + EXPECT_EQ(JXL_PRIMARIES_SRGB, pixel_encoding.primaries); + EXPECT_EQ(JXL_TRANSFER_FUNCTION_LINEAR, pixel_encoding.transfer_function); + + // Test after setting the preferred color profile to non-linear sRGB: + // for XYB images with ICC profile, this setting is expected to take effect. + jxl::ColorEncoding temp_jxl_srgb = jxl::ColorEncoding::SRGB(false); + JxlColorEncoding pixel_encoding_srgb = temp_jxl_srgb.ToExternal(); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreferredColorProfile(dec, &pixel_encoding_srgb)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, JXL_COLOR_PROFILE_TARGET_DATA, &pixel_encoding)); + EXPECT_EQ(JXL_TRANSFER_FUNCTION_SRGB, pixel_encoding.transfer_function); + + // The decoder can also output this as a generated ICC profile anyway, and + // we're certain that it will differ from the above defined profile since + // the sRGB data should not have swapped R/G/B primaries. + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_DATA, + &dec_profile_size)); + // We don't need to dictate exactly what size the generated ICC profile + // must be (since there are many ways to represent the same color space), + // but it should not be zero. + EXPECT_NE(0u, dec_profile_size); + jxl::IccBytes icc_profile2(dec_profile_size); + if (0 != dec_profile_size) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile2.data(), icc_profile2.size())); + // expected not equal + EXPECT_NE(icc_profile, icc_profile2); + } + + // Test setting another different preferred profile, to verify that the + // returned JXL_COLOR_PROFILE_TARGET_DATA ICC profile is correctly + // updated. + + jxl::ColorEncoding temp_jxl_linear = jxl::ColorEncoding::LinearSRGB(false); + JxlColorEncoding pixel_encoding_linear = temp_jxl_linear.ToExternal(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreferredColorProfile(dec, &pixel_encoding_linear)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile( + dec, JXL_COLOR_PROFILE_TARGET_DATA, &pixel_encoding)); + EXPECT_EQ(JXL_TRANSFER_FUNCTION_LINEAR, pixel_encoding.transfer_function); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_DATA, + &dec_profile_size)); + EXPECT_NE(0u, dec_profile_size); + jxl::IccBytes icc_profile3(dec_profile_size); + if (0 != dec_profile_size) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile3.data(), icc_profile3.size())); + // expected not equal to the previously set preferred profile. + EXPECT_NE(icc_profile2, icc_profile3); + } + + JxlDecoderDestroy(dec); +} + +// Test decoding ICC from partial files byte for byte. +// This test must pass also if JXL_CRASH_ON_ERROR is enabled, that is, the +// decoding of the ANS histogram and stream of the encoded ICC profile must also +// handle the case of not enough input bytes with StatusCode::kNotEnoughBytes +// rather than fatal error status codes. +TEST(DecodeTest, ICCPartialTest) { + jxl::IccBytes icc_profile = GetIccTestProfile(); + std::vector<uint8_t> data = GetIccTestHeader(icc_profile, false); + + const uint8_t* next_in = data.data(); + size_t avail_in = 0; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING)); + + bool seen_basic_info = false; + bool seen_color_encoding = false; + size_t total_size = 0; + + for (;;) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (total_size >= data.size()) { + // End of partial codestream with codestrema headers and ICC profile + // reached, it should not require more input since full image is not + // requested + FAIL(); + break; + } + size_t increment = 1; + if (total_size + increment > data.size()) { + increment = data.size() - total_size; + } + total_size += increment; + avail_in += increment; + } else if (status == JXL_DEC_BASIC_INFO) { + EXPECT_FALSE(seen_basic_info); + seen_basic_info = true; + } else if (status == JXL_DEC_COLOR_ENCODING) { + EXPECT_TRUE(seen_basic_info); + EXPECT_FALSE(seen_color_encoding); + seen_color_encoding = true; + + // Sanity check that the ICC profile was decoded correctly + size_t dec_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize( + dec, JXL_COLOR_PROFILE_TARGET_ORIGINAL, &dec_profile_size)); + EXPECT_EQ(icc_profile.size(), dec_profile_size); + + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_TRUE(seen_color_encoding); + break; + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + EXPECT_TRUE(seen_basic_info); + EXPECT_TRUE(seen_color_encoding); + + JxlDecoderDestroy(dec); +} + +struct PixelTestConfig { + // Input image definition. + bool grayscale; + bool include_alpha; + size_t xsize; + size_t ysize; + jxl::PreviewMode preview_mode; + bool add_intrinsic_size; + // Output format. + JxlEndianness endianness; + JxlDataType data_type; + uint32_t output_channels; + // Container options. + CodeStreamBoxFormat add_container; + // Decoding mode. + bool use_callback; + bool set_buffer_early; + bool use_resizable_runner; + // Exif orientation, 1-8 + JxlOrientation orientation; + bool keep_orientation; + size_t upsampling; +}; + +class DecodeTestParam : public ::testing::TestWithParam<PixelTestConfig> {}; + +TEST_P(DecodeTestParam, PixelTest) { + PixelTestConfig config = GetParam(); + JxlDecoder* dec = JxlDecoderCreate(NULL); + + if (config.keep_orientation) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetKeepOrientation(dec, JXL_TRUE)); + } + + size_t num_pixels = config.xsize * config.ysize; + uint32_t orig_channels = + (config.grayscale ? 1 : 3) + (config.include_alpha ? 1 : 0); + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(config.xsize, config.ysize, orig_channels, 0); + JxlPixelFormat format_orig = {orig_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, + 0}; + jxl::TestCodestreamParams params; + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + params.cparams.resampling = config.upsampling; + params.cparams.ec_resampling = config.upsampling; + params.box_format = config.add_container; + params.orientation = config.orientation; + params.preview_mode = config.preview_mode; + params.add_intrinsic_size = config.add_intrinsic_size; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), config.xsize, config.ysize, + orig_channels, params); + + JxlPixelFormat format = {config.output_channels, config.data_type, + config.endianness, 0}; + + bool swap_xy = !config.keep_orientation && (config.orientation > 4); + size_t xsize = swap_xy ? config.ysize : config.xsize; + size_t ysize = swap_xy ? config.xsize : config.ysize; + + std::vector<uint8_t> pixels2 = + jxl::DecodeWithAPI(dec, jxl::Bytes(compressed.data(), compressed.size()), + format, config.use_callback, config.set_buffer_early, + config.use_resizable_runner, /*require_boxes=*/false, + /*expect_success=*/true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * config.output_channels * + jxl::test::GetDataBits(config.data_type) / jxl::kBitsPerByte, + pixels2.size()); + + // If an orientation transformation is expected, to compare the pixels, also + // apply this transformation to the original pixels. ConvertToExternal is + // used to achieve this, with a temporary conversion to CodecInOut and back. + if (config.orientation > 1 && !config.keep_orientation) { + jxl::Span<const uint8_t> bytes(pixels.data(), pixels.size()); + jxl::ColorEncoding color_encoding = + jxl::ColorEncoding::SRGB(config.grayscale); + + jxl::CodecInOut io; + if (config.include_alpha) io.metadata.m.SetAlphaBits(16); + io.metadata.m.color_encoding = color_encoding; + io.SetSize(config.xsize, config.ysize); + + EXPECT_TRUE(ConvertFromExternal(bytes, config.xsize, config.ysize, + color_encoding, 16, format_orig, nullptr, + &io.Main())); + + for (size_t i = 0; i < pixels.size(); i++) pixels[i] = 0; + EXPECT_TRUE(ConvertToExternal( + io.Main(), 16, + /*float_out=*/false, orig_channels, JXL_BIG_ENDIAN, + xsize * 2 * orig_channels, nullptr, pixels.data(), pixels.size(), + /*out_callback=*/{}, + static_cast<jxl::Orientation>(config.orientation))); + } + if (config.upsampling == 1) { + EXPECT_EQ(0u, jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } else { + // resampling is of course not lossless, so as a rough check: + // count pixels that are more than off-by-25 in the 8-bit value of one of + // the channels + EXPECT_LE( + jxl::test::ComparePixels( + pixels.data(), pixels2.data(), xsize, ysize, format_orig, format, + 50.0 * (config.data_type == JXL_TYPE_UINT8 ? 1.0 : 256.0)), + 300u); + } + + JxlDecoderDestroy(dec); +} + +std::vector<PixelTestConfig> GeneratePixelTests() { + std::vector<PixelTestConfig> all_tests; + struct ChannelInfo { + bool grayscale; + bool include_alpha; + size_t output_channels; + }; + ChannelInfo ch_info[] = { + {false, true, 4}, // RGBA -> RGBA + {true, false, 1}, // G -> G + {true, true, 1}, // GA -> G + {true, true, 2}, // GA -> GA + {false, false, 3}, // RGB -> RGB + {false, true, 3}, // RGBA -> RGB + {false, false, 4}, // RGB -> RGBA + }; + + struct OutputFormat { + JxlEndianness endianness; + JxlDataType data_type; + }; + OutputFormat out_formats[] = { + {JXL_NATIVE_ENDIAN, JXL_TYPE_UINT8}, + {JXL_LITTLE_ENDIAN, JXL_TYPE_UINT16}, + {JXL_BIG_ENDIAN, JXL_TYPE_UINT16}, + {JXL_NATIVE_ENDIAN, JXL_TYPE_FLOAT16}, + {JXL_LITTLE_ENDIAN, JXL_TYPE_FLOAT}, + {JXL_BIG_ENDIAN, JXL_TYPE_FLOAT}, + }; + + auto make_test = [&](ChannelInfo ch, size_t xsize, size_t ysize, + jxl::PreviewMode preview_mode, bool intrinsic_size, + CodeStreamBoxFormat box, JxlOrientation orientation, + bool keep_orientation, OutputFormat format, + bool use_callback, bool set_buffer_early, + bool resizable_runner, size_t upsampling) { + PixelTestConfig c; + c.grayscale = ch.grayscale; + c.include_alpha = ch.include_alpha; + c.preview_mode = preview_mode; + c.add_intrinsic_size = intrinsic_size; + c.xsize = xsize; + c.ysize = ysize; + c.add_container = (CodeStreamBoxFormat)box; + c.output_channels = ch.output_channels; + c.data_type = format.data_type; + c.endianness = format.endianness; + c.use_callback = use_callback; + c.set_buffer_early = set_buffer_early; + c.use_resizable_runner = resizable_runner; + c.orientation = orientation; + c.keep_orientation = keep_orientation; + c.upsampling = upsampling; + all_tests.push_back(c); + }; + + // Test output formats and methods. + for (ChannelInfo ch : ch_info) { + for (int use_callback = 0; use_callback <= 1; use_callback++) { + for (size_t upsampling : {1, 2, 4, 8}) { + for (OutputFormat fmt : out_formats) { + make_test(ch, 301, 33, jxl::kNoPreview, + /*add_intrinsic_size=*/false, + CodeStreamBoxFormat::kCSBF_None, JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, fmt, use_callback, + /*set_buffer_early=*/false, /*resizable_runner=*/false, + upsampling); + } + } + } + } + // Test codestream formats. + for (size_t box = 1; box < kCSBF_NUM_ENTRIES; ++box) { + make_test(ch_info[0], 77, 33, jxl::kNoPreview, + /*add_intrinsic_size=*/false, (CodeStreamBoxFormat)box, + JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, out_formats[0], + /*use_callback=*/false, + /*set_buffer_early=*/false, /*resizable_runner=*/false, 1); + } + // Test previews. + for (int preview_mode = 0; preview_mode < jxl::kNumPreviewModes; + preview_mode++) { + make_test(ch_info[0], 77, 33, (jxl::PreviewMode)preview_mode, + /*add_intrinsic_size=*/false, CodeStreamBoxFormat::kCSBF_None, + JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, out_formats[0], + /*use_callback=*/false, /*set_buffer_early=*/false, + /*resizable_runner=*/false, 1); + } + // Test intrinsic sizes. + for (int add_intrinsic_size = 0; add_intrinsic_size <= 1; + add_intrinsic_size++) { + make_test(ch_info[0], 55, 34, jxl::kNoPreview, add_intrinsic_size, + CodeStreamBoxFormat::kCSBF_None, JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, out_formats[0], + /*use_callback=*/false, /*set_buffer_early=*/false, + /*resizable_runner=*/false, 1); + } + // Test setting buffers early. + make_test(ch_info[0], 300, 33, jxl::kNoPreview, + /*add_intrinsic_size=*/false, CodeStreamBoxFormat::kCSBF_None, + JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, out_formats[0], + /*use_callback=*/false, /*set_buffer_early=*/true, + /*resizable_runner=*/false, 1); + + // Test using the resizable runner + for (size_t i = 0; i < 4; i++) { + make_test(ch_info[0], 300 << i, 33 << i, jxl::kNoPreview, + /*add_intrinsic_size=*/false, CodeStreamBoxFormat::kCSBF_None, + JXL_ORIENT_IDENTITY, + /*keep_orientation=*/false, out_formats[0], + /*use_callback=*/false, /*set_buffer_early=*/false, + /*resizable_runner=*/true, 1); + } + + // Test orientations. + for (int orientation = 2; orientation <= 8; ++orientation) { + for (int keep_orientation = 0; keep_orientation <= 1; keep_orientation++) { + for (int use_callback = 0; use_callback <= 1; use_callback++) { + for (ChannelInfo ch : ch_info) { + for (OutputFormat fmt : out_formats) { + make_test(ch, 280, 12, jxl::kNoPreview, + /*add_intrinsic_size=*/false, + CodeStreamBoxFormat::kCSBF_None, + static_cast<JxlOrientation>(orientation), + /*keep_orientation=*/keep_orientation, fmt, + /*use_callback=*/use_callback, /*set_buffer_early=*/true, + /*resizable_runner=*/false, 1); + } + } + } + } + } + + return all_tests; +} + +std::ostream& operator<<(std::ostream& os, const PixelTestConfig& c) { + os << c.xsize << "x" << c.ysize; + const char* colors[] = {"", "G", "GA", "RGB", "RGBA"}; + os << colors[(c.grayscale ? 1 : 3) + (c.include_alpha ? 1 : 0)]; + os << "to"; + os << colors[c.output_channels]; + switch (c.data_type) { + case JXL_TYPE_UINT8: + os << "u8"; + break; + case JXL_TYPE_UINT16: + os << "u16"; + break; + case JXL_TYPE_FLOAT: + os << "f32"; + break; + case JXL_TYPE_FLOAT16: + os << "f16"; + break; + default: + JXL_ASSERT(false); + }; + if (jxl::test::GetDataBits(c.data_type) > jxl::kBitsPerByte) { + if (c.endianness == JXL_NATIVE_ENDIAN) { + // add nothing + } else if (c.endianness == JXL_BIG_ENDIAN) { + os << "BE"; + } else if (c.endianness == JXL_LITTLE_ENDIAN) { + os << "LE"; + } + } + if (c.add_container != CodeStreamBoxFormat::kCSBF_None) { + os << "Box"; + os << (size_t)c.add_container; + } + if (c.preview_mode == jxl::kSmallPreview) os << "Preview"; + if (c.preview_mode == jxl::kBigPreview) os << "BigPreview"; + if (c.add_intrinsic_size) os << "IntrinicSize"; + if (c.use_callback) os << "Callback"; + if (c.set_buffer_early) os << "EarlyBuffer"; + if (c.use_resizable_runner) os << "ResizableRunner"; + if (c.orientation != 1) os << "O" << c.orientation; + if (c.keep_orientation) os << "Keep"; + if (c.upsampling > 1) os << "x" << c.upsampling; + return os; +} + +std::string PixelTestDescription( + const testing::TestParamInfo<DecodeTestParam::ParamType>& info) { + std::stringstream name; + name << info.param; + return name.str(); +} + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(DecodeTest, DecodeTestParam, + testing::ValuesIn(GeneratePixelTests()), + PixelTestDescription); + +TEST(DecodeTest, PixelTestWithICCProfileLossless) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat format_orig = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::TestCodestreamParams params; + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + params.add_icc_profile = true; + // For variation: some have container and no preview, others have preview + // and no container. + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 4, params); + + for (uint32_t channels = 3; channels <= 4; ++channels) { + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + std::vector<uint8_t> pixels2 = jxl::DecodeWithAPI( + dec, jxl::Bytes(compressed.data(), compressed.size()), format, + /*use_callback=*/false, /*set_buffer_early=*/false, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels, pixels2.size()); + EXPECT_EQ(0u, + jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } + { + JxlPixelFormat format = {channels, JXL_TYPE_UINT16, JXL_LITTLE_ENDIAN, 0}; + + // Test with the container for one of the pixel formats. + std::vector<uint8_t> pixels2 = jxl::DecodeWithAPI( + dec, jxl::Bytes(compressed.data(), compressed.size()), format, + /*use_callback=*/true, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 2, pixels2.size()); + EXPECT_EQ(0u, + jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } + + { + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + + std::vector<uint8_t> pixels2 = jxl::DecodeWithAPI( + dec, jxl::Bytes(compressed.data(), compressed.size()), format, + /*use_callback=*/false, /*set_buffer_early=*/false, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + EXPECT_EQ(0u, + jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } + } + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, PixelTestWithICCProfileLossy) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format_orig = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::TestCodestreamParams params; + params.add_icc_profile = true; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 3, params); + uint32_t channels = 3; + + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + + std::vector<uint8_t> icc_data; + std::vector<uint8_t> pixels2 = jxl::DecodeWithAPI( + dec, jxl::Bytes(compressed.data(), compressed.size()), format, + /*use_callback=*/false, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true, /*icc=*/&icc_data); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + + // The input pixels use the profile matching GetIccTestProfile, since we set + // add_icc_profile for CreateTestJXLCodestream to true. + jxl::ColorEncoding color_encoding0; + EXPECT_TRUE(color_encoding0.SetICC(GetIccTestProfile(), JxlGetDefaultCms())); + jxl::Span<const uint8_t> span0(pixels.data(), pixels.size()); + jxl::CodecInOut io0; + io0.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal(span0, xsize, ysize, color_encoding0, + /*bits_per_sample=*/16, format_orig, + /*pool=*/nullptr, &io0.Main())); + + jxl::ColorEncoding color_encoding1; + jxl::IccBytes icc; + jxl::Bytes(icc_data).AppendTo(&icc); + EXPECT_TRUE(color_encoding1.SetICC(std::move(icc), JxlGetDefaultCms())); + jxl::Span<const uint8_t> span1(pixels2.data(), pixels2.size()); + jxl::CodecInOut io1; + io1.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal(span1, xsize, ysize, color_encoding1, + /*bits_per_sample=*/32, format, + /*pool=*/nullptr, &io1.Main())); + + jxl::ButteraugliParams ba; + EXPECT_THAT( + ButteraugliDistance(io0.frames, io1.frames, ba, *JxlGetDefaultCms(), + /*distmap=*/nullptr, nullptr), + IsSlightlyBelow(0.56f)); + + JxlDecoderDestroy(dec); +} + +std::string ColorDescription(JxlColorEncoding c) { + jxl::ColorEncoding color_encoding; + EXPECT_TRUE(color_encoding.FromExternal(c)); + return Description(color_encoding); +} + +std::string GetOrigProfile(JxlDecoder* dec) { + JxlColorEncoding c; + JxlColorProfileTarget target = JXL_COLOR_PROFILE_TARGET_ORIGINAL; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile(dec, target, &c)); + return ColorDescription(c); +} + +std::string GetDataProfile(JxlDecoder* dec) { + JxlColorEncoding c; + JxlColorProfileTarget target = JXL_COLOR_PROFILE_TARGET_DATA; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetColorAsEncodedProfile(dec, target, &c)); + return ColorDescription(c); +} + +double ButteraugliDistance(size_t xsize, size_t ysize, + const std::vector<uint8_t>& pixels_in, + const jxl::ColorEncoding& color_in, + float intensity_in, + const std::vector<uint8_t>& pixels_out, + const jxl::ColorEncoding& color_out, + float intensity_out) { + jxl::CodecInOut in; + in.metadata.m.color_encoding = color_in; + in.metadata.m.SetIntensityTarget(intensity_in); + JxlPixelFormat format_in = {static_cast<uint32_t>(color_in.Channels()), + JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + EXPECT_TRUE(jxl::ConvertFromExternal( + jxl::Bytes(pixels_in.data(), pixels_in.size()), xsize, ysize, color_in, + /*bits_per_sample=*/16, format_in, + /*pool=*/nullptr, &in.Main())); + jxl::CodecInOut out; + out.metadata.m.color_encoding = color_out; + out.metadata.m.SetIntensityTarget(intensity_out); + JxlPixelFormat format_out = {static_cast<uint32_t>(color_out.Channels()), + JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + EXPECT_TRUE(jxl::ConvertFromExternal( + jxl::Bytes(pixels_out.data(), pixels_out.size()), xsize, ysize, color_out, + /*bits_per_sample=*/16, format_out, + /*pool=*/nullptr, &out.Main())); + return ButteraugliDistance(in.frames, out.frames, jxl::ButteraugliParams(), + *JxlGetDefaultCms(), nullptr, nullptr); +} + +class DecodeAllEncodingsTest + : public ::testing::TestWithParam<jxl::test::ColorEncodingDescriptor> {}; +JXL_GTEST_INSTANTIATE_TEST_SUITE_P( + DecodeAllEncodingsTestInstantiation, DecodeAllEncodingsTest, + ::testing::ValuesIn(jxl::test::AllEncodings())); +TEST_P(DecodeAllEncodingsTest, PreserveOriginalProfileTest) { + size_t xsize = 123, ysize = 77; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + int events = JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING | JXL_DEC_FULL_IMAGE; + const auto& cdesc = GetParam(); + jxl::ColorEncoding c_in = jxl::test::ColorEncodingFromDescriptor(cdesc); + if (c_in.GetRenderingIntent() != jxl::RenderingIntent::kRelative) return; + std::string color_space_in = Description(c_in); + float intensity_in = c_in.Tf().IsPQ() ? 10000 : 255; + printf("Testing input color space %s\n", color_space_in.c_str()); + jxl::TestCodestreamParams params; + params.color_space = color_space_in; + params.intensity_target = intensity_in; + std::vector<uint8_t> data = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 3, params); + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), data.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + EXPECT_FALSE(info.uses_original_profile); + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + EXPECT_EQ(GetOrigProfile(dec), color_space_in); + EXPECT_EQ(GetDataProfile(dec), color_space_in); + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + std::vector<uint8_t> out(pixels.size()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, out.data(), out.size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + double dist = ButteraugliDistance(xsize, ysize, pixels, c_in, intensity_in, + out, c_in, intensity_in); + EXPECT_LT(dist, 1.29); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlDecoderDestroy(dec); +} + +namespace { +void SetPreferredColorProfileTest( + const jxl::test::ColorEncodingDescriptor& from, bool icc_dst, + bool use_cms) { + size_t xsize = 123, ysize = 77; + int events = JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING | JXL_DEC_FULL_IMAGE; + jxl::ColorEncoding c_in = jxl::test::ColorEncodingFromDescriptor(from); + if (c_in.GetRenderingIntent() != jxl::RenderingIntent::kRelative) return; + if (c_in.GetWhitePointType() != jxl::WhitePoint::kD65) return; + uint32_t num_channels = c_in.Channels(); + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::string color_space_in = Description(c_in); + float intensity_in = c_in.Tf().IsPQ() ? 10000 : 255; + jxl::TestCodestreamParams params; + params.color_space = color_space_in; + params.intensity_target = intensity_in; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + auto all_encodings = jxl::test::AllEncodings(); + // TODO(firsching): understand why XYB does not work together with icc_dst. + if (!icc_dst) { + all_encodings.push_back( + {jxl::ColorSpace::kXYB, jxl::WhitePoint::kD65, jxl::Primaries::kCustom, + jxl::TransferFunction::kUnknown, jxl::RenderingIntent::kPerceptual}); + } + for (const auto& c1 : all_encodings) { + jxl::ColorEncoding c_out = jxl::test::ColorEncodingFromDescriptor(c1); + float intensity_out = intensity_in; + if (c_out.GetColorSpace() != jxl::ColorSpace::kXYB) { + if (c_out.GetRenderingIntent() != jxl::RenderingIntent::kRelative) { + continue; + } + if ((c_in.GetPrimariesType() == jxl::Primaries::k2100 && + c_out.GetPrimariesType() != jxl::Primaries::k2100) || + (c_in.GetPrimariesType() == jxl::Primaries::kP3 && + c_out.GetPrimariesType() == jxl::Primaries::kSRGB)) { + // Converting to a narrower gamut does not work without gamut mapping. + continue; + } + } + if (c_out.Tf().IsHLG() && intensity_out > 300) { + // The Linear->HLG OOTF function at this intensity level can push + // saturated colors out of gamut, so we would need gamut mapping in + // this case too. + continue; + } + std::string color_space_out = Description(c_out); + if (color_space_in == color_space_out) continue; + printf("Testing input color space %s with output color space %s\n", + color_space_in.c_str(), color_space_out.c_str()); + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, data.data(), data.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + EXPECT_FALSE(info.uses_original_profile); + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + EXPECT_EQ(GetOrigProfile(dec), color_space_in); + JxlColorEncoding encoding_out; + EXPECT_TRUE(jxl::ParseDescription(color_space_out, &encoding_out)); + if (c_out.GetColorSpace() == jxl::ColorSpace::kXYB && + (c_in.GetPrimariesType() != jxl::Primaries::kSRGB || + c_in.Tf().IsPQ())) { + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderSetPreferredColorProfile(dec, &encoding_out)); + JxlDecoderDestroy(dec); + continue; + } + if (use_cms) { + JxlDecoderSetCms(dec, *JxlGetDefaultCms()); + } + if (icc_dst) { + jxl::ColorEncoding internal_encoding_out; + EXPECT_TRUE(internal_encoding_out.FromExternal(encoding_out)); + EXPECT_TRUE(internal_encoding_out.CreateICC()); + std::vector<uint8_t> rewritten_icc = internal_encoding_out.ICC(); + + EXPECT_EQ(use_cms ? JXL_DEC_SUCCESS : JXL_DEC_ERROR, + JxlDecoderSetOutputColorProfile( + dec, nullptr, rewritten_icc.data(), rewritten_icc.size())); + if (!use_cms) { + // continue if we don't have a cms here + JxlDecoderDestroy(dec); + continue; + } + } else { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreferredColorProfile(dec, &encoding_out)); + } + EXPECT_EQ(GetOrigProfile(dec), color_space_in); + if (icc_dst) { + } else { + EXPECT_EQ(GetDataProfile(dec), color_space_out); + } + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + size_t buffer_size; + JxlPixelFormat out_format = format; + out_format.num_channels = c_out.Channels(); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &out_format, &buffer_size)); + std::vector<uint8_t> out(buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &out_format, out.data(), out.size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + double dist = ButteraugliDistance(xsize, ysize, pixels, c_in, intensity_in, + out, c_out, intensity_out); + + if (c_in.GetWhitePointType() == c_out.GetWhitePointType()) { + EXPECT_LT(dist, 1.29); + } else { + EXPECT_LT(dist, 4.0); + } + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlDecoderDestroy(dec); + } +} +} // namespace + +TEST(DecodeTest, SetPreferredColorProfileTestFromGray) { + jxl::test::ColorEncodingDescriptor gray = { + jxl::ColorSpace::kGray, jxl::WhitePoint::kD65, jxl::Primaries::kSRGB, + jxl::TransferFunction::kSRGB, jxl::RenderingIntent::kRelative}; + SetPreferredColorProfileTest(gray, true, true); + SetPreferredColorProfileTest(gray, false, true); + SetPreferredColorProfileTest(gray, true, false); + SetPreferredColorProfileTest(gray, false, false); +} + +static std::string DecodeAllEncodingsVariantsTestName( + const ::testing::TestParamInfo< + std::tuple<jxl::test::ColorEncodingDescriptor, bool, bool>>& info) { + const auto& encoding = std::get<0>(info.param); + bool icc_dst = std::get<1>(info.param); + bool use_cms = std::get<2>(info.param); + + std::string encoding_name = + Description(ColorEncodingFromDescriptor(encoding)); + + return "From_" + encoding_name + + (icc_dst ? "_with_icc_dst" : "_without_icc_dst") + + (use_cms ? "_with_cms" : "_without_cms"); +} + +class DecodeAllEncodingsVariantsTest + : public ::testing::TestWithParam< + std::tuple<jxl::test::ColorEncodingDescriptor, bool, bool>> {}; +JXL_GTEST_INSTANTIATE_TEST_SUITE_P( + DecodeAllEncodingsVariantsTestInstantiation, DecodeAllEncodingsVariantsTest, + ::testing::Combine(::testing::ValuesIn(jxl::test::AllEncodings()), + ::testing::Bool(), ::testing::Bool()), + DecodeAllEncodingsVariantsTestName); +TEST_P(DecodeAllEncodingsVariantsTest, SetPreferredColorProfileTest) { + const auto& from = std::get<0>(GetParam()); + bool icc_dst = std::get<1>(GetParam()); + bool use_cms = std::get<2>(GetParam()); + SetPreferredColorProfileTest(from, icc_dst, use_cms); +} + +void DecodeImageWithColorEncoding(const std::vector<uint8_t>& compressed, + jxl::ColorEncoding& color_encoding, + bool with_cms, std::vector<uint8_t>& out, + JxlBasicInfo& info) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + int events = JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING | JXL_DEC_FULL_IMAGE; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + std::string color_space_in = GetOrigProfile(dec); + if (with_cms) { + JxlDecoderSetCms(dec, *JxlGetDefaultCms()); + EXPECT_TRUE(color_encoding.CreateICC()); + std::vector<uint8_t> rewritten_icc = color_encoding.ICC(); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetOutputColorProfile( + dec, nullptr, rewritten_icc.data(), rewritten_icc.size())); + } else { + JxlColorEncoding external_color_encoding = color_encoding.ToExternal(); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetOutputColorProfile( + dec, &external_color_encoding, nullptr, 0)); + } + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + size_t buffer_size; + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + JxlPixelFormat out_format = format; + out_format.num_channels = color_encoding.Channels(); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &out_format, &buffer_size)); + out.resize(buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &out_format, out.data(), out.size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + JxlDecoderDestroy(dec); +} + +class DecodeAllEncodingsWithCMSTest + : public ::testing::TestWithParam<jxl::test::ColorEncodingDescriptor> {}; + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P( + AllEncodings, DecodeAllEncodingsWithCMSTest, + testing::ValuesIn(jxl::test::AllEncodings())); + +TEST_P(DecodeAllEncodingsWithCMSTest, DecodeWithCMS) { + auto all_encodings = jxl::test::AllEncodings(); + uint32_t num_channels = 3; + size_t xsize = 177, ysize = 123; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + + jxl::ColorEncoding color_encoding = + jxl::test::ColorEncodingFromDescriptor(GetParam()); + fprintf(stderr, "color_description: %s\n", + Description(color_encoding).c_str()); + + std::vector<uint8_t> out_with_cms; + JxlBasicInfo info_with_cms; + DecodeImageWithColorEncoding(data, color_encoding, true, out_with_cms, + info_with_cms); + + std::vector<uint8_t> out_without_cms; + JxlBasicInfo info_without_cms; + DecodeImageWithColorEncoding(data, color_encoding, false, out_without_cms, + info_without_cms); + + EXPECT_EQ(info_with_cms.xsize, info_without_cms.xsize); + EXPECT_EQ(info_with_cms.ysize, info_without_cms.ysize); + EXPECT_EQ(out_with_cms.size(), out_without_cms.size()); + double dist = ButteraugliDistance(xsize, ysize, out_with_cms, color_encoding, + 255, out_without_cms, color_encoding, 255); + + EXPECT_LT(dist, .1); +} + +// Tests the case of lossy sRGB image without alpha channel, decoded to RGB8 +// and to RGBA8 +TEST(DecodeTest, PixelTestOpaqueSrgbLossy) { + for (unsigned channels = 3; channels <= 4; channels++) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format_orig = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 3, + jxl::TestCodestreamParams()); + + JxlPixelFormat format = {channels, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + std::vector<uint8_t> pixels2 = jxl::DecodeWithAPI( + dec, jxl::Bytes(compressed.data(), compressed.size()), format, + /*use_callback=*/true, /*set_buffer_early=*/false, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success*/ true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels, pixels2.size()); + + jxl::ColorEncoding color_encoding0 = jxl::ColorEncoding::SRGB(false); + jxl::Span<const uint8_t> span0(pixels.data(), pixels.size()); + jxl::CodecInOut io0; + io0.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal(span0, xsize, ysize, color_encoding0, + /*bits_per_sample=*/16, format_orig, + /*pool=*/nullptr, &io0.Main())); + + jxl::ColorEncoding color_encoding1 = jxl::ColorEncoding::SRGB(false); + jxl::Span<const uint8_t> span1(pixels2.data(), pixels2.size()); + jxl::CodecInOut io1; + EXPECT_TRUE(ConvertFromExternal(span1, xsize, ysize, color_encoding1, + /*bits_per_sample=*/8, format, + /*pool=*/nullptr, &io1.Main())); + + jxl::ButteraugliParams ba; + EXPECT_THAT( + ButteraugliDistance(io0.frames, io1.frames, ba, *JxlGetDefaultCms(), + /*distmap=*/nullptr, nullptr), + IsSlightlyBelow(0.65f)); + + JxlDecoderDestroy(dec); + } +} + +// Opaque image with noise enabled, decoded to RGB8 and RGBA8. +TEST(DecodeTest, PixelTestOpaqueSrgbLossyNoise) { + for (unsigned channels = 3; channels <= 4; channels++) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + + size_t xsize = 512, ysize = 300; + size_t num_pixels = xsize * ysize; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format_orig = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::TestCodestreamParams params; + params.cparams.noise = jxl::Override::kOn; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 3, params); + + JxlPixelFormat format = {channels, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + std::vector<uint8_t> pixels2 = jxl::DecodeWithAPI( + dec, jxl::Bytes(compressed.data(), compressed.size()), format, + /*use_callback=*/false, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true); + JxlDecoderReset(dec); + EXPECT_EQ(num_pixels * channels, pixels2.size()); + + jxl::ColorEncoding color_encoding0 = jxl::ColorEncoding::SRGB(false); + jxl::Span<const uint8_t> span0(pixels.data(), pixels.size()); + jxl::CodecInOut io0; + io0.SetSize(xsize, ysize); + EXPECT_TRUE(ConvertFromExternal(span0, xsize, ysize, color_encoding0, + /*bits_per_sample=*/16, format_orig, + /*pool=*/nullptr, &io0.Main())); + + jxl::ColorEncoding color_encoding1 = jxl::ColorEncoding::SRGB(false); + jxl::Span<const uint8_t> span1(pixels2.data(), pixels2.size()); + jxl::CodecInOut io1; + EXPECT_TRUE(ConvertFromExternal(span1, xsize, ysize, color_encoding1, + /*bits_per_sample=*/8, format, + /*pool=*/nullptr, &io1.Main())); + + jxl::ButteraugliParams ba; + EXPECT_THAT( + ButteraugliDistance(io0.frames, io1.frames, ba, *JxlGetDefaultCms(), + /*distmap=*/nullptr, nullptr), + IsSlightlyBelow(1.3f)); + + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, ProcessEmptyInputWithBoxes) { + size_t xsize = 123, ysize = 77; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + jxl::CompressParams cparams; + uint32_t channels = 3; + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + JxlDecoder* dec = JxlDecoderCreate(NULL); + jxl::TestCodestreamParams params; + params.box_format = (CodeStreamBoxFormat)i; + printf("Testing empty input with box format %d\n", (int)params.box_format); + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 3, params); + const int events = + JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE | JXL_DEC_COLOR_ENCODING; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events)); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + const size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, compressed.size()); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, ExtraBytesAfterCompressedStream) { + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + jxl::CompressParams cparams; + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + CodeStreamBoxFormat box_format = (CodeStreamBoxFormat)i; + if (box_format == kCSBF_Multi_Other_Zero_Terminated) continue; + printf("Testing with box format %d\n", (int)box_format); + size_t last_unknown_box_size = 0; + if (box_format == kCSBF_Single_Other) { + last_unknown_box_size = unk1_box_size + 8; + } else if (box_format == kCSBF_Multi_Other_Terminated) { + last_unknown_box_size = unk3_box_size + 8; + } else if (box_format == kCSBF_Multi_Last_Empty_Other) { + // If boxes are not required, the decoder won't consume the last empty + // jxlp box. + last_unknown_box_size = 12 + unk3_box_size + 8; + } + jxl::TestCodestreamParams params; + params.box_format = box_format; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 3, params); + // Add some more bytes after compressed data. + compressed.push_back(0); + compressed.push_back(1); + compressed.push_back(2); + JxlDecoder* dec = JxlDecoderCreate(NULL); + uint32_t channels = 3; + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + std::vector<uint8_t> pixels2 = jxl::DecodeWithAPI( + dec, jxl::Bytes(compressed.data(), compressed.size()), format, + /*use_callback=*/false, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true); + size_t unconsumed_bytes = JxlDecoderReleaseInput(dec); + EXPECT_EQ(last_unknown_box_size + 3, unconsumed_bytes); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, ExtraBytesAfterCompressedStreamRequireBoxes) { + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + jxl::CompressParams cparams; + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + CodeStreamBoxFormat box_format = (CodeStreamBoxFormat)i; + if (box_format == kCSBF_Multi_Other_Zero_Terminated) continue; + printf("Testing with box format %d\n", (int)box_format); + bool expect_success = (box_format == kCSBF_None || + box_format == kCSBF_Single_Zero_Terminated || + box_format == kCSBF_Multi_Zero_Terminated); + jxl::TestCodestreamParams params; + params.box_format = box_format; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 3, params); + // Add some more bytes after compressed data. + compressed.push_back(0); + compressed.push_back(1); + compressed.push_back(2); + JxlDecoder* dec = JxlDecoderCreate(NULL); + uint32_t channels = 3; + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + std::vector<uint8_t> pixels2 = jxl::DecodeWithAPI( + dec, jxl::Bytes(compressed.data(), compressed.size()), format, + /*use_callback=*/false, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/true, expect_success); + size_t unconsumed_bytes = JxlDecoderReleaseInput(dec); + EXPECT_EQ(3, unconsumed_bytes); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, ConcatenatedCompressedStreams) { + size_t xsize = 123, ysize = 77; + size_t num_pixels = xsize * ysize; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + jxl::CompressParams cparams; + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + CodeStreamBoxFormat first_box_format = (CodeStreamBoxFormat)i; + if (first_box_format == kCSBF_Multi_Other_Zero_Terminated) continue; + jxl::TestCodestreamParams params1; + params1.box_format = first_box_format; + std::vector<uint8_t> compressed1 = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 3, params1); + for (int j = 0; j < kCSBF_NUM_ENTRIES; ++j) { + CodeStreamBoxFormat second_box_format = (CodeStreamBoxFormat)j; + if (second_box_format == kCSBF_Multi_Other_Zero_Terminated) continue; + printf("Testing with box format pair %d, %d\n", (int)first_box_format, + (int)second_box_format); + jxl::TestCodestreamParams params2; + params2.box_format = second_box_format; + std::vector<uint8_t> compressed2 = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 3, params2); + std::vector<uint8_t> concat; + jxl::Bytes(compressed1).AppendTo(&concat); + jxl::Bytes(compressed2).AppendTo(&concat); + uint32_t channels = 3; + JxlPixelFormat format = {channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, 0}; + size_t remaining = concat.size(); + for (int part = 0; part < 2; ++part) { + printf(" Decoding part %d\n", part + 1); + JxlDecoder* dec = JxlDecoderCreate(NULL); + size_t pos = concat.size() - remaining; + bool expect_success = + (part == 0 || second_box_format == kCSBF_None || + second_box_format == kCSBF_Single_Zero_Terminated || + second_box_format == kCSBF_Multi_Zero_Terminated); + std::vector<uint8_t> pixels2 = jxl::DecodeWithAPI( + dec, jxl::Bytes(concat.data() + pos, remaining), format, + /*use_callback=*/false, /*set_buffer_early=*/true, + /*use_resizable_runner=*/false, /*require_boxes=*/true, + expect_success); + EXPECT_EQ(num_pixels * channels * 4, pixels2.size()); + remaining = JxlDecoderReleaseInput(dec); + JxlDecoderDestroy(dec); + } + EXPECT_EQ(0, remaining); + } + } +} + +void TestPartialStream(bool reconstructible_jpeg) { + size_t xsize = 123, ysize = 77; + uint32_t channels = 4; + if (reconstructible_jpeg) { + channels = 3; + } + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, channels, 0); + JxlPixelFormat format_orig = {channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::TestCodestreamParams params; + if (reconstructible_jpeg) { + params.cparams.color_transform = jxl::ColorTransform::kNone; + } else { + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + } + + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + + std::vector<uint8_t> jpeg_output(64); + size_t used_jpeg_output = 0; + + std::vector<std::vector<uint8_t>> codestreams(kCSBF_NUM_ENTRIES); + std::vector<std::vector<uint8_t>> jpeg_codestreams(kCSBF_NUM_ENTRIES); + for (size_t i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + params.box_format = (CodeStreamBoxFormat)i; + if (reconstructible_jpeg) { + params.jpeg_codestream = &jpeg_codestreams[i]; + } + codestreams[i] = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, channels, params); + } + + // Test multiple step sizes, to test different combinations of the streaming + // box parsing. + std::vector<size_t> increments = {1, 3, 17, 23, 120, 700, 1050}; + + for (size_t index = 0; index < increments.size(); index++) { + for (size_t i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + if (reconstructible_jpeg && + (CodeStreamBoxFormat)i == CodeStreamBoxFormat::kCSBF_None) { + continue; + } + const std::vector<uint8_t>& data = codestreams[i]; + const uint8_t* next_in = data.data(); + size_t avail_in = 0; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE | + JXL_DEC_JPEG_RECONSTRUCTION)); + + bool seen_basic_info = false; + bool seen_full_image = false; + bool seen_jpeg_recon = false; + + size_t total_size = 0; + + for (;;) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (total_size >= data.size()) { + // End of test data reached, it should have successfully decoded the + // image now. + FAIL(); + break; + } + + size_t increment = increments[index]; + // End of the file reached, should be the final test. + if (total_size + increment > data.size()) { + increment = data.size() - total_size; + } + total_size += increment; + avail_in += increment; + } else if (status == JXL_DEC_BASIC_INFO) { + // This event should happen exactly once + EXPECT_FALSE(seen_basic_info); + if (seen_basic_info) break; + seen_basic_info = true; + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + } else if (status == JXL_DEC_JPEG_RECONSTRUCTION) { + EXPECT_FALSE(seen_basic_info); + EXPECT_FALSE(seen_full_image); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec, jpeg_output.data(), + jpeg_output.size())); + seen_jpeg_recon = true; + } else if (status == JXL_DEC_JPEG_NEED_MORE_OUTPUT) { + EXPECT_TRUE(seen_jpeg_recon); + used_jpeg_output = + jpeg_output.size() - JxlDecoderReleaseJPEGBuffer(dec); + jpeg_output.resize(jpeg_output.size() * 2); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer( + dec, jpeg_output.data() + used_jpeg_output, + jpeg_output.size() - used_jpeg_output)); + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer( + dec, &format_orig, pixels2.data(), pixels2.size())); + } else if (status == JXL_DEC_FULL_IMAGE) { + // This event should happen exactly once + EXPECT_FALSE(seen_full_image); + if (seen_full_image) break; + // This event should happen after basic info + EXPECT_TRUE(seen_basic_info); + seen_full_image = true; + if (reconstructible_jpeg) { + used_jpeg_output = + jpeg_output.size() - JxlDecoderReleaseJPEGBuffer(dec); + EXPECT_EQ(used_jpeg_output, jpeg_codestreams[i].size()); + EXPECT_EQ(0, memcmp(jpeg_output.data(), jpeg_codestreams[i].data(), + used_jpeg_output)); + } else { + EXPECT_EQ(pixels, pixels2); + } + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_TRUE(seen_full_image); + break; + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + // Ensure the decoder emitted the basic info and full image events + EXPECT_TRUE(seen_basic_info); + EXPECT_TRUE(seen_full_image); + + JxlDecoderDestroy(dec); + } + } +} + +// Tests the return status when trying to decode pixels on incomplete file: it +// should return JXL_DEC_NEED_MORE_INPUT, not error. +TEST(DecodeTest, PixelPartialTest) { TestPartialStream(false); } + +// Tests the return status when trying to decode JPEG bytes on incomplete file. +TEST(DecodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGPartialTest)) { + TEST_LIBJPEG_SUPPORT(); + TestPartialStream(true); +} + +// The DC event still exists, but is no longer implemented, it is deprecated. +TEST(DecodeTest, DCNotGettableTest) { + // 1x1 pixel JXL image + std::string compressed( + "\377\n\0\20\260\23\0H\200(" + "\0\334\0U\17\0\0\250P\31e\334\340\345\\\317\227\37:," + "\246m\\gh\253m\vK\22E\306\261I\252C&pH\22\353 " + "\363\6\22\bp\0\200\237\34\231W2d\255$\1", + 68); + + JxlDecoder* dec = JxlDecoderCreate(NULL); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput( + dec, reinterpret_cast<const uint8_t*>(compressed.data()), + compressed.size())); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + + // Since the image is only 1x1 pixel, there is only 1 group, the decoder is + // unable to get DC size from this, and will not return the DC at all. Since + // no full image is requested either, it is expected to return success. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, PreviewTest) { + size_t xsize = 77, ysize = 120; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + JxlPixelFormat format_orig = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + for (jxl::PreviewMode mode : {jxl::kSmallPreview, jxl::kBigPreview}) { + jxl::TestCodestreamParams params; + params.preview_mode = mode; + + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 3, params); + + JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_PREVIEW_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)); + + jxl::ColorEncoding c_srgb = jxl::ColorEncoding::SRGB(false); + jxl::CodecInOut io0; + EXPECT_TRUE(jxl::ConvertFromExternal( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, c_srgb, + /*bits_per_sample=*/16, format_orig, /*pool=*/nullptr, &io0.Main())); + GeneratePreview(params.preview_mode, &io0.Main()); + + size_t xsize_preview = io0.Main().xsize(); + size_t ysize_preview = io0.Main().ysize(); + EXPECT_EQ(xsize_preview, info.preview.xsize); + EXPECT_EQ(ysize_preview, info.preview.ysize); + EXPECT_EQ(xsize_preview * ysize_preview * 3, buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_PREVIEW_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + std::vector<uint8_t> preview(buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreviewOutBuffer(dec, &format, preview.data(), + preview.size())); + + EXPECT_EQ(JXL_DEC_PREVIEW_IMAGE, JxlDecoderProcessInput(dec)); + + jxl::CodecInOut io1; + EXPECT_TRUE( + jxl::ConvertFromExternal(jxl::Bytes(preview.data(), preview.size()), + xsize_preview, ysize_preview, c_srgb, + /*bits_per_sample=*/8, format, + /*pool=*/nullptr, &io1.Main())); + + jxl::ButteraugliParams ba; + // TODO(lode): this ButteraugliDistance silently returns 0 (dangerous for + // tests) if xsize or ysize is < 8, no matter how different the images, a + // tiny size that could happen for a preview. ButteraugliDiffmap does + // support smaller than 8x8, but jxl's ButteraugliDistance does not. Perhaps + // move butteraugli's <8x8 handling from ButteraugliDiffmap to + // ButteraugliComparator::Diffmap in butteraugli.cc. + EXPECT_LE( + ButteraugliDistance(io0.frames, io1.frames, ba, *JxlGetDefaultCms(), + /*distmap=*/nullptr, nullptr), + mode == jxl::kSmallPreview ? 0.7f : 1.2f); + + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, AlignTest) { + size_t xsize = 123, ysize = 77; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat format_orig = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::TestCodestreamParams params; + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 4, params); + + size_t align = 17; + JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, align}; + // On purpose not using jxl::RoundUpTo to test it independently. + size_t expected_line_bytes = (1 * 3 * xsize + align - 1) / align * align; + + for (int use_callback = 0; use_callback <= 1; ++use_callback) { + std::vector<uint8_t> pixels2 = jxl::DecodeWithAPI( + jxl::Bytes(compressed.data(), compressed.size()), format, use_callback, + /*set_buffer_early=*/false, + /*use_resizable_runner=*/false, /*require_boxes=*/false, + /*expect_success=*/true); + EXPECT_EQ(expected_line_bytes * ysize, pixels2.size()); + EXPECT_EQ(0u, jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format)); + } +} + +TEST(DecodeTest, AnimationTest) { + size_t xsize = 123, ysize = 77; + static const size_t num_frames = 2; + std::vector<uint8_t> frames[2]; + frames[0] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + frames[1] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 1); + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector<uint32_t> frame_durations(num_frames); + for (size_t i = 0; i < num_frames; ++i) { + frame_durations[i] = 5 + i; + } + + for (size_t i = 0; i < num_frames; ++i) { + jxl::ImageBundle bundle(&io.metadata.m); + + EXPECT_TRUE(ConvertFromExternal( + jxl::Bytes(frames[i].data(), frames[i].size()), xsize, ysize, + jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, &bundle)); + bundle.duration = frame_durations[i]; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + std::vector<uint8_t> compressed; + EXPECT_TRUE(jxl::test::EncodeFile(cparams, &io, &compressed)); + + // Decode and test the animation frames + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + for (size_t i = 0; i < num_frames; ++i) { + std::vector<uint8_t> pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + EXPECT_EQ(0u, frame_header.name_length); + // For now, test with empty name, there's currently no easy way to encode + // a jxl file with a frame name because ImageBundle doesn't have a + // jxl::FrameHeader to set the name in. We can test the null termination + // character though. + char name; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameName(dec, &name, 1)); + EXPECT_EQ(0, name); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, AnimationTestStreaming) { + size_t xsize = 123, ysize = 77; + static const size_t num_frames = 2; + std::vector<uint8_t> frames[2]; + frames[0] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 0); + frames[1] = jxl::test::GetSomeTestImage(xsize, ysize, 3, 1); + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector<uint32_t> frame_durations(num_frames); + for (size_t i = 0; i < num_frames; ++i) { + frame_durations[i] = 5 + i; + } + + for (size_t i = 0; i < num_frames; ++i) { + jxl::ImageBundle bundle(&io.metadata.m); + + EXPECT_TRUE(ConvertFromExternal( + jxl::Bytes(frames[i].data(), frames[i].size()), xsize, ysize, + jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, &bundle)); + bundle.duration = frame_durations[i]; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + std::vector<uint8_t> compressed; + EXPECT_TRUE(jxl::test::EncodeFile(cparams, &io, &compressed)); + + // Decode and test the animation frames + + const size_t step_size = 16; + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = 0; + size_t frame_headers_seen = 0; + size_t frames_seen = 0; + bool seen_basic_info = false; + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + std::vector<uint8_t> frames2[2]; + for (size_t i = 0; i < num_frames; ++i) { + frames2[i].resize(frames[i].size()); + } + + size_t total_in = 0; + size_t loop_count = 0; + + for (;;) { + if (loop_count++ > compressed.size()) { + fprintf(stderr, "Too many loops\n"); + FAIL(); + break; + } + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + auto status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + + if (status == JXL_DEC_SUCCESS) { + break; + } else if (status == JXL_DEC_ERROR) { + FAIL(); + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + if (total_in >= compressed.size()) { + fprintf(stderr, "Already gave all input data\n"); + FAIL(); + break; + } + size_t amount = step_size; + if (total_in + amount > compressed.size()) { + amount = compressed.size() - total_in; + } + avail_in += amount; + total_in += amount; + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, frames2[frames_seen].data(), + frames2[frames_seen].size())); + } else if (status == JXL_DEC_BASIC_INFO) { + EXPECT_EQ(false, seen_basic_info); + seen_basic_info = true; + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + } else if (status == JXL_DEC_FRAME) { + EXPECT_EQ(true, seen_basic_info); + frame_headers_seen++; + } else if (status == JXL_DEC_FULL_IMAGE) { + frames_seen++; + EXPECT_EQ(frame_headers_seen, frames_seen); + } else { + fprintf(stderr, "Unexpected status: %d\n", (int)status); + FAIL(); + } + } + + EXPECT_EQ(true, seen_basic_info); + EXPECT_EQ(num_frames, frames_seen); + EXPECT_EQ(num_frames, frame_headers_seen); + for (size_t i = 0; i < num_frames; ++i) { + EXPECT_EQ(frames[i], frames2[i]); + } + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, ExtraChannelTest) { + size_t xsize = 55, ysize = 257; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat format_orig = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::TestCodestreamParams params; + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 4, params); + + size_t align = 17; + JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, align}; + + JxlDecoder* dec = JxlDecoderCreate(NULL); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(1u, info.num_extra_channels); + EXPECT_EQ(JXL_FALSE, info.alpha_premultiplied); + + JxlExtraChannelInfo extra_info; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetExtraChannelInfo(dec, 0, &extra_info)); + EXPECT_EQ(0, extra_info.type); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + size_t extra_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderExtraChannelBufferSize(dec, &format, &extra_size, 0)); + + std::vector<uint8_t> image(buffer_size); + std::vector<uint8_t> extra(extra_size); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, image.data(), image.size())); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetExtraChannelBuffer( + dec, &format, extra.data(), extra.size(), 0)); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + // After the full image was output, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlDecoderDestroy(dec); + + EXPECT_EQ(0u, jxl::test::ComparePixels(pixels.data(), image.data(), xsize, + ysize, format_orig, format)); + + // Compare the extracted extra channel with the original alpha channel + + std::vector<uint8_t> alpha(pixels.size() / 4); + for (size_t i = 0; i < pixels.size(); i += 8) { + size_t index_alpha = i / 4; + alpha[index_alpha + 0] = pixels[i + 6]; + alpha[index_alpha + 1] = pixels[i + 7]; + } + JxlPixelFormat format_alpha = format; + format_alpha.num_channels = 1; + JxlPixelFormat format_orig_alpha = format_orig; + format_orig_alpha.num_channels = 1; + + EXPECT_EQ(0u, + jxl::test::ComparePixels(alpha.data(), extra.data(), xsize, ysize, + format_orig_alpha, format_alpha)); +} + +TEST(DecodeTest, SkipCurrentFrameTest) { + size_t xsize = 90, ysize = 120; + constexpr size_t num_frames = 7; + std::vector<uint8_t> frames[num_frames]; + for (size_t i = 0; i < num_frames; i++) { + frames[i] = jxl::test::GetSomeTestImage(xsize, ysize, 3, i); + } + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector<uint32_t> frame_durations(num_frames); + for (size_t i = 0; i < num_frames; ++i) { + frame_durations[i] = 5 + i; + } + + for (size_t i = 0; i < num_frames; ++i) { + jxl::ImageBundle bundle(&io.metadata.m); + if (i & 1) { + // Mark some frames as referenceable, others not. + bundle.use_for_next_frame = true; + } + + EXPECT_TRUE(ConvertFromExternal( + jxl::Bytes(frames[i].data(), frames[i].size()), xsize, ysize, + jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, &bundle)); + bundle.duration = frame_durations[i]; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.speed_tier = jxl::SpeedTier::kThunder; + std::vector<uint8_t> compressed; + jxl::PassDefinition passes[] = {{2, 0, 4}, {4, 0, 4}, {8, 2, 2}, {8, 0, 1}}; + jxl::ProgressiveMode progressive_mode{passes}; + cparams.custom_progressive_mode = &progressive_mode; + EXPECT_TRUE(jxl::test::EncodeFile(cparams, &io, &compressed)); + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | + JXL_DEC_FRAME_PROGRESSION | + JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetProgressiveDetail(dec, kLastPasses)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + for (size_t i = 0; i < num_frames; ++i) { + printf("Decoding frame %d\n", (int)i); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSkipCurrentFrame(dec)); + std::vector<uint8_t> pixels(buffer_size); + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSkipCurrentFrame(dec)); + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + if (i == 2) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSkipCurrentFrame(dec)); + continue; + } + EXPECT_EQ(JXL_DEC_FRAME_PROGRESSION, JxlDecoderProcessInput(dec)); + EXPECT_EQ(8, JxlDecoderGetIntendedDownsamplingRatio(dec)); + if (i == 3) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSkipCurrentFrame(dec)); + continue; + } + EXPECT_EQ(JXL_DEC_FRAME_PROGRESSION, JxlDecoderProcessInput(dec)); + EXPECT_EQ(4, JxlDecoderGetIntendedDownsamplingRatio(dec)); + if (i == 4) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSkipCurrentFrame(dec)); + continue; + } + EXPECT_EQ(JXL_DEC_FRAME_PROGRESSION, JxlDecoderProcessInput(dec)); + EXPECT_EQ(2, JxlDecoderGetIntendedDownsamplingRatio(dec)); + if (i == 5) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSkipCurrentFrame(dec)); + continue; + } + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSkipCurrentFrame(dec)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, SkipFrameTest) { + size_t xsize = 90, ysize = 120; + constexpr size_t num_frames = 16; + std::vector<uint8_t> frames[num_frames]; + for (size_t i = 0; i < num_frames; i++) { + frames[i] = jxl::test::GetSomeTestImage(xsize, ysize, 3, i); + } + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector<uint32_t> frame_durations(num_frames); + for (size_t i = 0; i < num_frames; ++i) { + frame_durations[i] = 5 + i; + } + + for (size_t i = 0; i < num_frames; ++i) { + jxl::ImageBundle bundle(&io.metadata.m); + if (i & 1) { + // Mark some frames as referenceable, others not. + bundle.use_for_next_frame = true; + } + + EXPECT_TRUE(ConvertFromExternal( + jxl::Bytes(frames[i].data(), frames[i].size()), xsize, ysize, + jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, &bundle)); + bundle.duration = frame_durations[i]; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + std::vector<uint8_t> compressed; + EXPECT_TRUE(jxl::test::EncodeFile(cparams, &io, &compressed)); + + // Decode and test the animation frames + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + for (size_t i = 0; i < num_frames; ++i) { + if (i == 3) { + JxlDecoderSkipFrames(dec, 5); + i += 5; + } + std::vector<uint8_t> pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + // Test rewinding the decoder and skipping different frames + + JxlDecoderRewind(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + for (size_t i = 0; i < num_frames; ++i) { + int test_skipping = (i == 9) ? 3 : 0; + std::vector<uint8_t> pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Since this is after JXL_DEC_FRAME but before JXL_DEC_FULL_IMAGE, this + // should only skip the next frame, not the currently processed one. + if (test_skipping) JxlDecoderSkipFrames(dec, test_skipping); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + + if (test_skipping) i += test_skipping; + } + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, SkipFrameWithBlendingTest) { + size_t xsize = 90, ysize = 120; + constexpr size_t num_frames = 16; + std::vector<uint8_t> frames[num_frames]; + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames); + io.SetSize(xsize, ysize); + + std::vector<uint32_t> frame_durations(num_frames); + + for (size_t i = 0; i < num_frames; ++i) { + if (i < 5) { + std::vector<uint8_t> frame_internal = + jxl::test::GetSomeTestImage(xsize, ysize, 3, i * 2 + 1); + // An internal frame with 0 duration, and use_for_next_frame, this is a + // frame that is not rendered and not output by the API, but on which the + // rendered frames depend + jxl::ImageBundle bundle_internal(&io.metadata.m); + EXPECT_TRUE(ConvertFromExternal( + jxl::Bytes(frame_internal.data(), frame_internal.size()), xsize, + ysize, jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, &bundle_internal)); + bundle_internal.duration = 0; + bundle_internal.use_for_next_frame = true; + io.frames.push_back(std::move(bundle_internal)); + } + + std::vector<uint8_t> frame = + jxl::test::GetSomeTestImage(xsize, ysize, 3, i * 2); + // Actual rendered frame + frame_durations[i] = 5 + i; + jxl::ImageBundle bundle(&io.metadata.m); + EXPECT_TRUE(ConvertFromExternal(jxl::Bytes(frame.data(), frame.size()), + xsize, ysize, + jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, &bundle)); + bundle.duration = frame_durations[i]; + // Create some variation in which frames depend on which. + if (i != 3 && i != 9 && i != 10) { + bundle.use_for_next_frame = true; + } + if (i != 12) { + bundle.blend = true; + // Choose a blend mode that depends on the pixels of the saved frame and + // doesn't use alpha + bundle.blendmode = jxl::BlendMode::kMul; + } + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + std::vector<uint8_t> compressed; + EXPECT_TRUE(jxl::test::EncodeFile(cparams, &io, &compressed)); + + // Independently decode all frames without any skipping, to create the + // expected blended frames, for the actual tests below to compare with. + { + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetParallelRunner( + dec, JxlThreadParallelRunner, runner)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + for (size_t i = 0; i < num_frames; ++i) { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + frames[i].resize(xsize * ysize * 6); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, frames[i].data(), + frames[i].size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); + } + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetParallelRunner(dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + for (size_t i = 0; i < num_frames; ++i) { + std::vector<uint8_t> pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + + // Test rewinding mid-way, not decoding all frames. + if (i == 8) { + break; + } + } + + JxlDecoderRewind(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + for (size_t i = 0; i < num_frames; ++i) { + if (i == 3) { + JxlDecoderSkipFrames(dec, 5); + i += 5; + } + std::vector<uint8_t> pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + // Test rewinding the decoder and skipping different frames + + JxlDecoderRewind(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + for (size_t i = 0; i < num_frames; ++i) { + int test_skipping = (i == 9) ? 3 : 0; + std::vector<uint8_t> pixels(buffer_size); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Since this is after JXL_DEC_FRAME but before JXL_DEC_FULL_IMAGE, this + // should only skip the next frame, not the currently processed one. + if (test_skipping) JxlDecoderSkipFrames(dec, test_skipping); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ(frame_durations[i], frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels.data(), pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + xsize, ysize, format, format)); + + if (test_skipping) i += test_skipping; + } + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, SkipFrameWithAlphaBlendingTest) { + size_t xsize = 90, ysize = 120; + constexpr size_t num_frames = 16; + std::vector<uint8_t> frames[num_frames + 5]; + JxlPixelFormat format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.have_animation = true; + io.frames.clear(); + io.frames.reserve(num_frames + 5); + io.SetSize(xsize, ysize); + + std::vector<uint32_t> frame_durations_c; + std::vector<uint32_t> frame_durations_nc; + std::vector<uint32_t> frame_xsize, frame_ysize, frame_x0, frame_y0; + + for (size_t i = 0; i < num_frames; ++i) { + size_t cropxsize = 1 + xsize * 2 / (i + 1); + size_t cropysize = 1 + ysize * 3 / (i + 2); + int cropx0 = i * 3 - 8; + int cropy0 = i * 4 - 7; + if (i < 5) { + std::vector<uint8_t> frame_internal = + jxl::test::GetSomeTestImage(xsize / 2, ysize / 2, 4, i * 2 + 1); + // An internal frame with 0 duration, and use_for_next_frame, this is a + // frame that is not rendered and not output by default by the API, but on + // which the rendered frames depend + jxl::ImageBundle bundle_internal(&io.metadata.m); + EXPECT_TRUE(ConvertFromExternal( + jxl::Bytes(frame_internal.data(), frame_internal.size()), xsize / 2, + ysize / 2, jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, &bundle_internal)); + bundle_internal.duration = 0; + bundle_internal.use_for_next_frame = true; + bundle_internal.origin = {13, 17}; + io.frames.push_back(std::move(bundle_internal)); + frame_durations_nc.push_back(0); + frame_xsize.push_back(xsize / 2); + frame_ysize.push_back(ysize / 2); + frame_x0.push_back(13); + frame_y0.push_back(17); + } + + std::vector<uint8_t> frame = + jxl::test::GetSomeTestImage(cropxsize, cropysize, 4, i * 2); + // Actual rendered frame + jxl::ImageBundle bundle(&io.metadata.m); + EXPECT_TRUE(ConvertFromExternal(jxl::Bytes(frame.data(), frame.size()), + cropxsize, cropysize, + jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, &bundle)); + bundle.duration = 5 + i; + frame_durations_nc.push_back(5 + i); + frame_durations_c.push_back(5 + i); + frame_xsize.push_back(cropxsize); + frame_ysize.push_back(cropysize); + frame_x0.push_back(cropx0); + frame_y0.push_back(cropy0); + bundle.origin = {cropx0, cropy0}; + // Create some variation in which frames depend on which. + if (i != 3 && i != 9 && i != 10) { + bundle.use_for_next_frame = true; + } + if (i != 12) { + bundle.blend = true; + bundle.blendmode = jxl::BlendMode::kBlend; + } + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams.SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + std::vector<uint8_t> compressed; + EXPECT_TRUE(jxl::test::EncodeFile(cparams, &io, &compressed)); + // try both with and without coalescing + for (auto coalescing : {JXL_TRUE, JXL_FALSE}) { + // Independently decode all frames without any skipping, to create the + // expected blended frames, for the actual tests below to compare with. + { + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetCoalescing(dec, coalescing)); + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetParallelRunner( + dec, JxlThreadParallelRunner, runner)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + for (size_t i = 0; i < num_frames + (coalescing ? 0 : 5); ++i) { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + if (coalescing) { + EXPECT_EQ(xsize * ysize * 8, buffer_size); + } else { + EXPECT_EQ(frame_xsize[i] * frame_ysize[i] * 8, buffer_size); + } + frames[i].resize(buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, frames[i].data(), + frames[i].size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); + } + + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetCoalescing(dec, coalescing)); + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetParallelRunner( + dec, JxlThreadParallelRunner, runner)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | + JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + + for (size_t i = 0; i < num_frames; ++i) { + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + std::vector<uint8_t> pixels(buffer_size); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ((coalescing ? frame_durations_c[i] : frame_durations_nc[i]), + frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames, frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels.data(), + pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + if (coalescing) { + EXPECT_EQ(frame_header.layer_info.xsize, xsize); + } else { + EXPECT_EQ(frame_header.layer_info.xsize, frame_xsize[i]); + } + if (coalescing) { + EXPECT_EQ(frame_header.layer_info.ysize, ysize); + } else { + EXPECT_EQ(frame_header.layer_info.ysize, frame_ysize[i]); + } + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + frame_header.layer_info.xsize, + frame_header.layer_info.ysize, + format, format)); + + // Test rewinding mid-way, not decoding all frames. + if (i == 8) { + break; + } + } + + JxlDecoderRewind(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + for (size_t i = 0; i < num_frames + (coalescing ? 0 : 5); ++i) { + if (i == 3) { + JxlDecoderSkipFrames(dec, 5); + i += 5; + } + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + std::vector<uint8_t> pixels(buffer_size); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ((coalescing ? frame_durations_c[i] : frame_durations_nc[i]), + frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames + (coalescing ? 0 : 5), + frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels.data(), + pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + if (coalescing) { + EXPECT_EQ(frame_header.layer_info.xsize, xsize); + EXPECT_EQ(frame_header.layer_info.ysize, ysize); + EXPECT_EQ(frame_header.layer_info.crop_x0, 0); + EXPECT_EQ(frame_header.layer_info.crop_y0, 0); + } else { + EXPECT_EQ(frame_header.layer_info.xsize, frame_xsize[i]); + EXPECT_EQ(frame_header.layer_info.ysize, frame_ysize[i]); + EXPECT_EQ(frame_header.layer_info.crop_x0, frame_x0[i]); + EXPECT_EQ(frame_header.layer_info.crop_y0, frame_y0[i]); + EXPECT_EQ(frame_header.layer_info.blend_info.blendmode, + i != 12 + 5 && frame_header.duration != 0 + ? 2 + : 0); // kBlend or the default kReplace + } + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + frame_header.layer_info.xsize, + frame_header.layer_info.ysize, + format, format)); + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + // Test rewinding the decoder and skipping different frames + + JxlDecoderRewind(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents( + dec, JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + + for (size_t i = 0; i < num_frames + (coalescing ? 0 : 5); ++i) { + int test_skipping = (i == 9) ? 3 : 0; + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + std::vector<uint8_t> pixels(buffer_size); + + // Since this is after JXL_DEC_FRAME but before JXL_DEC_FULL_IMAGE, this + // should only skip the next frame, not the currently processed one. + if (test_skipping) JxlDecoderSkipFrames(dec, test_skipping); + + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec, &frame_header)); + EXPECT_EQ((coalescing ? frame_durations_c[i] : frame_durations_nc[i]), + frame_header.duration); + + EXPECT_EQ(i + 1 == num_frames + (coalescing ? 0 : 5), + frame_header.is_last); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels.data(), + pixels.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[i].data(), pixels.data(), + frame_header.layer_info.xsize, + frame_header.layer_info.ysize, + format, format)); + + if (test_skipping) i += test_skipping; + } + + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, OrientedCroppedFrameTest) { + const auto test = [](bool keep_orientation, uint32_t orientation, + uint32_t resampling) { + size_t xsize = 90, ysize = 120; + JxlPixelFormat format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + size_t oxsize = (!keep_orientation && orientation > 4 ? ysize : xsize); + size_t oysize = (!keep_orientation && orientation > 4 ? xsize : ysize); + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetUintSamples(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.orientation = orientation; + io.frames.clear(); + io.SetSize(xsize, ysize); + + for (size_t i = 0; i < 3; ++i) { + size_t cropxsize = 1 + xsize * 2 / (i + 1); + size_t cropysize = 1 + ysize * 3 / (i + 2); + int cropx0 = i * 3 - 8; + int cropy0 = i * 4 - 7; + + std::vector<uint8_t> frame = + jxl::test::GetSomeTestImage(cropxsize, cropysize, 4, i * 2); + jxl::ImageBundle bundle(&io.metadata.m); + EXPECT_TRUE(ConvertFromExternal( + jxl::Bytes(frame.data(), frame.size()), cropxsize, cropysize, + jxl::ColorEncoding::SRGB(/*is_gray=*/false), + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, &bundle)); + bundle.origin = {cropx0, cropy0}; + bundle.use_for_next_frame = true; + io.frames.push_back(std::move(bundle)); + } + + jxl::CompressParams cparams; + cparams + .SetLossless(); // Lossless to verify pixels exactly after roundtrip. + cparams.speed_tier = jxl::SpeedTier::kThunder; + cparams.resampling = resampling; + std::vector<uint8_t> compressed; + EXPECT_TRUE(jxl::test::EncodeFile(cparams, &io, &compressed)); + + // 0 is merged frame as decoded with coalescing enabled (default) + // 1-3 are non-coalesced frames as decoded with coalescing disabled + // 4 is the manually merged frame + std::vector<uint8_t> frames[5]; + frames[4].resize(xsize * ysize * 8, 0); + + // try both with and without coalescing + for (auto coalescing : {JXL_TRUE, JXL_FALSE}) { + // Independently decode all frames without any skipping, to create the + // expected blended frames, for the actual tests below to compare with. + { + JxlDecoder* dec = JxlDecoderCreate(NULL); + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetCoalescing(dec, coalescing)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetKeepOrientation(dec, keep_orientation)); + void* runner = JxlThreadParallelRunnerCreate( + NULL, JxlThreadParallelRunnerDefaultNumWorkerThreads()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetParallelRunner( + dec, JxlThreadParallelRunner, runner)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + for (size_t i = (coalescing ? 0 : 1); i < (coalescing ? 1 : 4); ++i) { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + JxlFrameHeader frame_header; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetFrameHeader(dec, &frame_header)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + if (coalescing) { + EXPECT_EQ(xsize * ysize * 8, buffer_size); + } else { + EXPECT_EQ(frame_header.layer_info.xsize * + frame_header.layer_info.ysize * 8, + buffer_size); + } + frames[i].resize(buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, frames[i].data(), + frames[i].size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_EQ(frame_header.layer_info.blend_info.blendmode, + JXL_BLEND_REPLACE); + if (coalescing) { + EXPECT_EQ(frame_header.layer_info.xsize, oxsize); + EXPECT_EQ(frame_header.layer_info.ysize, oysize); + EXPECT_EQ(frame_header.layer_info.crop_x0, 0); + EXPECT_EQ(frame_header.layer_info.crop_y0, 0); + } else { + // manually merge this layer + int x0 = frame_header.layer_info.crop_x0; + int y0 = frame_header.layer_info.crop_y0; + int w = frame_header.layer_info.xsize; + int h = frame_header.layer_info.ysize; + for (int y = 0; y < static_cast<int>(oysize); y++) { + if (y < y0 || y >= y0 + h) continue; + // pointers do whole 16-bit RGBA pixels at a time + uint64_t* row_merged = static_cast<uint64_t*>( + (void*)(frames[4].data() + y * oxsize * 8)); + uint64_t* row_layer = static_cast<uint64_t*>( + (void*)(frames[i].data() + (y - y0) * w * 8)); + for (int x = 0; x < static_cast<int>(oxsize); x++) { + if (x < x0 || x >= x0 + w) continue; + row_merged[x] = row_layer[x - x0]; + } + } + } + } + + // After all frames were decoded, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlThreadParallelRunnerDestroy(runner); + JxlDecoderDestroy(dec); + } + } + + EXPECT_EQ(0u, jxl::test::ComparePixels(frames[0].data(), frames[4].data(), + oxsize, oysize, format, format)); + }; + + for (bool keep_orientation : {true, false}) { + for (uint32_t orientation = 1; orientation <= 8; orientation++) { + for (uint32_t resampling : {1, 2, 4, 8}) { + SCOPED_TRACE(testing::Message() + << "keep_orientation: " << keep_orientation << ", " + << "orientation: " << orientation << ", " + << "resampling: " << resampling); + test(keep_orientation, orientation, resampling); + } + } + } +} + +struct FramePositions { + size_t frame_start; + size_t header_end; + size_t toc_end; + std::vector<size_t> section_end; +}; + +struct StreamPositions { + size_t codestream_start; + size_t codestream_end; + size_t basic_info; + size_t jbrd_end = 0; + std::vector<size_t> box_start; + std::vector<FramePositions> frames; +}; + +void AnalyzeCodestream(const std::vector<uint8_t>& data, + StreamPositions* streampos) { + // Unbox data to codestream and mark where it is broken up by boxes. + std::vector<uint8_t> codestream; + std::vector<std::pair<size_t, size_t>> breakpoints; + bool codestream_end = false; + ASSERT_LE(2, data.size()); + if (data[0] == 0xff && data[1] == 0x0a) { + codestream = std::vector<uint8_t>(data.begin(), data.end()); + streampos->codestream_start = 0; + } else { + const uint8_t* in = data.data(); + size_t pos = 0; + while (pos < data.size()) { + ASSERT_LE(pos + 8, data.size()); + streampos->box_start.push_back(pos); + size_t box_size = LoadBE32(in + pos); + if (box_size == 0) box_size = data.size() - pos; + ASSERT_LE(pos + box_size, data.size()); + if (memcmp(in + pos + 4, "jxlc", 4) == 0) { + EXPECT_TRUE(codestream.empty()); + streampos->codestream_start = pos + 8; + codestream.insert(codestream.end(), in + pos + 8, in + pos + box_size); + codestream_end = true; + } else if (memcmp(in + pos + 4, "jxlp", 4) == 0) { + codestream_end = (LoadBE32(in + pos + 8) & 0x80000000); + if (codestream.empty()) { + streampos->codestream_start = pos + 12; + } else if (box_size > 12 || !codestream_end) { + breakpoints.push_back({codestream.size(), 12}); + } + codestream.insert(codestream.end(), in + pos + 12, in + pos + box_size); + } else if (memcmp(in + pos + 4, "jbrd", 4) == 0) { + EXPECT_TRUE(codestream.empty()); + streampos->jbrd_end = pos + box_size; + } else if (!codestream.empty() && !codestream_end) { + breakpoints.push_back({codestream.size(), box_size}); + } + pos += box_size; + } + ASSERT_EQ(pos, data.size()); + } + // Translate codestream positions to boxed stream positions. + size_t offset = streampos->codestream_start; + size_t bp = 0; + auto add_offset = [&](size_t pos) { + while (bp < breakpoints.size() && pos >= breakpoints[bp].first) { + offset += breakpoints[bp++].second; + } + return pos + offset; + }; + // Analyze the unboxed codestream. + jxl::BitReader br(jxl::Bytes(codestream.data(), codestream.size())); + ASSERT_EQ(br.ReadFixedBits<16>(), 0x0AFF); + jxl::CodecMetadata metadata; + ASSERT_TRUE(ReadSizeHeader(&br, &metadata.size)); + ASSERT_TRUE(ReadImageMetadata(&br, &metadata.m)); + streampos->basic_info = + add_offset(br.TotalBitsConsumed() / jxl::kBitsPerByte); + metadata.transform_data.nonserialized_xyb_encoded = metadata.m.xyb_encoded; + ASSERT_TRUE(jxl::Bundle::Read(&br, &metadata.transform_data)); + if (metadata.m.color_encoding.WantICC()) { + std::vector<uint8_t> icc; + ASSERT_TRUE(jxl::test::ReadICC(&br, &icc)); + ASSERT_TRUE(!icc.empty()); + metadata.m.color_encoding.SetICCRaw(std::move(icc)); + } + ASSERT_TRUE(br.JumpToByteBoundary()); + bool has_preview = metadata.m.have_preview; + while (br.TotalBitsConsumed() < br.TotalBytes() * jxl::kBitsPerByte) { + FramePositions p; + p.frame_start = add_offset(br.TotalBitsConsumed() / jxl::kBitsPerByte); + jxl::FrameHeader frame_header(&metadata); + if (has_preview) { + frame_header.nonserialized_is_preview = true; + has_preview = false; + } + ASSERT_TRUE(ReadFrameHeader(&br, &frame_header)); + p.header_end = + add_offset(jxl::DivCeil(br.TotalBitsConsumed(), jxl::kBitsPerByte)); + jxl::FrameDimensions frame_dim = frame_header.ToFrameDimensions(); + uint64_t groups_total_size; + const size_t toc_entries = + jxl::NumTocEntries(frame_dim.num_groups, frame_dim.num_dc_groups, + frame_header.passes.num_passes); + std::vector<uint64_t> section_offsets; + std::vector<uint32_t> section_sizes; + ASSERT_TRUE(ReadGroupOffsets(toc_entries, &br, §ion_offsets, + §ion_sizes, &groups_total_size)); + EXPECT_EQ(br.TotalBitsConsumed() % jxl::kBitsPerByte, 0); + size_t sections_start = br.TotalBitsConsumed() / jxl::kBitsPerByte; + p.toc_end = add_offset(sections_start); + for (size_t i = 0; i < toc_entries; ++i) { + size_t end = sections_start + section_offsets[i] + section_sizes[i]; + p.section_end.push_back(add_offset(end)); + } + br.SkipBits(groups_total_size * jxl::kBitsPerByte); + streampos->frames.push_back(p); + } + streampos->codestream_end = add_offset(codestream.size()); + EXPECT_EQ(br.TotalBitsConsumed(), br.TotalBytes() * jxl::kBitsPerByte); + EXPECT_TRUE(br.Close()); +} + +enum ExpectedFlushState { NO_FLUSH, SAME_FLUSH, NEW_FLUSH }; +struct Breakpoint { + size_t file_pos; + ExpectedFlushState expect_flush; +}; + +void VerifyProgression(size_t xsize, size_t ysize, uint32_t num_channels, + const std::vector<uint8_t>& pixels, + const std::vector<uint8_t>& data, + std::vector<Breakpoint> breakpoints) { + // Size large enough for multiple groups, required to have progressive stages. + ASSERT_LT(256, xsize); + ASSERT_LT(256, ysize); + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + int bp = 0; + const uint8_t* next_in = data.data(); + size_t avail_in = breakpoints[bp].file_pos; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + double prev_dist = 1.0; + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + printf("bp: %d status: 0x%x\n", bp, (int)status); + if (status == JXL_DEC_BASIC_INFO) { + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + // Output buffer/callback not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels2.data(), + pixels2.size())); + } else if (status == JXL_DEC_FRAME) { + // Nothing to do. + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_EQ(bp + 1, breakpoints.size()); + break; + } else if (status == JXL_DEC_NEED_MORE_INPUT || + status == JXL_DEC_FULL_IMAGE) { + if (breakpoints[bp].expect_flush == NO_FLUSH) { + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + } else { + if (status != JXL_DEC_FULL_IMAGE) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + } + double dist = jxl::test::DistanceRMS(pixels2.data(), pixels.data(), + xsize, ysize, format); + if (breakpoints[bp].expect_flush == NEW_FLUSH) { + EXPECT_LT(dist, prev_dist); + prev_dist = dist; + } else { + EXPECT_EQ(dist, prev_dist); + } + } + if (status == JXL_DEC_FULL_IMAGE) { + EXPECT_EQ(bp + 1, breakpoints.size()); + continue; + } + ASSERT_LT(++bp, breakpoints.size()); + next_in += avail_in - JxlDecoderReleaseInput(dec); + avail_in = breakpoints[bp].file_pos - (next_in - data.data()); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + } else { + printf("Unexpected status: 0x%x\n", (int)status); + FAIL(); // unexpected returned status + } + } + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, ProgressionTest) { + size_t xsize = 508, ysize = 470; + uint32_t num_channels = 3; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.cparams.progressive_dc = 1; + params.preview_mode = jxl::kSmallPreview; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + StreamPositions streampos; + AnalyzeCodestream(data, &streampos); + const std::vector<FramePositions>& fp = streampos.frames; + // We have preview, dc frame and regular frame. + EXPECT_EQ(3, fp.size()); + EXPECT_EQ(7, fp[2].section_end.size()); + EXPECT_EQ(data.size(), fp[2].section_end[6]); + std::vector<Breakpoint> breakpoints{ + {fp[0].frame_start, NO_FLUSH}, // headers + {fp[1].frame_start, NO_FLUSH}, // preview + {fp[2].frame_start, NO_FLUSH}, // dc frame + {fp[2].section_end[0], NO_FLUSH}, // DC global + {fp[2].section_end[1] - 1, NO_FLUSH}, // partial DC group + {fp[2].section_end[1], NEW_FLUSH}, // DC group + {fp[2].section_end[2], SAME_FLUSH}, // AC global + {fp[2].section_end[3], NEW_FLUSH}, // AC group 0 + {fp[2].section_end[4] - 1, SAME_FLUSH}, // partial AC group 1 + {fp[2].section_end[4], NEW_FLUSH}, // AC group 1 + {fp[2].section_end[5], NEW_FLUSH}, // AC group 2 + {data.size() - 1, SAME_FLUSH}, // partial AC group 3 + {data.size(), NEW_FLUSH}}; // full image + VerifyProgression(xsize, ysize, num_channels, pixels, data, breakpoints); +} + +TEST(DecodeTest, ProgressionTestLosslessAlpha) { + size_t xsize = 508, ysize = 470; + uint32_t num_channels = 4; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + params.cparams.responsive = 1; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + StreamPositions streampos; + AnalyzeCodestream(data, &streampos); + const std::vector<FramePositions>& fp = streampos.frames; + // We have preview, dc frame and regular frame. + EXPECT_EQ(1, fp.size()); + EXPECT_EQ(7, fp[0].section_end.size()); + EXPECT_EQ(data.size(), fp[0].section_end[6]); + std::vector<Breakpoint> breakpoints{ + {fp[0].frame_start, NO_FLUSH}, // headers + {fp[0].section_end[0] - 1, NO_FLUSH}, // partial DC global + {fp[0].section_end[0], NEW_FLUSH}, // DC global + {fp[0].section_end[1], SAME_FLUSH}, // DC group + {fp[0].section_end[2], SAME_FLUSH}, // AC global + {fp[0].section_end[3], NEW_FLUSH}, // AC group 0 + {fp[0].section_end[4] - 1, SAME_FLUSH}, // partial AC group 1 + {fp[0].section_end[4], NEW_FLUSH}, // AC group 1 + {fp[0].section_end[5], NEW_FLUSH}, // AC group 2 + {data.size() - 1, SAME_FLUSH}, // partial AC group 3 + {data.size(), NEW_FLUSH}}; // full image + VerifyProgression(xsize, ysize, num_channels, pixels, data, breakpoints); +} + +void VerifyFilePosition(size_t expected_pos, const std::vector<uint8_t>& data, + JxlDecoder* dec) { + size_t remaining = JxlDecoderReleaseInput(dec); + size_t pos = data.size() - remaining; + EXPECT_EQ(expected_pos, pos); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, data.data() + pos, remaining)); +} + +TEST(DecodeTest, InputHandlingTestOneShot) { + size_t xsize = 508, ysize = 470; + uint32_t num_channels = 3; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + printf("Testing with box format %d\n", i); + jxl::TestCodestreamParams params; + params.cparams.progressive_dc = 1; + params.preview_mode = jxl::kSmallPreview; + params.box_format = (CodeStreamBoxFormat)i; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + StreamPositions streampos; + AnalyzeCodestream(data, &streampos); + const std::vector<FramePositions>& fp = streampos.frames; + // We have preview, dc frame and regular frame. + EXPECT_EQ(3, fp.size()); + + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + + int kNumEvents = 6; + int events[] = { + JXL_DEC_BASIC_INFO, JXL_DEC_COLOR_ENCODING, JXL_DEC_PREVIEW_IMAGE, + JXL_DEC_FRAME, JXL_DEC_FULL_IMAGE, JXL_DEC_FRAME_PROGRESSION, + }; + size_t end_positions[] = { + streampos.basic_info, fp[0].frame_start, + fp[1].frame_start, fp[2].toc_end, + streampos.codestream_end, streampos.codestream_end}; + int events_wanted = 0; + for (int j = 0; j < kNumEvents; ++j) { + events_wanted |= events[j]; + size_t end_pos = end_positions[j]; + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events_wanted)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, data.data(), data.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + VerifyFilePosition(streampos.basic_info, data, dec); + if (j >= 1) { + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[0].frame_start, data, dec); + } + if (j >= 2) { + EXPECT_EQ(JXL_DEC_NEED_PREVIEW_OUT_BUFFER, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[0].toc_end, data, dec); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)); + EXPECT_GE(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreviewOutBuffer(dec, &format, pixels2.data(), + buffer_size)); + EXPECT_EQ(JXL_DEC_PREVIEW_IMAGE, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[1].frame_start, data, dec); + } + if (j >= 3) { + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[2].toc_end, data, dec); + if (j >= 5) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetProgressiveDetail(dec, kDC)); + } + } + if (j >= 4) { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[2].toc_end, data, dec); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels2.data(), + pixels2.size())); + if (j >= 5) { + EXPECT_EQ(JXL_DEC_FRAME_PROGRESSION, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[2].section_end[1], data, dec); + } + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + VerifyFilePosition(streampos.codestream_end, data, dec); + } + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + VerifyFilePosition(end_pos, data, dec); + JxlDecoderDestroy(dec); + } + } +} + +TEST(DecodeTest, JXL_TRANSCODE_JPEG_TEST(InputHandlingTestJPEGOneshot)) { + TEST_LIBJPEG_SUPPORT(); + size_t xsize = 123; + size_t ysize = 77; + size_t channels = 3; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, channels, /*seed=*/0); + for (int i = 1; i < kCSBF_NUM_ENTRIES; ++i) { + printf("Testing with box format %d\n", i); + std::vector<uint8_t> jpeg_codestream; + jxl::TestCodestreamParams params; + params.cparams.color_transform = jxl::ColorTransform::kNone; + params.jpeg_codestream = &jpeg_codestream; + params.preview_mode = jxl::kSmallPreview; + params.box_format = (CodeStreamBoxFormat)i; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, channels, params); + JxlPixelFormat format = {3, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + StreamPositions streampos; + AnalyzeCodestream(data, &streampos); + const std::vector<FramePositions>& fp = streampos.frames; + // We have preview and regular frame. + EXPECT_EQ(2, fp.size()); + EXPECT_LT(0, streampos.jbrd_end); + + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + + int kNumEvents = 6; + int events[] = {JXL_DEC_BASIC_INFO, JXL_DEC_JPEG_RECONSTRUCTION, + JXL_DEC_COLOR_ENCODING, JXL_DEC_PREVIEW_IMAGE, + JXL_DEC_FRAME, JXL_DEC_FULL_IMAGE}; + size_t end_positions[] = {streampos.basic_info, streampos.basic_info, + fp[0].frame_start, fp[1].frame_start, + fp[1].toc_end, streampos.codestream_end}; + int events_wanted = 0; + for (int j = 0; j < kNumEvents; ++j) { + printf("j = %d\n", j); + events_wanted |= events[j]; + size_t end_pos = end_positions[j]; + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events_wanted)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, data.data(), data.size())); + if (j >= 1) { + EXPECT_EQ(JXL_DEC_JPEG_RECONSTRUCTION, JxlDecoderProcessInput(dec)); + VerifyFilePosition(streampos.jbrd_end, data, dec); + } + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + VerifyFilePosition(streampos.basic_info, data, dec); + if (j >= 2) { + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[0].frame_start, data, dec); + } + if (j >= 3) { + EXPECT_EQ(JXL_DEC_NEED_PREVIEW_OUT_BUFFER, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[0].toc_end, data, dec); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)); + EXPECT_GE(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreviewOutBuffer(dec, &format, pixels2.data(), + buffer_size)); + EXPECT_EQ(JXL_DEC_PREVIEW_IMAGE, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[1].frame_start, data, dec); + } + if (j >= 4) { + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[1].toc_end, data, dec); + } + if (j >= 5) { + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + VerifyFilePosition(fp[1].toc_end, data, dec); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels2.data(), + pixels2.size())); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + VerifyFilePosition(streampos.codestream_end, data, dec); + } + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + VerifyFilePosition(end_pos, data, dec); + JxlDecoderDestroy(dec); + } + } +} + +TEST(DecodeTest, InputHandlingTestStreaming) { + size_t xsize = 508, ysize = 470; + uint32_t num_channels = 3; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + for (int i = 0; i < kCSBF_NUM_ENTRIES; ++i) { + printf("Testing with box format %d\n", i); + fflush(stdout); + jxl::TestCodestreamParams params; + params.cparams.progressive_dc = 1; + params.box_format = (CodeStreamBoxFormat)i; + params.preview_mode = jxl::kSmallPreview; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + StreamPositions streampos; + AnalyzeCodestream(data, &streampos); + const std::vector<FramePositions>& fp = streampos.frames; + // We have preview, dc frame and regular frame. + EXPECT_EQ(3, fp.size()); + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + int events_wanted = + (JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING | JXL_DEC_PREVIEW_IMAGE | + JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE | JXL_DEC_FRAME_PROGRESSION | + JXL_DEC_BOX); + for (size_t increment : {1, 7, 27, 1024}) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, events_wanted)); + size_t file_pos = 0; + size_t box_index = 0; + size_t avail_in = 0; + for (;;) { + const uint8_t* next_in = data.data() + file_pos; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + size_t consumed = avail_in - remaining; + file_pos += consumed; + avail_in += increment; + avail_in = std::min<size_t>(avail_in, data.size() - file_pos); + if (status == JXL_DEC_BASIC_INFO) { + EXPECT_EQ(file_pos, streampos.basic_info); + } else if (status == JXL_DEC_COLOR_ENCODING) { + EXPECT_EQ(file_pos, streampos.frames[0].frame_start); + } else if (status == JXL_DEC_NEED_PREVIEW_OUT_BUFFER) { + EXPECT_EQ(file_pos, streampos.frames[0].toc_end); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderPreviewOutBufferSize(dec, &format, &buffer_size)); + EXPECT_GE(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetPreviewOutBuffer(dec, &format, pixels2.data(), + buffer_size)); + } else if (status == JXL_DEC_PREVIEW_IMAGE) { + EXPECT_EQ(file_pos, streampos.frames[1].frame_start); + } else if (status == JXL_DEC_FRAME) { + EXPECT_EQ(file_pos, streampos.frames[2].toc_end); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetProgressiveDetail(dec, kDC)); + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(file_pos, streampos.frames[2].toc_end); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, pixels2.data(), + pixels2.size())); + } else if (status == JXL_DEC_FRAME_PROGRESSION) { + EXPECT_EQ(file_pos, streampos.frames[2].section_end[1]); + } else if (status == JXL_DEC_FULL_IMAGE) { + EXPECT_EQ(file_pos, streampos.codestream_end); + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_EQ(file_pos, streampos.codestream_end); + break; + } else if (status == JXL_DEC_NEED_MORE_INPUT) { + EXPECT_LT(remaining, 12); + if ((i == kCSBF_None && file_pos >= 2) || + (box_index > 0 && box_index < streampos.box_start.size() && + file_pos >= streampos.box_start[box_index - 1] + 12 && + file_pos < streampos.box_start[box_index])) { + EXPECT_EQ(remaining, 0); + } + if (file_pos == data.size()) break; + } else if (status == JXL_DEC_BOX) { + ASSERT_LT(box_index, streampos.box_start.size()); + EXPECT_EQ(file_pos, streampos.box_start[box_index++]); + } else { + printf("Unexpected status: 0x%x\n", (int)status); + FAIL(); + } + } + JxlDecoderDestroy(dec); + } + } +} + +TEST(DecodeTest, FlushTest) { + // Size large enough for multiple groups, required to have progressive + // stages + size_t xsize = 333, ysize = 300; + uint32_t num_channels = 3; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.preview_mode = jxl::kSmallPreview; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + // Ensure that the first part contains at least the full DC of the image, + // otherwise flush does not work. + size_t first_part = data.size() - 1; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), first_part)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Output buffer not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels2.data(), pixels2.size())); + + // Must process input further until we get JXL_DEC_NEED_MORE_INPUT, even if + // data was already input before, since the processing of the frame only + // happens at the JxlDecoderProcessInput call after JXL_DEC_FRAME. + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + + // Crude test of actual pixel data: pixel threshold of about 4% (2560/65535). + // 29000 pixels can be above the threshold + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 29000u); + + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + size_t consumed = first_part - JxlDecoderReleaseInput(dec); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data() + consumed, + data.size() - consumed)); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + // Lower threshold for the final (still lossy) image + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 11000u); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, FlushTestImageOutCallback) { + // Size large enough for multiple groups, required to have progressive + // stages + size_t xsize = 333, ysize = 300; + uint32_t num_channels = 3; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.preview_mode = jxl::kSmallPreview; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + + size_t bytes_per_pixel = format.num_channels * 2; + size_t stride = bytes_per_pixel * xsize; + auto callback = [&](size_t x, size_t y, size_t num_pixels, + const void* pixels_row) { + memcpy(pixels2.data() + stride * y + bytes_per_pixel * x, pixels_row, + num_pixels * bytes_per_pixel); + }; + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + // Ensure that the first part contains at least the full DC of the image, + // otherwise flush does not work. + size_t first_part = data.size() - 1; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), first_part)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Output callback not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutCallback( + dec, &format, + [](void* opaque, size_t x, size_t y, + size_t xsize, const void* pixels_row) { + auto cb = + static_cast<decltype(&callback)>(opaque); + (*cb)(x, y, xsize, pixels_row); + }, + /*opaque=*/&callback)); + + // Must process input further until we get JXL_DEC_NEED_MORE_INPUT, even if + // data was already input before, since the processing of the frame only + // happens at the JxlDecoderProcessInput call after JXL_DEC_FRAME. + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + + // Crude test of actual pixel data: pixel threshold of about 4% (2560/65535). + // 29000 pixels can be above the threshold + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 29000u); + + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + size_t consumed = first_part - JxlDecoderReleaseInput(dec); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data() + consumed, + data.size() - consumed)); + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + // Lower threshold for the final (still lossy) image + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 11000u); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, FlushTestLossyProgressiveAlpha) { + // Size large enough for multiple groups, required to have progressive + // stages + size_t xsize = 333, ysize = 300; + uint32_t num_channels = 4; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.preview_mode = jxl::kSmallPreview; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + // Ensure that the first part contains at least the full DC of the image, + // otherwise flush does not work. + size_t first_part = data.size() - 1; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), first_part)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Output buffer not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels2.data(), pixels2.size())); + + // Must process input further until we get JXL_DEC_NEED_MORE_INPUT, even if + // data was already input before, since the processing of the frame only + // happens at the JxlDecoderProcessInput call after JXL_DEC_FRAME. + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 30000u); + + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + size_t consumed = first_part - JxlDecoderReleaseInput(dec); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data() + consumed, + data.size() - consumed)); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 11000u); + + JxlDecoderDestroy(dec); +} +TEST(DecodeTest, FlushTestLossyProgressiveAlphaUpsampling) { + size_t xsize = 533, ysize = 401; + uint32_t num_channels = 4; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.cparams.resampling = 2; + params.cparams.ec_resampling = 4; + params.preview_mode = jxl::kSmallPreview; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + // Ensure that the first part contains at least the full DC of the image, + // otherwise flush does not work. + size_t first_part = data.size() * 2 / 3; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), first_part)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Output buffer not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels2.data(), pixels2.size())); + + // Must process input further until we get JXL_DEC_NEED_MORE_INPUT, even if + // data was already input before, since the processing of the frame only + // happens at the JxlDecoderProcessInput call after JXL_DEC_FRAME. + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 125000u); + + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + size_t consumed = first_part - JxlDecoderReleaseInput(dec); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data() + consumed, + data.size() - consumed)); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 70000u); + + JxlDecoderDestroy(dec); +} +TEST(DecodeTest, FlushTestLosslessProgressiveAlpha) { + // Size large enough for multiple groups, required to have progressive + // stages + size_t xsize = 333, ysize = 300; + uint32_t num_channels = 4; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + jxl::TestCodestreamParams params; + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + params.cparams.responsive = 1; + params.cparams.modular_group_size_shift = 1; + params.preview_mode = jxl::kSmallPreview; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + + // Ensure that the first part contains at least the full DC of the image, + // otherwise flush does not work. + size_t first_part = data.size() / 2; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data(), first_part)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + + // Output buffer not yet set + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderFlushImage(dec)); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels2.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, pixels2.data(), pixels2.size())); + + // Must process input further until we get JXL_DEC_NEED_MORE_INPUT, even if + // data was already input before, since the processing of the frame only + // happens at the JxlDecoderProcessInput call after JXL_DEC_FRAME. + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format, 2560.0), + 2700u); + + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + + size_t consumed = first_part - JxlDecoderReleaseInput(dec); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, data.data() + consumed, + data.size() - consumed)); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + EXPECT_LE(jxl::test::ComparePixels(pixels2.data(), pixels.data(), xsize, + ysize, format, format), + 0u); + + JxlDecoderDestroy(dec); +} + +class DecodeProgressiveTest : public ::testing::TestWithParam<int> {}; +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(DecodeProgressiveTestInstantiation, + DecodeProgressiveTest, + ::testing::Range(0, 8)); +TEST_P(DecodeProgressiveTest, ProgressiveEventTest) { + const int params = GetParam(); + int single_group = params & 1; + int lossless = (params >> 1) & 1; + uint32_t num_channels = 3 + ((params >> 2) & 1); + std::set<JxlProgressiveDetail> progressive_details = {kDC, kLastPasses, + kPasses}; + for (auto prog_detail : progressive_details) { + // Only few combinations are expected to support outputting + // intermediate flushes for complete DC and complete passes. + // The test can be updated if more cases are expected to support it. + bool expect_flush = (num_channels & 1) && !lossless; + size_t xsize, ysize; + if (single_group) { + // An image smaller than 256x256 ensures it contains only 1 group. + xsize = 99; + ysize = 100; + } else { + xsize = 277; + ysize = 280; + } + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, num_channels, 0); + JxlPixelFormat format = {num_channels, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + jxl::ColorEncoding color_encoding = jxl::ColorEncoding::SRGB(false); + jxl::CodecInOut io; + EXPECT_TRUE(jxl::ConvertFromExternal( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, color_encoding, + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, &io.Main())); + jxl::TestCodestreamParams params; + if (lossless) { + params.cparams.SetLossless(); + } else { + params.cparams.butteraugli_distance = 0.5f; + } + jxl::PassDefinition passes[] = { + {2, 0, 4}, {4, 0, 4}, {8, 2, 2}, {8, 1, 2}, {8, 0, 1}}; + const int kNumPasses = 5; + jxl::ProgressiveMode progressive_mode{passes}; + params.cparams.custom_progressive_mode = &progressive_mode; + std::vector<uint8_t> data = + jxl::CreateTestJXLCodestream(jxl::Bytes(pixels.data(), pixels.size()), + xsize, ysize, num_channels, params); + + for (size_t increment : {(size_t)1, data.size()}) { + printf( + "Testing with single_group=%d, lossless=%d, " + "num_channels=%d, prog_detail=%d, increment=%d\n", + single_group, lossless, (int)num_channels, (int)prog_detail, + (int)increment); + std::vector<std::vector<uint8_t>> passes(kNumPasses + 1); + for (int i = 0; i <= kNumPasses; ++i) { + passes[i].resize(pixels.size()); + } + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME | + JXL_DEC_FULL_IMAGE | JXL_DEC_FRAME_PROGRESSION)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSetProgressiveDetail(dec, kFrames)); + EXPECT_EQ(JXL_DEC_ERROR, + JxlDecoderSetProgressiveDetail(dec, kDCProgressive)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSetProgressiveDetail(dec, kDCGroups)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderSetProgressiveDetail(dec, kGroups)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetProgressiveDetail(dec, prog_detail)); + + uint8_t* next_in = data.data(); + size_t avail_in = 0; + size_t pos = 0; + + auto process_input = [&]() { + for (;;) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, next_in, avail_in)); + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + if (status == JXL_DEC_NEED_MORE_INPUT && pos < data.size()) { + size_t chunk = std::min<size_t>(increment, data.size() - pos); + pos += chunk; + avail_in += chunk; + continue; + } + return status; + } + }; + + EXPECT_EQ(JXL_DEC_BASIC_INFO, process_input()); + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + + EXPECT_EQ(JXL_DEC_FRAME, process_input()); + + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(pixels.size(), buffer_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, passes[kNumPasses].data(), + passes[kNumPasses].size())); + + auto next_pass = [&](int pass) { + if (prog_detail <= kDC) return kNumPasses; + if (prog_detail <= kLastPasses) { + return std::min(pass + 2, kNumPasses); + } + return pass + 1; + }; + + if (expect_flush) { + // Return a particular downsampling ratio only after the last + // pass for that downsampling was processed. + int expected_downsampling_ratios[] = {8, 8, 4, 4, 2}; + for (int p = 0; p < kNumPasses; p = next_pass(p)) { + EXPECT_EQ(JXL_DEC_FRAME_PROGRESSION, process_input()); + EXPECT_EQ(expected_downsampling_ratios[p], + JxlDecoderGetIntendedDownsamplingRatio(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderFlushImage(dec)); + passes[p] = passes[kNumPasses]; + } + } + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, process_input()); + EXPECT_EQ(JXL_DEC_SUCCESS, process_input()); + + JxlDecoderDestroy(dec); + + if (!expect_flush) { + continue; + } + jxl::ButteraugliParams ba; + std::vector<float> distances(kNumPasses + 1); + for (int p = 0;; p = next_pass(p)) { + jxl::CodecInOut io1; + EXPECT_TRUE(jxl::ConvertFromExternal( + jxl::Bytes(passes[p].data(), passes[p].size()), xsize, ysize, + color_encoding, + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, &io1.Main())); + distances[p] = ButteraugliDistance( + io.frames, io1.frames, ba, *JxlGetDefaultCms(), nullptr, nullptr); + if (p == kNumPasses) break; + } + const float kMaxDistance[kNumPasses + 1] = {30.0f, 20.0f, 10.0f, + 5.0f, 3.0f, 2.0f}; + EXPECT_LT(distances[kNumPasses], kMaxDistance[kNumPasses]); + for (int p = 0; p < kNumPasses;) { + int next_p = next_pass(p); + EXPECT_LT(distances[p], kMaxDistance[p]); + // Verify that the returned pass image is actually not the + // same as the next pass image, by checking that it has a bit + // worse butteraugli score. + EXPECT_LT(distances[next_p] * 1.1f, distances[p]); + p = next_p; + } + } + } +} + +void VerifyJPEGReconstruction(jxl::Span<const uint8_t> container, + jxl::Span<const uint8_t> jpeg_bytes) { + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec.get(), JXL_DEC_JPEG_RECONSTRUCTION | JXL_DEC_FULL_IMAGE)); + JxlDecoderSetInput(dec.get(), container.data(), container.size()); + EXPECT_EQ(JXL_DEC_JPEG_RECONSTRUCTION, JxlDecoderProcessInput(dec.get())); + std::vector<uint8_t> reconstructed_buffer(128); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data(), + reconstructed_buffer.size())); + size_t used = 0; + JxlDecoderStatus process_result = JXL_DEC_JPEG_NEED_MORE_OUTPUT; + while (process_result == JXL_DEC_JPEG_NEED_MORE_OUTPUT) { + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + reconstructed_buffer.resize(reconstructed_buffer.size() * 2); + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data() + used, + reconstructed_buffer.size() - used)); + process_result = JxlDecoderProcessInput(dec.get()); + } + ASSERT_EQ(JXL_DEC_FULL_IMAGE, process_result); + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + ASSERT_EQ(used, jpeg_bytes.size()); + EXPECT_EQ(0, memcmp(reconstructed_buffer.data(), jpeg_bytes.data(), used)); +} + +TEST(DecodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGReconstructTestCodestream)) { + TEST_LIBJPEG_SUPPORT(); + size_t xsize = 123; + size_t ysize = 77; + size_t channels = 3; + std::vector<uint8_t> pixels = + jxl::test::GetSomeTestImage(xsize, ysize, channels, /*seed=*/0); + std::vector<uint8_t> jpeg_codestream; + jxl::TestCodestreamParams params; + params.cparams.color_transform = jxl::ColorTransform::kNone; + params.box_format = kCSBF_Single; + params.jpeg_codestream = &jpeg_codestream; + params.preview_mode = jxl::kSmallPreview; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, channels, params); + VerifyJPEGReconstruction(jxl::Bytes(compressed), jxl::Bytes(jpeg_codestream)); +} + +TEST(DecodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGReconstructionTest)) { + const std::string jpeg_path = "jxl/flower/flower.png.im_q85_420.jpg"; + const std::vector<uint8_t> orig = jxl::test::ReadTestData(jpeg_path); + jxl::CodecInOut orig_io; + ASSERT_TRUE(jxl::jpeg::DecodeImageJPG(jxl::Bytes(orig), &orig_io)); + orig_io.metadata.m.xyb_encoded = false; + jxl::BitWriter writer; + ASSERT_TRUE(WriteCodestreamHeaders(&orig_io.metadata, &writer, nullptr)); + writer.ZeroPadToByte(); + jxl::CompressParams cparams; + cparams.color_transform = jxl::ColorTransform::kNone; + ASSERT_TRUE(jxl::EncodeFrame(cparams, jxl::FrameInfo{}, &orig_io.metadata, + orig_io.Main(), *JxlGetDefaultCms(), + /*pool=*/nullptr, &writer, + /*aux_out=*/nullptr)); + + std::vector<uint8_t> jpeg_data; + ASSERT_TRUE( + EncodeJPEGData(*orig_io.Main().jpeg_data.get(), &jpeg_data, cparams)); + std::vector<uint8_t> container; + jxl::Bytes(jxl::kContainerHeader).AppendTo(&container); + jxl::AppendBoxHeader(jxl::MakeBoxType("jbrd"), jpeg_data.size(), false, + &container); + jxl::Bytes(jpeg_data).AppendTo(&container); + jxl::AppendBoxHeader(jxl::MakeBoxType("jxlc"), 0, true, &container); + jxl::PaddedBytes codestream = std::move(writer).TakeBytes(); + jxl::Bytes(codestream).AppendTo(&container); + VerifyJPEGReconstruction(jxl::Bytes(container), jxl::Bytes(orig)); +} + +TEST(DecodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGReconstructionMetadataTest)) { + const std::string jpeg_path = "jxl/jpeg_reconstruction/1x1_exif_xmp.jpg"; + const std::string jxl_path = "jxl/jpeg_reconstruction/1x1_exif_xmp.jxl"; + const std::vector<uint8_t> jpeg = jxl::test::ReadTestData(jpeg_path); + const std::vector<uint8_t> jxl = jxl::test::ReadTestData(jxl_path); + VerifyJPEGReconstruction(jxl::Bytes(jxl), jxl::Bytes(jpeg)); +} + +TEST(DecodeTest, ContinueFinalNonEssentialBoxTest) { + size_t xsize = 80, ysize = 90; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + jxl::TestCodestreamParams params; + params.box_format = kCSBF_Multi_Other_Terminated; + params.add_icc_profile = true; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 4, params); + StreamPositions streampos; + AnalyzeCodestream(compressed, &streampos); + + // The non-essential final box size including 8-byte header + size_t final_box_size = unk3_box_size + 8; + size_t last_box_begin = compressed.size() - final_box_size; + // Verify that the test is indeed setup correctly to be at the beginning of + // the 'unkn' box header. + ASSERT_EQ(compressed[last_box_begin + 3], final_box_size); + ASSERT_EQ(compressed[last_box_begin + 4], 'u'); + ASSERT_EQ(compressed[last_box_begin + 5], 'n'); + ASSERT_EQ(compressed[last_box_begin + 6], 'k'); + ASSERT_EQ(compressed[last_box_begin + 7], '3'); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | JXL_DEC_FRAME)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), last_box_begin)); + + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_FRAME, JxlDecoderProcessInput(dec)); + // The decoder returns success despite not having seen the final unknown box + // yet. This is because calling JxlDecoderCloseInput is not mandatory for + // backwards compatibility, so it doesn't know more bytes follow, the current + // bytes ended at a perfectly valid place. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + size_t remaining = JxlDecoderReleaseInput(dec); + // Since the test was set up to end exactly at the boundary of the final + // codestream box, and the decoder returned success, all bytes are expected to + // be consumed until the end of the frame header. + EXPECT_EQ(remaining, last_box_begin - streampos.frames[0].toc_end); + + // Now set the remaining non-codestream box as input. + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data() + last_box_begin, + compressed.size() - last_box_begin)); + // Even though JxlDecoderProcessInput already returned JXL_DEC_SUCCESS before, + // when calling it again now after setting more input, success is expected, no + // event occurs but the box has been successfully skipped. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); +} + +namespace { +bool BoxTypeEquals(const std::string& type_string, JxlBoxType type) { + return type_string.size() == 4 && type_string[0] == type[0] && + type_string[1] == type[1] && type_string[2] == type[2] && + type_string[3] == type[3]; +} +} // namespace + +TEST(DecodeTest, ExtentedBoxSizeTest) { + const std::string jxl_path = "jxl/boxes/square-extended-size-container.jxl"; + const std::vector<uint8_t> orig = jxl::test::ReadTestData(jxl_path); + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, JXL_DEC_BOX)); + + JxlBoxType type; + uint64_t box_size; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, orig.data(), orig.size())); + EXPECT_EQ(JXL_DEC_BOX, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + EXPECT_TRUE(BoxTypeEquals("JXL ", type)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxSizeRaw(dec, &box_size)); + EXPECT_EQ(12, box_size); + EXPECT_EQ(JXL_DEC_BOX, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + EXPECT_TRUE(BoxTypeEquals("ftyp", type)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxSizeRaw(dec, &box_size)); + EXPECT_EQ(20, box_size); + EXPECT_EQ(JXL_DEC_BOX, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + EXPECT_TRUE(BoxTypeEquals("jxlc", type)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxSizeRaw(dec, &box_size)); + EXPECT_EQ(72, box_size); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, JXL_BOXES_TEST(BoxTest)) { + size_t xsize = 1, ysize = 1; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + jxl::TestCodestreamParams params; + params.box_format = kCSBF_Multi_Other_Terminated; + params.add_icc_profile = true; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 4, params); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, JXL_DEC_BOX)); + + std::vector<std::string> expected_box_types = { + "JXL ", "ftyp", "jxlp", "unk1", "unk2", "jxlp", "jxlp", "jxlp", "unk3"}; + + // Value 0 means to not test the size: codestream is not required to be a + // particular exact size. + std::vector<size_t> expected_box_sizes = {12, 20, 0, 34, 18, 0, 0, 0, 20}; + + JxlBoxType type; + uint64_t box_size; + std::vector<uint8_t> contents(50); + size_t expected_release_size = 0; + + // Cannot get these when decoding didn't start yet + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderGetBoxSizeRaw(dec, &box_size)); + + uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + for (size_t i = 0; i < expected_box_types.size(); i++) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + EXPECT_EQ(JXL_DEC_BOX, JxlDecoderProcessInput(dec)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxSizeRaw(dec, &box_size)); + EXPECT_TRUE(BoxTypeEquals(expected_box_types[i], type)); + if (expected_box_sizes[i]) { + EXPECT_EQ(expected_box_sizes[i], box_size); + } + + if (expected_release_size > 0) { + EXPECT_EQ(expected_release_size, JxlDecoderReleaseBoxBuffer(dec)); + expected_release_size = 0; + } + + if (type[0] == 'u' && type[1] == 'n' && type[2] == 'k') { + JxlDecoderSetBoxBuffer(dec, contents.data(), contents.size()); + size_t expected_box_contents_size = + type[3] == '1' ? unk1_box_size + : (type[3] == '2' ? unk2_box_size : unk3_box_size); + expected_release_size = contents.size() - expected_box_contents_size; + } + size_t consumed = avail_in - JxlDecoderReleaseInput(dec); + next_in += consumed; + avail_in -= consumed; + } + + // After the last DEC_BOX event, check that the input position is exactly at + // the stat of the box header. + EXPECT_EQ(avail_in, expected_box_sizes.back()); + + // Even though all input is given, the decoder cannot assume there aren't + // more boxes if the input was not closed. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec, next_in, avail_in)); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec)); + JxlDecoderCloseInput(dec); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); +} + +TEST(DecodeTest, JXL_BOXES_TEST(ExifBrobBoxTest)) { + size_t xsize = 1, ysize = 1; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + jxl::TestCodestreamParams params; + // Lossless to verify pixels exactly after roundtrip. + params.cparams.SetLossless(); + params.box_format = kCSBF_Brob_Exif; + params.add_icc_profile = true; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 4, params); + + // Test raw brob box, not brotli-decompressing + for (int streaming = 0; streaming < 2; ++streaming) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, JXL_DEC_BOX)); + if (!streaming) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + JxlDecoderCloseInput(dec); + } + // for streaming input case + const uint8_t* next_in = compressed.data(); + size_t avail_in = 0; + size_t total_in = 0; + size_t step_size = 64; + + std::vector<uint8_t> box_buffer; + size_t box_num_output; + bool seen_brob_begin = false; + bool seen_brob_end = false; + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (streaming) { + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + size_t amount = step_size; + if (total_in + amount > compressed.size()) { + amount = compressed.size() - total_in; + } + avail_in += amount; + total_in += amount; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, next_in, avail_in)); + if (total_in == compressed.size()) JxlDecoderCloseInput(dec); + } else { + FAIL(); + break; + } + } else if (status == JXL_DEC_BOX || status == JXL_DEC_SUCCESS) { + if (!box_buffer.empty()) { + EXPECT_EQ(false, seen_brob_end); + seen_brob_end = true; + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + EXPECT_EQ(box_num_output, box_brob_exif_size - 8); + EXPECT_EQ( + 0, memcmp(box_buffer.data(), box_brob_exif + 8, box_num_output)); + box_buffer.clear(); + } + if (status == JXL_DEC_SUCCESS) break; + JxlBoxType type; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + if (BoxTypeEquals("brob", type)) { + EXPECT_EQ(false, seen_brob_begin); + seen_brob_begin = true; + box_buffer.resize(8); + JxlDecoderSetBoxBuffer(dec, box_buffer.data(), box_buffer.size()); + } + } else if (status == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + box_buffer.resize(box_buffer.size() * 2); + JxlDecoderSetBoxBuffer(dec, box_buffer.data() + box_num_output, + box_buffer.size() - box_num_output); + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + EXPECT_EQ(true, seen_brob_begin); + EXPECT_EQ(true, seen_brob_end); + + JxlDecoderDestroy(dec); + } + + // Test decompressed brob box + for (int streaming = 0; streaming < 2; ++streaming) { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSubscribeEvents(dec, JXL_DEC_BOX)); + if (!streaming) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + JxlDecoderCloseInput(dec); + } + // for streaming input case + const uint8_t* next_in = compressed.data(); + size_t avail_in = 0; + size_t total_in = 0; + size_t step_size = 64; + + std::vector<uint8_t> box_buffer; + size_t box_num_output; + bool seen_exif_begin = false; + bool seen_exif_end = false; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetDecompressBoxes(dec, JXL_TRUE)); + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + if (status == JXL_DEC_NEED_MORE_INPUT) { + if (streaming) { + size_t remaining = JxlDecoderReleaseInput(dec); + EXPECT_LE(remaining, avail_in); + next_in += avail_in - remaining; + avail_in = remaining; + size_t amount = step_size; + if (total_in + amount > compressed.size()) { + amount = compressed.size() - total_in; + } + avail_in += amount; + total_in += amount; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, next_in, avail_in)); + if (total_in == compressed.size()) JxlDecoderCloseInput(dec); + } else { + FAIL(); + break; + } + } else if (status == JXL_DEC_BOX || status == JXL_DEC_SUCCESS) { + if (!box_buffer.empty()) { + EXPECT_EQ(false, seen_exif_end); + seen_exif_end = true; + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + // Expect that the output has the same size and contents as the + // uncompressed exif data. Only check contents if the sizes match to + // avoid comparing uninitialized memory in the test. + EXPECT_EQ(box_num_output, exif_uncompressed_size); + if (box_num_output == exif_uncompressed_size) { + EXPECT_EQ(0, memcmp(box_buffer.data(), exif_uncompressed, + exif_uncompressed_size)); + } + box_buffer.clear(); + } + if (status == JXL_DEC_SUCCESS) break; + JxlBoxType type; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_TRUE)); + if (BoxTypeEquals("Exif", type)) { + EXPECT_EQ(false, seen_exif_begin); + seen_exif_begin = true; + box_buffer.resize(8); + JxlDecoderSetBoxBuffer(dec, box_buffer.data(), box_buffer.size()); + } + } else if (status == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + box_buffer.resize(box_buffer.size() * 2); + JxlDecoderSetBoxBuffer(dec, box_buffer.data() + box_num_output, + box_buffer.size() - box_num_output); + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + EXPECT_EQ(true, seen_exif_begin); + EXPECT_EQ(true, seen_exif_end); + + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, JXL_BOXES_TEST(PartialCodestreamBoxTest)) { + size_t xsize = 23, ysize = 81; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlPixelFormat format_orig = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + // Lossless to verify pixels exactly after roundtrip. + jxl::TestCodestreamParams params; + params.cparams.SetLossless(); + params.cparams.speed_tier = jxl::SpeedTier::kThunder; + params.box_format = kCSBF_Multi; + params.add_icc_profile = true; + std::vector<uint8_t> compressed = jxl::CreateTestJXLCodestream( + jxl::Bytes(pixels.data(), pixels.size()), xsize, ysize, 4, params); + + std::vector<uint8_t> extracted_codestream; + + { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE | JXL_DEC_BOX)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + JxlDecoderCloseInput(dec); + + size_t num_jxlp = 0; + + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + + std::vector<uint8_t> box_buffer; + size_t box_num_output; + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + if (status == JXL_DEC_NEED_MORE_INPUT) { + FAIL(); + break; + } else if (status == JXL_DEC_BASIC_INFO) { + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format_orig, pixels2.data(), + pixels2.size())); + } else if (status == JXL_DEC_FULL_IMAGE) { + continue; + } else if (status == JXL_DEC_BOX || status == JXL_DEC_SUCCESS) { + if (!box_buffer.empty()) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + EXPECT_GE(box_num_output, 4); + // Do not insert the first 4 bytes, which are not part of the + // codestream, but the partial codestream box index + extracted_codestream.insert(extracted_codestream.end(), + box_buffer.begin() + 4, + box_buffer.begin() + box_num_output); + box_buffer.clear(); + } + if (status == JXL_DEC_SUCCESS) break; + JxlBoxType type; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec, type, JXL_FALSE)); + if (BoxTypeEquals("jxlp", type)) { + num_jxlp++; + box_buffer.resize(8); + JxlDecoderSetBoxBuffer(dec, box_buffer.data(), box_buffer.size()); + } + } else if (status == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + box_buffer.resize(box_buffer.size() * 2); + JxlDecoderSetBoxBuffer(dec, box_buffer.data() + box_num_output, + box_buffer.size() - box_num_output); + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + // The test file created with kCSBF_Multi is expected to have 4 jxlp boxes. + EXPECT_EQ(4, num_jxlp); + + EXPECT_EQ(0u, jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format_orig)); + + JxlDecoderDestroy(dec); + } + + // Now test whether the codestream extracted from the jxlp boxes can itself + // also be decoded and gives the same pixels + { + JxlDecoder* dec = JxlDecoderCreate(nullptr); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE | JXL_DEC_BOX)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, extracted_codestream.data(), + extracted_codestream.size())); + JxlDecoderCloseInput(dec); + + size_t num_boxes = 0; + + std::vector<uint8_t> pixels2; + pixels2.resize(pixels.size()); + + std::vector<uint8_t> box_buffer; + size_t box_num_output; + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec); + if (status == JXL_DEC_NEED_MORE_INPUT) { + FAIL(); + break; + } else if (status == JXL_DEC_BASIC_INFO) { + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(info.xsize, xsize); + EXPECT_EQ(info.ysize, ysize); + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format_orig, pixels2.data(), + pixels2.size())); + } else if (status == JXL_DEC_FULL_IMAGE) { + continue; + } else if (status == JXL_DEC_BOX) { + num_boxes++; + } else if (status == JXL_DEC_BOX_NEED_MORE_OUTPUT) { + size_t remaining = JxlDecoderReleaseBoxBuffer(dec); + box_num_output = box_buffer.size() - remaining; + box_buffer.resize(box_buffer.size() * 2); + JxlDecoderSetBoxBuffer(dec, box_buffer.data() + box_num_output, + box_buffer.size() - box_num_output); + } else if (status == JXL_DEC_SUCCESS) { + break; + } else { + // We do not expect any other events or errors + FAIL(); + break; + } + } + + EXPECT_EQ(0, num_boxes); // The data does not use the container format. + EXPECT_EQ(0u, jxl::test::ComparePixels(pixels.data(), pixels2.data(), xsize, + ysize, format_orig, format_orig)); + + JxlDecoderDestroy(dec); + } +} + +TEST(DecodeTest, SpotColorTest) { + jxl::CodecInOut io; + size_t xsize = 55, ysize = 257; + io.metadata.m.color_encoding = jxl::ColorEncoding::LinearSRGB(); + jxl::Image3F main(xsize, ysize); + jxl::ImageF spot(xsize, ysize); + jxl::ZeroFillImage(&main); + jxl::ZeroFillImage(&spot); + + for (size_t y = 0; y < ysize; y++) { + float* JXL_RESTRICT rowm = main.PlaneRow(1, y); + float* JXL_RESTRICT rows = spot.Row(y); + for (size_t x = 0; x < xsize; x++) { + rowm[x] = (x + y) * (1.f / 255.f); + rows[x] = ((x ^ y) & 255) * (1.f / 255.f); + } + } + io.SetFromImage(std::move(main), jxl::ColorEncoding::LinearSRGB()); + jxl::ExtraChannelInfo info; + info.bit_depth.bits_per_sample = 8; + info.dim_shift = 0; + info.type = jxl::ExtraChannel::kSpotColor; + info.spot_color[0] = 0.5f; + info.spot_color[1] = 0.2f; + info.spot_color[2] = 1.f; + info.spot_color[3] = 0.5f; + + io.metadata.m.extra_channel_info.push_back(info); + std::vector<jxl::ImageF> ec; + ec.push_back(std::move(spot)); + io.frames[0].SetExtraChannels(std::move(ec)); + + jxl::CompressParams cparams; + cparams.speed_tier = jxl::SpeedTier::kLightning; + cparams.modular_mode = true; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.butteraugli_distance = 0.f; + + std::vector<uint8_t> compressed; + EXPECT_TRUE(jxl::test::EncodeFile(cparams, &io, &compressed)); + + for (size_t render_spot = 0; render_spot < 2; render_spot++) { + JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + JxlDecoder* dec = JxlDecoderCreate(NULL); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec, JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE)); + if (!render_spot) { + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetRenderSpotcolors(dec, JXL_FALSE)); + } + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetInput(dec, compressed.data(), compressed.size())); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + JxlBasicInfo binfo; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &binfo)); + EXPECT_EQ(1u, binfo.num_extra_channels); + EXPECT_EQ(xsize, binfo.xsize); + EXPECT_EQ(ysize, binfo.ysize); + + JxlExtraChannelInfo extra_info; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetExtraChannelInfo(dec, 0, &extra_info)); + EXPECT_EQ((unsigned int)jxl::ExtraChannel::kSpotColor, extra_info.type); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + size_t extra_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderExtraChannelBufferSize(dec, &format, &extra_size, 0)); + + std::vector<uint8_t> image(buffer_size); + std::vector<uint8_t> extra(extra_size); + size_t bytes_per_pixel = format.num_channels * + jxl::test::GetDataBits(format.data_type) / + jxl::kBitsPerByte; + size_t stride = bytes_per_pixel * binfo.xsize; + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer( + dec, &format, image.data(), image.size())); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetExtraChannelBuffer(dec, &format, extra.data(), + extra.size(), 0)); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + // After the full image was output, JxlDecoderProcessInput should return + // success to indicate all is done. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + JxlDecoderDestroy(dec); + + for (size_t y = 0; y < ysize; y++) { + uint8_t* JXL_RESTRICT rowm = image.data() + stride * y; + uint8_t* JXL_RESTRICT rows = extra.data() + xsize * y; + for (size_t x = 0; x < xsize; x++) { + if (!render_spot) { + // if spot color isn't rendered, main image should be as we made it + // (red and blue are all zeroes) + + EXPECT_EQ(rowm[x * 3 + 0], 0); + EXPECT_EQ(rowm[x * 3 + 1], (x + y > 255 ? 255 : x + y)); + EXPECT_EQ(rowm[x * 3 + 2], 0); + } + if (render_spot) { + // if spot color is rendered, expect red and blue to look like the + // spot color channel + EXPECT_LT(abs(rowm[x * 3 + 0] - (rows[x] * 0.25f)), 1); + EXPECT_LT(abs(rowm[x * 3 + 2] - (rows[x] * 0.5f)), 1); + } + EXPECT_EQ(rows[x], ((x ^ y) & 255)); + } + } + } +} + +TEST(DecodeTest, CloseInput) { + std::vector<uint8_t> partial_file = {0xff}; + + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec.get(), + JXL_DEC_BASIC_INFO | JXL_DEC_FULL_IMAGE)); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetInput(dec.get(), partial_file.data(), + partial_file.size())); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec.get())); + EXPECT_EQ(JXL_DEC_NEED_MORE_INPUT, JxlDecoderProcessInput(dec.get())); + JxlDecoderCloseInput(dec.get()); + EXPECT_EQ(JXL_DEC_ERROR, JxlDecoderProcessInput(dec.get())); +} diff --git a/third_party/jpeg-xl/lib/jxl/decode_to_jpeg.cc b/third_party/jpeg-xl/lib/jxl/decode_to_jpeg.cc new file mode 100644 index 0000000000..36d19fe793 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/decode_to_jpeg.cc @@ -0,0 +1,182 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/decode_to_jpeg.h" + +#include <jxl/decode.h> + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" // JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/jpeg/dec_jpeg_data.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { + +#if JPEGXL_ENABLE_TRANSCODE_JPEG + +JxlDecoderStatus JxlToJpegDecoder::Process(const uint8_t** next_in, + size_t* avail_in) { + if (!inside_box_) { + JXL_UNREACHABLE( + "processing of JPEG reconstruction data outside JPEG reconstruction " + "box"); + } + Span<const uint8_t> to_decode; + if (box_until_eof_) { + // Until EOF means consume all data. + to_decode = Bytes(*next_in, *avail_in); + *next_in += *avail_in; + *avail_in = 0; + } else { + // Defined size means consume min(available, needed). + size_t avail_recon_in = + std::min<size_t>(*avail_in, box_size_ - buffer_.size()); + to_decode = Bytes(*next_in, avail_recon_in); + *next_in += avail_recon_in; + *avail_in -= avail_recon_in; + } + bool old_data_exists = !buffer_.empty(); + if (old_data_exists) { + // Append incoming data to buffer if we already had data in the buffer. + buffer_.insert(buffer_.end(), to_decode.data(), + to_decode.data() + to_decode.size()); + to_decode = Bytes(buffer_.data(), buffer_.size()); + } + if (!box_until_eof_ && to_decode.size() > box_size_) { + JXL_UNREACHABLE("JPEG reconstruction data to decode larger than expected"); + } + if (box_until_eof_ || to_decode.size() == box_size_) { + // If undefined size, or the right size, try to decode. + jpeg_data_ = make_unique<jpeg::JPEGData>(); + const auto status = jpeg::DecodeJPEGData(to_decode, jpeg_data_.get()); + if (status.IsFatalError()) return JXL_DEC_ERROR; + if (status) { + // Successful decoding, emit event after updating state to track that we + // are no longer parsing JPEG reconstruction data. + inside_box_ = false; + return JXL_DEC_JPEG_RECONSTRUCTION; + } + if (box_until_eof_) { + // Unsuccessful decoding and undefined size, assume incomplete data. Copy + // the data if we haven't already. + if (!old_data_exists) { + buffer_.insert(buffer_.end(), to_decode.data(), + to_decode.data() + to_decode.size()); + } + } else { + // Unsuccessful decoding of correct amount of data, assume error. + return JXL_DEC_ERROR; + } + } else { + // Not enough data, copy the data if we haven't already. + if (!old_data_exists) { + buffer_.insert(buffer_.end(), to_decode.data(), + to_decode.data() + to_decode.size()); + } + } + return JXL_DEC_NEED_MORE_INPUT; +} + +size_t JxlToJpegDecoder::NumExifMarkers(const jpeg::JPEGData& jpeg_data) { + size_t num = 0; + for (size_t i = 0; i < jpeg_data.app_data.size(); ++i) { + if (jpeg_data.app_marker_type[i] == jxl::jpeg::AppMarkerType::kExif) { + num++; + } + } + return num; +} + +size_t JxlToJpegDecoder::NumXmpMarkers(const jpeg::JPEGData& jpeg_data) { + size_t num = 0; + for (size_t i = 0; i < jpeg_data.app_data.size(); ++i) { + if (jpeg_data.app_marker_type[i] == jxl::jpeg::AppMarkerType::kXMP) { + num++; + } + } + return num; +} + +JxlDecoderStatus JxlToJpegDecoder::ExifBoxContentSize( + const jpeg::JPEGData& jpeg_data, size_t* size) { + for (size_t i = 0; i < jpeg_data.app_data.size(); ++i) { + if (jpeg_data.app_marker_type[i] == jxl::jpeg::AppMarkerType::kExif) { + if (jpeg_data.app_data[i].size() < 3 + sizeof(jpeg::kExifTag)) { + // too small for app marker header + return JXL_DEC_ERROR; + } + // The first 4 bytes are the TIFF header from the box contents, and are + // not included in the JPEG + *size = jpeg_data.app_data[i].size() + 4 - 3 - sizeof(jpeg::kExifTag); + return JXL_DEC_SUCCESS; + } + } + return JXL_DEC_ERROR; +} + +JxlDecoderStatus JxlToJpegDecoder::XmlBoxContentSize( + const jpeg::JPEGData& jpeg_data, size_t* size) { + for (size_t i = 0; i < jpeg_data.app_data.size(); ++i) { + if (jpeg_data.app_marker_type[i] == jxl::jpeg::AppMarkerType::kXMP) { + if (jpeg_data.app_data[i].size() < 3 + sizeof(jpeg::kXMPTag)) { + // too small for app marker header + return JXL_DEC_ERROR; + } + *size = jpeg_data.app_data[i].size() - 3 - sizeof(jpeg::kXMPTag); + return JXL_DEC_SUCCESS; + } + } + return JXL_DEC_ERROR; +} + +JxlDecoderStatus JxlToJpegDecoder::SetExif(const uint8_t* data, size_t size, + jpeg::JPEGData* jpeg_data) { + for (size_t i = 0; i < jpeg_data->app_data.size(); ++i) { + if (jpeg_data->app_marker_type[i] == jxl::jpeg::AppMarkerType::kExif) { + if (jpeg_data->app_data[i].size() != + size + 3 + sizeof(jpeg::kExifTag) - 4) + return JXL_DEC_ERROR; + // The first 9 bytes are used for JPEG marker header. + jpeg_data->app_data[i][0] = 0xE1; + // The second and third byte are already filled in correctly + memcpy(jpeg_data->app_data[i].data() + 3, jpeg::kExifTag, + sizeof(jpeg::kExifTag)); + // The first 4 bytes are the TIFF header from the box contents, and are + // not included in the JPEG + memcpy(jpeg_data->app_data[i].data() + 3 + sizeof(jpeg::kExifTag), + data + 4, size - 4); + return JXL_DEC_SUCCESS; + } + } + return JXL_DEC_ERROR; +} +JxlDecoderStatus JxlToJpegDecoder::SetXmp(const uint8_t* data, size_t size, + jpeg::JPEGData* jpeg_data) { + for (size_t i = 0; i < jpeg_data->app_data.size(); ++i) { + if (jpeg_data->app_marker_type[i] == jxl::jpeg::AppMarkerType::kXMP) { + if (jpeg_data->app_data[i].size() != size + 3 + sizeof(jpeg::kXMPTag)) + return JXL_DEC_ERROR; + // The first 9 bytes are used for JPEG marker header. + jpeg_data->app_data[i][0] = 0xE1; + // The second and third byte are already filled in correctly + memcpy(jpeg_data->app_data[i].data() + 3, jpeg::kXMPTag, + sizeof(jpeg::kXMPTag)); + memcpy(jpeg_data->app_data[i].data() + 3 + sizeof(jpeg::kXMPTag), data, + size); + return JXL_DEC_SUCCESS; + } + } + return JXL_DEC_ERROR; +} + +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/decode_to_jpeg.h b/third_party/jpeg-xl/lib/jxl/decode_to_jpeg.h new file mode 100644 index 0000000000..8dd32d63b4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/decode_to_jpeg.h @@ -0,0 +1,220 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_DECODE_TO_JPEG_H_ +#define LIB_JXL_DECODE_TO_JPEG_H_ + +// JPEG XL to JPEG bytes decoder logic. The JxlToJpegDecoder class keeps track +// of the decoder state needed to parse the JPEG reconstruction box and provide +// the reconstructed JPEG to the output buffer. + +#include <jxl/decode.h> +#include <stdint.h> +#include <stdlib.h> + +#include <algorithm> +#include <cstring> +#include <memory> +#include <utility> +#include <vector> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/jpeg/jpeg_data.h" +#if JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +namespace jxl { + +#if JPEGXL_ENABLE_TRANSCODE_JPEG + +class JxlToJpegDecoder { + public: + // Returns whether an output buffer is set. + bool IsOutputSet() const { return next_out_ != nullptr; } + + // Returns whether the decoder is parsing a boxa JPEG box was parsed. + bool IsParsingBox() const { return inside_box_; } + + // Sets the output buffer used when producing JPEG output. + JxlDecoderStatus SetOutputBuffer(uint8_t* data, size_t size) { + if (next_out_) return JXL_DEC_ERROR; + next_out_ = data; + avail_size_ = size; + return JXL_DEC_SUCCESS; + } + + // Releases the buffer set with SetOutputBuffer(). + size_t ReleaseOutputBuffer() { + size_t result = avail_size_; + next_out_ = nullptr; + avail_size_ = 0; + return result; + } + + void StartBox(bool box_until_eof, size_t contents_size) { + // A new box implies that we clear the buffer. + buffer_.clear(); + inside_box_ = true; + if (box_until_eof) { + box_until_eof_ = true; + } else { + box_size_ = contents_size; + } + } + + // Consumes data from next_in/avail_in to reconstruct JPEG data. + // Uses box_size_, inside_box_ and box_until_eof_ to calculate how much to + // consume. Potentially stores unparsed data in buffer_. + // Potentially populates jpeg_data_. Potentially updates inside_box_. + // Returns JXL_DEC_JPEG_RECONSTRUCTION when finished, JXL_DEC_NEED_MORE_INPUT + // if more input is needed, JXL_DEC_ERROR on parsing error. + JxlDecoderStatus Process(const uint8_t** next_in, size_t* avail_in); + + // Returns non-owned copy of the JPEGData, only after Process finished and + // the JPEGData was not yet moved to an image bundle with + // SetImageBundleJpegData. + jpeg::JPEGData* GetJpegData() { return jpeg_data_.get(); } + + // Returns how many exif or xmp app markers are present in the JPEG data. A + // return value higher than 1 would require multiple exif boxes or multiple + // xmp boxes in the container format, and this is not supported by the API and + // considered an error. May only be called after Process returned success. + static size_t NumExifMarkers(const jpeg::JPEGData& jpeg_data); + static size_t NumXmpMarkers(const jpeg::JPEGData& jpeg_data); + + // Returns box content size for metadata, using the known data from the app + // markers. + static JxlDecoderStatus ExifBoxContentSize(const jpeg::JPEGData& jpeg_data, + size_t* size); + static JxlDecoderStatus XmlBoxContentSize(const jpeg::JPEGData& jpeg_data, + size_t* size); + + // Returns JXL_DEC_ERROR if there is no exif/XMP marker or the data size + // does not match, or this function is called before Process returned + // success, JXL_DEC_SUCCESS otherwise. As input, provide the full box contents + // but not the box header. In case of exif, this includes the 4-byte TIFF + // header, even though it won't be copied into the JPEG. + static JxlDecoderStatus SetExif(const uint8_t* data, size_t size, + jpeg::JPEGData* jpeg_data); + static JxlDecoderStatus SetXmp(const uint8_t* data, size_t size, + jpeg::JPEGData* jpeg_data); + + // Sets the JpegData of the ImageBundle passed if there is anything to set. + // Releases the JpegData from this decoder if set. + Status SetImageBundleJpegData(ImageBundle* ib) { + if (IsOutputSet() && jpeg_data_ != nullptr) { + if (!jpeg::SetJPEGDataFromICC(ib->metadata()->color_encoding.ICC(), + jpeg_data_.get())) { + return false; + } + ib->jpeg_data = std::move(jpeg_data_); + } + return true; + } + + JxlDecoderStatus WriteOutput(const jpeg::JPEGData& jpeg_data) { + // Copy JPEG bytestream if desired. + uint8_t* tmp_next_out = next_out_; + size_t tmp_avail_size = avail_size_; + auto write = [&tmp_next_out, &tmp_avail_size](const uint8_t* buf, + size_t len) { + size_t to_write = std::min<size_t>(tmp_avail_size, len); + if (to_write != 0) memcpy(tmp_next_out, buf, to_write); + tmp_next_out += to_write; + tmp_avail_size -= to_write; + return to_write; + }; + Status write_result = jpeg::WriteJpeg(jpeg_data, write); + if (!write_result) { + if (tmp_avail_size == 0) { + return JXL_DEC_JPEG_NEED_MORE_OUTPUT; + } + return JXL_DEC_ERROR; + } + next_out_ = tmp_next_out; + avail_size_ = tmp_avail_size; + return JXL_DEC_SUCCESS; + } + + private: + // Content of the most recently parsed JPEG reconstruction box if any. + std::vector<uint8_t> buffer_; + + // Decoded content of the most recently parsed JPEG reconstruction box is + // stored here. + std::unique_ptr<jpeg::JPEGData> jpeg_data_; + + // True if the decoder is currently reading bytes inside a JPEG reconstruction + // box. + bool inside_box_ = false; + + // True if the JPEG reconstruction box had undefined size (all remaining + // bytes). + bool box_until_eof_ = false; + // Size of most recently parsed JPEG reconstruction box contents. + size_t box_size_ = 0; + + // Next bytes to write JPEG reconstruction to. + uint8_t* next_out_ = nullptr; + // Available bytes to write JPEG reconstruction to. + size_t avail_size_ = 0; +}; + +#else + +// Fake class that disables support for decoding JPEG XL to JPEG. +class JxlToJpegDecoder { + public: + bool IsOutputSet() const { return false; } + bool IsParsingBox() const { return false; } + + JxlDecoderStatus SetOutputBuffer(uint8_t* /* data */, size_t /* size */) { + return JXL_DEC_ERROR; + } + size_t ReleaseOutputBuffer() { return 0; } + + void StartBox(bool /* box_until_eof */, size_t /* contents_size */) {} + + JxlDecoderStatus Process(const uint8_t** next_in, size_t* avail_in) { + return JXL_DEC_ERROR; + } + jpeg::JPEGData* GetJpegData() { return nullptr; } + + Status SetImageBundleJpegData(ImageBundle* /* ib */) { return true; } + + static size_t NumExifMarkers(const jpeg::JPEGData& /*jpeg_data*/) { + return 0; + } + static size_t NumXmpMarkers(const jpeg::JPEGData& /*jpeg_data*/) { return 0; } + static size_t ExifBoxContentSize(const jpeg::JPEGData& /*jpeg_data*/, + size_t* /*size*/) { + return JXL_DEC_ERROR; + } + static size_t XmlBoxContentSize(const jpeg::JPEGData& /*jpeg_data*/, + size_t* /*size*/) { + return JXL_DEC_ERROR; + } + static JxlDecoderStatus SetExif(const uint8_t* /*data*/, size_t /*size*/, + jpeg::JPEGData* /*jpeg_data*/) { + return JXL_DEC_ERROR; + } + static JxlDecoderStatus SetXmp(const uint8_t* /*data*/, size_t /*size*/, + jpeg::JPEGData* /*jpeg_data*/) { + return JXL_DEC_ERROR; + } + + JxlDecoderStatus WriteOutput(const jpeg::JPEGData& /* jpeg_data */) { + return JXL_DEC_SUCCESS; + } +}; + +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +} // namespace jxl + +#endif // LIB_JXL_DECODE_TO_JPEG_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.cc b/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.cc new file mode 100644 index 0000000000..d3b5ad3269 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.cc @@ -0,0 +1,1150 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_ac_strategy.h" + +#include <stdint.h> +#include <string.h> + +#include <algorithm> +#include <cmath> +#include <cstdio> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_ac_strategy.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/fast_math-inl.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_transforms-inl.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_debug_image.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/simd_util.h" + +// Some of the floating point constants in this file and in other +// files in the libjxl project have been obtained using the +// tools/optimizer/simplex_fork.py tool. It is a variation of +// Nelder-Mead optimization, and we generally try to minimize +// BPP * pnorm aggregate as reported by the benchmark_xl tool, +// but occasionally the values are optimized by using additional +// constraints such as maintaining a certain density, or ratio of +// popularity of integral transforms. Jyrki visually reviews all +// such changes and often makes manual changes to maintain good +// visual quality to changes where butteraugli was not sufficiently +// sensitive to some kind of degradation. Unfortunately image quality +// is still more of an art than science. + +// Set JXL_DEBUG_AC_STRATEGY to 1 to enable debugging. +#ifndef JXL_DEBUG_AC_STRATEGY +#define JXL_DEBUG_AC_STRATEGY 0 +#endif + +// This must come before the begin/end_target, but HWY_ONCE is only true +// after that, so use an "include guard". +#ifndef LIB_JXL_ENC_AC_STRATEGY_ +#define LIB_JXL_ENC_AC_STRATEGY_ +// Parameters of the heuristic are marked with a OPTIMIZE comment. +namespace jxl { +namespace { + +// Debugging utilities. + +// Returns a linear sRGB color (as bytes) for each AC strategy. +const uint8_t* TypeColor(const uint8_t& raw_strategy) { + JXL_ASSERT(AcStrategy::IsRawStrategyValid(raw_strategy)); + static_assert(AcStrategy::kNumValidStrategies == 27, "Change colors"); + static constexpr uint8_t kColors[][3] = { + {0xFF, 0xFF, 0x00}, // DCT8 + {0xFF, 0x80, 0x80}, // HORNUSS + {0xFF, 0x80, 0x80}, // DCT2x2 + {0xFF, 0x80, 0x80}, // DCT4x4 + {0x80, 0xFF, 0x00}, // DCT16x16 + {0x00, 0xC0, 0x00}, // DCT32x32 + {0xC0, 0xFF, 0x00}, // DCT16x8 + {0xC0, 0xFF, 0x00}, // DCT8x16 + {0x00, 0xFF, 0x00}, // DCT32x8 + {0x00, 0xFF, 0x00}, // DCT8x32 + {0x00, 0xFF, 0x00}, // DCT32x16 + {0x00, 0xFF, 0x00}, // DCT16x32 + {0xFF, 0x80, 0x00}, // DCT4x8 + {0xFF, 0x80, 0x00}, // DCT8x4 + {0xFF, 0xFF, 0x80}, // AFV0 + {0xFF, 0xFF, 0x80}, // AFV1 + {0xFF, 0xFF, 0x80}, // AFV2 + {0xFF, 0xFF, 0x80}, // AFV3 + {0x00, 0xC0, 0xFF}, // DCT64x64 + {0x00, 0xFF, 0xFF}, // DCT64x32 + {0x00, 0xFF, 0xFF}, // DCT32x64 + {0x00, 0x40, 0xFF}, // DCT128x128 + {0x00, 0x80, 0xFF}, // DCT128x64 + {0x00, 0x80, 0xFF}, // DCT64x128 + {0x00, 0x00, 0xC0}, // DCT256x256 + {0x00, 0x00, 0xFF}, // DCT256x128 + {0x00, 0x00, 0xFF}, // DCT128x256 + }; + return kColors[raw_strategy]; +} + +const uint8_t* TypeMask(const uint8_t& raw_strategy) { + JXL_ASSERT(AcStrategy::IsRawStrategyValid(raw_strategy)); + static_assert(AcStrategy::kNumValidStrategies == 27, "Add masks"); + // implicitly, first row and column is made dark + static constexpr uint8_t kMask[][64] = { + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // DCT8 + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 1, 1, 1, 1, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 1, 0, 0, 1, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // HORNUSS + { + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 0, 1, 0, 1, 0, 1, 0, // + }, // 2x2 + { + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + }, // 4x4 + {}, // DCT16x16 (unused) + {}, // DCT32x32 (unused) + {}, // DCT16x8 (unused) + {}, // DCT8x16 (unused) + {}, // DCT32x8 (unused) + {}, // DCT8x32 (unused) + {}, // DCT32x16 (unused) + {}, // DCT16x32 (unused) + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // DCT4x8 + { + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + 0, 0, 0, 0, 1, 0, 0, 0, // + }, // DCT8x4 + { + 1, 1, 1, 1, 1, 0, 0, 0, // + 1, 1, 1, 1, 0, 0, 0, 0, // + 1, 1, 1, 0, 0, 0, 0, 0, // + 1, 1, 0, 0, 0, 0, 0, 0, // + 1, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // AFV0 + { + 0, 0, 0, 0, 1, 1, 1, 1, // + 0, 0, 0, 0, 0, 1, 1, 1, // + 0, 0, 0, 0, 0, 0, 1, 1, // + 0, 0, 0, 0, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + }, // AFV1 + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 1, 0, 0, 0, 0, 0, 0, 0, // + 1, 1, 0, 0, 0, 0, 0, 0, // + 1, 1, 1, 0, 0, 0, 0, 0, // + 1, 1, 1, 1, 0, 0, 0, 0, // + }, // AFV2 + { + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 0, // + 0, 0, 0, 0, 0, 0, 0, 1, // + 0, 0, 0, 0, 0, 0, 1, 1, // + 0, 0, 0, 0, 0, 1, 1, 1, // + }, // AFV3 + }; + return kMask[raw_strategy]; +} + +void DumpAcStrategy(const AcStrategyImage& ac_strategy, size_t xsize, + size_t ysize, const char* tag, AuxOut* aux_out, + const CompressParams& cparams) { + Image3F color_acs(xsize, ysize); + for (size_t y = 0; y < ysize; y++) { + float* JXL_RESTRICT rows[3] = { + color_acs.PlaneRow(0, y), + color_acs.PlaneRow(1, y), + color_acs.PlaneRow(2, y), + }; + const AcStrategyRow acs_row = ac_strategy.ConstRow(y / kBlockDim); + for (size_t x = 0; x < xsize; x++) { + AcStrategy acs = acs_row[x / kBlockDim]; + const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy()); + for (size_t c = 0; c < 3; c++) { + rows[c][x] = color[c] / 255.f; + } + } + } + size_t stride = color_acs.PixelsPerRow(); + for (size_t c = 0; c < 3; c++) { + for (size_t by = 0; by < DivCeil(ysize, kBlockDim); by++) { + float* JXL_RESTRICT row = color_acs.PlaneRow(c, by * kBlockDim); + const AcStrategyRow acs_row = ac_strategy.ConstRow(by); + for (size_t bx = 0; bx < DivCeil(xsize, kBlockDim); bx++) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy()); + const uint8_t* JXL_RESTRICT mask = TypeMask(acs.RawStrategy()); + if (acs.covered_blocks_x() == 1 && acs.covered_blocks_y() == 1) { + for (size_t iy = 0; iy < kBlockDim && by * kBlockDim + iy < ysize; + iy++) { + for (size_t ix = 0; ix < kBlockDim && bx * kBlockDim + ix < xsize; + ix++) { + if (mask[iy * kBlockDim + ix]) { + row[iy * stride + bx * kBlockDim + ix] = color[c] / 800.f; + } + } + } + } + // draw block edges + for (size_t ix = 0; ix < kBlockDim * acs.covered_blocks_x() && + bx * kBlockDim + ix < xsize; + ix++) { + row[0 * stride + bx * kBlockDim + ix] = color[c] / 350.f; + } + for (size_t iy = 0; iy < kBlockDim * acs.covered_blocks_y() && + by * kBlockDim + iy < ysize; + iy++) { + row[iy * stride + bx * kBlockDim + 0] = color[c] / 350.f; + } + } + } + } + DumpImage(cparams, tag, color_acs); +} + +} // namespace +} // namespace jxl +#endif // LIB_JXL_ENC_AC_STRATEGY_ + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::AbsDiff; +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::IfThenElseZero; +using hwy::HWY_NAMESPACE::IfThenZeroElse; +using hwy::HWY_NAMESPACE::Round; +using hwy::HWY_NAMESPACE::Sqrt; + +bool MultiBlockTransformCrossesHorizontalBoundary( + const AcStrategyImage& ac_strategy, size_t start_x, size_t y, + size_t end_x) { + if (start_x >= ac_strategy.xsize() || y >= ac_strategy.ysize()) { + return false; + } + if (y % 8 == 0) { + // Nothing crosses 64x64 boundaries, and the memory on the other side + // of the 64x64 block may still uninitialized. + return false; + } + end_x = std::min(end_x, ac_strategy.xsize()); + // The first multiblock might be before the start_x, let's adjust it + // to point to the first IsFirstBlock() == true block we find by backward + // tracing. + AcStrategyRow row = ac_strategy.ConstRow(y); + const size_t start_x_limit = start_x & ~7; + while (start_x != start_x_limit && !row[start_x].IsFirstBlock()) { + --start_x; + } + for (size_t x = start_x; x < end_x;) { + if (row[x].IsFirstBlock()) { + x += row[x].covered_blocks_x(); + } else { + return true; + } + } + return false; +} + +bool MultiBlockTransformCrossesVerticalBoundary( + const AcStrategyImage& ac_strategy, size_t x, size_t start_y, + size_t end_y) { + if (x >= ac_strategy.xsize() || start_y >= ac_strategy.ysize()) { + return false; + } + if (x % 8 == 0) { + // Nothing crosses 64x64 boundaries, and the memory on the other side + // of the 64x64 block may still uninitialized. + return false; + } + end_y = std::min(end_y, ac_strategy.ysize()); + // The first multiblock might be before the start_y, let's adjust it + // to point to the first IsFirstBlock() == true block we find by backward + // tracing. + const size_t start_y_limit = start_y & ~7; + while (start_y != start_y_limit && + !ac_strategy.ConstRow(start_y)[x].IsFirstBlock()) { + --start_y; + } + + for (size_t y = start_y; y < end_y;) { + AcStrategyRow row = ac_strategy.ConstRow(y); + if (row[x].IsFirstBlock()) { + y += row[x].covered_blocks_y(); + } else { + return true; + } + } + return false; +} + +float EstimateEntropy(const AcStrategy& acs, float entropy_mul, size_t x, + size_t y, const ACSConfig& config, + const float* JXL_RESTRICT cmap_factors, float* block, + float* scratch_space, uint32_t* quantized) { + const size_t size = (1 << acs.log2_covered_blocks()) * kDCTBlockSize; + + // Apply transform. + for (size_t c = 0; c < 3; c++) { + float* JXL_RESTRICT block_c = block + size * c; + TransformFromPixels(acs.Strategy(), &config.Pixel(c, x, y), + config.src_stride, block_c, scratch_space); + } + HWY_FULL(float) df; + + const size_t num_blocks = acs.covered_blocks_x() * acs.covered_blocks_y(); + // avoid large blocks when there is a lot going on in red-green. + float quant_norm16 = 0; + if (num_blocks == 1) { + // When it is only one 8x8, we don't need aggregation of values. + quant_norm16 = config.Quant(x / 8, y / 8); + } else if (num_blocks == 2) { + // Taking max instead of 8th norm seems to work + // better for smallest blocks up to 16x8. Jyrki couldn't get + // improvements in trying the same for 16x16 blocks. + if (acs.covered_blocks_y() == 2) { + quant_norm16 = + std::max(config.Quant(x / 8, y / 8), config.Quant(x / 8, y / 8 + 1)); + } else { + quant_norm16 = + std::max(config.Quant(x / 8, y / 8), config.Quant(x / 8 + 1, y / 8)); + } + } else { + // Load QF value, calculate empirical heuristic on masking field + // for weighting the information loss. Information loss manifests + // itself as ringing, and masking could hide it. + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + float qval = config.Quant(x / 8 + ix, y / 8 + iy); + qval *= qval; + qval *= qval; + qval *= qval; + quant_norm16 += qval * qval; + } + } + quant_norm16 /= num_blocks; + quant_norm16 = FastPowf(quant_norm16, 1.0f / 16.0f); + } + const auto quant = Set(df, quant_norm16); + + // Compute entropy. + float entropy = 0.0f; + const HWY_CAPPED(float, 8) df8; + + auto mem_alloc = hwy::AllocateAligned<float>(AcStrategy::kMaxCoeffArea); + float* mem = mem_alloc.get(); + auto loss = Zero(df8); + for (size_t c = 0; c < 3; c++) { + const float* inv_matrix = config.dequant->InvMatrix(acs.RawStrategy(), c); + const float* matrix = config.dequant->Matrix(acs.RawStrategy(), c); + const auto cmap_factor = Set(df, cmap_factors[c]); + + auto entropy_v = Zero(df); + auto nzeros_v = Zero(df); + for (size_t i = 0; i < num_blocks * kDCTBlockSize; i += Lanes(df)) { + const auto in = Load(df, block + c * size + i); + const auto in_y = Mul(Load(df, block + size + i), cmap_factor); + const auto im = Load(df, inv_matrix + i); + const auto val = Mul(Sub(in, in_y), Mul(im, quant)); + const auto rval = Round(val); + const auto diff = Sub(val, rval); + const auto m = Load(df, matrix + i); + Store(Mul(m, diff), df, &mem[i]); + const auto q = Abs(rval); + const auto q_is_zero = Eq(q, Zero(df)); + // We used to have q * C here, but that cost model seems to + // be punishing large values more than necessary. Sqrt tries + // to avoid large values less aggressively. + entropy_v = Add(Sqrt(q), entropy_v); + nzeros_v = Add(nzeros_v, IfThenZeroElse(q_is_zero, Set(df, 1.0f))); + } + + { + auto lossc = Zero(df8); + TransformToPixels(acs.Strategy(), &mem[0], block, + acs.covered_blocks_x() * 8, scratch_space); + + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + for (size_t dy = 0; dy < kBlockDim; ++dy) { + for (size_t dx = 0; dx < kBlockDim; dx += Lanes(df8)) { + auto in = Load(df8, block + + (iy * kBlockDim + dy) * + (acs.covered_blocks_x() * kBlockDim) + + ix * kBlockDim + dx); + auto masku = Abs(Load( + df8, config.MaskingPtr1x1(x + ix * 8 + dx, y + iy * 8 + dy))); + in = Mul(masku, in); + in = Mul(in, in); + in = Mul(in, in); + in = Mul(in, in); + lossc = Add(lossc, in); + } + } + } + } + static const double kChannelMul[3] = { + 10.2, + 1.0, + 1.03, + }; + lossc = Mul(Set(df8, pow(kChannelMul[c], 8.0)), lossc); + loss = Add(loss, lossc); + } + entropy += config.cost_delta * GetLane(SumOfLanes(df, entropy_v)); + size_t num_nzeros = GetLane(SumOfLanes(df, nzeros_v)); + // Add #bit of num_nonzeros, as an estimate of the cost for encoding the + // number of non-zeros of the block. + size_t nbits = CeilLog2Nonzero(num_nzeros + 1) + 1; + // Also add #bit of #bit of num_nonzeros, to estimate the ANS cost, with a + // bias. + entropy += config.zeros_mul * (CeilLog2Nonzero(nbits + 17) + nbits); + } + float loss_scalar = + pow(GetLane(SumOfLanes(df8, loss)) / (num_blocks * kDCTBlockSize), + 1.0 / 8.0) * + (num_blocks * kDCTBlockSize) / quant_norm16; + float ret = entropy * entropy_mul; + ret += config.info_loss_multiplier * loss_scalar; + return ret; +} + +uint8_t FindBest8x8Transform(size_t x, size_t y, int encoding_speed_tier, + float butteraugli_target, const ACSConfig& config, + const float* JXL_RESTRICT cmap_factors, + AcStrategyImage* JXL_RESTRICT ac_strategy, + float* block, float* scratch_space, + uint32_t* quantized, float* entropy_out) { + struct TransformTry8x8 { + AcStrategy::Type type; + int encoding_speed_tier_max_limit; + double entropy_mul; + }; + static const TransformTry8x8 kTransforms8x8[] = { + { + AcStrategy::Type::DCT, + 9, + 0.8, + }, + { + AcStrategy::Type::DCT4X4, + 5, + 1.08, + }, + { + AcStrategy::Type::DCT2X2, + 5, + 0.95, + }, + { + AcStrategy::Type::DCT4X8, + 4, + 0.85931637428340035, + }, + { + AcStrategy::Type::DCT8X4, + 4, + 0.85931637428340035, + }, + { + AcStrategy::Type::IDENTITY, + 5, + 1.0427542510634957, + }, + { + AcStrategy::Type::AFV0, + 4, + 0.81779489591359944, + }, + { + AcStrategy::Type::AFV1, + 4, + 0.81779489591359944, + }, + { + AcStrategy::Type::AFV2, + 4, + 0.81779489591359944, + }, + { + AcStrategy::Type::AFV3, + 4, + 0.81779489591359944, + }, + }; + double best = 1e30; + uint8_t best_tx = kTransforms8x8[0].type; + for (auto tx : kTransforms8x8) { + if (tx.encoding_speed_tier_max_limit < encoding_speed_tier) { + continue; + } + AcStrategy acs = AcStrategy::FromRawStrategy(tx.type); + float entropy_mul = tx.entropy_mul / kTransforms8x8[0].entropy_mul; + if ((tx.type == AcStrategy::Type::DCT2X2 || + tx.type == AcStrategy::Type::IDENTITY) && + butteraugli_target < 5.0) { + static const float kFavor2X2AtHighQuality = 0.4; + float weight = pow((5.0f - butteraugli_target) / 5.0f, 2.0); + entropy_mul -= kFavor2X2AtHighQuality * weight; + } + if ((tx.type != AcStrategy::Type::DCT && + tx.type != AcStrategy::Type::DCT2X2 && + tx.type != AcStrategy::Type::IDENTITY) && + butteraugli_target > 4.0) { + static const float kAvoidEntropyOfTransforms = 0.5; + float mul = 1.0; + if (butteraugli_target < 12.0) { + mul *= (12.0 - 4.0) / (butteraugli_target - 4.0); + } + entropy_mul += kAvoidEntropyOfTransforms * mul; + } + float entropy = + EstimateEntropy(acs, entropy_mul, x, y, config, cmap_factors, block, + scratch_space, quantized); + if (entropy < best) { + best_tx = tx.type; + best = entropy; + } + } + *entropy_out = best; + return best_tx; +} + +// bx, by addresses the 64x64 block at 8x8 subresolution +// cx, cy addresses the left, upper 8x8 block position of the candidate +// transform. +void TryMergeAcs(AcStrategy::Type acs_raw, size_t bx, size_t by, size_t cx, + size_t cy, const ACSConfig& config, + const float* JXL_RESTRICT cmap_factors, + AcStrategyImage* JXL_RESTRICT ac_strategy, + const float entropy_mul, const uint8_t candidate_priority, + uint8_t* priority, float* JXL_RESTRICT entropy_estimate, + float* block, float* scratch_space, uint32_t* quantized) { + AcStrategy acs = AcStrategy::FromRawStrategy(acs_raw); + float entropy_current = 0; + for (size_t iy = 0; iy < acs.covered_blocks_y(); ++iy) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ++ix) { + if (priority[(cy + iy) * 8 + (cx + ix)] >= candidate_priority) { + // Transform would reuse already allocated blocks and + // lead to invalid overlaps, for example DCT64X32 vs. + // DCT32X64. + return; + } + entropy_current += entropy_estimate[(cy + iy) * 8 + (cx + ix)]; + } + } + float entropy_candidate = + EstimateEntropy(acs, entropy_mul, (bx + cx) * 8, (by + cy) * 8, config, + cmap_factors, block, scratch_space, quantized); + if (entropy_candidate >= entropy_current) return; + // Accept the candidate. + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + entropy_estimate[(cy + iy) * 8 + cx + ix] = 0; + priority[(cy + iy) * 8 + cx + ix] = candidate_priority; + } + } + ac_strategy->Set(bx + cx, by + cy, acs_raw); + entropy_estimate[cy * 8 + cx] = entropy_candidate; +} + +static void SetEntropyForTransform(size_t cx, size_t cy, + const AcStrategy::Type acs_raw, + float entropy, + float* JXL_RESTRICT entropy_estimate) { + const AcStrategy acs = AcStrategy::FromRawStrategy(acs_raw); + for (size_t dy = 0; dy < acs.covered_blocks_y(); ++dy) { + for (size_t dx = 0; dx < acs.covered_blocks_x(); ++dx) { + entropy_estimate[(cy + dy) * 8 + cx + dx] = 0.0; + } + } + entropy_estimate[cy * 8 + cx] = entropy; +} + +AcStrategy::Type AcsSquare(size_t blocks) { + if (blocks == 2) { + return AcStrategy::Type::DCT16X16; + } else if (blocks == 4) { + return AcStrategy::Type::DCT32X32; + } else { + return AcStrategy::Type::DCT64X64; + } +} + +AcStrategy::Type AcsVerticalSplit(size_t blocks) { + if (blocks == 2) { + return AcStrategy::Type::DCT16X8; + } else if (blocks == 4) { + return AcStrategy::Type::DCT32X16; + } else { + return AcStrategy::Type::DCT64X32; + } +} + +AcStrategy::Type AcsHorizontalSplit(size_t blocks) { + if (blocks == 2) { + return AcStrategy::Type::DCT8X16; + } else if (blocks == 4) { + return AcStrategy::Type::DCT16X32; + } else { + return AcStrategy::Type::DCT32X64; + } +} + +// The following function tries to merge smaller transforms into +// squares and the rectangles originating from a single middle division +// (horizontal or vertical) fairly. +// +// This is now generalized to concern about squares +// of blocks X blocks size, where a block is 8x8 pixels. +void FindBestFirstLevelDivisionForSquare( + size_t blocks, bool allow_square_transform, size_t bx, size_t by, size_t cx, + size_t cy, const ACSConfig& config, const float* JXL_RESTRICT cmap_factors, + AcStrategyImage* JXL_RESTRICT ac_strategy, const float entropy_mul_JXK, + const float entropy_mul_JXJ, float* JXL_RESTRICT entropy_estimate, + float* block, float* scratch_space, uint32_t* quantized) { + // We denote J for the larger dimension here, and K for the smaller. + // For example, for 32x32 block splitting, J would be 32, K 16. + const size_t blocks_half = blocks / 2; + const AcStrategy::Type acs_rawJXK = AcsVerticalSplit(blocks); + const AcStrategy::Type acs_rawKXJ = AcsHorizontalSplit(blocks); + const AcStrategy::Type acs_rawJXJ = AcsSquare(blocks); + const AcStrategy acsJXK = AcStrategy::FromRawStrategy(acs_rawJXK); + const AcStrategy acsKXJ = AcStrategy::FromRawStrategy(acs_rawKXJ); + const AcStrategy acsJXJ = AcStrategy::FromRawStrategy(acs_rawJXJ); + AcStrategyRow row0 = ac_strategy->ConstRow(by + cy + 0); + AcStrategyRow row1 = ac_strategy->ConstRow(by + cy + blocks_half); + // Let's check if we can consider a JXJ block here at all. + // This is not necessary in the basic use of hierarchically merging + // blocks in the simplest possible way, but is needed when we try other + // 'floating' options of merging, possibly after a simple hierarchical + // merge has been explored. + if (MultiBlockTransformCrossesHorizontalBoundary(*ac_strategy, bx + cx, + by + cy, bx + cx + blocks) || + MultiBlockTransformCrossesHorizontalBoundary( + *ac_strategy, bx + cx, by + cy + blocks, bx + cx + blocks) || + MultiBlockTransformCrossesVerticalBoundary(*ac_strategy, bx + cx, by + cy, + by + cy + blocks) || + MultiBlockTransformCrossesVerticalBoundary(*ac_strategy, bx + cx + blocks, + by + cy, by + cy + blocks)) { + return; // not suitable for JxJ analysis, some transforms leak out. + } + // For floating transforms there may be + // already blocks selected that make either or both JXK and + // KXJ not feasible for this location. + const bool allow_JXK = !MultiBlockTransformCrossesVerticalBoundary( + *ac_strategy, bx + cx + blocks_half, by + cy, by + cy + blocks); + const bool allow_KXJ = !MultiBlockTransformCrossesHorizontalBoundary( + *ac_strategy, bx + cx, by + cy + blocks_half, bx + cx + blocks); + // Current entropies aggregated on NxN resolution. + float entropy[2][2] = {}; + for (size_t dy = 0; dy < blocks; ++dy) { + for (size_t dx = 0; dx < blocks; ++dx) { + entropy[dy / blocks_half][dx / blocks_half] += + entropy_estimate[(cy + dy) * 8 + (cx + dx)]; + } + } + float entropy_JXK_left = std::numeric_limits<float>::max(); + float entropy_JXK_right = std::numeric_limits<float>::max(); + float entropy_KXJ_top = std::numeric_limits<float>::max(); + float entropy_KXJ_bottom = std::numeric_limits<float>::max(); + float entropy_JXJ = std::numeric_limits<float>::max(); + if (allow_JXK) { + if (row0[bx + cx + 0].RawStrategy() != acs_rawJXK) { + entropy_JXK_left = EstimateEntropy( + acsJXK, entropy_mul_JXK, (bx + cx + 0) * 8, (by + cy + 0) * 8, config, + cmap_factors, block, scratch_space, quantized); + } + if (row0[bx + cx + blocks_half].RawStrategy() != acs_rawJXK) { + entropy_JXK_right = + EstimateEntropy(acsJXK, entropy_mul_JXK, (bx + cx + blocks_half) * 8, + (by + cy + 0) * 8, config, cmap_factors, block, + scratch_space, quantized); + } + } + if (allow_KXJ) { + if (row0[bx + cx].RawStrategy() != acs_rawKXJ) { + entropy_KXJ_top = EstimateEntropy( + acsKXJ, entropy_mul_JXK, (bx + cx + 0) * 8, (by + cy + 0) * 8, config, + cmap_factors, block, scratch_space, quantized); + } + if (row1[bx + cx].RawStrategy() != acs_rawKXJ) { + entropy_KXJ_bottom = + EstimateEntropy(acsKXJ, entropy_mul_JXK, (bx + cx + 0) * 8, + (by + cy + blocks_half) * 8, config, cmap_factors, + block, scratch_space, quantized); + } + } + if (allow_square_transform) { + // We control the exploration of the square transform separately so that + // we can turn it off at high decoding speeds for 32x32, but still allow + // exploring 16x32 and 32x16. + entropy_JXJ = EstimateEntropy(acsJXJ, entropy_mul_JXJ, (bx + cx + 0) * 8, + (by + cy + 0) * 8, config, cmap_factors, + block, scratch_space, quantized); + } + + // Test if this block should have JXK or KXJ transforms, + // because it can have only one or the other. + float costJxN = std::min(entropy_JXK_left, entropy[0][0] + entropy[1][0]) + + std::min(entropy_JXK_right, entropy[0][1] + entropy[1][1]); + float costNxJ = std::min(entropy_KXJ_top, entropy[0][0] + entropy[0][1]) + + std::min(entropy_KXJ_bottom, entropy[1][0] + entropy[1][1]); + if (entropy_JXJ < costJxN && entropy_JXJ < costNxJ) { + ac_strategy->Set(bx + cx, by + cy, acs_rawJXJ); + SetEntropyForTransform(cx, cy, acs_rawJXJ, entropy_JXJ, entropy_estimate); + } else if (costJxN < costNxJ) { + if (entropy_JXK_left < entropy[0][0] + entropy[1][0]) { + ac_strategy->Set(bx + cx, by + cy, acs_rawJXK); + SetEntropyForTransform(cx, cy, acs_rawJXK, entropy_JXK_left, + entropy_estimate); + } + if (entropy_JXK_right < entropy[0][1] + entropy[1][1]) { + ac_strategy->Set(bx + cx + blocks_half, by + cy, acs_rawJXK); + SetEntropyForTransform(cx + blocks_half, cy, acs_rawJXK, + entropy_JXK_right, entropy_estimate); + } + } else { + if (entropy_KXJ_top < entropy[0][0] + entropy[0][1]) { + ac_strategy->Set(bx + cx, by + cy, acs_rawKXJ); + SetEntropyForTransform(cx, cy, acs_rawKXJ, entropy_KXJ_top, + entropy_estimate); + } + if (entropy_KXJ_bottom < entropy[1][0] + entropy[1][1]) { + ac_strategy->Set(bx + cx, by + cy + blocks_half, acs_rawKXJ); + SetEntropyForTransform(cx, cy + blocks_half, acs_rawKXJ, + entropy_KXJ_bottom, entropy_estimate); + } + } +} + +void ProcessRectACS(const CompressParams& cparams, const ACSConfig& config, + const Rect& rect, const ColorCorrelationMap& cmap, + AcStrategyImage* ac_strategy) { + // Main philosophy here: + // 1. First find best 8x8 transform for each area. + // 2. Merging them into larger transforms where possibly, but + // starting from the smallest transforms (16x8 and 8x16). + // Additional complication: 16x8 and 8x16 are considered + // simultaneously and fairly against each other. + // We are looking at 64x64 squares since the YtoX and YtoB + // maps happen to be at that resolution, and having + // integral transforms cross these boundaries leads to + // additional complications. + const float butteraugli_target = cparams.butteraugli_distance; + const size_t dct_scratch_size = + 3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim; + // TODO(veluca): reuse allocations + auto mem = hwy::AllocateAligned<float>(5 * AcStrategy::kMaxCoeffArea + + dct_scratch_size); + auto qmem = hwy::AllocateAligned<uint32_t>(AcStrategy::kMaxCoeffArea); + uint32_t* JXL_RESTRICT quantized = qmem.get(); + float* JXL_RESTRICT block = mem.get(); + float* JXL_RESTRICT scratch_space = mem.get() + 3 * AcStrategy::kMaxCoeffArea; + size_t bx = rect.x0(); + size_t by = rect.y0(); + JXL_ASSERT(rect.xsize() <= 8); + JXL_ASSERT(rect.ysize() <= 8); + size_t tx = bx / kColorTileDimInBlocks; + size_t ty = by / kColorTileDimInBlocks; + const float cmap_factors[3] = { + cmap.YtoXRatio(cmap.ytox_map.ConstRow(ty)[tx]), + 0.0f, + cmap.YtoBRatio(cmap.ytob_map.ConstRow(ty)[tx]), + }; + if (cparams.speed_tier > SpeedTier::kHare) return; + // First compute the best 8x8 transform for each square. Later, we do not + // experiment with different combinations, but only use the best of the 8x8s + // when DCT8X8 is specified in the tree search. + // 8x8 transforms have 10 variants, but every larger transform is just a DCT. + float entropy_estimate[64] = {}; + // Favor all 8x8 transforms (against 16x8 and larger transforms)) at + // low butteraugli_target distances. + static const float k8x8mul1 = -0.4; + static const float k8x8mul2 = 1.0; + static const float k8x8base = 1.4; + const float mul8x8 = k8x8mul2 + k8x8mul1 / (butteraugli_target + k8x8base); + for (size_t iy = 0; iy < rect.ysize(); iy++) { + for (size_t ix = 0; ix < rect.xsize(); ix++) { + float entropy = 0.0; + const uint8_t best_of_8x8s = FindBest8x8Transform( + 8 * (bx + ix), 8 * (by + iy), static_cast<int>(cparams.speed_tier), + butteraugli_target, config, cmap_factors, ac_strategy, block, + scratch_space, quantized, &entropy); + ac_strategy->Set(bx + ix, by + iy, + static_cast<AcStrategy::Type>(best_of_8x8s)); + entropy_estimate[iy * 8 + ix] = entropy * mul8x8; + } + } + // Merge when a larger transform is better than the previously + // searched best combination of 8x8 transforms. + struct MergeTry { + AcStrategy::Type type; + uint8_t priority; + uint8_t decoding_speed_tier_max_limit; + uint8_t encoding_speed_tier_max_limit; + float entropy_mul; + }; + // These numbers need to be figured out manually and looking at + // ringing next to sky etc. Optimization will find larger numbers + // and produce more ringing than is ideal. Larger numbers will + // help stop ringing. + const float entropy_mul16X8 = 1.25; + const float entropy_mul16X16 = 1.35; + const float entropy_mul16X32 = 1.5; + const float entropy_mul32X32 = 1.5; + const float entropy_mul64X32 = 2.26; + const float entropy_mul64X64 = 2.26; + // TODO(jyrki): Consider this feedback in further changes: + // Also effectively when the multipliers for smaller blocks are + // below 1, this raises the bar for the bigger blocks even higher + // in that sense these constants are not independent (e.g. changing + // the constant for DCT16x32 by -5% (making it more likely) also + // means that DCT32x32 becomes harder to do when starting from + // two DCT16x32s). It might be better to make them more independent, + // e.g. by not applying the multiplier when storing the new entropy + // estimates in TryMergeToACSCandidate(). + const MergeTry kTransformsForMerge[9] = { + {AcStrategy::Type::DCT16X8, 2, 4, 5, entropy_mul16X8}, + {AcStrategy::Type::DCT8X16, 2, 4, 5, entropy_mul16X8}, + // FindBestFirstLevelDivisionForSquare looks for DCT16X16 and its + // subdivisions. {AcStrategy::Type::DCT16X16, 3, entropy_mul16X16}, + {AcStrategy::Type::DCT16X32, 4, 4, 4, entropy_mul16X32}, + {AcStrategy::Type::DCT32X16, 4, 4, 4, entropy_mul16X32}, + // FindBestFirstLevelDivisionForSquare looks for DCT32X32 and its + // subdivisions. {AcStrategy::Type::DCT32X32, 5, 1, 5, + // 0.9822994906548809f}, + {AcStrategy::Type::DCT64X32, 6, 1, 3, entropy_mul64X32}, + {AcStrategy::Type::DCT32X64, 6, 1, 3, entropy_mul64X32}, + // {AcStrategy::Type::DCT64X64, 8, 1, 3, 2.0846542128012948f}, + }; + /* + These sizes not yet included in merge heuristic: + set(AcStrategy::Type::DCT32X8, 0.0f, 2.261390410971102f); + set(AcStrategy::Type::DCT8X32, 0.0f, 2.261390410971102f); + set(AcStrategy::Type::DCT128X128, 0.0f, 1.0f); + set(AcStrategy::Type::DCT128X64, 0.0f, 0.73f); + set(AcStrategy::Type::DCT64X128, 0.0f, 0.73f); + set(AcStrategy::Type::DCT256X256, 0.0f, 1.0f); + set(AcStrategy::Type::DCT256X128, 0.0f, 0.73f); + set(AcStrategy::Type::DCT128X256, 0.0f, 0.73f); + */ + + // Priority is a tricky kludge to avoid collisions so that transforms + // don't overlap. + uint8_t priority[64] = {}; + bool enable_32x32 = cparams.decoding_speed_tier < 4; + for (auto tx : kTransformsForMerge) { + if (tx.decoding_speed_tier_max_limit < cparams.decoding_speed_tier) { + continue; + } + AcStrategy acs = AcStrategy::FromRawStrategy(tx.type); + + for (size_t cy = 0; cy + acs.covered_blocks_y() - 1 < rect.ysize(); + cy += acs.covered_blocks_y()) { + for (size_t cx = 0; cx + acs.covered_blocks_x() - 1 < rect.xsize(); + cx += acs.covered_blocks_x()) { + if (cy + 7 < rect.ysize() && cx + 7 < rect.xsize()) { + if (cparams.decoding_speed_tier < 4 && + tx.type == AcStrategy::Type::DCT32X64) { + // We handle both DCT8X16 and DCT16X8 at the same time. + if ((cy | cx) % 8 == 0) { + FindBestFirstLevelDivisionForSquare( + 8, true, bx, by, cx, cy, config, cmap_factors, ac_strategy, + tx.entropy_mul, entropy_mul64X64, entropy_estimate, block, + scratch_space, quantized); + } + continue; + } else if (tx.type == AcStrategy::Type::DCT32X16) { + // We handled both DCT8X16 and DCT16X8 at the same time, + // and that is above. The last column and last row, + // when the last column or last row is odd numbered, + // are still handled by TryMergeAcs. + continue; + } + } + if ((tx.type == AcStrategy::Type::DCT16X32 && cy % 4 != 0) || + (tx.type == AcStrategy::Type::DCT32X16 && cx % 4 != 0)) { + // already covered by FindBest32X32 + continue; + } + + if (cy + 3 < rect.ysize() && cx + 3 < rect.xsize()) { + if (tx.type == AcStrategy::Type::DCT16X32) { + // We handle both DCT8X16 and DCT16X8 at the same time. + if ((cy | cx) % 4 == 0) { + FindBestFirstLevelDivisionForSquare( + 4, enable_32x32, bx, by, cx, cy, config, cmap_factors, + ac_strategy, tx.entropy_mul, entropy_mul32X32, + entropy_estimate, block, scratch_space, quantized); + } + continue; + } else if (tx.type == AcStrategy::Type::DCT32X16) { + // We handled both DCT8X16 and DCT16X8 at the same time, + // and that is above. The last column and last row, + // when the last column or last row is odd numbered, + // are still handled by TryMergeAcs. + continue; + } + } + if ((tx.type == AcStrategy::Type::DCT16X32 && cy % 4 != 0) || + (tx.type == AcStrategy::Type::DCT32X16 && cx % 4 != 0)) { + // already covered by FindBest32X32 + continue; + } + if (cy + 1 < rect.ysize() && cx + 1 < rect.xsize()) { + if (tx.type == AcStrategy::Type::DCT8X16) { + // We handle both DCT8X16 and DCT16X8 at the same time. + if ((cy | cx) % 2 == 0) { + FindBestFirstLevelDivisionForSquare( + 2, true, bx, by, cx, cy, config, cmap_factors, ac_strategy, + tx.entropy_mul, entropy_mul16X16, entropy_estimate, block, + scratch_space, quantized); + } + continue; + } else if (tx.type == AcStrategy::Type::DCT16X8) { + // We handled both DCT8X16 and DCT16X8 at the same time, + // and that is above. The last column and last row, + // when the last column or last row is odd numbered, + // are still handled by TryMergeAcs. + continue; + } + } + if ((tx.type == AcStrategy::Type::DCT8X16 && cy % 2 == 1) || + (tx.type == AcStrategy::Type::DCT16X8 && cx % 2 == 1)) { + // already covered by FindBestFirstLevelDivisionForSquare + continue; + } + // All other merge sizes are handled here. + // Some of the DCT16X8s and DCT8X16s will still leak through here + // when there is an odd number of 8x8 blocks, then the last row + // and column will get their DCT16X8s and DCT8X16s through the + // normal integral transform merging process. + TryMergeAcs(tx.type, bx, by, cx, cy, config, cmap_factors, ac_strategy, + tx.entropy_mul, tx.priority, &priority[0], entropy_estimate, + block, scratch_space, quantized); + } + } + } + if (cparams.speed_tier >= SpeedTier::kHare) { + return; + } + // Here we still try to do some non-aligned matching, find a few more + // 16X8, 8X16 and 16X16s between the non-2-aligned blocks. + for (size_t cy = 0; cy + 1 < rect.ysize(); ++cy) { + for (size_t cx = 0; cx + 1 < rect.xsize(); ++cx) { + if ((cy | cx) % 2 != 0) { + FindBestFirstLevelDivisionForSquare( + 2, true, bx, by, cx, cy, config, cmap_factors, ac_strategy, + entropy_mul16X8, entropy_mul16X16, entropy_estimate, block, + scratch_space, quantized); + } + } + } + // Non-aligned matching for 32X32, 16X32 and 32X16. + size_t step = cparams.speed_tier >= SpeedTier::kTortoise ? 2 : 1; + for (size_t cy = 0; cy + 3 < rect.ysize(); cy += step) { + for (size_t cx = 0; cx + 3 < rect.xsize(); cx += step) { + if ((cy | cx) % 4 == 0) { + continue; // Already tried with loop above (DCT16X32 case). + } + FindBestFirstLevelDivisionForSquare( + 4, enable_32x32, bx, by, cx, cy, config, cmap_factors, ac_strategy, + entropy_mul16X32, entropy_mul32X32, entropy_estimate, block, + scratch_space, quantized); + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ProcessRectACS); + +void AcStrategyHeuristics::Init(const Image3F& src, const Rect& rect_in, + const ImageF& quant_field, const ImageF& mask, + const ImageF& mask1x1, + DequantMatrices* matrices) { + config.dequant = matrices; + + if (cparams.speed_tier >= SpeedTier::kCheetah) { + JXL_CHECK(matrices->EnsureComputed(1)); // DCT8 only + } else { + uint32_t acs_mask = 0; + // All transforms up to 64x64. + for (size_t i = 0; i < AcStrategy::DCT128X128; i++) { + acs_mask |= (1 << i); + } + JXL_CHECK(matrices->EnsureComputed(acs_mask)); + } + + // Image row pointers and strides. + config.quant_field_row = quant_field.Row(0); + config.quant_field_stride = quant_field.PixelsPerRow(); + if (mask.xsize() > 0 && mask.ysize() > 0) { + config.masking_field_row = mask.Row(0); + config.masking_field_stride = mask.PixelsPerRow(); + } + if (mask1x1.xsize() > 0 && mask1x1.ysize() > 0) { + config.masking1x1_field_row = mask1x1.Row(0); + config.masking1x1_field_stride = mask1x1.PixelsPerRow(); + } + + config.src_rows[0] = rect_in.ConstPlaneRow(src, 0, 0); + config.src_rows[1] = rect_in.ConstPlaneRow(src, 1, 0); + config.src_rows[2] = rect_in.ConstPlaneRow(src, 2, 0); + config.src_stride = src.PixelsPerRow(); + + // Entropy estimate is composed of two factors: + // - estimate of the number of bits that will be used by the block + // - information loss due to quantization + // The following constant controls the relative weights of these components. + config.info_loss_multiplier = 1.2; + config.zeros_mul = 9.3089059022677905; + config.cost_delta = 10.833273317067883; + + static const float kBias = 0.13731742964354549; + const float ratio = (cparams.butteraugli_distance + kBias) / (1.0f + kBias); + + static const float kPow1 = 0.33677806662454718; + static const float kPow2 = 0.50990926717963703; + static const float kPow3 = 0.36702940662370243; + config.info_loss_multiplier *= pow(ratio, kPow1); + config.zeros_mul *= pow(ratio, kPow2); + config.cost_delta *= pow(ratio, kPow3); +} + +void AcStrategyHeuristics::ProcessRect(const Rect& rect, + const ColorCorrelationMap& cmap, + AcStrategyImage* ac_strategy) { + // In Falcon mode, use DCT8 everywhere and uniform quantization. + if (cparams.speed_tier >= SpeedTier::kCheetah) { + ac_strategy->FillDCT8(rect); + return; + } + HWY_DYNAMIC_DISPATCH(ProcessRectACS) + (cparams, config, rect, cmap, ac_strategy); +} + +void AcStrategyHeuristics::Finalize(const FrameDimensions& frame_dim, + const AcStrategyImage& ac_strategy, + AuxOut* aux_out) { + // Accounting and debug output. + if (aux_out != nullptr) { + aux_out->num_small_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::IDENTITY) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT2X2) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT4X4); + aux_out->num_dct4x8_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT4X8) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT8X4); + aux_out->num_afv_blocks = ac_strategy.CountBlocks(AcStrategy::Type::AFV0) + + ac_strategy.CountBlocks(AcStrategy::Type::AFV1) + + ac_strategy.CountBlocks(AcStrategy::Type::AFV2) + + ac_strategy.CountBlocks(AcStrategy::Type::AFV3); + aux_out->num_dct8_blocks = ac_strategy.CountBlocks(AcStrategy::Type::DCT); + aux_out->num_dct8x16_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT8X16) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT16X8); + aux_out->num_dct8x32_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT8X32) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT32X8); + aux_out->num_dct16_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT16X16); + aux_out->num_dct16x32_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT16X32) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT32X16); + aux_out->num_dct32_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT32X32); + aux_out->num_dct32x64_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT32X64) + + ac_strategy.CountBlocks(AcStrategy::Type::DCT64X32); + aux_out->num_dct64_blocks = + ac_strategy.CountBlocks(AcStrategy::Type::DCT64X64); + } + + // if (JXL_DEBUG_AC_STRATEGY && WantDebugOutput(aux_out)) { + if (JXL_DEBUG_AC_STRATEGY && WantDebugOutput(cparams)) { + DumpAcStrategy(ac_strategy, frame_dim.xsize, frame_dim.ysize, "ac_strategy", + aux_out, cparams); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.h b/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.h new file mode 100644 index 0000000000..9f6d92a6f7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ac_strategy.h @@ -0,0 +1,75 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_AC_STRATEGY_H_ +#define LIB_JXL_ENC_AC_STRATEGY_H_ + +#include <cstddef> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quant_weights.h" + +// `FindBestAcStrategy` uses heuristics to choose which AC strategy should be +// used in each block, as well as the initial quantization field. + +namespace jxl { + +struct AuxOut; + +// AC strategy selection: utility struct. + +struct ACSConfig { + const DequantMatrices* JXL_RESTRICT dequant; + const float* JXL_RESTRICT quant_field_row; + size_t quant_field_stride; + const float* JXL_RESTRICT masking_field_row; + size_t masking_field_stride; + const float* JXL_RESTRICT masking1x1_field_row; + size_t masking1x1_field_stride; + const float* JXL_RESTRICT src_rows[3]; + size_t src_stride; + float info_loss_multiplier; + float cost_delta; + float zeros_mul; + const float& Pixel(size_t c, size_t x, size_t y) const { + return src_rows[c][y * src_stride + x]; + } + float Masking(size_t bx, size_t by) const { + JXL_DASSERT(masking_field_row[by * masking_field_stride + bx] > 0); + return masking_field_row[by * masking_field_stride + bx]; + } + const float* MaskingPtr1x1(size_t bx, size_t by) const { + JXL_DASSERT(masking1x1_field_row[by * masking1x1_field_stride + bx] > 0); + return &masking1x1_field_row[by * masking1x1_field_stride + bx]; + } + float Quant(size_t bx, size_t by) const { + JXL_DASSERT(quant_field_row[by * quant_field_stride + bx] > 0); + return quant_field_row[by * quant_field_stride + bx]; + } +}; + +struct AcStrategyHeuristics { + AcStrategyHeuristics(const CompressParams& cparams) : cparams(cparams) {} + void Init(const Image3F& src, const Rect& rect_in, const ImageF& quant_field, + const ImageF& mask, const ImageF& mask1x1, + DequantMatrices* matrices); + void ProcessRect(const Rect& rect, const ColorCorrelationMap& cmap, + AcStrategyImage* ac_strategy); + void Finalize(const FrameDimensions& frame_dim, + const AcStrategyImage& ac_strategy, AuxOut* aux_out); + const CompressParams& cparams; + ACSConfig config; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_AC_STRATEGY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc new file mode 100644 index 0000000000..ae4cd3bd3b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc @@ -0,0 +1,1170 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_adaptive_quantization.h" + +#include <stddef.h> +#include <stdlib.h> + +#include <algorithm> +#include <cmath> +#include <string> +#include <vector> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_adaptive_quantization.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/fast_math-inl.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/cms/opsin_params.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_group.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_debug_image.h" +#include "lib/jxl/enc_group.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/quant_weights.h" + +// Set JXL_DEBUG_ADAPTIVE_QUANTIZATION to 1 to enable debugging. +#ifndef JXL_DEBUG_ADAPTIVE_QUANTIZATION +#define JXL_DEBUG_ADAPTIVE_QUANTIZATION 0 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::AbsDiff; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::And; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::Sqrt; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +// The following functions modulate an exponent (out_val) and return the updated +// value. Their descriptor is limited to 8 lanes for 8x8 blocks. + +// Hack for mask estimation. Eventually replace this code with butteraugli's +// masking. +float ComputeMaskForAcStrategyUse(const float out_val) { + const float kMul = 1.0f; + const float kOffset = 0.001f; + return kMul / (out_val + kOffset); +} + +template <class D, class V> +V ComputeMask(const D d, const V out_val) { + const auto kBase = Set(d, -0.7647f); + const auto kMul4 = Set(d, 9.4708735624378946f); + const auto kMul2 = Set(d, 17.35036561631863f); + const auto kOffset2 = Set(d, 302.59587815579727f); + const auto kMul3 = Set(d, 6.7943250517376494f); + const auto kOffset3 = Set(d, 3.7179635626140772f); + const auto kOffset4 = Mul(Set(d, 0.25f), kOffset3); + const auto kMul0 = Set(d, 0.80061762862741759f); + const auto k1 = Set(d, 1.0f); + + // Avoid division by zero. + const auto v1 = Max(Mul(out_val, kMul0), Set(d, 1e-3f)); + const auto v2 = Div(k1, Add(v1, kOffset2)); + const auto v3 = Div(k1, MulAdd(v1, v1, kOffset3)); + const auto v4 = Div(k1, MulAdd(v1, v1, kOffset4)); + // TODO(jyrki): + // A log or two here could make sense. In butteraugli we have effectively + // log(log(x + C)) for this kind of use, as a single log is used in + // saturating visual masking and here the modulation values are exponential, + // another log would counter that. + return Add(kBase, MulAdd(kMul4, v4, MulAdd(kMul2, v2, Mul(kMul3, v3)))); +} + +// mul and mul2 represent a scaling difference between jxl and butteraugli. +static const float kSGmul = 226.77216153508914f; +static const float kSGmul2 = 1.0f / 73.377132366608819f; +static const float kLog2 = 0.693147181f; +// Includes correction factor for std::log -> log2. +static const float kSGRetMul = kSGmul2 * 18.6580932135f * kLog2; +static const float kSGVOffset = 7.7825991679894591f; + +template <bool invert, typename D, typename V> +V RatioOfDerivativesOfCubicRootToSimpleGamma(const D d, V v) { + // The opsin space in jxl is the cubic root of photons, i.e., v * v * v + // is related to the number of photons. + // + // SimpleGamma(v * v * v) is the psychovisual space in butteraugli. + // This ratio allows quantization to move from jxl's opsin space to + // butteraugli's log-gamma space. + float kEpsilon = 1e-2; + v = ZeroIfNegative(v); + const auto kNumMul = Set(d, kSGRetMul * 3 * kSGmul); + const auto kVOffset = Set(d, kSGVOffset * kLog2 + kEpsilon); + const auto kDenMul = Set(d, kLog2 * kSGmul); + + const auto v2 = Mul(v, v); + + const auto num = MulAdd(kNumMul, v2, Set(d, kEpsilon)); + const auto den = MulAdd(Mul(kDenMul, v), v2, kVOffset); + return invert ? Div(num, den) : Div(den, num); +} + +template <bool invert = false> +static float RatioOfDerivativesOfCubicRootToSimpleGamma(float v) { + using DScalar = HWY_CAPPED(float, 1); + auto vscalar = Load(DScalar(), &v); + return GetLane( + RatioOfDerivativesOfCubicRootToSimpleGamma<invert>(DScalar(), vscalar)); +} + +// TODO(veluca): this function computes an approximation of the derivative of +// SimpleGamma with (f(x+eps)-f(x))/eps. Consider two-sided approximation or +// exact derivatives. For reference, SimpleGamma was: +/* +template <typename D, typename V> +V SimpleGamma(const D d, V v) { + // A simple HDR compatible gamma function. + const auto mul = Set(d, kSGmul); + const auto kRetMul = Set(d, kSGRetMul); + const auto kRetAdd = Set(d, kSGmul2 * -20.2789020414f); + const auto kVOffset = Set(d, kSGVOffset); + + v *= mul; + + // This should happen rarely, but may lead to a NaN, which is rather + // undesirable. Since negative photons don't exist we solve the NaNs by + // clamping here. + // TODO(veluca): with FastLog2f, this no longer leads to NaNs. + v = ZeroIfNegative(v); + return kRetMul * FastLog2f(d, v + kVOffset) + kRetAdd; +} +*/ + +template <class D, class V> +V GammaModulation(const D d, const size_t x, const size_t y, + const ImageF& xyb_x, const ImageF& xyb_y, const Rect& rect, + const V out_val) { + const float kBias = 0.16f; + JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[0]); + JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[1]); + JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[2]); + auto overall_ratio = Zero(d); + auto bias = Set(d, kBias); + auto half = Set(d, 0.5f); + for (size_t dy = 0; dy < 8; ++dy) { + const float* const JXL_RESTRICT row_in_x = rect.ConstRow(xyb_x, y + dy); + const float* const JXL_RESTRICT row_in_y = rect.ConstRow(xyb_y, y + dy); + for (size_t dx = 0; dx < 8; dx += Lanes(d)) { + const auto iny = Add(Load(d, row_in_y + x + dx), bias); + const auto inx = Load(d, row_in_x + x + dx); + const auto r = Sub(iny, inx); + const auto g = Add(iny, inx); + const auto ratio_r = + RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, r); + const auto ratio_g = + RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, g); + const auto avg_ratio = Mul(half, Add(ratio_r, ratio_g)); + + overall_ratio = Add(overall_ratio, avg_ratio); + } + } + overall_ratio = Mul(SumOfLanes(d, overall_ratio), Set(d, 1.0f / 64)); + // ideally -1.0, but likely optimal correction adds some entropy, so slightly + // less than that. + // ln(2) constant folded in because we want std::log but have FastLog2f. + static const float v = 0.14507933746197058f; + const auto kGam = Set(d, v * 0.693147180559945f); + return MulAdd(kGam, FastLog2f(d, overall_ratio), out_val); +} + +// Change precision in 8x8 blocks that have high frequency content. +template <class D, class V> +V HfModulation(const D d, const size_t x, const size_t y, const ImageF& xyb, + const Rect& rect, const V out_val) { + // Zero out the invalid differences for the rightmost value per row. + const Rebind<uint32_t, D> du; + HWY_ALIGN constexpr uint32_t kMaskRight[kBlockDim] = {~0u, ~0u, ~0u, ~0u, + ~0u, ~0u, ~0u, 0}; + + auto sum = Zero(d); // sum of absolute differences with right and below + + static const float valmin = 0.020602694503245016f; + auto valminv = Set(d, valmin); + for (size_t dy = 0; dy < 8; ++dy) { + const float* JXL_RESTRICT row_in = rect.ConstRow(xyb, y + dy) + x; + const float* JXL_RESTRICT row_in_next = + dy == 7 ? row_in : rect.ConstRow(xyb, y + dy + 1) + x; + + // In SCALAR, there is no guarantee of having extra row padding. + // Hence, we need to ensure we don't access pixels outside the row itself. + // In SIMD modes, however, rows are padded, so it's safe to access one + // garbage value after the row. The vector then gets masked with kMaskRight + // to remove the influence of that value. +#if HWY_TARGET != HWY_SCALAR + for (size_t dx = 0; dx < 8; dx += Lanes(d)) { +#else + for (size_t dx = 0; dx < 7; dx += Lanes(d)) { +#endif + const auto p = Load(d, row_in + dx); + const auto pr = LoadU(d, row_in + dx + 1); + const auto mask = BitCast(d, Load(du, kMaskRight + dx)); + sum = Add(sum, And(mask, Min(valminv, AbsDiff(p, pr)))); + + const auto pd = Load(d, row_in_next + dx); + sum = Add(sum, Min(valminv, AbsDiff(p, pd))); + } +#if HWY_TARGET == HWY_SCALAR + const auto p = Load(d, row_in + 7); + const auto pd = Load(d, row_in_next + 7); + sum = Add(sum, Min(valminv, AbsDiff(p, pd))); +#endif + } + // more negative value gives more bpp + static const float kOffset = -1.110929106987477; + static const float kMul = -0.38078920620238305; + sum = SumOfLanes(d, sum); + float scalar_sum = GetLane(sum); + scalar_sum += kOffset; + scalar_sum *= kMul; + return Add(Set(d, scalar_sum), out_val); +} + +void PerBlockModulations(const float butteraugli_target, const ImageF& xyb_x, + const ImageF& xyb_y, const ImageF& xyb_b, + const Rect& rect_in, const float scale, + const Rect& rect_out, ImageF* out) { + float base_level = 0.48f * scale; + float kDampenRampStart = 2.0f; + float kDampenRampEnd = 14.0f; + float dampen = 1.0f; + if (butteraugli_target >= kDampenRampStart) { + dampen = 1.0f - ((butteraugli_target - kDampenRampStart) / + (kDampenRampEnd - kDampenRampStart)); + if (dampen < 0) { + dampen = 0; + } + } + const float mul = scale * dampen; + const float add = (1.0f - dampen) * base_level; + for (size_t iy = rect_out.y0(); iy < rect_out.y1(); iy++) { + const size_t y = iy * 8; + float* const JXL_RESTRICT row_out = out->Row(iy); + const HWY_CAPPED(float, kBlockDim) df; + for (size_t ix = rect_out.x0(); ix < rect_out.x1(); ix++) { + size_t x = ix * 8; + auto out_val = Set(df, row_out[ix]); + out_val = ComputeMask(df, out_val); + out_val = HfModulation(df, x, y, xyb_y, rect_in, out_val); + out_val = GammaModulation(df, x, y, xyb_x, xyb_y, rect_in, out_val); + // We want multiplicative quantization field, so everything + // until this point has been modulating the exponent. + row_out[ix] = FastPow2f(GetLane(out_val) * 1.442695041f) * mul + add; + } + } +} + +template <typename D, typename V> +V MaskingSqrt(const D d, V v) { + static const float kLogOffset = 27.97044946785558f; + static const float kMul = 211.53333281566171f; + const auto mul_v = Set(d, kMul * 1e8); + const auto offset_v = Set(d, kLogOffset); + return Mul(Set(d, 0.25f), Sqrt(MulAdd(v, Sqrt(mul_v), offset_v))); +} + +float MaskingSqrt(const float v) { + using DScalar = HWY_CAPPED(float, 1); + auto vscalar = Load(DScalar(), &v); + return GetLane(MaskingSqrt(DScalar(), vscalar)); +} + +void StoreMin4(const float v, float& min0, float& min1, float& min2, + float& min3) { + if (v < min3) { + if (v < min0) { + min3 = min2; + min2 = min1; + min1 = min0; + min0 = v; + } else if (v < min1) { + min3 = min2; + min2 = min1; + min1 = v; + } else if (v < min2) { + min3 = min2; + min2 = v; + } else { + min3 = v; + } + } +} + +// Look for smooth areas near the area of degradation. +// If the areas are generally smooth, don't do masking. +// Output is downsampled 2x. +void FuzzyErosion(const float butteraugli_target, const Rect& from_rect, + const ImageF& from, const Rect& to_rect, ImageF* to) { + const size_t xsize = from.xsize(); + const size_t ysize = from.ysize(); + constexpr int kStep = 1; + static_assert(kStep == 1, "Step must be 1"); + JXL_ASSERT(to_rect.xsize() * 2 == from_rect.xsize()); + JXL_ASSERT(to_rect.ysize() * 2 == from_rect.ysize()); + static const float kMulBase0 = 0.125; + static const float kMulBase1 = 0.10; + static const float kMulBase2 = 0.09; + static const float kMulBase3 = 0.06; + static const float kMulAdd0 = 0.0; + static const float kMulAdd1 = -0.10; + static const float kMulAdd2 = -0.09; + static const float kMulAdd3 = -0.06; + + float mul = 0.0; + if (butteraugli_target < 2.0f) { + mul = (2.0f - butteraugli_target) * (1.0f / 2.0f); + } + float kMul0 = kMulBase0 + mul * kMulAdd0; + float kMul1 = kMulBase1 + mul * kMulAdd1; + float kMul2 = kMulBase2 + mul * kMulAdd2; + float kMul3 = kMulBase3 + mul * kMulAdd3; + static const float kTotal = 0.29959705784054957; + float norm = kTotal / (kMul0 + kMul1 + kMul2 + kMul3); + kMul0 *= norm; + kMul1 *= norm; + kMul2 *= norm; + kMul3 *= norm; + + for (size_t fy = 0; fy < from_rect.ysize(); ++fy) { + size_t y = fy + from_rect.y0(); + size_t ym1 = y >= kStep ? y - kStep : y; + size_t yp1 = y + kStep < ysize ? y + kStep : y; + const float* rowt = from.Row(ym1); + const float* row = from.Row(y); + const float* rowb = from.Row(yp1); + float* row_out = to_rect.Row(to, fy / 2); + for (size_t fx = 0; fx < from_rect.xsize(); ++fx) { + size_t x = fx + from_rect.x0(); + size_t xm1 = x >= kStep ? x - kStep : x; + size_t xp1 = x + kStep < xsize ? x + kStep : x; + float min0 = row[x]; + float min1 = row[xm1]; + float min2 = row[xp1]; + float min3 = rowt[xm1]; + // Sort the first four values. + if (min0 > min1) std::swap(min0, min1); + if (min0 > min2) std::swap(min0, min2); + if (min0 > min3) std::swap(min0, min3); + if (min1 > min2) std::swap(min1, min2); + if (min1 > min3) std::swap(min1, min3); + if (min2 > min3) std::swap(min2, min3); + // The remaining five values of a 3x3 neighbourhood. + StoreMin4(rowt[x], min0, min1, min2, min3); + StoreMin4(rowt[xp1], min0, min1, min2, min3); + StoreMin4(rowb[xm1], min0, min1, min2, min3); + StoreMin4(rowb[x], min0, min1, min2, min3); + StoreMin4(rowb[xp1], min0, min1, min2, min3); + + float v = kMul0 * min0 + kMul1 * min1 + kMul2 * min2 + kMul3 * min3; + if (fx % 2 == 0 && fy % 2 == 0) { + row_out[fx / 2] = v; + } else { + row_out[fx / 2] += v; + } + } + } +} + +struct AdaptiveQuantizationImpl { + void PrepareBuffers(size_t num_threads) { + diff_buffer = ImageF(kEncTileDim + 8, num_threads); + for (size_t i = pre_erosion.size(); i < num_threads; i++) { + pre_erosion.emplace_back(kEncTileDimInBlocks * 2 + 2, + kEncTileDimInBlocks * 2 + 2); + } + } + + void ComputeTile(float butteraugli_target, float scale, const Image3F& xyb, + const Rect& rect_in, const Rect& rect_out, const int thread, + ImageF* mask, ImageF* mask1x1) { + JXL_ASSERT(rect_in.x0() % 8 == 0); + JXL_ASSERT(rect_in.y0() % 8 == 0); + const size_t xsize = xyb.xsize(); + const size_t ysize = xyb.ysize(); + + // The XYB gamma is 3.0 to be able to decode faster with two muls. + // Butteraugli's gamma is matching the gamma of human eye, around 2.6. + // We approximate the gamma difference by adding one cubic root into + // the adaptive quantization. This gives us a total gamma of 2.6666 + // for quantization uses. + const float match_gamma_offset = 0.019; + + const HWY_FULL(float) df; + + size_t y_start_1x1 = rect_in.y0() + rect_out.y0() * 8; + size_t y_end_1x1 = y_start_1x1 + rect_out.ysize() * 8; + + size_t x_start_1x1 = rect_in.x0() + rect_out.x0() * 8; + size_t x_end_1x1 = x_start_1x1 + rect_out.xsize() * 8; + + if (rect_in.x0() != 0 && rect_out.x0() == 0) x_start_1x1 -= 2; + if (rect_in.x1() < xsize && rect_out.x1() * 8 == rect_in.xsize()) { + x_end_1x1 += 2; + } + if (rect_in.y0() != 0 && rect_out.y0() == 0) y_start_1x1 -= 2; + if (rect_in.y1() < ysize && rect_out.y1() * 8 == rect_in.ysize()) { + y_end_1x1 += 2; + } + + // Computes image (padded to multiple of 8x8) of local pixel differences. + // Subsample both directions by 4. + // 1x1 Laplacian of intensity. + for (size_t y = y_start_1x1; y < y_end_1x1; ++y) { + const size_t y2 = y + 1 < ysize ? y + 1 : y; + const size_t y1 = y > 0 ? y - 1 : y; + const float* row_in = xyb.ConstPlaneRow(1, y); + const float* row_in1 = xyb.ConstPlaneRow(1, y1); + const float* row_in2 = xyb.ConstPlaneRow(1, y2); + float* mask1x1_out = mask1x1->Row(y); + auto scalar_pixel1x1 = [&](size_t x) { + const size_t x2 = x + 1 < xsize ? x + 1 : x; + const size_t x1 = x > 0 ? x - 1 : x; + const float base = + 0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]); + const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma( + row_in[x] + match_gamma_offset); + float diff = fabs(gammac * (row_in[x] - base)); + static const double kScaler = 1.0; + diff *= kScaler; + diff = log1p(diff); + static const float kMul = 1.0; + static const float kOffset = 0.01; + mask1x1_out[x] = kMul / (diff + kOffset); + }; + for (size_t x = x_start_1x1; x < x_end_1x1; ++x) { + scalar_pixel1x1(x); + } + } + + size_t y_start = rect_in.y0() + rect_out.y0() * 8; + size_t y_end = y_start + rect_out.ysize() * 8; + + size_t x_start = rect_in.x0() + rect_out.x0() * 8; + size_t x_end = x_start + rect_out.xsize() * 8; + + if (x_start != 0) x_start -= 4; + if (x_end != xsize) x_end += 4; + if (y_start != 0) y_start -= 4; + if (y_end != ysize) y_end += 4; + pre_erosion[thread].ShrinkTo((x_end - x_start) / 4, (y_end - y_start) / 4); + + static const float limit = 0.2f; + for (size_t y = y_start; y < y_end; ++y) { + size_t y2 = y + 1 < ysize ? y + 1 : y; + size_t y1 = y > 0 ? y - 1 : y; + + const float* row_in = xyb.ConstPlaneRow(1, y); + const float* row_in1 = xyb.ConstPlaneRow(1, y1); + const float* row_in2 = xyb.ConstPlaneRow(1, y2); + float* JXL_RESTRICT row_out = diff_buffer.Row(thread); + + auto scalar_pixel = [&](size_t x) { + const size_t x2 = x + 1 < xsize ? x + 1 : x; + const size_t x1 = x > 0 ? x - 1 : x; + const float base = + 0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]); + const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma( + row_in[x] + match_gamma_offset); + float diff = gammac * (row_in[x] - base); + diff *= diff; + if (diff >= limit) { + diff = limit; + } + diff = MaskingSqrt(diff); + if ((y % 4) != 0) { + row_out[x - x_start] += diff; + } else { + row_out[x - x_start] = diff; + } + }; + + size_t x = x_start; + // First pixel of the row. + if (x_start == 0) { + scalar_pixel(x_start); + ++x; + } + // SIMD + const auto match_gamma_offset_v = Set(df, match_gamma_offset); + const auto quarter = Set(df, 0.25f); + for (; x + 1 + Lanes(df) < x_end; x += Lanes(df)) { + const auto in = LoadU(df, row_in + x); + const auto in_r = LoadU(df, row_in + x + 1); + const auto in_l = LoadU(df, row_in + x - 1); + const auto in_t = LoadU(df, row_in2 + x); + const auto in_b = LoadU(df, row_in1 + x); + auto base = Mul(quarter, Add(Add(in_r, in_l), Add(in_t, in_b))); + auto gammacv = + RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/false>( + df, Add(in, match_gamma_offset_v)); + auto diff = Mul(gammacv, Sub(in, base)); + diff = Mul(diff, diff); + diff = Min(diff, Set(df, limit)); + diff = MaskingSqrt(df, diff); + if ((y & 3) != 0) { + diff = Add(diff, LoadU(df, row_out + x - x_start)); + } + StoreU(diff, df, row_out + x - x_start); + } + // Scalar + for (; x < x_end; ++x) { + scalar_pixel(x); + } + if (y % 4 == 3) { + float* row_dout = pre_erosion[thread].Row((y - y_start) / 4); + for (size_t x = 0; x < (x_end - x_start) / 4; x++) { + row_dout[x] = (row_out[x * 4] + row_out[x * 4 + 1] + + row_out[x * 4 + 2] + row_out[x * 4 + 3]) * + 0.25f; + } + } + } + Rect from_rect(x_start % 8 == 0 ? 0 : 1, y_start % 8 == 0 ? 0 : 1, + rect_out.xsize() * 2, rect_out.ysize() * 2); + FuzzyErosion(butteraugli_target, from_rect, pre_erosion[thread], rect_out, + &aq_map); + for (size_t y = 0; y < rect_out.ysize(); ++y) { + const float* aq_map_row = rect_out.ConstRow(aq_map, y); + float* mask_row = rect_out.Row(mask, y); + for (size_t x = 0; x < rect_out.xsize(); ++x) { + mask_row[x] = ComputeMaskForAcStrategyUse(aq_map_row[x]); + } + } + PerBlockModulations(butteraugli_target, xyb.Plane(0), xyb.Plane(1), + xyb.Plane(2), rect_in, scale, rect_out, &aq_map); + } + std::vector<ImageF> pre_erosion; + ImageF aq_map; + ImageF diff_buffer; +}; + +static void Blur1x1Masking(ThreadPool* pool, ImageF* mask1x1, + const Rect& rect) { + // Blur the mask1x1 to obtain the masking image. + // Before blurring it contains an image of absolute value of the + // Laplacian of the intensity channel. + static const float kFilterMask1x1[5] = { + static_cast<float>(0.25647067633737227), + static_cast<float>(0.2050056912354399075), + static_cast<float>(0.154082048668497307), + static_cast<float>(0.08149576591362004441), + static_cast<float>(0.0512750104812308467), + }; + double sum = + 1.0 + 4 * (kFilterMask1x1[0] + kFilterMask1x1[1] + kFilterMask1x1[2] + + kFilterMask1x1[4] + 2 * kFilterMask1x1[3]); + if (sum < 1e-5) { + sum = 1e-5; + } + const float normalize = static_cast<float>(1.0 / sum); + const float normalize_mul = normalize; + WeightsSymmetric5 weights = + WeightsSymmetric5{{HWY_REP4(normalize)}, + {HWY_REP4(normalize_mul * kFilterMask1x1[0])}, + {HWY_REP4(normalize_mul * kFilterMask1x1[2])}, + {HWY_REP4(normalize_mul * kFilterMask1x1[1])}, + {HWY_REP4(normalize_mul * kFilterMask1x1[4])}, + {HWY_REP4(normalize_mul * kFilterMask1x1[3])}}; + ImageF temp(rect.xsize(), rect.ysize()); + Symmetric5(*mask1x1, rect, weights, pool, &temp); + *mask1x1 = std::move(temp); +} + +ImageF AdaptiveQuantizationMap(const float butteraugli_target, + const Image3F& xyb, const Rect& rect, + float scale, ThreadPool* pool, ImageF* mask, + ImageF* mask1x1) { + JXL_DASSERT(rect.xsize() % kBlockDim == 0); + JXL_DASSERT(rect.ysize() % kBlockDim == 0); + AdaptiveQuantizationImpl impl; + const size_t xsize_blocks = rect.xsize() / kBlockDim; + const size_t ysize_blocks = rect.ysize() / kBlockDim; + impl.aq_map = ImageF(xsize_blocks, ysize_blocks); + *mask = ImageF(xsize_blocks, ysize_blocks); + *mask1x1 = ImageF(xyb.xsize(), xyb.ysize()); + JXL_CHECK(RunOnPool( + pool, 0, + DivCeil(xsize_blocks, kEncTileDimInBlocks) * + DivCeil(ysize_blocks, kEncTileDimInBlocks), + [&](const size_t num_threads) { + impl.PrepareBuffers(num_threads); + return true; + }, + [&](const uint32_t tid, const size_t thread) { + size_t n_enc_tiles = DivCeil(xsize_blocks, kEncTileDimInBlocks); + size_t tx = tid % n_enc_tiles; + size_t ty = tid / n_enc_tiles; + size_t by0 = ty * kEncTileDimInBlocks; + size_t by1 = std::min((ty + 1) * kEncTileDimInBlocks, ysize_blocks); + size_t bx0 = tx * kEncTileDimInBlocks; + size_t bx1 = std::min((tx + 1) * kEncTileDimInBlocks, xsize_blocks); + Rect rect_out(bx0, by0, bx1 - bx0, by1 - by0); + impl.ComputeTile(butteraugli_target, scale, xyb, rect, rect_out, thread, + mask, mask1x1); + }, + "AQ DiffPrecompute")); + + Blur1x1Masking(pool, mask1x1, rect); + return std::move(impl).aq_map; +} + +} // namespace + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(AdaptiveQuantizationMap); + +namespace { + +// If true, prints the quantization maps at each iteration. +constexpr bool FLAGS_dump_quant_state = false; + +void DumpHeatmap(const CompressParams& cparams, const AuxOut* aux_out, + const std::string& label, const ImageF& image, + float good_threshold, float bad_threshold) { + if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { + Image3F heatmap = CreateHeatMapImage(image, good_threshold, bad_threshold); + char filename[200]; + snprintf(filename, sizeof(filename), "%s%05d", label.c_str(), + aux_out->num_butteraugli_iters); + DumpImage(cparams, filename, heatmap); + } +} + +void DumpHeatmaps(const CompressParams& cparams, const AuxOut* aux_out, + float ba_target, const ImageF& quant_field, + const ImageF& tile_heatmap, const ImageF& bt_diffmap) { + if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { + if (!WantDebugOutput(cparams)) return; + ImageF inv_qmap(quant_field.xsize(), quant_field.ysize()); + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* JXL_RESTRICT row_q = quant_field.ConstRow(y); + float* JXL_RESTRICT row_inv_q = inv_qmap.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + row_inv_q[x] = 1.0f / row_q[x]; // never zero + } + } + DumpHeatmap(cparams, aux_out, "quant_heatmap", inv_qmap, 4.0f * ba_target, + 6.0f * ba_target); + DumpHeatmap(cparams, aux_out, "tile_heatmap", tile_heatmap, ba_target, + 1.5f * ba_target); + // matches heat maps produced by the command line tool. + DumpHeatmap(cparams, aux_out, "bt_diffmap", bt_diffmap, + ButteraugliFuzzyInverse(1.5), ButteraugliFuzzyInverse(0.5)); + } +} + +ImageF TileDistMap(const ImageF& distmap, int tile_size, int margin, + const AcStrategyImage& ac_strategy) { + const int tile_xsize = (distmap.xsize() + tile_size - 1) / tile_size; + const int tile_ysize = (distmap.ysize() + tile_size - 1) / tile_size; + ImageF tile_distmap(tile_xsize, tile_ysize); + size_t distmap_stride = tile_distmap.PixelsPerRow(); + for (int tile_y = 0; tile_y < tile_ysize; ++tile_y) { + AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(tile_y); + float* JXL_RESTRICT dist_row = tile_distmap.Row(tile_y); + for (int tile_x = 0; tile_x < tile_xsize; ++tile_x) { + AcStrategy acs = ac_strategy_row[tile_x]; + if (!acs.IsFirstBlock()) continue; + int this_tile_xsize = acs.covered_blocks_x() * tile_size; + int this_tile_ysize = acs.covered_blocks_y() * tile_size; + int y_begin = std::max<int>(0, tile_size * tile_y - margin); + int y_end = std::min<int>(distmap.ysize(), + tile_size * tile_y + this_tile_ysize + margin); + int x_begin = std::max<int>(0, tile_size * tile_x - margin); + int x_end = std::min<int>(distmap.xsize(), + tile_size * tile_x + this_tile_xsize + margin); + float dist_norm = 0.0; + double pixels = 0; + for (int y = y_begin; y < y_end; ++y) { + float ymul = 1.0; + constexpr float kBorderMul = 0.98f; + constexpr float kCornerMul = 0.7f; + if (margin != 0 && (y == y_begin || y == y_end - 1)) { + ymul = kBorderMul; + } + const float* const JXL_RESTRICT row = distmap.Row(y); + for (int x = x_begin; x < x_end; ++x) { + float xmul = ymul; + if (margin != 0 && (x == x_begin || x == x_end - 1)) { + if (xmul == 1.0) { + xmul = kBorderMul; + } else { + xmul = kCornerMul; + } + } + float v = row[x]; + v *= v; + v *= v; + v *= v; + v *= v; + dist_norm += xmul * v; + pixels += xmul; + } + } + if (pixels == 0) pixels = 1; + // 16th norm is less than the max norm, we reduce the difference + // with this normalization factor. + constexpr float kTileNorm = 1.2f; + const float tile_dist = + kTileNorm * std::pow(dist_norm / pixels, 1.0f / 16.0f); + dist_row[tile_x] = tile_dist; + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + dist_row[tile_x + distmap_stride * iy + ix] = tile_dist; + } + } + } + } + return tile_distmap; +} + +static const float kDcQuantPow = 0.83f; +static const float kDcQuant = 1.095924047623553f; +static const float kAcQuant = 0.7381485255235064f; + +// Computes the decoded image for a given set of compression parameters. +ImageBundle RoundtripImage(const FrameHeader& frame_header, + const Image3F& opsin, PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool) { + std::unique_ptr<PassesDecoderState> dec_state = + jxl::make_unique<PassesDecoderState>(); + JXL_CHECK(dec_state->output_encoding_info.SetFromMetadata( + *enc_state->shared.metadata)); + dec_state->shared = &enc_state->shared; + JXL_ASSERT(opsin.ysize() % kBlockDim == 0); + + const size_t xsize_groups = DivCeil(opsin.xsize(), kGroupDim); + const size_t ysize_groups = DivCeil(opsin.ysize(), kGroupDim); + const size_t num_groups = xsize_groups * ysize_groups; + + size_t num_special_frames = enc_state->special_frames.size(); + size_t num_passes = enc_state->progressive_splitter.GetNumPasses(); + ModularFrameEncoder modular_frame_encoder(frame_header, enc_state->cparams); + JXL_CHECK(InitializePassesEncoder(frame_header, opsin, Rect(opsin), cms, pool, + enc_state, &modular_frame_encoder, + nullptr)); + JXL_CHECK(dec_state->Init(frame_header)); + JXL_CHECK(dec_state->InitForAC(num_passes, pool)); + + ImageBundle decoded(&enc_state->shared.metadata->m); + decoded.origin = frame_header.frame_origin; + decoded.SetFromImage(Image3F(opsin.xsize(), opsin.ysize()), + dec_state->output_encoding_info.color_encoding); + + PassesDecoderState::PipelineOptions options; + options.use_slow_render_pipeline = false; + options.coalescing = false; + options.render_spotcolors = false; + options.render_noise = false; + + // Same as frame_header.nonserialized_metadata->m + const ImageMetadata& metadata = *decoded.metadata(); + + JXL_CHECK(dec_state->PreparePipeline(frame_header, &decoded, options)); + + hwy::AlignedUniquePtr<GroupDecCache[]> group_dec_caches; + const auto allocate_storage = [&](const size_t num_threads) -> Status { + JXL_RETURN_IF_ERROR( + dec_state->render_pipeline->PrepareForThreads(num_threads, + /*use_group_ids=*/false)); + group_dec_caches = hwy::MakeUniqueAlignedArray<GroupDecCache>(num_threads); + return true; + }; + const auto process_group = [&](const uint32_t group_index, + const size_t thread) { + if (frame_header.loop_filter.epf_iters > 0) { + ComputeSigma(frame_header.loop_filter, + dec_state->shared->frame_dim.BlockGroupRect(group_index), + dec_state.get()); + } + RenderPipelineInput input = + dec_state->render_pipeline->GetInputBuffers(group_index, thread); + JXL_CHECK(DecodeGroupForRoundtrip( + frame_header, enc_state->coeffs, group_index, dec_state.get(), + &group_dec_caches[thread], thread, input, &decoded, nullptr)); + for (size_t c = 0; c < metadata.num_extra_channels; c++) { + std::pair<ImageF*, Rect> ri = input.GetBuffer(3 + c); + FillPlane(0.0f, ri.first, ri.second); + } + input.Done(); + }; + JXL_CHECK(RunOnPool(pool, 0, num_groups, allocate_storage, process_group, + "AQ loop")); + + // Ensure we don't create any new special frames. + enc_state->special_frames.resize(num_special_frames); + + return decoded; +} + +constexpr int kMaxButteraugliIters = 4; + +void FindBestQuantization(const FrameHeader& frame_header, + const Image3F& linear, const Image3F& opsin, + ImageF& quant_field, PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out) { + const CompressParams& cparams = enc_state->cparams; + if (cparams.resampling > 1 && + cparams.original_butteraugli_distance <= 4.0 * cparams.resampling) { + // For downsampled opsin image, the butteraugli based adaptive quantization + // loop would only make the size bigger without improving the distance much, + // so in this case we enable it only for very high butteraugli targets. + return; + } + Quantizer& quantizer = enc_state->shared.quantizer; + ImageI& raw_quant_field = enc_state->shared.raw_quant_field; + + const float butteraugli_target = cparams.butteraugli_distance; + const float original_butteraugli = cparams.original_butteraugli_distance; + ButteraugliParams params; + params.intensity_target = 80.f; + JxlButteraugliComparator comparator(params, cms); + JXL_CHECK(comparator.SetLinearReferenceImage(linear)); + bool lower_is_better = + (comparator.GoodQualityScore() < comparator.BadQualityScore()); + const float initial_quant_dc = InitialQuantDC(butteraugli_target); + AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field), + original_butteraugli, &quant_field); + ImageF tile_distmap; + ImageF initial_quant_field(quant_field.xsize(), quant_field.ysize()); + CopyImageTo(quant_field, &initial_quant_field); + + float initial_qf_min, initial_qf_max; + ImageMinMax(initial_quant_field, &initial_qf_min, &initial_qf_max); + float initial_qf_ratio = initial_qf_max / initial_qf_min; + float qf_max_deviation_low = std::sqrt(250 / initial_qf_ratio); + float asymmetry = 2; + if (qf_max_deviation_low < asymmetry) asymmetry = qf_max_deviation_low; + float qf_lower = initial_qf_min / (asymmetry * qf_max_deviation_low); + float qf_higher = initial_qf_max * (qf_max_deviation_low / asymmetry); + + JXL_ASSERT(qf_higher / qf_lower < 253); + + constexpr int kOriginalComparisonRound = 1; + int iters = kMaxButteraugliIters; + if (cparams.speed_tier != SpeedTier::kTortoise) { + iters = 2; + } + for (int i = 0; i < iters + 1; ++i) { + if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { + printf("\nQuantization field:\n"); + for (size_t y = 0; y < quant_field.ysize(); ++y) { + for (size_t x = 0; x < quant_field.xsize(); ++x) { + printf(" %.5f", quant_field.Row(y)[x]); + } + printf("\n"); + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); + ImageBundle dec_linear = + RoundtripImage(frame_header, opsin, enc_state, cms, pool); + float score; + ImageF diffmap; + JXL_CHECK(comparator.CompareWith(dec_linear, &diffmap, &score)); + if (!lower_is_better) { + score = -score; + ScaleImage(-1.0f, &diffmap); + } + tile_distmap = TileDistMap(diffmap, 8 * cparams.resampling, 0, + enc_state->shared.ac_strategy); + if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && WantDebugOutput(cparams)) { + DumpImage(cparams, ("dec" + ToString(i)).c_str(), *dec_linear.color()); + DumpHeatmaps(cparams, aux_out, butteraugli_target, quant_field, + tile_distmap, diffmap); + } + if (aux_out != nullptr) ++aux_out->num_butteraugli_iters; + if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { + float minval, maxval; + ImageMinMax(quant_field, &minval, &maxval); + printf("\nButteraugli iter: %d/%d\n", i, kMaxButteraugliIters); + printf("Butteraugli distance: %f (target = %f)\n", score, + original_butteraugli); + printf("quant range: %f ... %f DC quant: %f\n", minval, maxval, + initial_quant_dc); + if (FLAGS_dump_quant_state) { + quantizer.DumpQuantizationMap(raw_quant_field); + } + } + + if (i == iters) break; + + double kPow[8] = { + 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + }; + double kPowMod[8] = { + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + }; + if (i == kOriginalComparisonRound) { + // Don't allow optimization to make the quant field a lot worse than + // what the initial guess was. This allows the AC field to have enough + // precision to reduce the oscillations due to the dc reconstruction. + double kInitMul = 0.6; + const double kOneMinusInitMul = 1.0 - kInitMul; + for (size_t y = 0; y < quant_field.ysize(); ++y) { + float* const JXL_RESTRICT row_q = quant_field.Row(y); + const float* const JXL_RESTRICT row_init = initial_quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + double clamp = kOneMinusInitMul * row_q[x] + kInitMul * row_init[x]; + if (row_q[x] < clamp) { + row_q[x] = clamp; + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } + } + + double cur_pow = 0.0; + if (i < 7) { + cur_pow = kPow[i] + (original_butteraugli - 1.0) * kPowMod[i]; + if (cur_pow < 0) { + cur_pow = 0; + } + } + if (cur_pow == 0.0) { + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); + float* const JXL_RESTRICT row_q = quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + const float diff = row_dist[x] / original_butteraugli; + if (diff > 1.0f) { + float old = row_q[x]; + row_q[x] *= diff; + int qf_old = old * quantizer.InvGlobalScale() + 0.5; + int qf_new = row_q[x] * quantizer.InvGlobalScale() + 0.5; + if (qf_old == qf_new) { + row_q[x] = old + quantizer.Scale(); + } + } + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } else { + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); + float* const JXL_RESTRICT row_q = quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + const float diff = row_dist[x] / original_butteraugli; + if (diff <= 1.0f) { + row_q[x] *= std::pow(diff, cur_pow); + } else { + float old = row_q[x]; + row_q[x] *= diff; + int qf_old = old * quantizer.InvGlobalScale() + 0.5; + int qf_new = row_q[x] * quantizer.InvGlobalScale() + 0.5; + if (qf_old == qf_new) { + row_q[x] = old + quantizer.Scale(); + } + } + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); +} + +void FindBestQuantizationMaxError(const FrameHeader& frame_header, + const Image3F& opsin, ImageF& quant_field, + PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out) { + // TODO(szabadka): Make this work for non-opsin color spaces. + const CompressParams& cparams = enc_state->cparams; + Quantizer& quantizer = enc_state->shared.quantizer; + ImageI& raw_quant_field = enc_state->shared.raw_quant_field; + + // TODO(veluca): better choice of this value. + const float initial_quant_dc = + 16 * std::sqrt(0.1f / cparams.butteraugli_distance); + AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field), + cparams.original_butteraugli_distance, &quant_field); + + const float inv_max_err[3] = {1.0f / enc_state->cparams.max_error[0], + 1.0f / enc_state->cparams.max_error[1], + 1.0f / enc_state->cparams.max_error[2]}; + + for (int i = 0; i < kMaxButteraugliIters + 1; ++i) { + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); + if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) { + DumpXybImage(cparams, ("ops" + ToString(i)).c_str(), opsin); + } + ImageBundle decoded = + RoundtripImage(frame_header, opsin, enc_state, cms, pool); + if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) { + DumpXybImage(cparams, ("dec" + ToString(i)).c_str(), *decoded.color()); + } + for (size_t by = 0; by < enc_state->shared.frame_dim.ysize_blocks; by++) { + AcStrategyRow ac_strategy_row = + enc_state->shared.ac_strategy.ConstRow(by); + for (size_t bx = 0; bx < enc_state->shared.frame_dim.xsize_blocks; bx++) { + AcStrategy acs = ac_strategy_row[bx]; + if (!acs.IsFirstBlock()) continue; + float max_error = 0; + for (size_t c = 0; c < 3; c++) { + for (size_t y = by * kBlockDim; + y < (by + acs.covered_blocks_y()) * kBlockDim; y++) { + if (y >= decoded.ysize()) continue; + const float* JXL_RESTRICT in_row = opsin.ConstPlaneRow(c, y); + const float* JXL_RESTRICT dec_row = + decoded.color()->ConstPlaneRow(c, y); + for (size_t x = bx * kBlockDim; + x < (bx + acs.covered_blocks_x()) * kBlockDim; x++) { + if (x >= decoded.xsize()) continue; + max_error = std::max( + std::abs(in_row[x] - dec_row[x]) * inv_max_err[c], max_error); + } + } + } + // Target an error between max_error/2 and max_error. + // If the error in the varblock is above the target, increase the qf to + // compensate. If the error is below the target, decrease the qf. + // However, to avoid an excessive increase of the qf, only do so if the + // error is less than half the maximum allowed error. + const float qf_mul = (max_error < 0.5f) ? max_error * 2.0f + : (max_error > 1.0f) ? max_error + : 1.0f; + for (size_t qy = by; qy < by + acs.covered_blocks_y(); qy++) { + float* JXL_RESTRICT quant_field_row = quant_field.Row(qy); + for (size_t qx = bx; qx < bx + acs.covered_blocks_x(); qx++) { + quant_field_row[qx] *= qf_mul; + } + } + } + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); +} + +} // namespace + +void AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect, + float butteraugli_target, ImageF* quant_field) { + // Replace the whole quant_field in non-8x8 blocks with the maximum of each + // 8x8 block. + size_t stride = quant_field->PixelsPerRow(); + + // At low distances it is great to use max, but mean works better + // at high distances. We interpolate between them for a distance + // range. + float mean_max_mixer = 1.0f; + { + static const float kLimit = 1.54138f; + static const float kMul = 0.56391f; + static const float kMin = 0.0f; + if (butteraugli_target > kLimit) { + mean_max_mixer -= (butteraugli_target - kLimit) * kMul; + if (mean_max_mixer < kMin) { + mean_max_mixer = kMin; + } + } + } + for (size_t y = 0; y < rect.ysize(); ++y) { + AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(rect, y); + float* JXL_RESTRICT quant_row = rect.Row(quant_field, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + AcStrategy acs = ac_strategy_row[x]; + if (!acs.IsFirstBlock()) continue; + JXL_ASSERT(x + acs.covered_blocks_x() <= quant_field->xsize()); + JXL_ASSERT(y + acs.covered_blocks_y() <= quant_field->ysize()); + float max = quant_row[x]; + float mean = 0.0; + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + mean += quant_row[x + ix + iy * stride]; + max = std::max(quant_row[x + ix + iy * stride], max); + } + } + mean /= acs.covered_blocks_y() * acs.covered_blocks_x(); + if (acs.covered_blocks_y() * acs.covered_blocks_x() >= 4) { + max *= mean_max_mixer; + max += (1.0f - mean_max_mixer) * mean; + } + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + quant_row[x + ix + iy * stride] = max; + } + } + } + } +} + +float InitialQuantDC(float butteraugli_target) { + const float kDcMul = 0.3; // Butteraugli target where non-linearity kicks in. + const float butteraugli_target_dc = std::max<float>( + 0.5f * butteraugli_target, + std::min<float>(butteraugli_target, + kDcMul * std::pow((1.0f / kDcMul) * butteraugli_target, + kDcQuantPow))); + // We want the maximum DC value to be at most 2**15 * kInvDCQuant / quant_dc. + // The maximum DC value might not be in the kXybRange because of inverse + // gaborish, so we add some slack to the maximum theoretical quant obtained + // this way (64). + return std::min(kDcQuant / butteraugli_target_dc, 50.f); +} + +ImageF InitialQuantField(const float butteraugli_target, const Image3F& opsin, + const Rect& rect, ThreadPool* pool, float rescale, + ImageF* mask, ImageF* mask1x1) { + const float quant_ac = kAcQuant / butteraugli_target; + return HWY_DYNAMIC_DISPATCH(AdaptiveQuantizationMap)( + butteraugli_target, opsin, rect, quant_ac * rescale, pool, mask, mask1x1); +} + +void FindBestQuantizer(const FrameHeader& frame_header, const Image3F* linear, + const Image3F& opsin, ImageF& quant_field, + PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out, double rescale) { + const CompressParams& cparams = enc_state->cparams; + if (cparams.max_error_mode) { + FindBestQuantizationMaxError(frame_header, opsin, quant_field, enc_state, + cms, pool, aux_out); + } else if (linear && cparams.speed_tier <= SpeedTier::kKitten) { + // Normal encoding to a butteraugli score. + FindBestQuantization(frame_header, *linear, opsin, quant_field, enc_state, + cms, pool, aux_out); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.h b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.h new file mode 100644 index 0000000000..6aa8b10df6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.h @@ -0,0 +1,56 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_ADAPTIVE_QUANTIZATION_H_ +#define LIB_JXL_ENC_ADAPTIVE_QUANTIZATION_H_ + +#include <jxl/cms_interface.h> +#include <stddef.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" + +// Heuristics to find a good quantizer for a given image. InitialQuantField +// produces a quantization field (i.e. relative quantization amounts for each +// block) out of an opsin-space image. `InitialQuantField` uses heuristics, +// `FindBestQuantizer` (in non-fast mode) will run multiple encoding-decoding +// steps and try to improve the given quant field. + +namespace jxl { + +struct AuxOut; + +// Returns an image subsampled by kBlockDim in each direction. If the value +// at pixel (x,y) in the returned image is greater than 1.0, it means that +// more fine-grained quantization should be used in the corresponding block +// of the input image, while a value less than 1.0 indicates that less +// fine-grained quantization should be enough. Returns a mask, too, which +// can later be used to make better decisions about ac strategy. +ImageF InitialQuantField(float butteraugli_target, const Image3F& opsin, + const Rect& rect, ThreadPool* pool, float rescale, + ImageF* initial_quant_mask, + ImageF* initial_quant_mask1x1); + +float InitialQuantDC(float butteraugli_target); + +void AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect, + float butteraugli_target, ImageF* quant_field); + +// Returns a quantizer that uses an adjusted version of the provided +// quant_field. Also computes the dequant_map corresponding to the given +// dequant_float_map and chosen quantization levels. +// `linear` is only used in Kitten mode or slower. +void FindBestQuantizer(const FrameHeader& frame_header, const Image3F* linear, + const Image3F& opsin, ImageF& quant_field, + PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out, double rescale = 1.0); + +} // namespace jxl + +#endif // LIB_JXL_ENC_ADAPTIVE_QUANTIZATION_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_ans.cc b/third_party/jpeg-xl/lib/jxl/enc_ans.cc new file mode 100644 index 0000000000..3efa62d8e1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ans.cc @@ -0,0 +1,1782 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_ans.h" + +#include <stdint.h> + +#include <algorithm> +#include <array> +#include <cmath> +#include <limits> +#include <numeric> +#include <type_traits> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "lib/jxl/ans_common.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/fast_math-inl.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_cluster.h" +#include "lib/jxl/enc_context_map.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/enc_huffman.h" +#include "lib/jxl/fields.h" + +namespace jxl { + +namespace { + +#if !JXL_IS_DEBUG_BUILD +constexpr +#endif + bool ans_fuzzer_friendly_ = false; + +static const int kMaxNumSymbolsForSmallCode = 4; + +void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table, + size_t alphabet_size, size_t log_alpha_size, + ANSEncSymbolInfo* info) { + size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size; + size_t entry_size_minus_1 = (1 << log_entry_size) - 1; + // create valid alias table for empty streams. + for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) { + const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s]; + info[s].freq_ = static_cast<uint16_t>(freq); +#ifdef USE_MULT_BY_RECIPROCAL + if (freq != 0) { + info[s].ifreq_ = + ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_; + } else { + info[s].ifreq_ = 1; // shouldn't matter (symbol shouldn't occur), but... + } +#endif + info[s].reverse_map_.resize(freq); + } + for (int i = 0; i < ANS_TAB_SIZE; i++) { + AliasTable::Symbol s = + AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1); + info[s.value].reverse_map_[s.offset] = i; + } +} + +float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts, + size_t len) { + float sum = 0.0f; + int total_histogram = 0; + int total_counts = 0; + for (size_t i = 0; i < len; ++i) { + total_histogram += histogram[i]; + total_counts += counts[i]; + if (histogram[i] > 0) { + JXL_ASSERT(counts[i] > 0); + // += histogram[i] * -log(counts[i]/total_counts) + sum += histogram[i] * + std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i])); + } + } + if (total_histogram > 0) { + // Used only in assert. + (void)total_counts; + JXL_ASSERT(total_counts == ANS_TAB_SIZE); + } + return sum; +} + +float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) { + const float flat_bits = std::max(FastLog2f(len), 0.0f); + float total_histogram = 0; + for (size_t i = 0; i < len; ++i) { + total_histogram += histogram[i]; + } + return total_histogram * flat_bits; +} + +// Static Huffman code for encoding logcounts. The last symbol is used as RLE +// sequence. +static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = { + 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7, +}; +static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = { + 17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65, +}; + +// Returns the difference between largest count that can be represented and is +// smaller than "count" and smallest representable count larger than "count". +static int SmallestIncrement(uint32_t count, uint32_t shift) { + int bits = count == 0 ? -1 : FloorLog2Nonzero(count); + int drop_bits = bits - GetPopulationCountPrecision(bits, shift); + return drop_bits < 0 ? 1 : (1 << drop_bits); +} + +template <bool minimize_error_of_sum> +bool RebalanceHistogram(const float* targets, int max_symbol, int table_size, + uint32_t shift, int* omit_pos, ANSHistBin* counts) { + int sum = 0; + float sum_nonrounded = 0.0; + int remainder_pos = 0; // if all of them are handled in first loop + int remainder_log = -1; + for (int n = 0; n < max_symbol; ++n) { + if (targets[n] > 0 && targets[n] < 1.0f) { + counts[n] = 1; + sum_nonrounded += targets[n]; + sum += counts[n]; + } + } + const float discount_ratio = + (table_size - sum) / (table_size - sum_nonrounded); + JXL_ASSERT(discount_ratio > 0); + JXL_ASSERT(discount_ratio <= 1.0f); + // Invariant for minimize_error_of_sum == true: + // abs(sum - sum_nonrounded) + // <= SmallestIncrement(max(targets[])) + max_symbol + for (int n = 0; n < max_symbol; ++n) { + if (targets[n] >= 1.0f) { + sum_nonrounded += targets[n]; + counts[n] = + static_cast<ANSHistBin>(targets[n] * discount_ratio); // truncate + if (counts[n] == 0) counts[n] = 1; + if (counts[n] == table_size) counts[n] = table_size - 1; + // Round the count to the closest nonzero multiple of SmallestIncrement + // (when minimize_error_of_sum is false) or one of two closest so as to + // keep the sum as close as possible to sum_nonrounded. + int inc = SmallestIncrement(counts[n], shift); + counts[n] -= counts[n] & (inc - 1); + // TODO(robryk): Should we rescale targets[n]? + const float target = + minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n]; + if (counts[n] == 0 || + (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) { + counts[n] += inc; + } + sum += counts[n]; + const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n])); + if (count_log > remainder_log) { + remainder_pos = n; + remainder_log = count_log; + } + } + } + JXL_ASSERT(remainder_pos != -1); + // NOTE: This is the only place where counts could go negative. We could + // detect that, return false and make ANSHistBin uint32_t. + counts[remainder_pos] -= sum - table_size; + *omit_pos = remainder_pos; + return counts[remainder_pos] > 0; +} + +Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length, + const int precision_bits, uint32_t shift, + int* num_symbols, int* symbols) { + const int32_t table_size = 1 << precision_bits; // target sum / table size + uint64_t total = 0; + int max_symbol = 0; + int symbol_count = 0; + for (int n = 0; n < length; ++n) { + total += counts[n]; + if (counts[n] > 0) { + if (symbol_count < kMaxNumSymbolsForSmallCode) { + symbols[symbol_count] = n; + } + ++symbol_count; + max_symbol = n + 1; + } + } + *num_symbols = symbol_count; + if (symbol_count == 0) { + return true; + } + if (symbol_count == 1) { + counts[symbols[0]] = table_size; + return true; + } + if (symbol_count > table_size) + return JXL_FAILURE("Too many entries in an ANS histogram"); + + const float norm = 1.f * table_size / total; + std::vector<float> targets(max_symbol); + for (size_t n = 0; n < targets.size(); ++n) { + targets[n] = norm * counts[n]; + } + if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift, + omit_pos, counts)) { + // Use an alternative rebalancing mechanism if the one above failed + // to create a histogram that is positive wherever the original one was. + if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift, + omit_pos, counts)) { + return JXL_FAILURE("Logic error: couldn't rebalance a histogram"); + } + } + return true; +} + +struct SizeWriter { + size_t size = 0; + void Write(size_t num, size_t bits) { size += num; } +}; + +template <typename Writer> +void StoreVarLenUint8(size_t n, Writer* writer) { + JXL_DASSERT(n <= 255); + if (n == 0) { + writer->Write(1, 0); + } else { + writer->Write(1, 1); + size_t nbits = FloorLog2Nonzero(n); + writer->Write(3, nbits); + writer->Write(nbits, n - (1ULL << nbits)); + } +} + +template <typename Writer> +void StoreVarLenUint16(size_t n, Writer* writer) { + JXL_DASSERT(n <= 65535); + if (n == 0) { + writer->Write(1, 0); + } else { + writer->Write(1, 1); + size_t nbits = FloorLog2Nonzero(n); + writer->Write(4, nbits); + writer->Write(nbits, n - (1ULL << nbits)); + } +} + +template <typename Writer> +bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size, + const int omit_pos, const int num_symbols, uint32_t shift, + const int* symbols, Writer* writer) { + bool ok = true; + if (num_symbols <= 2) { + // Small tree marker to encode 1-2 symbols. + writer->Write(1, 1); + if (num_symbols == 0) { + writer->Write(1, 0); + StoreVarLenUint8(0, writer); + } else { + writer->Write(1, num_symbols - 1); + for (int i = 0; i < num_symbols; ++i) { + StoreVarLenUint8(symbols[i], writer); + } + } + if (num_symbols == 2) { + writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]); + } + } else { + // Mark non-small tree. + writer->Write(1, 0); + // Mark non-flat histogram. + writer->Write(1, 0); + + // Precompute sequences for RLE encoding. Contains the number of identical + // values starting at a given index. Only contains the value at the first + // element of the series. + std::vector<uint32_t> same(alphabet_size, 0); + int last = 0; + for (int i = 1; i < alphabet_size; i++) { + // Store the sequence length once different symbol reached, or we're at + // the end, or the length is longer than we can encode, or we are at + // the omit_pos. We don't support including the omit_pos in an RLE + // sequence because this value may use a different amount of log2 bits + // than standard, it is too complex to handle in the decoder. + if (counts[i] != counts[last] || i + 1 == alphabet_size || + (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) { + same[last] = (i - last); + last = i + 1; + } + } + + int length = 0; + std::vector<int> logcounts(alphabet_size); + int omit_log = 0; + for (int i = 0; i < alphabet_size; ++i) { + JXL_ASSERT(counts[i] <= ANS_TAB_SIZE); + JXL_ASSERT(counts[i] >= 0); + if (i == omit_pos) { + length = i + 1; + } else if (counts[i] > 0) { + logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1; + length = i + 1; + if (i < omit_pos) { + omit_log = std::max(omit_log, logcounts[i] + 1); + } else { + omit_log = std::max(omit_log, logcounts[i]); + } + } + } + logcounts[omit_pos] = omit_log; + + // Elias gamma-like code for shift. Only difference is that if the number + // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip + // the terminating 0 in unary coding. + int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1); + int log = FloorLog2Nonzero(shift + 1); + writer->Write(log, (1 << log) - 1); + if (log != upper_bound_log) writer->Write(1, 0); + writer->Write(log, ((1 << log) - 1) & (shift + 1)); + + // Since num_symbols >= 3, we know that length >= 3, therefore we encode + // length - 3. + if (length - 3 > 255) { + // Pretend that everything is OK, but complain about correctness later. + StoreVarLenUint8(255, writer); + ok = false; + } else { + StoreVarLenUint8(length - 3, writer); + } + + // The logcount values are encoded with a static Huffman code. + static const size_t kMinReps = 4; + size_t rep = ANS_LOG_TAB_SIZE + 1; + for (int i = 0; i < length; ++i) { + if (i > 0 && same[i - 1] > kMinReps) { + // Encode the RLE symbol and skip the repeated ones. + writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]); + StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer); + i += same[i - 1] - 2; + continue; + } + writer->Write(kLogCountBitLengths[logcounts[i]], + kLogCountSymbols[logcounts[i]]); + } + for (int i = 0; i < length; ++i) { + if (i > 0 && same[i - 1] > kMinReps) { + // Skip symbols encoded by RLE. + i += same[i - 1] - 2; + continue; + } + if (logcounts[i] > 1 && i != omit_pos) { + int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift); + int drop_bits = logcounts[i] - 1 - bitcount; + JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0); + writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount)); + } + } + } + return ok; +} + +void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) { + // Mark non-small tree. + writer->Write(1, 0); + // Mark uniform histogram. + writer->Write(1, 1); + JXL_ASSERT(alphabet_size > 0); + // Encode alphabet size. + StoreVarLenUint8(alphabet_size - 1, writer); +} + +float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size, + uint32_t method) { + if (method == 0) { // Flat code + return ANS_LOG_TAB_SIZE + 2 + + EstimateDataBitsFlat(histogram, alphabet_size); + } + // Non-flat: shift = method-1. + uint32_t shift = method - 1; + std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size); + int omit_pos = 0; + int num_symbols; + int symbols[kMaxNumSymbolsForSmallCode] = {}; + JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size, + ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols)); + SizeWriter writer; + // Ignore the correctness, no real encoding happens at this stage. + (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift, + symbols, &writer); + return writer.size + + EstimateDataBits(histogram, counts.data(), alphabet_size); +} + +uint32_t ComputeBestMethod( + const ANSHistBin* histogram, size_t alphabet_size, float* cost, + HistogramParams::ANSHistogramStrategy ans_histogram_strategy) { + size_t method = 0; + float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0); + auto try_shift = [&](size_t shift) { + float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1); + if (c < fcost) { + method = shift + 1; + fcost = c; + } + }; + switch (ans_histogram_strategy) { + case HistogramParams::ANSHistogramStrategy::kPrecise: { + for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) { + try_shift(shift); + } + break; + } + case HistogramParams::ANSHistogramStrategy::kApproximate: { + for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) { + try_shift(shift); + } + break; + } + case HistogramParams::ANSHistogramStrategy::kFast: { + try_shift(0); + try_shift(ANS_LOG_TAB_SIZE / 2); + try_shift(ANS_LOG_TAB_SIZE); + break; + } + }; + *cost = fcost; + return method; +} + +} // namespace + +// Returns an estimate of the cost of encoding this histogram and the +// corresponding data. +size_t BuildAndStoreANSEncodingData( + HistogramParams::ANSHistogramStrategy ans_histogram_strategy, + const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size, + bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) { + if (use_prefix_code) { + if (alphabet_size <= 1) return 0; + std::vector<uint32_t> histo(alphabet_size); + for (size_t i = 0; i < alphabet_size; i++) { + histo[i] = histogram[i]; + JXL_CHECK(histogram[i] >= 0); + } + size_t cost = 0; + { + std::vector<uint8_t> depths(alphabet_size); + std::vector<uint16_t> bits(alphabet_size); + if (writer == nullptr) { + BitWriter tmp_writer; + BitWriter::Allotment allotment( + &tmp_writer, 8 * alphabet_size + 8); // safe upper bound + BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(), + bits.data(), &tmp_writer); + allotment.ReclaimAndCharge(&tmp_writer, 0, /*aux_out=*/nullptr); + cost = tmp_writer.BitsWritten(); + } else { + size_t start = writer->BitsWritten(); + BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(), + bits.data(), writer); + cost = writer->BitsWritten() - start; + } + for (size_t i = 0; i < alphabet_size; i++) { + info[i].bits = depths[i] == 0 ? 0 : bits[i]; + info[i].depth = depths[i]; + } + } + // Estimate data cost. + for (size_t i = 0; i < alphabet_size; i++) { + cost += histogram[i] * info[i].depth; + } + return cost; + } + JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE); + float cost; + uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost, + ans_histogram_strategy); + JXL_ASSERT(cost >= 0); + int num_symbols; + int symbols[kMaxNumSymbolsForSmallCode] = {}; + std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size); + if (!counts.empty()) { + size_t sum = 0; + for (size_t i = 0; i < counts.size(); i++) { + sum += counts[i]; + } + if (sum == 0) { + counts[0] = ANS_TAB_SIZE; + } + } + int omit_pos = 0; + uint32_t shift = method - 1; + if (method == 0) { + counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE); + } else { + JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size, + ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols)); + } + AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE]; + InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a); + ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info); + if (writer != nullptr) { + if (method == 0) { + EncodeFlatHistogram(alphabet_size, writer); + } else { + bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, + num_symbols, method - 1, symbols, writer); + (void)ok; + JXL_DASSERT(ok); + } + } + return cost; +} + +float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) { + float c; + ComputeBestMethod(data, alphabet_size, &c, + HistogramParams::ANSHistogramStrategy::kFast); + return c; +} + +template <typename Writer> +void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer, + size_t log_alpha_size) { + writer->Write(CeilLog2Nonzero(log_alpha_size + 1), + uint_config.split_exponent); + if (uint_config.split_exponent == log_alpha_size) { + return; // msb/lsb don't matter. + } + size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1); + writer->Write(nbits, uint_config.msb_in_token); + nbits = CeilLog2Nonzero(uint_config.split_exponent - + uint_config.msb_in_token + 1); + writer->Write(nbits, uint_config.lsb_in_token); +} +template <typename Writer> +void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config, + Writer* writer, size_t log_alpha_size) { + // TODO(veluca): RLE? + for (size_t i = 0; i < uint_config.size(); i++) { + EncodeUintConfig(uint_config[i], writer, log_alpha_size); + } +} +template void EncodeUintConfigs(const std::vector<HybridUintConfig>&, + BitWriter*, size_t); + +namespace { + +void ChooseUintConfigs(const HistogramParams& params, + const std::vector<std::vector<Token>>& tokens, + const std::vector<uint8_t>& context_map, + std::vector<Histogram>* clustered_histograms, + EntropyEncodingData* codes, size_t* log_alpha_size) { + codes->uint_config.resize(clustered_histograms->size()); + if (params.streaming_mode || + params.uint_method == HistogramParams::HybridUintMethod::kNone) { + return; + } + if (params.uint_method == HistogramParams::HybridUintMethod::k000) { + codes->uint_config.clear(); + codes->uint_config.resize(clustered_histograms->size(), + HybridUintConfig(0, 0, 0)); + return; + } + if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) { + codes->uint_config.clear(); + codes->uint_config.resize(clustered_histograms->size(), + HybridUintConfig(2, 0, 1)); + return; + } + + // Brute-force method that tries a few options. + std::vector<HybridUintConfig> configs; + if (params.uint_method == HistogramParams::HybridUintMethod::kBest) { + configs = { + HybridUintConfig(4, 2, 0), // default + HybridUintConfig(4, 1, 0), // less precise + HybridUintConfig(4, 2, 1), // add sign + HybridUintConfig(4, 2, 2), // add sign+parity + HybridUintConfig(4, 1, 2), // add parity but less msb + // Same as above, but more direct coding. + HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0), + HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2), + HybridUintConfig(5, 1, 2), + // Same as above, but less direct coding. + HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0), + HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2), + // For near-lossless. + HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4), + HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5), + HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0), + // Other + HybridUintConfig(0, 0, 0), // varlenuint + HybridUintConfig(2, 0, 1), // works well for ctx map + HybridUintConfig(7, 0, 0), // direct coding + HybridUintConfig(8, 0, 0), // direct coding + HybridUintConfig(9, 0, 0), // direct coding + HybridUintConfig(10, 0, 0), // direct coding + HybridUintConfig(11, 0, 0), // direct coding + HybridUintConfig(12, 0, 0), // direct coding + }; + } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) { + configs = { + HybridUintConfig(4, 2, 0), // default + HybridUintConfig(4, 1, 2), // add parity but less msb + HybridUintConfig(0, 0, 0), // smallest histograms + HybridUintConfig(2, 0, 1), // works well for ctx map + }; + } + + std::vector<float> costs(clustered_histograms->size(), + std::numeric_limits<float>::max()); + std::vector<uint32_t> extra_bits(clustered_histograms->size()); + std::vector<uint8_t> is_valid(clustered_histograms->size()); + size_t max_alpha = + codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE; + for (HybridUintConfig cfg : configs) { + std::fill(is_valid.begin(), is_valid.end(), true); + std::fill(extra_bits.begin(), extra_bits.end(), 0); + + for (size_t i = 0; i < clustered_histograms->size(); i++) { + (*clustered_histograms)[i].Clear(); + } + for (size_t i = 0; i < tokens.size(); ++i) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token token = tokens[i][j]; + // TODO(veluca): do not ignore lz77 commands. + if (token.is_lz77_length) continue; + size_t histo = context_map[token.context]; + uint32_t tok, nbits, bits; + cfg.Encode(token.value, &tok, &nbits, &bits); + if (tok >= max_alpha || + (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) { + is_valid[histo] = false; + continue; + } + extra_bits[histo] += nbits; + (*clustered_histograms)[histo].Add(tok); + } + } + + for (size_t i = 0; i < clustered_histograms->size(); i++) { + if (!is_valid[i]) continue; + float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i]; + // add signaling cost of the hybriduintconfig itself + cost += CeilLog2Nonzero(cfg.split_exponent + 1); + cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1); + if (cost < costs[i]) { + codes->uint_config[i] = cfg; + costs[i] = cost; + } + } + } + + // Rebuild histograms. + for (size_t i = 0; i < clustered_histograms->size(); i++) { + (*clustered_histograms)[i].Clear(); + } + *log_alpha_size = 4; + for (size_t i = 0; i < tokens.size(); ++i) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token token = tokens[i][j]; + uint32_t tok, nbits, bits; + size_t histo = context_map[token.context]; + (token.is_lz77_length ? codes->lz77.length_uint_config + : codes->uint_config[histo]) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes->lz77.min_symbol : 0; + (*clustered_histograms)[histo].Add(tok); + while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++; + } + } +#if JXL_ENABLE_ASSERT + size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8; + JXL_ASSERT(*log_alpha_size <= max_log_alpha_size); +#endif +} + +Histogram HistogramFromSymbolInfo( + const std::vector<ANSEncSymbolInfo>& encoding_info, bool use_prefix_code) { + Histogram histo; + histo.data_.resize(DivCeil(encoding_info.size(), Histogram::kRounding) * + Histogram::kRounding); + histo.total_count_ = 0; + for (size_t i = 0; i < encoding_info.size(); ++i) { + const ANSEncSymbolInfo& info = encoding_info[i]; + int count = use_prefix_code + ? (info.depth ? (1u << (PREFIX_MAX_BITS - info.depth)) : 0) + : info.freq_; + histo.data_[i] = count; + histo.total_count_ += count; + } + return histo; +} + +class HistogramBuilder { + public: + explicit HistogramBuilder(const size_t num_contexts) + : histograms_(num_contexts) {} + + void VisitSymbol(int symbol, size_t histo_idx) { + JXL_DASSERT(histo_idx < histograms_.size()); + histograms_[histo_idx].Add(symbol); + } + + // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge. + size_t BuildAndStoreEntropyCodes( + const HistogramParams& params, + const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes, + std::vector<uint8_t>* context_map, BitWriter* writer, size_t layer, + AuxOut* aux_out) const { + const size_t prev_histograms = codes->encoding_info.size(); + size_t cost = 0; + std::vector<Histogram> clustered_histograms; + for (size_t i = 0; i < prev_histograms; ++i) { + clustered_histograms.push_back(HistogramFromSymbolInfo( + codes->encoding_info[i], codes->use_prefix_code)); + } + size_t context_offset = context_map->size(); + context_map->resize(context_offset + histograms_.size()); + if (histograms_.size() > 1) { + if (!ans_fuzzer_friendly_) { + std::vector<uint32_t> histogram_symbols; + ClusterHistograms(params, histograms_, kClustersLimit, + &clustered_histograms, &histogram_symbols); + for (size_t c = 0; c < histograms_.size(); ++c) { + (*context_map)[context_offset + c] = + static_cast<uint8_t>(histogram_symbols[c]); + } + } else { + JXL_ASSERT(codes->encoding_info.empty()); + fill(context_map->begin(), context_map->end(), 0); + size_t max_symbol = 0; + for (const Histogram& h : histograms_) { + max_symbol = std::max(h.data_.size(), max_symbol); + } + size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1); + clustered_histograms.resize(1); + clustered_histograms[0].Clear(); + for (size_t i = 0; i < num_symbols; i++) { + clustered_histograms[0].Add(i); + } + } + if (writer != nullptr) { + EncodeContextMap(*context_map, clustered_histograms.size(), writer, + layer, aux_out); + } + } else { + JXL_ASSERT(codes->encoding_info.empty()); + clustered_histograms.push_back(histograms_[0]); + } + if (aux_out != nullptr) { + for (size_t i = prev_histograms; i < clustered_histograms.size(); ++i) { + aux_out->layers[layer].clustered_entropy += + clustered_histograms[i].ShannonEntropy(); + } + } + size_t log_alpha_size = codes->lz77.enabled ? 8 : 7; // Sane default. + if (ans_fuzzer_friendly_) { + codes->uint_config.clear(); + codes->uint_config.resize(1, HybridUintConfig(7, 0, 0)); + } else { + ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms, + codes, &log_alpha_size); + } + if (log_alpha_size < 5) log_alpha_size = 5; + if (params.streaming_mode) { + // TODO(szabadka) Figure out if we can use lower values here. + log_alpha_size = 8; + } + SizeWriter size_writer; // Used if writer == nullptr to estimate costs. + cost += 1; + if (writer) writer->Write(1, codes->use_prefix_code); + + if (codes->use_prefix_code) { + log_alpha_size = PREFIX_MAX_BITS; + } else { + cost += 2; + } + if (writer == nullptr) { + EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size); + } else { + if (!codes->use_prefix_code) writer->Write(2, log_alpha_size - 5); + EncodeUintConfigs(codes->uint_config, writer, log_alpha_size); + } + if (codes->use_prefix_code) { + for (size_t c = 0; c < clustered_histograms.size(); ++c) { + size_t alphabet_size = clustered_histograms[c].alphabet_size(); + if (writer) { + StoreVarLenUint16(alphabet_size - 1, writer); + } else { + StoreVarLenUint16(alphabet_size - 1, &size_writer); + } + } + } + cost += size_writer.size; + for (size_t c = prev_histograms; c < clustered_histograms.size(); ++c) { + size_t alphabet_size = clustered_histograms[c].alphabet_size(); + codes->encoding_info.emplace_back(); + codes->encoding_info.back().resize(alphabet_size); + BitWriter* histo_writer = writer; + if (params.streaming_mode) { + codes->encoded_histograms.emplace_back(); + histo_writer = &codes->encoded_histograms.back(); + } + BitWriter::Allotment allotment(histo_writer, 256 + alphabet_size * 24); + cost += BuildAndStoreANSEncodingData( + params.ans_histogram_strategy, clustered_histograms[c].data_.data(), + alphabet_size, log_alpha_size, codes->use_prefix_code, + codes->encoding_info.back().data(), histo_writer); + allotment.FinishedHistogram(histo_writer); + allotment.ReclaimAndCharge(histo_writer, layer, aux_out); + if (params.streaming_mode) { + writer->AppendUnaligned(*histo_writer); + } + } + return cost; + } + + const Histogram& Histo(size_t i) const { return histograms_[i]; } + + private: + std::vector<Histogram> histograms_; +}; + +class SymbolCostEstimator { + public: + SymbolCostEstimator(size_t num_contexts, bool force_huffman, + const std::vector<std::vector<Token>>& tokens, + const LZ77Params& lz77) { + HistogramBuilder builder(num_contexts); + // Build histograms for estimating lz77 savings. + HybridUintConfig uint_config; + for (size_t i = 0; i < tokens.size(); ++i) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token token = tokens[i][j]; + uint32_t tok, nbits, bits; + (token.is_lz77_length ? lz77.length_uint_config : uint_config) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? lz77.min_symbol : 0; + builder.VisitSymbol(tok, token.context); + } + } + max_alphabet_size_ = 0; + for (size_t i = 0; i < num_contexts; i++) { + max_alphabet_size_ = + std::max(max_alphabet_size_, builder.Histo(i).data_.size()); + } + bits_.resize(num_contexts * max_alphabet_size_); + // TODO(veluca): SIMD? + add_symbol_cost_.resize(num_contexts); + for (size_t i = 0; i < num_contexts; i++) { + float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f); + float total_cost = 0; + for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) { + size_t cnt = builder.Histo(i).data_[j]; + float cost = 0; + if (cnt != 0 && cnt != builder.Histo(i).total_count_) { + cost = -FastLog2f(cnt * inv_total); + if (force_huffman) cost = std::ceil(cost); + } else if (cnt == 0) { + cost = ANS_LOG_TAB_SIZE; // Highest possible cost. + } + bits_[i * max_alphabet_size_ + j] = cost; + total_cost += cost * builder.Histo(i).data_[j]; + } + // Penalty for adding a lz77 symbol to this contest (only used for static + // cost model). Higher penalty for contexts that have a very low + // per-symbol entropy. + add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total); + } + } + float Bits(size_t ctx, size_t sym) const { + return bits_[ctx * max_alphabet_size_ + sym]; + } + float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const { + uint32_t nbits, bits, tok; + lz77.length_uint_config.Encode(len, &tok, &nbits, &bits); + tok += lz77.min_symbol; + return nbits + Bits(ctx, tok); + } + float DistCost(size_t len, const LZ77Params& lz77) const { + uint32_t nbits, bits, tok; + HybridUintConfig().Encode(len, &tok, &nbits, &bits); + return nbits + Bits(lz77.nonserialized_distance_context, tok); + } + float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; } + + private: + size_t max_alphabet_size_; + std::vector<float> bits_; + std::vector<float> add_symbol_cost_; +}; + +void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts, + const std::vector<std::vector<Token>>& tokens, + LZ77Params& lz77, + std::vector<std::vector<Token>>& tokens_lz77) { + // TODO(veluca): tune heuristics here. + SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); + float bit_decrease = 0; + size_t total_symbols = 0; + tokens_lz77.resize(tokens.size()); + std::vector<float> sym_cost; + HybridUintConfig uint_config; + for (size_t stream = 0; stream < tokens.size(); stream++) { + size_t distance_multiplier = + params.image_widths.size() > stream ? params.image_widths[stream] : 0; + const auto& in = tokens[stream]; + auto& out = tokens_lz77[stream]; + total_symbols += in.size(); + // Cumulative sum of bit costs. + sym_cost.resize(in.size() + 1); + for (size_t i = 0; i < in.size(); i++) { + uint32_t tok, nbits, unused_bits; + uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); + sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; + } + out.reserve(in.size()); + for (size_t i = 0; i < in.size(); i++) { + size_t num_to_copy = 0; + size_t distance_symbol = 0; // 1 for RLE. + if (distance_multiplier != 0) { + distance_symbol = 1; // Special distance 1 if enabled. + JXL_DASSERT(kSpecialDistances[1][0] == 1); + JXL_DASSERT(kSpecialDistances[1][1] == 0); + } + if (i > 0) { + for (; i + num_to_copy < in.size(); num_to_copy++) { + if (in[i + num_to_copy].value != in[i - 1].value) { + break; + } + } + } + if (num_to_copy == 0) { + out.push_back(in[i]); + continue; + } + float cost = sym_cost[i + num_to_copy] - sym_cost[i]; + // This subtraction might overflow, but that's OK. + size_t lz77_len = num_to_copy - lz77.min_length; + float lz77_cost = num_to_copy >= lz77.min_length + ? CeilLog2Nonzero(lz77_len + 1) + 1 + : 0; + if (num_to_copy < lz77.min_length || cost <= lz77_cost) { + for (size_t j = 0; j < num_to_copy; j++) { + out.push_back(in[i + j]); + } + i += num_to_copy - 1; + continue; + } + // Output the LZ77 length + out.emplace_back(in[i].context, lz77_len); + out.back().is_lz77_length = true; + i += num_to_copy - 1; + bit_decrease += cost - lz77_cost; + // Output the LZ77 copy distance. + out.emplace_back(lz77.nonserialized_distance_context, distance_symbol); + } + } + + if (bit_decrease > total_symbols * 0.2 + 16) { + lz77.enabled = true; + } +} + +// Hash chain for LZ77 matching +struct HashChain { + size_t size_; + std::vector<uint32_t> data_; + + unsigned hash_num_values_ = 32768; + unsigned hash_mask_ = hash_num_values_ - 1; + unsigned hash_shift_ = 5; + + std::vector<int> head; + std::vector<uint32_t> chain; + std::vector<int> val; + + // Speed up repetitions of zero + std::vector<int> headz; + std::vector<uint32_t> chainz; + std::vector<uint32_t> zeros; + uint32_t numzeros = 0; + + size_t window_size_; + size_t window_mask_; + size_t min_length_; + size_t max_length_; + + // Map of special distance codes. + std::unordered_map<int, int> special_dist_table_; + size_t num_special_distances_ = 0; + + uint32_t maxchainlength = 256; // window_size_ to allow all + + HashChain(const Token* data, size_t size, size_t window_size, + size_t min_length, size_t max_length, size_t distance_multiplier) + : size_(size), + window_size_(window_size), + window_mask_(window_size - 1), + min_length_(min_length), + max_length_(max_length) { + data_.resize(size); + for (size_t i = 0; i < size; i++) { + data_[i] = data[i].value; + } + + head.resize(hash_num_values_, -1); + val.resize(window_size_, -1); + chain.resize(window_size_); + for (uint32_t i = 0; i < window_size_; ++i) { + chain[i] = i; // same value as index indicates uninitialized + } + + zeros.resize(window_size_); + headz.resize(window_size_ + 1, -1); + chainz.resize(window_size_); + for (uint32_t i = 0; i < window_size_; ++i) { + chainz[i] = i; + } + // Translate distance to special distance code. + if (distance_multiplier) { + // Count down, so if due to small distance multiplier multiple distances + // map to the same code, the smallest code will be used in the end. + for (int i = kNumSpecialDistances - 1; i >= 0; --i) { + int xi = kSpecialDistances[i][0]; + int yi = kSpecialDistances[i][1]; + int distance = yi * distance_multiplier + xi; + // Ensure that we map distance 1 to the lowest symbols. + if (distance < 1) distance = 1; + special_dist_table_[distance] = i; + } + num_special_distances_ = kNumSpecialDistances; + } + } + + uint32_t GetHash(size_t pos) const { + uint32_t result = 0; + if (pos + 2 < size_) { + // TODO(lode): take the MSB's of the uint32_t values into account as well, + // given that the hash code itself is less than 32 bits. + result ^= (uint32_t)(data_[pos + 0] << 0u); + result ^= (uint32_t)(data_[pos + 1] << hash_shift_); + result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2)); + } else { + // No need to compute hash of last 2 bytes, the length 2 is too short. + return 0; + } + return result & hash_mask_; + } + + uint32_t CountZeros(size_t pos, uint32_t prevzeros) const { + size_t end = pos + window_size_; + if (end > size_) end = size_; + if (prevzeros > 0) { + if (prevzeros >= window_mask_ && data_[end - 1] == 0 && + end == pos + window_size_) { + return prevzeros; + } else { + return prevzeros - 1; + } + } + uint32_t num = 0; + while (pos + num < end && data_[pos + num] == 0) num++; + return num; + } + + void Update(size_t pos) { + uint32_t hashval = GetHash(pos); + uint32_t wpos = pos & window_mask_; + + val[wpos] = (int)hashval; + if (head[hashval] != -1) chain[wpos] = head[hashval]; + head[hashval] = wpos; + + if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0; + numzeros = CountZeros(pos, numzeros); + + zeros[wpos] = numzeros; + if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros]; + headz[numzeros] = wpos; + } + + void Update(size_t pos, size_t len) { + for (size_t i = 0; i < len; i++) { + Update(pos + i); + } + } + + template <typename CB> + void FindMatches(size_t pos, int max_dist, const CB& found_match) const { + uint32_t wpos = pos & window_mask_; + uint32_t hashval = GetHash(pos); + uint32_t hashpos = chain[wpos]; + + int prev_dist = 0; + int end = std::min<int>(pos + max_length_, size_); + uint32_t chainlength = 0; + uint32_t best_len = 0; + for (;;) { + int dist = (hashpos <= wpos) ? (wpos - hashpos) + : (wpos - hashpos + window_mask_ + 1); + if (dist < prev_dist) break; + prev_dist = dist; + uint32_t len = 0; + if (dist > 0) { + int i = pos; + int j = pos - dist; + if (numzeros > 3) { + int r = std::min<int>(numzeros - 1, zeros[hashpos]); + if (i + r >= end) r = end - i - 1; + i += r; + j += r; + } + while (i < end && data_[i] == data_[j]) { + i++; + j++; + } + len = i - pos; + // This can trigger even if the new length is slightly smaller than the + // best length, because it is possible for a slightly cheaper distance + // symbol to occur. + if (len >= min_length_ && len + 2 >= best_len) { + auto it = special_dist_table_.find(dist); + int dist_symbol = (it == special_dist_table_.end()) + ? (num_special_distances_ + dist - 1) + : it->second; + found_match(len, dist_symbol); + if (len > best_len) best_len = len; + } + } + + chainlength++; + if (chainlength >= maxchainlength) break; + + if (numzeros >= 3 && len > numzeros) { + if (hashpos == chainz[hashpos]) break; + hashpos = chainz[hashpos]; + if (zeros[hashpos] != numzeros) break; + } else { + if (hashpos == chain[hashpos]) break; + hashpos = chain[hashpos]; + if (val[hashpos] != (int)hashval) break; // outdated hash value + } + } + } + void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol, + size_t* result_len) const { + *result_dist_symbol = 0; + *result_len = 1; + FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) { + if (len > *result_len || + (len == *result_len && *result_dist_symbol > dist_symbol)) { + *result_len = len; + *result_dist_symbol = dist_symbol; + } + }); + } +}; + +float LenCost(size_t len) { + uint32_t nbits, bits, tok; + HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits); + constexpr float kCostTable[] = { + 2.797667318563126, 3.213177690381199, 2.5706009246743737, + 2.408392498667534, 2.829649191872326, 3.3923087753324577, + 4.029267451554331, 4.415576699706408, 4.509357574741465, + 9.21481543803004, 10.020590190114898, 11.858671627804766, + 12.45853300490526, 11.713105831990857, 12.561996324849314, + 13.775477692278367, 13.174027068768641, + }; + size_t table_size = sizeof kCostTable / sizeof *kCostTable; + if (tok >= table_size) tok = table_size - 1; + return kCostTable[tok] + nbits; +} + +// TODO(veluca): this does not take into account usage or non-usage of distance +// multipliers. +float DistCost(size_t dist) { + uint32_t nbits, bits, tok; + HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits); + constexpr float kCostTable[] = { + 6.368282626312716, 5.680793277090298, 8.347404197105247, + 7.641619201599141, 6.914328374119438, 7.959808291537444, + 8.70023120759855, 8.71378518934703, 9.379132523982769, + 9.110472749092708, 9.159029569270908, 9.430936766731973, + 7.278284055315169, 7.8278514904267755, 10.026641158289236, + 9.976049229827066, 9.64351607048908, 9.563403863480442, + 10.171474111762747, 10.45950155077234, 9.994813912104219, + 10.322524683741156, 8.465808729388186, 8.756254166066853, + 10.160930174662234, 10.247329273413435, 10.04090403724809, + 10.129398517544082, 9.342311691539546, 9.07608009102374, + 10.104799540677513, 10.378079384990906, 10.165828974075072, + 10.337595322341553, 7.940557464567944, 10.575665823319431, + 11.023344321751955, 10.736144698831827, 11.118277044595054, + 7.468468230648442, 10.738305230932939, 10.906980780216568, + 10.163468216353817, 10.17805759656433, 11.167283670483565, + 11.147050200274544, 10.517921919244333, 10.651764778156886, + 10.17074446448919, 11.217636876224745, 11.261630721139484, + 11.403140815247259, 10.892472096873417, 11.1859607804481, + 8.017346947551262, 7.895143720278828, 11.036577113822025, + 11.170562110315794, 10.326988722591086, 10.40872184751056, + 11.213498225466386, 11.30580635516863, 10.672272515665442, + 10.768069466228063, 11.145257364153565, 11.64668307145549, + 10.593156194627339, 11.207499484844943, 10.767517766396908, + 10.826629811407042, 10.737764794499988, 10.6200448518045, + 10.191315385198092, 8.468384171390085, 11.731295299170432, + 11.824619886654398, 10.41518844301179, 10.16310536548649, + 10.539423685097576, 10.495136599328031, 10.469112847728267, + 11.72057686174922, 10.910326337834674, 11.378921834673758, + 11.847759036098536, 11.92071647623854, 10.810628276345282, + 11.008601085273893, 11.910326337834674, 11.949212023423133, + 11.298614839104337, 11.611603659010392, 10.472930394619985, + 11.835564720850282, 11.523267392285337, 12.01055816679611, + 8.413029688994023, 11.895784139536406, 11.984679534970505, + 11.220654278717394, 11.716311684833672, 10.61036646226114, + 10.89849965960364, 10.203762898863669, 10.997560826267238, + 11.484217379438984, 11.792836176993665, 12.24310468755171, + 11.464858097919262, 12.212747017409377, 11.425595666074955, + 11.572048533398757, 12.742093965163013, 11.381874288645637, + 12.191870445817015, 11.683156920035426, 11.152442115262197, + 11.90303691580457, 11.653292787169159, 11.938615382266098, + 16.970641701570223, 16.853602280380002, 17.26240782594733, + 16.644655390108507, 17.14310889757499, 16.910935455445955, + 17.505678976959697, 17.213498225466388, 2.4162310293553024, + 3.494587244462329, 3.5258600986408344, 3.4959806589517095, + 3.098390886949687, 3.343454654302911, 3.588847442290287, + 4.14614790111827, 5.152948641990529, 7.433696808092598, + 9.716311684833672, + }; + size_t table_size = sizeof kCostTable / sizeof *kCostTable; + if (tok >= table_size) tok = table_size - 1; + return kCostTable[tok] + nbits; +} + +void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts, + const std::vector<std::vector<Token>>& tokens, + LZ77Params& lz77, + std::vector<std::vector<Token>>& tokens_lz77) { + // TODO(veluca): tune heuristics here. + SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); + float bit_decrease = 0; + size_t total_symbols = 0; + tokens_lz77.resize(tokens.size()); + HybridUintConfig uint_config; + std::vector<float> sym_cost; + for (size_t stream = 0; stream < tokens.size(); stream++) { + size_t distance_multiplier = + params.image_widths.size() > stream ? params.image_widths[stream] : 0; + const auto& in = tokens[stream]; + auto& out = tokens_lz77[stream]; + total_symbols += in.size(); + // Cumulative sum of bit costs. + sym_cost.resize(in.size() + 1); + for (size_t i = 0; i < in.size(); i++) { + uint32_t tok, nbits, unused_bits; + uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); + sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; + } + + out.reserve(in.size()); + size_t max_distance = in.size(); + size_t min_length = lz77.min_length; + JXL_ASSERT(min_length >= 3); + size_t max_length = in.size(); + + // Use next power of two as window size. + size_t window_size = 1; + while (window_size < max_distance && window_size < kWindowSize) { + window_size <<= 1; + } + + HashChain chain(in.data(), in.size(), window_size, min_length, max_length, + distance_multiplier); + size_t len, dist_symbol; + + const size_t max_lazy_match_len = 256; // 0 to disable lazy matching + + // Whether the next symbol was already updated (to test lazy matching) + bool already_updated = false; + for (size_t i = 0; i < in.size(); i++) { + out.push_back(in[i]); + if (!already_updated) chain.Update(i); + already_updated = false; + chain.FindMatch(i, max_distance, &dist_symbol, &len); + if (len >= min_length) { + if (len < max_lazy_match_len && i + 1 < in.size()) { + // Try length at next symbol lazy matching + chain.Update(i + 1); + already_updated = true; + size_t len2, dist_symbol2; + chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2); + if (len2 > len) { + // Use the lazy match. Add literal, and use the next length starting + // from the next byte. + ++i; + already_updated = false; + len = len2; + dist_symbol = dist_symbol2; + out.push_back(in[i]); + } + } + + float cost = sym_cost[i + len] - sym_cost[i]; + size_t lz77_len = len - lz77.min_length; + float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) + + sce.AddSymbolCost(out.back().context); + + if (lz77_cost <= cost) { + out.back().value = len - min_length; + out.back().is_lz77_length = true; + out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); + bit_decrease += cost - lz77_cost; + } else { + // LZ77 match ignored, and symbol already pushed. Push all other + // symbols and skip. + for (size_t j = 1; j < len; j++) { + out.push_back(in[i + j]); + } + } + + if (already_updated) { + chain.Update(i + 2, len - 2); + already_updated = false; + } else { + chain.Update(i + 1, len - 1); + } + i += len - 1; + } else { + // Literal, already pushed + } + } + } + + if (bit_decrease > total_symbols * 0.2 + 16) { + lz77.enabled = true; + } +} + +void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts, + const std::vector<std::vector<Token>>& tokens, + LZ77Params& lz77, + std::vector<std::vector<Token>>& tokens_lz77) { + std::vector<std::vector<Token>> tokens_for_cost_estimate; + ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate); + // If greedy-LZ77 does not give better compression than no-lz77, no reason to + // run the optimal matching. + if (!lz77.enabled) return; + SymbolCostEstimator sce(num_contexts + 1, params.force_huffman, + tokens_for_cost_estimate, lz77); + tokens_lz77.resize(tokens.size()); + HybridUintConfig uint_config; + std::vector<float> sym_cost; + std::vector<uint32_t> dist_symbols; + for (size_t stream = 0; stream < tokens.size(); stream++) { + size_t distance_multiplier = + params.image_widths.size() > stream ? params.image_widths[stream] : 0; + const auto& in = tokens[stream]; + auto& out = tokens_lz77[stream]; + // Cumulative sum of bit costs. + sym_cost.resize(in.size() + 1); + for (size_t i = 0; i < in.size(); i++) { + uint32_t tok, nbits, unused_bits; + uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); + sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; + } + + out.reserve(in.size()); + size_t max_distance = in.size(); + size_t min_length = lz77.min_length; + JXL_ASSERT(min_length >= 3); + size_t max_length = in.size(); + + // Use next power of two as window size. + size_t window_size = 1; + while (window_size < max_distance && window_size < kWindowSize) { + window_size <<= 1; + } + + HashChain chain(in.data(), in.size(), window_size, min_length, max_length, + distance_multiplier); + + struct MatchInfo { + uint32_t len; + uint32_t dist_symbol; + uint32_t ctx; + float total_cost = std::numeric_limits<float>::max(); + }; + // Total cost to encode the first N symbols. + std::vector<MatchInfo> prefix_costs(in.size() + 1); + prefix_costs[0].total_cost = 0; + + size_t rle_length = 0; + size_t skip_lz77 = 0; + for (size_t i = 0; i < in.size(); i++) { + chain.Update(i); + float lit_cost = + prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i]; + if (prefix_costs[i + 1].total_cost > lit_cost) { + prefix_costs[i + 1].dist_symbol = 0; + prefix_costs[i + 1].len = 1; + prefix_costs[i + 1].ctx = in[i].context; + prefix_costs[i + 1].total_cost = lit_cost; + } + if (skip_lz77 > 0) { + skip_lz77--; + continue; + } + dist_symbols.clear(); + chain.FindMatches(i, max_distance, + [&dist_symbols](size_t len, size_t dist_symbol) { + if (dist_symbols.size() <= len) { + dist_symbols.resize(len + 1, dist_symbol); + } + if (dist_symbol < dist_symbols[len]) { + dist_symbols[len] = dist_symbol; + } + }); + if (dist_symbols.size() <= min_length) continue; + { + size_t best_cost = dist_symbols.back(); + for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) { + if (dist_symbols[j] < best_cost) { + best_cost = dist_symbols[j]; + } + dist_symbols[j] = best_cost; + } + } + for (size_t j = min_length; j < dist_symbols.size(); j++) { + // Cost model that uses results from lazy LZ77. + float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) + + sce.DistCost(dist_symbols[j], lz77); + float cost = prefix_costs[i].total_cost + lz77_cost; + if (prefix_costs[i + j].total_cost > cost) { + prefix_costs[i + j].len = j; + prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1; + prefix_costs[i + j].ctx = in[i].context; + prefix_costs[i + j].total_cost = cost; + } + } + // We are in a RLE sequence: skip all the symbols except the first 8 and + // the last 8. This avoid quadratic costs for sequences with long runs of + // the same symbol. + if ((dist_symbols.back() == 0 && distance_multiplier == 0) || + (dist_symbols.back() == 1 && distance_multiplier != 0)) { + rle_length++; + } else { + rle_length = 0; + } + if (rle_length >= 8 && dist_symbols.size() > 9) { + skip_lz77 = dist_symbols.size() - 10; + rle_length = 0; + } + } + size_t pos = in.size(); + while (pos > 0) { + bool is_lz77_length = prefix_costs[pos].dist_symbol != 0; + if (is_lz77_length) { + size_t dist_symbol = prefix_costs[pos].dist_symbol - 1; + out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); + } + size_t val = is_lz77_length ? prefix_costs[pos].len - min_length + : in[pos - 1].value; + out.emplace_back(prefix_costs[pos].ctx, val); + out.back().is_lz77_length = is_lz77_length; + pos -= prefix_costs[pos].len; + } + std::reverse(out.begin(), out.end()); + } +} + +void ApplyLZ77(const HistogramParams& params, size_t num_contexts, + const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77, + std::vector<std::vector<Token>>& tokens_lz77) { + if (params.initialize_global_state) { + lz77.enabled = false; + } + if (params.force_huffman) { + lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512); + } else { + lz77.min_symbol = 224; + } + if (params.lz77_method == HistogramParams::LZ77Method::kNone) { + return; + } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) { + ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77); + } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) { + ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77); + } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) { + ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77); + } else { + JXL_UNREACHABLE("Not implemented"); + } +} +} // namespace + +void EncodeHistograms(const std::vector<uint8_t>& context_map, + const EntropyEncodingData& codes, BitWriter* writer, + size_t layer, AuxOut* aux_out) { + BitWriter::Allotment allotment(writer, 128 + kClustersLimit * 136); + JXL_CHECK(Bundle::Write(codes.lz77, writer, layer, aux_out)); + if (codes.lz77.enabled) { + EncodeUintConfig(codes.lz77.length_uint_config, writer, + /*log_alpha_size=*/8); + } + EncodeContextMap(context_map, codes.encoding_info.size(), writer, layer, + aux_out); + writer->Write(1, codes.use_prefix_code); + size_t log_alpha_size = 8; + if (codes.use_prefix_code) { + log_alpha_size = PREFIX_MAX_BITS; + } else { + log_alpha_size = 8; // streaming_mode + writer->Write(2, log_alpha_size - 5); + } + EncodeUintConfigs(codes.uint_config, writer, log_alpha_size); + if (codes.use_prefix_code) { + for (const auto& info : codes.encoding_info) { + StoreVarLenUint16(info.size() - 1, writer); + } + } + for (const auto& histo_writer : codes.encoded_histograms) { + writer->AppendUnaligned(histo_writer); + } + allotment.FinishedHistogram(writer); + allotment.ReclaimAndCharge(writer, layer, aux_out); +} + +size_t BuildAndEncodeHistograms(const HistogramParams& params, + size_t num_contexts, + std::vector<std::vector<Token>>& tokens, + EntropyEncodingData* codes, + std::vector<uint8_t>* context_map, + BitWriter* writer, size_t layer, + AuxOut* aux_out) { + size_t total_bits = 0; + codes->lz77.nonserialized_distance_context = num_contexts; + std::vector<std::vector<Token>> tokens_lz77; + ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77); + if (ans_fuzzer_friendly_) { + codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0); + codes->lz77.min_symbol = 2048; + } + + const size_t max_contexts = std::min(num_contexts, kClustersLimit); + BitWriter::Allotment allotment(writer, + 128 + num_contexts * 40 + max_contexts * 96); + if (writer) { + JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out)); + } else { + size_t ebits, bits; + JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits)); + total_bits += bits; + } + if (codes->lz77.enabled) { + if (writer) { + size_t b = writer->BitsWritten(); + EncodeUintConfig(codes->lz77.length_uint_config, writer, + /*log_alpha_size=*/8); + total_bits += writer->BitsWritten() - b; + } else { + SizeWriter size_writer; + EncodeUintConfig(codes->lz77.length_uint_config, &size_writer, + /*log_alpha_size=*/8); + total_bits += size_writer.size; + } + num_contexts += 1; + tokens = std::move(tokens_lz77); + } + size_t total_tokens = 0; + // Build histograms. + HistogramBuilder builder(num_contexts); + HybridUintConfig uint_config; // Default config for clustering. + // Unless we are using the kContextMap histogram option. + if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) { + uint_config = HybridUintConfig(2, 0, 1); + } + if (params.uint_method == HistogramParams::HybridUintMethod::k000) { + uint_config = HybridUintConfig(0, 0, 0); + } + if (ans_fuzzer_friendly_) { + uint_config = HybridUintConfig(10, 0, 0); + } + for (size_t i = 0; i < tokens.size(); ++i) { + if (codes->lz77.enabled) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token& token = tokens[i][j]; + total_tokens++; + uint32_t tok, nbits, bits; + (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes->lz77.min_symbol : 0; + builder.VisitSymbol(tok, token.context); + } + } else if (num_contexts == 1) { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token& token = tokens[i][j]; + total_tokens++; + uint32_t tok, nbits, bits; + uint_config.Encode(token.value, &tok, &nbits, &bits); + builder.VisitSymbol(tok, /*token.context=*/0); + } + } else { + for (size_t j = 0; j < tokens[i].size(); ++j) { + const Token& token = tokens[i][j]; + total_tokens++; + uint32_t tok, nbits, bits; + uint_config.Encode(token.value, &tok, &nbits, &bits); + builder.VisitSymbol(tok, token.context); + } + } + } + + if (params.add_missing_symbols) { + for (size_t c = 0; c < num_contexts; ++c) { + for (int symbol = 0; symbol < ANS_MAX_ALPHABET_SIZE; ++symbol) { + builder.VisitSymbol(symbol, c); + } + } + } + + if (params.initialize_global_state) { + bool use_prefix_code = + params.force_huffman || total_tokens < 100 || + params.clustering == HistogramParams::ClusteringType::kFastest || + ans_fuzzer_friendly_; + if (!use_prefix_code) { + bool all_singleton = true; + for (size_t i = 0; i < num_contexts; i++) { + if (builder.Histo(i).ShannonEntropy() >= 1e-5) { + all_singleton = false; + } + } + if (all_singleton) { + use_prefix_code = true; + } + } + codes->use_prefix_code = use_prefix_code; + } + + if (params.add_fixed_histograms) { + // TODO(szabadka) Add more fixed histograms. + // TODO(szabadka) Reduce alphabet size by choosing a non-default + // uint_config. + const size_t alphabet_size = ANS_MAX_ALPHABET_SIZE; + const size_t log_alpha_size = 8; + JXL_ASSERT(alphabet_size == 1u << log_alpha_size); + std::vector<int32_t> counts = + CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE); + codes->encoding_info.emplace_back(); + codes->encoding_info.back().resize(alphabet_size); + codes->encoded_histograms.emplace_back(); + BitWriter* histo_writer = &codes->encoded_histograms.back(); + BitWriter::Allotment allotment(histo_writer, 256 + alphabet_size * 24); + BuildAndStoreANSEncodingData(params.ans_histogram_strategy, counts.data(), + alphabet_size, log_alpha_size, + codes->use_prefix_code, + &codes->encoding_info.back()[0], histo_writer); + allotment.ReclaimAndCharge(histo_writer, 0, nullptr); + } + + // Encode histograms. + total_bits += builder.BuildAndStoreEntropyCodes( + params, tokens, codes, context_map, writer, layer, aux_out); + allotment.FinishedHistogram(writer); + allotment.ReclaimAndCharge(writer, layer, aux_out); + + if (aux_out != nullptr) { + aux_out->layers[layer].num_clustered_histograms += + codes->encoding_info.size(); + } + return total_bits; +} + +size_t WriteTokens(const std::vector<Token>& tokens, + const EntropyEncodingData& codes, + const std::vector<uint8_t>& context_map, + size_t context_offset, BitWriter* writer) { + size_t num_extra_bits = 0; + if (codes.use_prefix_code) { + for (size_t i = 0; i < tokens.size(); i++) { + uint32_t tok, nbits, bits; + const Token& token = tokens[i]; + size_t histo = context_map[context_offset + token.context]; + (token.is_lz77_length ? codes.lz77.length_uint_config + : codes.uint_config[histo]) + .Encode(token.value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes.lz77.min_symbol : 0; + // Combine two calls to the BitWriter. Equivalent to: + // writer->Write(codes.encoding_info[histo][tok].depth, + // codes.encoding_info[histo][tok].bits); + // writer->Write(nbits, bits); + uint64_t data = codes.encoding_info[histo][tok].bits; + data |= bits << codes.encoding_info[histo][tok].depth; + writer->Write(codes.encoding_info[histo][tok].depth + nbits, data); + num_extra_bits += nbits; + } + return num_extra_bits; + } + std::vector<uint64_t> out; + std::vector<uint8_t> out_nbits; + out.reserve(tokens.size()); + out_nbits.reserve(tokens.size()); + uint64_t allbits = 0; + size_t numallbits = 0; + // Writes in *reversed* order. + auto addbits = [&](size_t bits, size_t nbits) { + if (JXL_UNLIKELY(nbits)) { + JXL_DASSERT(bits >> nbits == 0); + if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) { + out.push_back(allbits); + out_nbits.push_back(numallbits); + numallbits = allbits = 0; + } + allbits <<= nbits; + allbits |= bits; + numallbits += nbits; + } + }; + const int end = tokens.size(); + ANSCoder ans; + if (codes.lz77.enabled || context_map.size() > 1) { + for (int i = end - 1; i >= 0; --i) { + const Token token = tokens[i]; + const uint8_t histo = context_map[context_offset + token.context]; + uint32_t tok, nbits, bits; + (token.is_lz77_length ? codes.lz77.length_uint_config + : codes.uint_config[histo]) + .Encode(tokens[i].value, &tok, &nbits, &bits); + tok += token.is_lz77_length ? codes.lz77.min_symbol : 0; + const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok]; + JXL_DASSERT(info.freq_ > 0); + // Extra bits first as this is reversed. + addbits(bits, nbits); + num_extra_bits += nbits; + uint8_t ans_nbits = 0; + uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits); + addbits(ans_bits, ans_nbits); + } + } else { + for (int i = end - 1; i >= 0; --i) { + uint32_t tok, nbits, bits; + codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits); + const ANSEncSymbolInfo& info = codes.encoding_info[0][tok]; + // Extra bits first as this is reversed. + addbits(bits, nbits); + num_extra_bits += nbits; + uint8_t ans_nbits = 0; + uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits); + addbits(ans_bits, ans_nbits); + } + } + const uint32_t state = ans.GetState(); + writer->Write(32, state); + writer->Write(numallbits, allbits); + for (int i = out.size(); i > 0; --i) { + writer->Write(out_nbits[i - 1], out[i - 1]); + } + return num_extra_bits; +} + +void WriteTokens(const std::vector<Token>& tokens, + const EntropyEncodingData& codes, + const std::vector<uint8_t>& context_map, size_t context_offset, + BitWriter* writer, size_t layer, AuxOut* aux_out) { + BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4); + size_t num_extra_bits = + WriteTokens(tokens, codes, context_map, context_offset, writer); + allotment.ReclaimAndCharge(writer, layer, aux_out); + if (aux_out != nullptr) { + aux_out->layers[layer].extra_bits += num_extra_bits; + } +} + +void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) { +#if JXL_IS_DEBUG_BUILD // Guard against accidental / malicious changes. + ans_fuzzer_friendly_ = ans_fuzzer_friendly; +#endif +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_ans.h b/third_party/jpeg-xl/lib/jxl/enc_ans.h new file mode 100644 index 0000000000..445a5f0c9a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ans.h @@ -0,0 +1,141 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_ANS_H_ +#define LIB_JXL_ENC_ANS_H_ + +// Library to encode the ANS population counts to the bit-stream and encode +// symbols based on the respective distributions. + +#include <cstddef> +#include <cstdint> +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_ans_params.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +struct AuxOut; + +#define USE_MULT_BY_RECIPROCAL + +// precision must be equal to: #bits(state_) + #bits(freq) +#define RECIPROCAL_PRECISION (32 + ANS_LOG_TAB_SIZE) + +// Data structure representing one element of the encoding table built +// from a distribution. +// TODO(veluca): split this up, or use an union. +struct ANSEncSymbolInfo { + // ANS + uint16_t freq_; + std::vector<uint16_t> reverse_map_; +#ifdef USE_MULT_BY_RECIPROCAL + uint64_t ifreq_; +#endif + // Prefix coding. + uint8_t depth; + uint16_t bits; +}; + +class ANSCoder { + public: + ANSCoder() : state_(ANS_SIGNATURE << 16) {} + + uint32_t PutSymbol(const ANSEncSymbolInfo& t, uint8_t* nbits) { + uint32_t bits = 0; + *nbits = 0; + if ((state_ >> (32 - ANS_LOG_TAB_SIZE)) >= t.freq_) { + bits = state_ & 0xffff; + state_ >>= 16; + *nbits = 16; + } +#ifdef USE_MULT_BY_RECIPROCAL + // We use mult-by-reciprocal trick, but that requires 64b calc. + const uint32_t v = (state_ * t.ifreq_) >> RECIPROCAL_PRECISION; + const uint32_t offset = t.reverse_map_[state_ - v * t.freq_]; + state_ = (v << ANS_LOG_TAB_SIZE) + offset; +#else + state_ = ((state_ / t.freq_) << ANS_LOG_TAB_SIZE) + + t.reverse_map_[state_ % t.freq_]; +#endif + return bits; + } + + uint32_t GetState() const { return state_; } + + private: + uint32_t state_; +}; + +static const int kNumFixedHistograms = 1; + +struct EntropyEncodingData { + std::vector<std::vector<ANSEncSymbolInfo>> encoding_info; + bool use_prefix_code; + std::vector<HybridUintConfig> uint_config; + LZ77Params lz77; + std::vector<BitWriter> encoded_histograms; +}; + +// Integer to be encoded by an entropy coder, either ANS or Huffman. +struct Token { + Token() {} + Token(uint32_t c, uint32_t value) + : is_lz77_length(false), context(c), value(value) {} + uint32_t is_lz77_length : 1; + uint32_t context : 31; + uint32_t value; +}; + +// Returns an estimate of the number of bits required to encode the given +// histogram (header bits plus data bits). +float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size); + +// Writes the context map to the bitstream and concatenates the individual +// histogram bistreams in codes.encoded_histograms. Used in streaming mode. +void EncodeHistograms(const std::vector<uint8_t>& context_map, + const EntropyEncodingData& codes, BitWriter* writer, + size_t layer, AuxOut* aux_out); + +// Apply context clustering, compute histograms and encode them. Returns an +// estimate of the total bits used for encoding the stream. If `writer` == +// nullptr, the bit estimate will not take into account the context map (which +// does not get written if `num_contexts` == 1). +size_t BuildAndEncodeHistograms(const HistogramParams& params, + size_t num_contexts, + std::vector<std::vector<Token>>& tokens, + EntropyEncodingData* codes, + std::vector<uint8_t>* context_map, + BitWriter* writer, size_t layer, + AuxOut* aux_out); + +// Write the tokens to a string. +void WriteTokens(const std::vector<Token>& tokens, + const EntropyEncodingData& codes, + const std::vector<uint8_t>& context_map, size_t context_offset, + BitWriter* writer, size_t layer, AuxOut* aux_out); + +// Same as above, but assumes allotment created by caller. +size_t WriteTokens(const std::vector<Token>& tokens, + const EntropyEncodingData& codes, + const std::vector<uint8_t>& context_map, + size_t context_offset, BitWriter* writer); + +// Exposed for tests; to be used with Writer=BitWriter only. +template <typename Writer> +void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config, + Writer* writer, size_t log_alpha_size); +extern template void EncodeUintConfigs(const std::vector<HybridUintConfig>&, + BitWriter*, size_t); + +// Globally set the option to create fuzzer-friendly ANS streams. Negatively +// impacts compression. Not thread-safe. +void SetANSFuzzerFriendly(bool ans_fuzzer_friendly); +} // namespace jxl + +#endif // LIB_JXL_ENC_ANS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_ans_params.h b/third_party/jpeg-xl/lib/jxl/enc_ans_params.h new file mode 100644 index 0000000000..86664f593e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ans_params.h @@ -0,0 +1,83 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_ANS_PARAMS_H_ +#define LIB_JXL_ENC_ANS_PARAMS_H_ + +// Encoder-only parameter needed for ANS entropy encoding methods. + +#include <stdint.h> +#include <stdlib.h> + +#include "lib/jxl/enc_params.h" + +namespace jxl { + +// RebalanceHistogram requires a signed type. +using ANSHistBin = int32_t; + +struct HistogramParams { + enum class ClusteringType { + kFastest, // Only 4 clusters. + kFast, + kBest, + }; + + enum class HybridUintMethod { + kNone, // just use kHybridUint420Config. + k000, // force the fastest option. + kFast, // just try a couple of options. + kContextMap, // fast choice for ctx map. + kBest, + }; + + enum class LZ77Method { + kNone, // do not try lz77. + kRLE, // only try doing RLE. + kLZ77, // try lz77 with backward references. + kOptimal, // optimal-matching LZ77 parsing. + }; + + enum class ANSHistogramStrategy { + kFast, // Only try some methods, early exit. + kApproximate, // Only try some methods. + kPrecise, // Try all methods. + }; + + HistogramParams() = default; + + HistogramParams(SpeedTier tier, size_t num_ctx) { + if (tier > SpeedTier::kFalcon) { + clustering = ClusteringType::kFastest; + lz77_method = LZ77Method::kNone; + } else if (tier > SpeedTier::kTortoise) { + clustering = ClusteringType::kFast; + } else { + clustering = ClusteringType::kBest; + } + if (tier > SpeedTier::kTortoise) { + uint_method = HybridUintMethod::kNone; + } + if (tier >= SpeedTier::kSquirrel) { + ans_histogram_strategy = ANSHistogramStrategy::kApproximate; + } + } + + ClusteringType clustering = ClusteringType::kBest; + HybridUintMethod uint_method = HybridUintMethod::kBest; + LZ77Method lz77_method = LZ77Method::kRLE; + ANSHistogramStrategy ans_histogram_strategy = ANSHistogramStrategy::kPrecise; + std::vector<size_t> image_widths; + size_t max_histograms = ~0; + bool force_huffman = false; + bool initialize_global_state = true; + bool streaming_mode = false; + bool add_missing_symbols = false; + bool add_fixed_histograms = false; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_ANS_PARAMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.cc b/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.cc new file mode 100644 index 0000000000..ed8a42d299 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.cc @@ -0,0 +1,322 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_ar_control_field.h" + +#include <stdint.h> +#include <stdlib.h> + +#include <algorithm> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_ar_control_field.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::GetLane; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Sqrt; + +void ProcessTile(const CompressParams& cparams, const FrameHeader& frame_header, + const Image3F& opsin, const Rect& opsin_rect, + const ImageF& quant_field, const AcStrategyImage& ac_strategy, + ImageB* epf_sharpness, const Rect& rect, + ArControlFieldHeuristics::TempImages* temp_image) { + JXL_ASSERT(opsin_rect.x0() % 8 == 0); + JXL_ASSERT(opsin_rect.y0() % 8 == 0); + JXL_ASSERT(opsin_rect.xsize() % 8 == 0); + JXL_ASSERT(opsin_rect.ysize() % 8 == 0); + constexpr size_t N = kBlockDim; + if (cparams.butteraugli_distance < kMinButteraugliForDynamicAR || + cparams.speed_tier > SpeedTier::kWombat || + frame_header.loop_filter.epf_iters == 0) { + FillPlane(static_cast<uint8_t>(4), epf_sharpness, rect); + return; + } + + // Likely better to have a higher X weight, like: + // const float kChannelWeights[3] = {47.0f, 4.35f, 0.287f}; + const float kChannelWeights[3] = {4.35f, 4.35f, 0.287f}; + const float kChannelWeightsLapNeg[3] = {-0.125f * kChannelWeights[0], + -0.125f * kChannelWeights[1], + -0.125f * kChannelWeights[2]}; + const size_t sharpness_stride = + static_cast<size_t>(epf_sharpness->PixelsPerRow()); + + size_t by0 = opsin_rect.y0() / 8 + rect.y0(); + size_t by1 = by0 + rect.ysize(); + size_t bx0 = opsin_rect.x0() / 8 + rect.x0(); + size_t bx1 = bx0 + rect.xsize(); + temp_image->InitOnce(); + ImageF& laplacian_sqrsum = temp_image->laplacian_sqrsum; + // Calculate the L2 of the 3x3 Laplacian in an integral transform + // (for example 32x32 dct). This relates to transforms ability + // to propagate artefacts. + size_t y0 = by0 == 0 ? 0 : by0 * N - 2; + size_t y1 = by1 * N == opsin.ysize() ? by1 * N : by1 * N + 2; + size_t x0 = bx0 == 0 ? 0 : bx0 * N - 2; + size_t x1 = bx1 * N == opsin.xsize() ? bx1 * N : bx1 * N + 2; + HWY_FULL(float) df; + for (size_t y = y0; y < y1; y++) { + float* JXL_RESTRICT laplacian_sqrsum_row = + laplacian_sqrsum.Row(y + 2 - by0 * N); + const float* JXL_RESTRICT in_row_t[3]; + const float* JXL_RESTRICT in_row[3]; + const float* JXL_RESTRICT in_row_b[3]; + for (size_t c = 0; c < 3; c++) { + in_row_t[c] = opsin.ConstPlaneRow(c, y > 0 ? y - 1 : y); + in_row[c] = opsin.ConstPlaneRow(c, y); + in_row_b[c] = opsin.ConstPlaneRow(c, y + 1 < opsin.ysize() ? y + 1 : y); + } + auto compute_laplacian_scalar = [&](size_t x) { + const size_t prevX = x >= 1 ? x - 1 : x; + const size_t nextX = x + 1 < opsin.xsize() ? x + 1 : x; + float sumsqr = 0; + for (size_t c = 0; c < 3; c++) { + float laplacian = + kChannelWeights[c] * in_row[c][x] + + kChannelWeightsLapNeg[c] * + (in_row[c][prevX] + in_row[c][nextX] + in_row_b[c][prevX] + + in_row_b[c][x] + in_row_b[c][nextX] + in_row_t[c][prevX] + + in_row_t[c][x] + in_row_t[c][nextX]); + sumsqr += laplacian * laplacian; + } + laplacian_sqrsum_row[x + 2 - bx0 * N] = sumsqr; + }; + size_t x = x0; + for (; x < 1; x++) { + compute_laplacian_scalar(x); + } + // Interior. One extra pixel of border as the last pixel is special. + for (; x + Lanes(df) <= x1 && x + Lanes(df) + 1 <= opsin.xsize(); + x += Lanes(df)) { + auto sumsqr = Zero(df); + for (size_t c = 0; c < 3; c++) { + auto laplacian = + Mul(LoadU(df, in_row[c] + x), Set(df, kChannelWeights[c])); + auto sum_oth0 = LoadU(df, in_row[c] + x - 1); + auto sum_oth1 = LoadU(df, in_row[c] + x + 1); + auto sum_oth2 = LoadU(df, in_row_t[c] + x - 1); + auto sum_oth3 = LoadU(df, in_row_t[c] + x); + sum_oth0 = Add(sum_oth0, LoadU(df, in_row_t[c] + x + 1)); + sum_oth1 = Add(sum_oth1, LoadU(df, in_row_b[c] + x - 1)); + sum_oth2 = Add(sum_oth2, LoadU(df, in_row_b[c] + x)); + sum_oth3 = Add(sum_oth3, LoadU(df, in_row_b[c] + x + 1)); + sum_oth0 = Add(sum_oth0, sum_oth1); + sum_oth2 = Add(sum_oth2, sum_oth3); + sum_oth0 = Add(sum_oth0, sum_oth2); + laplacian = + MulAdd(Set(df, kChannelWeightsLapNeg[c]), sum_oth0, laplacian); + sumsqr = MulAdd(laplacian, laplacian, sumsqr); + } + StoreU(sumsqr, df, laplacian_sqrsum_row + x + 2 - bx0 * N); + } + for (; x < x1; x++) { + compute_laplacian_scalar(x); + } + } + HWY_CAPPED(float, 4) df4; + // Calculate the L2 of the 3x3 Laplacian in 4x4 blocks within the area + // of the integral transform. Sample them within the integral transform + // with two offsets (0,0) and (-2, -2) pixels (sqrsum_00 and sqrsum_22, + // respectively). + ImageF& sqrsum_00 = temp_image->sqrsum_00; + size_t sqrsum_00_stride = sqrsum_00.PixelsPerRow(); + float* JXL_RESTRICT sqrsum_00_row = sqrsum_00.Row(0); + for (size_t y = 0; y < rect.ysize() * 2; y++) { + const float* JXL_RESTRICT rows_in[4]; + for (size_t iy = 0; iy < 4; iy++) { + rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy + 2); + } + float* JXL_RESTRICT row_out = sqrsum_00_row + y * sqrsum_00_stride; + for (size_t x = 0; x < rect.xsize() * 2; x++) { + auto sum = Zero(df4); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix += Lanes(df4)) { + sum = Add(sum, LoadU(df4, rows_in[iy] + x * 4 + ix + 2)); + } + } + row_out[x] = GetLane(Sqrt(SumOfLanes(df4, sum))) * (1.0f / 4.0f); + } + } + // Indexing iy and ix is a bit tricky as we include a 2 pixel border + // around the block for evenness calculations. This is similar to what + // we did in guetzli for the observability of artefacts, except there + // the element is a sliding 5x5, not sparsely sampled 4x4 box like here. + ImageF& sqrsum_22 = temp_image->sqrsum_22; + size_t sqrsum_22_stride = sqrsum_22.PixelsPerRow(); + float* JXL_RESTRICT sqrsum_22_row = sqrsum_22.Row(0); + for (size_t y = 0; y < rect.ysize() * 2 + 1; y++) { + const float* JXL_RESTRICT rows_in[4]; + for (size_t iy = 0; iy < 4; iy++) { + rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy); + } + float* JXL_RESTRICT row_out = sqrsum_22_row + y * sqrsum_22_stride; + // ignore pixels outside the image. + // Y coordinates are relative to by0*8+y*4. + size_t sy = y * 4 + by0 * 8 > 0 ? 0 : 2; + size_t ey = y * 4 + by0 * 8 + 2 <= opsin.ysize() + ? 4 + : opsin.ysize() - y * 4 - by0 * 8 + 2; + for (size_t x = 0; x < rect.xsize() * 2 + 1; x++) { + // ignore pixels outside the image. + // X coordinates are relative to bx0*8. + size_t sx = x * 4 + bx0 * 8 > 0 ? x * 4 : x * 4 + 2; + size_t ex = x * 4 + bx0 * 8 + 2 <= opsin.xsize() + ? x * 4 + 4 + : opsin.xsize() - bx0 * 8 + 2; + if (ex - sx == 4 && ey - sy == 4) { + auto sum = Zero(df4); + for (size_t iy = sy; iy < ey; iy++) { + for (size_t ix = sx; ix < ex; ix += Lanes(df4)) { + sum = Add(sum, Load(df4, rows_in[iy] + ix)); + } + } + row_out[x] = GetLane(Sqrt(SumOfLanes(df4, sum))) * (1.0f / 4.0f); + } else { + float sum = 0; + for (size_t iy = sy; iy < ey; iy++) { + for (size_t ix = sx; ix < ex; ix++) { + sum += rows_in[iy][ix]; + } + } + row_out[x] = std::sqrt(sum / ((ex - sx) * (ey - sy))); + } + } + } + for (size_t by = rect.y0(); by < rect.y1(); by++) { + AcStrategyRow acs_row = ac_strategy.ConstRow(by); + uint8_t* JXL_RESTRICT out_row = epf_sharpness->Row(by); + const float* JXL_RESTRICT quant_row = quant_field.Row(by); + for (size_t bx = rect.x0(); bx < rect.x1(); bx++) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + // The errors are going to be linear to the quantization value in this + // locality. We only have access to the initial quant field here. + float quant_val = 1.0f / quant_row[bx]; + + const auto sq00 = [&](size_t y, size_t x) { + return sqrsum_00_row[((by - rect.y0()) * 2 + y) * sqrsum_00_stride + + (bx - rect.x0()) * 2 + x]; + }; + const auto sq22 = [&](size_t y, size_t x) { + return sqrsum_22_row[((by - rect.y0()) * 2 + y) * sqrsum_22_stride + + (bx - rect.x0()) * 2 + x]; + }; + float sqrsum_integral_transform = 0; + for (size_t iy = 0; iy < acs.covered_blocks_y() * 2; iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x() * 2; ix++) { + sqrsum_integral_transform += sq00(iy, ix) * sq00(iy, ix); + } + } + sqrsum_integral_transform /= + 4 * acs.covered_blocks_x() * acs.covered_blocks_y(); + sqrsum_integral_transform = std::sqrt(sqrsum_integral_transform); + // If masking is high or amplitude of the artefacts is low, then no + // smoothing is needed. + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + // Five 4x4 blocks for masking estimation, all within the + // 8x8 area. + float minval_1 = std::min(sq00(2 * iy + 0, 2 * ix + 0), + sq00(2 * iy + 0, 2 * ix + 1)); + float minval_2 = std::min(sq00(2 * iy + 1, 2 * ix + 0), + sq00(2 * iy + 1, 2 * ix + 1)); + float minval = std::min(minval_1, minval_2); + minval = std::min(minval, sq22(2 * iy + 1, 2 * ix + 1)); + // Nine more 4x4 blocks for masking estimation, includes + // the 2 pixel area around the 8x8 block being controlled. + float minval2_1 = std::min(sq22(2 * iy + 0, 2 * ix + 0), + sq22(2 * iy + 0, 2 * ix + 1)); + float minval2_2 = std::min(sq22(2 * iy + 0, 2 * ix + 2), + sq22(2 * iy + 1, 2 * ix + 0)); + float minval2_3 = std::min(sq22(2 * iy + 1, 2 * ix + 1), + sq22(2 * iy + 1, 2 * ix + 2)); + float minval2_4 = std::min(sq22(2 * iy + 2, 2 * ix + 0), + sq22(2 * iy + 2, 2 * ix + 1)); + float minval2_5 = std::min(minval2_1, minval2_2); + float minval2_6 = std::min(minval2_3, minval2_4); + float minval2 = std::min(minval2_5, minval2_6); + minval2 = std::min(minval2, sq22(2 * iy + 2, 2 * ix + 2)); + float minval3 = std::min(minval, minval2); + minval *= 0.125f; + minval += 0.625f * minval3; + minval += + 0.125f * std::min(1.5f * minval3, sq22(2 * iy + 1, 2 * ix + 1)); + minval += 0.125f * minval2; + // Larger kBias, less smoothing for low intensity changes. + float kDeltaLimit = 3.2; + float bias = 0.0625f * quant_val; + float delta = + (sqrsum_integral_transform + (kDeltaLimit + 0.05) * bias) / + (minval + bias); + int out = 4; + if (delta > kDeltaLimit) { + out = 4; // smooth + } else { + out = 0; + } + // 'threshold' is separate from 'bias' for easier tuning of these + // heuristics. + float threshold = 0.0625f * quant_val; + const float kSmoothLimit = 0.085f; + float smooth = 0.20f * (sq00(2 * iy + 0, 2 * ix + 0) + + sq00(2 * iy + 0, 2 * ix + 1) + + sq00(2 * iy + 1, 2 * ix + 0) + + sq00(2 * iy + 1, 2 * ix + 1) + minval); + if (smooth < kSmoothLimit * threshold) { + out = 4; + } + out_row[bx + sharpness_stride * iy + ix] = out; + } + } + } + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ProcessTile); + +void ArControlFieldHeuristics::RunRect( + const CompressParams& cparams, const FrameHeader& frame_header, + const Rect& block_rect, const Image3F& opsin, const Rect& opsin_rect, + const ImageF& quant_field, const AcStrategyImage& ac_strategy, + ImageB* epf_sharpness, size_t thread) { + HWY_DYNAMIC_DISPATCH(ProcessTile) + (cparams, frame_header, opsin, opsin_rect, quant_field, ac_strategy, + epf_sharpness, block_rect, &temp_images[thread]); +} + +} // namespace jxl + +#endif diff --git a/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.h b/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.h new file mode 100644 index 0000000000..fe602c16e3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_ar_control_field.h @@ -0,0 +1,51 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_AR_CONTROL_FIELD_H_ +#define LIB_JXL_ENC_AR_CONTROL_FIELD_H_ + +#include <stddef.h> + +#include <vector> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" + +namespace jxl { + +struct PassesEncoderState; + +struct ArControlFieldHeuristics { + struct TempImages { + void InitOnce() { + if (laplacian_sqrsum.xsize() != 0) return; + laplacian_sqrsum = ImageF(kEncTileDim + 4, kEncTileDim + 4); + sqrsum_00 = ImageF(kEncTileDim / 4, kEncTileDim / 4); + sqrsum_22 = ImageF(kEncTileDim / 4 + 1, kEncTileDim / 4 + 1); + } + + ImageF laplacian_sqrsum; + ImageF sqrsum_00; + ImageF sqrsum_22; + }; + + void PrepareForThreads(size_t num_threads) { + temp_images.resize(num_threads); + } + + void RunRect(const CompressParams& cparams, const FrameHeader& frame_header, + const Rect& block_rect, const Image3F& opsin, + const Rect& opsin_rect, const ImageF& quant_field, + const AcStrategyImage& ac_strategy, ImageB* epf_sharpness, + size_t thread); + + std::vector<TempImages> temp_images; +}; + +} // namespace jxl + +#endif // LIB_JXL_AR_ENC_CONTROL_FIELD_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_aux_out.cc b/third_party/jpeg-xl/lib/jxl/enc_aux_out.cc new file mode 100644 index 0000000000..12c8619e91 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_aux_out.cc @@ -0,0 +1,130 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_aux_out.h" + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS +#endif + +#include <inttypes.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <numeric> // accumulate +#include <sstream> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +const char* LayerName(size_t layer) { + switch (layer) { + case kLayerHeader: + return "Headers"; + case kLayerTOC: + return "TOC"; + case kLayerDictionary: + return "Patches"; + case kLayerSplines: + return "Splines"; + case kLayerNoise: + return "Noise"; + case kLayerQuant: + return "Quantizer"; + case kLayerModularTree: + return "ModularTree"; + case kLayerModularGlobal: + return "ModularGlobal"; + case kLayerDC: + return "DC"; + case kLayerModularDcGroup: + return "ModularDcGroup"; + case kLayerControlFields: + return "ControlFields"; + case kLayerOrder: + return "CoeffOrder"; + case kLayerAC: + return "ACHistograms"; + case kLayerACTokens: + return "ACTokens"; + case kLayerModularAcGroup: + return "ModularAcGroup"; + default: + JXL_UNREACHABLE("Invalid layer %d\n", static_cast<int>(layer)); + } +} + +void AuxOut::LayerTotals::Print(size_t num_inputs) const { + if (JXL_DEBUG_V_LEVEL > 0) { + printf("%10" PRId64, static_cast<int64_t>(total_bits)); + if (histogram_bits != 0) { + printf(" [c/i:%6.2f | hst:%8" PRId64 " | ex:%8" PRId64 + " | h+c+e:%12.3f", + num_clustered_histograms * 1.0 / num_inputs, + static_cast<int64_t>(histogram_bits >> 3), + static_cast<int64_t>(extra_bits >> 3), + (histogram_bits + clustered_entropy + extra_bits) / 8.0); + printf("]"); + } + printf("\n"); + } +} + +void AuxOut::Assimilate(const AuxOut& victim) { + for (size_t i = 0; i < layers.size(); ++i) { + layers[i].Assimilate(victim.layers[i]); + } + num_blocks += victim.num_blocks; + num_small_blocks += victim.num_small_blocks; + num_dct4x8_blocks += victim.num_dct4x8_blocks; + num_afv_blocks += victim.num_afv_blocks; + num_dct8_blocks += victim.num_dct8_blocks; + num_dct8x16_blocks += victim.num_dct8x16_blocks; + num_dct8x32_blocks += victim.num_dct8x32_blocks; + num_dct16_blocks += victim.num_dct16_blocks; + num_dct16x32_blocks += victim.num_dct16x32_blocks; + num_dct32_blocks += victim.num_dct32_blocks; + num_dct32x64_blocks += victim.num_dct32x64_blocks; + num_dct64_blocks += victim.num_dct64_blocks; + num_butteraugli_iters += victim.num_butteraugli_iters; +} + +void AuxOut::Print(size_t num_inputs) const { + if (JXL_DEBUG_V_LEVEL > 0) { + if (num_inputs == 0) return; + + LayerTotals all_layers; + for (size_t i = 0; i < layers.size(); ++i) { + all_layers.Assimilate(layers[i]); + } + + printf("Average butteraugli iters: %10.2f\n", + num_butteraugli_iters * 1.0 / num_inputs); + + for (size_t i = 0; i < layers.size(); ++i) { + if (layers[i].total_bits != 0) { + printf("Total layer bits %-10s\t", LayerName(i)); + printf("%10f%%", 100.0 * layers[i].total_bits / all_layers.total_bits); + layers[i].Print(num_inputs); + } + } + printf("Total image size "); + all_layers.Print(num_inputs); + + size_t total_blocks = 0; + size_t total_positions = 0; + if (total_blocks != 0 && total_positions != 0) { + printf("\n\t\t Blocks\t\tPositions\t\t\tBlocks/Position\n"); + printf(" Total:\t\t %7" PRIuS "\t\t %7" PRIuS " \t\t\t%10f%%\n\n", + total_blocks, total_positions, + 100.0 * total_blocks / total_positions); + } + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_aux_out.h b/third_party/jpeg-xl/lib/jxl/enc_aux_out.h new file mode 100644 index 0000000000..545711af83 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_aux_out.h @@ -0,0 +1,102 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_AUX_OUT_H_ +#define LIB_JXL_AUX_OUT_H_ + +// Optional output information for debugging and analyzing size usage. + +#include <stddef.h> + +#include <array> +#include <functional> +#include <string> + +namespace jxl { + +struct ColorEncoding; + +// For LayerName and AuxOut::layers[] index. Order does not matter. +enum { + kLayerHeader = 0, + kLayerTOC, + kLayerDictionary, + kLayerSplines, + kLayerNoise, + kLayerQuant, + kLayerModularTree, + kLayerModularGlobal, + kLayerDC, + kLayerModularDcGroup, + kLayerControlFields, + kLayerOrder, + kLayerAC, + kLayerACTokens, + kLayerModularAcGroup, + kNumImageLayers +}; + +const char* LayerName(size_t layer); + +// Statistics gathered during compression or decompression. +struct AuxOut { + private: + struct LayerTotals { + void Assimilate(const LayerTotals& victim) { + num_clustered_histograms += victim.num_clustered_histograms; + histogram_bits += victim.histogram_bits; + extra_bits += victim.extra_bits; + total_bits += victim.total_bits; + clustered_entropy += victim.clustered_entropy; + } + void Print(size_t num_inputs) const; + + size_t num_clustered_histograms = 0; + size_t extra_bits = 0; + + // Set via BitsWritten below + size_t histogram_bits = 0; + size_t total_bits = 0; + + double clustered_entropy = 0.0; + }; + + public: + AuxOut() = default; + AuxOut(const AuxOut&) = default; + + void Assimilate(const AuxOut& victim); + + void Print(size_t num_inputs) const; + + size_t TotalBits() const { + size_t total = 0; + for (const auto& layer : layers) { + total += layer.total_bits; + } + return total; + } + + std::array<LayerTotals, kNumImageLayers> layers; + size_t num_blocks = 0; + + // Number of blocks that use larger DCT (set by ac_strategy). + size_t num_small_blocks = 0; + size_t num_dct4x8_blocks = 0; + size_t num_afv_blocks = 0; + size_t num_dct8_blocks = 0; + size_t num_dct8x16_blocks = 0; + size_t num_dct8x32_blocks = 0; + size_t num_dct16_blocks = 0; + size_t num_dct16x32_blocks = 0; + size_t num_dct32_blocks = 0; + size_t num_dct32x64_blocks = 0; + size_t num_dct64_blocks = 0; + + int num_butteraugli_iters = 0; +}; +} // namespace jxl + +#endif // LIB_JXL_AUX_OUT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_bit_writer.cc b/third_party/jpeg-xl/lib/jxl/enc_bit_writer.cc new file mode 100644 index 0000000000..a9a86dca3b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_bit_writer.cc @@ -0,0 +1,215 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_bit_writer.h" + +#include <string.h> // memcpy + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_aux_out.h" + +namespace jxl { + +BitWriter::Allotment::Allotment(BitWriter* JXL_RESTRICT writer, size_t max_bits) + : max_bits_(max_bits) { + if (writer == nullptr) return; + prev_bits_written_ = writer->BitsWritten(); + const size_t prev_bytes = writer->storage_.size(); + const size_t next_bytes = DivCeil(max_bits, kBitsPerByte); + writer->storage_.resize(prev_bytes + next_bytes); + parent_ = writer->current_allotment_; + writer->current_allotment_ = this; +} + +BitWriter::Allotment::~Allotment() { + if (!called_) { + // Not calling is a bug - unused storage will not be reclaimed. + JXL_UNREACHABLE("Did not call Allotment::ReclaimUnused"); + } +} + +void BitWriter::Allotment::FinishedHistogram(BitWriter* JXL_RESTRICT writer) { + if (writer == nullptr) return; + JXL_ASSERT(!called_); // Call before ReclaimUnused + JXL_ASSERT(histogram_bits_ == 0); // Do not call twice + JXL_ASSERT(writer->BitsWritten() >= prev_bits_written_); + histogram_bits_ = writer->BitsWritten() - prev_bits_written_; +} + +void BitWriter::Allotment::ReclaimAndCharge(BitWriter* JXL_RESTRICT writer, + size_t layer, + AuxOut* JXL_RESTRICT aux_out) { + size_t used_bits = 0, unused_bits = 0; + PrivateReclaim(writer, &used_bits, &unused_bits); + +#if 0 + printf("Layer %s bits: max %" PRIuS " used %" PRIuS " unused %" PRIuS "\n", + LayerName(layer), MaxBits(), used_bits, unused_bits); +#endif + + // This may be a nested call with aux_out == null. Whenever we know that + // aux_out is null, we can call ReclaimUnused directly. + if (aux_out != nullptr) { + aux_out->layers[layer].total_bits += used_bits; + aux_out->layers[layer].histogram_bits += HistogramBits(); + } +} + +void BitWriter::Allotment::PrivateReclaim(BitWriter* JXL_RESTRICT writer, + size_t* JXL_RESTRICT used_bits, + size_t* JXL_RESTRICT unused_bits) { + JXL_ASSERT(!called_); // Do not call twice + called_ = true; + if (writer == nullptr) return; + + JXL_ASSERT(writer->BitsWritten() >= prev_bits_written_); + *used_bits = writer->BitsWritten() - prev_bits_written_; + JXL_ASSERT(*used_bits <= max_bits_); + *unused_bits = max_bits_ - *used_bits; + + // Reclaim unused bytes whole bytes from writer's allotment. + const size_t unused_bytes = *unused_bits / kBitsPerByte; // truncate + JXL_ASSERT(writer->storage_.size() >= unused_bytes); + writer->storage_.resize(writer->storage_.size() - unused_bytes); + writer->current_allotment_ = parent_; + // Ensure we don't also charge the parent for these bits. + auto parent = parent_; + while (parent != nullptr) { + parent->prev_bits_written_ += *used_bits; + parent = parent->parent_; + } +} + +void BitWriter::AppendByteAligned(const Span<const uint8_t>& span) { + if (span.empty()) return; + storage_.resize(storage_.size() + span.size() + 1); // extra zero padding + + // Concatenate by copying bytes because both source and destination are bytes. + JXL_ASSERT(BitsWritten() % kBitsPerByte == 0); + size_t pos = BitsWritten() / kBitsPerByte; + memcpy(storage_.data() + pos, span.data(), span.size()); + pos += span.size(); + storage_[pos++] = 0; // for next Write + JXL_ASSERT(pos <= storage_.size()); + bits_written_ += span.size() * kBitsPerByte; +} + +void BitWriter::AppendByteAligned(const BitWriter& other) { + JXL_ASSERT(other.BitsWritten() % kBitsPerByte == 0); + JXL_ASSERT(other.BitsWritten() / kBitsPerByte != 0); + + AppendByteAligned(other.GetSpan()); +} + +void BitWriter::AppendUnaligned(const BitWriter& other) { + Allotment allotment(this, other.BitsWritten()); + size_t full_bytes = other.BitsWritten() / kBitsPerByte; + size_t remaining_bits = other.BitsWritten() % kBitsPerByte; + for (size_t i = 0; i < full_bytes; ++i) { + Write(8, other.storage_[i]); + } + if (remaining_bits > 0) { + Write(remaining_bits, + other.storage_[full_bytes] & ((1u << remaining_bits) - 1)); + } + allotment.ReclaimAndCharge(this, 0, nullptr); +} + +void BitWriter::AppendByteAligned(const std::vector<BitWriter>& others) { + // Total size to add so we can preallocate + size_t other_bytes = 0; + for (const BitWriter& writer : others) { + JXL_ASSERT(writer.BitsWritten() % kBitsPerByte == 0); + other_bytes += writer.BitsWritten() / kBitsPerByte; + } + if (other_bytes == 0) { + // No bytes to append: this happens for example when creating per-group + // storage for groups, but not writing anything in them for e.g. lossless + // images with no alpha. Do nothing. + return; + } + storage_.resize(storage_.size() + other_bytes + 1); // extra zero padding + + // Concatenate by copying bytes because both source and destination are bytes. + JXL_ASSERT(BitsWritten() % kBitsPerByte == 0); + size_t pos = BitsWritten() / kBitsPerByte; + for (const BitWriter& writer : others) { + const Span<const uint8_t> span = writer.GetSpan(); + if (!span.empty()) { + memcpy(storage_.data() + pos, span.data(), span.size()); + pos += span.size(); + } + } + storage_[pos++] = 0; // for next Write + JXL_ASSERT(pos <= storage_.size()); + bits_written_ += other_bytes * kBitsPerByte; +} + +// TODO(lode): avoid code duplication +void BitWriter::AppendByteAligned( + const std::vector<std::unique_ptr<BitWriter>>& others) { + // Total size to add so we can preallocate + size_t other_bytes = 0; + for (const auto& writer : others) { + JXL_ASSERT(writer->BitsWritten() % kBitsPerByte == 0); + other_bytes += writer->BitsWritten() / kBitsPerByte; + } + if (other_bytes == 0) { + // No bytes to append: this happens for example when creating per-group + // storage for groups, but not writing anything in them for e.g. lossless + // images with no alpha. Do nothing. + return; + } + storage_.resize(storage_.size() + other_bytes + 1); // extra zero padding + + // Concatenate by copying bytes because both source and destination are bytes. + JXL_ASSERT(BitsWritten() % kBitsPerByte == 0); + size_t pos = BitsWritten() / kBitsPerByte; + for (const auto& writer : others) { + const Span<const uint8_t> span = writer->GetSpan(); + memcpy(storage_.data() + pos, span.data(), span.size()); + pos += span.size(); + } + storage_[pos++] = 0; // for next Write + JXL_ASSERT(pos <= storage_.size()); + bits_written_ += other_bytes * kBitsPerByte; +} + +// Example: let's assume that 3 bits (Rs below) have been written already: +// BYTE+0 BYTE+1 BYTE+2 +// 0000 0RRR ???? ???? ???? ???? +// +// Now, we could write up to 5 bits by just shifting them left by 3 bits and +// OR'ing to BYTE-0. +// +// For n > 5 bits, we write the lowest 5 bits as above, then write the next +// lowest bits into BYTE+1 starting from its lower bits and so on. +void BitWriter::Write(size_t n_bits, uint64_t bits) { + JXL_DASSERT((bits >> n_bits) == 0); + JXL_DASSERT(n_bits <= kMaxBitsPerCall); + uint8_t* p = &storage_[bits_written_ / kBitsPerByte]; + const size_t bits_in_first_byte = bits_written_ % kBitsPerByte; + bits <<= bits_in_first_byte; +#if JXL_BYTE_ORDER_LITTLE + uint64_t v = *p; + // Last (partial) or next byte to write must be zero-initialized! + // PaddedBytes initializes the first, and Write/Append maintain this. + JXL_DASSERT(v >> bits_in_first_byte == 0); + v |= bits; + memcpy(p, &v, sizeof(v)); // Write bytes: possibly more than n_bits/8 +#else + *p++ |= static_cast<uint8_t>(bits & 0xFF); + for (size_t bits_left_to_write = n_bits + bits_in_first_byte; + bits_left_to_write >= 9; bits_left_to_write -= 8) { + bits >>= 8; + *p++ = static_cast<uint8_t>(bits & 0xFF); + } + *p = 0; +#endif + bits_written_ += n_bits; +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_bit_writer.h b/third_party/jpeg-xl/lib/jxl/enc_bit_writer.h new file mode 100644 index 0000000000..6f4865077d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_bit_writer.h @@ -0,0 +1,129 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_BIT_WRITER_H_ +#define LIB_JXL_ENC_BIT_WRITER_H_ + +// BitWriter class: unbuffered writes using unaligned 64-bit stores. + +#include <stddef.h> +#include <stdint.h> + +#include <utility> +#include <vector> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/padded_bytes.h" + +namespace jxl { + +struct AuxOut; + +struct BitWriter { + // Upper bound on `n_bits` in each call to Write. We shift a 64-bit word by + // 7 bits (max already valid bits in the last byte) and at least 1 bit is + // needed to zero-initialize the bit-stream ahead (i.e. if 7 bits are valid + // and we write 57 bits, then the next write will access a byte that was not + // yet zero-initialized). + static constexpr size_t kMaxBitsPerCall = 56; + + BitWriter() : bits_written_(0) {} + + // Disallow copying - may lead to bugs. + BitWriter(const BitWriter&) = delete; + BitWriter& operator=(const BitWriter&) = delete; + BitWriter(BitWriter&&) = default; + BitWriter& operator=(BitWriter&&) = default; + + size_t BitsWritten() const { return bits_written_; } + + Span<const uint8_t> GetSpan() const { + // Callers must ensure byte alignment to avoid uninitialized bits. + JXL_ASSERT(bits_written_ % kBitsPerByte == 0); + return Bytes(storage_.data(), bits_written_ / kBitsPerByte); + } + + // Example usage: bytes = std::move(writer).TakeBytes(); Useful for the + // top-level encoder which returns PaddedBytes, not a BitWriter. + // *this must be an rvalue reference and is invalid afterwards. + PaddedBytes&& TakeBytes() && { + // Callers must ensure byte alignment to avoid uninitialized bits. + JXL_ASSERT(bits_written_ % kBitsPerByte == 0); + storage_.resize(bits_written_ / kBitsPerByte); + return std::move(storage_); + } + + // Must be byte-aligned before calling. + void AppendByteAligned(const Span<const uint8_t>& span); + + // NOTE: no allotment needed, the other BitWriters have already been charged. + void AppendByteAligned(const BitWriter& other); + void AppendByteAligned(const std::vector<std::unique_ptr<BitWriter>>& others); + void AppendByteAligned(const std::vector<BitWriter>& others); + + void AppendUnaligned(const BitWriter& other); + + class Allotment { + public: + // Expands a BitWriter's storage. Must happen before calling Write or + // ZeroPadToByte. Must call ReclaimUnused after writing to reclaim the + // unused storage so that BitWriter memory use remains tightly bounded. + Allotment(BitWriter* JXL_RESTRICT writer, size_t max_bits); + ~Allotment(); + + size_t MaxBits() const { return max_bits_; } + + // Call after writing a histogram, but before ReclaimUnused. + void FinishedHistogram(BitWriter* JXL_RESTRICT writer); + + size_t HistogramBits() const { + JXL_ASSERT(called_); + return histogram_bits_; + } + + void ReclaimAndCharge(BitWriter* JXL_RESTRICT writer, size_t layer, + AuxOut* JXL_RESTRICT aux_out); + + private: + void PrivateReclaim(BitWriter* JXL_RESTRICT writer, + size_t* JXL_RESTRICT used_bits, + size_t* JXL_RESTRICT unused_bits); + + size_t prev_bits_written_; + const size_t max_bits_; + size_t histogram_bits_ = 0; + bool called_ = false; + Allotment* parent_; + }; + + // Writes bits into bytes in increasing addresses, and within a byte + // least-significant-bit first. + // + // The function can write up to 56 bits in one go. + void Write(size_t n_bits, uint64_t bits); + + // This should only rarely be used - e.g. when the current location will be + // referenced via byte offset (TOCs point to groups), or byte-aligned reading + // is required for speed. + void ZeroPadToByte() { + const size_t remainder_bits = + RoundUpBitsToByteMultiple(bits_written_) - bits_written_; + if (remainder_bits == 0) return; + Write(remainder_bits, 0); + JXL_ASSERT(bits_written_ % kBitsPerByte == 0); + } + + private: + size_t bits_written_; + PaddedBytes storage_; + Allotment* current_allotment_ = nullptr; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_BIT_WRITER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.cc b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.cc new file mode 100644 index 0000000000..019d6125a2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.cc @@ -0,0 +1,107 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_butteraugli_comparator.h" + +#include <algorithm> +#include <vector> + +#include "lib/jxl/enc_image_bundle.h" + +namespace jxl { + +JxlButteraugliComparator::JxlButteraugliComparator( + const ButteraugliParams& params, const JxlCmsInterface& cms) + : params_(params), cms_(cms) {} + +Status JxlButteraugliComparator::SetReferenceImage(const ImageBundle& ref) { + const ImageBundle* ref_linear_srgb; + ImageMetadata metadata = *ref.metadata(); + ImageBundle store(&metadata); + if (!TransformIfNeeded(ref, ColorEncoding::LinearSRGB(ref.IsGray()), cms_, + /*pool=*/nullptr, &store, &ref_linear_srgb)) { + return false; + } + + comparator_.reset( + new ButteraugliComparator(ref_linear_srgb->color(), params_)); + xsize_ = ref.xsize(); + ysize_ = ref.ysize(); + return true; +} + +Status JxlButteraugliComparator::SetLinearReferenceImage( + const Image3F& linear) { + comparator_.reset(new ButteraugliComparator(linear, params_)); + xsize_ = linear.xsize(); + ysize_ = linear.ysize(); + return true; +} + +Status JxlButteraugliComparator::CompareWith(const ImageBundle& actual, + ImageF* diffmap, float* score) { + if (!comparator_) { + return JXL_FAILURE("Must set reference image first"); + } + if (xsize_ != actual.xsize() || ysize_ != actual.ysize()) { + return JXL_FAILURE("Images must have same size"); + } + + const ImageBundle* actual_linear_srgb; + ImageMetadata metadata = *actual.metadata(); + ImageBundle store(&metadata); + if (!TransformIfNeeded(actual, ColorEncoding::LinearSRGB(actual.IsGray()), + cms_, + /*pool=*/nullptr, &store, &actual_linear_srgb)) { + return false; + } + + ImageF temp_diffmap(xsize_, ysize_); + comparator_->Diffmap(actual_linear_srgb->color(), temp_diffmap); + + if (score != nullptr) { + *score = ButteraugliScoreFromDiffmap(temp_diffmap, ¶ms_); + } + if (diffmap != nullptr) { + diffmap->Swap(temp_diffmap); + } + + return true; +} + +float JxlButteraugliComparator::GoodQualityScore() const { + return ButteraugliFuzzyInverse(1.5); +} + +float JxlButteraugliComparator::BadQualityScore() const { + return ButteraugliFuzzyInverse(0.5); +} + +float ButteraugliDistance(const ImageBundle& rgb0, const ImageBundle& rgb1, + const ButteraugliParams& params, + const JxlCmsInterface& cms, ImageF* distmap, + ThreadPool* pool, bool ignore_alpha) { + JxlButteraugliComparator comparator(params, cms); + return ComputeScore(rgb0, rgb1, &comparator, cms, distmap, pool, + ignore_alpha); +} + +float ButteraugliDistance(const std::vector<ImageBundle>& frames0, + const std::vector<ImageBundle>& frames1, + const ButteraugliParams& params, + const JxlCmsInterface& cms, ImageF* distmap, + ThreadPool* pool) { + JxlButteraugliComparator comparator(params, cms); + JXL_ASSERT(frames0.size() == frames1.size()); + float max_dist = 0.0f; + for (size_t i = 0; i < frames0.size(); ++i) { + max_dist = std::max( + max_dist, + ComputeScore(frames0[i], frames1[i], &comparator, cms, distmap, pool)); + } + return max_dist; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.h b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.h new file mode 100644 index 0000000000..641d7732d5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_butteraugli_comparator.h @@ -0,0 +1,62 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_BUTTERAUGLI_COMPARATOR_H_ +#define LIB_JXL_ENC_BUTTERAUGLI_COMPARATOR_H_ + +#include <jxl/cms_interface.h> +#include <stddef.h> + +#include <memory> +#include <vector> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/enc_comparator.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +class JxlButteraugliComparator : public Comparator { + public: + explicit JxlButteraugliComparator(const ButteraugliParams& params, + const JxlCmsInterface& cms); + + Status SetReferenceImage(const ImageBundle& ref) override; + Status SetLinearReferenceImage(const Image3F& linear); + + Status CompareWith(const ImageBundle& actual, ImageF* diffmap, + float* score) override; + + float GoodQualityScore() const override; + float BadQualityScore() const override; + + private: + ButteraugliParams params_; + JxlCmsInterface cms_; + std::unique_ptr<ButteraugliComparator> comparator_; + size_t xsize_ = 0; + size_t ysize_ = 0; +}; + +// Returns the butteraugli distance between rgb0 and rgb1. +// If distmap is not null, it must be the same size as rgb0 and rgb1. +float ButteraugliDistance(const ImageBundle& rgb0, const ImageBundle& rgb1, + const ButteraugliParams& params, + const JxlCmsInterface& cms, ImageF* distmap = nullptr, + ThreadPool* pool = nullptr, + bool ignore_alpha = false); + +float ButteraugliDistance(const std::vector<ImageBundle>& frames0, + const std::vector<ImageBundle>& frames1, + const ButteraugliParams& params, + const JxlCmsInterface& cms, ImageF* distmap = nullptr, + ThreadPool* pool = nullptr); + +} // namespace jxl + +#endif // LIB_JXL_ENC_BUTTERAUGLI_COMPARATOR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_cache.cc b/third_party/jpeg-xl/lib/jxl/enc_cache.cc new file mode 100644 index 0000000000..ff62c57e4d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_cache.cc @@ -0,0 +1,214 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_cache.h" + +#include <stddef.h> +#include <stdint.h> + +#include <type_traits> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/enc_group.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +Status InitializePassesEncoder(const FrameHeader& frame_header, + const Image3F& opsin, const Rect& rect, + const JxlCmsInterface& cms, ThreadPool* pool, + PassesEncoderState* enc_state, + ModularFrameEncoder* modular_frame_encoder, + AuxOut* aux_out) { + PassesSharedState& JXL_RESTRICT shared = enc_state->shared; + + enc_state->x_qm_multiplier = std::pow(1.25f, frame_header.x_qm_scale - 2.0f); + enc_state->b_qm_multiplier = std::pow(1.25f, frame_header.b_qm_scale - 2.0f); + + if (enc_state->coeffs.size() < frame_header.passes.num_passes) { + enc_state->coeffs.reserve(frame_header.passes.num_passes); + for (size_t i = enc_state->coeffs.size(); + i < frame_header.passes.num_passes; i++) { + // Allocate enough coefficients for each group on every row. + enc_state->coeffs.emplace_back(make_unique<ACImageT<int32_t>>( + kGroupDim * kGroupDim, shared.frame_dim.num_groups)); + } + } + while (enc_state->coeffs.size() > frame_header.passes.num_passes) { + enc_state->coeffs.pop_back(); + } + + if (enc_state->initialize_global_state) { + float scale = + shared.quantizer.ScaleGlobalScale(enc_state->cparams.quant_ac_rescale); + DequantMatricesScaleDC(&shared.matrices, scale); + shared.quantizer.RecomputeFromGlobalScale(); + } + + Image3F dc(shared.frame_dim.xsize_blocks, shared.frame_dim.ysize_blocks); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, shared.frame_dim.num_groups, ThreadPool::NoInit, + [&](size_t group_idx, size_t _) { + ComputeCoefficients(group_idx, enc_state, opsin, rect, &dc); + }, + "Compute coeffs")); + + if (frame_header.flags & FrameHeader::kUseDcFrame) { + CompressParams cparams = enc_state->cparams; + cparams.dots = Override::kOff; + cparams.noise = Override::kOff; + cparams.patches = Override::kOff; + cparams.gaborish = Override::kOff; + cparams.epf = 0; + cparams.resampling = 1; + cparams.ec_resampling = 1; + // The DC frame will have alpha=0. Don't erase its contents. + cparams.keep_invisible = Override::kOn; + JXL_ASSERT(cparams.progressive_dc > 0); + cparams.progressive_dc--; + // Use kVarDCT in max_error_mode for intermediate progressive DC, + // and kModular for the smallest DC (first in the bitstream) + if (cparams.progressive_dc == 0) { + cparams.modular_mode = true; + cparams.speed_tier = + SpeedTier(std::max(static_cast<int>(SpeedTier::kTortoise), + static_cast<int>(cparams.speed_tier) - 1)); + cparams.butteraugli_distance = + std::max(kMinButteraugliDistance, + enc_state->cparams.butteraugli_distance * 0.02f); + } else { + cparams.max_error_mode = true; + for (size_t c = 0; c < 3; c++) { + cparams.max_error[c] = shared.quantizer.MulDC()[c]; + } + // Guess a distance that produces good initial results. + cparams.butteraugli_distance = + std::max(kMinButteraugliDistance, + enc_state->cparams.butteraugli_distance * 0.1f); + } + ImageBundle ib(&shared.metadata->m); + // This is a lie - dc is in XYB + // (but EncodeFrame will skip RGB->XYB conversion anyway) + ib.SetFromImage( + std::move(dc), + ColorEncoding::LinearSRGB(shared.metadata->m.color_encoding.IsGray())); + if (!ib.metadata()->extra_channel_info.empty()) { + // Add placeholder extra channels to the patch image: dc_level frames do + // not yet support extra channels, but the codec expects that the amount + // of extra channels in frames matches that in the metadata of the + // codestream. + std::vector<ImageF> extra_channels; + extra_channels.reserve(ib.metadata()->extra_channel_info.size()); + for (size_t i = 0; i < ib.metadata()->extra_channel_info.size(); i++) { + extra_channels.emplace_back(ib.xsize(), ib.ysize()); + // Must initialize the image with data to not affect blending with + // uninitialized memory. + // TODO(lode): dc_level must copy and use the real extra channels + // instead. + ZeroFillImage(&extra_channels.back()); + } + ib.SetExtraChannels(std::move(extra_channels)); + } + auto special_frame = std::unique_ptr<BitWriter>(new BitWriter()); + FrameInfo dc_frame_info; + dc_frame_info.frame_type = FrameType::kDCFrame; + dc_frame_info.dc_level = frame_header.dc_level + 1; + dc_frame_info.ib_needs_color_transform = false; + dc_frame_info.save_before_color_transform = true; // Implicitly true + AuxOut dc_aux_out; + JXL_CHECK(EncodeFrame(cparams, dc_frame_info, shared.metadata, ib, cms, + pool, special_frame.get(), + aux_out ? &dc_aux_out : nullptr)); + if (aux_out) { + for (const auto& l : dc_aux_out.layers) { + aux_out->layers[kLayerDC].Assimilate(l); + } + } + const Span<const uint8_t> encoded = special_frame->GetSpan(); + enc_state->special_frames.emplace_back(std::move(special_frame)); + + ImageBundle decoded(&shared.metadata->m); + std::unique_ptr<PassesDecoderState> dec_state = + jxl::make_unique<PassesDecoderState>(); + JXL_CHECK( + dec_state->output_encoding_info.SetFromMetadata(*shared.metadata)); + const uint8_t* frame_start = encoded.data(); + size_t encoded_size = encoded.size(); + for (int i = 0; i <= cparams.progressive_dc; ++i) { + JXL_CHECK(DecodeFrame(dec_state.get(), pool, frame_start, encoded_size, + /*frame_header=*/nullptr, &decoded, + *shared.metadata)); + frame_start += decoded.decoded_bytes(); + encoded_size -= decoded.decoded_bytes(); + } + // TODO(lode): frame_header.dc_level should be equal to + // dec_state.frame_header.dc_level - 1 here, since above we set + // dc_frame_info.dc_level = frame_header.dc_level + 1, and + // dc_frame_info.dc_level is used by EncodeFrame. However, if EncodeFrame + // outputs multiple frames, this assumption could be wrong. + const Image3F& dc_frame = + dec_state->shared->dc_frames[frame_header.dc_level]; + shared.dc_storage = Image3F(dc_frame.xsize(), dc_frame.ysize()); + CopyImageTo(dc_frame, &shared.dc_storage); + ZeroFillImage(&shared.quant_dc); + shared.dc = &shared.dc_storage; + JXL_CHECK(encoded_size == 0); + } else { + auto compute_dc_coeffs = [&](int group_index, int /* thread */) { + const Rect r = enc_state->shared.frame_dim.DCGroupRect(group_index); + int modular_group_index = group_index; + if (enc_state->streaming_mode) { + JXL_ASSERT(group_index == 0); + modular_group_index = enc_state->dc_group_index; + } + modular_frame_encoder->AddVarDCTDC( + frame_header, dc, r, modular_group_index, + enc_state->cparams.speed_tier < SpeedTier::kFalcon, enc_state, + /*jpeg_transcode=*/false); + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, shared.frame_dim.num_dc_groups, + ThreadPool::NoInit, compute_dc_coeffs, + "Compute DC coeffs")); + // TODO(veluca): this is only useful in tests and if inspection is enabled. + if (!(frame_header.flags & FrameHeader::kSkipAdaptiveDCSmoothing)) { + AdaptiveDCSmoothing(shared.quantizer.MulDC(), &shared.dc_storage, pool); + } + } + auto compute_ac_meta = [&](int group_index, int /* thread */) { + const Rect r = enc_state->shared.frame_dim.DCGroupRect(group_index); + int modular_group_index = group_index; + if (enc_state->streaming_mode) { + JXL_ASSERT(group_index == 0); + modular_group_index = enc_state->dc_group_index; + } + modular_frame_encoder->AddACMetadata(r, modular_group_index, + /*jpeg_transcode=*/false, enc_state); + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, shared.frame_dim.num_dc_groups, + ThreadPool::NoInit, compute_ac_meta, + "Compute AC Metadata")); + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_cache.h b/third_party/jpeg-xl/lib/jxl/enc_cache.h new file mode 100644 index 0000000000..6efcc081c1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_cache.h @@ -0,0 +1,81 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_CACHE_H_ +#define LIB_JXL_ENC_CACHE_H_ + +#include <jxl/cms_interface.h> +#include <stddef.h> +#include <stdint.h> + +#include <memory> +#include <vector> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_progressive_split.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +struct AuxOut; + +// Contains encoder state. +struct PassesEncoderState { + PassesSharedState shared; + + bool streaming_mode = false; + bool initialize_global_state = true; + size_t dc_group_index = 0; + + // Per-pass DCT coefficients for the image. One row per group. + std::vector<std::unique_ptr<ACImage>> coeffs; + + // Raw data for special (reference+DC) frames. + std::vector<std::unique_ptr<BitWriter>> special_frames; + + // For splitting into passes. + ProgressiveSplitter progressive_splitter; + + CompressParams cparams; + + struct PassData { + std::vector<std::vector<Token>> ac_tokens; + std::vector<uint8_t> context_map; + EntropyEncodingData codes; + }; + + std::vector<PassData> passes; + std::vector<uint8_t> histogram_idx; + + // Block sizes seen so far. + uint32_t used_acs = 0; + // Coefficient orders that are non-default. + std::vector<uint32_t> used_orders; + + // Multiplier to be applied to the quant matrices of the x channel. + float x_qm_multiplier = 1.0f; + float b_qm_multiplier = 1.0f; +}; + +// Initialize per-frame information. +class ModularFrameEncoder; +Status InitializePassesEncoder(const FrameHeader& frame_header, + const Image3F& opsin, const Rect& rect, + const JxlCmsInterface& cms, ThreadPool* pool, + PassesEncoderState* passes_enc_state, + ModularFrameEncoder* modular_frame_encoder, + AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_CACHE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.cc b/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.cc new file mode 100644 index 0000000000..9a894d89cc --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.cc @@ -0,0 +1,403 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_chroma_from_luma.h" + +#include <float.h> +#include <stdlib.h> + +#include <algorithm> +#include <array> +#include <cmath> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_chroma_from_luma.cc" +#include <hwy/aligned_allocator.h> +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/cms/opsin_params.h" +#include "lib/jxl/dec_transforms-inl.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/simd_util.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Ge; +using hwy::HWY_NAMESPACE::GetLane; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::Lt; + +static HWY_FULL(float) df; + +struct CFLFunction { + static constexpr float kCoeff = 1.f / 3; + static constexpr float kThres = 100.0f; + static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor; + CFLFunction(const float* values_m, const float* values_s, size_t num, + float base, float distance_mul) + : values_m(values_m), + values_s(values_s), + num(num), + base(base), + distance_mul(distance_mul) {} + + // Returns f'(x), where f is 1/3 * sum ((|color residual| + 1)^2-1) + + // distance_mul * x^2 * num. + float Compute(float x, float eps, float* fpeps, float* fmeps) const { + float first_derivative = 2 * distance_mul * num * x; + float first_derivative_peps = 2 * distance_mul * num * (x + eps); + float first_derivative_meps = 2 * distance_mul * num * (x - eps); + + const auto inv_color_factor = Set(df, kInvColorFactor); + const auto thres = Set(df, kThres); + const auto coeffx2 = Set(df, kCoeff * 2.0f); + const auto one = Set(df, 1.0f); + const auto zero = Set(df, 0.0f); + const auto base_v = Set(df, base); + const auto x_v = Set(df, x); + const auto xpe_v = Set(df, x + eps); + const auto xme_v = Set(df, x - eps); + auto fd_v = Zero(df); + auto fdpe_v = Zero(df); + auto fdme_v = Zero(df); + JXL_ASSERT(num % Lanes(df) == 0); + + for (size_t i = 0; i < num; i += Lanes(df)) { + // color residual = ax + b + const auto a = Mul(inv_color_factor, Load(df, values_m + i)); + const auto b = + Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i)); + const auto v = MulAdd(a, x_v, b); + const auto vpe = MulAdd(a, xpe_v, b); + const auto vme = MulAdd(a, xme_v, b); + const auto av = Abs(v); + const auto avpe = Abs(vpe); + const auto avme = Abs(vme); + const auto acoeffx2 = Mul(coeffx2, a); + auto d = Mul(acoeffx2, Add(av, one)); + auto dpe = Mul(acoeffx2, Add(avpe, one)); + auto dme = Mul(acoeffx2, Add(avme, one)); + d = IfThenElse(Lt(v, zero), Sub(zero, d), d); + dpe = IfThenElse(Lt(vpe, zero), Sub(zero, dpe), dpe); + dme = IfThenElse(Lt(vme, zero), Sub(zero, dme), dme); + const auto above = Ge(av, thres); + // TODO(eustas): use IfThenElseZero + fd_v = Add(fd_v, IfThenElse(above, zero, d)); + fdpe_v = Add(fdpe_v, IfThenElse(above, zero, dpe)); + fdme_v = Add(fdme_v, IfThenElse(above, zero, dme)); + } + + *fpeps = first_derivative_peps + GetLane(SumOfLanes(df, fdpe_v)); + *fmeps = first_derivative_meps + GetLane(SumOfLanes(df, fdme_v)); + return first_derivative + GetLane(SumOfLanes(df, fd_v)); + } + + const float* JXL_RESTRICT values_m; + const float* JXL_RESTRICT values_s; + size_t num; + float base; + float distance_mul; +}; + +// Chroma-from-luma search, values_m will have luma -- and values_s chroma. +int32_t FindBestMultiplier(const float* values_m, const float* values_s, + size_t num, float base, float distance_mul, + bool fast) { + if (num == 0) { + return 0; + } + float x; + if (fast) { + static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor; + auto ca = Zero(df); + auto cb = Zero(df); + const auto inv_color_factor = Set(df, kInvColorFactor); + const auto base_v = Set(df, base); + for (size_t i = 0; i < num; i += Lanes(df)) { + // color residual = ax + b + const auto a = Mul(inv_color_factor, Load(df, values_m + i)); + const auto b = + Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i)); + ca = MulAdd(a, a, ca); + cb = MulAdd(a, b, cb); + } + // + distance_mul * x^2 * num + x = -GetLane(SumOfLanes(df, cb)) / + (GetLane(SumOfLanes(df, ca)) + num * distance_mul * 0.5f); + } else { + constexpr float eps = 100; + constexpr float kClamp = 20.0f; + CFLFunction fn(values_m, values_s, num, base, distance_mul); + x = 0; + // Up to 20 Newton iterations, with approximate derivatives. + // Derivatives are approximate due to the high amount of noise in the exact + // derivatives. + for (size_t i = 0; i < 20; i++) { + float dfpeps, dfmeps; + float df = fn.Compute(x, eps, &dfpeps, &dfmeps); + float ddf = (dfpeps - dfmeps) / (2 * eps); + float kExperimentalInsignificantStabilizer = 0.85; + float step = df / (ddf + kExperimentalInsignificantStabilizer); + x -= std::min(kClamp, std::max(-kClamp, step)); + if (std::abs(step) < 3e-3) break; + } + } + // CFL seems to be tricky for larger transforms for HF components + // close to zero. This heuristic brings the solutions closer to zero + // and reduces red-green oscillations. A better approach would + // look into variance of the multiplier within separate (e.g. 8x8) + // areas and only apply this heuristic where there is a high variance. + // This would give about 1 % more compression density. + float towards_zero = 2.6; + if (x >= towards_zero) { + x -= towards_zero; + } else if (x <= -towards_zero) { + x += towards_zero; + } else { + x = 0; + } + return std::max(-128.0f, std::min(127.0f, roundf(x))); +} + +void InitDCStorage(size_t num_blocks, ImageF* dc_values) { + // First row: Y channel + // Second row: X channel + // Third row: Y channel + // Fourth row: B channel + *dc_values = ImageF(RoundUpTo(num_blocks, Lanes(df)), 4); + + JXL_ASSERT(dc_values->xsize() != 0); + // Zero-fill the last lanes + for (size_t y = 0; y < 4; y++) { + for (size_t x = dc_values->xsize() - Lanes(df); x < dc_values->xsize(); + x++) { + dc_values->Row(y)[x] = 0; + } + } +} + +void ComputeTile(const Image3F& opsin, const Rect& opsin_rect, + const DequantMatrices& dequant, + const AcStrategyImage* ac_strategy, + const ImageI* raw_quant_field, const Quantizer* quantizer, + const Rect& rect, bool fast, bool use_dct8, ImageSB* map_x, + ImageSB* map_b, ImageF* dc_values, float* mem) { + static_assert(kEncTileDimInBlocks == kColorTileDimInBlocks, + "Invalid color tile dim"); + size_t xsize_blocks = opsin_rect.xsize() / kBlockDim; + constexpr float kDistanceMultiplierAC = 1e-9f; + const size_t dct_scratch_size = + 3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim; + + const size_t y0 = rect.y0(); + const size_t x0 = rect.x0(); + const size_t x1 = rect.x0() + rect.xsize(); + const size_t y1 = rect.y0() + rect.ysize(); + + int ty = y0 / kColorTileDimInBlocks; + int tx = x0 / kColorTileDimInBlocks; + + int8_t* JXL_RESTRICT row_out_x = map_x->Row(ty); + int8_t* JXL_RESTRICT row_out_b = map_b->Row(ty); + + float* JXL_RESTRICT dc_values_yx = dc_values->Row(0); + float* JXL_RESTRICT dc_values_x = dc_values->Row(1); + float* JXL_RESTRICT dc_values_yb = dc_values->Row(2); + float* JXL_RESTRICT dc_values_b = dc_values->Row(3); + + // All are aligned. + float* HWY_RESTRICT block_y = mem; + float* HWY_RESTRICT block_x = block_y + AcStrategy::kMaxCoeffArea; + float* HWY_RESTRICT block_b = block_x + AcStrategy::kMaxCoeffArea; + float* HWY_RESTRICT coeffs_yx = block_b + AcStrategy::kMaxCoeffArea; + float* HWY_RESTRICT coeffs_x = coeffs_yx + kColorTileDim * kColorTileDim; + float* HWY_RESTRICT coeffs_yb = coeffs_x + kColorTileDim * kColorTileDim; + float* HWY_RESTRICT coeffs_b = coeffs_yb + kColorTileDim * kColorTileDim; + float* HWY_RESTRICT scratch_space = coeffs_b + kColorTileDim * kColorTileDim; + float* scratch_space_end = + scratch_space + 2 * AcStrategy::kMaxCoeffArea + dct_scratch_size; + JXL_DASSERT(scratch_space_end == block_y + CfLHeuristics::ItemsPerThread()); + (void)scratch_space_end; + + // Small (~256 bytes each) + HWY_ALIGN_MAX float + dc_y[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; + HWY_ALIGN_MAX float + dc_x[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; + HWY_ALIGN_MAX float + dc_b[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; + size_t num_ac = 0; + + for (size_t y = y0; y < y1; ++y) { + const float* JXL_RESTRICT row_y = + opsin_rect.ConstPlaneRow(opsin, 1, y * kBlockDim); + const float* JXL_RESTRICT row_x = + opsin_rect.ConstPlaneRow(opsin, 0, y * kBlockDim); + const float* JXL_RESTRICT row_b = + opsin_rect.ConstPlaneRow(opsin, 2, y * kBlockDim); + size_t stride = opsin.PixelsPerRow(); + + for (size_t x = x0; x < x1; x++) { + AcStrategy acs = use_dct8 + ? AcStrategy::FromRawStrategy(AcStrategy::Type::DCT) + : ac_strategy->ConstRow(y)[x]; + if (!acs.IsFirstBlock()) continue; + size_t xs = acs.covered_blocks_x(); + TransformFromPixels(acs.Strategy(), row_y + x * kBlockDim, stride, + block_y, scratch_space); + DCFromLowestFrequencies(acs.Strategy(), block_y, dc_y, xs); + TransformFromPixels(acs.Strategy(), row_x + x * kBlockDim, stride, + block_x, scratch_space); + DCFromLowestFrequencies(acs.Strategy(), block_x, dc_x, xs); + TransformFromPixels(acs.Strategy(), row_b + x * kBlockDim, stride, + block_b, scratch_space); + DCFromLowestFrequencies(acs.Strategy(), block_b, dc_b, xs); + const float* const JXL_RESTRICT qm_x = + dequant.InvMatrix(acs.Strategy(), 0); + const float* const JXL_RESTRICT qm_b = + dequant.InvMatrix(acs.Strategy(), 2); + float q_dc_x = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(0); + float q_dc_b = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(2); + + // Copy DCs in dc_values. + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < xs; ix++) { + dc_values_yx[(iy + y) * xsize_blocks + ix + x] = + dc_y[iy * xs + ix] * q_dc_x; + dc_values_x[(iy + y) * xsize_blocks + ix + x] = + dc_x[iy * xs + ix] * q_dc_x; + dc_values_yb[(iy + y) * xsize_blocks + ix + x] = + dc_y[iy * xs + ix] * q_dc_b; + dc_values_b[(iy + y) * xsize_blocks + ix + x] = + dc_b[iy * xs + ix] * q_dc_b; + } + } + + // Do not use this block for computing AC CfL. + if (acs.covered_blocks_x() + x0 > x1 || + acs.covered_blocks_y() + y0 > y1) { + continue; + } + + // Copy AC coefficients in the local block. The order in which + // coefficients get stored does not matter. + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + CoefficientLayout(&cy, &cx); + // Zero out LFs. This introduces terms in the optimization loop that + // don't affect the result, as they are all 0, but allow for simpler + // SIMDfication. + for (size_t iy = 0; iy < cy; iy++) { + for (size_t ix = 0; ix < cx; ix++) { + block_y[cx * kBlockDim * iy + ix] = 0; + block_x[cx * kBlockDim * iy + ix] = 0; + block_b[cx * kBlockDim * iy + ix] = 0; + } + } + // Unclear why this is like it is. (This works slightly better + // than the previous approach which was also a hack.) + const float qq = + (raw_quant_field == nullptr) ? 1.0f : raw_quant_field->Row(y)[x]; + // Experimentally values 128-130 seem best -- I don't know why we + // need this multiplier. + const float kStrangeMultiplier = 128; + float q = use_dct8 ? 1 : quantizer->Scale() * kStrangeMultiplier * qq; + const auto qv = Set(df, q); + for (size_t i = 0; i < cx * cy * 64; i += Lanes(df)) { + const auto b_y = Load(df, block_y + i); + const auto b_x = Load(df, block_x + i); + const auto b_b = Load(df, block_b + i); + const auto qqm_x = Mul(qv, Load(df, qm_x + i)); + const auto qqm_b = Mul(qv, Load(df, qm_b + i)); + Store(Mul(b_y, qqm_x), df, coeffs_yx + num_ac); + Store(Mul(b_x, qqm_x), df, coeffs_x + num_ac); + Store(Mul(b_y, qqm_b), df, coeffs_yb + num_ac); + Store(Mul(b_b, qqm_b), df, coeffs_b + num_ac); + num_ac += Lanes(df); + } + } + } + JXL_CHECK(num_ac % Lanes(df) == 0); + row_out_x[tx] = FindBestMultiplier(coeffs_yx, coeffs_x, num_ac, 0.0f, + kDistanceMultiplierAC, fast); + row_out_b[tx] = + FindBestMultiplier(coeffs_yb, coeffs_b, num_ac, jxl::cms::kYToBRatio, + kDistanceMultiplierAC, fast); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(InitDCStorage); +HWY_EXPORT(ComputeTile); + +void CfLHeuristics::Init(const Rect& rect) { + size_t xsize_blocks = rect.xsize() / kBlockDim; + size_t ysize_blocks = rect.ysize() / kBlockDim; + HWY_DYNAMIC_DISPATCH(InitDCStorage) + (xsize_blocks * ysize_blocks, &dc_values); +} + +void CfLHeuristics::ComputeTile(const Rect& r, const Image3F& opsin, + const Rect& opsin_rect, + const DequantMatrices& dequant, + const AcStrategyImage* ac_strategy, + const ImageI* raw_quant_field, + const Quantizer* quantizer, bool fast, + size_t thread, ColorCorrelationMap* cmap) { + bool use_dct8 = ac_strategy == nullptr; + HWY_DYNAMIC_DISPATCH(ComputeTile) + (opsin, opsin_rect, dequant, ac_strategy, raw_quant_field, quantizer, r, fast, + use_dct8, &cmap->ytox_map, &cmap->ytob_map, &dc_values, + mem.get() + thread * ItemsPerThread()); +} + +void ColorCorrelationMapEncodeDC(const ColorCorrelationMap& map, + BitWriter* writer, size_t layer, + AuxOut* aux_out) { + float color_factor = map.GetColorFactor(); + float base_correlation_x = map.GetBaseCorrelationX(); + float base_correlation_b = map.GetBaseCorrelationB(); + int32_t ytox_dc = map.GetYToXDC(); + int32_t ytob_dc = map.GetYToBDC(); + + BitWriter::Allotment allotment(writer, 1 + 2 * kBitsPerByte + 12 + 32); + if (ytox_dc == 0 && ytob_dc == 0 && color_factor == kDefaultColorFactor && + base_correlation_x == 0.0f && + base_correlation_b == jxl::cms::kYToBRatio) { + writer->Write(1, 1); + allotment.ReclaimAndCharge(writer, layer, aux_out); + return; + } + writer->Write(1, 0); + JXL_CHECK(U32Coder::Write(kColorFactorDist, color_factor, writer)); + JXL_CHECK(F16Coder::Write(base_correlation_x, writer)); + JXL_CHECK(F16Coder::Write(base_correlation_b, writer)); + writer->Write(kBitsPerByte, ytox_dc - std::numeric_limits<int8_t>::min()); + writer->Write(kBitsPerByte, ytob_dc - std::numeric_limits<int8_t>::min()); + allotment.ReclaimAndCharge(writer, layer, aux_out); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.h b/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.h new file mode 100644 index 0000000000..04743842bf --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_chroma_from_luma.h @@ -0,0 +1,60 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_CHROMA_FROM_LUMA_H_ +#define LIB_JXL_ENC_CHROMA_FROM_LUMA_H_ + +// Chroma-from-luma, computed using heuristics to determine the best linear +// model for the X and B channels from the Y channel. + +#include <cstddef> +#include <hwy/aligned_allocator.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/simd_util.h" + +namespace jxl { + +struct AuxOut; +class Quantizer; + +void ColorCorrelationMapEncodeDC(const ColorCorrelationMap& map, + BitWriter* writer, size_t layer, + AuxOut* aux_out); + +struct CfLHeuristics { + void Init(const Rect& rect); + + void PrepareForThreads(size_t num_threads) { + mem = hwy::AllocateAligned<float>(num_threads * ItemsPerThread()); + } + + void ComputeTile(const Rect& r, const Image3F& opsin, const Rect& opsin_rect, + const DequantMatrices& dequant, + const AcStrategyImage* ac_strategy, + const ImageI* raw_quant_field, const Quantizer* quantizer, + bool fast, size_t thread, ColorCorrelationMap* cmap); + + ImageF dc_values; + hwy::AlignedFreeUniquePtr<float[]> mem; + + // Working set is too large for stack; allocate dynamically. + static size_t ItemsPerThread() { + const size_t dct_scratch_size = + 3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim; + return AcStrategy::kMaxCoeffArea * 3 // Blocks + + kColorTileDim * kColorTileDim * 4 // AC coeff storage + + AcStrategy::kMaxCoeffArea * 2 // Scratch space + + dct_scratch_size; + } +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_CHROMA_FROM_LUMA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_cluster.cc b/third_party/jpeg-xl/lib/jxl/enc_cluster.cc new file mode 100644 index 0000000000..df1b31ddf7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_cluster.cc @@ -0,0 +1,352 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_cluster.h" + +#include <algorithm> +#include <cmath> +#include <limits> +#include <map> +#include <memory> +#include <numeric> +#include <queue> +#include <tuple> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/base/fast_math-inl.h" +#include "lib/jxl/enc_ans.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::IfThenZeroElse; + +template <class V> +V Entropy(V count, V inv_total, V total) { + const HWY_CAPPED(float, Histogram::kRounding) d; + const auto zero = Set(d, 0.0f); + // TODO(eustas): why (0 - x) instead of Neg(x)? + return IfThenZeroElse( + Eq(count, total), + Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count))))); +} + +void HistogramEntropy(const Histogram& a) { + a.entropy_ = 0.0f; + if (a.total_count_ == 0) return; + + const HWY_CAPPED(float, Histogram::kRounding) df; + const HWY_CAPPED(int32_t, Histogram::kRounding) di; + + const auto inv_tot = Set(df, 1.0f / a.total_count_); + auto entropy_lanes = Zero(df); + auto total = Set(df, a.total_count_); + + for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) { + const auto counts = LoadU(di, &a.data_[i]); + entropy_lanes = + Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total)); + } + a.entropy_ += GetLane(SumOfLanes(df, entropy_lanes)); +} + +float HistogramDistance(const Histogram& a, const Histogram& b) { + if (a.total_count_ == 0 || b.total_count_ == 0) return 0; + + const HWY_CAPPED(float, Histogram::kRounding) df; + const HWY_CAPPED(int32_t, Histogram::kRounding) di; + + const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_)); + auto distance_lanes = Zero(df); + auto total = Set(df, a.total_count_ + b.total_count_); + + for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size()); + i += Lanes(di)) { + const auto a_counts = + a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di); + const auto b_counts = + b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di); + const auto counts = ConvertTo(df, Add(a_counts, b_counts)); + distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total)); + } + const float total_distance = GetLane(SumOfLanes(df, distance_lanes)); + return total_distance - a.entropy_ - b.entropy_; +} + +constexpr const float kInfinity = std::numeric_limits<float>::infinity(); + +float HistogramKLDivergence(const Histogram& actual, const Histogram& coding) { + if (actual.total_count_ == 0) return 0; + if (coding.total_count_ == 0) return kInfinity; + + const HWY_CAPPED(float, Histogram::kRounding) df; + const HWY_CAPPED(int32_t, Histogram::kRounding) di; + + const auto coding_inv = Set(df, 1.0f / coding.total_count_); + auto cost_lanes = Zero(df); + + for (size_t i = 0; i < actual.data_.size(); i += Lanes(di)) { + const auto counts = LoadU(di, &actual.data_[i]); + const auto coding_counts = + coding.data_.size() > i ? LoadU(di, &coding.data_[i]) : Zero(di); + const auto coding_probs = Mul(ConvertTo(df, coding_counts), coding_inv); + const auto neg_coding_cost = BitCast( + df, + IfThenZeroElse(Eq(counts, Zero(di)), + IfThenElse(Eq(coding_counts, Zero(di)), + BitCast(di, Set(df, -kInfinity)), + BitCast(di, FastLog2f(df, coding_probs))))); + cost_lanes = NegMulAdd(ConvertTo(df, counts), neg_coding_cost, cost_lanes); + } + const float total_cost = GetLane(SumOfLanes(df, cost_lanes)); + return total_cost - actual.entropy_; +} + +// First step of a k-means clustering with a fancy distance metric. +void FastClusterHistograms(const std::vector<Histogram>& in, + size_t max_histograms, std::vector<Histogram>* out, + std::vector<uint32_t>* histogram_symbols) { + const size_t prev_histograms = out->size(); + out->reserve(max_histograms); + histogram_symbols->clear(); + histogram_symbols->resize(in.size(), max_histograms); + + std::vector<float> dists(in.size(), std::numeric_limits<float>::max()); + size_t largest_idx = 0; + for (size_t i = 0; i < in.size(); i++) { + if (in[i].total_count_ == 0) { + (*histogram_symbols)[i] = 0; + dists[i] = 0.0f; + continue; + } + HistogramEntropy(in[i]); + if (in[i].total_count_ > in[largest_idx].total_count_) { + largest_idx = i; + } + } + + if (prev_histograms > 0) { + for (size_t j = 0; j < prev_histograms; ++j) { + HistogramEntropy((*out)[j]); + } + for (size_t i = 0; i < in.size(); i++) { + if (dists[i] == 0.0f) continue; + for (size_t j = 0; j < prev_histograms; ++j) { + dists[i] = std::min(HistogramKLDivergence(in[i], (*out)[j]), dists[i]); + } + } + auto max_dist = std::max_element(dists.begin(), dists.end()); + if (*max_dist > 0.0f) { + largest_idx = max_dist - dists.begin(); + } + } + + constexpr float kMinDistanceForDistinct = 48.0f; + while (out->size() < max_histograms) { + (*histogram_symbols)[largest_idx] = out->size(); + out->push_back(in[largest_idx]); + dists[largest_idx] = 0.0f; + largest_idx = 0; + for (size_t i = 0; i < in.size(); i++) { + if (dists[i] == 0.0f) continue; + dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]); + if (dists[i] > dists[largest_idx]) largest_idx = i; + } + if (dists[largest_idx] < kMinDistanceForDistinct) break; + } + + for (size_t i = 0; i < in.size(); i++) { + if ((*histogram_symbols)[i] != max_histograms) continue; + size_t best = 0; + float best_dist = std::numeric_limits<float>::max(); + for (size_t j = 0; j < out->size(); j++) { + float dist = j < prev_histograms ? HistogramKLDivergence(in[i], (*out)[j]) + : HistogramDistance(in[i], (*out)[j]); + if (dist < best_dist) { + best = j; + best_dist = dist; + } + } + JXL_ASSERT(best_dist < std::numeric_limits<float>::max()); + if (best >= prev_histograms) { + (*out)[best].AddHistogram(in[i]); + HistogramEntropy((*out)[best]); + } + (*histogram_symbols)[i] = best; + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(FastClusterHistograms); // Local function +HWY_EXPORT(HistogramEntropy); // Local function + +float Histogram::PopulationCost() const { + return ANSPopulationCost(data_.data(), data_.size()); +} + +float Histogram::ShannonEntropy() const { + HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this); + return entropy_; +} + +namespace { +// ----------------------------------------------------------------------------- +// Histogram refinement + +// Reorder histograms in *out so that the new symbols in *symbols come in +// increasing order. +void HistogramReindex(std::vector<Histogram>* out, size_t prev_histograms, + std::vector<uint32_t>* symbols) { + std::vector<Histogram> tmp(*out); + std::map<int, int> new_index; + for (size_t i = 0; i < prev_histograms; ++i) { + new_index[i] = i; + } + int next_index = prev_histograms; + for (uint32_t symbol : *symbols) { + if (new_index.find(symbol) == new_index.end()) { + new_index[symbol] = next_index; + (*out)[next_index] = tmp[symbol]; + ++next_index; + } + } + out->resize(next_index); + for (uint32_t& symbol : *symbols) { + symbol = new_index[symbol]; + } +} + +} // namespace + +// Clusters similar histograms in 'in' together, the selected histograms are +// placed in 'out', and for each index in 'in', *histogram_symbols will +// indicate which of the 'out' histograms is the best approximation. +void ClusterHistograms(const HistogramParams params, + const std::vector<Histogram>& in, size_t max_histograms, + std::vector<Histogram>* out, + std::vector<uint32_t>* histogram_symbols) { + size_t prev_histograms = out->size(); + max_histograms = std::min(max_histograms, params.max_histograms); + max_histograms = std::min(max_histograms, in.size()); + if (params.clustering == HistogramParams::ClusteringType::kFastest) { + max_histograms = std::min(max_histograms, static_cast<size_t>(4)); + } + + HWY_DYNAMIC_DISPATCH(FastClusterHistograms) + (in, prev_histograms + max_histograms, out, histogram_symbols); + + if (prev_histograms == 0 && + params.clustering == HistogramParams::ClusteringType::kBest) { + for (size_t i = 0; i < out->size(); i++) { + (*out)[i].entropy_ = + ANSPopulationCost((*out)[i].data_.data(), (*out)[i].data_.size()); + } + uint32_t next_version = 2; + std::vector<uint32_t> version(out->size(), 1); + std::vector<uint32_t> renumbering(out->size()); + std::iota(renumbering.begin(), renumbering.end(), 0); + + // Try to pair up clusters if doing so reduces the total cost. + + struct HistogramPair { + // validity of a pair: p.version == max(version[i], version[j]) + float cost; + uint32_t first; + uint32_t second; + uint32_t version; + // We use > because priority queues sort in *decreasing* order, but we + // want lower cost elements to appear first. + bool operator<(const HistogramPair& other) const { + return std::make_tuple(cost, first, second, version) > + std::make_tuple(other.cost, other.first, other.second, + other.version); + } + }; + + // Create list of all pairs by increasing merging cost. + std::priority_queue<HistogramPair> pairs_to_merge; + for (uint32_t i = 0; i < out->size(); i++) { + for (uint32_t j = i + 1; j < out->size(); j++) { + Histogram histo; + histo.AddHistogram((*out)[i]); + histo.AddHistogram((*out)[j]); + float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) - + (*out)[i].entropy_ - (*out)[j].entropy_; + // Avoid enqueueing pairs that are not advantageous to merge. + if (cost >= 0) continue; + pairs_to_merge.push( + HistogramPair{cost, i, j, std::max(version[i], version[j])}); + } + } + + // Merge the best pair to merge, add new pairs that get formed as a + // consequence. + while (!pairs_to_merge.empty()) { + uint32_t first = pairs_to_merge.top().first; + uint32_t second = pairs_to_merge.top().second; + uint32_t ver = pairs_to_merge.top().version; + pairs_to_merge.pop(); + if (ver != std::max(version[first], version[second]) || + version[first] == 0 || version[second] == 0) { + continue; + } + (*out)[first].AddHistogram((*out)[second]); + (*out)[first].entropy_ = ANSPopulationCost((*out)[first].data_.data(), + (*out)[first].data_.size()); + for (size_t i = 0; i < renumbering.size(); i++) { + if (renumbering[i] == second) { + renumbering[i] = first; + } + } + version[second] = 0; + version[first] = next_version++; + for (uint32_t j = 0; j < out->size(); j++) { + if (j == first) continue; + if (version[j] == 0) continue; + Histogram histo; + histo.AddHistogram((*out)[first]); + histo.AddHistogram((*out)[j]); + float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) - + (*out)[first].entropy_ - (*out)[j].entropy_; + // Avoid enqueueing pairs that are not advantageous to merge. + if (cost >= 0) continue; + pairs_to_merge.push( + HistogramPair{cost, std::min(first, j), std::max(first, j), + std::max(version[first], version[j])}); + } + } + std::vector<uint32_t> reverse_renumbering(out->size(), -1); + size_t num_alive = 0; + for (size_t i = 0; i < out->size(); i++) { + if (version[i] == 0) continue; + (*out)[num_alive++] = (*out)[i]; + reverse_renumbering[i] = num_alive - 1; + } + out->resize(num_alive); + for (size_t i = 0; i < histogram_symbols->size(); i++) { + (*histogram_symbols)[i] = + reverse_renumbering[renumbering[(*histogram_symbols)[i]]]; + } + } + + // Convert the context map to a canonical form. + HistogramReindex(out, prev_histograms, histogram_symbols); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_cluster.h b/third_party/jpeg-xl/lib/jxl/enc_cluster.h new file mode 100644 index 0000000000..923aaaccfe --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_cluster.h @@ -0,0 +1,69 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Functions for clustering similar histograms together. + +#ifndef LIB_JXL_ENC_CLUSTER_H_ +#define LIB_JXL_ENC_CLUSTER_H_ + +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/enc_ans_params.h" + +namespace jxl { + +struct Histogram { + Histogram() { + total_count_ = 0; + entropy_ = 0.0; + } + void Clear() { + data_.clear(); + total_count_ = 0; + } + void Add(size_t symbol) { + if (data_.size() <= symbol) { + data_.resize(DivCeil(symbol + 1, kRounding) * kRounding); + } + ++data_[symbol]; + ++total_count_; + } + void AddHistogram(const Histogram& other) { + if (other.data_.size() > data_.size()) { + data_.resize(other.data_.size()); + } + for (size_t i = 0; i < other.data_.size(); ++i) { + data_[i] += other.data_[i]; + } + total_count_ += other.total_count_; + } + size_t alphabet_size() const { + for (int i = data_.size() - 1; i >= 0; --i) { + if (data_[i] > 0) { + return i + 1; + } + } + return 1; + } + float PopulationCost() const; + float ShannonEntropy() const; + + std::vector<ANSHistBin> data_; + size_t total_count_; + mutable float entropy_; // WARNING: not kept up-to-date. + static constexpr size_t kRounding = 8; +}; + +void ClusterHistograms(HistogramParams params, const std::vector<Histogram>& in, + size_t max_histograms, std::vector<Histogram>* out, + std::vector<uint32_t>* histogram_symbols); +} // namespace jxl + +#endif // LIB_JXL_ENC_CLUSTER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_coeff_order.cc b/third_party/jpeg-xl/lib/jxl/enc_coeff_order.cc new file mode 100644 index 0000000000..5be012aa92 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_coeff_order.cc @@ -0,0 +1,297 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <stdint.h> + +#include <algorithm> +#include <hwy/aligned_allocator.h> +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/lehmer_code.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +struct AuxOut; + +std::pair<uint32_t, uint32_t> ComputeUsedOrders( + const SpeedTier speed, const AcStrategyImage& ac_strategy, + const Rect& rect) { + // No coefficient reordering in Falcon or faster. + // Only uses DCT8 = 0, so bitfield = 1. + if (speed >= SpeedTier::kFalcon) return {1, 1}; + + uint32_t ret = 0; + uint32_t ret_customize = 0; + size_t xsize_blocks = rect.xsize(); + size_t ysize_blocks = rect.ysize(); + // TODO(veluca): precompute when doing DCT. + for (size_t by = 0; by < ysize_blocks; ++by) { + AcStrategyRow acs_row = ac_strategy.ConstRow(rect, by); + for (size_t bx = 0; bx < xsize_blocks; ++bx) { + int ord = kStrategyOrder[acs_row[bx].RawStrategy()]; + // Do not customize coefficient orders for blocks bigger than 32x32. + ret |= 1u << ord; + if (ord > 6) { + continue; + } + ret_customize |= 1u << ord; + } + } + // Use default orders for small images. + if (ac_strategy.xsize() < 5 && ac_strategy.ysize() < 5) return {ret, 0}; + return {ret, ret_customize}; +} + +void ComputeCoeffOrder(SpeedTier speed, const ACImage& acs, + const AcStrategyImage& ac_strategy, + const FrameDimensions& frame_dim, + uint32_t& all_used_orders, uint32_t prev_used_acs, + uint32_t current_used_acs, uint32_t current_used_orders, + coeff_order_t* JXL_RESTRICT order) { + std::vector<int32_t> num_zeros(kCoeffOrderMaxSize); + // If compressing at high speed and only using 8x8 DCTs, only consider a + // subset of blocks. + double block_fraction = 1.0f; + // TODO(veluca): figure out why sampling blocks if non-8x8s are used makes + // encoding significantly less dense. + if (speed >= SpeedTier::kSquirrel && current_used_orders == 1) { + block_fraction = 0.5f; + } + // No need to compute number of zero coefficients if all orders are the + // default. + if (current_used_orders != 0) { + uint64_t threshold = + (std::numeric_limits<uint64_t>::max() >> 32) * block_fraction; + uint64_t s[2] = {static_cast<uint64_t>(0x94D049BB133111EBull), + static_cast<uint64_t>(0xBF58476D1CE4E5B9ull)}; + // Xorshift128+ adapted from xorshift128+-inl.h + auto use_sample = [&]() { + auto s1 = s[0]; + const auto s0 = s[1]; + const auto bits = s1 + s0; // b, c + s[0] = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s[1] = s1; + return (bits >> 32) <= threshold; + }; + + // Count number of zero coefficients, separately for each DCT band. + // TODO(veluca): precompute when doing DCT. + for (size_t group_index = 0; group_index < frame_dim.num_groups; + group_index++) { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + const Rect rect(gx * kGroupDimInBlocks, gy * kGroupDimInBlocks, + kGroupDimInBlocks, kGroupDimInBlocks, + frame_dim.xsize_blocks, frame_dim.ysize_blocks); + ConstACPtr rows[3]; + ACType type = acs.Type(); + for (size_t c = 0; c < 3; c++) { + rows[c] = acs.PlaneRow(c, group_index, 0); + } + size_t ac_offset = 0; + + // TODO(veluca): SIMDfy. + for (size_t by = 0; by < rect.ysize(); ++by) { + AcStrategyRow acs_row = ac_strategy.ConstRow(rect, by); + for (size_t bx = 0; bx < rect.xsize(); ++bx) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + if (!use_sample()) continue; + size_t size = kDCTBlockSize << acs.log2_covered_blocks(); + for (size_t c = 0; c < 3; ++c) { + const size_t order_offset = + CoeffOrderOffset(kStrategyOrder[acs.RawStrategy()], c); + if (type == ACType::k16) { + for (size_t k = 0; k < size; k++) { + bool is_zero = rows[c].ptr16[ac_offset + k] == 0; + num_zeros[order_offset + k] += is_zero ? 1 : 0; + } + } else { + for (size_t k = 0; k < size; k++) { + bool is_zero = rows[c].ptr32[ac_offset + k] == 0; + num_zeros[order_offset + k] += is_zero ? 1 : 0; + } + } + // Ensure LLFs are first in the order. + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + CoefficientLayout(&cy, &cx); + for (size_t iy = 0; iy < cy; iy++) { + for (size_t ix = 0; ix < cx; ix++) { + num_zeros[order_offset + iy * kBlockDim * cx + ix] = -1; + } + } + } + ac_offset += size; + } + } + } + } + struct PosAndCount { + uint32_t pos; + uint32_t count; + }; + auto mem = hwy::AllocateAligned<PosAndCount>(AcStrategy::kMaxCoeffArea); + + std::vector<coeff_order_t> natural_order_buffer; + + uint16_t computed = 0; + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + uint8_t ord = kStrategyOrder[o]; + if (computed & (1 << ord)) continue; + computed |= 1 << ord; + AcStrategy acs = AcStrategy::FromRawStrategy(o); + size_t sz = kDCTBlockSize * acs.covered_blocks_x() * acs.covered_blocks_y(); + + // Do nothing for transforms that don't appear. + if ((1 << ord) & ~current_used_acs) continue; + + // Do nothing if we already committed to this custom order previously. + if ((1 << ord) & prev_used_acs) continue; + if ((1 << ord) & all_used_orders) continue; + + if (natural_order_buffer.size() < sz) natural_order_buffer.resize(sz); + acs.ComputeNaturalCoeffOrder(natural_order_buffer.data()); + + // Ensure natural coefficient order is not permuted if the order is + // not transmitted. + if ((1 << ord) & ~current_used_orders) { + for (size_t c = 0; c < 3; c++) { + size_t offset = CoeffOrderOffset(ord, c); + JXL_DASSERT(CoeffOrderOffset(ord, c + 1) - offset == sz); + memcpy(&order[offset], natural_order_buffer.data(), + sz * sizeof(*order)); + } + continue; + } + + bool is_nondefault = false; + for (uint8_t c = 0; c < 3; c++) { + // Apply zig-zag order. + PosAndCount* pos_and_val = mem.get(); + size_t offset = CoeffOrderOffset(ord, c); + JXL_DASSERT(CoeffOrderOffset(ord, c + 1) - offset == sz); + float inv_sqrt_sz = 1.0f / std::sqrt(sz); + for (size_t i = 0; i < sz; ++i) { + size_t pos = natural_order_buffer[i]; + pos_and_val[i].pos = pos; + // We don't care for the exact number -> quantize number of zeros, + // to get less permuted order. + pos_and_val[i].count = num_zeros[offset + pos] * inv_sqrt_sz + 0.1f; + } + + // Stable-sort -> elements with same number of zeros will preserve their + // order. + auto comparator = [](const PosAndCount& a, const PosAndCount& b) -> bool { + return a.count < b.count; + }; + std::stable_sort(pos_and_val, pos_and_val + sz, comparator); + + // Grab indices. + for (size_t i = 0; i < sz; ++i) { + order[offset + i] = pos_and_val[i].pos; + is_nondefault |= natural_order_buffer[i] != pos_and_val[i].pos; + } + } + if (!is_nondefault) { + current_used_orders &= ~(1 << ord); + } + } + all_used_orders |= current_used_orders; +} + +namespace { + +void TokenizePermutation(const coeff_order_t* JXL_RESTRICT order, size_t skip, + size_t size, std::vector<Token>* tokens) { + std::vector<LehmerT> lehmer(size); + std::vector<uint32_t> temp(size + 1); + ComputeLehmerCode(order, temp.data(), size, lehmer.data()); + size_t end = size; + while (end > skip && lehmer[end - 1] == 0) { + --end; + } + tokens->emplace_back(CoeffOrderContext(size), end - skip); + uint32_t last = 0; + for (size_t i = skip; i < end; ++i) { + tokens->emplace_back(CoeffOrderContext(last), lehmer[i]); + last = lehmer[i]; + } +} + +} // namespace + +void EncodePermutation(const coeff_order_t* JXL_RESTRICT order, size_t skip, + size_t size, BitWriter* writer, int layer, + AuxOut* aux_out) { + std::vector<std::vector<Token>> tokens(1); + TokenizePermutation(order, skip, size, &tokens[0]); + std::vector<uint8_t> context_map; + EntropyEncodingData codes; + BuildAndEncodeHistograms(HistogramParams(), kPermutationContexts, tokens, + &codes, &context_map, writer, layer, aux_out); + WriteTokens(tokens[0], codes, context_map, 0, writer, layer, aux_out); +} + +namespace { +void EncodeCoeffOrder(const coeff_order_t* JXL_RESTRICT order, AcStrategy acs, + std::vector<Token>* tokens, coeff_order_t* order_zigzag, + std::vector<coeff_order_t>& natural_order_lut) { + const size_t llf = acs.covered_blocks_x() * acs.covered_blocks_y(); + const size_t size = kDCTBlockSize * llf; + for (size_t i = 0; i < size; ++i) { + order_zigzag[i] = natural_order_lut[order[i]]; + } + TokenizePermutation(order_zigzag, llf, size, tokens); +} +} // namespace + +void EncodeCoeffOrders(uint16_t used_orders, + const coeff_order_t* JXL_RESTRICT order, + BitWriter* writer, size_t layer, + AuxOut* JXL_RESTRICT aux_out) { + auto mem = hwy::AllocateAligned<coeff_order_t>(AcStrategy::kMaxCoeffArea); + uint16_t computed = 0; + std::vector<std::vector<Token>> tokens(1); + std::vector<coeff_order_t> natural_order_lut; + for (uint8_t o = 0; o < AcStrategy::kNumValidStrategies; ++o) { + uint8_t ord = kStrategyOrder[o]; + if (computed & (1 << ord)) continue; + computed |= 1 << ord; + if ((used_orders & (1 << ord)) == 0) continue; + AcStrategy acs = AcStrategy::FromRawStrategy(o); + const size_t llf = acs.covered_blocks_x() * acs.covered_blocks_y(); + const size_t size = kDCTBlockSize * llf; + if (natural_order_lut.size() < size) natural_order_lut.resize(size); + acs.ComputeNaturalCoeffOrderLut(natural_order_lut.data()); + for (size_t c = 0; c < 3; c++) { + EncodeCoeffOrder(&order[CoeffOrderOffset(ord, c)], acs, &tokens[0], + mem.get(), natural_order_lut); + } + } + // Do not write anything if no order is used. + if (used_orders != 0) { + std::vector<uint8_t> context_map; + EntropyEncodingData codes; + BuildAndEncodeHistograms(HistogramParams(), kPermutationContexts, tokens, + &codes, &context_map, writer, layer, aux_out); + WriteTokens(tokens[0], codes, context_map, 0, writer, layer, aux_out); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_coeff_order.h b/third_party/jpeg-xl/lib/jxl/enc_coeff_order.h new file mode 100644 index 0000000000..25e0f17a8d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_coeff_order.h @@ -0,0 +1,56 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_COEFF_ORDER_H_ +#define LIB_JXL_ENC_COEFF_ORDER_H_ + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_dimensions.h" + +namespace jxl { + +struct AuxOut; + +// Orders that are actually used in part of image. `rect` is in block units. +// Returns {orders that are used, orders that might be made non-default}. +std::pair<uint32_t, uint32_t> ComputeUsedOrders( + SpeedTier speed, const AcStrategyImage& ac_strategy, const Rect& rect); + +// Modify zig-zag order, so that DCT bands with more zeros go later. +// Order of DCT bands with same number of zeros is untouched, so +// permutation will be cheaper to encode. +void ComputeCoeffOrder(SpeedTier speed, const ACImage& acs, + const AcStrategyImage& ac_strategy, + const FrameDimensions& frame_dim, + uint32_t& all_used_orders, uint32_t prev_used_acs, + uint32_t current_used_acs, uint32_t current_used_orders, + coeff_order_t* JXL_RESTRICT order); + +void EncodeCoeffOrders(uint16_t used_orders, + const coeff_order_t* JXL_RESTRICT order, + BitWriter* writer, size_t layer, + AuxOut* JXL_RESTRICT aux_out); + +// Encoding/decoding of a single permutation. `size`: number of elements in the +// permutation. `skip`: number of elements to skip from the *beginning* of the +// permutation. +void EncodePermutation(const coeff_order_t* JXL_RESTRICT order, size_t skip, + size_t size, BitWriter* writer, int layer, + AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_COEFF_ORDER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_comparator.cc b/third_party/jpeg-xl/lib/jxl/enc_comparator.cc new file mode 100644 index 0000000000..268122af06 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_comparator.cc @@ -0,0 +1,127 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_comparator.h" + +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/enc_gamma_correct.h" +#include "lib/jxl/enc_image_bundle.h" + +namespace jxl { +namespace { + +// color is linear, but blending happens in gamma-compressed space using +// (gamma-compressed) grayscale background color, alpha image represents +// weights of the sRGB colors in the [0 .. (1 << bit_depth) - 1] interval, +// output image is in linear space. +void AlphaBlend(const Image3F& in, const size_t c, float background_linear, + const ImageF& alpha, Image3F* out) { + const float background = LinearToSrgb8Direct(background_linear); + + for (size_t y = 0; y < out->ysize(); ++y) { + const float* JXL_RESTRICT row_a = alpha.ConstRow(y); + const float* JXL_RESTRICT row_i = in.ConstPlaneRow(c, y); + float* JXL_RESTRICT row_o = out->PlaneRow(c, y); + for (size_t x = 0; x < out->xsize(); ++x) { + const float a = row_a[x]; + if (a <= 0.f) { + row_o[x] = background_linear; + } else if (a >= 1.f) { + row_o[x] = row_i[x]; + } else { + const float w_fg = a; + const float w_bg = 1.0f - w_fg; + const float fg = w_fg * LinearToSrgb8Direct(row_i[x]); + const float bg = w_bg * background; + row_o[x] = Srgb8ToLinearDirect(fg + bg); + } + } + } +} + +void AlphaBlend(float background_linear, ImageBundle* io_linear_srgb) { + // No alpha => all opaque. + if (!io_linear_srgb->HasAlpha()) return; + + for (size_t c = 0; c < 3; ++c) { + AlphaBlend(*io_linear_srgb->color(), c, background_linear, + *io_linear_srgb->alpha(), io_linear_srgb->color()); + } +} + +float ComputeScoreImpl(const ImageBundle& rgb0, const ImageBundle& rgb1, + Comparator* comparator, ImageF* distmap) { + JXL_CHECK(comparator->SetReferenceImage(rgb0)); + float score; + JXL_CHECK(comparator->CompareWith(rgb1, distmap, &score)); + return score; +} + +} // namespace + +float ComputeScore(const ImageBundle& rgb0, const ImageBundle& rgb1, + Comparator* comparator, const JxlCmsInterface& cms, + ImageF* diffmap, ThreadPool* pool, bool ignore_alpha) { + // Convert to linear sRGB (unless already in that space) + ImageMetadata metadata0 = *rgb0.metadata(); + ImageBundle store0(&metadata0); + const ImageBundle* linear_srgb0; + JXL_CHECK(TransformIfNeeded(rgb0, ColorEncoding::LinearSRGB(rgb0.IsGray()), + cms, pool, &store0, &linear_srgb0)); + ImageMetadata metadata1 = *rgb1.metadata(); + ImageBundle store1(&metadata1); + const ImageBundle* linear_srgb1; + JXL_CHECK(TransformIfNeeded(rgb1, ColorEncoding::LinearSRGB(rgb1.IsGray()), + cms, pool, &store1, &linear_srgb1)); + + // No alpha: skip blending, only need a single call to Butteraugli. + if (ignore_alpha || (!rgb0.HasAlpha() && !rgb1.HasAlpha())) { + return ComputeScoreImpl(*linear_srgb0, *linear_srgb1, comparator, diffmap); + } + + // Blend on black and white backgrounds + + const float black = 0.0f; + ImageBundle blended_black0 = linear_srgb0->Copy(); + ImageBundle blended_black1 = linear_srgb1->Copy(); + AlphaBlend(black, &blended_black0); + AlphaBlend(black, &blended_black1); + + const float white = 1.0f; + ImageBundle blended_white0 = linear_srgb0->Copy(); + ImageBundle blended_white1 = linear_srgb1->Copy(); + + AlphaBlend(white, &blended_white0); + AlphaBlend(white, &blended_white1); + + ImageF diffmap_black, diffmap_white; + const float dist_black = ComputeScoreImpl(blended_black0, blended_black1, + comparator, &diffmap_black); + const float dist_white = ComputeScoreImpl(blended_white0, blended_white1, + comparator, &diffmap_white); + + // diffmap and return values are the max of diffmap_black/white. + if (diffmap != nullptr) { + const size_t xsize = rgb0.xsize(); + const size_t ysize = rgb0.ysize(); + *diffmap = ImageF(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + const float* JXL_RESTRICT row_black = diffmap_black.ConstRow(y); + const float* JXL_RESTRICT row_white = diffmap_white.ConstRow(y); + float* JXL_RESTRICT row_out = diffmap->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = std::max(row_black[x], row_white[x]); + } + } + } + return std::max(dist_black, dist_white); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_comparator.h b/third_party/jpeg-xl/lib/jxl/enc_comparator.h new file mode 100644 index 0000000000..c545ea6111 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_comparator.h @@ -0,0 +1,53 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_COMPARATOR_H_ +#define LIB_JXL_ENC_COMPARATOR_H_ + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +class Comparator { + public: + virtual ~Comparator() = default; + + // Sets the reference image, the first to compare + // Image must be in linear sRGB (gamma expanded) in range 0.0f-1.0f as + // the range from standard black point to standard white point, but values + // outside permitted. + virtual Status SetReferenceImage(const ImageBundle& ref) = 0; + + // Sets the actual image (with loss), the second to compare + // Image must be in linear sRGB (gamma expanded) in range 0.0f-1.0f as + // the range from standard black point to standard white point, but values + // outside permitted. + // In diffmap it outputs the local score per pixel, while in score it outputs + // a single score. Any one may be set to nullptr to not compute it. + virtual Status CompareWith(const ImageBundle& actual, ImageF* diffmap, + float* score) = 0; + + // Quality thresholds for diffmap and score values. + // The good score must represent a value where the images are considered to + // be perceptually indistinguishable (but not identical) + // The bad value must be larger than good to indicate "lower means better" + // and smaller than good to indicate "higher means better" + virtual float GoodQualityScore() const = 0; + virtual float BadQualityScore() const = 0; +}; + +// Computes the score given images in any RGB color model, optionally with +// alpha channel. +float ComputeScore(const ImageBundle& rgb0, const ImageBundle& rgb1, + Comparator* comparator, const JxlCmsInterface& cms, + ImageF* diffmap = nullptr, ThreadPool* pool = nullptr, + bool ignore_alpha = false); + +} // namespace jxl + +#endif // LIB_JXL_ENC_COMPARATOR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_context_map.cc b/third_party/jpeg-xl/lib/jxl/enc_context_map.cc new file mode 100644 index 0000000000..6968a6fbae --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_context_map.cc @@ -0,0 +1,155 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Library to encode the context map. + +#include "lib/jxl/enc_context_map.h" + +#include <stdint.h> + +#include <algorithm> +#include <cstddef> +#include <vector> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/pack_signed.h" + +namespace jxl { + +namespace { + +size_t IndexOf(const std::vector<uint8_t>& v, uint8_t value) { + size_t i = 0; + for (; i < v.size(); ++i) { + if (v[i] == value) return i; + } + return i; +} + +void MoveToFront(std::vector<uint8_t>* v, size_t index) { + uint8_t value = (*v)[index]; + for (size_t i = index; i != 0; --i) { + (*v)[i] = (*v)[i - 1]; + } + (*v)[0] = value; +} + +std::vector<uint8_t> MoveToFrontTransform(const std::vector<uint8_t>& v) { + if (v.empty()) return v; + uint8_t max_value = *std::max_element(v.begin(), v.end()); + std::vector<uint8_t> mtf(max_value + 1); + for (size_t i = 0; i <= max_value; ++i) mtf[i] = i; + std::vector<uint8_t> result(v.size()); + for (size_t i = 0; i < v.size(); ++i) { + size_t index = IndexOf(mtf, v[i]); + JXL_ASSERT(index < mtf.size()); + result[i] = static_cast<uint8_t>(index); + MoveToFront(&mtf, index); + } + return result; +} + +} // namespace + +void EncodeContextMap(const std::vector<uint8_t>& context_map, + size_t num_histograms, BitWriter* writer, size_t layer, + AuxOut* aux_out) { + if (num_histograms == 1) { + // Simple code + writer->Write(1, 1); + // 0 bits per entry. + writer->Write(2, 0); + return; + } + + std::vector<uint8_t> transformed_symbols = MoveToFrontTransform(context_map); + std::vector<std::vector<Token>> tokens(1), mtf_tokens(1); + for (size_t i = 0; i < context_map.size(); i++) { + tokens[0].emplace_back(0, context_map[i]); + } + for (size_t i = 0; i < transformed_symbols.size(); i++) { + mtf_tokens[0].emplace_back(0, transformed_symbols[i]); + } + HistogramParams params; + params.uint_method = HistogramParams::HybridUintMethod::kContextMap; + size_t ans_cost, mtf_cost; + { + EntropyEncodingData codes; + std::vector<uint8_t> sink_context_map; + ans_cost = BuildAndEncodeHistograms(params, 1, tokens, &codes, + &sink_context_map, nullptr, 0, nullptr); + } + { + EntropyEncodingData codes; + std::vector<uint8_t> sink_context_map; + mtf_cost = BuildAndEncodeHistograms(params, 1, mtf_tokens, &codes, + &sink_context_map, nullptr, 0, nullptr); + } + bool use_mtf = mtf_cost < ans_cost; + // Rebuild token list. + tokens[0].clear(); + for (size_t i = 0; i < transformed_symbols.size(); i++) { + tokens[0].emplace_back(0, + use_mtf ? transformed_symbols[i] : context_map[i]); + } + size_t entry_bits = CeilLog2Nonzero(num_histograms); + size_t simple_cost = entry_bits * context_map.size(); + if (entry_bits < 4 && simple_cost < ans_cost && simple_cost < mtf_cost) { + BitWriter::Allotment allotment(writer, 3 + entry_bits * context_map.size()); + writer->Write(1, 1); + writer->Write(2, entry_bits); + for (size_t i = 0; i < context_map.size(); i++) { + writer->Write(entry_bits, context_map[i]); + } + allotment.ReclaimAndCharge(writer, layer, aux_out); + } else { + BitWriter::Allotment allotment(writer, 2 + tokens[0].size() * 24); + writer->Write(1, 0); + writer->Write(1, use_mtf); // Use/don't use MTF. + EntropyEncodingData codes; + std::vector<uint8_t> sink_context_map; + BuildAndEncodeHistograms(params, 1, tokens, &codes, &sink_context_map, + writer, layer, aux_out); + WriteTokens(tokens[0], codes, sink_context_map, 0, writer); + allotment.ReclaimAndCharge(writer, layer, aux_out); + } +} + +void EncodeBlockCtxMap(const BlockCtxMap& block_ctx_map, BitWriter* writer, + AuxOut* aux_out) { + auto& dct = block_ctx_map.dc_thresholds; + auto& qft = block_ctx_map.qf_thresholds; + auto& ctx_map = block_ctx_map.ctx_map; + BitWriter::Allotment allotment( + writer, + (dct[0].size() + dct[1].size() + dct[2].size() + qft.size()) * 34 + 1 + + 4 + 4 + ctx_map.size() * 10 + 1024); + if (dct[0].empty() && dct[1].empty() && dct[2].empty() && qft.empty() && + ctx_map.size() == 21 && + std::equal(ctx_map.begin(), ctx_map.end(), BlockCtxMap::kDefaultCtxMap)) { + writer->Write(1, 1); // default + allotment.ReclaimAndCharge(writer, kLayerAC, aux_out); + return; + } + writer->Write(1, 0); + for (int j : {0, 1, 2}) { + writer->Write(4, dct[j].size()); + for (int i : dct[j]) { + JXL_CHECK(U32Coder::Write(kDCThresholdDist, PackSigned(i), writer)); + } + } + writer->Write(4, qft.size()); + for (uint32_t i : qft) { + JXL_CHECK(U32Coder::Write(kQFThresholdDist, i - 1, writer)); + } + EncodeContextMap(ctx_map, block_ctx_map.num_ctxs, writer, kLayerAC, aux_out); + allotment.ReclaimAndCharge(writer, kLayerAC, aux_out); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_context_map.h b/third_party/jpeg-xl/lib/jxl/enc_context_map.h new file mode 100644 index 0000000000..041e71de7a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_context_map.h @@ -0,0 +1,35 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_CONTEXT_MAP_H_ +#define LIB_JXL_ENC_CONTEXT_MAP_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +struct AuxOut; + +// Max limit is 255 because encoding assumes numbers < 255 +// More clusters can help compression, but makes encode/decode somewhat slower +static const size_t kClustersLimit = 128; + +// Encodes the given context map to the bit stream. The number of different +// histogram ids is given by num_histograms. +void EncodeContextMap(const std::vector<uint8_t>& context_map, + size_t num_histograms, BitWriter* writer, size_t layer, + AuxOut* aux_out); + +void EncodeBlockCtxMap(const BlockCtxMap& block_ctx_map, BitWriter* writer, + AuxOut* aux_out); +} // namespace jxl + +#endif // LIB_JXL_ENC_CONTEXT_MAP_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_debug_image.cc b/third_party/jpeg-xl/lib/jxl/enc_debug_image.cc new file mode 100644 index 0000000000..261570e690 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_debug_image.cc @@ -0,0 +1,115 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_debug_image.h" + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +namespace { +template <typename From> +Plane<float> ConvertToFloat(const Plane<From>& from) { + float factor = 1.0f / std::numeric_limits<From>::max(); + if (std::is_same<From, double>::value || std::is_same<From, float>::value) { + factor = 1.0f; + } + Plane<float> to(from.xsize(), from.ysize()); + for (size_t y = 0; y < from.ysize(); ++y) { + const From* const JXL_RESTRICT row_from = from.Row(y); + float* const JXL_RESTRICT row_to = to.Row(y); + for (size_t x = 0; x < from.xsize(); ++x) { + row_to[x] = row_from[x] * factor; + } + } + return to; +} +template <typename From> +Image3F ConvertToFloat(const Image3<From>& from) { + return Image3F(ConvertToFloat(from.Plane(0)), ConvertToFloat(from.Plane(1)), + ConvertToFloat(from.Plane(2))); +} + +template <typename T> +void DumpImageT(const CompressParams& cparams, const char* label, + const ColorEncoding& color_encoding, const Image3<T>& image) { + if (!cparams.debug_image) return; + Image3F float_image = ConvertToFloat(image); + JxlColorEncoding color = color_encoding.ToExternal(); + size_t num_pixels = 3 * image.xsize() * image.ysize(); + std::vector<uint16_t> pixels(num_pixels); + const ImageF* channels[3]; + for (int c = 0; c < 3; ++c) { + channels[c] = &float_image.Plane(c); + } + JXL_CHECK(ConvertChannelsToExternal( + channels, 3, 16, false, JXL_BIG_ENDIAN, 6 * image.xsize(), nullptr, + &pixels[0], 2 * num_pixels, PixelCallback(), Orientation::kIdentity)); + (*cparams.debug_image)(cparams.debug_image_opaque, label, image.xsize(), + image.ysize(), &color, &pixels[0]); +} + +template <typename T> +void DumpPlaneNormalizedT(const CompressParams& cparams, const char* label, + const Plane<T>& image) { + T min; + T max; + ImageMinMax(image, &min, &max); + Image3B normalized(image.xsize(), image.ysize()); + for (size_t c = 0; c < 3; ++c) { + float mul = min == max ? 0 : (255.0f / (max - min)); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* JXL_RESTRICT row_in = image.ConstRow(y); + uint8_t* JXL_RESTRICT row_out = normalized.PlaneRow(c, y); + for (size_t x = 0; x < image.xsize(); ++x) { + row_out[x] = static_cast<uint8_t>((row_in[x] - min) * mul); + } + } + } + DumpImageT(cparams, label, ColorEncoding::SRGB(), normalized); +} + +} // namespace + +void DumpImage(const CompressParams& cparams, const char* label, + const Image3<float>& image) { + DumpImageT(cparams, label, ColorEncoding::SRGB(), image); +} + +void DumpImage(const CompressParams& cparams, const char* label, + const Image3<uint8_t>& image) { + DumpImageT(cparams, label, ColorEncoding::SRGB(), image); +} + +void DumpXybImage(const CompressParams& cparams, const char* label, + const Image3F& image) { + if (!cparams.debug_image) return; + + Image3F linear(image.xsize(), image.ysize()); + OpsinParams opsin_params; + opsin_params.Init(kDefaultIntensityTarget); + OpsinToLinear(image, Rect(linear), nullptr, &linear, opsin_params); + + DumpImageT(cparams, label, ColorEncoding::LinearSRGB(), linear); +} + +void DumpPlaneNormalized(const CompressParams& cparams, const char* label, + const Plane<float>& image) { + DumpPlaneNormalizedT(cparams, label, image); +} + +void DumpPlaneNormalized(const CompressParams& cparams, const char* label, + const Plane<uint8_t>& image) { + DumpPlaneNormalizedT(cparams, label, image); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_debug_image.h b/third_party/jpeg-xl/lib/jxl/enc_debug_image.h new file mode 100644 index 0000000000..33799a5f7f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_debug_image.h @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_DEBUG_IMAGE_H_ +#define LIB_JXL_ENC_DEBUG_IMAGE_H_ + +// Optional output images for debugging. + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" + +namespace jxl { + +void DumpImage(const CompressParams& cparams, const char* label, + const Image3<float>& image); +void DumpImage(const CompressParams& cparams, const char* label, + const Image3<uint8_t>& image); +void DumpXybImage(const CompressParams& cparams, const char* label, + const Image3<float>& image); +void DumpPlaneNormalized(const CompressParams& cparams, const char* label, + const Plane<float>& image); +void DumpPlaneNormalized(const CompressParams& cparams, const char* label, + const Plane<uint8_t>& image); + +// Used to skip image creation if they won't be written to debug directory. +static inline bool WantDebugOutput(const CompressParams& cparams) { + return cparams.debug_image != nullptr; +} + +} // namespace jxl + +#endif // LIB_JXL_ENC_DEBUG_IMAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_detect_dots.cc b/third_party/jpeg-xl/lib/jxl/enc_detect_dots.cc new file mode 100644 index 0000000000..4ee8808766 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_detect_dots.cc @@ -0,0 +1,587 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_detect_dots.h" + +#include <stdint.h> + +#include <algorithm> +#include <array> +#include <cmath> +#include <cstdio> +#include <utility> +#include <vector> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_detect_dots.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/enc_linalg.h" +#include "lib/jxl/enc_optimize.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +// Set JXL_DEBUG_DOT_DETECT to 1 to enable debugging. +#ifndef JXL_DEBUG_DOT_DETECT +#define JXL_DEBUG_DOT_DETECT 0 +#endif + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::Sub; + +ImageF SumOfSquareDifferences(const Image3F& forig, const Image3F& smooth, + ThreadPool* pool) { + const HWY_FULL(float) d; + const auto color_coef0 = Set(d, 0.0f); + const auto color_coef1 = Set(d, 10.0f); + const auto color_coef2 = Set(d, 0.0f); + + ImageF sum_of_squares(forig.xsize(), forig.ysize()); + JXL_CHECK(RunOnPool( + pool, 0, forig.ysize(), ThreadPool::NoInit, + [&](const uint32_t task, size_t thread) { + const size_t y = static_cast<size_t>(task); + const float* JXL_RESTRICT orig_row0 = forig.Plane(0).ConstRow(y); + const float* JXL_RESTRICT orig_row1 = forig.Plane(1).ConstRow(y); + const float* JXL_RESTRICT orig_row2 = forig.Plane(2).ConstRow(y); + const float* JXL_RESTRICT smooth_row0 = smooth.Plane(0).ConstRow(y); + const float* JXL_RESTRICT smooth_row1 = smooth.Plane(1).ConstRow(y); + const float* JXL_RESTRICT smooth_row2 = smooth.Plane(2).ConstRow(y); + float* JXL_RESTRICT sos_row = sum_of_squares.Row(y); + + for (size_t x = 0; x < forig.xsize(); x += Lanes(d)) { + auto v0 = Sub(Load(d, orig_row0 + x), Load(d, smooth_row0 + x)); + auto v1 = Sub(Load(d, orig_row1 + x), Load(d, smooth_row1 + x)); + auto v2 = Sub(Load(d, orig_row2 + x), Load(d, smooth_row2 + x)); + v0 = Mul(Mul(v0, v0), color_coef0); + v1 = Mul(Mul(v1, v1), color_coef1); + v2 = Mul(Mul(v2, v2), color_coef2); + const auto sos = + Add(v0, Add(v1, v2)); // weighted sum of square diffs + Store(sos, d, sos_row + x); + } + }, + "ComputeEnergyImage")); + return sum_of_squares; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(SumOfSquareDifferences); // Local function + +const int kEllipseWindowSize = 5; + +namespace { +struct GaussianEllipse { + double x; // position in x + double y; // position in y + double sigma_x; // scale in x + double sigma_y; // scale in y + double angle; // ellipse rotation in radians + std::array<double, 3> intensity; // intensity in each channel + + // The following variables do not need to be encoded + double l2_loss; // error after the Gaussian was fit + double l1_loss; + double ridge_loss; // the l2_loss plus regularization term + double custom_loss; // experimental custom loss + std::array<double, 3> bgColor; // best background color + size_t neg_pixels; // number of negative pixels when subtracting dot + std::array<double, 3> neg_value; // debt due to channel truncation +}; +double DotGaussianModel(double dx, double dy, double ct, double st, + double sigma_x, double sigma_y, double intensity) { + double rx = ct * dx + st * dy; + double ry = -st * dx + ct * dy; + double md = (rx * rx / sigma_x) + (ry * ry / sigma_y); + double value = intensity * exp(-0.5 * md); + return value; +} + +constexpr bool kOptimizeBackground = true; + +// Gaussian that smooths noise but preserves dots +const WeightsSeparable5& WeightsSeparable5Gaussian0_65() { + constexpr float w0 = 0.558311f; + constexpr float w1 = 0.210395f; + constexpr float w2 = 0.010449f; + static constexpr WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}}; + return weights; +} + +// (Iterated) Gaussian that removes dots. +const WeightsSeparable5& WeightsSeparable5Gaussian3() { + constexpr float w0 = 0.222338f; + constexpr float w1 = 0.210431f; + constexpr float w2 = 0.1784f; + static constexpr WeightsSeparable5 weights = { + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}, + {HWY_REP4(w0), HWY_REP4(w1), HWY_REP4(w2)}}; + return weights; +} + +ImageF ComputeEnergyImage(const Image3F& orig, Image3F* smooth, + ThreadPool* pool) { + // Prepare guidance images for dot selection. + Image3F forig(orig.xsize(), orig.ysize()); + *smooth = Image3F(orig.xsize(), orig.ysize()); + Rect rect(orig); + + const auto& weights1 = WeightsSeparable5Gaussian0_65(); + const auto& weights3 = WeightsSeparable5Gaussian3(); + + for (size_t c = 0; c < 3; ++c) { + // Use forig as temporary storage to reduce memory and keep it warmer. + Separable5(orig.Plane(c), rect, weights3, pool, &forig.Plane(c)); + Separable5(forig.Plane(c), rect, weights3, pool, &smooth->Plane(c)); + Separable5(orig.Plane(c), rect, weights1, pool, &forig.Plane(c)); + } + + return HWY_DYNAMIC_DISPATCH(SumOfSquareDifferences)(forig, *smooth, pool); +} + +struct Pixel { + int x; + int y; +}; + +Pixel operator+(const Pixel& a, const Pixel& b) { + return Pixel{a.x + b.x, a.y + b.y}; +} + +// Maximum area in pixels of a ellipse +const size_t kMaxCCSize = 1000; + +// Extracts a connected component from a Binary image where seed is part +// of the component +bool ExtractComponent(ImageF* img, std::vector<Pixel>* pixels, + const Pixel& seed, double threshold) { + static const std::vector<Pixel> neighbors{{1, -1}, {1, 0}, {1, 1}, {0, -1}, + {0, 1}, {-1, -1}, {-1, 1}, {1, 0}}; + std::vector<Pixel> q{seed}; + while (!q.empty()) { + Pixel current = q.back(); + q.pop_back(); + pixels->push_back(current); + if (pixels->size() > kMaxCCSize) return false; + for (const Pixel& delta : neighbors) { + Pixel child = current + delta; + if (child.x >= 0 && static_cast<size_t>(child.x) < img->xsize() && + child.y >= 0 && static_cast<size_t>(child.y) < img->ysize()) { + float* value = &img->Row(child.y)[child.x]; + if (*value > threshold) { + *value = 0.0; + q.push_back(child); + } + } + } + } + return true; +} + +inline bool PointInRect(const Rect& r, const Pixel& p) { + return (static_cast<size_t>(p.x) >= r.x0() && + static_cast<size_t>(p.x) < (r.x0() + r.xsize()) && + static_cast<size_t>(p.y) >= r.y0() && + static_cast<size_t>(p.y) < (r.y0() + r.ysize())); +} + +struct ConnectedComponent { + ConnectedComponent(const Rect& bounds, const std::vector<Pixel>&& pixels) + : bounds(bounds), pixels(pixels) {} + Rect bounds; + std::vector<Pixel> pixels; + float maxEnergy; + float meanEnergy; + float varEnergy; + float meanBg; + float varBg; + float score; + Pixel mode; + + void CompStats(const ImageF& energy, int extra) { + maxEnergy = 0.0; + meanEnergy = 0.0; + varEnergy = 0.0; + meanBg = 0.0; + varBg = 0.0; + int nIn = 0; + int nOut = 0; + mode.x = 0; + mode.y = 0; + for (int sy = -extra; sy < (static_cast<int>(bounds.ysize()) + extra); + sy++) { + int y = sy + static_cast<int>(bounds.y0()); + if (y < 0 || static_cast<size_t>(y) >= energy.ysize()) continue; + const float* JXL_RESTRICT erow = energy.ConstRow(y); + for (int sx = -extra; sx < (static_cast<int>(bounds.xsize()) + extra); + sx++) { + int x = sx + static_cast<int>(bounds.x0()); + if (x < 0 || static_cast<size_t>(x) >= energy.xsize()) continue; + if (erow[x] > maxEnergy) { + maxEnergy = erow[x]; + mode.x = x; + mode.y = y; + } + if (PointInRect(bounds, Pixel{x, y})) { + meanEnergy += erow[x]; + varEnergy += erow[x] * erow[x]; + nIn++; + } else { + meanBg += erow[x]; + varBg += erow[x] * erow[x]; + nOut++; + } + } + } + meanEnergy = meanEnergy / nIn; + meanBg = meanBg / nOut; + varEnergy = (varEnergy / nIn) - meanEnergy * meanEnergy; + varBg = (varBg / nOut) - meanBg * meanBg; + score = (meanEnergy - meanBg) / std::sqrt(varBg); + } +}; + +Rect BoundingRectangle(const std::vector<Pixel>& pixels) { + JXL_ASSERT(!pixels.empty()); + int low_x, high_x, low_y, high_y; + low_x = high_x = pixels[0].x; + low_y = high_y = pixels[0].y; + for (const Pixel& p : pixels) { + low_x = std::min(low_x, p.x); + high_x = std::max(high_x, p.x); + low_y = std::min(low_y, p.y); + high_y = std::max(high_y, p.y); + } + return Rect(low_x, low_y, high_x - low_x + 1, high_y - low_y + 1); +} + +std::vector<ConnectedComponent> FindCC(const ImageF& energy, double t_low, + double t_high, uint32_t maxWindow, + double minScore) { + const int kExtraRect = 4; + ImageF img(energy.xsize(), energy.ysize()); + CopyImageTo(energy, &img); + std::vector<ConnectedComponent> ans; + for (size_t y = 0; y < img.ysize(); y++) { + float* JXL_RESTRICT row = img.Row(y); + for (size_t x = 0; x < img.xsize(); x++) { + if (row[x] > t_high) { + std::vector<Pixel> pixels; + row[x] = 0.0; + bool success = ExtractComponent( + &img, &pixels, Pixel{static_cast<int>(x), static_cast<int>(y)}, + t_low); + if (!success) continue; +#if JXL_DEBUG_DOT_DETECT + for (size_t i = 0; i < pixels.size(); i++) { + fprintf(stderr, "(%d,%d) ", pixels[i].x, pixels[i].y); + } + fprintf(stderr, "\n"); +#endif // JXL_DEBUG_DOT_DETECT + Rect bounds = BoundingRectangle(pixels); + if (bounds.xsize() < maxWindow && bounds.ysize() < maxWindow) { + ConnectedComponent cc{bounds, std::move(pixels)}; + cc.CompStats(energy, kExtraRect); + if (cc.score < minScore) continue; + JXL_DEBUG(JXL_DEBUG_DOT_DETECT, + "cc mode: (%d,%d), max: %f, bgMean: %f bgVar: " + "%f bound:(%" PRIuS ",%" PRIuS ",%" PRIuS ",%" PRIuS ")\n", + cc.mode.x, cc.mode.y, cc.maxEnergy, cc.meanEnergy, + cc.varEnergy, cc.bounds.x0(), cc.bounds.y0(), + cc.bounds.xsize(), cc.bounds.ysize()); + ans.push_back(cc); + } + } + } + } + return ans; +} + +// TODO(sggonzalez): Adapt this function for the different color spaces or +// remove it if the color space with the best performance does not need it +void ComputeDotLosses(GaussianEllipse* ellipse, const ConnectedComponent& cc, + const Image3F& img, const Image3F& background) { + const int rectBounds = 2; + const double kIntensityR = 0.0; // 0.015; + const double kSigmaR = 0.0; // 0.01; + const double kZeroEpsilon = 0.1; // Tolerance to consider a value negative + double ct = cos(ellipse->angle), st = sin(ellipse->angle); + const std::array<double, 3> channelGains{{1.0, 1.0, 1.0}}; + int N = 0; + ellipse->l1_loss = 0.0; + ellipse->l2_loss = 0.0; + ellipse->neg_pixels = 0; + ellipse->neg_value.fill(0.0); + double distMeanModeSq = (cc.mode.x - ellipse->x) * (cc.mode.x - ellipse->x) + + (cc.mode.y - ellipse->y) * (cc.mode.y - ellipse->y); + ellipse->custom_loss = 0.0; + for (int c = 0; c < 3; c++) { + for (int sy = -rectBounds; + sy < (static_cast<int>(cc.bounds.ysize()) + rectBounds); sy++) { + int y = sy + cc.bounds.y0(); + if (y < 0 || static_cast<size_t>(y) >= img.ysize()) continue; + const float* JXL_RESTRICT row = img.ConstPlaneRow(c, y); + // bgrow is only used if kOptimizeBackground is false. + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(c, y); + for (int sx = -rectBounds; + sx < (static_cast<int>(cc.bounds.xsize()) + rectBounds); sx++) { + int x = sx + cc.bounds.x0(); + if (x < 0 || static_cast<size_t>(x) >= img.xsize()) continue; + double target = row[x]; + double dotDelta = DotGaussianModel( + x - ellipse->x, y - ellipse->y, ct, st, ellipse->sigma_x, + ellipse->sigma_y, ellipse->intensity[c]); + if (dotDelta > target + kZeroEpsilon) { + ellipse->neg_pixels++; + ellipse->neg_value[c] += dotDelta - target; + } + double bkg = kOptimizeBackground ? ellipse->bgColor[c] : bgrow[x]; + double pred = bkg + dotDelta; + double diff = target - pred; + double l2 = channelGains[c] * diff * diff; + double l1 = channelGains[c] * std::fabs(diff); + ellipse->l2_loss += l2; + ellipse->l1_loss += l1; + double w = DotGaussianModel(x - cc.mode.x, y - cc.mode.y, 1.0, 0.0, + 1.0 + ellipse->sigma_x, + 1.0 + ellipse->sigma_y, 1.0); + ellipse->custom_loss += w * l2; + N++; + } + } + } + ellipse->l2_loss /= N; + ellipse->custom_loss /= N; + ellipse->custom_loss += 20.0 * distMeanModeSq + ellipse->neg_value[1]; + ellipse->l1_loss /= N; + double ridgeTerm = kSigmaR * ellipse->sigma_x + kSigmaR * ellipse->sigma_y; + for (int c = 0; c < 3; c++) { + ridgeTerm += kIntensityR * ellipse->intensity[c] * ellipse->intensity[c]; + } + ellipse->ridge_loss = ellipse->l2_loss + ridgeTerm; +} + +GaussianEllipse FitGaussianFast(const ConnectedComponent& cc, + const ImageF& energy, const Image3F& img, + const Image3F& background) { + constexpr bool leastSqIntensity = true; + constexpr double kEpsilon = 1e-6; + GaussianEllipse ans; + constexpr int kRectBounds = (kEllipseWindowSize >> 1); + + // Compute the 1st and 2nd moments of the CC + double sum = 0.0; + int N = 0; + std::array<double, 3> m1{{0.0, 0.0, 0.0}}; + std::array<double, 3> m2{{0.0, 0.0, 0.0}}; + std::array<double, 3> color{{0.0, 0.0, 0.0}}; + std::array<double, 3> bgColor{{0.0, 0.0, 0.0}}; + + JXL_DEBUG(JXL_DEBUG_DOT_DETECT, + "%" PRIuS " %" PRIuS " %" PRIuS " %" PRIuS "\n", cc.bounds.x0(), + cc.bounds.y0(), cc.bounds.xsize(), cc.bounds.ysize()); + for (int c = 0; c < 3; c++) { + color[c] = img.ConstPlaneRow(c, cc.mode.y)[cc.mode.x] - + background.ConstPlaneRow(c, cc.mode.y)[cc.mode.x]; + } + double sign = (color[1] > 0) ? 1 : -1; + for (int sy = -kRectBounds; sy <= kRectBounds; sy++) { + int y = sy + cc.mode.y; + if (y < 0 || static_cast<size_t>(y) >= energy.ysize()) continue; + const float* JXL_RESTRICT row = img.ConstPlaneRow(1, y); + const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(1, y); + for (int sx = -kRectBounds; sx <= kRectBounds; sx++) { + int x = sx + cc.mode.x; + if (x < 0 || static_cast<size_t>(x) >= energy.xsize()) continue; + double w = std::max(kEpsilon, sign * (row[x] - bgrow[x])); + sum += w; + + m1[0] += w * x; + m1[1] += w * y; + m2[0] += w * x * x; + m2[1] += w * x * y; + m2[2] += w * y * y; + for (int c = 0; c < 3; c++) { + bgColor[c] += background.ConstPlaneRow(c, y)[x]; + } + N++; + } + } + JXL_CHECK(N > 0); + + for (int i = 0; i < 3; i++) { + m1[i] /= sum; + m2[i] /= sum; + bgColor[i] /= N; + } + + // Some magic constants + constexpr double kSigmaMult = 1.0; + constexpr std::array<double, 3> kScaleMult{{1.1, 1.1, 1.1}}; + + // Now set the parameters of the Gaussian + ans.x = m1[0]; + ans.y = m1[1]; + for (int j = 0; j < 3; j++) { + ans.intensity[j] = kScaleMult[j] * color[j]; + } + + ImageD Sigma(2, 2), D(1, 2), U(2, 2); + Sigma.Row(0)[0] = m2[0] - m1[0] * m1[0]; + Sigma.Row(1)[1] = m2[2] - m1[1] * m1[1]; + Sigma.Row(0)[1] = Sigma.Row(1)[0] = m2[1] - m1[0] * m1[1]; + ConvertToDiagonal(Sigma, &D, &U); + const double* JXL_RESTRICT d = D.ConstRow(0); + const double* JXL_RESTRICT u = U.ConstRow(1); + int p1 = 0, p2 = 1; + if (d[0] < d[1]) std::swap(p1, p2); + ans.sigma_x = kSigmaMult * d[p1]; + ans.sigma_y = kSigmaMult * d[p2]; + ans.angle = std::atan2(u[p1], u[p2]); + ans.l2_loss = 0.0; + ans.bgColor = bgColor; + if (leastSqIntensity) { + GaussianEllipse* ellipse = &ans; + double ct = cos(ans.angle), st = sin(ans.angle); + // Estimate intensity with least squares (fixed background) + for (int c = 0; c < 3; c++) { + double gg = 0.0; + double gd = 0.0; + int yc = static_cast<int>(cc.mode.y); + int xc = static_cast<int>(cc.mode.x); + for (int y = yc - kRectBounds; y <= yc + kRectBounds; y++) { + if (y < 0 || static_cast<size_t>(y) >= img.ysize()) continue; + const float* JXL_RESTRICT row = img.ConstPlaneRow(c, y); + const float* JXL_RESTRICT bgrow = background.ConstPlaneRow(c, y); + for (int x = xc - kRectBounds; x <= xc + kRectBounds; x++) { + if (x < 0 || static_cast<size_t>(x) >= img.xsize()) continue; + double target = row[x] - bgrow[x]; + double gaussian = + DotGaussianModel(x - ellipse->x, y - ellipse->y, ct, st, + ellipse->sigma_x, ellipse->sigma_y, 1.0); + gg += gaussian * gaussian; + gd += gaussian * target; + } + } + ans.intensity[c] = gd / (gg + 1e-6); // Regularized least squares + } + } + ComputeDotLosses(&ans, cc, img, background); + return ans; +} + +GaussianEllipse FitGaussian(const ConnectedComponent& cc, const ImageF& energy, + const Image3F& img, const Image3F& background) { + auto ellipse = FitGaussianFast(cc, energy, img, background); + if (ellipse.sigma_x < ellipse.sigma_y) { + std::swap(ellipse.sigma_x, ellipse.sigma_y); + ellipse.angle += kPi / 2.0; + } + ellipse.angle -= kPi * std::floor(ellipse.angle / kPi); + if (fabs(ellipse.angle - kPi) < 1e-6 || fabs(ellipse.angle) < 1e-6) { + ellipse.angle = 0.0; + } + JXL_CHECK(ellipse.angle >= 0 && ellipse.angle <= kPi && + ellipse.sigma_x >= ellipse.sigma_y); + JXL_DEBUG(JXL_DEBUG_DOT_DETECT, + "Ellipse mu=(%lf,%lf) sigma=(%lf,%lf) angle=%lf " + "intensity=(%lf,%lf,%lf) bg=(%lf,%lf,%lf) l2_loss=%lf " + "custom_loss=%lf, neg_pix=%" PRIuS ", neg_v=(%lf,%lf,%lf)\n", + ellipse.x, ellipse.y, ellipse.sigma_x, ellipse.sigma_y, + ellipse.angle, ellipse.intensity[0], ellipse.intensity[1], + ellipse.intensity[2], ellipse.bgColor[0], ellipse.bgColor[1], + ellipse.bgColor[2], ellipse.l2_loss, ellipse.custom_loss, + ellipse.neg_pixels, ellipse.neg_value[0], ellipse.neg_value[1], + ellipse.neg_value[2]); + return ellipse; +} + +} // namespace + +std::vector<PatchInfo> DetectGaussianEllipses( + const Image3F& opsin, const GaussianDetectParams& params, + const EllipseQuantParams& qParams, ThreadPool* pool) { + std::vector<PatchInfo> dots; + Image3F smooth(opsin.xsize(), opsin.ysize()); + ImageF energy = ComputeEnergyImage(opsin, &smooth, pool); + std::vector<ConnectedComponent> components = FindCC( + energy, params.t_low, params.t_high, params.maxWinSize, params.minScore); + size_t numCC = + std::min(params.maxCC, (components.size() * params.percCC) / 100); + if (components.size() > numCC) { + std::sort( + components.begin(), components.end(), + [](const ConnectedComponent& a, const ConnectedComponent& b) -> bool { + return a.score > b.score; + }); + components.erase(components.begin() + numCC, components.end()); + } + for (const auto& cc : components) { + GaussianEllipse ellipse = FitGaussian(cc, energy, opsin, smooth); + if (ellipse.x < 0.0 || + std::ceil(ellipse.x) >= static_cast<double>(opsin.xsize()) || + ellipse.y < 0.0 || + std::ceil(ellipse.y) >= static_cast<double>(opsin.ysize())) { + continue; + } + if (ellipse.neg_pixels > params.maxNegPixels) continue; + double intensity = 0.21 * ellipse.intensity[0] + + 0.72 * ellipse.intensity[1] + + 0.07 * ellipse.intensity[2]; + double intensitySq = intensity * intensity; + // for (int c = 0; c < 3; c++) { + // intensitySq += ellipse.intensity[c] * ellipse.intensity[c]; + //} + double sqDistMeanMode = (ellipse.x - cc.mode.x) * (ellipse.x - cc.mode.x) + + (ellipse.y - cc.mode.y) * (ellipse.y - cc.mode.y); + if (ellipse.l2_loss < params.maxL2Loss && + ellipse.custom_loss < params.maxCustomLoss && + intensitySq > (params.minIntensity * params.minIntensity) && + sqDistMeanMode < params.maxDistMeanMode * params.maxDistMeanMode) { + size_t x0 = cc.bounds.x0(); + size_t y0 = cc.bounds.y0(); + dots.emplace_back(); + dots.back().second.emplace_back(x0, y0); + QuantizedPatch& patch = dots.back().first; + patch.xsize = cc.bounds.xsize(); + patch.ysize = cc.bounds.ysize(); + for (size_t y = 0; y < patch.ysize; y++) { + for (size_t x = 0; x < patch.xsize; x++) { + for (size_t c = 0; c < 3; c++) { + patch.fpixels[c][y * patch.xsize + x] = + opsin.ConstPlaneRow(c, y0 + y)[x0 + x] - + smooth.ConstPlaneRow(c, y0 + y)[x0 + x]; + } + } + } + } + } + return dots; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_detect_dots.h b/third_party/jpeg-xl/lib/jxl/enc_detect_dots.h new file mode 100644 index 0000000000..c3071d9a2f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_detect_dots.h @@ -0,0 +1,67 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// We attempt to remove dots, or speckle from images using Gaussian blur. +#ifndef LIB_JXL_ENC_DETECT_DOTS_H_ +#define LIB_JXL_ENC_DETECT_DOTS_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <array> +#include <vector> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/image.h" + +namespace jxl { + +struct GaussianDetectParams { + double t_high = 0; // at least one pixel must have larger energy than t_high + double t_low = 0; // all pixels must have a larger energy than tLow + uint32_t maxWinSize = 0; // discard dots larger than this containing window + double maxL2Loss = 0; + double maxCustomLoss = 0; + double minIntensity = 0; // If the intensity is too low, discard it + double maxDistMeanMode = 0; // The mean and the mode must be close + size_t maxNegPixels = 0; // Maximum number of negative pixel + size_t minScore = 0; + size_t maxCC = 50; // Maximum number of CC to keep + size_t percCC = 15; // Percentage in [0,100] of CC to keep +}; + +// Ellipse Quantization Params +struct EllipseQuantParams { + size_t xsize; // Image size in x + size_t ysize; // Image size in y + size_t qPosition; // Position quantization delta + // Quantization for the Gaussian sigma parameters + double minSigma; + double maxSigma; + size_t qSigma; // number of quantization levels + // Quantization for the rotation angle (between -pi and pi) + size_t qAngle; + // Quantization for the intensity + std::array<double, 3> minIntensity; + std::array<double, 3> maxIntensity; + std::array<size_t, 3> qIntensity; // number of quantization levels + // Extra parameters for the encoding + bool subtractQuantized; // Should we subtract quantized or detected dots? + float ytox; + float ytob; + + void QuantPositionSize(size_t* xsize, size_t* ysize) const; +}; + +// Detects dots in XYB image. +std::vector<PatchInfo> DetectGaussianEllipses( + const Image3F& opsin, const GaussianDetectParams& params, + const EllipseQuantParams& qParams, ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_ENC_DETECT_DOTS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.cc b/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.cc new file mode 100644 index 0000000000..a5b1af63b2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.cc @@ -0,0 +1,71 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_dot_dictionary.h" + +#include <stddef.h> +#include <string.h> + +#include <array> +#include <utility> + +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_detect_dots.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// Private implementation of Dictionary Encode/Decode +namespace { + +/* Quantization constants for Ellipse dots */ +const size_t kEllipsePosQ = 2; // Quantization level for the position +const double kEllipseMinSigma = 0.1; // Minimum sigma value +const double kEllipseMaxSigma = 3.1; // Maximum Sigma value +const size_t kEllipseSigmaQ = 16; // Number of quantization levels for sigma +const size_t kEllipseAngleQ = 8; // Quantization level for the angle +// TODO(user): fix these values. +const std::array<double, 3> kEllipseMinIntensity{{-0.05, 0.0, -0.5}}; +const std::array<double, 3> kEllipseMaxIntensity{{0.05, 1.0, 0.4}}; +const std::array<size_t, 3> kEllipseIntensityQ{{10, 36, 10}}; +} // namespace + +std::vector<PatchInfo> FindDotDictionary(const CompressParams& cparams, + const Image3F& opsin, + const ColorCorrelationMap& cmap, + ThreadPool* pool) { + if (ApplyOverride(cparams.dots, + cparams.butteraugli_distance >= kMinButteraugliForDots)) { + GaussianDetectParams ellipse_params; + ellipse_params.t_high = 0.04; + ellipse_params.t_low = 0.02; + ellipse_params.maxWinSize = 5; + ellipse_params.maxL2Loss = 0.005; + ellipse_params.maxCustomLoss = 300; + ellipse_params.minIntensity = 0.12; + ellipse_params.maxDistMeanMode = 1.0; + ellipse_params.maxNegPixels = 0; + ellipse_params.minScore = 12.0; + ellipse_params.maxCC = 100; + ellipse_params.percCC = 100; + EllipseQuantParams qParams{ + opsin.xsize(), opsin.ysize(), kEllipsePosQ, + kEllipseMinSigma, kEllipseMaxSigma, kEllipseSigmaQ, + kEllipseAngleQ, kEllipseMinIntensity, kEllipseMaxIntensity, + kEllipseIntensityQ, kEllipsePosQ <= 5, cmap.YtoXRatio(0), + cmap.YtoBRatio(0)}; + + return DetectGaussianEllipses(opsin, ellipse_params, qParams, pool); + } + return {}; +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.h b/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.h new file mode 100644 index 0000000000..2ba4393f30 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_dot_dictionary.h @@ -0,0 +1,34 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_DOT_DICTIONARY_H_ +#define LIB_JXL_ENC_DOT_DICTIONARY_H_ + +// Dots are stored in a dictionary to avoid storing similar dots multiple +// times. + +#include <stddef.h> + +#include <vector> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/image.h" + +namespace jxl { + +std::vector<PatchInfo> FindDotDictionary(const CompressParams& cparams, + const Image3F& opsin, + const ColorCorrelationMap& cmap, + ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_ENC_DOT_DICTIONARY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.cc b/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.cc new file mode 100644 index 0000000000..07601a2221 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.cc @@ -0,0 +1,272 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_entropy_coder.h" + +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <utility> +#include <vector> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_entropy_coder.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_context_map.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/pack_signed.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::AndNot; +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::GetLane; + +// Returns number of non-zero coefficients (but skip LLF). +// We cannot rely on block[] being all-zero bits, so first truncate to integer. +// Also writes the per-8x8 block nzeros starting at nzeros_pos. +int32_t NumNonZeroExceptLLF(const size_t cx, const size_t cy, + const AcStrategy acs, const size_t covered_blocks, + const size_t log2_covered_blocks, + const int32_t* JXL_RESTRICT block, + const size_t nzeros_stride, + int32_t* JXL_RESTRICT nzeros_pos) { + const HWY_CAPPED(int32_t, kBlockDim) di; + + const auto zero = Zero(di); + // Add FF..FF for every zero coefficient, negate to get #zeros. + auto neg_sum_zero = zero; + + { + // Mask sufficient for one row of coefficients. + HWY_ALIGN const int32_t + llf_mask_lanes[AcStrategy::kMaxCoeffBlocks * (1 + kBlockDim)] = { + -1, -1, -1, -1}; + // First cx=1,2,4 elements are FF..FF, others 0. + const int32_t* llf_mask_pos = + llf_mask_lanes + AcStrategy::kMaxCoeffBlocks - cx; + + // Rows with LLF: mask out the LLF + for (size_t y = 0; y < cy; y++) { + for (size_t x = 0; x < cx * kBlockDim; x += Lanes(di)) { + const auto llf_mask = LoadU(di, llf_mask_pos + x); + + // LLF counts as zero so we don't include it in nzeros. + const auto coef = + AndNot(llf_mask, Load(di, &block[y * cx * kBlockDim + x])); + + neg_sum_zero = Add(neg_sum_zero, VecFromMask(di, Eq(coef, zero))); + } + } + } + + // Remaining rows: no mask + for (size_t y = cy; y < cy * kBlockDim; y++) { + for (size_t x = 0; x < cx * kBlockDim; x += Lanes(di)) { + const auto coef = Load(di, &block[y * cx * kBlockDim + x]); + neg_sum_zero = Add(neg_sum_zero, VecFromMask(di, Eq(coef, zero))); + } + } + + // We want area - sum_zero, add because neg_sum_zero is already negated. + const int32_t nzeros = + int32_t(cx * cy * kDCTBlockSize) + GetLane(SumOfLanes(di, neg_sum_zero)); + + const int32_t shifted_nzeros = static_cast<int32_t>( + (nzeros + covered_blocks - 1) >> log2_covered_blocks); + // Need non-canonicalized dimensions! + for (size_t y = 0; y < acs.covered_blocks_y(); y++) { + for (size_t x = 0; x < acs.covered_blocks_x(); x++) { + nzeros_pos[x + y * nzeros_stride] = shifted_nzeros; + } + } + + return nzeros; +} + +// Specialization for 8x8, where only top-left is LLF/DC. +// About 1% overall speedup vs. NumNonZeroExceptLLF. +int32_t NumNonZero8x8ExceptDC(const int32_t* JXL_RESTRICT block, + int32_t* JXL_RESTRICT nzeros_pos) { + const HWY_CAPPED(int32_t, kBlockDim) di; + + const auto zero = Zero(di); + // Add FF..FF for every zero coefficient, negate to get #zeros. + auto neg_sum_zero = zero; + + { + // First row has DC, so mask + const size_t y = 0; + HWY_ALIGN const int32_t dc_mask_lanes[kBlockDim] = {-1}; + + for (size_t x = 0; x < kBlockDim; x += Lanes(di)) { + const auto dc_mask = Load(di, dc_mask_lanes + x); + + // DC counts as zero so we don't include it in nzeros. + const auto coef = AndNot(dc_mask, Load(di, &block[y * kBlockDim + x])); + + neg_sum_zero = Add(neg_sum_zero, VecFromMask(di, Eq(coef, zero))); + } + } + + // Remaining rows: no mask + for (size_t y = 1; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x += Lanes(di)) { + const auto coef = Load(di, &block[y * kBlockDim + x]); + neg_sum_zero = Add(neg_sum_zero, VecFromMask(di, Eq(coef, zero))); + } + } + + // We want 64 - sum_zero, add because neg_sum_zero is already negated. + const int32_t nzeros = + int32_t(kDCTBlockSize) + GetLane(SumOfLanes(di, neg_sum_zero)); + + *nzeros_pos = nzeros; + + return nzeros; +} + +// The number of nonzeros of each block is predicted from the top and the left +// blocks, with opportune scaling to take into account the number of blocks of +// each strategy. The predicted number of nonzeros divided by two is used as a +// context; if this number is above 63, a specific context is used. If the +// number of nonzeros of a strategy is above 63, it is written directly using a +// fixed number of bits (that depends on the size of the strategy). +void TokenizeCoefficients(const coeff_order_t* JXL_RESTRICT orders, + const Rect& rect, + const int32_t* JXL_RESTRICT* JXL_RESTRICT ac_rows, + const AcStrategyImage& ac_strategy, + YCbCrChromaSubsampling cs, + Image3I* JXL_RESTRICT tmp_num_nzeroes, + std::vector<Token>* JXL_RESTRICT output, + const ImageB& qdc, const ImageI& qf, + const BlockCtxMap& block_ctx_map) { + const size_t xsize_blocks = rect.xsize(); + const size_t ysize_blocks = rect.ysize(); + output->clear(); + // TODO(user): update the estimate: usually less coefficients are used. + output->reserve(3 * xsize_blocks * ysize_blocks * kDCTBlockSize); + + size_t offset[3] = {}; + const size_t nzeros_stride = tmp_num_nzeroes->PixelsPerRow(); + for (size_t by = 0; by < ysize_blocks; ++by) { + size_t sby[3] = {by >> cs.VShift(0), by >> cs.VShift(1), + by >> cs.VShift(2)}; + int32_t* JXL_RESTRICT row_nzeros[3] = { + tmp_num_nzeroes->PlaneRow(0, sby[0]), + tmp_num_nzeroes->PlaneRow(1, sby[1]), + tmp_num_nzeroes->PlaneRow(2, sby[2]), + }; + const int32_t* JXL_RESTRICT row_nzeros_top[3] = { + sby[0] == 0 ? nullptr : tmp_num_nzeroes->ConstPlaneRow(0, sby[0] - 1), + sby[1] == 0 ? nullptr : tmp_num_nzeroes->ConstPlaneRow(1, sby[1] - 1), + sby[2] == 0 ? nullptr : tmp_num_nzeroes->ConstPlaneRow(2, sby[2] - 1), + }; + const uint8_t* JXL_RESTRICT row_qdc = + qdc.ConstRow(rect.y0() + by) + rect.x0(); + const int32_t* JXL_RESTRICT row_qf = rect.ConstRow(qf, by); + AcStrategyRow acs_row = ac_strategy.ConstRow(rect, by); + for (size_t bx = 0; bx < xsize_blocks; ++bx) { + AcStrategy acs = acs_row[bx]; + if (!acs.IsFirstBlock()) continue; + size_t sbx[3] = {bx >> cs.HShift(0), bx >> cs.HShift(1), + bx >> cs.HShift(2)}; + size_t cx = acs.covered_blocks_x(); + size_t cy = acs.covered_blocks_y(); + const size_t covered_blocks = cx * cy; // = #LLF coefficients + const size_t log2_covered_blocks = + Num0BitsBelowLS1Bit_Nonzero(covered_blocks); + const size_t size = covered_blocks * kDCTBlockSize; + + CoefficientLayout(&cy, &cx); // swap cx/cy to canonical order + + for (int c : {1, 0, 2}) { + if (sbx[c] << cs.HShift(c) != bx) continue; + if (sby[c] << cs.VShift(c) != by) continue; + const int32_t* JXL_RESTRICT block = ac_rows[c] + offset[c]; + + int32_t nzeros = + (covered_blocks == 1) + ? NumNonZero8x8ExceptDC(block, row_nzeros[c] + sbx[c]) + : NumNonZeroExceptLLF(cx, cy, acs, covered_blocks, + log2_covered_blocks, block, nzeros_stride, + row_nzeros[c] + sbx[c]); + + int ord = kStrategyOrder[acs.RawStrategy()]; + const coeff_order_t* JXL_RESTRICT order = + &orders[CoeffOrderOffset(ord, c)]; + + int32_t predicted_nzeros = + PredictFromTopAndLeft(row_nzeros_top[c], row_nzeros[c], sbx[c], 32); + size_t block_ctx = + block_ctx_map.Context(row_qdc[bx], row_qf[sbx[c]], ord, c); + const int32_t nzero_ctx = + block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx); + + output->emplace_back(nzero_ctx, nzeros); + const size_t histo_offset = + block_ctx_map.ZeroDensityContextsOffset(block_ctx); + // Skip LLF. + size_t prev = (nzeros > static_cast<ssize_t>(size / 16) ? 0 : 1); + for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) { + int32_t coeff = block[order[k]]; + size_t ctx = + histo_offset + ZeroDensityContext(nzeros, k, covered_blocks, + log2_covered_blocks, prev); + uint32_t u_coeff = PackSigned(coeff); + output->emplace_back(ctx, u_coeff); + prev = coeff != 0; + nzeros -= prev; + } + JXL_DASSERT(nzeros == 0); + offset[c] += size; + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(TokenizeCoefficients); +void TokenizeCoefficients(const coeff_order_t* JXL_RESTRICT orders, + const Rect& rect, + const int32_t* JXL_RESTRICT* JXL_RESTRICT ac_rows, + const AcStrategyImage& ac_strategy, + YCbCrChromaSubsampling cs, + Image3I* JXL_RESTRICT tmp_num_nzeroes, + std::vector<Token>* JXL_RESTRICT output, + const ImageB& qdc, const ImageI& qf, + const BlockCtxMap& block_ctx_map) { + return HWY_DYNAMIC_DISPATCH(TokenizeCoefficients)( + orders, rect, ac_rows, ac_strategy, cs, tmp_num_nzeroes, output, qdc, qf, + block_ctx_map); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.h b/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.h new file mode 100644 index 0000000000..7dfc71c726 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_entropy_coder.h @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_ENTROPY_CODER_H_ +#define LIB_JXL_ENC_ENTROPY_CODER_H_ + +#include <stddef.h> +#include <stdint.h> +#include <stdlib.h> +#include <string.h> +#include <sys/types.h> + +#include <memory> +#include <utility> +#include <vector> + +#include "lib/jxl/ac_context.h" // BlockCtxMap +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/frame_header.h" // YCbCrChromaSubsampling +#include "lib/jxl/image.h" + +// Entropy coding and context modeling of DC and AC coefficients, as well as AC +// strategy and quantization field. + +namespace jxl { + +// Generate DCT NxN quantized AC values tokens. +// Only the subset "rect" [in units of blocks] within all images. +// See also DecodeACVarBlock. +void TokenizeCoefficients(const coeff_order_t* JXL_RESTRICT orders, + const Rect& rect, + const int32_t* JXL_RESTRICT* JXL_RESTRICT ac_rows, + const AcStrategyImage& ac_strategy, + YCbCrChromaSubsampling cs, + Image3I* JXL_RESTRICT tmp_num_nzeroes, + std::vector<Token>* JXL_RESTRICT output, + const ImageB& qdc, const ImageI& qf, + const BlockCtxMap& block_ctx_map); + +} // namespace jxl + +#endif // LIB_JXL_ENC_ENTROPY_CODER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_external_image.cc b/third_party/jpeg-xl/lib/jxl/enc_external_image.cc new file mode 100644 index 0000000000..680323e79a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_external_image.cc @@ -0,0 +1,250 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_external_image.h" + +#include <jxl/types.h> +#include <string.h> + +#include <algorithm> +#include <array> +#include <atomic> +#include <functional> +#include <utility> +#include <vector> + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/float.h" +#include "lib/jxl/base/printf_macros.h" + +namespace jxl { +namespace { + +size_t JxlDataTypeBytes(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_UINT8: + return 1; + case JXL_TYPE_UINT16: + return 2; + case JXL_TYPE_FLOAT16: + return 2; + case JXL_TYPE_FLOAT: + return 4; + default: + return 0; + } +} + +} // namespace + +Status ConvertFromExternalNoSizeCheck(const uint8_t* data, size_t xsize, + size_t ysize, size_t stride, + size_t bits_per_sample, + JxlPixelFormat format, size_t c, + ThreadPool* pool, ImageF* channel) { + if (format.data_type == JXL_TYPE_UINT8) { + JXL_RETURN_IF_ERROR(bits_per_sample > 0 && bits_per_sample <= 8); + } else if (format.data_type == JXL_TYPE_UINT16) { + JXL_RETURN_IF_ERROR(bits_per_sample > 8 && bits_per_sample <= 16); + } else if (format.data_type != JXL_TYPE_FLOAT16 && + format.data_type != JXL_TYPE_FLOAT) { + JXL_FAILURE("unsupported pixel format data type %d", format.data_type); + } + + JXL_ASSERT(channel->xsize() == xsize); + JXL_ASSERT(channel->ysize() == ysize); + + size_t bytes_per_channel = JxlDataTypeBytes(format.data_type); + size_t bytes_per_pixel = format.num_channels * bytes_per_channel; + size_t pixel_offset = c * bytes_per_channel; + // Only for uint8/16. + float scale = 1.0f; + if (format.data_type == JXL_TYPE_UINT8) { + // We will do an integer multiplication by 257 in LoadFloatRow so that a + // UINT8 value and the corresponding UINT16 value convert to the same float + scale = 1.0f / (257 * ((1ull << bits_per_sample) - 1)); + } else { + scale = 1.0f / ((1ull << bits_per_sample) - 1); + } + + const bool little_endian = + format.endianness == JXL_LITTLE_ENDIAN || + (format.endianness == JXL_NATIVE_ENDIAN && IsLittleEndian()); + + std::atomic<size_t> error_count = {0}; + + const auto convert_row = [&](const uint32_t task, size_t /*thread*/) { + const size_t y = task; + size_t offset = y * stride + pixel_offset; + float* JXL_RESTRICT row_out = channel->Row(y); + const auto save_value = [&](size_t index, float value) { + row_out[index] = value; + }; + if (!LoadFloatRow(data + offset, xsize, bytes_per_pixel, format.data_type, + little_endian, scale, save_value)) { + error_count++; + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, static_cast<uint32_t>(ysize), + ThreadPool::NoInit, convert_row, + "ConvertExtraChannel")); + + if (error_count) { + JXL_FAILURE("unsupported pixel format data type"); + } + + return true; +} + +Status ConvertFromExternalNoSizeCheck(const uint8_t* data, size_t xsize, + size_t ysize, size_t stride, + const ColorEncoding& c_current, + size_t color_channels, + size_t bits_per_sample, + JxlPixelFormat format, ThreadPool* pool, + ImageBundle* ib) { + bool has_alpha = format.num_channels == 2 || format.num_channels == 4; + if (format.num_channels < color_channels) { + return JXL_FAILURE("Expected %" PRIuS + " color channels, received only %u channels", + color_channels, format.num_channels); + } + + Image3F color(xsize, ysize); + for (size_t c = 0; c < color_channels; ++c) { + JXL_RETURN_IF_ERROR(ConvertFromExternalNoSizeCheck( + data, xsize, ysize, stride, bits_per_sample, format, c, pool, + &color.Plane(c))); + } + if (color_channels == 1) { + CopyImageTo(color.Plane(0), &color.Plane(1)); + CopyImageTo(color.Plane(0), &color.Plane(2)); + } + ib->SetFromImage(std::move(color), c_current); + + // Passing an interleaved image with an alpha channel to an image that doesn't + // have alpha channel just discards the passed alpha channel. + if (has_alpha && ib->HasAlpha()) { + ImageF alpha(xsize, ysize); + JXL_RETURN_IF_ERROR(ConvertFromExternalNoSizeCheck( + data, xsize, ysize, stride, bits_per_sample, format, + format.num_channels - 1, pool, &alpha)); + ib->SetAlpha(std::move(alpha)); + } else if (!has_alpha && ib->HasAlpha()) { + // if alpha is not passed, but it is expected, then assume + // it is all-opaque + ImageF alpha(xsize, ysize); + FillImage(1.0f, &alpha); + ib->SetAlpha(std::move(alpha)); + } + + return true; +} + +Status ConvertFromExternal(const uint8_t* data, size_t size, size_t xsize, + size_t ysize, size_t bits_per_sample, + JxlPixelFormat format, size_t c, ThreadPool* pool, + ImageF* channel) { + size_t bytes_per_channel = JxlDataTypeBytes(format.data_type); + size_t bytes_per_pixel = format.num_channels * bytes_per_channel; + const size_t last_row_size = xsize * bytes_per_pixel; + const size_t align = format.align; + const size_t row_size = + (align > 1 ? jxl::DivCeil(last_row_size, align) * align : last_row_size); + const size_t bytes_to_read = row_size * (ysize - 1) + last_row_size; + if (xsize == 0 || ysize == 0) return JXL_FAILURE("Empty image"); + if (size > 0 && size < bytes_to_read) { + return JXL_FAILURE("Buffer size is too small, expected: %" PRIuS + " got: %" PRIuS " (Image: %" PRIuS "x%" PRIuS + "x%u, bytes_per_channel: %" PRIuS ")", + bytes_to_read, size, xsize, ysize, format.num_channels, + bytes_per_channel); + } + // Too large buffer is likely an application bug, so also fail for that. + // Do allow padding to stride in last row though. + if (size > row_size * ysize) { + return JXL_FAILURE("Buffer size is too large"); + } + return ConvertFromExternalNoSizeCheck( + data, xsize, ysize, row_size, bits_per_sample, format, c, pool, channel); +} +Status ConvertFromExternal(Span<const uint8_t> bytes, size_t xsize, + size_t ysize, const ColorEncoding& c_current, + size_t color_channels, size_t bits_per_sample, + JxlPixelFormat format, ThreadPool* pool, + ImageBundle* ib) { + bool has_alpha = format.num_channels == 2 || format.num_channels == 4; + if (format.num_channels < color_channels) { + return JXL_FAILURE("Expected %" PRIuS + " color channels, received only %u channels", + color_channels, format.num_channels); + } + + Image3F color(xsize, ysize); + for (size_t c = 0; c < color_channels; ++c) { + JXL_RETURN_IF_ERROR(ConvertFromExternal(bytes.data(), bytes.size(), xsize, + ysize, bits_per_sample, format, c, + pool, &color.Plane(c))); + } + if (color_channels == 1) { + CopyImageTo(color.Plane(0), &color.Plane(1)); + CopyImageTo(color.Plane(0), &color.Plane(2)); + } + ib->SetFromImage(std::move(color), c_current); + + // Passing an interleaved image with an alpha channel to an image that doesn't + // have alpha channel just discards the passed alpha channel. + if (has_alpha && ib->HasAlpha()) { + ImageF alpha(xsize, ysize); + JXL_RETURN_IF_ERROR(ConvertFromExternal( + bytes.data(), bytes.size(), xsize, ysize, bits_per_sample, format, + format.num_channels - 1, pool, &alpha)); + ib->SetAlpha(std::move(alpha)); + } else if (!has_alpha && ib->HasAlpha()) { + // if alpha is not passed, but it is expected, then assume + // it is all-opaque + ImageF alpha(xsize, ysize); + FillImage(1.0f, &alpha); + ib->SetAlpha(std::move(alpha)); + } + + return true; +} + +Status ConvertFromExternal(Span<const uint8_t> bytes, size_t xsize, + size_t ysize, const ColorEncoding& c_current, + size_t bits_per_sample, JxlPixelFormat format, + ThreadPool* pool, ImageBundle* ib) { + return ConvertFromExternal(bytes, xsize, ysize, c_current, + c_current.Channels(), bits_per_sample, format, + pool, ib); +} + +Status BufferToImageF(const JxlPixelFormat& pixel_format, size_t xsize, + size_t ysize, const void* buffer, size_t size, + ThreadPool* pool, ImageF* channel) { + size_t bitdepth = JxlDataTypeBytes(pixel_format.data_type) * kBitsPerByte; + return ConvertFromExternal(reinterpret_cast<const uint8_t*>(buffer), size, + xsize, ysize, bitdepth, pixel_format, 0, pool, + channel); +} + +Status BufferToImageBundle(const JxlPixelFormat& pixel_format, uint32_t xsize, + uint32_t ysize, const void* buffer, size_t size, + jxl::ThreadPool* pool, + const jxl::ColorEncoding& c_current, + jxl::ImageBundle* ib) { + size_t bitdepth = JxlDataTypeBytes(pixel_format.data_type) * kBitsPerByte; + JXL_RETURN_IF_ERROR(ConvertFromExternal( + jxl::Bytes(static_cast<const uint8_t*>(buffer), size), xsize, ysize, + c_current, bitdepth, pixel_format, pool, ib)); + ib->VerifyMetadata(); + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_external_image.h b/third_party/jpeg-xl/lib/jxl/enc_external_image.h new file mode 100644 index 0000000000..0d0fb75c5d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_external_image.h @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_EXTERNAL_IMAGE_H_ +#define LIB_JXL_ENC_EXTERNAL_IMAGE_H_ + +// Interleaved image for color transforms and Codec. + +#include <jxl/types.h> +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { +Status ConvertFromExternalNoSizeCheck(const uint8_t* data, size_t xsize, + size_t ysize, size_t stride, + size_t bits_per_sample, + JxlPixelFormat format, size_t c, + ThreadPool* pool, ImageF* channel); + +Status ConvertFromExternalNoSizeCheck(const uint8_t* data, size_t xsize, + size_t ysize, size_t stride, + const ColorEncoding& c_current, + size_t color_channels, + size_t bits_per_sample, + JxlPixelFormat format, ThreadPool* pool, + ImageBundle* ib); + +Status ConvertFromExternal(const uint8_t* data, size_t size, size_t xsize, + size_t ysize, size_t bits_per_sample, + JxlPixelFormat format, size_t c, ThreadPool* pool, + ImageF* channel); + +// Convert an interleaved pixel buffer to the internal ImageBundle +// representation. This is the opposite of ConvertToExternal(). +Status ConvertFromExternal(Span<const uint8_t> bytes, size_t xsize, + size_t ysize, const ColorEncoding& c_current, + size_t color_channels, size_t bits_per_sample, + JxlPixelFormat format, ThreadPool* pool, + ImageBundle* ib); +Status ConvertFromExternal(Span<const uint8_t> bytes, size_t xsize, + size_t ysize, const ColorEncoding& c_current, + size_t bits_per_sample, JxlPixelFormat format, + ThreadPool* pool, ImageBundle* ib); +Status BufferToImageF(const JxlPixelFormat& pixel_format, size_t xsize, + size_t ysize, const void* buffer, size_t size, + ThreadPool* pool, ImageF* channel); +Status BufferToImageBundle(const JxlPixelFormat& pixel_format, uint32_t xsize, + uint32_t ysize, const void* buffer, size_t size, + jxl::ThreadPool* pool, + const jxl::ColorEncoding& c_current, + jxl::ImageBundle* ib); + +} // namespace jxl + +#endif // LIB_JXL_ENC_EXTERNAL_IMAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_external_image_gbench.cc b/third_party/jpeg-xl/lib/jxl/enc_external_image_gbench.cc new file mode 100644 index 0000000000..64e9cf6ad5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_external_image_gbench.cc @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "benchmark/benchmark.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { +namespace { + +// Encoder case, deinterleaves a buffer. +void BM_EncExternalImage_ConvertImageRGBA(benchmark::State& state) { + const size_t kNumIter = 5; + size_t xsize = state.range(); + size_t ysize = state.range(); + + ImageMetadata im; + im.SetAlphaBits(8); + ImageBundle ib(&im); + + std::vector<uint8_t> interleaved(xsize * ysize * 4); + JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + for (auto _ : state) { + for (size_t i = 0; i < kNumIter; ++i) { + JXL_CHECK(ConvertFromExternal( + Bytes(interleaved.data(), interleaved.size()), xsize, ysize, + /*c_current=*/ColorEncoding::SRGB(), + /*bits_per_sample=*/8, format, + /*pool=*/nullptr, &ib)); + } + } + + // Pixels per second. + state.SetItemsProcessed(kNumIter * state.iterations() * xsize * ysize); + state.SetBytesProcessed(kNumIter * state.iterations() * interleaved.size()); +} + +BENCHMARK(BM_EncExternalImage_ConvertImageRGBA) + ->RangeMultiplier(2) + ->Range(256, 2048); + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_external_image_test.cc b/third_party/jpeg-xl/lib/jxl/enc_external_image_test.cc new file mode 100644 index 0000000000..de2e15ed16 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_external_image_test.cc @@ -0,0 +1,78 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_external_image.h" + +#include <array> +#include <new> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +#if !defined(JXL_CRASH_ON_ERROR) +TEST(ExternalImageTest, InvalidSize) { + ImageMetadata im; + im.SetAlphaBits(8); + ImageBundle ib(&im); + + JxlPixelFormat format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + const uint8_t buf[10 * 100 * 8] = {}; + EXPECT_FALSE(ConvertFromExternal(Bytes(buf, 10), /*xsize=*/10, /*ysize=*/100, + /*c_current=*/ColorEncoding::SRGB(), + /*bits_per_sample=*/16, format, nullptr, + &ib)); + EXPECT_FALSE(ConvertFromExternal( + Bytes(buf, sizeof(buf) - 1), /*xsize=*/10, /*ysize=*/100, + /*c_current=*/ColorEncoding::SRGB(), + /*bits_per_sample=*/16, format, nullptr, &ib)); + EXPECT_TRUE( + ConvertFromExternal(Bytes(buf, sizeof(buf)), /*xsize=*/10, + /*ysize=*/100, /*c_current=*/ColorEncoding::SRGB(), + /*bits_per_sample=*/16, format, nullptr, &ib)); +} +#endif + +TEST(ExternalImageTest, AlphaMissing) { + ImageMetadata im; + im.SetAlphaBits(0); // No alpha + ImageBundle ib(&im); + + const size_t xsize = 10; + const size_t ysize = 20; + const uint8_t buf[xsize * ysize * 4] = {}; + + JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_BIG_ENDIAN, 0}; + // has_alpha is true but the ImageBundle has no alpha. Alpha channel should + // be ignored. + EXPECT_TRUE(ConvertFromExternal(Bytes(buf, sizeof(buf)), xsize, ysize, + /*c_current=*/ColorEncoding::SRGB(), + /*bits_per_sample=*/8, format, nullptr, &ib)); + EXPECT_FALSE(ib.HasAlpha()); +} + +TEST(ExternalImageTest, AlphaPremultiplied) { + ImageMetadata im; + im.SetAlphaBits(8, true); + + ImageBundle ib(&im); + const size_t xsize = 10; + const size_t ysize = 20; + const size_t size = xsize * ysize * 8; + const uint8_t buf[size] = {}; + + JxlPixelFormat format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + EXPECT_TRUE(BufferToImageBundle(format, xsize, ysize, buf, size, nullptr, + ColorEncoding::SRGB(), &ib)); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc b/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc new file mode 100644 index 0000000000..b32d2478e0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc @@ -0,0 +1,4213 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef FJXL_SELF_INCLUDE + +#include "lib/jxl/enc_fast_lossless.h" + +#include <assert.h> +#include <stdint.h> +#include <string.h> + +#include <algorithm> +#include <array> +#include <limits> +#include <memory> +#include <vector> + +#if !FJXL_STANDALONE +#include "lib/jxl/encode_internal.h" +#endif + +// Enable NEON and AVX2/AVX512 if not asked to do otherwise and the compilers +// support it. +#if defined(__aarch64__) || defined(_M_ARM64) +#include <arm_neon.h> + +#ifndef FJXL_ENABLE_NEON +#define FJXL_ENABLE_NEON 1 +#endif + +#elif (defined(__x86_64__) || defined(_M_X64)) && !defined(_MSC_VER) +#include <immintrin.h> + +// manually add _mm512_cvtsi512_si32 definition if missing +// (e.g. with Xcode on macOS Mojave) +// copied from gcc 11.1.0 include/avx512fintrin.h line 14367-14373 +#if defined(__clang__) && \ + ((!defined(__apple_build_version__) && __clang_major__ < 10) || \ + (defined(__apple_build_version__) && __apple_build_version__ < 12000032)) +inline int __attribute__((__gnu_inline__, __always_inline__, __artificial__)) +_mm512_cvtsi512_si32(__m512i __A) { + __v16si __B = (__v16si)__A; + return __B[0]; +} +#endif + +// TODO(veluca): MSVC support for dynamic dispatch. +#if defined(__clang__) || defined(__GNUC__) + +#ifndef FJXL_ENABLE_AVX2 +#define FJXL_ENABLE_AVX2 1 +#endif + +#ifndef FJXL_ENABLE_AVX512 +// On clang-7 or earlier, and gcc-10 or earlier, AVX512 seems broken. +#if (defined(__clang__) && \ + (!defined(__apple_build_version__) && __clang_major__ > 7) || \ + (defined(__apple_build_version__) && \ + __apple_build_version__ > 10010046)) || \ + (defined(__GNUC__) && __GNUC__ > 10) +#define FJXL_ENABLE_AVX512 1 +#endif +#endif + +#endif + +#endif + +#ifndef FJXL_ENABLE_NEON +#define FJXL_ENABLE_NEON 0 +#endif + +#ifndef FJXL_ENABLE_AVX2 +#define FJXL_ENABLE_AVX2 0 +#endif + +#ifndef FJXL_ENABLE_AVX512 +#define FJXL_ENABLE_AVX512 0 +#endif + +namespace { +#if defined(_MSC_VER) && !defined(__clang__) +#define FJXL_INLINE __forceinline +FJXL_INLINE uint32_t FloorLog2(uint32_t v) { + unsigned long index; + _BitScanReverse(&index, v); + return index; +} +FJXL_INLINE uint32_t CtzNonZero(uint64_t v) { + unsigned long index; + _BitScanForward(&index, v); + return index; +} +#else +#define FJXL_INLINE inline __attribute__((always_inline)) +FJXL_INLINE uint32_t FloorLog2(uint32_t v) { + return v ? 31 - __builtin_clz(v) : 0; +} +FJXL_INLINE uint32_t CtzNonZero(uint64_t v) { return __builtin_ctzll(v); } +#endif + +// Compiles to a memcpy on little-endian systems. +FJXL_INLINE void StoreLE64(uint8_t* tgt, uint64_t data) { +#if (!defined(__BYTE_ORDER__) || (__BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__)) + for (int i = 0; i < 8; i++) { + tgt[i] = (data >> (i * 8)) & 0xFF; + } +#else + memcpy(tgt, &data, 8); +#endif +} + +FJXL_INLINE size_t AddBits(uint32_t count, uint64_t bits, uint8_t* data_buf, + size_t& bits_in_buffer, uint64_t& bit_buffer) { + bit_buffer |= bits << bits_in_buffer; + bits_in_buffer += count; + StoreLE64(data_buf, bit_buffer); + size_t bytes_in_buffer = bits_in_buffer / 8; + bits_in_buffer -= bytes_in_buffer * 8; + bit_buffer >>= bytes_in_buffer * 8; + return bytes_in_buffer; +} + +struct BitWriter { + void Allocate(size_t maximum_bit_size) { + assert(data == nullptr); + // Leave some padding. + data.reset(static_cast<uint8_t*>(malloc(maximum_bit_size / 8 + 64))); + } + + void Write(uint32_t count, uint64_t bits) { + bytes_written += AddBits(count, bits, data.get() + bytes_written, + bits_in_buffer, buffer); + } + + void ZeroPadToByte() { + if (bits_in_buffer != 0) { + Write(8 - bits_in_buffer, 0); + } + } + + FJXL_INLINE void WriteMultiple(const uint64_t* nbits, const uint64_t* bits, + size_t n) { + // Necessary because Write() is only guaranteed to work with <=56 bits. + // Trying to SIMD-fy this code results in lower speed (and definitely less + // clarity). + { + for (size_t i = 0; i < n; i++) { + this->buffer |= bits[i] << this->bits_in_buffer; + memcpy(this->data.get() + this->bytes_written, &this->buffer, 8); + uint64_t shift = 64 - this->bits_in_buffer; + this->bits_in_buffer += nbits[i]; + // This `if` seems to be faster than using ternaries. + if (this->bits_in_buffer >= 64) { + uint64_t next_buffer = bits[i] >> shift; + this->buffer = next_buffer; + this->bits_in_buffer -= 64; + this->bytes_written += 8; + } + } + memcpy(this->data.get() + this->bytes_written, &this->buffer, 8); + size_t bytes_in_buffer = this->bits_in_buffer / 8; + this->bits_in_buffer -= bytes_in_buffer * 8; + this->buffer >>= bytes_in_buffer * 8; + this->bytes_written += bytes_in_buffer; + } + } + + std::unique_ptr<uint8_t[], void (*)(void*)> data = {nullptr, free}; + size_t bytes_written = 0; + size_t bits_in_buffer = 0; + uint64_t buffer = 0; +}; + +size_t SectionSize(const std::array<BitWriter, 4>& group_data) { + size_t sz = 0; + for (size_t j = 0; j < 4; j++) { + const auto& writer = group_data[j]; + sz += writer.bytes_written * 8 + writer.bits_in_buffer; + } + sz = (sz + 7) / 8; + return sz; +} + +constexpr size_t kMaxFrameHeaderSize = 5; + +constexpr size_t kGroupSizeOffset[4] = { + static_cast<size_t>(0), + static_cast<size_t>(1024), + static_cast<size_t>(17408), + static_cast<size_t>(4211712), +}; +constexpr size_t kTOCBits[4] = {12, 16, 24, 32}; + +size_t TOCBucket(size_t group_size) { + size_t bucket = 0; + while (bucket < 3 && group_size >= kGroupSizeOffset[bucket + 1]) ++bucket; + return bucket; +} + +size_t TOCSize(const std::vector<size_t>& group_sizes) { + size_t toc_bits = 0; + for (size_t i = 0; i < group_sizes.size(); i++) { + toc_bits += kTOCBits[TOCBucket(group_sizes[i])]; + } + return (toc_bits + 7) / 8; +} + +size_t FrameHeaderSize(bool have_alpha, bool is_last) { + size_t nbits = 28 + (have_alpha ? 4 : 0) + (is_last ? 0 : 2); + return (nbits + 7) / 8; +} + +void ComputeAcGroupDataOffset(size_t dc_global_size, size_t num_dc_groups, + size_t num_ac_groups, size_t& min_dc_global_size, + size_t& ac_group_offset) { + // Max AC group size is 768 kB, so max AC group TOC bits is 24. + size_t ac_toc_max_bits = num_ac_groups * 24; + size_t ac_toc_min_bits = num_ac_groups * 12; + size_t max_padding = 1 + (ac_toc_max_bits - ac_toc_min_bits + 7) / 8; + min_dc_global_size = dc_global_size; + size_t dc_global_bucket = TOCBucket(min_dc_global_size); + while (TOCBucket(min_dc_global_size + max_padding) > dc_global_bucket) { + dc_global_bucket = TOCBucket(min_dc_global_size + max_padding); + min_dc_global_size = kGroupSizeOffset[dc_global_bucket]; + } + assert(TOCBucket(min_dc_global_size) == dc_global_bucket); + assert(TOCBucket(min_dc_global_size + max_padding) == dc_global_bucket); + size_t max_toc_bits = + kTOCBits[dc_global_bucket] + 12 * (1 + num_dc_groups) + ac_toc_max_bits; + size_t max_toc_size = (max_toc_bits + 7) / 8; + ac_group_offset = kMaxFrameHeaderSize + max_toc_size + min_dc_global_size; +} + +size_t ComputeDcGlobalPadding(const std::vector<size_t>& group_sizes, + size_t ac_group_data_offset, + size_t min_dc_global_size, bool have_alpha, + bool is_last) { + std::vector<size_t> new_group_sizes = group_sizes; + new_group_sizes[0] = min_dc_global_size; + size_t toc_size = TOCSize(new_group_sizes); + size_t actual_offset = + FrameHeaderSize(have_alpha, is_last) + toc_size + group_sizes[0]; + return ac_group_data_offset - actual_offset; +} + +constexpr size_t kNumRawSymbols = 19; +constexpr size_t kNumLZ77 = 33; +constexpr size_t kLZ77CacheSize = 32; + +constexpr size_t kLZ77Offset = 224; +constexpr size_t kLZ77MinLength = 7; + +void EncodeHybridUintLZ77(uint32_t value, uint32_t* token, uint32_t* nbits, + uint32_t* bits) { + // 400 config + uint32_t n = FloorLog2(value); + *token = value < 16 ? value : 16 + n - 4; + *nbits = value < 16 ? 0 : n; + *bits = value < 16 ? 0 : value - (1 << *nbits); +} + +struct PrefixCode { + uint8_t raw_nbits[kNumRawSymbols] = {}; + uint8_t raw_bits[kNumRawSymbols] = {}; + + uint8_t lz77_nbits[kNumLZ77] = {}; + uint16_t lz77_bits[kNumLZ77] = {}; + + uint64_t lz77_cache_bits[kLZ77CacheSize] = {}; + uint8_t lz77_cache_nbits[kLZ77CacheSize] = {}; + + size_t numraw; + + static uint16_t BitReverse(size_t nbits, uint16_t bits) { + constexpr uint16_t kNibbleLookup[16] = { + 0b0000, 0b1000, 0b0100, 0b1100, 0b0010, 0b1010, 0b0110, 0b1110, + 0b0001, 0b1001, 0b0101, 0b1101, 0b0011, 0b1011, 0b0111, 0b1111, + }; + uint16_t rev16 = (kNibbleLookup[bits & 0xF] << 12) | + (kNibbleLookup[(bits >> 4) & 0xF] << 8) | + (kNibbleLookup[(bits >> 8) & 0xF] << 4) | + (kNibbleLookup[bits >> 12]); + return rev16 >> (16 - nbits); + } + + // Create the prefix codes given the code lengths. + // Supports the code lengths being split into two halves. + static void ComputeCanonicalCode(const uint8_t* first_chunk_nbits, + uint8_t* first_chunk_bits, + size_t first_chunk_size, + const uint8_t* second_chunk_nbits, + uint16_t* second_chunk_bits, + size_t second_chunk_size) { + constexpr size_t kMaxCodeLength = 15; + uint8_t code_length_counts[kMaxCodeLength + 1] = {}; + for (size_t i = 0; i < first_chunk_size; i++) { + code_length_counts[first_chunk_nbits[i]]++; + assert(first_chunk_nbits[i] <= kMaxCodeLength); + assert(first_chunk_nbits[i] <= 8); + assert(first_chunk_nbits[i] > 0); + } + for (size_t i = 0; i < second_chunk_size; i++) { + code_length_counts[second_chunk_nbits[i]]++; + assert(second_chunk_nbits[i] <= kMaxCodeLength); + } + + uint16_t next_code[kMaxCodeLength + 1] = {}; + + uint16_t code = 0; + for (size_t i = 1; i < kMaxCodeLength + 1; i++) { + code = (code + code_length_counts[i - 1]) << 1; + next_code[i] = code; + } + + for (size_t i = 0; i < first_chunk_size; i++) { + first_chunk_bits[i] = + BitReverse(first_chunk_nbits[i], next_code[first_chunk_nbits[i]]++); + } + for (size_t i = 0; i < second_chunk_size; i++) { + second_chunk_bits[i] = + BitReverse(second_chunk_nbits[i], next_code[second_chunk_nbits[i]]++); + } + } + + template <typename T> + static void ComputeCodeLengthsNonZeroImpl(const uint64_t* freqs, size_t n, + size_t precision, T infty, + uint8_t* min_limit, + uint8_t* max_limit, + uint8_t* nbits) { + std::vector<T> dynp(((1U << precision) + 1) * (n + 1), infty); + auto d = [&](size_t sym, size_t off) -> T& { + return dynp[sym * ((1 << precision) + 1) + off]; + }; + d(0, 0) = 0; + for (size_t sym = 0; sym < n; sym++) { + for (T bits = min_limit[sym]; bits <= max_limit[sym]; bits++) { + size_t off_delta = 1U << (precision - bits); + for (size_t off = 0; off + off_delta <= (1U << precision); off++) { + d(sym + 1, off + off_delta) = + std::min(d(sym, off) + static_cast<T>(freqs[sym]) * bits, + d(sym + 1, off + off_delta)); + } + } + } + + size_t sym = n; + size_t off = 1U << precision; + + assert(d(sym, off) != infty); + + while (sym-- > 0) { + assert(off > 0); + for (size_t bits = min_limit[sym]; bits <= max_limit[sym]; bits++) { + size_t off_delta = 1U << (precision - bits); + if (off_delta <= off && + d(sym + 1, off) == d(sym, off - off_delta) + freqs[sym] * bits) { + off -= off_delta; + nbits[sym] = bits; + break; + } + } + } + } + + // Computes nbits[i] for i <= n, subject to min_limit[i] <= nbits[i] <= + // max_limit[i] and sum 2**-nbits[i] == 1, so to minimize sum(nbits[i] * + // freqs[i]). + static void ComputeCodeLengthsNonZero(const uint64_t* freqs, size_t n, + uint8_t* min_limit, uint8_t* max_limit, + uint8_t* nbits) { + size_t precision = 0; + size_t shortest_length = 255; + uint64_t freqsum = 0; + for (size_t i = 0; i < n; i++) { + assert(freqs[i] != 0); + freqsum += freqs[i]; + if (min_limit[i] < 1) min_limit[i] = 1; + assert(min_limit[i] <= max_limit[i]); + precision = std::max<size_t>(max_limit[i], precision); + shortest_length = std::min<size_t>(min_limit[i], shortest_length); + } + // If all the minimum limits are greater than 1, shift precision so that we + // behave as if the shortest was 1. + precision -= shortest_length - 1; + uint64_t infty = freqsum * precision; + if (infty < std::numeric_limits<uint32_t>::max() / 2) { + ComputeCodeLengthsNonZeroImpl(freqs, n, precision, + static_cast<uint32_t>(infty), min_limit, + max_limit, nbits); + } else { + ComputeCodeLengthsNonZeroImpl(freqs, n, precision, infty, min_limit, + max_limit, nbits); + } + } + + static constexpr size_t kMaxNumSymbols = + kNumRawSymbols + 1 < kNumLZ77 ? kNumLZ77 : kNumRawSymbols + 1; + static void ComputeCodeLengths(const uint64_t* freqs, size_t n, + const uint8_t* min_limit_in, + const uint8_t* max_limit_in, uint8_t* nbits) { + assert(n <= kMaxNumSymbols); + uint64_t compact_freqs[kMaxNumSymbols]; + uint8_t min_limit[kMaxNumSymbols]; + uint8_t max_limit[kMaxNumSymbols]; + size_t ni = 0; + for (size_t i = 0; i < n; i++) { + if (freqs[i]) { + compact_freqs[ni] = freqs[i]; + min_limit[ni] = min_limit_in[i]; + max_limit[ni] = max_limit_in[i]; + ni++; + } + } + uint8_t num_bits[kMaxNumSymbols] = {}; + ComputeCodeLengthsNonZero(compact_freqs, ni, min_limit, max_limit, + num_bits); + ni = 0; + for (size_t i = 0; i < n; i++) { + nbits[i] = 0; + if (freqs[i]) { + nbits[i] = num_bits[ni++]; + } + } + } + + // Invalid code, used to construct arrays. + PrefixCode() {} + + template <typename BitDepth> + PrefixCode(BitDepth, uint64_t* raw_counts, uint64_t* lz77_counts) { + // "merge" together all the lz77 counts in a single symbol for the level 1 + // table (containing just the raw symbols, up to length 7). + uint64_t level1_counts[kNumRawSymbols + 1]; + memcpy(level1_counts, raw_counts, kNumRawSymbols * sizeof(uint64_t)); + numraw = kNumRawSymbols; + while (numraw > 0 && level1_counts[numraw - 1] == 0) numraw--; + + level1_counts[numraw] = 0; + for (size_t i = 0; i < kNumLZ77; i++) { + level1_counts[numraw] += lz77_counts[i]; + } + uint8_t level1_nbits[kNumRawSymbols + 1] = {}; + ComputeCodeLengths(level1_counts, numraw + 1, BitDepth::kMinRawLength, + BitDepth::kMaxRawLength, level1_nbits); + + uint8_t level2_nbits[kNumLZ77] = {}; + uint8_t min_lengths[kNumLZ77] = {}; + uint8_t l = 15 - level1_nbits[numraw]; + uint8_t max_lengths[kNumLZ77]; + for (size_t i = 0; i < kNumLZ77; i++) { + max_lengths[i] = l; + } + size_t num_lz77 = kNumLZ77; + while (num_lz77 > 0 && lz77_counts[num_lz77 - 1] == 0) num_lz77--; + ComputeCodeLengths(lz77_counts, num_lz77, min_lengths, max_lengths, + level2_nbits); + for (size_t i = 0; i < numraw; i++) { + raw_nbits[i] = level1_nbits[i]; + } + for (size_t i = 0; i < num_lz77; i++) { + lz77_nbits[i] = + level2_nbits[i] ? level1_nbits[numraw] + level2_nbits[i] : 0; + } + + ComputeCanonicalCode(raw_nbits, raw_bits, numraw, lz77_nbits, lz77_bits, + kNumLZ77); + + // Prepare lz77 cache + for (size_t count = 0; count < kLZ77CacheSize; count++) { + unsigned token, nbits, bits; + EncodeHybridUintLZ77(count, &token, &nbits, &bits); + lz77_cache_nbits[count] = lz77_nbits[token] + nbits + raw_nbits[0]; + lz77_cache_bits[count] = + (((bits << lz77_nbits[token]) | lz77_bits[token]) << raw_nbits[0]) | + raw_bits[0]; + } + } + + // Max bits written: 2 + 72 + 95 + 24 + 165 = 286 + void WriteTo(BitWriter* writer) const { + uint64_t code_length_counts[18] = {}; + code_length_counts[17] = 3 + 2 * (kNumLZ77 - 1); + for (size_t i = 0; i < kNumRawSymbols; i++) { + code_length_counts[raw_nbits[i]]++; + } + for (size_t i = 0; i < kNumLZ77; i++) { + code_length_counts[lz77_nbits[i]]++; + } + uint8_t code_length_nbits[18] = {}; + uint8_t code_length_nbits_min[18] = {}; + uint8_t code_length_nbits_max[18] = { + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + }; + ComputeCodeLengths(code_length_counts, 18, code_length_nbits_min, + code_length_nbits_max, code_length_nbits); + writer->Write(2, 0b00); // HSKIP = 0, i.e. don't skip code lengths. + + // As per Brotli RFC. + uint8_t code_length_order[18] = {1, 2, 3, 4, 0, 5, 17, 6, 16, + 7, 8, 9, 10, 11, 12, 13, 14, 15}; + uint8_t code_length_length_nbits[] = {2, 4, 3, 2, 2, 4}; + uint8_t code_length_length_bits[] = {0, 7, 3, 2, 1, 15}; + + // Encode lengths of code lengths. + size_t num_code_lengths = 18; + while (code_length_nbits[code_length_order[num_code_lengths - 1]] == 0) { + num_code_lengths--; + } + // Max bits written in this loop: 18 * 4 = 72 + for (size_t i = 0; i < num_code_lengths; i++) { + int symbol = code_length_nbits[code_length_order[i]]; + writer->Write(code_length_length_nbits[symbol], + code_length_length_bits[symbol]); + } + + // Compute the canonical codes for the codes that represent the lengths of + // the actual codes for data. + uint16_t code_length_bits[18] = {}; + ComputeCanonicalCode(nullptr, nullptr, 0, code_length_nbits, + code_length_bits, 18); + // Encode raw bit code lengths. + // Max bits written in this loop: 19 * 5 = 95 + for (size_t i = 0; i < kNumRawSymbols; i++) { + writer->Write(code_length_nbits[raw_nbits[i]], + code_length_bits[raw_nbits[i]]); + } + size_t num_lz77 = kNumLZ77; + while (lz77_nbits[num_lz77 - 1] == 0) { + num_lz77--; + } + // Encode 0s until 224 (start of LZ77 symbols). This is in total 224-19 = + // 205. + static_assert(kLZ77Offset == 224, ""); + static_assert(kNumRawSymbols == 19, ""); + { + // Max bits in this block: 24 + writer->Write(code_length_nbits[17], code_length_bits[17]); + writer->Write(3, 0b010); // 5 + writer->Write(code_length_nbits[17], code_length_bits[17]); + writer->Write(3, 0b000); // (5-2)*8 + 3 = 27 + writer->Write(code_length_nbits[17], code_length_bits[17]); + writer->Write(3, 0b010); // (27-2)*8 + 5 = 205 + } + // Encode LZ77 symbols, with values 224+i. + // Max bits written in this loop: 33 * 5 = 165 + for (size_t i = 0; i < num_lz77; i++) { + writer->Write(code_length_nbits[lz77_nbits[i]], + code_length_bits[lz77_nbits[i]]); + } + } +}; + +} // namespace + +extern "C" { + +struct JxlFastLosslessFrameState { + JxlChunkedFrameInputSource input; + size_t width; + size_t height; + size_t num_groups_x; + size_t num_groups_y; + size_t num_dc_groups_x; + size_t num_dc_groups_y; + size_t nb_chans; + size_t bitdepth; + int big_endian; + int effort; + bool collided; + PrefixCode hcode[4]; + std::vector<int16_t> lookup; + BitWriter header; + std::vector<std::array<BitWriter, 4>> group_data; + std::vector<size_t> group_sizes; + size_t ac_group_data_offset = 0; + size_t min_dc_global_size = 0; + size_t current_bit_writer = 0; + size_t bit_writer_byte_pos = 0; + size_t bits_in_buffer = 0; + uint64_t bit_buffer = 0; + bool process_done = false; +}; + +size_t JxlFastLosslessOutputSize(const JxlFastLosslessFrameState* frame) { + size_t total_size_groups = 0; + for (size_t i = 0; i < frame->group_data.size(); i++) { + total_size_groups += SectionSize(frame->group_data[i]); + } + return frame->header.bytes_written + total_size_groups; +} + +size_t JxlFastLosslessMaxRequiredOutput( + const JxlFastLosslessFrameState* frame) { + return JxlFastLosslessOutputSize(frame) + 32; +} + +void JxlFastLosslessPrepareHeader(JxlFastLosslessFrameState* frame, + int add_image_header, int is_last) { + BitWriter* output = &frame->header; + output->Allocate(1000 + frame->group_sizes.size() * 32); + + bool have_alpha = (frame->nb_chans == 2 || frame->nb_chans == 4); + +#if FJXL_STANDALONE + if (add_image_header) { + // Signature + output->Write(16, 0x0AFF); + + // Size header, hand-crafted. + // Not small + output->Write(1, 0); + + auto wsz = [output](size_t size) { + if (size - 1 < (1 << 9)) { + output->Write(2, 0b00); + output->Write(9, size - 1); + } else if (size - 1 < (1 << 13)) { + output->Write(2, 0b01); + output->Write(13, size - 1); + } else if (size - 1 < (1 << 18)) { + output->Write(2, 0b10); + output->Write(18, size - 1); + } else { + output->Write(2, 0b11); + output->Write(30, size - 1); + } + }; + + wsz(frame->height); + + // No special ratio. + output->Write(3, 0); + + wsz(frame->width); + + // Hand-crafted ImageMetadata. + output->Write(1, 0); // all_default + output->Write(1, 0); // extra_fields + output->Write(1, 0); // bit_depth.floating_point_sample + if (frame->bitdepth == 8) { + output->Write(2, 0b00); // bit_depth.bits_per_sample = 8 + } else if (frame->bitdepth == 10) { + output->Write(2, 0b01); // bit_depth.bits_per_sample = 10 + } else if (frame->bitdepth == 12) { + output->Write(2, 0b10); // bit_depth.bits_per_sample = 12 + } else { + output->Write(2, 0b11); // 1 + u(6) + output->Write(6, frame->bitdepth - 1); + } + if (frame->bitdepth <= 14) { + output->Write(1, 1); // 16-bit-buffer sufficient + } else { + output->Write(1, 0); // 16-bit-buffer NOT sufficient + } + if (have_alpha) { + output->Write(2, 0b01); // One extra channel + output->Write(1, 1); // ... all_default (ie. 8-bit alpha) + } else { + output->Write(2, 0b00); // No extra channel + } + output->Write(1, 0); // Not XYB + if (frame->nb_chans > 2) { + output->Write(1, 1); // color_encoding.all_default (sRGB) + } else { + output->Write(1, 0); // color_encoding.all_default false + output->Write(1, 0); // color_encoding.want_icc false + output->Write(2, 1); // grayscale + output->Write(2, 1); // D65 + output->Write(1, 0); // no gamma transfer function + output->Write(2, 0b10); // tf: 2 + u(4) + output->Write(4, 11); // tf of sRGB + output->Write(2, 1); // relative rendering intent + } + output->Write(2, 0b00); // No extensions. + + output->Write(1, 1); // all_default transform data + + // No ICC, no preview. Frame should start at byte boundery. + output->ZeroPadToByte(); + } +#else + assert(!add_image_header); +#endif + // Handcrafted frame header. + output->Write(1, 0); // all_default + output->Write(2, 0b00); // regular frame + output->Write(1, 1); // modular + output->Write(2, 0b00); // default flags + output->Write(1, 0); // not YCbCr + output->Write(2, 0b00); // no upsampling + if (have_alpha) { + output->Write(2, 0b00); // no alpha upsampling + } + output->Write(2, 0b01); // default group size + output->Write(2, 0b00); // exactly one pass + output->Write(1, 0); // no custom size or origin + output->Write(2, 0b00); // kReplace blending mode + if (have_alpha) { + output->Write(2, 0b00); // kReplace blending mode for alpha channel + } + output->Write(1, is_last); // is_last + if (!is_last) { + output->Write(2, 0b00); // can not be saved as reference + } + output->Write(2, 0b00); // a frame has no name + output->Write(1, 0); // loop filter is not all_default + output->Write(1, 0); // no gaborish + output->Write(2, 0); // 0 EPF iters + output->Write(2, 0b00); // No LF extensions + output->Write(2, 0b00); // No FH extensions + + output->Write(1, 0); // No TOC permutation + output->ZeroPadToByte(); // TOC is byte-aligned. + assert(add_image_header || output->bytes_written <= kMaxFrameHeaderSize); + for (size_t i = 0; i < frame->group_sizes.size(); i++) { + size_t sz = frame->group_sizes[i]; + size_t bucket = TOCBucket(sz); + output->Write(2, bucket); + output->Write(kTOCBits[bucket] - 2, sz - kGroupSizeOffset[bucket]); + } + output->ZeroPadToByte(); // Groups are byte-aligned. +} + +#if !FJXL_STANDALONE +void JxlFastLosslessOutputAlignedSection( + const BitWriter& bw, JxlEncoderOutputProcessorWrapper* output_processor) { + assert(bw.bits_in_buffer == 0); + const uint8_t* data = bw.data.get(); + size_t remaining_len = bw.bytes_written; + while (remaining_len > 0) { + auto retval = output_processor->GetBuffer(1, remaining_len); + assert(retval.status()); + auto buffer = std::move(retval).value(); + size_t n = std::min(buffer.size(), remaining_len); + if (n == 0) break; + memcpy(buffer.data(), data, n); + buffer.advance(n); + data += n; + remaining_len -= n; + }; +} + +void JxlFastLosslessOutputHeaders( + JxlFastLosslessFrameState* frame_state, + JxlEncoderOutputProcessorWrapper* output_processor) { + JxlFastLosslessOutputAlignedSection(frame_state->header, output_processor); + JxlFastLosslessOutputAlignedSection(frame_state->group_data[0][0], + output_processor); +} +#endif + +#if FJXL_ENABLE_AVX512 +__attribute__((target("avx512vbmi2"))) static size_t AppendBytesWithBitOffset( + const uint8_t* data, size_t n, size_t bit_buffer_nbits, + unsigned char* output, uint64_t& bit_buffer) { + if (n < 128) { + return 0; + } + + size_t i = 0; + __m512i shift = _mm512_set1_epi64(64 - bit_buffer_nbits); + __m512i carry = _mm512_set1_epi64(bit_buffer << (64 - bit_buffer_nbits)); + + for (; i + 64 <= n; i += 64) { + __m512i current = _mm512_loadu_si512(data + i); + __m512i previous_u64 = _mm512_alignr_epi64(current, carry, 7); + carry = current; + __m512i out = _mm512_shrdv_epi64(previous_u64, current, shift); + _mm512_storeu_si512(output + i, out); + } + + bit_buffer = data[i - 1] >> (8 - bit_buffer_nbits); + + return i; +} +#endif + +size_t JxlFastLosslessWriteOutput(JxlFastLosslessFrameState* frame, + unsigned char* output, size_t output_size) { + assert(output_size >= 32); + unsigned char* initial_output = output; + size_t (*append_bytes_with_bit_offset)(const uint8_t*, size_t, size_t, + unsigned char*, uint64_t&) = nullptr; + +#if FJXL_ENABLE_AVX512 + if (__builtin_cpu_supports("avx512vbmi2")) { + append_bytes_with_bit_offset = AppendBytesWithBitOffset; + } +#endif + + while (true) { + size_t& cur = frame->current_bit_writer; + size_t& bw_pos = frame->bit_writer_byte_pos; + if (cur >= 1 + frame->group_data.size() * frame->nb_chans) { + return output - initial_output; + } + if (output_size <= 9) { + return output - initial_output; + } + size_t nbc = frame->nb_chans; + const BitWriter& writer = + cur == 0 ? frame->header + : frame->group_data[(cur - 1) / nbc][(cur - 1) % nbc]; + size_t full_byte_count = + std::min(output_size - 9, writer.bytes_written - bw_pos); + if (frame->bits_in_buffer == 0) { + memcpy(output, writer.data.get() + bw_pos, full_byte_count); + } else { + size_t i = 0; + if (append_bytes_with_bit_offset) { + i += append_bytes_with_bit_offset( + writer.data.get() + bw_pos, full_byte_count, frame->bits_in_buffer, + output, frame->bit_buffer); + } +#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) + // Copy 8 bytes at a time until we reach the border. + for (; i + 8 < full_byte_count; i += 8) { + uint64_t chunk; + memcpy(&chunk, writer.data.get() + bw_pos + i, 8); + uint64_t out = frame->bit_buffer | (chunk << frame->bits_in_buffer); + memcpy(output + i, &out, 8); + frame->bit_buffer = chunk >> (64 - frame->bits_in_buffer); + } +#endif + for (; i < full_byte_count; i++) { + AddBits(8, writer.data.get()[bw_pos + i], output + i, + frame->bits_in_buffer, frame->bit_buffer); + } + } + output += full_byte_count; + output_size -= full_byte_count; + bw_pos += full_byte_count; + if (bw_pos == writer.bytes_written) { + auto write = [&](size_t num, uint64_t bits) { + size_t n = AddBits(num, bits, output, frame->bits_in_buffer, + frame->bit_buffer); + output += n; + output_size -= n; + }; + if (writer.bits_in_buffer) { + write(writer.bits_in_buffer, writer.buffer); + } + bw_pos = 0; + cur++; + if ((cur - 1) % nbc == 0 && frame->bits_in_buffer != 0) { + write(8 - frame->bits_in_buffer, 0); + } + } + } +} + +void JxlFastLosslessFreeFrameState(JxlFastLosslessFrameState* frame) { + delete frame; +} + +} // extern "C" + +#endif + +#ifdef FJXL_SELF_INCLUDE + +namespace { + +template <typename T> +struct VecPair { + T low; + T hi; +}; + +#ifdef FJXL_GENERIC_SIMD +#undef FJXL_GENERIC_SIMD +#endif + +#ifdef FJXL_AVX512 +#define FJXL_GENERIC_SIMD +struct SIMDVec32; +struct Mask32 { + __mmask16 mask; + SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false); + size_t CountPrefix() const { + return CtzNonZero(~uint64_t{_cvtmask16_u32(mask)}); + } +}; + +struct SIMDVec32 { + __m512i vec; + + static constexpr size_t kLanes = 16; + + FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) { + return SIMDVec32{_mm512_loadu_si512((__m512i*)data)}; + } + FJXL_INLINE void Store(uint32_t* data) { + _mm512_storeu_si512((__m512i*)data, vec); + } + FJXL_INLINE static SIMDVec32 Val(uint32_t v) { + return SIMDVec32{_mm512_set1_epi32(v)}; + } + FJXL_INLINE SIMDVec32 ValToToken() const { + return SIMDVec32{ + _mm512_sub_epi32(_mm512_set1_epi32(32), _mm512_lzcnt_epi32(vec))}; + } + FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const { + return SIMDVec32{_mm512_sub_epi32(_mm512_max_epu32(vec, to_subtract.vec), + to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const { + return SIMDVec32{_mm512_sub_epi32(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const { + return SIMDVec32{_mm512_add_epi32(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const { + return SIMDVec32{_mm512_xor_epi32(vec, oth.vec)}; + } + FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const { + return Mask32{_mm512_cmpeq_epi32_mask(vec, oth.vec)}; + } + FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const { + return Mask32{_mm512_cmpgt_epi32_mask(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Pow2() const { + return SIMDVec32{_mm512_sllv_epi32(_mm512_set1_epi32(1), vec)}; + } + template <size_t i> + FJXL_INLINE SIMDVec32 SignedShiftRight() const { + return SIMDVec32{_mm512_srai_epi32(vec, i)}; + } +}; + +struct SIMDVec16; + +struct Mask16 { + __mmask32 mask; + SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false); + Mask16 And(const Mask16& oth) const { + return Mask16{_kand_mask32(mask, oth.mask)}; + } + size_t CountPrefix() const { + return CtzNonZero(~uint64_t{_cvtmask32_u32(mask)}); + } +}; + +struct SIMDVec16 { + __m512i vec; + + static constexpr size_t kLanes = 32; + + FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) { + return SIMDVec16{_mm512_loadu_si512((__m512i*)data)}; + } + FJXL_INLINE void Store(uint16_t* data) { + _mm512_storeu_si512((__m512i*)data, vec); + } + FJXL_INLINE static SIMDVec16 Val(uint16_t v) { + return SIMDVec16{_mm512_set1_epi16(v)}; + } + FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo, + const SIMDVec32& hi) { + auto tmp = _mm512_packus_epi32(lo.vec, hi.vec); + alignas(64) uint64_t perm[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return SIMDVec16{ + _mm512_permutex2var_epi64(tmp, _mm512_load_si512((__m512i*)perm), tmp)}; + } + + FJXL_INLINE SIMDVec16 ValToToken() const { + auto c16 = _mm512_set1_epi32(16); + auto c32 = _mm512_set1_epi32(32); + auto low16bit = _mm512_set1_epi32(0x0000FFFF); + auto lzhi = + _mm512_sub_epi32(c16, _mm512_min_epu32(c16, _mm512_lzcnt_epi32(vec))); + auto lzlo = _mm512_sub_epi32( + c32, _mm512_lzcnt_epi32(_mm512_and_si512(low16bit, vec))); + return SIMDVec16{_mm512_or_si512(lzlo, _mm512_slli_epi32(lzhi, 16))}; + } + + FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const { + return SIMDVec16{_mm512_subs_epu16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const { + return SIMDVec16{_mm512_sub_epi16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_add_epi16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_min_epu16(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const { + return Mask16{_mm512_cmpeq_epi16_mask(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const { + return Mask16{_mm512_cmpgt_epi16_mask(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Pow2() const { + return SIMDVec16{_mm512_sllv_epi16(_mm512_set1_epi16(1), vec)}; + } + FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_or_si512(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_xor_si512(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_and_si512(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_srai_epi16(_mm512_add_epi16(vec, oth.vec), 1)}; + } + FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const { + return SIMDVec16{_mm512_or_si512(vec, _mm512_set1_epi16(0xFF00))}; + } + FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const { + return SIMDVec16{_mm512_shuffle_epi8( + _mm512_broadcast_i32x4(_mm_loadu_si128((__m128i*)table)), vec)}; + } + FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const { + auto lo = _mm512_unpacklo_epi16(low.vec, vec); + auto hi = _mm512_unpackhi_epi16(low.vec, vec); + alignas(64) uint64_t perm1[8] = {0, 1, 8, 9, 2, 3, 10, 11}; + alignas(64) uint64_t perm2[8] = {4, 5, 12, 13, 6, 7, 14, 15}; + return {SIMDVec16{_mm512_permutex2var_epi64( + lo, _mm512_load_si512((__m512i*)perm1), hi)}, + SIMDVec16{_mm512_permutex2var_epi64( + lo, _mm512_load_si512((__m512i*)perm2), hi)}}; + } + FJXL_INLINE VecPair<SIMDVec32> Upcast() const { + auto lo = _mm512_unpacklo_epi16(vec, _mm512_setzero_si512()); + auto hi = _mm512_unpackhi_epi16(vec, _mm512_setzero_si512()); + alignas(64) uint64_t perm1[8] = {0, 1, 8, 9, 2, 3, 10, 11}; + alignas(64) uint64_t perm2[8] = {4, 5, 12, 13, 6, 7, 14, 15}; + return {SIMDVec32{_mm512_permutex2var_epi64( + lo, _mm512_load_si512((__m512i*)perm1), hi)}, + SIMDVec32{_mm512_permutex2var_epi64( + lo, _mm512_load_si512((__m512i*)perm2), hi)}}; + } + template <size_t i> + FJXL_INLINE SIMDVec16 SignedShiftRight() const { + return SIMDVec16{_mm512_srai_epi16(vec, i)}; + } + + static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) { + __m256i bytes = _mm256_loadu_si256((__m256i*)data); + return {SIMDVec16{_mm512_cvtepu8_epi16(bytes)}}; + } + static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) { + return {Load((const uint16_t*)data)}; + } + + static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) { + __m512i bytes = _mm512_loadu_si512((__m512i*)data); + __m512i gray = _mm512_and_si512(bytes, _mm512_set1_epi16(0xFF)); + __m512i alpha = _mm512_srli_epi16(bytes, 8); + return {SIMDVec16{gray}, SIMDVec16{alpha}}; + } + static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) { + __m512i bytes1 = _mm512_loadu_si512((__m512i*)data); + __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 64)); + __m512i g_mask = _mm512_set1_epi32(0xFFFF); + __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + __m512i g = _mm512_permutexvar_epi64( + permuteidx, _mm512_packus_epi32(_mm512_and_si512(bytes1, g_mask), + _mm512_and_si512(bytes2, g_mask))); + __m512i a = _mm512_permutexvar_epi64( + permuteidx, _mm512_packus_epi32(_mm512_srli_epi32(bytes1, 16), + _mm512_srli_epi32(bytes2, 16))); + return {SIMDVec16{g}, SIMDVec16{a}}; + } + + static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) { + __m512i bytes0 = _mm512_loadu_si512((__m512i*)data); + __m512i bytes1 = + _mm512_zextsi256_si512(_mm256_loadu_si256((__m256i*)(data + 64))); + + // 0x7A = element of upper half of second vector = 0 after lookup; still in + // the upper half once we add 1 or 2. + uint8_t z = 0x7A; + __m512i ridx = + _mm512_set_epi8(z, 93, z, 90, z, 87, z, 84, z, 81, z, 78, z, 75, z, 72, + z, 69, z, 66, z, 63, z, 60, z, 57, z, 54, z, 51, z, 48, + z, 45, z, 42, z, 39, z, 36, z, 33, z, 30, z, 27, z, 24, + z, 21, z, 18, z, 15, z, 12, z, 9, z, 6, z, 3, z, 0); + __m512i gidx = _mm512_add_epi8(ridx, _mm512_set1_epi8(1)); + __m512i bidx = _mm512_add_epi8(gidx, _mm512_set1_epi8(1)); + __m512i r = _mm512_permutex2var_epi8(bytes0, ridx, bytes1); + __m512i g = _mm512_permutex2var_epi8(bytes0, gidx, bytes1); + __m512i b = _mm512_permutex2var_epi8(bytes0, bidx, bytes1); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}}; + } + static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) { + __m512i bytes0 = _mm512_loadu_si512((__m512i*)data); + __m512i bytes1 = _mm512_loadu_si512((__m512i*)(data + 64)); + __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 128)); + + __m512i ridx_lo = _mm512_set_epi16(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 63, 60, 57, + 54, 51, 48, 45, 42, 39, 36, 33, 30, 27, + 24, 21, 18, 15, 12, 9, 6, 3, 0); + // -1 is such that when adding 1 or 2, we get the correct index for + // green/blue. + __m512i ridx_hi = + _mm512_set_epi16(29, 26, 23, 20, 17, 14, 11, 8, 5, 2, -1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m512i gidx_lo = _mm512_add_epi16(ridx_lo, _mm512_set1_epi16(1)); + __m512i gidx_hi = _mm512_add_epi16(ridx_hi, _mm512_set1_epi16(1)); + __m512i bidx_lo = _mm512_add_epi16(gidx_lo, _mm512_set1_epi16(1)); + __m512i bidx_hi = _mm512_add_epi16(gidx_hi, _mm512_set1_epi16(1)); + + __mmask32 rmask = _cvtu32_mask32(0b11111111110000000000000000000000); + __mmask32 gbmask = _cvtu32_mask32(0b11111111111000000000000000000000); + + __m512i rlo = _mm512_permutex2var_epi16(bytes0, ridx_lo, bytes1); + __m512i glo = _mm512_permutex2var_epi16(bytes0, gidx_lo, bytes1); + __m512i blo = _mm512_permutex2var_epi16(bytes0, bidx_lo, bytes1); + __m512i r = _mm512_mask_permutexvar_epi16(rlo, rmask, ridx_hi, bytes2); + __m512i g = _mm512_mask_permutexvar_epi16(glo, gbmask, gidx_hi, bytes2); + __m512i b = _mm512_mask_permutexvar_epi16(blo, gbmask, bidx_hi, bytes2); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}}; + } + + static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) { + __m512i bytes1 = _mm512_loadu_si512((__m512i*)data); + __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 64)); + __m512i rg_mask = _mm512_set1_epi32(0xFFFF); + __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + __m512i rg = _mm512_permutexvar_epi64( + permuteidx, _mm512_packus_epi32(_mm512_and_si512(bytes1, rg_mask), + _mm512_and_si512(bytes2, rg_mask))); + __m512i ba = _mm512_permutexvar_epi64( + permuteidx, _mm512_packus_epi32(_mm512_srli_epi32(bytes1, 16), + _mm512_srli_epi32(bytes2, 16))); + __m512i r = _mm512_and_si512(rg, _mm512_set1_epi16(0xFF)); + __m512i g = _mm512_srli_epi16(rg, 8); + __m512i b = _mm512_and_si512(ba, _mm512_set1_epi16(0xFF)); + __m512i a = _mm512_srli_epi16(ba, 8); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}}; + } + static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) { + __m512i bytes0 = _mm512_loadu_si512((__m512i*)data); + __m512i bytes1 = _mm512_loadu_si512((__m512i*)(data + 64)); + __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 128)); + __m512i bytes3 = _mm512_loadu_si512((__m512i*)(data + 192)); + + auto pack32 = [](__m512i a, __m512i b) { + __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(permuteidx, _mm512_packus_epi32(a, b)); + }; + auto packlow32 = [&pack32](__m512i a, __m512i b) { + __m512i mask = _mm512_set1_epi32(0xFFFF); + return pack32(_mm512_and_si512(a, mask), _mm512_and_si512(b, mask)); + }; + auto packhi32 = [&pack32](__m512i a, __m512i b) { + return pack32(_mm512_srli_epi32(a, 16), _mm512_srli_epi32(b, 16)); + }; + + __m512i rb0 = packlow32(bytes0, bytes1); + __m512i rb1 = packlow32(bytes2, bytes3); + __m512i ga0 = packhi32(bytes0, bytes1); + __m512i ga1 = packhi32(bytes2, bytes3); + + __m512i r = packlow32(rb0, rb1); + __m512i g = packlow32(ga0, ga1); + __m512i b = packhi32(rb0, rb1); + __m512i a = packhi32(ga0, ga1); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}}; + } + + void SwapEndian() { + auto indices = _mm512_broadcast_i32x4( + _mm_setr_epi8(1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14)); + vec = _mm512_shuffle_epi8(vec, indices); + } +}; + +SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true, + const SIMDVec16& if_false) { + return SIMDVec16{_mm512_mask_blend_epi16(mask, if_false.vec, if_true.vec)}; +} + +SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true, + const SIMDVec32& if_false) { + return SIMDVec32{_mm512_mask_blend_epi32(mask, if_false.vec, if_true.vec)}; +} + +struct Bits64 { + static constexpr size_t kLanes = 8; + + __m512i nbits; + __m512i bits; + + FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) { + _mm512_storeu_si512((__m512i*)nbits_out, nbits); + _mm512_storeu_si512((__m512i*)bits_out, bits); + } +}; + +struct Bits32 { + __m512i nbits; + __m512i bits; + + static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) { + return Bits32{nbits.vec, bits.vec}; + } + + Bits64 Merge() const { + auto nbits_hi32 = _mm512_srli_epi64(nbits, 32); + auto nbits_lo32 = _mm512_and_si512(nbits, _mm512_set1_epi64(0xFFFFFFFF)); + auto bits_hi32 = _mm512_srli_epi64(bits, 32); + auto bits_lo32 = _mm512_and_si512(bits, _mm512_set1_epi64(0xFFFFFFFF)); + + auto nbits64 = _mm512_add_epi64(nbits_hi32, nbits_lo32); + auto bits64 = + _mm512_or_si512(_mm512_sllv_epi64(bits_hi32, nbits_lo32), bits_lo32); + return Bits64{nbits64, bits64}; + } + + void Interleave(const Bits32& low) { + bits = _mm512_or_si512(_mm512_sllv_epi32(bits, low.nbits), low.bits); + nbits = _mm512_add_epi32(nbits, low.nbits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 16); + constexpr uint32_t kMask[32] = { + ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, + ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 16 - n)); + nbits = _mm512_and_si512(mask, nbits); + bits = _mm512_and_si512(mask, bits); + } + void Skip(size_t n) { + n = std::min<size_t>(n, 16); + constexpr uint32_t kMask[32] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, + ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, + }; + __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 16 - n)); + nbits = _mm512_and_si512(mask, nbits); + bits = _mm512_and_si512(mask, bits); + } +}; + +struct Bits16 { + __m512i nbits; + __m512i bits; + + static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) { + return Bits16{nbits.vec, bits.vec}; + } + + Bits32 Merge() const { + auto nbits_hi16 = _mm512_srli_epi32(nbits, 16); + auto nbits_lo16 = _mm512_and_si512(nbits, _mm512_set1_epi32(0xFFFF)); + auto bits_hi16 = _mm512_srli_epi32(bits, 16); + auto bits_lo16 = _mm512_and_si512(bits, _mm512_set1_epi32(0xFFFF)); + + auto nbits32 = _mm512_add_epi32(nbits_hi16, nbits_lo16); + auto bits32 = + _mm512_or_si512(_mm512_sllv_epi32(bits_hi16, nbits_lo16), bits_lo16); + return Bits32{nbits32, bits32}; + } + + void Interleave(const Bits16& low) { + bits = _mm512_or_si512(_mm512_sllv_epi16(bits, low.nbits), low.bits); + nbits = _mm512_add_epi16(nbits, low.nbits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 32); + constexpr uint16_t kMask[64] = { + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }; + __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 32 - n)); + nbits = _mm512_and_si512(mask, nbits); + bits = _mm512_and_si512(mask, bits); + } + void Skip(size_t n) { + n = std::min<size_t>(n, 32); + constexpr uint16_t kMask[64] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + }; + __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 32 - n)); + nbits = _mm512_and_si512(mask, nbits); + bits = _mm512_and_si512(mask, bits); + } +}; + +#endif + +#ifdef FJXL_AVX2 +#define FJXL_GENERIC_SIMD + +struct SIMDVec32; + +struct Mask32 { + __m256i mask; + SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false); + size_t CountPrefix() const { + return CtzNonZero(~static_cast<uint64_t>( + (uint8_t)_mm256_movemask_ps(_mm256_castsi256_ps(mask)))); + } +}; + +struct SIMDVec32 { + __m256i vec; + + static constexpr size_t kLanes = 8; + + FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) { + return SIMDVec32{_mm256_loadu_si256((__m256i*)data)}; + } + FJXL_INLINE void Store(uint32_t* data) { + _mm256_storeu_si256((__m256i*)data, vec); + } + FJXL_INLINE static SIMDVec32 Val(uint32_t v) { + return SIMDVec32{_mm256_set1_epi32(v)}; + } + FJXL_INLINE SIMDVec32 ValToToken() const { + // we know that each value has at most 20 bits, so we just need 5 nibbles + // and don't need to mask the fifth. However we do need to set the higher + // bytes to 0xFF, which will make table lookups return 0. + auto nibble0 = + _mm256_or_si256(_mm256_and_si256(vec, _mm256_set1_epi32(0xF)), + _mm256_set1_epi32(0xFFFFFF00)); + auto nibble1 = _mm256_or_si256( + _mm256_and_si256(_mm256_srli_epi32(vec, 4), _mm256_set1_epi32(0xF)), + _mm256_set1_epi32(0xFFFFFF00)); + auto nibble2 = _mm256_or_si256( + _mm256_and_si256(_mm256_srli_epi32(vec, 8), _mm256_set1_epi32(0xF)), + _mm256_set1_epi32(0xFFFFFF00)); + auto nibble3 = _mm256_or_si256( + _mm256_and_si256(_mm256_srli_epi32(vec, 12), _mm256_set1_epi32(0xF)), + _mm256_set1_epi32(0xFFFFFF00)); + auto nibble4 = _mm256_or_si256(_mm256_srli_epi32(vec, 16), + _mm256_set1_epi32(0xFFFFFF00)); + + auto lut0 = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4)); + auto lut1 = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8)); + auto lut2 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 9, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12)); + auto lut3 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 13, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16)); + auto lut4 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 17, 18, 18, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20)); + + auto token0 = _mm256_shuffle_epi8(lut0, nibble0); + auto token1 = _mm256_shuffle_epi8(lut1, nibble1); + auto token2 = _mm256_shuffle_epi8(lut2, nibble2); + auto token3 = _mm256_shuffle_epi8(lut3, nibble3); + auto token4 = _mm256_shuffle_epi8(lut4, nibble4); + + auto token = + _mm256_max_epi32(_mm256_max_epi32(_mm256_max_epi32(token0, token1), + _mm256_max_epi32(token2, token3)), + token4); + return SIMDVec32{token}; + } + FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const { + return SIMDVec32{_mm256_sub_epi32(_mm256_max_epu32(vec, to_subtract.vec), + to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const { + return SIMDVec32{_mm256_sub_epi32(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const { + return SIMDVec32{_mm256_add_epi32(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const { + return SIMDVec32{_mm256_xor_si256(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Pow2() const { + return SIMDVec32{_mm256_sllv_epi32(_mm256_set1_epi32(1), vec)}; + } + FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const { + return Mask32{_mm256_cmpeq_epi32(vec, oth.vec)}; + } + FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const { + return Mask32{_mm256_cmpgt_epi32(vec, oth.vec)}; + } + template <size_t i> + FJXL_INLINE SIMDVec32 SignedShiftRight() const { + return SIMDVec32{_mm256_srai_epi32(vec, i)}; + } +}; + +struct SIMDVec16; + +struct Mask16 { + __m256i mask; + SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false); + Mask16 And(const Mask16& oth) const { + return Mask16{_mm256_and_si256(mask, oth.mask)}; + } + size_t CountPrefix() const { + return CtzNonZero( + ~static_cast<uint64_t>((uint32_t)_mm256_movemask_epi8(mask))) / + 2; + } +}; + +struct SIMDVec16 { + __m256i vec; + + static constexpr size_t kLanes = 16; + + FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) { + return SIMDVec16{_mm256_loadu_si256((__m256i*)data)}; + } + FJXL_INLINE void Store(uint16_t* data) { + _mm256_storeu_si256((__m256i*)data, vec); + } + FJXL_INLINE static SIMDVec16 Val(uint16_t v) { + return SIMDVec16{_mm256_set1_epi16(v)}; + } + FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo, + const SIMDVec32& hi) { + auto tmp = _mm256_packus_epi32(lo.vec, hi.vec); + return SIMDVec16{_mm256_permute4x64_epi64(tmp, 0b11011000)}; + } + + FJXL_INLINE SIMDVec16 ValToToken() const { + auto nibble0 = + _mm256_or_si256(_mm256_and_si256(vec, _mm256_set1_epi16(0xF)), + _mm256_set1_epi16(0xFF00)); + auto nibble1 = _mm256_or_si256( + _mm256_and_si256(_mm256_srli_epi16(vec, 4), _mm256_set1_epi16(0xF)), + _mm256_set1_epi16(0xFF00)); + auto nibble2 = _mm256_or_si256( + _mm256_and_si256(_mm256_srli_epi16(vec, 8), _mm256_set1_epi16(0xF)), + _mm256_set1_epi16(0xFF00)); + auto nibble3 = + _mm256_or_si256(_mm256_srli_epi16(vec, 12), _mm256_set1_epi16(0xFF00)); + + auto lut0 = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4)); + auto lut1 = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8)); + auto lut2 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 9, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12)); + auto lut3 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 13, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16)); + + auto token0 = _mm256_shuffle_epi8(lut0, nibble0); + auto token1 = _mm256_shuffle_epi8(lut1, nibble1); + auto token2 = _mm256_shuffle_epi8(lut2, nibble2); + auto token3 = _mm256_shuffle_epi8(lut3, nibble3); + + auto token = _mm256_max_epi16(_mm256_max_epi16(token0, token1), + _mm256_max_epi16(token2, token3)); + return SIMDVec16{token}; + } + + FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const { + return SIMDVec16{_mm256_subs_epu16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const { + return SIMDVec16{_mm256_sub_epi16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_add_epi16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_min_epu16(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const { + return Mask16{_mm256_cmpeq_epi16(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const { + return Mask16{_mm256_cmpgt_epi16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Pow2() const { + auto pow2_lo_lut = _mm256_broadcastsi128_si256( + _mm_setr_epi8(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, + 1u << 7, 0, 0, 0, 0, 0, 0, 0, 0)); + auto pow2_hi_lut = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 1 << 0, 1 << 1, 1 << 2, 1 << 3, + 1 << 4, 1 << 5, 1 << 6, 1u << 7)); + + auto masked = _mm256_or_si256(vec, _mm256_set1_epi16(0xFF00)); + + auto pow2_lo = _mm256_shuffle_epi8(pow2_lo_lut, masked); + auto pow2_hi = _mm256_shuffle_epi8(pow2_hi_lut, masked); + + auto pow2 = _mm256_or_si256(_mm256_slli_epi16(pow2_hi, 8), pow2_lo); + return SIMDVec16{pow2}; + } + FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_or_si256(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_xor_si256(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_and_si256(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_srai_epi16(_mm256_add_epi16(vec, oth.vec), 1)}; + } + FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const { + return SIMDVec16{_mm256_or_si256(vec, _mm256_set1_epi16(0xFF00))}; + } + FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const { + return SIMDVec16{_mm256_shuffle_epi8( + _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)table)), vec)}; + } + FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const { + auto v02 = _mm256_unpacklo_epi16(low.vec, vec); + auto v13 = _mm256_unpackhi_epi16(low.vec, vec); + return {SIMDVec16{_mm256_permute2x128_si256(v02, v13, 0x20)}, + SIMDVec16{_mm256_permute2x128_si256(v02, v13, 0x31)}}; + } + FJXL_INLINE VecPair<SIMDVec32> Upcast() const { + auto v02 = _mm256_unpacklo_epi16(vec, _mm256_setzero_si256()); + auto v13 = _mm256_unpackhi_epi16(vec, _mm256_setzero_si256()); + return {SIMDVec32{_mm256_permute2x128_si256(v02, v13, 0x20)}, + SIMDVec32{_mm256_permute2x128_si256(v02, v13, 0x31)}}; + } + template <size_t i> + FJXL_INLINE SIMDVec16 SignedShiftRight() const { + return SIMDVec16{_mm256_srai_epi16(vec, i)}; + } + + static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) { + __m128i bytes = _mm_loadu_si128((__m128i*)data); + return {SIMDVec16{_mm256_cvtepu8_epi16(bytes)}}; + } + static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) { + return {Load((const uint16_t*)data)}; + } + + static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) { + __m256i bytes = _mm256_loadu_si256((__m256i*)data); + __m256i gray = _mm256_and_si256(bytes, _mm256_set1_epi16(0xFF)); + __m256i alpha = _mm256_srli_epi16(bytes, 8); + return {SIMDVec16{gray}, SIMDVec16{alpha}}; + } + static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) { + __m256i bytes1 = _mm256_loadu_si256((__m256i*)data); + __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 32)); + __m256i g_mask = _mm256_set1_epi32(0xFFFF); + __m256i g = _mm256_permute4x64_epi64( + _mm256_packus_epi32(_mm256_and_si256(bytes1, g_mask), + _mm256_and_si256(bytes2, g_mask)), + 0b11011000); + __m256i a = _mm256_permute4x64_epi64( + _mm256_packus_epi32(_mm256_srli_epi32(bytes1, 16), + _mm256_srli_epi32(bytes2, 16)), + 0b11011000); + return {SIMDVec16{g}, SIMDVec16{a}}; + } + + static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) { + __m128i bytes0 = _mm_loadu_si128((__m128i*)data); + __m128i bytes1 = _mm_loadu_si128((__m128i*)(data + 16)); + __m128i bytes2 = _mm_loadu_si128((__m128i*)(data + 32)); + + __m128i idx = + _mm_setr_epi8(0, 3, 6, 9, 12, 15, 2, 5, 8, 11, 14, 1, 4, 7, 10, 13); + + __m128i r6b5g5_0 = _mm_shuffle_epi8(bytes0, idx); + __m128i g6r5b5_1 = _mm_shuffle_epi8(bytes1, idx); + __m128i b6g5r5_2 = _mm_shuffle_epi8(bytes2, idx); + + __m128i mask010 = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0, 0, 0, 0, 0); + __m128i mask001 = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF); + + __m128i b2g2b1 = _mm_blendv_epi8(b6g5r5_2, g6r5b5_1, mask001); + __m128i b2b0b1 = _mm_blendv_epi8(b2g2b1, r6b5g5_0, mask010); + + __m128i r0r1b1 = _mm_blendv_epi8(r6b5g5_0, g6r5b5_1, mask010); + __m128i r0r1r2 = _mm_blendv_epi8(r0r1b1, b6g5r5_2, mask001); + + __m128i g1r1g0 = _mm_blendv_epi8(g6r5b5_1, r6b5g5_0, mask001); + __m128i g1g2g0 = _mm_blendv_epi8(g1r1g0, b6g5r5_2, mask010); + + __m128i g0g1g2 = _mm_alignr_epi8(g1g2g0, g1g2g0, 11); + __m128i b0b1b2 = _mm_alignr_epi8(b2b0b1, b2b0b1, 6); + + return {SIMDVec16{_mm256_cvtepu8_epi16(r0r1r2)}, + SIMDVec16{_mm256_cvtepu8_epi16(g0g1g2)}, + SIMDVec16{_mm256_cvtepu8_epi16(b0b1b2)}}; + } + static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) { + auto load_and_split_lohi = [](const unsigned char* data) { + // LHLHLH... + __m256i bytes = _mm256_loadu_si256((__m256i*)data); + // L0L0L0... + __m256i lo = _mm256_and_si256(bytes, _mm256_set1_epi16(0xFF)); + // H0H0H0... + __m256i hi = _mm256_srli_epi16(bytes, 8); + // LLLLLLLLHHHHHHHHLLLLLLLLHHHHHHHH + __m256i packed = _mm256_packus_epi16(lo, hi); + return _mm256_permute4x64_epi64(packed, 0b11011000); + }; + __m256i bytes0 = load_and_split_lohi(data); + __m256i bytes1 = load_and_split_lohi(data + 32); + __m256i bytes2 = load_and_split_lohi(data + 64); + + __m256i idx = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 3, 6, 9, 12, 15, 2, 5, 8, 11, 14, 1, 4, 7, 10, 13)); + + __m256i r6b5g5_0 = _mm256_shuffle_epi8(bytes0, idx); + __m256i g6r5b5_1 = _mm256_shuffle_epi8(bytes1, idx); + __m256i b6g5r5_2 = _mm256_shuffle_epi8(bytes2, idx); + + __m256i mask010 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0, 0)); + __m256i mask001 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF)); + + __m256i b2g2b1 = _mm256_blendv_epi8(b6g5r5_2, g6r5b5_1, mask001); + __m256i b2b0b1 = _mm256_blendv_epi8(b2g2b1, r6b5g5_0, mask010); + + __m256i r0r1b1 = _mm256_blendv_epi8(r6b5g5_0, g6r5b5_1, mask010); + __m256i r0r1r2 = _mm256_blendv_epi8(r0r1b1, b6g5r5_2, mask001); + + __m256i g1r1g0 = _mm256_blendv_epi8(g6r5b5_1, r6b5g5_0, mask001); + __m256i g1g2g0 = _mm256_blendv_epi8(g1r1g0, b6g5r5_2, mask010); + + __m256i g0g1g2 = _mm256_alignr_epi8(g1g2g0, g1g2g0, 11); + __m256i b0b1b2 = _mm256_alignr_epi8(b2b0b1, b2b0b1, 6); + + // Now r0r1r2, g0g1g2, b0b1b2 have the low bytes of the RGB pixels in their + // lower half, and the high bytes in their upper half. + + auto combine_low_hi = [](__m256i v) { + __m128i low = _mm256_extracti128_si256(v, 0); + __m128i hi = _mm256_extracti128_si256(v, 1); + __m256i low16 = _mm256_cvtepu8_epi16(low); + __m256i hi16 = _mm256_cvtepu8_epi16(hi); + return _mm256_or_si256(_mm256_slli_epi16(hi16, 8), low16); + }; + + return {SIMDVec16{combine_low_hi(r0r1r2)}, + SIMDVec16{combine_low_hi(g0g1g2)}, + SIMDVec16{combine_low_hi(b0b1b2)}}; + } + + static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) { + __m256i bytes1 = _mm256_loadu_si256((__m256i*)data); + __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 32)); + __m256i rg_mask = _mm256_set1_epi32(0xFFFF); + __m256i rg = _mm256_permute4x64_epi64( + _mm256_packus_epi32(_mm256_and_si256(bytes1, rg_mask), + _mm256_and_si256(bytes2, rg_mask)), + 0b11011000); + __m256i ba = _mm256_permute4x64_epi64( + _mm256_packus_epi32(_mm256_srli_epi32(bytes1, 16), + _mm256_srli_epi32(bytes2, 16)), + 0b11011000); + __m256i r = _mm256_and_si256(rg, _mm256_set1_epi16(0xFF)); + __m256i g = _mm256_srli_epi16(rg, 8); + __m256i b = _mm256_and_si256(ba, _mm256_set1_epi16(0xFF)); + __m256i a = _mm256_srli_epi16(ba, 8); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}}; + } + static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) { + __m256i bytes0 = _mm256_loadu_si256((__m256i*)data); + __m256i bytes1 = _mm256_loadu_si256((__m256i*)(data + 32)); + __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 64)); + __m256i bytes3 = _mm256_loadu_si256((__m256i*)(data + 96)); + + auto pack32 = [](__m256i a, __m256i b) { + return _mm256_permute4x64_epi64(_mm256_packus_epi32(a, b), 0b11011000); + }; + auto packlow32 = [&pack32](__m256i a, __m256i b) { + __m256i mask = _mm256_set1_epi32(0xFFFF); + return pack32(_mm256_and_si256(a, mask), _mm256_and_si256(b, mask)); + }; + auto packhi32 = [&pack32](__m256i a, __m256i b) { + return pack32(_mm256_srli_epi32(a, 16), _mm256_srli_epi32(b, 16)); + }; + + __m256i rb0 = packlow32(bytes0, bytes1); + __m256i rb1 = packlow32(bytes2, bytes3); + __m256i ga0 = packhi32(bytes0, bytes1); + __m256i ga1 = packhi32(bytes2, bytes3); + + __m256i r = packlow32(rb0, rb1); + __m256i g = packlow32(ga0, ga1); + __m256i b = packhi32(rb0, rb1); + __m256i a = packhi32(ga0, ga1); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}}; + } + + void SwapEndian() { + auto indices = _mm256_broadcastsi128_si256( + _mm_setr_epi8(1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14)); + vec = _mm256_shuffle_epi8(vec, indices); + } +}; + +SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true, + const SIMDVec16& if_false) { + return SIMDVec16{_mm256_blendv_epi8(if_false.vec, if_true.vec, mask)}; +} + +SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true, + const SIMDVec32& if_false) { + return SIMDVec32{_mm256_blendv_epi8(if_false.vec, if_true.vec, mask)}; +} + +struct Bits64 { + static constexpr size_t kLanes = 4; + + __m256i nbits; + __m256i bits; + + FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) { + _mm256_storeu_si256((__m256i*)nbits_out, nbits); + _mm256_storeu_si256((__m256i*)bits_out, bits); + } +}; + +struct Bits32 { + __m256i nbits; + __m256i bits; + + static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) { + return Bits32{nbits.vec, bits.vec}; + } + + Bits64 Merge() const { + auto nbits_hi32 = _mm256_srli_epi64(nbits, 32); + auto nbits_lo32 = _mm256_and_si256(nbits, _mm256_set1_epi64x(0xFFFFFFFF)); + auto bits_hi32 = _mm256_srli_epi64(bits, 32); + auto bits_lo32 = _mm256_and_si256(bits, _mm256_set1_epi64x(0xFFFFFFFF)); + + auto nbits64 = _mm256_add_epi64(nbits_hi32, nbits_lo32); + auto bits64 = + _mm256_or_si256(_mm256_sllv_epi64(bits_hi32, nbits_lo32), bits_lo32); + return Bits64{nbits64, bits64}; + } + + void Interleave(const Bits32& low) { + bits = _mm256_or_si256(_mm256_sllv_epi32(bits, low.nbits), low.bits); + nbits = _mm256_add_epi32(nbits, low.nbits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 8); + constexpr uint32_t kMask[16] = { + ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, 0, 0, + }; + __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 8 - n)); + nbits = _mm256_and_si256(mask, nbits); + bits = _mm256_and_si256(mask, bits); + } + void Skip(size_t n) { + n = std::min<size_t>(n, 8); + constexpr uint32_t kMask[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, + }; + __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 8 - n)); + nbits = _mm256_and_si256(mask, nbits); + bits = _mm256_and_si256(mask, bits); + } +}; + +struct Bits16 { + __m256i nbits; + __m256i bits; + + static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) { + return Bits16{nbits.vec, bits.vec}; + } + + Bits32 Merge() const { + auto nbits_hi16 = _mm256_srli_epi32(nbits, 16); + auto nbits_lo16 = _mm256_and_si256(nbits, _mm256_set1_epi32(0xFFFF)); + auto bits_hi16 = _mm256_srli_epi32(bits, 16); + auto bits_lo16 = _mm256_and_si256(bits, _mm256_set1_epi32(0xFFFF)); + + auto nbits32 = _mm256_add_epi32(nbits_hi16, nbits_lo16); + auto bits32 = + _mm256_or_si256(_mm256_sllv_epi32(bits_hi16, nbits_lo16), bits_lo16); + return Bits32{nbits32, bits32}; + } + + void Interleave(const Bits16& low) { + auto pow2_lo_lut = _mm256_broadcastsi128_si256( + _mm_setr_epi8(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, + 1u << 7, 0, 0, 0, 0, 0, 0, 0, 0)); + auto low_nbits_masked = + _mm256_or_si256(low.nbits, _mm256_set1_epi16(0xFF00)); + + auto bits_shifted = _mm256_mullo_epi16( + bits, _mm256_shuffle_epi8(pow2_lo_lut, low_nbits_masked)); + + nbits = _mm256_add_epi16(nbits, low.nbits); + bits = _mm256_or_si256(bits_shifted, low.bits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 16); + constexpr uint16_t kMask[32] = { + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }; + __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 16 - n)); + nbits = _mm256_and_si256(mask, nbits); + bits = _mm256_and_si256(mask, bits); + } + + void Skip(size_t n) { + n = std::min<size_t>(n, 16); + constexpr uint16_t kMask[32] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + }; + __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 16 - n)); + nbits = _mm256_and_si256(mask, nbits); + bits = _mm256_and_si256(mask, bits); + } +}; + +#endif + +#ifdef FJXL_NEON +#define FJXL_GENERIC_SIMD + +struct SIMDVec32; + +struct Mask32 { + uint32x4_t mask; + SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false); + Mask32 And(const Mask32& oth) const { + return Mask32{vandq_u32(mask, oth.mask)}; + } + size_t CountPrefix() const { + uint32_t val_unset[4] = {0, 1, 2, 3}; + uint32_t val_set[4] = {4, 4, 4, 4}; + uint32x4_t val = vbslq_u32(mask, vld1q_u32(val_set), vld1q_u32(val_unset)); + return vminvq_u32(val); + } +}; + +struct SIMDVec32 { + uint32x4_t vec; + + static constexpr size_t kLanes = 4; + + FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) { + return SIMDVec32{vld1q_u32(data)}; + } + FJXL_INLINE void Store(uint32_t* data) { vst1q_u32(data, vec); } + FJXL_INLINE static SIMDVec32 Val(uint32_t v) { + return SIMDVec32{vdupq_n_u32(v)}; + } + FJXL_INLINE SIMDVec32 ValToToken() const { + return SIMDVec32{vsubq_u32(vdupq_n_u32(32), vclzq_u32(vec))}; + } + FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const { + return SIMDVec32{vqsubq_u32(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const { + return SIMDVec32{vsubq_u32(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const { + return SIMDVec32{vaddq_u32(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const { + return SIMDVec32{veorq_u32(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Pow2() const { + return SIMDVec32{vshlq_u32(vdupq_n_u32(1), vreinterpretq_s32_u32(vec))}; + } + FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const { + return Mask32{vceqq_u32(vec, oth.vec)}; + } + FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const { + return Mask32{ + vcgtq_s32(vreinterpretq_s32_u32(vec), vreinterpretq_s32_u32(oth.vec))}; + } + template <size_t i> + FJXL_INLINE SIMDVec32 SignedShiftRight() const { + return SIMDVec32{ + vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(vec), i))}; + } +}; + +struct SIMDVec16; + +struct Mask16 { + uint16x8_t mask; + SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false); + Mask16 And(const Mask16& oth) const { + return Mask16{vandq_u16(mask, oth.mask)}; + } + size_t CountPrefix() const { + uint16_t val_unset[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + uint16_t val_set[8] = {8, 8, 8, 8, 8, 8, 8, 8}; + uint16x8_t val = vbslq_u16(mask, vld1q_u16(val_set), vld1q_u16(val_unset)); + return vminvq_u16(val); + } +}; + +struct SIMDVec16 { + uint16x8_t vec; + + static constexpr size_t kLanes = 8; + + FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) { + return SIMDVec16{vld1q_u16(data)}; + } + FJXL_INLINE void Store(uint16_t* data) { vst1q_u16(data, vec); } + FJXL_INLINE static SIMDVec16 Val(uint16_t v) { + return SIMDVec16{vdupq_n_u16(v)}; + } + FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo, + const SIMDVec32& hi) { + return SIMDVec16{vmovn_high_u32(vmovn_u32(lo.vec), hi.vec)}; + } + + FJXL_INLINE SIMDVec16 ValToToken() const { + return SIMDVec16{vsubq_u16(vdupq_n_u16(16), vclzq_u16(vec))}; + } + FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const { + return SIMDVec16{vqsubq_u16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const { + return SIMDVec16{vsubq_u16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const { + return SIMDVec16{vaddq_u16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const { + return SIMDVec16{vminq_u16(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const { + return Mask16{vceqq_u16(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const { + return Mask16{ + vcgtq_s16(vreinterpretq_s16_u16(vec), vreinterpretq_s16_u16(oth.vec))}; + } + FJXL_INLINE SIMDVec16 Pow2() const { + return SIMDVec16{vshlq_u16(vdupq_n_u16(1), vreinterpretq_s16_u16(vec))}; + } + FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const { + return SIMDVec16{vorrq_u16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const { + return SIMDVec16{veorq_u16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const { + return SIMDVec16{vandq_u16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const { + return SIMDVec16{vhaddq_u16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const { + return SIMDVec16{vorrq_u16(vec, vdupq_n_u16(0xFF00))}; + } + FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const { + uint8x16_t tbl = vld1q_u8(table); + uint8x16_t indices = vreinterpretq_u8_u16(vec); + return SIMDVec16{vreinterpretq_u16_u8(vqtbl1q_u8(tbl, indices))}; + } + FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const { + return {SIMDVec16{vzip1q_u16(low.vec, vec)}, + SIMDVec16{vzip2q_u16(low.vec, vec)}}; + } + FJXL_INLINE VecPair<SIMDVec32> Upcast() const { + uint32x4_t lo = vmovl_u16(vget_low_u16(vec)); + uint32x4_t hi = vmovl_high_u16(vec); + return {SIMDVec32{lo}, SIMDVec32{hi}}; + } + template <size_t i> + FJXL_INLINE SIMDVec16 SignedShiftRight() const { + return SIMDVec16{ + vreinterpretq_u16_s16(vshrq_n_s16(vreinterpretq_s16_u16(vec), i))}; + } + + static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) { + uint8x8_t v = vld1_u8(data); + return {SIMDVec16{vmovl_u8(v)}}; + } + static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) { + return {Load((const uint16_t*)data)}; + } + + static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) { + uint8x8x2_t v = vld2_u8(data); + return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])}}; + } + static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) { + uint16x8x2_t v = vld2q_u16((const uint16_t*)data); + return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}}; + } + + static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) { + uint8x8x3_t v = vld3_u8(data); + return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])}, + SIMDVec16{vmovl_u8(v.val[2])}}; + } + static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) { + uint16x8x3_t v = vld3q_u16((const uint16_t*)data); + return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}, SIMDVec16{v.val[2]}}; + } + + static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) { + uint8x8x4_t v = vld4_u8(data); + return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])}, + SIMDVec16{vmovl_u8(v.val[2])}, SIMDVec16{vmovl_u8(v.val[3])}}; + } + static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) { + uint16x8x4_t v = vld4q_u16((const uint16_t*)data); + return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}, SIMDVec16{v.val[2]}, + SIMDVec16{v.val[3]}}; + } + + void SwapEndian() { + vec = vreinterpretq_u16_u8(vrev16q_u8(vreinterpretq_u8_u16(vec))); + } +}; + +SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true, + const SIMDVec16& if_false) { + return SIMDVec16{vbslq_u16(mask, if_true.vec, if_false.vec)}; +} + +SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true, + const SIMDVec32& if_false) { + return SIMDVec32{vbslq_u32(mask, if_true.vec, if_false.vec)}; +} + +struct Bits64 { + static constexpr size_t kLanes = 2; + + uint64x2_t nbits; + uint64x2_t bits; + + FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) { + vst1q_u64(nbits_out, nbits); + vst1q_u64(bits_out, bits); + } +}; + +struct Bits32 { + uint32x4_t nbits; + uint32x4_t bits; + + static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) { + return Bits32{nbits.vec, bits.vec}; + } + + Bits64 Merge() const { + // TODO(veluca): can probably be optimized. + uint64x2_t nbits_lo32 = + vandq_u64(vreinterpretq_u64_u32(nbits), vdupq_n_u64(0xFFFFFFFF)); + uint64x2_t bits_hi32 = + vshlq_u64(vshrq_n_u64(vreinterpretq_u64_u32(bits), 32), + vreinterpretq_s64_u64(nbits_lo32)); + uint64x2_t bits_lo32 = + vandq_u64(vreinterpretq_u64_u32(bits), vdupq_n_u64(0xFFFFFFFF)); + uint64x2_t nbits64 = + vsraq_n_u64(nbits_lo32, vreinterpretq_u64_u32(nbits), 32); + uint64x2_t bits64 = vorrq_u64(bits_hi32, bits_lo32); + return Bits64{nbits64, bits64}; + } + + void Interleave(const Bits32& low) { + bits = + vorrq_u32(vshlq_u32(bits, vreinterpretq_s32_u32(low.nbits)), low.bits); + nbits = vaddq_u32(nbits, low.nbits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 4); + constexpr uint32_t kMask[8] = { + ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, + }; + uint32x4_t mask = vld1q_u32(kMask + 4 - n); + nbits = vandq_u32(mask, nbits); + bits = vandq_u32(mask, bits); + } + void Skip(size_t n) { + n = std::min<size_t>(n, 4); + constexpr uint32_t kMask[8] = { + 0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u, + }; + uint32x4_t mask = vld1q_u32(kMask + 4 - n); + nbits = vandq_u32(mask, nbits); + bits = vandq_u32(mask, bits); + } +}; + +struct Bits16 { + uint16x8_t nbits; + uint16x8_t bits; + + static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) { + return Bits16{nbits.vec, bits.vec}; + } + + Bits32 Merge() const { + // TODO(veluca): can probably be optimized. + uint32x4_t nbits_lo16 = + vandq_u32(vreinterpretq_u32_u16(nbits), vdupq_n_u32(0xFFFF)); + uint32x4_t bits_hi16 = + vshlq_u32(vshrq_n_u32(vreinterpretq_u32_u16(bits), 16), + vreinterpretq_s32_u32(nbits_lo16)); + uint32x4_t bits_lo16 = + vandq_u32(vreinterpretq_u32_u16(bits), vdupq_n_u32(0xFFFF)); + uint32x4_t nbits32 = + vsraq_n_u32(nbits_lo16, vreinterpretq_u32_u16(nbits), 16); + uint32x4_t bits32 = vorrq_u32(bits_hi16, bits_lo16); + return Bits32{nbits32, bits32}; + } + + void Interleave(const Bits16& low) { + bits = + vorrq_u16(vshlq_u16(bits, vreinterpretq_s16_u16(low.nbits)), low.bits); + nbits = vaddq_u16(nbits, low.nbits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 8); + constexpr uint16_t kMask[16] = { + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0, 0, 0, 0, 0, 0, 0, 0, + }; + uint16x8_t mask = vld1q_u16(kMask + 8 - n); + nbits = vandq_u16(mask, nbits); + bits = vandq_u16(mask, bits); + } + void Skip(size_t n) { + n = std::min<size_t>(n, 8); + constexpr uint16_t kMask[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + }; + uint16x8_t mask = vld1q_u16(kMask + 8 - n); + nbits = vandq_u16(mask, nbits); + bits = vandq_u16(mask, bits); + } +}; + +#endif + +#ifdef FJXL_GENERIC_SIMD +constexpr size_t SIMDVec32::kLanes; +constexpr size_t SIMDVec16::kLanes; + +// Each of these functions will process SIMDVec16::kLanes worth of values. + +FJXL_INLINE void TokenizeSIMD(const uint16_t* residuals, uint16_t* token_out, + uint16_t* nbits_out, uint16_t* bits_out) { + SIMDVec16 res = SIMDVec16::Load(residuals); + SIMDVec16 token = res.ValToToken(); + SIMDVec16 nbits = token.SatSubU(SIMDVec16::Val(1)); + SIMDVec16 bits = res.SatSubU(nbits.Pow2()); + token.Store(token_out); + nbits.Store(nbits_out); + bits.Store(bits_out); +} + +FJXL_INLINE void TokenizeSIMD(const uint32_t* residuals, uint16_t* token_out, + uint32_t* nbits_out, uint32_t* bits_out) { + static_assert(SIMDVec16::kLanes == 2 * SIMDVec32::kLanes, ""); + SIMDVec32 res_lo = SIMDVec32::Load(residuals); + SIMDVec32 res_hi = SIMDVec32::Load(residuals + SIMDVec32::kLanes); + SIMDVec32 token_lo = res_lo.ValToToken(); + SIMDVec32 token_hi = res_hi.ValToToken(); + SIMDVec32 nbits_lo = token_lo.SatSubU(SIMDVec32::Val(1)); + SIMDVec32 nbits_hi = token_hi.SatSubU(SIMDVec32::Val(1)); + SIMDVec32 bits_lo = res_lo.SatSubU(nbits_lo.Pow2()); + SIMDVec32 bits_hi = res_hi.SatSubU(nbits_hi.Pow2()); + SIMDVec16 token = SIMDVec16::FromTwo32(token_lo, token_hi); + token.Store(token_out); + nbits_lo.Store(nbits_out); + nbits_hi.Store(nbits_out + SIMDVec32::kLanes); + bits_lo.Store(bits_out); + bits_hi.Store(bits_out + SIMDVec32::kLanes); +} + +FJXL_INLINE void HuffmanSIMDUpTo13(const uint16_t* tokens, + const uint8_t* raw_nbits_simd, + const uint8_t* raw_bits_simd, + uint16_t* nbits_out, uint16_t* bits_out) { + SIMDVec16 tok = SIMDVec16::Load(tokens).PrepareForU8Lookup(); + tok.U8Lookup(raw_nbits_simd).Store(nbits_out); + tok.U8Lookup(raw_bits_simd).Store(bits_out); +} + +FJXL_INLINE void HuffmanSIMD14(const uint16_t* tokens, + const uint8_t* raw_nbits_simd, + const uint8_t* raw_bits_simd, + uint16_t* nbits_out, uint16_t* bits_out) { + SIMDVec16 token_cap = SIMDVec16::Val(15); + SIMDVec16 tok = SIMDVec16::Load(tokens); + SIMDVec16 tok_index = tok.Min(token_cap).PrepareForU8Lookup(); + SIMDVec16 huff_bits_pre = tok_index.U8Lookup(raw_bits_simd); + // Set the highest bit when token == 16; the Huffman code is constructed in + // such a way that the code for token 15 is the same as the code for 16, + // except for the highest bit. + Mask16 needs_high_bit = tok.Eq(SIMDVec16::Val(16)); + SIMDVec16 huff_bits = needs_high_bit.IfThenElse( + huff_bits_pre.Or(SIMDVec16::Val(128)), huff_bits_pre); + huff_bits.Store(bits_out); + tok_index.U8Lookup(raw_nbits_simd).Store(nbits_out); +} + +FJXL_INLINE void HuffmanSIMDAbove14(const uint16_t* tokens, + const uint8_t* raw_nbits_simd, + const uint8_t* raw_bits_simd, + uint16_t* nbits_out, uint16_t* bits_out) { + SIMDVec16 tok = SIMDVec16::Load(tokens); + // We assume `tok` fits in a *signed* 16-bit integer. + Mask16 above = tok.Gt(SIMDVec16::Val(12)); + // 13, 14 -> 13 + // 15, 16 -> 14 + // 17, 18 -> 15 + SIMDVec16 remap_tok = above.IfThenElse(tok.HAdd(SIMDVec16::Val(13)), tok); + SIMDVec16 tok_index = remap_tok.PrepareForU8Lookup(); + SIMDVec16 huff_bits_pre = tok_index.U8Lookup(raw_bits_simd); + // Set the highest bit when token == 14, 16, 18. + Mask16 needs_high_bit = above.And(tok.Eq(tok.And(SIMDVec16::Val(0xFFFE)))); + SIMDVec16 huff_bits = needs_high_bit.IfThenElse( + huff_bits_pre.Or(SIMDVec16::Val(128)), huff_bits_pre); + huff_bits.Store(bits_out); + tok_index.U8Lookup(raw_nbits_simd).Store(nbits_out); +} + +FJXL_INLINE void StoreSIMDUpTo8(const uint16_t* nbits_tok, + const uint16_t* bits_tok, + const uint16_t* nbits_huff, + const uint16_t* bits_huff, size_t n, + size_t skip, Bits32* bits_out) { + Bits16 bits = + Bits16::FromRaw(SIMDVec16::Load(nbits_tok), SIMDVec16::Load(bits_tok)); + Bits16 huff_bits = + Bits16::FromRaw(SIMDVec16::Load(nbits_huff), SIMDVec16::Load(bits_huff)); + bits.Interleave(huff_bits); + bits.ClipTo(n); + bits.Skip(skip); + bits_out[0] = bits.Merge(); +} + +// Huffman and raw bits don't necessarily fit in a single u16 here. +FJXL_INLINE void StoreSIMDUpTo14(const uint16_t* nbits_tok, + const uint16_t* bits_tok, + const uint16_t* nbits_huff, + const uint16_t* bits_huff, size_t n, + size_t skip, Bits32* bits_out) { + VecPair<SIMDVec16> bits = + SIMDVec16::Load(bits_tok).Interleave(SIMDVec16::Load(bits_huff)); + VecPair<SIMDVec16> nbits = + SIMDVec16::Load(nbits_tok).Interleave(SIMDVec16::Load(nbits_huff)); + Bits16 low = Bits16::FromRaw(nbits.low, bits.low); + Bits16 hi = Bits16::FromRaw(nbits.hi, bits.hi); + low.ClipTo(2 * n); + low.Skip(2 * skip); + hi.ClipTo(std::max(2 * n, SIMDVec16::kLanes) - SIMDVec16::kLanes); + hi.Skip(std::max(2 * skip, SIMDVec16::kLanes) - SIMDVec16::kLanes); + + bits_out[0] = low.Merge(); + bits_out[1] = hi.Merge(); +} + +FJXL_INLINE void StoreSIMDAbove14(const uint32_t* nbits_tok, + const uint32_t* bits_tok, + const uint16_t* nbits_huff, + const uint16_t* bits_huff, size_t n, + size_t skip, Bits32* bits_out) { + static_assert(SIMDVec16::kLanes == 2 * SIMDVec32::kLanes, ""); + Bits32 bits_low = + Bits32::FromRaw(SIMDVec32::Load(nbits_tok), SIMDVec32::Load(bits_tok)); + Bits32 bits_hi = + Bits32::FromRaw(SIMDVec32::Load(nbits_tok + SIMDVec32::kLanes), + SIMDVec32::Load(bits_tok + SIMDVec32::kLanes)); + + VecPair<SIMDVec32> huff_bits = SIMDVec16::Load(bits_huff).Upcast(); + VecPair<SIMDVec32> huff_nbits = SIMDVec16::Load(nbits_huff).Upcast(); + + Bits32 huff_low = Bits32::FromRaw(huff_nbits.low, huff_bits.low); + Bits32 huff_hi = Bits32::FromRaw(huff_nbits.hi, huff_bits.hi); + + bits_low.Interleave(huff_low); + bits_low.ClipTo(n); + bits_low.Skip(skip); + bits_out[0] = bits_low; + bits_hi.Interleave(huff_hi); + bits_hi.ClipTo(std::max(n, SIMDVec32::kLanes) - SIMDVec32::kLanes); + bits_hi.Skip(std::max(skip, SIMDVec32::kLanes) - SIMDVec32::kLanes); + bits_out[1] = bits_hi; +} + +#ifdef FJXL_AVX512 +FJXL_INLINE void StoreToWriterAVX512(const Bits32& bits32, BitWriter& output) { + __m512i bits = bits32.bits; + __m512i nbits = bits32.nbits; + + // Insert the leftover bits from the bit buffer at the bottom of the vector + // and extract the top of the vector. + uint64_t trail_bits = + _mm512_cvtsi512_si32(_mm512_alignr_epi32(bits, bits, 15)); + uint64_t trail_nbits = + _mm512_cvtsi512_si32(_mm512_alignr_epi32(nbits, nbits, 15)); + __m512i lead_bits = _mm512_set1_epi32(output.buffer); + __m512i lead_nbits = _mm512_set1_epi32(output.bits_in_buffer); + bits = _mm512_alignr_epi32(bits, lead_bits, 15); + nbits = _mm512_alignr_epi32(nbits, lead_nbits, 15); + + // Merge 32 -> 64 bits. + Bits32 b{nbits, bits}; + Bits64 b64 = b.Merge(); + bits = b64.bits; + nbits = b64.nbits; + + __m512i zero = _mm512_setzero_si512(); + + auto sh1 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 7); }; + auto sh2 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 6); }; + auto sh4 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 4); }; + + // Compute first-past-end-bit-position. + __m512i end_interm0 = _mm512_add_epi64(nbits, sh1(nbits)); + __m512i end_interm1 = _mm512_add_epi64(end_interm0, sh2(end_interm0)); + __m512i end = _mm512_add_epi64(end_interm1, sh4(end_interm1)); + + uint64_t simd_nbits = _mm512_cvtsi512_si32(_mm512_alignr_epi64(end, end, 7)); + + // Compute begin-bit-position. + __m512i begin = _mm512_sub_epi64(end, nbits); + + // Index of the last bit in the chunk, or the end bit if nbits==0. + __m512i last = _mm512_mask_sub_epi64( + end, _mm512_cmpneq_epi64_mask(nbits, zero), end, _mm512_set1_epi64(1)); + + __m512i lane_offset_mask = _mm512_set1_epi64(63); + + // Starting position of the chunk that each lane will ultimately belong to. + __m512i chunk_start = _mm512_andnot_si512(lane_offset_mask, last); + + // For all lanes that contain bits belonging to two different 64-bit chunks, + // compute the number of bits that belong to the first chunk. + // total # of bits fit in a u16, so we can satsub_u16 here. + __m512i first_chunk_nbits = _mm512_subs_epu16(chunk_start, begin); + + // Move all the previous-chunk-bits to the previous lane. + __m512i negnbits = _mm512_sub_epi64(_mm512_set1_epi64(64), first_chunk_nbits); + __m512i first_chunk_bits = + _mm512_srlv_epi64(_mm512_sllv_epi64(bits, negnbits), negnbits); + __m512i first_chunk_bits_down = + _mm512_alignr_epi32(zero, first_chunk_bits, 2); + bits = _mm512_srlv_epi64(bits, first_chunk_nbits); + nbits = _mm512_sub_epi64(nbits, first_chunk_nbits); + bits = _mm512_or_si512(bits, _mm512_sllv_epi64(first_chunk_bits_down, nbits)); + begin = _mm512_add_epi64(begin, first_chunk_nbits); + + // We now know that every lane should give bits to only one chunk. We can + // shift the bits and then horizontally-or-reduce them within the same chunk. + __m512i offset = _mm512_and_si512(begin, lane_offset_mask); + __m512i aligned_bits = _mm512_sllv_epi64(bits, offset); + // h-or-reduce within same chunk + __m512i red0 = _mm512_mask_or_epi64( + aligned_bits, _mm512_cmpeq_epi64_mask(sh1(chunk_start), chunk_start), + sh1(aligned_bits), aligned_bits); + __m512i red1 = _mm512_mask_or_epi64( + red0, _mm512_cmpeq_epi64_mask(sh2(chunk_start), chunk_start), sh2(red0), + red0); + __m512i reduced = _mm512_mask_or_epi64( + red1, _mm512_cmpeq_epi64_mask(sh4(chunk_start), chunk_start), sh4(red1), + red1); + // Extract the highest lane that belongs to each chunk (the lane that ends up + // with the OR-ed value of all the other lanes of that chunk). + __m512i next_chunk_start = + _mm512_alignr_epi32(_mm512_set1_epi64(~0), chunk_start, 2); + __m512i result = _mm512_maskz_compress_epi64( + _mm512_cmpneq_epi64_mask(chunk_start, next_chunk_start), reduced); + + _mm512_storeu_si512((__m512i*)(output.data.get() + output.bytes_written), + result); + + // Update the bit writer and add the last 32-bit lane. + // Note that since trail_nbits was at most 32 to begin with, operating on + // trail_bits does not risk overflowing. + output.bytes_written += simd_nbits / 8; + // Here we are implicitly relying on the fact that simd_nbits < 512 to know + // that the byte of bitreader data we access is initialized. This is + // guaranteed because the remaining bits in the bitreader buffer are at most + // 7, so simd_nbits <= 505 always. + trail_bits = (trail_bits << (simd_nbits % 8)) + + output.data.get()[output.bytes_written]; + trail_nbits += simd_nbits % 8; + StoreLE64(output.data.get() + output.bytes_written, trail_bits); + size_t trail_bytes = trail_nbits / 8; + output.bits_in_buffer = trail_nbits % 8; + output.buffer = trail_bits >> (trail_bytes * 8); + output.bytes_written += trail_bytes; +} + +#endif + +template <size_t n> +FJXL_INLINE void StoreToWriter(const Bits32* bits, BitWriter& output) { +#ifdef FJXL_AVX512 + static_assert(n <= 2, ""); + StoreToWriterAVX512(bits[0], output); + if (n == 2) { + StoreToWriterAVX512(bits[1], output); + } + return; +#endif + static_assert(n <= 4, ""); + alignas(64) uint64_t nbits64[Bits64::kLanes * n]; + alignas(64) uint64_t bits64[Bits64::kLanes * n]; + bits[0].Merge().Store(nbits64, bits64); + if (n > 1) { + bits[1].Merge().Store(nbits64 + Bits64::kLanes, bits64 + Bits64::kLanes); + } + if (n > 2) { + bits[2].Merge().Store(nbits64 + 2 * Bits64::kLanes, + bits64 + 2 * Bits64::kLanes); + } + if (n > 3) { + bits[3].Merge().Store(nbits64 + 3 * Bits64::kLanes, + bits64 + 3 * Bits64::kLanes); + } + output.WriteMultiple(nbits64, bits64, Bits64::kLanes * n); +} + +namespace detail { +template <typename T> +struct IntegerTypes; + +template <> +struct IntegerTypes<SIMDVec16> { + using signed_ = int16_t; + using unsigned_ = uint16_t; +}; + +template <> +struct IntegerTypes<SIMDVec32> { + using signed_ = int32_t; + using unsigned_ = uint32_t; +}; + +template <typename T> +struct SIMDType; + +template <> +struct SIMDType<int16_t> { + using type = SIMDVec16; +}; + +template <> +struct SIMDType<int32_t> { + using type = SIMDVec32; +}; + +} // namespace detail + +template <typename T> +using signed_t = typename detail::IntegerTypes<T>::signed_; + +template <typename T> +using unsigned_t = typename detail::IntegerTypes<T>::unsigned_; + +template <typename T> +using simd_t = typename detail::SIMDType<T>::type; + +// This function will process exactly one vector worth of pixels. + +template <typename T> +size_t PredictPixels(const signed_t<T>* pixels, const signed_t<T>* pixels_left, + const signed_t<T>* pixels_top, + const signed_t<T>* pixels_topleft, + unsigned_t<T>* residuals) { + T px = T::Load((unsigned_t<T>*)pixels); + T left = T::Load((unsigned_t<T>*)pixels_left); + T top = T::Load((unsigned_t<T>*)pixels_top); + T topleft = T::Load((unsigned_t<T>*)pixels_topleft); + T ac = left.Sub(topleft); + T ab = left.Sub(top); + T bc = top.Sub(topleft); + T grad = ac.Add(top); + T d = ab.Xor(bc); + T zero = T::Val(0); + T clamp = zero.Gt(d).IfThenElse(top, left); + T s = ac.Xor(bc); + T pred = zero.Gt(s).IfThenElse(grad, clamp); + T res = px.Sub(pred); + T res_times_2 = res.Add(res); + res = zero.Gt(res).IfThenElse(T::Val(-1).Sub(res_times_2), res_times_2); + res.Store(residuals); + return res.Eq(T::Val(0)).CountPrefix(); +} + +#endif + +void EncodeHybridUint000(uint32_t value, uint32_t* token, uint32_t* nbits, + uint32_t* bits) { + uint32_t n = FloorLog2(value); + *token = value ? n + 1 : 0; + *nbits = value ? n : 0; + *bits = value ? value - (1 << n) : 0; +} + +#ifdef FJXL_AVX512 +constexpr static size_t kLogChunkSize = 5; +#elif defined(FJXL_AVX2) || defined(FJXL_NEON) +// Even if NEON only has 128-bit lanes, it is still significantly (~1.3x) faster +// to process two vectors at a time. +constexpr static size_t kLogChunkSize = 4; +#else +constexpr static size_t kLogChunkSize = 3; +#endif + +constexpr static size_t kChunkSize = 1 << kLogChunkSize; + +template <typename Residual> +void GenericEncodeChunk(const Residual* residuals, size_t n, size_t skip, + const PrefixCode& code, BitWriter& output) { + for (size_t ix = skip; ix < n; ix++) { + unsigned token, nbits, bits; + EncodeHybridUint000(residuals[ix], &token, &nbits, &bits); + output.Write(code.raw_nbits[token] + nbits, + code.raw_bits[token] | bits << code.raw_nbits[token]); + } +} + +struct UpTo8Bits { + size_t bitdepth; + explicit UpTo8Bits(size_t bitdepth) : bitdepth(bitdepth) { + assert(bitdepth <= 8); + } + // Here we can fit up to 9 extra bits + 7 Huffman bits in a u16; for all other + // symbols, we could actually go up to 8 Huffman bits as we have at most 8 + // extra bits; however, the SIMD bit merging logic for AVX2 assumes that no + // Huffman length is 8 or more, so we cap at 8 anyway. Last symbol is used for + // LZ77 lengths and has no limitations except allowing to represent 32 symbols + // in total. + static constexpr uint8_t kMinRawLength[12] = {}; + static constexpr uint8_t kMaxRawLength[12] = { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 10, + }; + static size_t MaxEncodedBitsPerSample() { return 16; } + static constexpr size_t kInputBytes = 1; + using pixel_t = int16_t; + using upixel_t = uint16_t; + + static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits, + size_t n, uint8_t* nbits_simd, + uint8_t* bits_simd) { + assert(n <= 16); + memcpy(nbits_simd, nbits, 16); + memcpy(bits_simd, bits, 16); + } + +#ifdef FJXL_GENERIC_SIMD + static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip, + const uint8_t* raw_nbits_simd, + const uint8_t* raw_bits_simd, BitWriter& output) { + Bits32 bits32[kChunkSize / SIMDVec16::kLanes]; + alignas(64) uint16_t bits[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits[SIMDVec16::kLanes]; + alignas(64) uint16_t bits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t token[SIMDVec16::kLanes]; + for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) { + TokenizeSIMD(residuals + i, token, nbits, bits); + HuffmanSIMDUpTo13(token, raw_nbits_simd, raw_bits_simd, nbits_huff, + bits_huff); + StoreSIMDUpTo8(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i, + std::max(skip, i) - i, bits32 + i / SIMDVec16::kLanes); + } + StoreToWriter<kChunkSize / SIMDVec16::kLanes>(bits32, output); + } +#endif + + size_t NumSymbols(bool doing_ycocg_or_large_palette) const { + // values gain 1 bit for YCoCg, 1 bit for prediction. + // Maximum symbol is 1 + effective bit depth of residuals. + if (doing_ycocg_or_large_palette) { + return bitdepth + 3; + } else { + return bitdepth + 2; + } + } +}; +constexpr uint8_t UpTo8Bits::kMinRawLength[]; +constexpr uint8_t UpTo8Bits::kMaxRawLength[]; + +struct From9To13Bits { + size_t bitdepth; + explicit From9To13Bits(size_t bitdepth) : bitdepth(bitdepth) { + assert(bitdepth <= 13 && bitdepth >= 9); + } + // Last symbol is used for LZ77 lengths and has no limitations except allowing + // to represent 32 symbols in total. + // We cannot fit all the bits in a u16, so do not even try and use up to 8 + // bits per raw symbol. + // There are at most 16 raw symbols, so Huffman coding can be SIMDfied without + // any special tricks. + static constexpr uint8_t kMinRawLength[17] = {}; + static constexpr uint8_t kMaxRawLength[17] = { + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 10, + }; + static size_t MaxEncodedBitsPerSample() { return 21; } + static constexpr size_t kInputBytes = 2; + using pixel_t = int16_t; + using upixel_t = uint16_t; + + static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits, + size_t n, uint8_t* nbits_simd, + uint8_t* bits_simd) { + assert(n <= 16); + memcpy(nbits_simd, nbits, 16); + memcpy(bits_simd, bits, 16); + } + +#ifdef FJXL_GENERIC_SIMD + static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip, + const uint8_t* raw_nbits_simd, + const uint8_t* raw_bits_simd, BitWriter& output) { + Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes]; + alignas(64) uint16_t bits[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits[SIMDVec16::kLanes]; + alignas(64) uint16_t bits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t token[SIMDVec16::kLanes]; + for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) { + TokenizeSIMD(residuals + i, token, nbits, bits); + HuffmanSIMDUpTo13(token, raw_nbits_simd, raw_bits_simd, nbits_huff, + bits_huff); + StoreSIMDUpTo14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i, + std::max(skip, i) - i, + bits32 + 2 * i / SIMDVec16::kLanes); + } + StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output); + } +#endif + + size_t NumSymbols(bool doing_ycocg_or_large_palette) const { + // values gain 1 bit for YCoCg, 1 bit for prediction. + // Maximum symbol is 1 + effective bit depth of residuals. + if (doing_ycocg_or_large_palette) { + return bitdepth + 3; + } else { + return bitdepth + 2; + } + } +}; +constexpr uint8_t From9To13Bits::kMinRawLength[]; +constexpr uint8_t From9To13Bits::kMaxRawLength[]; + +void CheckHuffmanBitsSIMD(int bits1, int nbits1, int bits2, int nbits2) { + assert(nbits1 == 8); + assert(nbits2 == 8); + assert(bits2 == (bits1 | 128)); +} + +struct Exactly14Bits { + explicit Exactly14Bits(size_t bitdepth) { assert(bitdepth == 14); } + // Force LZ77 symbols to have at least 8 bits, and raw symbols 15 and 16 to + // have exactly 8, and no other symbol to have 8 or more. This ensures that + // the representation for 15 and 16 is identical up to one bit. + static constexpr uint8_t kMinRawLength[18] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 7, + }; + static constexpr uint8_t kMaxRawLength[18] = { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 10, + }; + static constexpr size_t bitdepth = 14; + static size_t MaxEncodedBitsPerSample() { return 22; } + static constexpr size_t kInputBytes = 2; + using pixel_t = int16_t; + using upixel_t = uint16_t; + + static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits, + size_t n, uint8_t* nbits_simd, + uint8_t* bits_simd) { + assert(n == 17); + CheckHuffmanBitsSIMD(bits[15], nbits[15], bits[16], nbits[16]); + memcpy(nbits_simd, nbits, 16); + memcpy(bits_simd, bits, 16); + } + +#ifdef FJXL_GENERIC_SIMD + static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip, + const uint8_t* raw_nbits_simd, + const uint8_t* raw_bits_simd, BitWriter& output) { + Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes]; + alignas(64) uint16_t bits[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits[SIMDVec16::kLanes]; + alignas(64) uint16_t bits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t token[SIMDVec16::kLanes]; + for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) { + TokenizeSIMD(residuals + i, token, nbits, bits); + HuffmanSIMD14(token, raw_nbits_simd, raw_bits_simd, nbits_huff, + bits_huff); + StoreSIMDUpTo14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i, + std::max(skip, i) - i, + bits32 + 2 * i / SIMDVec16::kLanes); + } + StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output); + } +#endif + + size_t NumSymbols(bool) const { return 17; } +}; +constexpr uint8_t Exactly14Bits::kMinRawLength[]; +constexpr uint8_t Exactly14Bits::kMaxRawLength[]; + +struct MoreThan14Bits { + size_t bitdepth; + explicit MoreThan14Bits(size_t bitdepth) : bitdepth(bitdepth) { + assert(bitdepth > 14); + assert(bitdepth <= 16); + } + // Force LZ77 symbols to have at least 8 bits, and raw symbols 13 to 18 to + // have exactly 8, and no other symbol to have 8 or more. This ensures that + // the representation for (13, 14), (15, 16), (17, 18) is identical up to one + // bit. + static constexpr uint8_t kMinRawLength[20] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 7, + }; + static constexpr uint8_t kMaxRawLength[20] = { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 10, + }; + static size_t MaxEncodedBitsPerSample() { return 24; } + static constexpr size_t kInputBytes = 2; + using pixel_t = int32_t; + using upixel_t = uint32_t; + + static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits, + size_t n, uint8_t* nbits_simd, + uint8_t* bits_simd) { + assert(n == 19); + CheckHuffmanBitsSIMD(bits[13], nbits[13], bits[14], nbits[14]); + CheckHuffmanBitsSIMD(bits[15], nbits[15], bits[16], nbits[16]); + CheckHuffmanBitsSIMD(bits[17], nbits[17], bits[18], nbits[18]); + for (size_t i = 0; i < 14; i++) { + nbits_simd[i] = nbits[i]; + bits_simd[i] = bits[i]; + } + nbits_simd[14] = nbits[15]; + bits_simd[14] = bits[15]; + nbits_simd[15] = nbits[17]; + bits_simd[15] = bits[17]; + } + +#ifdef FJXL_GENERIC_SIMD + static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip, + const uint8_t* raw_nbits_simd, + const uint8_t* raw_bits_simd, BitWriter& output) { + Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes]; + alignas(64) uint32_t bits[SIMDVec16::kLanes]; + alignas(64) uint32_t nbits[SIMDVec16::kLanes]; + alignas(64) uint16_t bits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t token[SIMDVec16::kLanes]; + for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) { + TokenizeSIMD(residuals + i, token, nbits, bits); + HuffmanSIMDAbove14(token, raw_nbits_simd, raw_bits_simd, nbits_huff, + bits_huff); + StoreSIMDAbove14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i, + std::max(skip, i) - i, + bits32 + 2 * i / SIMDVec16::kLanes); + } + StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output); + } +#endif + size_t NumSymbols(bool) const { return 19; } +}; +constexpr uint8_t MoreThan14Bits::kMinRawLength[]; +constexpr uint8_t MoreThan14Bits::kMaxRawLength[]; + +void PrepareDCGlobalCommon(bool is_single_group, size_t width, size_t height, + const PrefixCode code[4], BitWriter* output) { + output->Allocate(100000 + (is_single_group ? width * height * 16 : 0)); + // No patches, spline or noise. + output->Write(1, 1); // default DC dequantization factors (?) + output->Write(1, 1); // use global tree / histograms + output->Write(1, 0); // no lz77 for the tree + + output->Write(1, 1); // simple code for the tree's context map + output->Write(2, 0); // all contexts clustered together + output->Write(1, 1); // use prefix code for tree + output->Write(4, 0); // 000 hybrid uint + output->Write(6, 0b100011); // Alphabet size is 4 (var16) + output->Write(2, 1); // simple prefix code + output->Write(2, 3); // with 4 symbols + output->Write(2, 0); + output->Write(2, 1); + output->Write(2, 2); + output->Write(2, 3); + output->Write(1, 0); // First tree encoding option + + // Huffman table + extra bits for the tree. + uint8_t symbol_bits[6] = {0b00, 0b10, 0b001, 0b101, 0b0011, 0b0111}; + uint8_t symbol_nbits[6] = {2, 2, 3, 3, 4, 4}; + // Write a tree with a leaf per channel, and gradient predictor for every + // leaf. + for (auto v : {1, 2, 1, 4, 1, 0, 0, 5, 0, 0, 0, 0, 5, + 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0}) { + output->Write(symbol_nbits[v], symbol_bits[v]); + } + + output->Write(1, 1); // Enable lz77 for the main bitstream + output->Write(2, 0b00); // lz77 offset 224 + static_assert(kLZ77Offset == 224, ""); + output->Write(4, 0b1010); // lz77 min length 7 + // 400 hybrid uint config for lz77 + output->Write(4, 4); + output->Write(3, 0); + output->Write(3, 0); + + output->Write(1, 1); // simple code for the context map + output->Write(2, 3); // 3 bits per entry + output->Write(3, 4); // channel 3 + output->Write(3, 3); // channel 2 + output->Write(3, 2); // channel 1 + output->Write(3, 1); // channel 0 + output->Write(3, 0); // distance histogram first + + output->Write(1, 1); // use prefix codes + output->Write(4, 0); // 000 hybrid uint config for distances (only need 0) + for (size_t i = 0; i < 4; i++) { + output->Write(4, 0); // 000 hybrid uint config for symbols (only <= 10) + } + + // Distance alphabet size: + output->Write(5, 0b00001); // 2: just need 1 for RLE (i.e. distance 1) + // Symbol + LZ77 alphabet size: + for (size_t i = 0; i < 4; i++) { + output->Write(1, 1); // > 1 + output->Write(4, 8); // <= 512 + output->Write(8, 256); // == 512 + } + + // Distance histogram: + output->Write(2, 1); // simple prefix code + output->Write(2, 0); // with one symbol + output->Write(1, 1); // 1 + + // Symbol + lz77 histogram: + for (size_t i = 0; i < 4; i++) { + code[i].WriteTo(output); + } + + // Group header for global modular image. + output->Write(1, 1); // Global tree + output->Write(1, 1); // All default wp +} + +void PrepareDCGlobal(bool is_single_group, size_t width, size_t height, + size_t nb_chans, const PrefixCode code[4], + BitWriter* output) { + PrepareDCGlobalCommon(is_single_group, width, height, code, output); + if (nb_chans > 2) { + output->Write(2, 0b01); // 1 transform + output->Write(2, 0b00); // RCT + output->Write(5, 0b00000); // Starting from ch 0 + output->Write(2, 0b00); // YCoCg + } else { + output->Write(2, 0b00); // no transforms + } + if (!is_single_group) { + output->ZeroPadToByte(); + } +} + +template <typename BitDepth> +struct ChunkEncoder { + void PrepareForSimd() { + BitDepth::PrepareForSimd(code->raw_nbits, code->raw_bits, code->numraw, + raw_nbits_simd, raw_bits_simd); + } + FJXL_INLINE static void EncodeRle(size_t count, const PrefixCode& code, + BitWriter& output) { + if (count == 0) return; + count -= kLZ77MinLength + 1; + if (count < kLZ77CacheSize) { + output.Write(code.lz77_cache_nbits[count], code.lz77_cache_bits[count]); + } else { + unsigned token, nbits, bits; + EncodeHybridUintLZ77(count, &token, &nbits, &bits); + uint64_t wbits = bits; + wbits = (wbits << code.lz77_nbits[token]) | code.lz77_bits[token]; + wbits = (wbits << code.raw_nbits[0]) | code.raw_bits[0]; + output.Write(code.lz77_nbits[token] + nbits + code.raw_nbits[0], wbits); + } + } + + FJXL_INLINE void Chunk(size_t run, typename BitDepth::upixel_t* residuals, + size_t skip, size_t n) { + EncodeRle(run, *code, *output); +#ifdef FJXL_GENERIC_SIMD + BitDepth::EncodeChunkSimd(residuals, n, skip, raw_nbits_simd, raw_bits_simd, + *output); +#else + GenericEncodeChunk(residuals, n, skip, *code, *output); +#endif + } + + inline void Finalize(size_t run) { EncodeRle(run, *code, *output); } + + const PrefixCode* code; + BitWriter* output; + alignas(64) uint8_t raw_nbits_simd[16] = {}; + alignas(64) uint8_t raw_bits_simd[16] = {}; +}; + +template <typename BitDepth> +struct ChunkSampleCollector { + FJXL_INLINE void Rle(size_t count, uint64_t* lz77_counts) { + if (count == 0) return; + raw_counts[0] += 1; + count -= kLZ77MinLength + 1; + unsigned token, nbits, bits; + EncodeHybridUintLZ77(count, &token, &nbits, &bits); + lz77_counts[token]++; + } + + FJXL_INLINE void Chunk(size_t run, typename BitDepth::upixel_t* residuals, + size_t skip, size_t n) { + // Run is broken. Encode the run and encode the individual vector. + Rle(run, lz77_counts); + for (size_t ix = skip; ix < n; ix++) { + unsigned token, nbits, bits; + EncodeHybridUint000(residuals[ix], &token, &nbits, &bits); + raw_counts[token]++; + } + } + + // don't count final run since we don't know how long it really is + void Finalize(size_t run) {} + + uint64_t* raw_counts; + uint64_t* lz77_counts; +}; + +constexpr uint32_t PackSigned(int32_t value) { + return (static_cast<uint32_t>(value) << 1) ^ + ((static_cast<uint32_t>(~value) >> 31) - 1); +} + +template <typename T, typename BitDepth> +struct ChannelRowProcessor { + using upixel_t = typename BitDepth::upixel_t; + using pixel_t = typename BitDepth::pixel_t; + T* t; + void ProcessChunk(const pixel_t* row, const pixel_t* row_left, + const pixel_t* row_top, const pixel_t* row_topleft, + size_t n) { + alignas(64) upixel_t residuals[kChunkSize] = {}; + size_t prefix_size = 0; + size_t required_prefix_size = 0; +#ifdef FJXL_GENERIC_SIMD + constexpr size_t kNum = + sizeof(pixel_t) == 2 ? SIMDVec16::kLanes : SIMDVec32::kLanes; + for (size_t ix = 0; ix < kChunkSize; ix += kNum) { + size_t c = + PredictPixels<simd_t<pixel_t>>(row + ix, row_left + ix, row_top + ix, + row_topleft + ix, residuals + ix); + prefix_size = + prefix_size == required_prefix_size ? prefix_size + c : prefix_size; + required_prefix_size += kNum; + } +#else + for (size_t ix = 0; ix < kChunkSize; ix++) { + pixel_t px = row[ix]; + pixel_t left = row_left[ix]; + pixel_t top = row_top[ix]; + pixel_t topleft = row_topleft[ix]; + pixel_t ac = left - topleft; + pixel_t ab = left - top; + pixel_t bc = top - topleft; + pixel_t grad = static_cast<pixel_t>(static_cast<upixel_t>(ac) + + static_cast<upixel_t>(top)); + pixel_t d = ab ^ bc; + pixel_t clamp = d < 0 ? top : left; + pixel_t s = ac ^ bc; + pixel_t pred = s < 0 ? grad : clamp; + residuals[ix] = PackSigned(px - pred); + prefix_size = prefix_size == required_prefix_size + ? prefix_size + (residuals[ix] == 0) + : prefix_size; + required_prefix_size += 1; + } +#endif + prefix_size = std::min(n, prefix_size); + if (prefix_size == n && (run > 0 || prefix_size > kLZ77MinLength)) { + // Run continues, nothing to do. + run += prefix_size; + } else if (prefix_size + run > kLZ77MinLength) { + // Run is broken. Encode the run and encode the individual vector. + t->Chunk(run + prefix_size, residuals, prefix_size, n); + run = 0; + } else { + // There was no run to begin with. + t->Chunk(0, residuals, 0, n); + } + } + + void ProcessRow(const pixel_t* row, const pixel_t* row_left, + const pixel_t* row_top, const pixel_t* row_topleft, + size_t xs) { + for (size_t x = 0; x < xs; x += kChunkSize) { + ProcessChunk(row + x, row_left + x, row_top + x, row_topleft + x, + std::min(kChunkSize, xs - x)); + } + } + + void Finalize() { t->Finalize(run); } + // Invariant: run == 0 or run > kLZ77MinLength. + size_t run = 0; +}; + +uint16_t LoadLE16(const unsigned char* ptr) { + return uint16_t{ptr[0]} | (uint16_t{ptr[1]} << 8); +} + +uint16_t SwapEndian(uint16_t in) { return (in >> 8) | (in << 8); } + +#ifdef FJXL_GENERIC_SIMD +void StorePixels(SIMDVec16 p, int16_t* dest) { p.Store((uint16_t*)dest); } + +void StorePixels(SIMDVec16 p, int32_t* dest) { + VecPair<SIMDVec32> p_up = p.Upcast(); + p_up.low.Store((uint32_t*)dest); + p_up.hi.Store((uint32_t*)dest + SIMDVec32::kLanes); +} +#endif + +template <typename pixel_t> +void FillRowG8(const unsigned char* rgba, size_t oxs, pixel_t* luma) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadG8(rgba + x); + StorePixels(rgb[0], luma + x); + } +#endif + for (; x < oxs; x++) { + luma[x] = rgba[x]; + } +} + +template <bool big_endian, typename pixel_t> +void FillRowG16(const unsigned char* rgba, size_t oxs, pixel_t* luma) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadG16(rgba + 2 * x); + if (big_endian) { + rgb[0].SwapEndian(); + } + StorePixels(rgb[0], luma + x); + } +#endif + for (; x < oxs; x++) { + uint16_t val = LoadLE16(rgba + 2 * x); + if (big_endian) { + val = SwapEndian(val); + } + luma[x] = val; + } +} + +template <typename pixel_t> +void FillRowGA8(const unsigned char* rgba, size_t oxs, pixel_t* luma, + pixel_t* alpha) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadGA8(rgba + 2 * x); + StorePixels(rgb[0], luma + x); + StorePixels(rgb[1], alpha + x); + } +#endif + for (; x < oxs; x++) { + luma[x] = rgba[2 * x]; + alpha[x] = rgba[2 * x + 1]; + } +} + +template <bool big_endian, typename pixel_t> +void FillRowGA16(const unsigned char* rgba, size_t oxs, pixel_t* luma, + pixel_t* alpha) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadGA16(rgba + 4 * x); + if (big_endian) { + rgb[0].SwapEndian(); + rgb[1].SwapEndian(); + } + StorePixels(rgb[0], luma + x); + StorePixels(rgb[1], alpha + x); + } +#endif + for (; x < oxs; x++) { + uint16_t l = LoadLE16(rgba + 4 * x); + uint16_t a = LoadLE16(rgba + 4 * x + 2); + if (big_endian) { + l = SwapEndian(l); + a = SwapEndian(a); + } + luma[x] = l; + alpha[x] = a; + } +} + +template <typename pixel_t> +void StoreYCoCg(pixel_t r, pixel_t g, pixel_t b, pixel_t* y, pixel_t* co, + pixel_t* cg) { + *co = r - b; + pixel_t tmp = b + (*co >> 1); + *cg = g - tmp; + *y = tmp + (*cg >> 1); +} + +#ifdef FJXL_GENERIC_SIMD +void StoreYCoCg(SIMDVec16 r, SIMDVec16 g, SIMDVec16 b, int16_t* y, int16_t* co, + int16_t* cg) { + SIMDVec16 co_v = r.Sub(b); + SIMDVec16 tmp = b.Add(co_v.SignedShiftRight<1>()); + SIMDVec16 cg_v = g.Sub(tmp); + SIMDVec16 y_v = tmp.Add(cg_v.SignedShiftRight<1>()); + y_v.Store((uint16_t*)y); + co_v.Store((uint16_t*)co); + cg_v.Store((uint16_t*)cg); +} + +void StoreYCoCg(SIMDVec16 r, SIMDVec16 g, SIMDVec16 b, int32_t* y, int32_t* co, + int32_t* cg) { + VecPair<SIMDVec32> r_up = r.Upcast(); + VecPair<SIMDVec32> g_up = g.Upcast(); + VecPair<SIMDVec32> b_up = b.Upcast(); + SIMDVec32 co_lo_v = r_up.low.Sub(b_up.low); + SIMDVec32 tmp_lo = b_up.low.Add(co_lo_v.SignedShiftRight<1>()); + SIMDVec32 cg_lo_v = g_up.low.Sub(tmp_lo); + SIMDVec32 y_lo_v = tmp_lo.Add(cg_lo_v.SignedShiftRight<1>()); + SIMDVec32 co_hi_v = r_up.hi.Sub(b_up.hi); + SIMDVec32 tmp_hi = b_up.hi.Add(co_hi_v.SignedShiftRight<1>()); + SIMDVec32 cg_hi_v = g_up.hi.Sub(tmp_hi); + SIMDVec32 y_hi_v = tmp_hi.Add(cg_hi_v.SignedShiftRight<1>()); + y_lo_v.Store((uint32_t*)y); + co_lo_v.Store((uint32_t*)co); + cg_lo_v.Store((uint32_t*)cg); + y_hi_v.Store((uint32_t*)y + SIMDVec32::kLanes); + co_hi_v.Store((uint32_t*)co + SIMDVec32::kLanes); + cg_hi_v.Store((uint32_t*)cg + SIMDVec32::kLanes); +} +#endif + +template <typename pixel_t> +void FillRowRGB8(const unsigned char* rgba, size_t oxs, pixel_t* y, pixel_t* co, + pixel_t* cg) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadRGB8(rgba + 3 * x); + StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x); + } +#endif + for (; x < oxs; x++) { + uint16_t r = rgba[3 * x]; + uint16_t g = rgba[3 * x + 1]; + uint16_t b = rgba[3 * x + 2]; + StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x); + } +} + +template <bool big_endian, typename pixel_t> +void FillRowRGB16(const unsigned char* rgba, size_t oxs, pixel_t* y, + pixel_t* co, pixel_t* cg) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadRGB16(rgba + 6 * x); + if (big_endian) { + rgb[0].SwapEndian(); + rgb[1].SwapEndian(); + rgb[2].SwapEndian(); + } + StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x); + } +#endif + for (; x < oxs; x++) { + uint16_t r = LoadLE16(rgba + 6 * x); + uint16_t g = LoadLE16(rgba + 6 * x + 2); + uint16_t b = LoadLE16(rgba + 6 * x + 4); + if (big_endian) { + r = SwapEndian(r); + g = SwapEndian(g); + b = SwapEndian(b); + } + StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x); + } +} + +template <typename pixel_t> +void FillRowRGBA8(const unsigned char* rgba, size_t oxs, pixel_t* y, + pixel_t* co, pixel_t* cg, pixel_t* alpha) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadRGBA8(rgba + 4 * x); + StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x); + StorePixels(rgb[3], alpha + x); + } +#endif + for (; x < oxs; x++) { + uint16_t r = rgba[4 * x]; + uint16_t g = rgba[4 * x + 1]; + uint16_t b = rgba[4 * x + 2]; + uint16_t a = rgba[4 * x + 3]; + StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x); + alpha[x] = a; + } +} + +template <bool big_endian, typename pixel_t> +void FillRowRGBA16(const unsigned char* rgba, size_t oxs, pixel_t* y, + pixel_t* co, pixel_t* cg, pixel_t* alpha) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadRGBA16(rgba + 8 * x); + if (big_endian) { + rgb[0].SwapEndian(); + rgb[1].SwapEndian(); + rgb[2].SwapEndian(); + rgb[3].SwapEndian(); + } + StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x); + StorePixels(rgb[3], alpha + x); + } +#endif + for (; x < oxs; x++) { + uint16_t r = LoadLE16(rgba + 8 * x); + uint16_t g = LoadLE16(rgba + 8 * x + 2); + uint16_t b = LoadLE16(rgba + 8 * x + 4); + uint16_t a = LoadLE16(rgba + 8 * x + 6); + if (big_endian) { + r = SwapEndian(r); + g = SwapEndian(g); + b = SwapEndian(b); + a = SwapEndian(a); + } + StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x); + alpha[x] = a; + } +} + +template <typename Processor, typename BitDepth> +void ProcessImageArea(const unsigned char* rgba, size_t x0, size_t y0, + size_t xs, size_t yskip, size_t ys, size_t row_stride, + BitDepth bitdepth, size_t nb_chans, bool big_endian, + Processor* processors) { + constexpr size_t kPadding = 32; + + using pixel_t = typename BitDepth::pixel_t; + + constexpr size_t kAlign = 64; + constexpr size_t kAlignPixels = kAlign / sizeof(pixel_t); + + auto align = [=](pixel_t* ptr) { + size_t offset = reinterpret_cast<uintptr_t>(ptr) % kAlign; + if (offset) { + ptr += offset / sizeof(pixel_t); + } + return ptr; + }; + + constexpr size_t kNumPx = + (256 + kPadding * 2 + kAlignPixels + kAlignPixels - 1) / kAlignPixels * + kAlignPixels; + + std::vector<std::array<std::array<pixel_t, kNumPx>, 2>> group_data(nb_chans); + + for (size_t y = 0; y < ys; y++) { + const auto rgba_row = + rgba + row_stride * (y0 + y) + x0 * nb_chans * BitDepth::kInputBytes; + pixel_t* crow[4] = {}; + pixel_t* prow[4] = {}; + for (size_t i = 0; i < nb_chans; i++) { + crow[i] = align(&group_data[i][y & 1][kPadding]); + prow[i] = align(&group_data[i][(y - 1) & 1][kPadding]); + } + + // Pre-fill rows with YCoCg converted pixels. + if (nb_chans == 1) { + if (BitDepth::kInputBytes == 1) { + FillRowG8(rgba_row, xs, crow[0]); + } else if (big_endian) { + FillRowG16</*big_endian=*/true>(rgba_row, xs, crow[0]); + } else { + FillRowG16</*big_endian=*/false>(rgba_row, xs, crow[0]); + } + } else if (nb_chans == 2) { + if (BitDepth::kInputBytes == 1) { + FillRowGA8(rgba_row, xs, crow[0], crow[1]); + } else if (big_endian) { + FillRowGA16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1]); + } else { + FillRowGA16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1]); + } + } else if (nb_chans == 3) { + if (BitDepth::kInputBytes == 1) { + FillRowRGB8(rgba_row, xs, crow[0], crow[1], crow[2]); + } else if (big_endian) { + FillRowRGB16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1], + crow[2]); + } else { + FillRowRGB16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1], + crow[2]); + } + } else { + if (BitDepth::kInputBytes == 1) { + FillRowRGBA8(rgba_row, xs, crow[0], crow[1], crow[2], crow[3]); + } else if (big_endian) { + FillRowRGBA16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1], + crow[2], crow[3]); + } else { + FillRowRGBA16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1], + crow[2], crow[3]); + } + } + // Deal with x == 0. + for (size_t c = 0; c < nb_chans; c++) { + *(crow[c] - 1) = y > 0 ? *(prow[c]) : 0; + // Fix topleft. + *(prow[c] - 1) = y > 0 ? *(prow[c]) : 0; + } + if (y < yskip) continue; + for (size_t c = 0; c < nb_chans; c++) { + // Get pointers to px/left/top/topleft data to speedup loop. + const pixel_t* row = crow[c]; + const pixel_t* row_left = crow[c] - 1; + const pixel_t* row_top = y == 0 ? row_left : prow[c]; + const pixel_t* row_topleft = y == 0 ? row_left : prow[c] - 1; + + processors[c].ProcessRow(row, row_left, row_top, row_topleft, xs); + } + } + for (size_t c = 0; c < nb_chans; c++) { + processors[c].Finalize(); + } +} + +template <typename BitDepth> +void WriteACSection(const unsigned char* rgba, size_t x0, size_t y0, size_t xs, + size_t ys, size_t row_stride, bool is_single_group, + BitDepth bitdepth, size_t nb_chans, bool big_endian, + const PrefixCode code[4], + std::array<BitWriter, 4>& output) { + for (size_t i = 0; i < nb_chans; i++) { + if (is_single_group && i == 0) continue; + output[i].Allocate(xs * ys * bitdepth.MaxEncodedBitsPerSample() + 4); + } + if (!is_single_group) { + // Group header for modular image. + // When the image is single-group, the global modular image is the one + // that contains the pixel data, and there is no group header. + output[0].Write(1, 1); // Global tree + output[0].Write(1, 1); // All default wp + output[0].Write(2, 0b00); // 0 transforms + } + + ChunkEncoder<BitDepth> encoders[4]; + ChannelRowProcessor<ChunkEncoder<BitDepth>, BitDepth> row_encoders[4]; + for (size_t c = 0; c < nb_chans; c++) { + row_encoders[c].t = &encoders[c]; + encoders[c].output = &output[c]; + encoders[c].code = &code[c]; + encoders[c].PrepareForSimd(); + } + ProcessImageArea<ChannelRowProcessor<ChunkEncoder<BitDepth>, BitDepth>>( + rgba, x0, y0, xs, 0, ys, row_stride, bitdepth, nb_chans, big_endian, + row_encoders); +} + +constexpr int kHashExp = 16; +constexpr uint32_t kHashSize = 1 << kHashExp; +constexpr uint32_t kHashMultiplier = 2654435761; +constexpr int kMaxColors = 512; + +// can be any function that returns a value in 0 .. kHashSize-1 +// has to map 0 to 0 +inline uint32_t pixel_hash(uint32_t p) { + return (p * kHashMultiplier) >> (32 - kHashExp); +} + +template <size_t nb_chans> +void FillRowPalette(const unsigned char* inrow, size_t xs, + const int16_t* lookup, int16_t* out) { + for (size_t x = 0; x < xs; x++) { + uint32_t p = 0; + memcpy(&p, inrow + x * nb_chans, nb_chans); + out[x] = lookup[pixel_hash(p)]; + } +} + +template <typename Processor> +void ProcessImageAreaPalette(const unsigned char* rgba, size_t x0, size_t y0, + size_t xs, size_t yskip, size_t ys, + size_t row_stride, const int16_t* lookup, + size_t nb_chans, Processor* processors) { + constexpr size_t kPadding = 32; + + std::vector<std::array<int16_t, 256 + kPadding * 2>> group_data(2); + Processor& row_encoder = processors[0]; + + for (size_t y = 0; y < ys; y++) { + // Pre-fill rows with palette converted pixels. + const unsigned char* inrow = rgba + row_stride * (y0 + y) + x0 * nb_chans; + int16_t* outrow = &group_data[y & 1][kPadding]; + if (nb_chans == 1) { + FillRowPalette<1>(inrow, xs, lookup, outrow); + } else if (nb_chans == 2) { + FillRowPalette<2>(inrow, xs, lookup, outrow); + } else if (nb_chans == 3) { + FillRowPalette<3>(inrow, xs, lookup, outrow); + } else if (nb_chans == 4) { + FillRowPalette<4>(inrow, xs, lookup, outrow); + } + // Deal with x == 0. + group_data[y & 1][kPadding - 1] = + y > 0 ? group_data[(y - 1) & 1][kPadding] : 0; + // Fix topleft. + group_data[(y - 1) & 1][kPadding - 1] = + y > 0 ? group_data[(y - 1) & 1][kPadding] : 0; + // Get pointers to px/left/top/topleft data to speedup loop. + const int16_t* row = &group_data[y & 1][kPadding]; + const int16_t* row_left = &group_data[y & 1][kPadding - 1]; + const int16_t* row_top = + y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding]; + const int16_t* row_topleft = + y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding - 1]; + + row_encoder.ProcessRow(row, row_left, row_top, row_topleft, xs); + } + row_encoder.Finalize(); +} + +void WriteACSectionPalette(const unsigned char* rgba, size_t x0, size_t y0, + size_t xs, size_t ys, size_t row_stride, + bool is_single_group, const PrefixCode code[4], + const int16_t* lookup, size_t nb_chans, + BitWriter& output) { + if (!is_single_group) { + output.Allocate(16 * xs * ys + 4); + // Group header for modular image. + // When the image is single-group, the global modular image is the one + // that contains the pixel data, and there is no group header. + output.Write(1, 1); // Global tree + output.Write(1, 1); // All default wp + output.Write(2, 0b00); // 0 transforms + } + + ChunkEncoder<UpTo8Bits> encoder; + ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits> row_encoder; + + row_encoder.t = &encoder; + encoder.output = &output; + encoder.code = &code[is_single_group ? 1 : 0]; + encoder.PrepareForSimd(); + ProcessImageAreaPalette< + ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits>>( + rgba, x0, y0, xs, 0, ys, row_stride, lookup, nb_chans, &row_encoder); +} + +template <typename BitDepth> +void CollectSamples(const unsigned char* rgba, size_t x0, size_t y0, size_t xs, + size_t row_stride, size_t row_count, + uint64_t raw_counts[4][kNumRawSymbols], + uint64_t lz77_counts[4][kNumLZ77], bool is_single_group, + bool palette, BitDepth bitdepth, size_t nb_chans, + bool big_endian, const int16_t* lookup) { + if (palette) { + ChunkSampleCollector<UpTo8Bits> sample_collectors[4]; + ChannelRowProcessor<ChunkSampleCollector<UpTo8Bits>, UpTo8Bits> + row_sample_collectors[4]; + for (size_t c = 0; c < nb_chans; c++) { + row_sample_collectors[c].t = &sample_collectors[c]; + sample_collectors[c].raw_counts = raw_counts[is_single_group ? 1 : 0]; + sample_collectors[c].lz77_counts = lz77_counts[is_single_group ? 1 : 0]; + } + ProcessImageAreaPalette< + ChannelRowProcessor<ChunkSampleCollector<UpTo8Bits>, UpTo8Bits>>( + rgba, x0, y0, xs, 1, 1 + row_count, row_stride, lookup, nb_chans, + row_sample_collectors); + } else { + ChunkSampleCollector<BitDepth> sample_collectors[4]; + ChannelRowProcessor<ChunkSampleCollector<BitDepth>, BitDepth> + row_sample_collectors[4]; + for (size_t c = 0; c < nb_chans; c++) { + row_sample_collectors[c].t = &sample_collectors[c]; + sample_collectors[c].raw_counts = raw_counts[c]; + sample_collectors[c].lz77_counts = lz77_counts[c]; + } + ProcessImageArea< + ChannelRowProcessor<ChunkSampleCollector<BitDepth>, BitDepth>>( + rgba, x0, y0, xs, 1, 1 + row_count, row_stride, bitdepth, nb_chans, + big_endian, row_sample_collectors); + } +} + +void PrepareDCGlobalPalette(bool is_single_group, size_t width, size_t height, + size_t nb_chans, const PrefixCode code[4], + const std::vector<uint32_t>& palette, + size_t pcolors, BitWriter* output) { + PrepareDCGlobalCommon(is_single_group, width, height, code, output); + output->Write(2, 0b01); // 1 transform + output->Write(2, 0b01); // Palette + output->Write(5, 0b00000); // Starting from ch 0 + if (nb_chans == 1) { + output->Write(2, 0b00); // 1-channel palette (Gray) + } else if (nb_chans == 3) { + output->Write(2, 0b01); // 3-channel palette (RGB) + } else if (nb_chans == 4) { + output->Write(2, 0b10); // 4-channel palette (RGBA) + } else { + output->Write(2, 0b11); + output->Write(13, nb_chans - 1); + } + // pcolors <= kMaxColors + kChunkSize - 1 + static_assert(kMaxColors + kChunkSize < 1281, + "add code to signal larger palette sizes"); + if (pcolors < 256) { + output->Write(2, 0b00); + output->Write(8, pcolors); + } else { + output->Write(2, 0b01); + output->Write(10, pcolors - 256); + } + + output->Write(2, 0b00); // nb_deltas == 0 + output->Write(4, 0); // Zero predictor for delta palette + // Encode palette + ChunkEncoder<UpTo8Bits> encoder; + ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits> row_encoder; + row_encoder.t = &encoder; + encoder.output = output; + encoder.code = &code[0]; + encoder.PrepareForSimd(); + int16_t p[4][32 + 1024] = {}; + uint8_t prgba[4]; + size_t i = 0; + size_t have_zero = 0; + if (palette[pcolors - 1] == 0) have_zero = 1; + for (; i < pcolors; i++) { + memcpy(prgba, &palette[i], 4); + p[0][16 + i + have_zero] = prgba[0]; + p[1][16 + i + have_zero] = prgba[1]; + p[2][16 + i + have_zero] = prgba[2]; + p[3][16 + i + have_zero] = prgba[3]; + } + p[0][15] = 0; + row_encoder.ProcessRow(p[0] + 16, p[0] + 15, p[0] + 15, p[0] + 15, pcolors); + p[1][15] = p[0][16]; + p[0][15] = p[0][16]; + if (nb_chans > 1) { + row_encoder.ProcessRow(p[1] + 16, p[1] + 15, p[0] + 16, p[0] + 15, pcolors); + } + p[2][15] = p[1][16]; + p[1][15] = p[1][16]; + if (nb_chans > 2) { + row_encoder.ProcessRow(p[2] + 16, p[2] + 15, p[1] + 16, p[1] + 15, pcolors); + } + p[3][15] = p[2][16]; + p[2][15] = p[2][16]; + if (nb_chans > 3) { + row_encoder.ProcessRow(p[3] + 16, p[3] + 15, p[2] + 16, p[2] + 15, pcolors); + } + row_encoder.Finalize(); + + if (!is_single_group) { + output->ZeroPadToByte(); + } +} + +template <size_t nb_chans> +bool detect_palette(const unsigned char* r, size_t width, + std::vector<uint32_t>& palette) { + size_t x = 0; + bool collided = false; + // this is just an unrolling of the next loop + for (; x + 7 < width; x += 8) { + uint32_t p[8] = {}, index[8]; + for (int i = 0; i < 8; i++) memcpy(&p[i], r + (x + i) * nb_chans, 4); + for (int i = 0; i < 8; i++) p[i] &= ((1llu << (8 * nb_chans)) - 1); + for (int i = 0; i < 8; i++) index[i] = pixel_hash(p[i]); + for (int i = 0; i < 8; i++) { + collided |= (palette[index[i]] != 0 && p[i] != palette[index[i]]); + } + for (int i = 0; i < 8; i++) palette[index[i]] = p[i]; + } + for (; x < width; x++) { + uint32_t p = 0; + memcpy(&p, r + x * nb_chans, nb_chans); + uint32_t index = pixel_hash(p); + collided |= (palette[index] != 0 && p != palette[index]); + palette[index] = p; + } + return collided; +} + +template <typename BitDepth> +JxlFastLosslessFrameState* LLPrepare(JxlChunkedFrameInputSource input, + size_t width, size_t height, + BitDepth bitdepth, size_t nb_chans, + bool big_endian, int effort, int oneshot) { + assert(width != 0); + assert(height != 0); + + // Count colors to try palette + std::vector<uint32_t> palette(kHashSize); + std::vector<int16_t> lookup(kHashSize); + lookup[0] = 0; + int pcolors = 0; + bool collided = effort < 2 || bitdepth.bitdepth != 8 || !oneshot; + for (size_t y0 = 0; y0 < height && !collided; y0 += 256) { + size_t ys = std::min<size_t>(height - y0, 256); + for (size_t x0 = 0; x0 < width && !collided; x0 += 256) { + size_t xs = std::min<size_t>(width - x0, 256); + size_t stride; + // TODO(szabadka): Add RAII wrapper around this. + const void* buffer = input.get_color_channel_data_at(input.opaque, x0, y0, + xs, ys, &stride); + auto rgba = reinterpret_cast<const unsigned char*>(buffer); + for (size_t y = 0; y < ys && !collided; y++) { + const unsigned char* r = rgba + stride * y; + if (nb_chans == 1) collided = detect_palette<1>(r, xs, palette); + if (nb_chans == 2) collided = detect_palette<2>(r, xs, palette); + if (nb_chans == 3) collided = detect_palette<3>(r, xs, palette); + if (nb_chans == 4) collided = detect_palette<4>(r, xs, palette); + } + input.release_buffer(input.opaque, buffer); + } + } + int nb_entries = 0; + if (!collided) { + pcolors = 1; // always have all-zero as a palette color + bool have_color = false; + uint8_t minG = 255, maxG = 0; + for (uint32_t k = 0; k < kHashSize; k++) { + if (palette[k] == 0) continue; + uint8_t p[4]; + memcpy(p, &palette[k], 4); + // move entries to front so sort has less work + palette[nb_entries] = palette[k]; + if (p[0] != p[1] || p[0] != p[2]) have_color = true; + if (p[1] < minG) minG = p[1]; + if (p[1] > maxG) maxG = p[1]; + nb_entries++; + // don't do palette if too many colors are needed + if (nb_entries + pcolors > kMaxColors) { + collided = true; + break; + } + } + if (!have_color) { + // don't do palette if it's just grayscale without many holes + if (maxG - minG < nb_entries * 1.4f) collided = true; + } + } + if (!collided) { + std::sort( + palette.begin(), palette.begin() + nb_entries, + [&nb_chans](uint32_t ap, uint32_t bp) { + if (ap == 0) return false; + if (bp == 0) return true; + uint8_t a[4], b[4]; + memcpy(a, &ap, 4); + memcpy(b, &bp, 4); + float ay, by; + if (nb_chans == 4) { + ay = (0.299f * a[0] + 0.587f * a[1] + 0.114f * a[2] + 0.01f) * a[3]; + by = (0.299f * b[0] + 0.587f * b[1] + 0.114f * b[2] + 0.01f) * b[3]; + } else { + ay = (0.299f * a[0] + 0.587f * a[1] + 0.114f * a[2] + 0.01f); + by = (0.299f * b[0] + 0.587f * b[1] + 0.114f * b[2] + 0.01f); + } + return ay < by; // sort on alpha*luma + }); + for (int k = 0; k < nb_entries; k++) { + if (palette[k] == 0) break; + lookup[pixel_hash(palette[k])] = pcolors++; + } + } + + size_t num_groups_x = (width + 255) / 256; + size_t num_groups_y = (height + 255) / 256; + size_t num_dc_groups_x = (width + 2047) / 2048; + size_t num_dc_groups_y = (height + 2047) / 2048; + + uint64_t raw_counts[4][kNumRawSymbols] = {}; + uint64_t lz77_counts[4][kNumLZ77] = {}; + + bool onegroup = num_groups_x == 1 && num_groups_y == 1; + + auto sample_rows = [&](size_t xg, size_t yg, size_t num_rows) { + size_t y0 = yg * 256; + size_t x0 = xg * 256; + size_t ys = std::min<size_t>(height - y0, 256); + size_t xs = std::min<size_t>(width - x0, 256); + size_t stride; + const void* buffer = + input.get_color_channel_data_at(input.opaque, x0, y0, xs, ys, &stride); + auto rgba = reinterpret_cast<const unsigned char*>(buffer); + int y_begin = std::max<int>(0, ys - 2 * effort) / 2; + int y_count = std::min<int>(num_rows, y0 + ys - y_begin - 1); + int x_max = xs / kChunkSize * kChunkSize; + CollectSamples(rgba, 0, y_begin, x_max, stride, y_count, raw_counts, + lz77_counts, onegroup, !collided, bitdepth, nb_chans, + big_endian, lookup.data()); + input.release_buffer(input.opaque, buffer); + }; + + // TODO(veluca): that `64` is an arbitrary constant, meant to correspond to + // the point where the number of processed rows is large enough that loading + // the entire image is cost-effective. + if (oneshot || effort >= 64) { + for (size_t g = 0; g < num_groups_y * num_groups_x; g++) { + size_t xg = g % num_groups_x; + size_t yg = g / num_groups_x; + size_t y0 = yg * 256; + size_t ys = std::min<size_t>(height - y0, 256); + size_t num_rows = 2 * effort * ys / 256; + sample_rows(xg, yg, num_rows); + } + } else { + // sample the middle (effort * 2 * num_groups) rows of the center group + // (possibly all of them). + sample_rows((num_groups_x - 1) / 2, (num_groups_y - 1) / 2, + 2 * effort * num_groups_x * num_groups_y); + } + + // TODO(veluca): can probably improve this and make it bitdepth-dependent. + uint64_t base_raw_counts[kNumRawSymbols] = { + 3843, 852, 1270, 1214, 1014, 727, 481, 300, 159, 51, + 5, 1, 1, 1, 1, 1, 1, 1, 1}; + + bool doing_ycocg = nb_chans > 2 && collided; + bool large_palette = !collided || pcolors >= 256; + for (size_t i = bitdepth.NumSymbols(doing_ycocg || large_palette); + i < kNumRawSymbols; i++) { + base_raw_counts[i] = 0; + } + + for (size_t c = 0; c < 4; c++) { + for (size_t i = 0; i < kNumRawSymbols; i++) { + raw_counts[c][i] = (raw_counts[c][i] << 8) + base_raw_counts[i]; + } + } + + if (!collided) { + unsigned token, nbits, bits; + EncodeHybridUint000(PackSigned(pcolors - 1), &token, &nbits, &bits); + // ensure all palette indices can actually be encoded + for (size_t i = 0; i < token + 1; i++) + raw_counts[0][i] = std::max<uint64_t>(raw_counts[0][i], 1); + // these tokens are only used for the palette itself so they can get a bad + // code + for (size_t i = token + 1; i < 10; i++) raw_counts[0][i] = 1; + } + + uint64_t base_lz77_counts[kNumLZ77] = { + 29, 27, 25, 23, 21, 21, 19, 18, 21, 17, 16, 15, 15, 14, + 13, 13, 137, 98, 61, 34, 1, 1, 1, 1, 1, 1, 1, 1, + }; + + for (size_t c = 0; c < 4; c++) { + for (size_t i = 0; i < kNumLZ77; i++) { + lz77_counts[c][i] = (lz77_counts[c][i] << 8) + base_lz77_counts[i]; + } + } + + JxlFastLosslessFrameState* frame_state = new JxlFastLosslessFrameState(); + for (size_t i = 0; i < 4; i++) { + frame_state->hcode[i] = PrefixCode(bitdepth, raw_counts[i], lz77_counts[i]); + } + + size_t num_dc_groups = num_dc_groups_x * num_dc_groups_y; + size_t num_ac_groups = num_groups_x * num_groups_y; + size_t num_groups = onegroup ? 1 : (2 + num_dc_groups + num_ac_groups); + frame_state->input = input; + frame_state->width = width; + frame_state->height = height; + frame_state->num_groups_x = num_groups_x; + frame_state->num_groups_y = num_groups_y; + frame_state->num_dc_groups_x = num_dc_groups_x; + frame_state->num_dc_groups_y = num_dc_groups_y; + frame_state->nb_chans = nb_chans; + frame_state->bitdepth = bitdepth.bitdepth; + frame_state->big_endian = big_endian; + frame_state->effort = effort; + frame_state->collided = collided; + frame_state->lookup = lookup; + + frame_state->group_data = std::vector<std::array<BitWriter, 4>>(num_groups); + frame_state->group_sizes.resize(num_groups); + if (collided) { + PrepareDCGlobal(onegroup, width, height, nb_chans, frame_state->hcode, + &frame_state->group_data[0][0]); + } else { + PrepareDCGlobalPalette(onegroup, width, height, nb_chans, + frame_state->hcode, palette, pcolors, + &frame_state->group_data[0][0]); + } + frame_state->group_sizes[0] = SectionSize(frame_state->group_data[0]); + if (!onegroup) { + ComputeAcGroupDataOffset(frame_state->group_sizes[0], num_dc_groups, + num_ac_groups, frame_state->min_dc_global_size, + frame_state->ac_group_data_offset); + } + + return frame_state; +} + +template <typename BitDepth> +void LLProcess(JxlFastLosslessFrameState* frame_state, bool is_last, + BitDepth bitdepth, void* runner_opaque, + FJxlParallelRunner runner, + JxlEncoderOutputProcessorWrapper* output_processor) { +#if !FJXL_STANDALONE + if (frame_state->process_done) { + JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/0, is_last); + if (output_processor) { + JxlFastLosslessOutputFrame(frame_state, output_processor); + } + return; + } +#endif + // The maximum number of groups that we process concurrently here. + // TODO(szabadka) Use the number of threads or some outside parameter for the + // maximum memory usage instead. + constexpr size_t kMaxLocalGroups = 16; + bool onegroup = frame_state->group_sizes.size() == 1; + bool streaming = !onegroup && output_processor; + size_t total_groups = frame_state->num_groups_x * frame_state->num_groups_y; + size_t max_groups = streaming ? kMaxLocalGroups : total_groups; +#if !FJXL_STANDALONE + size_t start_pos = 0; + if (streaming) { + start_pos = output_processor->CurrentPosition(); + output_processor->Seek(start_pos + frame_state->ac_group_data_offset); + } +#endif + for (size_t offset = 0; offset < total_groups; offset += max_groups) { + size_t num_groups = std::min(max_groups, total_groups - offset); + JxlFastLosslessFrameState local_frame_state; + if (streaming) { + local_frame_state.group_data = + std::vector<std::array<BitWriter, 4>>(num_groups); + } + auto run_one = [&](size_t i) { + size_t g = offset + i; + size_t xg = g % frame_state->num_groups_x; + size_t yg = g / frame_state->num_groups_x; + size_t num_dc_groups = + frame_state->num_dc_groups_x * frame_state->num_dc_groups_y; + size_t group_id = onegroup ? 0 : (2 + num_dc_groups + g); + size_t xs = std::min<size_t>(frame_state->width - xg * 256, 256); + size_t ys = std::min<size_t>(frame_state->height - yg * 256, 256); + size_t x0 = xg * 256; + size_t y0 = yg * 256; + size_t stride; + JxlChunkedFrameInputSource input = frame_state->input; + const void* buffer = input.get_color_channel_data_at(input.opaque, x0, y0, + xs, ys, &stride); + const unsigned char* rgba = + reinterpret_cast<const unsigned char*>(buffer); + + auto& gd = streaming ? local_frame_state.group_data[i] + : frame_state->group_data[group_id]; + if (frame_state->collided) { + WriteACSection(rgba, 0, 0, xs, ys, stride, onegroup, bitdepth, + frame_state->nb_chans, frame_state->big_endian, + frame_state->hcode, gd); + } else { + WriteACSectionPalette(rgba, 0, 0, xs, ys, stride, onegroup, + frame_state->hcode, frame_state->lookup.data(), + frame_state->nb_chans, gd[0]); + } + frame_state->group_sizes[group_id] = SectionSize(gd); + input.release_buffer(input.opaque, buffer); + }; + runner( + runner_opaque, &run_one, + +[](void* r, size_t i) { + (*reinterpret_cast<decltype(&run_one)>(r))(i); + }, + num_groups); +#if !FJXL_STANDALONE + if (streaming) { + local_frame_state.nb_chans = frame_state->nb_chans; + local_frame_state.current_bit_writer = 1; + JxlFastLosslessOutputFrame(&local_frame_state, output_processor); + } +#endif + } +#if !FJXL_STANDALONE + if (streaming) { + size_t end_pos = output_processor->CurrentPosition(); + output_processor->Seek(start_pos); + frame_state->group_data.resize(1); + bool have_alpha = frame_state->nb_chans == 2 || frame_state->nb_chans == 4; + size_t padding = ComputeDcGlobalPadding( + frame_state->group_sizes, frame_state->ac_group_data_offset, + frame_state->min_dc_global_size, have_alpha, is_last); + + for (size_t i = 0; i < padding; ++i) { + frame_state->group_data[0][0].Write(8, 0); + } + frame_state->group_sizes[0] += padding; + JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/0, is_last); + assert(frame_state->ac_group_data_offset == + JxlFastLosslessOutputSize(frame_state)); + JxlFastLosslessOutputHeaders(frame_state, output_processor); + output_processor->Seek(end_pos); + } else if (output_processor) { + assert(onegroup); + JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/0, is_last); + if (output_processor) { + JxlFastLosslessOutputFrame(frame_state, output_processor); + } + } + frame_state->process_done = true; +#endif +} + +JxlFastLosslessFrameState* JxlFastLosslessPrepareImpl( + JxlChunkedFrameInputSource input, size_t width, size_t height, + size_t nb_chans, size_t bitdepth, bool big_endian, int effort, + int oneshot) { + assert(bitdepth > 0); + assert(nb_chans <= 4); + assert(nb_chans != 0); + if (bitdepth <= 8) { + return LLPrepare(input, width, height, UpTo8Bits(bitdepth), nb_chans, + big_endian, effort, oneshot); + } + if (bitdepth <= 13) { + return LLPrepare(input, width, height, From9To13Bits(bitdepth), nb_chans, + big_endian, effort, oneshot); + } + if (bitdepth == 14) { + return LLPrepare(input, width, height, Exactly14Bits(bitdepth), nb_chans, + big_endian, effort, oneshot); + } + return LLPrepare(input, width, height, MoreThan14Bits(bitdepth), nb_chans, + big_endian, effort, oneshot); +} + +void JxlFastLosslessProcessFrameImpl( + JxlFastLosslessFrameState* frame_state, bool is_last, void* runner_opaque, + FJxlParallelRunner runner, + JxlEncoderOutputProcessorWrapper* output_processor) { + const size_t bitdepth = frame_state->bitdepth; + if (bitdepth <= 8) { + LLProcess(frame_state, is_last, UpTo8Bits(bitdepth), runner_opaque, runner, + output_processor); + } else if (bitdepth <= 13) { + LLProcess(frame_state, is_last, From9To13Bits(bitdepth), runner_opaque, + runner, output_processor); + } else if (bitdepth == 14) { + LLProcess(frame_state, is_last, Exactly14Bits(bitdepth), runner_opaque, + runner, output_processor); + } else { + LLProcess(frame_state, is_last, MoreThan14Bits(bitdepth), runner_opaque, + runner, output_processor); + } +} + +} // namespace + +#endif // FJXL_SELF_INCLUDE + +#ifndef FJXL_SELF_INCLUDE + +#define FJXL_SELF_INCLUDE + +// If we have NEON enabled, it is the default target. +#if FJXL_ENABLE_NEON + +namespace default_implementation { +#define FJXL_NEON +#include "lib/jxl/enc_fast_lossless.cc" +#undef FJXL_NEON +} // namespace default_implementation + +#else // FJXL_ENABLE_NEON + +namespace default_implementation { +#include "lib/jxl/enc_fast_lossless.cc" +} + +#if FJXL_ENABLE_AVX2 +#ifdef __clang__ +#pragma clang attribute push(__attribute__((target("avx,avx2"))), \ + apply_to = function) +// Causes spurious warnings on clang5. +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wmissing-braces" +#elif defined(__GNUC__) +#pragma GCC push_options +// Seems to cause spurious errors on GCC8. +#pragma GCC diagnostic ignored "-Wpsabi" +#pragma GCC target "avx,avx2" +#endif + +namespace AVX2 { +#define FJXL_AVX2 +#include "lib/jxl/enc_fast_lossless.cc" +#undef FJXL_AVX2 +} // namespace AVX2 + +#ifdef __clang__ +#pragma clang attribute pop +#pragma clang diagnostic pop +#elif defined(__GNUC__) +#pragma GCC pop_options +#endif +#endif // FJXL_ENABLE_AVX2 + +#if FJXL_ENABLE_AVX512 +#ifdef __clang__ +#pragma clang attribute push( \ + __attribute__((target("avx512cd,avx512bw,avx512vl,avx512f,avx512vbmi"))), \ + apply_to = function) +#elif defined(__GNUC__) +#pragma GCC push_options +#pragma GCC target "avx512cd,avx512bw,avx512vl,avx512f,avx512vbmi" +#endif + +namespace AVX512 { +#define FJXL_AVX512 +#include "lib/jxl/enc_fast_lossless.cc" +#undef FJXL_AVX512 +} // namespace AVX512 + +#ifdef __clang__ +#pragma clang attribute pop +#elif defined(__GNUC__) +#pragma GCC pop_options +#endif +#endif // FJXL_ENABLE_AVX512 + +#endif + +extern "C" { + +#if FJXL_STANDALONE +class FJxlFrameInput { + public: + FJxlFrameInput(const unsigned char* rgba, size_t row_stride, size_t nb_chans, + size_t bitdepth) + : rgba_(rgba), + row_stride_(row_stride), + bytes_per_pixel_(bitdepth <= 8 ? nb_chans : 2 * nb_chans) {} + + JxlChunkedFrameInputSource GetInputSource() { + return JxlChunkedFrameInputSource{this, GetDataAt, + [](void*, const void*) {}}; + } + + private: + static const void* GetDataAt(void* opaque, size_t xpos, size_t ypos, + size_t xsize, size_t ysize, size_t* row_offset) { + FJxlFrameInput* self = static_cast<FJxlFrameInput*>(opaque); + *row_offset = self->row_stride_; + return self->rgba_ + ypos * (*row_offset) + xpos * self->bytes_per_pixel_; + } + + const uint8_t* rgba_; + size_t row_stride_; + size_t bytes_per_pixel_; +}; + +size_t JxlFastLosslessEncode(const unsigned char* rgba, size_t width, + size_t row_stride, size_t height, size_t nb_chans, + size_t bitdepth, int big_endian, int effort, + unsigned char** output, void* runner_opaque, + FJxlParallelRunner runner) { + FJxlFrameInput input(rgba, row_stride, nb_chans, bitdepth); + auto frame_state = JxlFastLosslessPrepareFrame( + input.GetInputSource(), width, height, nb_chans, bitdepth, big_endian, + effort, /*oneshot=*/true); + JxlFastLosslessProcessFrame(frame_state, /*is_last=*/true, runner_opaque, + runner, nullptr); + JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/1, + /*is_last=*/1); + size_t output_size = JxlFastLosslessMaxRequiredOutput(frame_state); + *output = (unsigned char*)malloc(output_size); + size_t written = 0; + size_t total = 0; + while ((written = JxlFastLosslessWriteOutput(frame_state, *output + total, + output_size - total)) != 0) { + total += written; + } + JxlFastLosslessFreeFrameState(frame_state); + return total; +} +#endif + +JxlFastLosslessFrameState* JxlFastLosslessPrepareFrame( + JxlChunkedFrameInputSource input, size_t width, size_t height, + size_t nb_chans, size_t bitdepth, int big_endian, int effort, int oneshot) { +#if FJXL_ENABLE_AVX512 + if (__builtin_cpu_supports("avx512cd") && + __builtin_cpu_supports("avx512vbmi") && + __builtin_cpu_supports("avx512bw") && __builtin_cpu_supports("avx512f") && + __builtin_cpu_supports("avx512vl")) { + return AVX512::JxlFastLosslessPrepareImpl( + input, width, height, nb_chans, bitdepth, big_endian, effort, oneshot); + } +#endif +#if FJXL_ENABLE_AVX2 + if (__builtin_cpu_supports("avx2")) { + return AVX2::JxlFastLosslessPrepareImpl( + input, width, height, nb_chans, bitdepth, big_endian, effort, oneshot); + } +#endif + + return default_implementation::JxlFastLosslessPrepareImpl( + input, width, height, nb_chans, bitdepth, big_endian, effort, oneshot); +} + +void JxlFastLosslessProcessFrame( + JxlFastLosslessFrameState* frame_state, bool is_last, void* runner_opaque, + FJxlParallelRunner runner, + JxlEncoderOutputProcessorWrapper* output_processor) { + auto trivial_runner = + +[](void*, void* opaque, void fun(void*, size_t), size_t count) { + for (size_t i = 0; i < count; i++) { + fun(opaque, i); + } + }; + + if (runner == nullptr) { + runner = trivial_runner; + } + +#if FJXL_ENABLE_AVX512 + if (__builtin_cpu_supports("avx512cd") && + __builtin_cpu_supports("avx512vbmi") && + __builtin_cpu_supports("avx512bw") && __builtin_cpu_supports("avx512f") && + __builtin_cpu_supports("avx512vl")) { + return AVX512::JxlFastLosslessProcessFrameImpl( + frame_state, is_last, runner_opaque, runner, output_processor); + } +#endif +#if FJXL_ENABLE_AVX2 + if (__builtin_cpu_supports("avx2")) { + return AVX2::JxlFastLosslessProcessFrameImpl( + frame_state, is_last, runner_opaque, runner, output_processor); + } +#endif + + return default_implementation::JxlFastLosslessProcessFrameImpl( + frame_state, is_last, runner_opaque, runner, output_processor); +} + +} // extern "C" + +#if !FJXL_STANDALONE +void JxlFastLosslessOutputFrame( + JxlFastLosslessFrameState* frame_state, + JxlEncoderOutputProcessorWrapper* output_processor) { + size_t fl_size = JxlFastLosslessOutputSize(frame_state); + size_t written = 0; + while (written < fl_size) { + auto retval = output_processor->GetBuffer(32, fl_size - written); + assert(retval.status()); + auto buffer = std::move(retval).value(); + size_t n = + JxlFastLosslessWriteOutput(frame_state, buffer.data(), buffer.size()); + if (n == 0) break; + buffer.advance(n); + written += n; + }; +} +#endif + +#endif // FJXL_SELF_INCLUDE diff --git a/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.h b/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.h new file mode 100644 index 0000000000..33c3aa54fa --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.h @@ -0,0 +1,117 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_FAST_LOSSLESS_H_ +#define LIB_JXL_ENC_FAST_LOSSLESS_H_ +#include <stdlib.h> + +// FJXL_STANDALONE=1 for a stand-alone jxl encoder +// FJXL_STANDALONE=0 for use in libjxl to encode frames (but no image header) +#ifndef FJXL_STANDALONE +#define FJXL_STANDALONE 0 +#endif + +#if !FJXL_STANDALONE +#include <jxl/encode.h> +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#if FJXL_STANDALONE +// Simplified version of the streaming input source from jxl/encode.h +// We only need this part to wrap the full image buffer in the standalone mode +// and this way we don't need to depend on the jxl headers. +struct JxlChunkedFrameInputSource { + void* opaque; + const void* (*get_color_channel_data_at)(void* opaque, size_t xpos, + size_t ypos, size_t xsize, + size_t ysize, size_t* row_offset); + void (*release_buffer)(void* opaque, const void* buf); +}; +// The standalone version does not use this struct, but we define it here so +// that we don't have to clutter all the function signatures with defines. +struct JxlEncoderOutputProcessorWrapper { + int unused; +}; +#endif + +// Simple encoding API. + +// A FJxlParallelRunner must call fun(opaque, i) for all i from 0 to count. It +// may do so in parallel. +typedef void(FJxlParallelRunner)(void* runner_opaque, void* opaque, + void fun(void*, size_t), size_t count); + +#if FJXL_STANDALONE +// You may pass `nullptr` as a runner: encoding will be sequential. +size_t JxlFastLosslessEncode(const unsigned char* rgba, size_t width, + size_t row_stride, size_t height, size_t nb_chans, + size_t bitdepth, int big_endian, int effort, + unsigned char** output, void* runner_opaque, + FJxlParallelRunner runner); +#endif + +// More complex API for cases in which you may want to allocate your own buffer +// and other advanced use cases. + +// Opaque struct that represents an intermediate state of the computation. +struct JxlFastLosslessFrameState; + +// Returned JxlFastLosslessFrameState must be freed by calling +// JxlFastLosslessFreeFrameState. +JxlFastLosslessFrameState* JxlFastLosslessPrepareFrame( + JxlChunkedFrameInputSource input, size_t width, size_t height, + size_t nb_chans, size_t bitdepth, int big_endian, int effort, int oneshot); + +#if !FJXL_STANDALONE +class JxlEncoderOutputProcessorWrapper; +#endif + +void JxlFastLosslessProcessFrame( + JxlFastLosslessFrameState* frame_state, bool is_last, void* runner_opaque, + FJxlParallelRunner runner, + JxlEncoderOutputProcessorWrapper* output_processor); + +// Prepare the (image/frame) header. You may encode animations by concatenating +// the output of multiple frames, of which the first one has add_image_header = +// 1 and subsequent ones have add_image_header = 0, and all frames but the last +// one have is_last = 0. +// (when FJXL_STANDALONE=0, add_image_header has to be 0) +void JxlFastLosslessPrepareHeader(JxlFastLosslessFrameState* frame, + int add_image_header, int is_last); + +// Upper bound on the required output size, including any padding that may be +// required by JxlFastLosslessWriteOutput. Cannot be called before +// JxlFastLosslessPrepareHeader. +size_t JxlFastLosslessMaxRequiredOutput(const JxlFastLosslessFrameState* frame); + +// Actual size of the frame once it is encoded. This is not identical to +// JxlFastLosslessMaxRequiredOutput because JxlFastLosslessWriteOutput may +// require extra padding. +size_t JxlFastLosslessOutputSize(const JxlFastLosslessFrameState* frame); + +// Writes the frame to the given output buffer. Returns the number of bytes that +// were written, which is at least 1 unless the entire output has been written +// already. It is required that `output_size >= 32` when calling this function. +// This function must be called repeatedly until it returns 0. +size_t JxlFastLosslessWriteOutput(JxlFastLosslessFrameState* frame, + unsigned char* output, size_t output_size); + +// Frees the provided frame state. +void JxlFastLosslessFreeFrameState(JxlFastLosslessFrameState* frame); + +#ifdef __cplusplus +} // extern "C" +#endif + +#if !FJXL_STANDALONE +void JxlFastLosslessOutputFrame( + JxlFastLosslessFrameState* frame_state, + JxlEncoderOutputProcessorWrapper* output_process); +#endif + +#endif // LIB_JXL_ENC_FAST_LOSSLESS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_fields.cc b/third_party/jpeg-xl/lib/jxl/enc_fields.cc new file mode 100644 index 0000000000..dc0cbb7913 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_fields.cc @@ -0,0 +1,241 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_fields.h" + +#include <cinttypes> + +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/fields.h" + +namespace jxl { + +namespace { +using ::jxl::fields_internal::VisitorBase; +class WriteVisitor : public VisitorBase { + public: + WriteVisitor(const size_t extension_bits, BitWriter* JXL_RESTRICT writer) + : extension_bits_(extension_bits), writer_(writer) {} + + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + ok_ &= BitsCoder::Write(bits, *value, writer_); + return true; + } + Status U32(const U32Enc enc, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + ok_ &= U32Coder::Write(enc, *value, writer_); + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT value) override { + ok_ &= U64Coder::Write(*value, writer_); + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT value) override { + ok_ &= F16Coder::Write(*value, writer_); + return true; + } + + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions)); + if (*extensions == 0) { + JXL_ASSERT(extension_bits_ == 0); + return true; + } + // TODO(janwas): extend API to pass in array of extension_bits, one per + // extension. We currently ascribe all bits to the first extension, but + // this is only an encoder limitation. NOTE: extension_bits_ can be zero + // if an extension does not require any additional fields. + ok_ &= U64Coder::Write(extension_bits_, writer_); + // For each nonzero bit except the lowest/first (already written): + for (uint64_t remaining_extensions = *extensions & (*extensions - 1); + remaining_extensions != 0; + remaining_extensions &= remaining_extensions - 1) { + ok_ &= U64Coder::Write(0, writer_); + } + return true; + } + // EndExtensions = default. + + Status OK() const { return ok_; } + + private: + const size_t extension_bits_; + BitWriter* JXL_RESTRICT writer_; + bool ok_ = true; +}; +} // namespace + +Status Bundle::Write(const Fields& fields, BitWriter* writer, size_t layer, + AuxOut* aux_out) { + size_t extension_bits, total_bits; + JXL_RETURN_IF_ERROR(Bundle::CanEncode(fields, &extension_bits, &total_bits)); + + BitWriter::Allotment allotment(writer, total_bits); + WriteVisitor visitor(extension_bits, writer); + JXL_RETURN_IF_ERROR(visitor.VisitConst(fields)); + JXL_RETURN_IF_ERROR(visitor.OK()); + allotment.ReclaimAndCharge(writer, layer, aux_out); + return true; +} + +// Returns false if the value is too large to encode. +Status BitsCoder::Write(const size_t bits, const uint32_t value, + BitWriter* JXL_RESTRICT writer) { + if (value >= (1ULL << bits)) { + return JXL_FAILURE("Value %d too large to encode in %" PRIu64 " bits", + value, static_cast<uint64_t>(bits)); + } + writer->Write(bits, value); + return true; +} + +// Returns false if the value is too large to encode. +Status U32Coder::Write(const U32Enc enc, const uint32_t value, + BitWriter* JXL_RESTRICT writer) { + uint32_t selector; + size_t total_bits; + JXL_RETURN_IF_ERROR(ChooseSelector(enc, value, &selector, &total_bits)); + + writer->Write(2, selector); + + const U32Distr d = enc.GetDistr(selector); + if (!d.IsDirect()) { // Nothing more to write for direct encoding + const uint32_t offset = d.Offset(); + JXL_ASSERT(value >= offset); + writer->Write(total_bits - 2, value - offset); + } + + return true; +} + +// Returns false if the value is too large to encode. +Status U64Coder::Write(uint64_t value, BitWriter* JXL_RESTRICT writer) { + if (value == 0) { + // Selector: use 0 bits, value 0 + writer->Write(2, 0); + } else if (value <= 16) { + // Selector: use 4 bits, value 1..16 + writer->Write(2, 1); + writer->Write(4, value - 1); + } else if (value <= 272) { + // Selector: use 8 bits, value 17..272 + writer->Write(2, 2); + writer->Write(8, value - 17); + } else { + // Selector: varint, first a 12-bit group, after that per 8-bit group. + writer->Write(2, 3); + writer->Write(12, value & 4095); + value >>= 12; + int shift = 12; + while (value > 0 && shift < 60) { + // Indicate varint not done + writer->Write(1, 1); + writer->Write(8, value & 255); + value >>= 8; + shift += 8; + } + if (value > 0) { + // This only could happen if shift == N - 4. + writer->Write(1, 1); + writer->Write(4, value & 15); + // Implicitly closed sequence, no extra stop bit is required. + } else { + // Indicate end of varint + writer->Write(1, 0); + } + } + + return true; +} + +Status F16Coder::Write(float value, BitWriter* JXL_RESTRICT writer) { + uint32_t bits32; + memcpy(&bits32, &value, sizeof(bits32)); + const uint32_t sign = bits32 >> 31; + const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; + const uint32_t mantissa32 = bits32 & 0x7FFFFF; + + const int32_t exp = static_cast<int32_t>(biased_exp32) - 127; + if (JXL_UNLIKELY(exp > 15)) { + return JXL_FAILURE("Too big to encode, CanEncode should return false"); + } + + // Tiny or zero => zero. + if (exp < -24) { + writer->Write(16, 0); + return true; + } + + uint32_t biased_exp16, mantissa16; + + // exp = [-24, -15] => subnormal + if (JXL_UNLIKELY(exp < -14)) { + biased_exp16 = 0; + const uint32_t sub_exp = static_cast<uint32_t>(-14 - exp); + JXL_ASSERT(1 <= sub_exp && sub_exp < 11); + mantissa16 = (1 << (10 - sub_exp)) + (mantissa32 >> (13 + sub_exp)); + } else { + // exp = [-14, 15] + biased_exp16 = static_cast<uint32_t>(exp + 15); + JXL_ASSERT(1 <= biased_exp16 && biased_exp16 < 31); + mantissa16 = mantissa32 >> 13; + } + + JXL_ASSERT(mantissa16 < 1024); + const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; + JXL_ASSERT(bits16 < 0x10000); + writer->Write(16, bits16); + return true; +} + +Status WriteCodestreamHeaders(CodecMetadata* metadata, BitWriter* writer, + AuxOut* aux_out) { + // Marker/signature + BitWriter::Allotment allotment(writer, 16); + writer->Write(8, 0xFF); + writer->Write(8, kCodestreamMarker); + allotment.ReclaimAndCharge(writer, kLayerHeader, aux_out); + + JXL_RETURN_IF_ERROR( + WriteSizeHeader(metadata->size, writer, kLayerHeader, aux_out)); + + JXL_RETURN_IF_ERROR( + WriteImageMetadata(metadata->m, writer, kLayerHeader, aux_out)); + + metadata->transform_data.nonserialized_xyb_encoded = metadata->m.xyb_encoded; + JXL_RETURN_IF_ERROR( + Bundle::Write(metadata->transform_data, writer, kLayerHeader, aux_out)); + + return true; +} + +Status WriteFrameHeader(const FrameHeader& frame, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out) { + return Bundle::Write(frame, writer, kLayerHeader, aux_out); +} + +Status WriteImageMetadata(const ImageMetadata& metadata, + BitWriter* JXL_RESTRICT writer, size_t layer, + AuxOut* aux_out) { + return Bundle::Write(metadata, writer, layer, aux_out); +} + +Status WriteQuantizerParams(const QuantizerParams& params, + BitWriter* JXL_RESTRICT writer, size_t layer, + AuxOut* aux_out) { + return Bundle::Write(params, writer, layer, aux_out); +} + +Status WriteSizeHeader(const SizeHeader& size, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* aux_out) { + return Bundle::Write(size, writer, layer, aux_out); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_fields.h b/third_party/jpeg-xl/lib/jxl/enc_fields.h new file mode 100644 index 0000000000..5bb179a719 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_fields.h @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_FIELDS_H_ +#define LIB_JXL_ENC_FIELDS_H_ + +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +struct AuxOut; + +// Write headers from the CodecMetadata. Also may modify nonserialized_... +// fields of the metadata. +Status WriteCodestreamHeaders(CodecMetadata* metadata, BitWriter* writer, + AuxOut* aux_out); + +Status WriteFrameHeader(const FrameHeader& frame, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out); + +Status WriteQuantizerParams(const QuantizerParams& params, + BitWriter* JXL_RESTRICT writer, size_t layer, + AuxOut* aux_out); + +Status WriteSizeHeader(const SizeHeader& size, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_FIELDS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_frame.cc b/third_party/jpeg-xl/lib/jxl/enc_frame.cc new file mode 100644 index 0000000000..aae59c49a6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_frame.cc @@ -0,0 +1,2197 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_frame.h" + +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <array> +#include <atomic> +#include <cmath> +#include <limits> +#include <numeric> +#include <vector> + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" // kMaxNumPasses +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_external_image.h" +#include "lib/jxl/enc_ac_strategy.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_ar_control_field.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_chroma_from_luma.h" +#include "lib/jxl/enc_coeff_order.h" +#include "lib/jxl/enc_context_map.h" +#include "lib/jxl/enc_entropy_coder.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/enc_gaborish.h" +#include "lib/jxl/enc_group.h" +#include "lib/jxl/enc_heuristics.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_noise.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/enc_photon_noise.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/enc_splines.h" +#include "lib/jxl/enc_toc.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/splines.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +Status ParamsPostInit(CompressParams* p) { + if (!p->manual_noise.empty() && + p->manual_noise.size() != NoiseParams::kNumNoisePoints) { + return JXL_FAILURE("Invalid number of noise lut entries"); + } + if (!p->manual_xyb_factors.empty() && p->manual_xyb_factors.size() != 3) { + return JXL_FAILURE("Invalid number of XYB quantization factors"); + } + if (!p->modular_mode && p->butteraugli_distance == 0.0) { + p->butteraugli_distance = kMinButteraugliDistance; + } + if (p->original_butteraugli_distance == -1.0) { + p->original_butteraugli_distance = p->butteraugli_distance; + } + if (p->resampling <= 0) { + p->resampling = 1; + // For very low bit rates, using 2x2 resampling gives better results on + // most photographic images, with an adjusted butteraugli score chosen to + // give roughly the same amount of bits per pixel. + if (!p->already_downsampled && p->butteraugli_distance >= 20) { + p->resampling = 2; + p->butteraugli_distance = 6 + ((p->butteraugli_distance - 20) * 0.25); + } + } + if (p->ec_resampling <= 0) { + p->ec_resampling = p->resampling; + } + return true; +} + +namespace { + +template <typename T> +uint32_t GetBitDepth(JxlBitDepth bit_depth, const T& metadata, + JxlPixelFormat format) { + if (bit_depth.type == JXL_BIT_DEPTH_FROM_PIXEL_FORMAT) { + return BitsPerChannel(format.data_type); + } else if (bit_depth.type == JXL_BIT_DEPTH_FROM_CODESTREAM) { + return metadata.bit_depth.bits_per_sample; + } else if (bit_depth.type == JXL_BIT_DEPTH_CUSTOM) { + return bit_depth.bits_per_sample; + } else { + return 0; + } +} + +Status CopyColorChannels(JxlChunkedFrameInputSource input, Rect rect, + const FrameInfo& frame_info, + const ImageMetadata& metadata, ThreadPool* pool, + Image3F* color, ImageF* alpha, + bool* has_interleaved_alpha) { + JxlPixelFormat format = {4, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + input.get_color_channels_pixel_format(input.opaque, &format); + *has_interleaved_alpha = format.num_channels == 2 || format.num_channels == 4; + size_t bits_per_sample = + GetBitDepth(frame_info.image_bit_depth, metadata, format); + size_t row_offset; + auto buffer = GetColorBuffer(input, rect.x0(), rect.y0(), rect.xsize(), + rect.ysize(), &row_offset); + if (!buffer) { + return JXL_FAILURE("no buffer for color channels given"); + } + size_t color_channels = frame_info.ib_needs_color_transform + ? metadata.color_encoding.Channels() + : 3; + if (format.num_channels < color_channels) { + return JXL_FAILURE("Expected %" PRIuS + " color channels, received only %u channels", + color_channels, format.num_channels); + } + const uint8_t* data = reinterpret_cast<const uint8_t*>(buffer.get()); + for (size_t c = 0; c < color_channels; ++c) { + JXL_RETURN_IF_ERROR(ConvertFromExternalNoSizeCheck( + data, rect.xsize(), rect.ysize(), row_offset, bits_per_sample, format, + c, pool, &color->Plane(c))); + } + if (color_channels == 1) { + CopyImageTo(color->Plane(0), &color->Plane(1)); + CopyImageTo(color->Plane(0), &color->Plane(2)); + } + if (alpha) { + if (*has_interleaved_alpha) { + JXL_RETURN_IF_ERROR(ConvertFromExternalNoSizeCheck( + data, rect.xsize(), rect.ysize(), row_offset, bits_per_sample, format, + format.num_channels - 1, pool, alpha)); + } else { + // if alpha is not passed, but it is expected, then assume + // it is all-opaque + FillImage(1.0f, alpha); + } + } + return true; +} + +Status CopyExtraChannels(JxlChunkedFrameInputSource input, Rect rect, + const FrameInfo& frame_info, + const ImageMetadata& metadata, + bool has_interleaved_alpha, ThreadPool* pool, + std::vector<ImageF>* extra_channels) { + for (size_t ec = 0; ec < metadata.num_extra_channels; ec++) { + if (has_interleaved_alpha && + metadata.extra_channel_info[ec].type == ExtraChannel::kAlpha) { + // Skip this alpha channel, but still request additional alpha channels + // if they exist. + has_interleaved_alpha = false; + continue; + } + JxlPixelFormat ec_format = {1, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + input.get_extra_channel_pixel_format(input.opaque, ec, &ec_format); + ec_format.num_channels = 1; + size_t row_offset; + auto buffer = + GetExtraChannelBuffer(input, ec, rect.x0(), rect.y0(), rect.xsize(), + rect.ysize(), &row_offset); + if (!buffer) { + return JXL_FAILURE("no buffer for extra channel given"); + } + size_t bits_per_sample = GetBitDepth( + frame_info.image_bit_depth, metadata.extra_channel_info[ec], ec_format); + if (!ConvertFromExternalNoSizeCheck( + reinterpret_cast<const uint8_t*>(buffer.get()), rect.xsize(), + rect.ysize(), row_offset, bits_per_sample, ec_format, 0, pool, + &(*extra_channels)[ec])) { + return JXL_FAILURE("Failed to set buffer for extra channel"); + } + } + return true; +} + +void SetProgressiveMode(const CompressParams& cparams, + ProgressiveSplitter* progressive_splitter) { + constexpr PassDefinition progressive_passes_dc_vlf_lf_full_ac[] = { + {/*num_coefficients=*/2, /*shift=*/0, + /*suitable_for_downsampling_of_at_least=*/4}, + {/*num_coefficients=*/3, /*shift=*/0, + /*suitable_for_downsampling_of_at_least=*/2}, + {/*num_coefficients=*/8, /*shift=*/0, + /*suitable_for_downsampling_of_at_least=*/0}, + }; + constexpr PassDefinition progressive_passes_dc_quant_ac_full_ac[] = { + {/*num_coefficients=*/8, /*shift=*/1, + /*suitable_for_downsampling_of_at_least=*/2}, + {/*num_coefficients=*/8, /*shift=*/0, + /*suitable_for_downsampling_of_at_least=*/0}, + }; + bool progressive_mode = ApplyOverride(cparams.progressive_mode, false); + bool qprogressive_mode = ApplyOverride(cparams.qprogressive_mode, false); + if (cparams.custom_progressive_mode) { + progressive_splitter->SetProgressiveMode(*cparams.custom_progressive_mode); + } else if (qprogressive_mode) { + progressive_splitter->SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_quant_ac_full_ac}); + } else if (progressive_mode) { + progressive_splitter->SetProgressiveMode( + ProgressiveMode{progressive_passes_dc_vlf_lf_full_ac}); + } +} + +uint64_t FrameFlagsFromParams(const CompressParams& cparams) { + uint64_t flags = 0; + + const float dist = cparams.butteraugli_distance; + + // We don't add noise at low butteraugli distances because the original + // noise is stored within the compressed image and adding noise makes things + // worse. + if (ApplyOverride(cparams.noise, dist >= kMinButteraugliForNoise) || + cparams.photon_noise_iso > 0 || + cparams.manual_noise.size() == NoiseParams::kNumNoisePoints) { + flags |= FrameHeader::kNoise; + } + + if (cparams.progressive_dc > 0 && cparams.modular_mode == false) { + flags |= FrameHeader::kUseDcFrame; + } + + return flags; +} + +Status LoopFilterFromParams(const CompressParams& cparams, bool streaming_mode, + FrameHeader* JXL_RESTRICT frame_header) { + LoopFilter* loop_filter = &frame_header->loop_filter; + + // Gaborish defaults to enabled in Hare or slower. + loop_filter->gab = ApplyOverride( + cparams.gaborish, cparams.speed_tier <= SpeedTier::kHare && + frame_header->encoding == FrameEncoding::kVarDCT && + cparams.decoding_speed_tier < 4); + + if (cparams.epf != -1) { + loop_filter->epf_iters = cparams.epf; + } else { + if (frame_header->encoding == FrameEncoding::kModular) { + loop_filter->epf_iters = 0; + } else { + constexpr float kThresholds[3] = {0.7, 1.5, 4.0}; + loop_filter->epf_iters = 0; + if (cparams.decoding_speed_tier < 3) { + for (size_t i = cparams.decoding_speed_tier == 2 ? 1 : 0; i < 3; i++) { + if (cparams.butteraugli_distance >= kThresholds[i]) { + loop_filter->epf_iters++; + } + } + } + } + } + // Strength of EPF in modular mode. + if (frame_header->encoding == FrameEncoding::kModular && + !cparams.IsLossless()) { + // TODO(veluca): this formula is nonsense. + loop_filter->epf_sigma_for_modular = cparams.butteraugli_distance; + } + if (frame_header->encoding == FrameEncoding::kModular && + cparams.lossy_palette) { + loop_filter->epf_sigma_for_modular = 1.0f; + } + + return true; +} + +Status MakeFrameHeader(size_t xsize, size_t ysize, + const CompressParams& cparams, + const ProgressiveSplitter& progressive_splitter, + const FrameInfo& frame_info, + const jpeg::JPEGData* jpeg_data, bool streaming_mode, + FrameHeader* JXL_RESTRICT frame_header) { + frame_header->nonserialized_is_preview = frame_info.is_preview; + frame_header->is_last = frame_info.is_last; + frame_header->save_before_color_transform = + frame_info.save_before_color_transform; + frame_header->frame_type = frame_info.frame_type; + frame_header->name = frame_info.name; + + progressive_splitter.InitPasses(&frame_header->passes); + + if (cparams.modular_mode) { + frame_header->encoding = FrameEncoding::kModular; + if (cparams.modular_group_size_shift == -1) { + frame_header->group_size_shift = 1; + // no point using groups when only one group is full and the others are + // less than half full: multithreading will not really help much, while + // compression does suffer + if (xsize <= 400 && ysize <= 400) { + frame_header->group_size_shift = 2; + } + } else { + frame_header->group_size_shift = cparams.modular_group_size_shift; + } + } + + if (jpeg_data) { + // we are transcoding a JPEG, so we don't get to choose + frame_header->encoding = FrameEncoding::kVarDCT; + frame_header->x_qm_scale = 2; + frame_header->b_qm_scale = 2; + JXL_RETURN_IF_ERROR(SetChromaSubsamplingFromJpegData( + *jpeg_data, &frame_header->chroma_subsampling)); + JXL_RETURN_IF_ERROR(SetColorTransformFromJpegData( + *jpeg_data, &frame_header->color_transform)); + } else { + frame_header->color_transform = cparams.color_transform; + if (!cparams.modular_mode && + (frame_header->chroma_subsampling.MaxHShift() != 0 || + frame_header->chroma_subsampling.MaxVShift() != 0)) { + return JXL_FAILURE( + "Chroma subsampling is not supported in VarDCT mode when not " + "recompressing JPEGs"); + } + } + if (frame_header->color_transform != ColorTransform::kYCbCr && + (frame_header->chroma_subsampling.MaxHShift() != 0 || + frame_header->chroma_subsampling.MaxVShift() != 0)) { + return JXL_FAILURE( + "Chroma subsampling is not supported when color transform is not " + "YCbCr"); + } + + frame_header->flags = FrameFlagsFromParams(cparams); + // Non-photon noise is not supported in the Modular encoder for now. + if (frame_header->encoding != FrameEncoding::kVarDCT && + cparams.photon_noise_iso == 0 && cparams.manual_noise.empty()) { + frame_header->UpdateFlag(false, FrameHeader::Flags::kNoise); + } + + JXL_RETURN_IF_ERROR( + LoopFilterFromParams(cparams, streaming_mode, frame_header)); + + frame_header->dc_level = frame_info.dc_level; + if (frame_header->dc_level > 2) { + // With 3 or more progressive_dc frames, the implementation does not yet + // work, see enc_cache.cc. + return JXL_FAILURE("progressive_dc > 2 is not yet supported"); + } + if (cparams.progressive_dc > 0 && + (cparams.ec_resampling != 1 || cparams.resampling != 1)) { + return JXL_FAILURE("Resampling not supported with DC frames"); + } + if (cparams.resampling != 1 && cparams.resampling != 2 && + cparams.resampling != 4 && cparams.resampling != 8) { + return JXL_FAILURE("Invalid resampling factor"); + } + if (cparams.ec_resampling != 1 && cparams.ec_resampling != 2 && + cparams.ec_resampling != 4 && cparams.ec_resampling != 8) { + return JXL_FAILURE("Invalid ec_resampling factor"); + } + // Resized frames. + if (frame_info.frame_type != FrameType::kDCFrame) { + frame_header->frame_origin = frame_info.origin; + size_t ups = 1; + if (cparams.already_downsampled) ups = cparams.resampling; + + // TODO(lode): this is not correct in case of odd original image sizes in + // combination with cparams.already_downsampled. Likely these values should + // be set to respectively frame_header->default_xsize() and + // frame_header->default_ysize() instead, the original (non downsampled) + // intended decoded image dimensions. But it may be more subtle than that + // if combined with crop. This issue causes custom_size_or_origin to be + // incorrectly set to true in case of already_downsampled with odd output + // image size when no cropping is used. + frame_header->frame_size.xsize = xsize * ups; + frame_header->frame_size.ysize = ysize * ups; + if (frame_info.origin.x0 != 0 || frame_info.origin.y0 != 0 || + frame_header->frame_size.xsize != frame_header->default_xsize() || + frame_header->frame_size.ysize != frame_header->default_ysize()) { + frame_header->custom_size_or_origin = true; + } + } + // Upsampling. + frame_header->upsampling = cparams.resampling; + const std::vector<ExtraChannelInfo>& extra_channels = + frame_header->nonserialized_metadata->m.extra_channel_info; + frame_header->extra_channel_upsampling.clear(); + frame_header->extra_channel_upsampling.resize(extra_channels.size(), + cparams.ec_resampling); + frame_header->save_as_reference = frame_info.save_as_reference; + + // Set blending-related information. + if (frame_info.blend || frame_header->custom_size_or_origin) { + // Set blend_channel to the first alpha channel. These values are only + // encoded in case a blend mode involving alpha is used and there are more + // than one extra channels. + size_t index = 0; + if (frame_info.alpha_channel == -1) { + if (extra_channels.size() > 1) { + for (size_t i = 0; i < extra_channels.size(); i++) { + if (extra_channels[i].type == ExtraChannel::kAlpha) { + index = i; + break; + } + } + } + } else { + index = static_cast<size_t>(frame_info.alpha_channel); + JXL_ASSERT(index == 0 || index < extra_channels.size()); + } + frame_header->blending_info.alpha_channel = index; + frame_header->blending_info.mode = + frame_info.blend ? frame_info.blendmode : BlendMode::kReplace; + frame_header->blending_info.source = frame_info.source; + frame_header->blending_info.clamp = frame_info.clamp; + const auto& extra_channel_info = frame_info.extra_channel_blending_info; + for (size_t i = 0; i < extra_channels.size(); i++) { + if (i < extra_channel_info.size()) { + frame_header->extra_channel_blending_info[i] = extra_channel_info[i]; + } else { + frame_header->extra_channel_blending_info[i].alpha_channel = index; + BlendMode default_blend = frame_info.blendmode; + if (extra_channels[i].type != ExtraChannel::kBlack && i != index) { + // K needs to be blended, spot colors and other stuff gets added + default_blend = BlendMode::kAdd; + } + frame_header->extra_channel_blending_info[i].mode = + frame_info.blend ? default_blend : BlendMode::kReplace; + frame_header->extra_channel_blending_info[i].source = 1; + } + } + } + + frame_header->animation_frame.duration = frame_info.duration; + frame_header->animation_frame.timecode = frame_info.timecode; + + if (jpeg_data) { + frame_header->UpdateFlag(false, FrameHeader::kUseDcFrame); + frame_header->UpdateFlag(true, FrameHeader::kSkipAdaptiveDCSmoothing); + } + + return true; +} + +// Invisible (alpha = 0) pixels tend to be a mess in optimized PNGs. +// Since they have no visual impact whatsoever, we can replace them with +// something that compresses better and reduces artifacts near the edges. This +// does some kind of smooth stuff that seems to work. +// Replace invisible pixels with a weighted average of the pixel to the left, +// the pixel to the topright, and non-invisible neighbours. +// Produces downward-blurry smears, with in the upwards direction only a 1px +// edge duplication but not more. It would probably be better to smear in all +// directions. That requires an alpha-weighed convolution with a large enough +// kernel though, which might be overkill... +void SimplifyInvisible(Image3F* image, const ImageF& alpha, bool lossless) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + float* JXL_RESTRICT row = image->PlaneRow(c, y); + const float* JXL_RESTRICT prow = + (y > 0 ? image->PlaneRow(c, y - 1) : nullptr); + const float* JXL_RESTRICT nrow = + (y + 1 < image->ysize() ? image->PlaneRow(c, y + 1) : nullptr); + const float* JXL_RESTRICT a = alpha.Row(y); + const float* JXL_RESTRICT pa = (y > 0 ? alpha.Row(y - 1) : nullptr); + const float* JXL_RESTRICT na = + (y + 1 < image->ysize() ? alpha.Row(y + 1) : nullptr); + for (size_t x = 0; x < image->xsize(); ++x) { + if (a[x] == 0) { + if (lossless) { + row[x] = 0; + continue; + } + float d = 0.f; + row[x] = 0; + if (x > 0) { + row[x] += row[x - 1]; + d++; + if (a[x - 1] > 0.f) { + row[x] += row[x - 1]; + d++; + } + } + if (x + 1 < image->xsize()) { + if (y > 0) { + row[x] += prow[x + 1]; + d++; + } + if (a[x + 1] > 0.f) { + row[x] += 2.f * row[x + 1]; + d += 2.f; + } + if (y > 0 && pa[x + 1] > 0.f) { + row[x] += 2.f * prow[x + 1]; + d += 2.f; + } + if (y + 1 < image->ysize() && na[x + 1] > 0.f) { + row[x] += 2.f * nrow[x + 1]; + d += 2.f; + } + } + if (y > 0 && pa[x] > 0.f) { + row[x] += 2.f * prow[x]; + d += 2.f; + } + if (y + 1 < image->ysize() && na[x] > 0.f) { + row[x] += 2.f * nrow[x]; + d += 2.f; + } + if (d > 1.f) row[x] /= d; + } + } + } + } +} + +struct PixelStatsForChromacityAdjustment { + float dx = 0; + float db = 0; + float exposed_blue = 0; + float CalcPlane(const ImageF* JXL_RESTRICT plane, const Rect& rect) const { + float xmax = 0; + float ymax = 0; + for (size_t ty = 1; ty < rect.ysize(); ++ty) { + for (size_t tx = 1; tx < rect.xsize(); ++tx) { + float cur = rect.Row(plane, ty)[tx]; + float prev_row = rect.Row(plane, ty - 1)[tx]; + float prev = rect.Row(plane, ty)[tx - 1]; + xmax = std::max(xmax, std::abs(cur - prev)); + ymax = std::max(ymax, std::abs(cur - prev_row)); + } + } + return std::max(xmax, ymax); + } + void CalcExposedBlue(const ImageF* JXL_RESTRICT plane_y, + const ImageF* JXL_RESTRICT plane_b, const Rect& rect) { + float eb = 0; + float xmax = 0; + float ymax = 0; + for (size_t ty = 1; ty < rect.ysize(); ++ty) { + for (size_t tx = 1; tx < rect.xsize(); ++tx) { + float cur_y = rect.Row(plane_y, ty)[tx]; + float cur_b = rect.Row(plane_b, ty)[tx]; + float exposed_b = cur_b - cur_y * 1.2; + float diff_b = cur_b - cur_y; + float prev_row = rect.Row(plane_b, ty - 1)[tx]; + float prev = rect.Row(plane_b, ty)[tx - 1]; + float diff_prev_row = prev_row - rect.Row(plane_y, ty - 1)[tx]; + float diff_prev = prev - rect.Row(plane_y, ty)[tx - 1]; + xmax = std::max(xmax, std::abs(diff_b - diff_prev)); + ymax = std::max(ymax, std::abs(diff_b - diff_prev_row)); + if (exposed_b >= 0) { + exposed_b *= fabs(cur_b - prev) + fabs(cur_b - prev_row); + eb = std::max(eb, exposed_b); + } + } + } + exposed_blue = eb; + db = std::max(xmax, ymax); + } + void Calc(const Image3F* JXL_RESTRICT opsin, const Rect& rect) { + dx = CalcPlane(&opsin->Plane(0), rect); + CalcExposedBlue(&opsin->Plane(1), &opsin->Plane(2), rect); + } + int HowMuchIsXChannelPixelized() { + if (dx >= 0.03) { + return 2; + } + if (dx >= 0.017) { + return 1; + } + return 0; + } + int HowMuchIsBChannelPixelized() { + int add = exposed_blue >= 0.13 ? 1 : 0; + if (db > 0.38) { + return 2 + add; + } + if (db > 0.33) { + return 1 + add; + } + if (db > 0.28) { + return add; + } + return 0; + } +}; + +void ComputeChromacityAdjustments(const CompressParams& cparams, + const Image3F& opsin, const Rect& rect, + FrameHeader* frame_header) { + if (frame_header->encoding != FrameEncoding::kVarDCT || + cparams.max_error_mode) { + return; + } + // 1) Distance based approach for chromacity adjustment: + float x_qm_scale_steps[4] = {1.25f, 7.0f, 15.0f, 24.0f}; + frame_header->x_qm_scale = 2; + for (float x_qm_scale_step : x_qm_scale_steps) { + if (cparams.original_butteraugli_distance > x_qm_scale_step) { + frame_header->x_qm_scale++; + } + } + if (cparams.butteraugli_distance < 0.299f) { + // Favor chromacity preservation for making images appear more + // faithful to original even with extreme (5-10x) zooming. + frame_header->x_qm_scale++; + } + // 2) Pixel-based approach for chromacity adjustment: + // look at the individual pixels and make a guess how difficult + // the image would be based on the worst case pixel. + PixelStatsForChromacityAdjustment pixel_stats; + if (cparams.speed_tier <= SpeedTier::kSquirrel) { + pixel_stats.Calc(&opsin, rect); + } + // For X take the most severe adjustment. + frame_header->x_qm_scale = std::max<int>( + frame_header->x_qm_scale, 2 + pixel_stats.HowMuchIsXChannelPixelized()); + // B only adjusted by pixel-based approach. + frame_header->b_qm_scale = 2 + pixel_stats.HowMuchIsBChannelPixelized(); +} + +void ComputeNoiseParams(const CompressParams& cparams, bool streaming_mode, + bool color_is_jpeg, const Image3F& opsin, + const FrameDimensions& frame_dim, + FrameHeader* frame_header, NoiseParams* noise_params) { + if (cparams.photon_noise_iso > 0) { + *noise_params = SimulatePhotonNoise(frame_dim.xsize, frame_dim.ysize, + cparams.photon_noise_iso); + } else if (cparams.manual_noise.size() == NoiseParams::kNumNoisePoints) { + for (size_t i = 0; i < NoiseParams::kNumNoisePoints; i++) { + noise_params->lut[i] = cparams.manual_noise[i]; + } + } else if (frame_header->encoding == FrameEncoding::kVarDCT && + frame_header->flags & FrameHeader::kNoise && !color_is_jpeg && + !streaming_mode) { + // Don't start at zero amplitude since adding noise is expensive -- it + // significantly slows down decoding, and this is unlikely to + // completely go away even with advanced optimizations. After the + // kNoiseModelingRampUpDistanceRange we have reached the full level, + // i.e. noise is no longer represented by the compressed image, so we + // can add full noise by the noise modeling itself. + static const float kNoiseModelingRampUpDistanceRange = 0.6; + static const float kNoiseLevelAtStartOfRampUp = 0.25; + static const float kNoiseRampupStart = 1.0; + // TODO(user) test and properly select quality_coef with smooth + // filter + float quality_coef = 1.0f; + const float rampup = (cparams.butteraugli_distance - kNoiseRampupStart) / + kNoiseModelingRampUpDistanceRange; + if (rampup < 1.0f) { + quality_coef = kNoiseLevelAtStartOfRampUp + + (1.0f - kNoiseLevelAtStartOfRampUp) * rampup; + } + if (rampup < 0.0f) { + quality_coef = kNoiseRampupStart; + } + if (!GetNoiseParameter(opsin, noise_params, quality_coef)) { + frame_header->flags &= ~FrameHeader::kNoise; + } + } +} + +void DownsampleColorChannels(const CompressParams& cparams, + const FrameHeader& frame_header, + bool color_is_jpeg, Image3F* opsin) { + if (color_is_jpeg || frame_header.upsampling == 1 || + cparams.already_downsampled) { + return; + } + if (frame_header.encoding == FrameEncoding::kVarDCT && + frame_header.upsampling == 2) { + // TODO(lode): use the regular DownsampleImage, or adapt to the custom + // coefficients, if there is are custom upscaling coefficients in + // CustomTransformData + if (cparams.speed_tier <= SpeedTier::kSquirrel) { + // TODO(lode): DownsampleImage2_Iterative is currently too slow to + // be used for squirrel, make it faster, and / or enable it only for + // kitten. + DownsampleImage2_Iterative(opsin); + } else { + DownsampleImage2_Sharper(opsin); + } + } else { + DownsampleImage(opsin, frame_header.upsampling); + } + if (frame_header.encoding == FrameEncoding::kVarDCT) { + PadImageToBlockMultipleInPlace(opsin); + } +} + +template <typename V, typename R> +void FindIndexOfSumMaximum(const V* array, const size_t len, R* idx, V* sum) { + JXL_ASSERT(len > 0); + V maxval = 0; + V val = 0; + R maxidx = 0; + for (size_t i = 0; i < len; ++i) { + val += array[i]; + if (val > maxval) { + maxval = val; + maxidx = i; + } + } + *idx = maxidx; + *sum = maxval; +} + +Status ComputeJPEGTranscodingData(const jpeg::JPEGData& jpeg_data, + const FrameHeader& frame_header, + ThreadPool* pool, + ModularFrameEncoder* enc_modular, + PassesEncoderState* enc_state) { + PassesSharedState& shared = enc_state->shared; + const FrameDimensions& frame_dim = shared.frame_dim; + + const size_t xsize = frame_dim.xsize_padded; + const size_t ysize = frame_dim.ysize_padded; + const size_t xsize_blocks = frame_dim.xsize_blocks; + const size_t ysize_blocks = frame_dim.ysize_blocks; + + // no-op chroma from luma + shared.cmap = ColorCorrelationMap(xsize, ysize, false); + shared.ac_strategy.FillDCT8(); + FillImage(uint8_t(0), &shared.epf_sharpness); + + enc_state->coeffs.clear(); + while (enc_state->coeffs.size() < enc_state->passes.size()) { + enc_state->coeffs.emplace_back(make_unique<ACImageT<int32_t>>( + kGroupDim * kGroupDim, frame_dim.num_groups)); + } + + // convert JPEG quantization table to a Quantizer object + float dcquantization[3]; + std::vector<QuantEncoding> qe(DequantMatrices::kNum, + QuantEncoding::Library(0)); + + auto jpeg_c_map = + JpegOrder(frame_header.color_transform, jpeg_data.components.size() == 1); + + std::vector<int> qt(192); + for (size_t c = 0; c < 3; c++) { + size_t jpeg_c = jpeg_c_map[c]; + const int32_t* quant = + jpeg_data.quant[jpeg_data.components[jpeg_c].quant_idx].values.data(); + + dcquantization[c] = 255 * 8.0f / quant[0]; + for (size_t y = 0; y < 8; y++) { + for (size_t x = 0; x < 8; x++) { + // JPEG XL transposes the DCT, JPEG doesn't. + qt[c * 64 + 8 * x + y] = quant[8 * y + x]; + } + } + } + DequantMatricesSetCustomDC(&shared.matrices, dcquantization); + float dcquantization_r[3] = {1.0f / dcquantization[0], + 1.0f / dcquantization[1], + 1.0f / dcquantization[2]}; + + qe[AcStrategy::Type::DCT] = QuantEncoding::RAW(qt); + DequantMatricesSetCustom(&shared.matrices, qe, enc_modular); + + // Ensure that InvGlobalScale() is 1. + shared.quantizer = Quantizer(&shared.matrices, 1, kGlobalScaleDenom); + // Recompute MulDC() and InvMulDC(). + shared.quantizer.RecomputeFromGlobalScale(); + + // Per-block dequant scaling should be 1. + FillImage(static_cast<int32_t>(shared.quantizer.InvGlobalScale()), + &shared.raw_quant_field); + + std::vector<int32_t> scaled_qtable(192); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 64; i++) { + scaled_qtable[64 * c + i] = + (1 << kCFLFixedPointPrecision) * qt[64 + i] / qt[64 * c + i]; + } + } + + auto jpeg_row = [&](size_t c, size_t y) { + return jpeg_data.components[jpeg_c_map[c]].coeffs.data() + + jpeg_data.components[jpeg_c_map[c]].width_in_blocks * kDCTBlockSize * + y; + }; + + bool DCzero = (frame_header.color_transform == ColorTransform::kYCbCr); + // Compute chroma-from-luma for AC (doesn't seem to be useful for DC) + if (frame_header.chroma_subsampling.Is444() && + enc_state->cparams.force_cfl_jpeg_recompression && + jpeg_data.components.size() == 3) { + for (size_t c : {0, 2}) { + ImageSB* map = (c == 0 ? &shared.cmap.ytox_map : &shared.cmap.ytob_map); + const float kScale = kDefaultColorFactor; + const int kOffset = 127; + const float kBase = + c == 0 ? shared.cmap.YtoXRatio(0) : shared.cmap.YtoBRatio(0); + const float kZeroThresh = + kScale * kZeroBiasDefault[c] * + 0.9999f; // just epsilon less for better rounding + + auto process_row = [&](const uint32_t task, const size_t thread) { + size_t ty = task; + int8_t* JXL_RESTRICT row_out = map->Row(ty); + for (size_t tx = 0; tx < map->xsize(); ++tx) { + const size_t y0 = ty * kColorTileDimInBlocks; + const size_t x0 = tx * kColorTileDimInBlocks; + const size_t y1 = std::min(frame_dim.ysize_blocks, + (ty + 1) * kColorTileDimInBlocks); + const size_t x1 = std::min(frame_dim.xsize_blocks, + (tx + 1) * kColorTileDimInBlocks); + int32_t d_num_zeros[257] = {0}; + // TODO(veluca): this needs SIMD + fixed point adaptation, and/or + // conversion to the new CfL algorithm. + for (size_t y = y0; y < y1; ++y) { + const int16_t* JXL_RESTRICT row_m = jpeg_row(1, y); + const int16_t* JXL_RESTRICT row_s = jpeg_row(c, y); + for (size_t x = x0; x < x1; ++x) { + for (size_t coeffpos = 1; coeffpos < kDCTBlockSize; coeffpos++) { + const float scaled_m = row_m[x * kDCTBlockSize + coeffpos] * + scaled_qtable[64 * c + coeffpos] * + (1.0f / (1 << kCFLFixedPointPrecision)); + const float scaled_s = + kScale * row_s[x * kDCTBlockSize + coeffpos] + + (kOffset - kBase * kScale) * scaled_m; + if (std::abs(scaled_m) > 1e-8f) { + float from, to; + if (scaled_m > 0) { + from = (scaled_s - kZeroThresh) / scaled_m; + to = (scaled_s + kZeroThresh) / scaled_m; + } else { + from = (scaled_s + kZeroThresh) / scaled_m; + to = (scaled_s - kZeroThresh) / scaled_m; + } + if (from < 0.0f) { + from = 0.0f; + } + if (to > 255.0f) { + to = 255.0f; + } + // Instead of clamping the both values + // we just check that range is sane. + if (from <= to) { + d_num_zeros[static_cast<int>(std::ceil(from))]++; + d_num_zeros[static_cast<int>(std::floor(to + 1))]--; + } + } + } + } + } + int best = 0; + int32_t best_sum = 0; + FindIndexOfSumMaximum(d_num_zeros, 256, &best, &best_sum); + int32_t offset_sum = 0; + for (int i = 0; i < 256; ++i) { + if (i <= kOffset) { + offset_sum += d_num_zeros[i]; + } + } + row_out[tx] = 0; + if (best_sum > offset_sum + 1) { + row_out[tx] = best - kOffset; + } + } + }; + + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, map->ysize(), ThreadPool::NoInit, + process_row, "FindCorrelation")); + } + } + + Image3F dc = Image3F(xsize_blocks, ysize_blocks); + if (!frame_header.chroma_subsampling.Is444()) { + ZeroFillImage(&dc); + for (auto& coeff : enc_state->coeffs) { + coeff->ZeroFill(); + } + } + // JPEG DC is from -1024 to 1023. + std::vector<size_t> dc_counts[3] = {}; + dc_counts[0].resize(2048); + dc_counts[1].resize(2048); + dc_counts[2].resize(2048); + size_t total_dc[3] = {}; + for (size_t c : {1, 0, 2}) { + if (jpeg_data.components.size() == 1 && c != 1) { + for (auto& coeff : enc_state->coeffs) { + coeff->ZeroFillPlane(c); + } + ZeroFillImage(&dc.Plane(c)); + // Ensure no division by 0. + dc_counts[c][1024] = 1; + total_dc[c] = 1; + continue; + } + size_t hshift = frame_header.chroma_subsampling.HShift(c); + size_t vshift = frame_header.chroma_subsampling.VShift(c); + ImageSB& map = (c == 0 ? shared.cmap.ytox_map : shared.cmap.ytob_map); + for (size_t group_index = 0; group_index < frame_dim.num_groups; + group_index++) { + const size_t gx = group_index % frame_dim.xsize_groups; + const size_t gy = group_index / frame_dim.xsize_groups; + int32_t* coeffs[kMaxNumPasses]; + for (size_t i = 0; i < enc_state->coeffs.size(); i++) { + coeffs[i] = enc_state->coeffs[i]->PlaneRow(c, group_index, 0).ptr32; + } + int32_t block[64]; + for (size_t by = gy * kGroupDimInBlocks; + by < ysize_blocks && by < (gy + 1) * kGroupDimInBlocks; ++by) { + if ((by >> vshift) << vshift != by) continue; + const int16_t* JXL_RESTRICT inputjpeg = jpeg_row(c, by >> vshift); + const int16_t* JXL_RESTRICT inputjpegY = jpeg_row(1, by); + float* JXL_RESTRICT fdc = dc.PlaneRow(c, by >> vshift); + const int8_t* JXL_RESTRICT cm = + map.ConstRow(by / kColorTileDimInBlocks); + for (size_t bx = gx * kGroupDimInBlocks; + bx < xsize_blocks && bx < (gx + 1) * kGroupDimInBlocks; ++bx) { + if ((bx >> hshift) << hshift != bx) continue; + size_t base = (bx >> hshift) * kDCTBlockSize; + int idc; + if (DCzero) { + idc = inputjpeg[base]; + } else { + idc = inputjpeg[base] + 1024 / qt[c * 64]; + } + dc_counts[c][std::min(static_cast<uint32_t>(idc + 1024), + uint32_t(2047))]++; + total_dc[c]++; + fdc[bx >> hshift] = idc * dcquantization_r[c]; + if (c == 1 || !enc_state->cparams.force_cfl_jpeg_recompression || + !frame_header.chroma_subsampling.Is444()) { + for (size_t y = 0; y < 8; y++) { + for (size_t x = 0; x < 8; x++) { + block[y * 8 + x] = inputjpeg[base + x * 8 + y]; + } + } + } else { + const int32_t scale = + shared.cmap.RatioJPEG(cm[bx / kColorTileDimInBlocks]); + + for (size_t y = 0; y < 8; y++) { + for (size_t x = 0; x < 8; x++) { + int Y = inputjpegY[kDCTBlockSize * bx + x * 8 + y]; + int QChroma = inputjpeg[kDCTBlockSize * bx + x * 8 + y]; + // Fixed-point multiply of CfL scale with quant table ratio + // first, and Y value second. + int coeff_scale = (scale * scaled_qtable[64 * c + y * 8 + x] + + (1 << (kCFLFixedPointPrecision - 1))) >> + kCFLFixedPointPrecision; + int cfl_factor = + (Y * coeff_scale + (1 << (kCFLFixedPointPrecision - 1))) >> + kCFLFixedPointPrecision; + int QCR = QChroma - cfl_factor; + block[y * 8 + x] = QCR; + } + } + } + enc_state->progressive_splitter.SplitACCoefficients( + block, AcStrategy::FromRawStrategy(AcStrategy::Type::DCT), bx, by, + coeffs); + for (size_t i = 0; i < enc_state->coeffs.size(); i++) { + coeffs[i] += kDCTBlockSize; + } + } + } + } + } + + auto& dct = enc_state->shared.block_ctx_map.dc_thresholds; + auto& num_dc_ctxs = enc_state->shared.block_ctx_map.num_dc_ctxs; + num_dc_ctxs = 1; + for (size_t i = 0; i < 3; i++) { + dct[i].clear(); + int num_thresholds = (CeilLog2Nonzero(total_dc[i]) - 12) / 2; + // up to 3 buckets per channel: + // dark/medium/bright, yellow/unsat/blue, green/unsat/red + num_thresholds = std::min(std::max(num_thresholds, 0), 2); + size_t cumsum = 0; + size_t cut = total_dc[i] / (num_thresholds + 1); + for (int j = 0; j < 2048; j++) { + cumsum += dc_counts[i][j]; + if (cumsum > cut) { + dct[i].push_back(j - 1025); + cut = total_dc[i] * (dct[i].size() + 1) / (num_thresholds + 1); + } + } + num_dc_ctxs *= dct[i].size() + 1; + } + + auto& ctx_map = enc_state->shared.block_ctx_map.ctx_map; + ctx_map.clear(); + ctx_map.resize(3 * kNumOrders * num_dc_ctxs, 0); + + int lbuckets = (dct[1].size() + 1); + for (size_t i = 0; i < num_dc_ctxs; i++) { + // up to 9 contexts for luma + ctx_map[i] = i / lbuckets; + // up to 3 contexts for chroma + ctx_map[kNumOrders * num_dc_ctxs + i] = + ctx_map[2 * kNumOrders * num_dc_ctxs + i] = + num_dc_ctxs / lbuckets + (i % lbuckets); + } + enc_state->shared.block_ctx_map.num_ctxs = + *std::max_element(ctx_map.begin(), ctx_map.end()) + 1; + + // disable DC frame for now + auto compute_dc_coeffs = [&](const uint32_t group_index, + size_t /* thread */) { + const Rect r = enc_state->shared.frame_dim.DCGroupRect(group_index); + enc_modular->AddVarDCTDC(frame_header, dc, r, group_index, + /*nl_dc=*/false, enc_state, + /*jpeg_transcode=*/true); + enc_modular->AddACMetadata(r, group_index, /*jpeg_transcode=*/true, + enc_state); + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, shared.frame_dim.num_dc_groups, + ThreadPool::NoInit, compute_dc_coeffs, + "Compute DC coeffs")); + + return true; +} + +Status ComputeVarDCTEncodingData(const FrameHeader& frame_header, + const Image3F* linear, + Image3F* JXL_RESTRICT opsin, const Rect& rect, + const JxlCmsInterface& cms, ThreadPool* pool, + ModularFrameEncoder* enc_modular, + PassesEncoderState* enc_state, + AuxOut* aux_out) { + JXL_ASSERT((rect.xsize() % kBlockDim) == 0 && + (rect.ysize() % kBlockDim) == 0); + JXL_RETURN_IF_ERROR(LossyFrameHeuristics(frame_header, enc_state, enc_modular, + linear, opsin, rect, cms, pool, + aux_out)); + + JXL_RETURN_IF_ERROR(InitializePassesEncoder( + frame_header, *opsin, rect, cms, pool, enc_state, enc_modular, aux_out)); + return true; +} + +void ComputeAllCoeffOrders(PassesEncoderState& enc_state, + const FrameDimensions& frame_dim) { + auto used_orders_info = ComputeUsedOrders( + enc_state.cparams.speed_tier, enc_state.shared.ac_strategy, + Rect(enc_state.shared.raw_quant_field)); + enc_state.used_orders.resize(enc_state.progressive_splitter.GetNumPasses()); + for (size_t i = 0; i < enc_state.progressive_splitter.GetNumPasses(); i++) { + ComputeCoeffOrder( + enc_state.cparams.speed_tier, *enc_state.coeffs[i], + enc_state.shared.ac_strategy, frame_dim, enc_state.used_orders[i], + enc_state.used_acs, used_orders_info.first, used_orders_info.second, + &enc_state.shared.coeff_orders[i * enc_state.shared.coeff_order_size]); + } + enc_state.used_acs |= used_orders_info.first; +} + +// Working area for TokenizeCoefficients (per-group!) +struct EncCache { + // Allocates memory when first called. + void InitOnce() { + if (num_nzeroes.xsize() == 0) { + num_nzeroes = Image3I(kGroupDimInBlocks, kGroupDimInBlocks); + } + } + // TokenizeCoefficients + Image3I num_nzeroes; +}; + +Status TokenizeAllCoefficients(const FrameHeader& frame_header, + ThreadPool* pool, + PassesEncoderState* enc_state) { + PassesSharedState& shared = enc_state->shared; + std::vector<EncCache> group_caches; + const auto tokenize_group_init = [&](const size_t num_threads) { + group_caches.resize(num_threads); + return true; + }; + const auto tokenize_group = [&](const uint32_t group_index, + const size_t thread) { + // Tokenize coefficients. + const Rect rect = shared.frame_dim.BlockGroupRect(group_index); + for (size_t idx_pass = 0; idx_pass < enc_state->passes.size(); idx_pass++) { + JXL_ASSERT(enc_state->coeffs[idx_pass]->Type() == ACType::k32); + const int32_t* JXL_RESTRICT ac_rows[3] = { + enc_state->coeffs[idx_pass]->PlaneRow(0, group_index, 0).ptr32, + enc_state->coeffs[idx_pass]->PlaneRow(1, group_index, 0).ptr32, + enc_state->coeffs[idx_pass]->PlaneRow(2, group_index, 0).ptr32, + }; + // Ensure group cache is initialized. + group_caches[thread].InitOnce(); + TokenizeCoefficients( + &shared.coeff_orders[idx_pass * shared.coeff_order_size], rect, + ac_rows, shared.ac_strategy, frame_header.chroma_subsampling, + &group_caches[thread].num_nzeroes, + &enc_state->passes[idx_pass].ac_tokens[group_index], shared.quant_dc, + shared.raw_quant_field, shared.block_ctx_map); + } + }; + return RunOnPool(pool, 0, shared.frame_dim.num_groups, tokenize_group_init, + tokenize_group, "TokenizeGroup"); +} + +Status EncodeGlobalDCInfo(const PassesSharedState& shared, BitWriter* writer, + AuxOut* aux_out) { + // Encode quantizer DC and global scale. + QuantizerParams params = shared.quantizer.GetParams(); + JXL_RETURN_IF_ERROR( + WriteQuantizerParams(params, writer, kLayerQuant, aux_out)); + EncodeBlockCtxMap(shared.block_ctx_map, writer, aux_out); + ColorCorrelationMapEncodeDC(shared.cmap, writer, kLayerDC, aux_out); + return true; +} + +// In streaming mode, this function only performs the histogram clustering and +// saves the histogram bitstreams in enc_state, the actual AC global bitstream +// is written in OutputAcGlobal() function after all the groups are processed. +Status EncodeGlobalACInfo(PassesEncoderState* enc_state, BitWriter* writer, + ModularFrameEncoder* enc_modular, AuxOut* aux_out) { + PassesSharedState& shared = enc_state->shared; + JXL_RETURN_IF_ERROR(DequantMatricesEncode(shared.matrices, writer, + kLayerQuant, aux_out, enc_modular)); + size_t num_histo_bits = CeilLog2Nonzero(shared.frame_dim.num_groups); + if (!enc_state->streaming_mode && num_histo_bits != 0) { + BitWriter::Allotment allotment(writer, num_histo_bits); + writer->Write(num_histo_bits, shared.num_histograms - 1); + allotment.ReclaimAndCharge(writer, kLayerAC, aux_out); + } + + for (size_t i = 0; i < enc_state->progressive_splitter.GetNumPasses(); i++) { + // Encode coefficient orders. + if (!enc_state->streaming_mode) { + size_t order_bits = 0; + JXL_RETURN_IF_ERROR(U32Coder::CanEncode( + kOrderEnc, enc_state->used_orders[i], &order_bits)); + BitWriter::Allotment allotment(writer, order_bits); + JXL_CHECK(U32Coder::Write(kOrderEnc, enc_state->used_orders[i], writer)); + allotment.ReclaimAndCharge(writer, kLayerOrder, aux_out); + EncodeCoeffOrders(enc_state->used_orders[i], + &shared.coeff_orders[i * shared.coeff_order_size], + writer, kLayerOrder, aux_out); + } + + // Encode histograms. + HistogramParams hist_params(enc_state->cparams.speed_tier, + shared.block_ctx_map.NumACContexts()); + if (enc_state->cparams.speed_tier > SpeedTier::kTortoise) { + hist_params.lz77_method = HistogramParams::LZ77Method::kNone; + } + if (enc_state->cparams.decoding_speed_tier >= 1) { + hist_params.max_histograms = 6; + } + size_t num_histogram_groups = shared.num_histograms; + if (enc_state->streaming_mode) { + size_t prev_num_histograms = + enc_state->passes[i].codes.encoding_info.size(); + if (enc_state->initialize_global_state) { + prev_num_histograms += kNumFixedHistograms; + hist_params.add_fixed_histograms = true; + } + size_t remaining_histograms = kClustersLimit - prev_num_histograms; + // Heuristic to assign budget of new histograms to DC groups. + // TODO(szabadka) Tune this together with the DC group ordering. + size_t max_histograms = remaining_histograms < 20 + ? std::min<size_t>(remaining_histograms, 4) + : remaining_histograms / 4; + hist_params.max_histograms = + std::min(max_histograms, hist_params.max_histograms); + num_histogram_groups = 1; + } + hist_params.streaming_mode = enc_state->streaming_mode; + hist_params.initialize_global_state = enc_state->initialize_global_state; + BuildAndEncodeHistograms( + hist_params, + num_histogram_groups * shared.block_ctx_map.NumACContexts(), + enc_state->passes[i].ac_tokens, &enc_state->passes[i].codes, + &enc_state->passes[i].context_map, writer, kLayerAC, aux_out); + } + + return true; +} + +Status EncodeGroups(const FrameHeader& frame_header, + PassesEncoderState* enc_state, + ModularFrameEncoder* enc_modular, ThreadPool* pool, + std::vector<BitWriter>* group_codes, AuxOut* aux_out) { + const PassesSharedState& shared = enc_state->shared; + const FrameDimensions& frame_dim = shared.frame_dim; + const size_t num_groups = frame_dim.num_groups; + const size_t num_passes = enc_state->progressive_splitter.GetNumPasses(); + const size_t global_ac_index = frame_dim.num_dc_groups + 1; + const bool is_small_image = frame_dim.num_groups == 1 && num_passes == 1; + + group_codes->resize( + NumTocEntries(num_groups, frame_dim.num_dc_groups, num_passes)); + + const auto get_output = [&](const size_t index) { + return &(*group_codes)[is_small_image ? 0 : index]; + }; + auto ac_group_code = [&](size_t pass, size_t group) { + return get_output(AcGroupIndex(pass, group, frame_dim.num_groups, + frame_dim.num_dc_groups)); + }; + + if (enc_state->initialize_global_state) { + if (frame_header.flags & FrameHeader::kPatches) { + PatchDictionaryEncoder::Encode(shared.image_features.patches, + get_output(0), kLayerDictionary, aux_out); + } + if (frame_header.flags & FrameHeader::kSplines) { + EncodeSplines(shared.image_features.splines, get_output(0), kLayerSplines, + HistogramParams(), aux_out); + } + if (frame_header.flags & FrameHeader::kNoise) { + EncodeNoise(shared.image_features.noise_params, get_output(0), + kLayerNoise, aux_out); + } + + JXL_RETURN_IF_ERROR(DequantMatricesEncodeDC(shared.matrices, get_output(0), + kLayerQuant, aux_out)); + if (frame_header.encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR(EncodeGlobalDCInfo(shared, get_output(0), aux_out)); + } + JXL_RETURN_IF_ERROR(enc_modular->EncodeGlobalInfo(enc_state->streaming_mode, + get_output(0), aux_out)); + JXL_RETURN_IF_ERROR(enc_modular->EncodeStream(get_output(0), aux_out, + kLayerModularGlobal, + ModularStreamId::Global())); + } + + std::vector<std::unique_ptr<AuxOut>> aux_outs; + auto resize_aux_outs = [&aux_outs, + aux_out](const size_t num_threads) -> Status { + if (aux_out == nullptr) { + aux_outs.resize(num_threads); + } else { + while (aux_outs.size() > num_threads) { + aux_out->Assimilate(*aux_outs.back()); + aux_outs.pop_back(); + } + while (num_threads > aux_outs.size()) { + aux_outs.emplace_back(jxl::make_unique<AuxOut>()); + } + } + return true; + }; + + const auto process_dc_group = [&](const uint32_t group_index, + const size_t thread) { + AuxOut* my_aux_out = aux_outs[thread].get(); + BitWriter* output = get_output(group_index + 1); + int modular_group_index = group_index; + if (enc_state->streaming_mode) { + JXL_ASSERT(group_index == 0); + modular_group_index = enc_state->dc_group_index; + } + if (frame_header.encoding == FrameEncoding::kVarDCT && + !(frame_header.flags & FrameHeader::kUseDcFrame)) { + BitWriter::Allotment allotment(output, 2); + output->Write(2, enc_modular->extra_dc_precision[modular_group_index]); + allotment.ReclaimAndCharge(output, kLayerDC, my_aux_out); + JXL_CHECK(enc_modular->EncodeStream( + output, my_aux_out, kLayerDC, + ModularStreamId::VarDCTDC(modular_group_index))); + } + JXL_CHECK(enc_modular->EncodeStream( + output, my_aux_out, kLayerModularDcGroup, + ModularStreamId::ModularDC(modular_group_index))); + if (frame_header.encoding == FrameEncoding::kVarDCT) { + const Rect& rect = enc_state->shared.frame_dim.DCGroupRect(group_index); + size_t nb_bits = CeilLog2Nonzero(rect.xsize() * rect.ysize()); + if (nb_bits != 0) { + BitWriter::Allotment allotment(output, nb_bits); + output->Write(nb_bits, + enc_modular->ac_metadata_size[modular_group_index] - 1); + allotment.ReclaimAndCharge(output, kLayerControlFields, my_aux_out); + } + JXL_CHECK(enc_modular->EncodeStream( + output, my_aux_out, kLayerControlFields, + ModularStreamId::ACMetadata(modular_group_index))); + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, frame_dim.num_dc_groups, + resize_aux_outs, process_dc_group, + "EncodeDCGroup")); + + if (frame_header.encoding == FrameEncoding::kVarDCT) { + JXL_RETURN_IF_ERROR(EncodeGlobalACInfo( + enc_state, get_output(global_ac_index), enc_modular, aux_out)); + } + + std::atomic<int> num_errors{0}; + const auto process_group = [&](const uint32_t group_index, + const size_t thread) { + AuxOut* my_aux_out = aux_outs[thread].get(); + + for (size_t i = 0; i < num_passes; i++) { + if (frame_header.encoding == FrameEncoding::kVarDCT) { + if (!EncodeGroupTokenizedCoefficients( + group_index, i, enc_state->histogram_idx[group_index], + *enc_state, ac_group_code(i, group_index), my_aux_out)) { + num_errors.fetch_add(1, std::memory_order_relaxed); + return; + } + } + // Write all modular encoded data (color?, alpha, depth, extra channels) + if (!enc_modular->EncodeStream( + ac_group_code(i, group_index), my_aux_out, kLayerModularAcGroup, + ModularStreamId::ModularAC(group_index, i))) { + num_errors.fetch_add(1, std::memory_order_relaxed); + return; + } + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, num_groups, resize_aux_outs, + process_group, "EncodeGroupCoefficients")); + + // Resizing aux_outs to 0 also Assimilates the array. + static_cast<void>(resize_aux_outs(0)); + JXL_RETURN_IF_ERROR(num_errors.load(std::memory_order_relaxed) == 0); + + for (BitWriter& bw : *group_codes) { + BitWriter::Allotment allotment(&bw, 8); + bw.ZeroPadToByte(); // end of group. + allotment.ReclaimAndCharge(&bw, kLayerAC, aux_out); + } + return true; +} + +Status ComputeEncodingData( + const CompressParams& cparams, const FrameInfo& frame_info, + const CodecMetadata* metadata, JxlEncoderChunkedFrameAdapter& frame_data, + const jpeg::JPEGData* jpeg_data, size_t x0, size_t y0, size_t xsize, + size_t ysize, const JxlCmsInterface& cms, ThreadPool* pool, + FrameHeader& mutable_frame_header, ModularFrameEncoder& enc_modular, + PassesEncoderState& enc_state, std::vector<BitWriter>* group_codes, + AuxOut* aux_out) { + JXL_ASSERT(x0 + xsize <= frame_data.xsize); + JXL_ASSERT(y0 + ysize <= frame_data.ysize); + const FrameHeader& frame_header = mutable_frame_header; + PassesSharedState& shared = enc_state.shared; + shared.metadata = metadata; + if (enc_state.streaming_mode) { + shared.frame_dim.Set(xsize, ysize, /*group_size_shift=*/1, + /*maxhshift=*/0, /*maxvshift=*/0, + /*modular_mode=*/false, /*upsampling=*/1); + } else { + shared.frame_dim = frame_header.ToFrameDimensions(); + } + + shared.image_features.patches.SetPassesSharedState(&shared); + const FrameDimensions& frame_dim = shared.frame_dim; + shared.ac_strategy = + AcStrategyImage(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared.raw_quant_field = + ImageI(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared.epf_sharpness = ImageB(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared.cmap = ColorCorrelationMap(frame_dim.xsize, frame_dim.ysize); + shared.coeff_order_size = kCoeffOrderMaxSize; + if (frame_header.encoding == FrameEncoding::kVarDCT) { + shared.coeff_orders.resize(frame_header.passes.num_passes * + kCoeffOrderMaxSize); + } + + shared.quant_dc = ImageB(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared.dc_storage = Image3F(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared.dc = &shared.dc_storage; + + const size_t num_extra_channels = metadata->m.num_extra_channels; + const ExtraChannelInfo* alpha_eci = metadata->m.Find(ExtraChannel::kAlpha); + const ExtraChannelInfo* black_eci = metadata->m.Find(ExtraChannel::kBlack); + const size_t alpha_idx = alpha_eci - metadata->m.extra_channel_info.data(); + const size_t black_idx = black_eci - metadata->m.extra_channel_info.data(); + const ColorEncoding c_enc = metadata->m.color_encoding; + + // Make the image patch bigger than the currently processed group in streaming + // mode so that we can take into account border pixels around the group when + // computing inverse Gaborish and adaptive quantization map. + int max_border = enc_state.streaming_mode ? kBlockDim : 0; + Rect frame_rect(0, 0, frame_data.xsize, frame_data.ysize); + Rect patch_rect = Rect(x0, y0, xsize, ysize).Extend(max_border, frame_rect); + JXL_ASSERT(patch_rect.IsInside(frame_rect)); + + // Allocating a large enough image avoids a copy when padding. + Image3F color(RoundUpToBlockDim(patch_rect.xsize()), + RoundUpToBlockDim(patch_rect.ysize())); + color.ShrinkTo(patch_rect.xsize(), patch_rect.ysize()); + std::vector<ImageF> extra_channels(num_extra_channels); + for (auto& extra_channel : extra_channels) { + extra_channel = jxl::ImageF(patch_rect.xsize(), patch_rect.ysize()); + } + ImageF* alpha = alpha_eci ? &extra_channels[alpha_idx] : nullptr; + ImageF* black = black_eci ? &extra_channels[black_idx] : nullptr; + bool has_interleaved_alpha = false; + JxlChunkedFrameInputSource input = frame_data.GetInputSource(); + if (!frame_data.IsJPEG()) { + JXL_RETURN_IF_ERROR(CopyColorChannels(input, patch_rect, frame_info, + metadata->m, pool, &color, alpha, + &has_interleaved_alpha)); + } + JXL_RETURN_IF_ERROR(CopyExtraChannels(input, patch_rect, frame_info, + metadata->m, has_interleaved_alpha, + pool, &extra_channels)); + + shared.image_features.patches.SetPassesSharedState(&shared); + enc_state.cparams = cparams; + + Image3F linear_storage; + Image3F* linear = nullptr; + + if (!jpeg_data) { + if (frame_header.color_transform == ColorTransform::kXYB && + frame_info.ib_needs_color_transform) { + if (frame_header.encoding == FrameEncoding::kVarDCT && + cparams.speed_tier <= SpeedTier::kKitten) { + linear_storage = Image3F(patch_rect.xsize(), patch_rect.ysize()); + linear = &linear_storage; + } + ToXYB(c_enc, metadata->m.IntensityTarget(), black, pool, &color, cms, + linear); + } else { + // Nothing to do. + // RGB or YCbCr: forward YCbCr is not implemented, this is only used when + // the input is already in YCbCr + // If encoding a special DC or reference frame: input is already in XYB. + } + bool lossless = cparams.IsLossless(); + if (alpha && !alpha_eci->alpha_associated && + frame_header.frame_type == FrameType::kRegularFrame && + !ApplyOverride(cparams.keep_invisible, lossless) && + cparams.ec_resampling == cparams.resampling) { + // simplify invisible pixels + SimplifyInvisible(&color, *alpha, lossless); + if (linear) { + SimplifyInvisible(linear, *alpha, lossless); + } + } + PadImageToBlockMultipleInPlace(&color); + } + + // Rectangle within color that corresponds to the currently processed group in + // streaming mode. + Rect group_rect(x0 - patch_rect.x0(), y0 - patch_rect.y0(), + RoundUpToBlockDim(xsize), RoundUpToBlockDim(ysize)); + + if (enc_state.initialize_global_state && !jpeg_data) { + ComputeChromacityAdjustments(cparams, color, group_rect, + &mutable_frame_header); + } + + ComputeNoiseParams(cparams, enc_state.streaming_mode, !!jpeg_data, color, + frame_dim, &mutable_frame_header, + &shared.image_features.noise_params); + + DownsampleColorChannels(cparams, frame_header, !!jpeg_data, &color); + + if (cparams.ec_resampling != 1 && !cparams.already_downsampled) { + for (ImageF& ec : extra_channels) { + DownsampleImage(&ec, cparams.ec_resampling); + } + } + + if (!enc_state.streaming_mode) { + group_rect = Rect(color); + } + + if (frame_header.encoding == FrameEncoding::kVarDCT) { + enc_state.passes.resize(enc_state.progressive_splitter.GetNumPasses()); + for (PassesEncoderState::PassData& pass : enc_state.passes) { + pass.ac_tokens.resize(shared.frame_dim.num_groups); + } + if (jpeg_data) { + JXL_RETURN_IF_ERROR(ComputeJPEGTranscodingData( + *jpeg_data, frame_header, pool, &enc_modular, &enc_state)); + } else { + JXL_RETURN_IF_ERROR(ComputeVarDCTEncodingData( + frame_header, linear, &color, group_rect, cms, pool, &enc_modular, + &enc_state, aux_out)); + } + ComputeAllCoeffOrders(enc_state, frame_dim); + if (!enc_state.streaming_mode) { + shared.num_histograms = 1; + enc_state.histogram_idx.resize(frame_dim.num_groups); + } + JXL_RETURN_IF_ERROR( + TokenizeAllCoefficients(frame_header, pool, &enc_state)); + } + + if (!enc_state.streaming_mode) { + if (cparams.modular_mode || !extra_channels.empty()) { + JXL_RETURN_IF_ERROR(enc_modular.ComputeEncodingData( + frame_header, metadata->m, &color, extra_channels, &enc_state, cms, + pool, aux_out, /*do_color=*/cparams.modular_mode)); + } + JXL_RETURN_IF_ERROR(enc_modular.ComputeTree(pool)); + JXL_RETURN_IF_ERROR(enc_modular.ComputeTokens(pool)); + + mutable_frame_header.UpdateFlag(shared.image_features.patches.HasAny(), + FrameHeader::kPatches); + mutable_frame_header.UpdateFlag(shared.image_features.splines.HasAny(), + FrameHeader::kSplines); + } + + JXL_RETURN_IF_ERROR(EncodeGroups(frame_header, &enc_state, &enc_modular, pool, + group_codes, aux_out)); + if (enc_state.streaming_mode) { + const size_t group_index = enc_state.dc_group_index; + enc_modular.ClearStreamData(ModularStreamId::VarDCTDC(group_index)); + enc_modular.ClearStreamData(ModularStreamId::ACMetadata(group_index)); + } + return true; +} + +Status PermuteGroups(const CompressParams& cparams, + const FrameDimensions& frame_dim, size_t num_passes, + std::vector<coeff_order_t>* permutation, + std::vector<BitWriter>* group_codes) { + const size_t num_groups = frame_dim.num_groups; + if (!cparams.centerfirst || (num_passes == 1 && num_groups == 1)) { + return true; + } + // Don't permute global DC/AC or DC. + permutation->resize(frame_dim.num_dc_groups + 2); + std::iota(permutation->begin(), permutation->end(), 0); + std::vector<coeff_order_t> ac_group_order(num_groups); + std::iota(ac_group_order.begin(), ac_group_order.end(), 0); + size_t group_dim = frame_dim.group_dim; + + // The center of the image is either given by parameters or chosen + // to be the middle of the image by default if center_x, center_y resp. + // are not provided. + + int64_t imag_cx; + if (cparams.center_x != static_cast<size_t>(-1)) { + JXL_RETURN_IF_ERROR(cparams.center_x < frame_dim.xsize); + imag_cx = cparams.center_x; + } else { + imag_cx = frame_dim.xsize / 2; + } + + int64_t imag_cy; + if (cparams.center_y != static_cast<size_t>(-1)) { + JXL_RETURN_IF_ERROR(cparams.center_y < frame_dim.ysize); + imag_cy = cparams.center_y; + } else { + imag_cy = frame_dim.ysize / 2; + } + + // The center of the group containing the center of the image. + int64_t cx = (imag_cx / group_dim) * group_dim + group_dim / 2; + int64_t cy = (imag_cy / group_dim) * group_dim + group_dim / 2; + // This identifies in what area of the central group the center of the image + // lies in. + double direction = -std::atan2(imag_cy - cy, imag_cx - cx); + // This identifies the side of the central group the center of the image + // lies closest to. This can take values 0, 1, 2, 3 corresponding to left, + // bottom, right, top. + int64_t side = std::fmod((direction + 5 * kPi / 4), 2 * kPi) * 2 / kPi; + auto get_distance_from_center = [&](size_t gid) { + Rect r = frame_dim.GroupRect(gid); + int64_t gcx = r.x0() + group_dim / 2; + int64_t gcy = r.y0() + group_dim / 2; + int64_t dx = gcx - cx; + int64_t dy = gcy - cy; + // The angle is determined by taking atan2 and adding an appropriate + // starting point depending on the side we want to start on. + double angle = std::remainder( + std::atan2(dy, dx) + kPi / 4 + side * (kPi / 2), 2 * kPi); + // Concentric squares in clockwise order. + return std::make_pair(std::max(std::abs(dx), std::abs(dy)), angle); + }; + std::sort(ac_group_order.begin(), ac_group_order.end(), + [&](coeff_order_t a, coeff_order_t b) { + return get_distance_from_center(a) < get_distance_from_center(b); + }); + std::vector<coeff_order_t> inv_ac_group_order(ac_group_order.size(), 0); + for (size_t i = 0; i < ac_group_order.size(); i++) { + inv_ac_group_order[ac_group_order[i]] = i; + } + for (size_t i = 0; i < num_passes; i++) { + size_t pass_start = permutation->size(); + for (coeff_order_t v : inv_ac_group_order) { + permutation->push_back(pass_start + v); + } + } + std::vector<BitWriter> new_group_codes(group_codes->size()); + for (size_t i = 0; i < permutation->size(); i++) { + new_group_codes[(*permutation)[i]] = std::move((*group_codes)[i]); + } + *group_codes = std::move(new_group_codes); + return true; +} + +bool CanDoStreamingEncoding(const CompressParams& cparams, + const FrameInfo& frame_info, + const CodecMetadata& metadata, + const JxlEncoderChunkedFrameAdapter& frame_data) { + if (frame_data.IsJPEG()) { + return false; + } + if (cparams.noise == Override::kOn || cparams.patches == Override::kOn) { + return false; + } + if (cparams.progressive_dc != 0 || frame_info.dc_level != 0) { + return false; + } + if (cparams.resampling != 1 || cparams.ec_resampling != 1) { + return false; + } + if (cparams.max_error_mode) { + return false; + } + if (cparams.color_transform != ColorTransform::kXYB) { + return false; + } + if (cparams.modular_mode) { + return false; + } + if (metadata.m.num_extra_channels > 0) { + return false; + } + if (cparams.buffering == 0) { + return false; + } + if (cparams.buffering == 1 && frame_data.xsize <= 2048 && + frame_data.ysize <= 2048) { + return false; + } + if (frame_data.xsize <= 256 && frame_data.ysize <= 256) { + return false; + } + return true; +} + +void ComputePermutationForStreaming(size_t xsize, size_t ysize, + size_t num_passes, + std::vector<coeff_order_t>& permutation, + std::vector<size_t>& dc_group_order) { + // This is only valid in VarDCT mode, otherwise there can be group shift. + const size_t group_size = 256; + const size_t dc_group_size = group_size * kBlockDim; + const size_t group_xsize = DivCeil(xsize, group_size); + const size_t group_ysize = DivCeil(ysize, group_size); + const size_t dc_group_xsize = DivCeil(xsize, dc_group_size); + const size_t dc_group_ysize = DivCeil(ysize, dc_group_size); + const size_t num_groups = group_xsize * group_ysize; + const size_t num_dc_groups = dc_group_xsize * dc_group_ysize; + const size_t num_sections = 2 + num_dc_groups + num_passes * num_groups; + permutation.resize(num_sections); + size_t new_ix = 0; + // DC Global is first + permutation[0] = new_ix++; + // TODO(szabadka) Change the dc group order to center-first. + for (size_t dc_y = 0; dc_y < dc_group_ysize; ++dc_y) { + for (size_t dc_x = 0; dc_x < dc_group_xsize; ++dc_x) { + size_t dc_ix = dc_y * dc_group_xsize + dc_x; + dc_group_order.push_back(dc_ix); + permutation[1 + dc_ix] = new_ix++; + size_t ac_y0 = dc_y * kBlockDim; + size_t ac_x0 = dc_x * kBlockDim; + size_t ac_y1 = std::min<size_t>(group_ysize, ac_y0 + kBlockDim); + size_t ac_x1 = std::min<size_t>(group_xsize, ac_x0 + kBlockDim); + for (size_t pass = 0; pass < num_passes; ++pass) { + for (size_t ac_y = ac_y0; ac_y < ac_y1; ++ac_y) { + for (size_t ac_x = ac_x0; ac_x < ac_x1; ++ac_x) { + size_t group_ix = ac_y * group_xsize + ac_x; + size_t old_ix = + AcGroupIndex(pass, group_ix, num_groups, num_dc_groups); + permutation[old_ix] = new_ix++; + } + } + } + } + } + // AC Global is last + permutation[1 + num_dc_groups] = new_ix++; + JXL_ASSERT(new_ix == num_sections); +} + +constexpr size_t kGroupSizeOffset[4] = { + static_cast<size_t>(0), + static_cast<size_t>(1024), + static_cast<size_t>(17408), + static_cast<size_t>(4211712), +}; +constexpr size_t kTOCBits[4] = {12, 16, 24, 32}; + +size_t TOCBucket(size_t group_size) { + size_t bucket = 0; + while (bucket < 3 && group_size >= kGroupSizeOffset[bucket + 1]) ++bucket; + return bucket; +} + +size_t TOCSize(const std::vector<size_t>& group_sizes) { + size_t toc_bits = 0; + for (size_t i = 0; i < group_sizes.size(); i++) { + toc_bits += kTOCBits[TOCBucket(group_sizes[i])]; + } + return (toc_bits + 7) / 8; +} + +PaddedBytes EncodeTOC(const std::vector<size_t>& group_sizes, AuxOut* aux_out) { + BitWriter writer; + BitWriter::Allotment allotment(&writer, 32 * group_sizes.size()); + for (size_t i = 0; i < group_sizes.size(); i++) { + JXL_CHECK(U32Coder::Write(kTocDist, group_sizes[i], &writer)); + } + writer.ZeroPadToByte(); // before first group + allotment.ReclaimAndCharge(&writer, kLayerTOC, aux_out); + return std::move(writer).TakeBytes(); +} + +void ComputeGroupDataOffset(size_t frame_header_size, size_t dc_global_size, + size_t num_sections, size_t& min_dc_global_size, + size_t& group_offset) { + size_t max_toc_bits = (num_sections - 1) * 32; + size_t min_toc_bits = (num_sections - 1) * 12; + size_t max_padding = (max_toc_bits - min_toc_bits + 7) / 8; + min_dc_global_size = dc_global_size; + size_t dc_global_bucket = TOCBucket(min_dc_global_size); + while (TOCBucket(min_dc_global_size + max_padding) > dc_global_bucket) { + dc_global_bucket = TOCBucket(min_dc_global_size + max_padding); + min_dc_global_size = kGroupSizeOffset[dc_global_bucket]; + } + JXL_ASSERT(TOCBucket(min_dc_global_size) == dc_global_bucket); + JXL_ASSERT(TOCBucket(min_dc_global_size + max_padding) == dc_global_bucket); + max_toc_bits += kTOCBits[dc_global_bucket]; + size_t max_toc_size = (max_toc_bits + 7) / 8; + group_offset = frame_header_size + max_toc_size + min_dc_global_size; +} + +size_t ComputeDcGlobalPadding(const std::vector<size_t>& group_sizes, + size_t frame_header_size, + size_t group_data_offset, + size_t min_dc_global_size) { + std::vector<size_t> new_group_sizes = group_sizes; + new_group_sizes[0] = min_dc_global_size; + size_t toc_size = TOCSize(new_group_sizes); + size_t actual_offset = frame_header_size + toc_size + group_sizes[0]; + return group_data_offset - actual_offset; +} + +Status OutputGroups(std::vector<BitWriter>&& group_codes, + std::vector<size_t>* group_sizes, + JxlEncoderOutputProcessorWrapper* output_processor) { + JXL_ASSERT(group_codes.size() >= 4); + { + PaddedBytes dc_group = std::move(group_codes[1]).TakeBytes(); + group_sizes->push_back(dc_group.size()); + JXL_RETURN_IF_ERROR(AppendData(*output_processor, dc_group)); + } + for (size_t i = 3; i < group_codes.size(); ++i) { + PaddedBytes ac_group = std::move(group_codes[i]).TakeBytes(); + group_sizes->push_back(ac_group.size()); + JXL_RETURN_IF_ERROR(AppendData(*output_processor, ac_group)); + } + return true; +} + +void RemoveUnusedHistograms(std::vector<uint8_t>& context_map, + EntropyEncodingData& codes) { + std::vector<int> remap(256, -1); + std::vector<uint8_t> inv_remap; + for (size_t i = 0; i < context_map.size(); ++i) { + const uint8_t histo_ix = context_map[i]; + if (remap[histo_ix] == -1) { + remap[histo_ix] = inv_remap.size(); + inv_remap.push_back(histo_ix); + } + context_map[i] = remap[histo_ix]; + } + EntropyEncodingData new_codes; + new_codes.use_prefix_code = codes.use_prefix_code; + new_codes.lz77 = codes.lz77; + for (uint8_t histo_idx : inv_remap) { + new_codes.encoding_info.emplace_back( + std::move(codes.encoding_info[histo_idx])); + new_codes.uint_config.emplace_back(std::move(codes.uint_config[histo_idx])); + new_codes.encoded_histograms.emplace_back( + std::move(codes.encoded_histograms[histo_idx])); + } + codes = std::move(new_codes); +} + +Status OutputAcGlobal(PassesEncoderState& enc_state, + const FrameDimensions& frame_dim, + std::vector<size_t>* group_sizes, + JxlEncoderOutputProcessorWrapper* output_processor, + AuxOut* aux_out) { + JXL_ASSERT(frame_dim.num_groups > 1); + BitWriter writer; + { + size_t num_histo_bits = CeilLog2Nonzero(frame_dim.num_groups); + BitWriter::Allotment allotment(&writer, num_histo_bits + 1); + writer.Write(1, 1); // default dequant matrices + writer.Write(num_histo_bits, frame_dim.num_dc_groups - 1); + allotment.ReclaimAndCharge(&writer, kLayerAC, aux_out); + } + const PassesSharedState& shared = enc_state.shared; + for (size_t i = 0; i < enc_state.progressive_splitter.GetNumPasses(); i++) { + // Encode coefficient orders. + size_t order_bits = 0; + JXL_RETURN_IF_ERROR( + U32Coder::CanEncode(kOrderEnc, enc_state.used_orders[i], &order_bits)); + BitWriter::Allotment allotment(&writer, order_bits); + JXL_CHECK(U32Coder::Write(kOrderEnc, enc_state.used_orders[i], &writer)); + allotment.ReclaimAndCharge(&writer, kLayerOrder, aux_out); + EncodeCoeffOrders(enc_state.used_orders[i], + &shared.coeff_orders[i * shared.coeff_order_size], + &writer, kLayerOrder, aux_out); + // Fix up context map and entropy codes to remove any fix histograms that + // were not selected by clustering. + RemoveUnusedHistograms(enc_state.passes[i].context_map, + enc_state.passes[i].codes); + EncodeHistograms(enc_state.passes[i].context_map, enc_state.passes[i].codes, + &writer, kLayerAC, aux_out); + } + { + BitWriter::Allotment allotment(&writer, 8); + writer.ZeroPadToByte(); // end of group. + allotment.ReclaimAndCharge(&writer, kLayerAC, aux_out); + } + PaddedBytes ac_global = std::move(writer).TakeBytes(); + group_sizes->push_back(ac_global.size()); + JXL_RETURN_IF_ERROR(AppendData(*output_processor, ac_global)); + return true; +} + +Status EncodeFrameStreaming(const CompressParams& cparams, + const FrameInfo& frame_info, + const CodecMetadata* metadata, + JxlEncoderChunkedFrameAdapter& frame_data, + const JxlCmsInterface& cms, ThreadPool* pool, + JxlEncoderOutputProcessorWrapper* output_processor, + AuxOut* aux_out) { + PassesEncoderState enc_state; + SetProgressiveMode(cparams, &enc_state.progressive_splitter); + FrameHeader frame_header(metadata); + std::unique_ptr<jpeg::JPEGData> jpeg_data; + if (frame_data.IsJPEG()) { + jpeg_data = make_unique<jpeg::JPEGData>(frame_data.TakeJPEGData()); + } + JXL_RETURN_IF_ERROR(MakeFrameHeader(frame_data.xsize, frame_data.ysize, + cparams, enc_state.progressive_splitter, + frame_info, jpeg_data.get(), true, + &frame_header)); + const size_t num_passes = enc_state.progressive_splitter.GetNumPasses(); + ModularFrameEncoder enc_modular(frame_header, cparams); + std::vector<coeff_order_t> permutation; + std::vector<size_t> dc_group_order; + ComputePermutationForStreaming(frame_data.xsize, frame_data.ysize, num_passes, + permutation, dc_group_order); + enc_state.shared.num_histograms = dc_group_order.size(); + // This is only valid in VarDCT mode, otherwise there can be group shift. + size_t group_size = 256; + size_t dc_group_size = group_size * kBlockDim; + size_t dc_group_xsize = DivCeil(frame_data.xsize, dc_group_size); + size_t min_dc_global_size = 0; + size_t group_data_offset = 0; + PaddedBytes frame_header_bytes; + PaddedBytes dc_global_bytes; + std::vector<size_t> group_sizes; + size_t start_pos = output_processor->CurrentPosition(); + for (size_t i = 0; i < dc_group_order.size(); ++i) { + size_t dc_ix = dc_group_order[i]; + size_t dc_y = dc_ix / dc_group_xsize; + size_t dc_x = dc_ix % dc_group_xsize; + size_t y0 = dc_y * dc_group_size; + size_t x0 = dc_x * dc_group_size; + size_t ysize = std::min<size_t>(dc_group_size, frame_data.ysize - y0); + size_t xsize = std::min<size_t>(dc_group_size, frame_data.xsize - x0); + size_t group_xsize = DivCeil(xsize, group_size); + size_t group_ysize = DivCeil(ysize, group_size); + JXL_DEBUG_V(2, + "Encoding DC group #%" PRIuS " dc_y = %" PRIuS " dc_x = %" PRIuS + " (x0, y0) = (%" PRIuS ", %" PRIuS ") (xsize, ysize) = (%" PRIuS + ", %" PRIuS ")", + dc_ix, dc_y, dc_x, x0, y0, xsize, ysize); + enc_state.streaming_mode = true; + enc_state.initialize_global_state = (i == 0); + enc_state.dc_group_index = dc_ix; + enc_state.histogram_idx = + std::vector<uint8_t>(group_xsize * group_ysize, i); + std::vector<BitWriter> group_codes; + JXL_RETURN_IF_ERROR(ComputeEncodingData( + cparams, frame_info, metadata, frame_data, jpeg_data.get(), x0, y0, + xsize, ysize, cms, pool, frame_header, enc_modular, enc_state, + &group_codes, aux_out)); + JXL_ASSERT(enc_state.special_frames.empty()); + if (i == 0) { + BitWriter writer; + JXL_RETURN_IF_ERROR(WriteFrameHeader(frame_header, &writer, aux_out)); + BitWriter::Allotment allotment(&writer, 8); + writer.Write(1, 1); // write permutation + EncodePermutation(permutation.data(), /*skip=*/0, permutation.size(), + &writer, kLayerHeader, aux_out); + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, kLayerHeader, aux_out); + frame_header_bytes = std::move(writer).TakeBytes(); + dc_global_bytes = std::move(group_codes[0]).TakeBytes(); + ComputeGroupDataOffset(frame_header_bytes.size(), dc_global_bytes.size(), + permutation.size(), min_dc_global_size, + group_data_offset); + JXL_DEBUG_V(2, "Frame header size: %" PRIuS, frame_header_bytes.size()); + JXL_DEBUG_V(2, "DC global size: %" PRIuS ", min size for TOC: %" PRIuS, + dc_global_bytes.size(), min_dc_global_size); + JXL_DEBUG_V(2, "Num groups: %" PRIuS " group data offset: %" PRIuS, + permutation.size(), group_data_offset); + group_sizes.push_back(dc_global_bytes.size()); + output_processor->Seek(start_pos + group_data_offset); + } + JXL_RETURN_IF_ERROR( + OutputGroups(std::move(group_codes), &group_sizes, output_processor)); + } + JXL_RETURN_IF_ERROR(OutputAcGlobal(enc_state, + frame_header.ToFrameDimensions(), + &group_sizes, output_processor, aux_out)); + JXL_ASSERT(group_sizes.size() == permutation.size()); + size_t end_pos = output_processor->CurrentPosition(); + output_processor->Seek(start_pos); + size_t padding_size = + ComputeDcGlobalPadding(group_sizes, frame_header_bytes.size(), + group_data_offset, min_dc_global_size); + group_sizes[0] += padding_size; + PaddedBytes toc_bytes = EncodeTOC(group_sizes, aux_out); + std::vector<uint8_t> padding_bytes(padding_size); + JXL_RETURN_IF_ERROR(AppendData(*output_processor, frame_header_bytes)); + JXL_RETURN_IF_ERROR(AppendData(*output_processor, toc_bytes)); + JXL_RETURN_IF_ERROR(AppendData(*output_processor, dc_global_bytes)); + JXL_RETURN_IF_ERROR(AppendData(*output_processor, padding_bytes)); + JXL_DEBUG_V(2, "TOC size: %" PRIuS " padding bytes after DC global: %" PRIuS, + toc_bytes.size(), padding_size); + JXL_ASSERT(output_processor->CurrentPosition() == + start_pos + group_data_offset); + output_processor->Seek(end_pos); + return true; +} + +Status EncodeFrameOneShot(const CompressParams& cparams, + const FrameInfo& frame_info, + const CodecMetadata* metadata, + JxlEncoderChunkedFrameAdapter& frame_data, + const JxlCmsInterface& cms, ThreadPool* pool, + JxlEncoderOutputProcessorWrapper* output_processor, + AuxOut* aux_out) { + PassesEncoderState enc_state; + SetProgressiveMode(cparams, &enc_state.progressive_splitter); + std::vector<BitWriter> group_codes; + FrameHeader frame_header(metadata); + std::unique_ptr<jpeg::JPEGData> jpeg_data; + if (frame_data.IsJPEG()) { + jpeg_data = make_unique<jpeg::JPEGData>(frame_data.TakeJPEGData()); + } + JXL_RETURN_IF_ERROR(MakeFrameHeader(frame_data.xsize, frame_data.ysize, + cparams, enc_state.progressive_splitter, + frame_info, jpeg_data.get(), false, + &frame_header)); + const size_t num_passes = enc_state.progressive_splitter.GetNumPasses(); + ModularFrameEncoder enc_modular(frame_header, cparams); + JXL_RETURN_IF_ERROR(ComputeEncodingData( + cparams, frame_info, metadata, frame_data, jpeg_data.get(), 0, 0, + frame_data.xsize, frame_data.ysize, cms, pool, frame_header, enc_modular, + enc_state, &group_codes, aux_out)); + + BitWriter writer; + writer.AppendByteAligned(enc_state.special_frames); + JXL_RETURN_IF_ERROR(WriteFrameHeader(frame_header, &writer, aux_out)); + + std::vector<coeff_order_t> permutation; + JXL_RETURN_IF_ERROR(PermuteGroups(cparams, enc_state.shared.frame_dim, + num_passes, &permutation, &group_codes)); + + JXL_RETURN_IF_ERROR( + WriteGroupOffsets(group_codes, permutation, &writer, aux_out)); + + writer.AppendByteAligned(group_codes); + PaddedBytes frame_bytes = std::move(writer).TakeBytes(); + JXL_RETURN_IF_ERROR(AppendData(*output_processor, frame_bytes)); + + return true; +} + +} // namespace + +Status EncodeFrame(const CompressParams& cparams_orig, + const FrameInfo& frame_info, const CodecMetadata* metadata, + JxlEncoderChunkedFrameAdapter& frame_data, + const JxlCmsInterface& cms, ThreadPool* pool, + JxlEncoderOutputProcessorWrapper* output_processor, + AuxOut* aux_out) { + CompressParams cparams = cparams_orig; + if (cparams.speed_tier == SpeedTier::kGlacier && !cparams.IsLossless()) { + cparams.speed_tier = SpeedTier::kTortoise; + } + if (cparams.speed_tier == SpeedTier::kGlacier) { + std::vector<CompressParams> all_params; + std::vector<size_t> size; + + CompressParams cparams_attempt = cparams_orig; + cparams_attempt.speed_tier = SpeedTier::kTortoise; + cparams_attempt.options.max_properties = 4; + + for (float x : {0.0f, 80.f}) { + cparams_attempt.channel_colors_percent = x; + for (float y : {0.0f, 95.0f}) { + cparams_attempt.channel_colors_pre_transform_percent = y; + // 70000 ensures that the number of palette colors is representable in + // modular headers. + for (int K : {0, 1 << 10, 70000}) { + cparams_attempt.palette_colors = K; + for (int tree_mode : {-1, (int)ModularOptions::TreeMode::kNoWP, + (int)ModularOptions::TreeMode::kDefault}) { + if (tree_mode == -1) { + // LZ77 only + cparams_attempt.options.nb_repeats = 0; + } else { + cparams_attempt.options.nb_repeats = 1; + cparams_attempt.options.wp_tree_mode = + static_cast<ModularOptions::TreeMode>(tree_mode); + } + for (Predictor pred : {Predictor::Zero, Predictor::Variable}) { + cparams_attempt.options.predictor = pred; + for (int g : {0, -1, 3}) { + cparams_attempt.modular_group_size_shift = g; + for (Override patches : {Override::kDefault, Override::kOff}) { + cparams_attempt.patches = patches; + all_params.push_back(cparams_attempt); + } + } + } + } + } + } + } + + size.resize(all_params.size()); + + std::atomic<int> num_errors{0}; + + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, all_params.size(), ThreadPool::NoInit, + [&](size_t task, size_t) { + std::vector<uint8_t> output(64); + uint8_t* next_out = output.data(); + size_t avail_out = output.size(); + JxlEncoderOutputProcessorWrapper local_output; + local_output.SetAvailOut(&next_out, &avail_out); + if (!EncodeFrame(all_params[task], frame_info, metadata, frame_data, + cms, nullptr, &local_output, aux_out)) { + num_errors.fetch_add(1, std::memory_order_relaxed); + return; + } + size[task] = local_output.CurrentPosition(); + }, + "Compress kGlacier")); + JXL_RETURN_IF_ERROR(num_errors.load(std::memory_order_relaxed) == 0); + + size_t best_idx = 0; + for (size_t i = 1; i < all_params.size(); i++) { + if (size[best_idx] > size[i]) { + best_idx = i; + } + } + cparams = all_params[best_idx]; + } + + JXL_RETURN_IF_ERROR(ParamsPostInit(&cparams)); + + if (cparams.butteraugli_distance < 0) { + return JXL_FAILURE("Expected non-negative distance"); + } + + if (cparams.progressive_dc < 0) { + if (cparams.progressive_dc != -1) { + return JXL_FAILURE("Invalid progressive DC setting value (%d)", + cparams.progressive_dc); + } + cparams.progressive_dc = 0; + } + if (cparams.ec_resampling < cparams.resampling) { + cparams.ec_resampling = cparams.resampling; + } + if (cparams.resampling > 1 || frame_info.is_preview) { + cparams.progressive_dc = 0; + } + + if (frame_info.dc_level + cparams.progressive_dc > 4) { + return JXL_FAILURE("Too many levels of progressive DC"); + } + + if (cparams.butteraugli_distance != 0 && + cparams.butteraugli_distance < kMinButteraugliDistance) { + return JXL_FAILURE("Butteraugli distance is too low (%f)", + cparams.butteraugli_distance); + } + + if (frame_data.IsJPEG()) { + cparams.gaborish = Override::kOff; + cparams.epf = 0; + cparams.modular_mode = false; + } + + if (frame_data.xsize == 0 || frame_data.ysize == 0) { + return JXL_FAILURE("Empty image"); + } + + // Assert that this metadata is correctly set up for the compression params, + // this should have been done by enc_file.cc + JXL_ASSERT(metadata->m.xyb_encoded == + (cparams.color_transform == ColorTransform::kXYB)); + + if (frame_data.IsJPEG() && cparams.color_transform == ColorTransform::kXYB) { + return JXL_FAILURE("Can't add JPEG frame to XYB codestream"); + } + + if (CanDoStreamingEncoding(cparams, frame_info, *metadata, frame_data)) { + return EncodeFrameStreaming(cparams, frame_info, metadata, frame_data, cms, + pool, output_processor, aux_out); + } else { + return EncodeFrameOneShot(cparams, frame_info, metadata, frame_data, cms, + pool, output_processor, aux_out); + } +} + +Status EncodeFrame(const CompressParams& cparams_orig, + const FrameInfo& frame_info, const CodecMetadata* metadata, + const ImageBundle& ib, const JxlCmsInterface& cms, + ThreadPool* pool, BitWriter* writer, AuxOut* aux_out) { + JxlEncoderChunkedFrameAdapter frame_data(ib.xsize(), ib.ysize(), + ib.extra_channels().size()); + std::vector<uint8_t> color; + if (ib.IsJPEG()) { + frame_data.SetJPEGData(*ib.jpeg_data); + } else { + uint32_t num_channels = + ib.IsGray() && frame_info.ib_needs_color_transform ? 1 : 3; + size_t stride = ib.xsize() * num_channels * 4; + color.resize(ib.ysize() * stride); + JXL_RETURN_IF_ERROR(ConvertToExternal( + ib, /*bites_per_sample=*/32, /*float_out=*/true, num_channels, + JXL_NATIVE_ENDIAN, stride, pool, color.data(), color.size(), + /*out_callback=*/{}, Orientation::kIdentity)); + JxlPixelFormat format{num_channels, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + frame_data.SetFromBuffer(0, color.data(), color.size(), format); + } + for (size_t ec = 0; ec < ib.extra_channels().size(); ++ec) { + JxlPixelFormat ec_format{1, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + size_t ec_stride = ib.xsize() * 4; + std::vector<uint8_t> ec_data(ib.ysize() * ec_stride); + const ImageF* channel = &ib.extra_channels()[ec]; + JXL_RETURN_IF_ERROR(ConvertChannelsToExternal( + &channel, 1, + /*bites_per_sample=*/32, + /*float_out=*/true, JXL_NATIVE_ENDIAN, ec_stride, pool, ec_data.data(), + ec_data.size(), /*out_callback=*/{}, Orientation::kIdentity)); + frame_data.SetFromBuffer(1 + ec, ec_data.data(), ec_data.size(), ec_format); + } + FrameInfo fi = frame_info; + fi.origin = ib.origin; + fi.blend = ib.blend; + fi.blendmode = ib.blendmode; + fi.duration = ib.duration; + fi.timecode = ib.timecode; + fi.name = ib.name; + std::vector<uint8_t> output(64); + uint8_t* next_out = output.data(); + size_t avail_out = output.size(); + JxlEncoderOutputProcessorWrapper output_processor; + output_processor.SetAvailOut(&next_out, &avail_out); + JXL_RETURN_IF_ERROR(EncodeFrame(cparams_orig, fi, metadata, frame_data, cms, + pool, &output_processor, aux_out)); + output_processor.SetFinalizedPosition(); + output_processor.CopyOutput(output, next_out, avail_out); + writer->AppendByteAligned(Bytes(output)); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_frame.h b/third_party/jpeg-xl/lib/jxl/enc_frame.h new file mode 100644 index 0000000000..c6db64ee4e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_frame.h @@ -0,0 +1,108 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_FRAME_H_ +#define LIB_JXL_ENC_FRAME_H_ + +#include <jxl/cms_interface.h> +#include <jxl/types.h> + +#include <cstddef> +#include <cstdint> +#include <string> +#include <vector> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" + +namespace jxl { + +struct AuxOut; + +// Information needed for encoding a frame that is not contained elsewhere and +// does not belong to `cparams`. +// TODO(lode): if possible, it might be better to replace FrameInfo and several +// fields from ImageBundle (such as frame name and duration) by direct usage of +// jxl::FrameHeader itself. +struct FrameInfo { + // TODO(veluca): consider adding more parameters, such as custom patches. + bool save_before_color_transform = false; + // Whether or not the input image bundle is already in the codestream + // colorspace (as deduced by cparams). + // TODO(veluca): this is a hack - ImageBundle doesn't have a simple way to say + // "this is already in XYB". + bool ib_needs_color_transform = true; + FrameType frame_type = FrameType::kRegularFrame; + size_t dc_level = 0; + // Only used for kRegularFrame. + bool is_last = true; + bool is_preview = false; + // Information for storing this frame for future use (only for non-DC frames). + size_t save_as_reference = 0; + // The source frame for blending of a next frame, matching the + // save_as_reference value of a previous frame. Animated frames can use + // save_as_reference values 1, 2 and 3, while composite still frames can use + // save_as_reference values 0, 1, 2 and 3. The current C++ encoder + // implementation is assuming and using 1 for all frames of animations, so + // using that as the default value here. + // Corresponds to BlendingInfo::source from the FrameHeader. + size_t source = 1; + // Corresponds to BlendingInfo::clamp from the FrameHeader. + size_t clamp = 1; + // Corresponds to BlendingInfo::alpha_channel from the FrameHeader, or set to + // -1 to automatically choose it as the index of the first extra channel of + // type alpha. + int alpha_channel = -1; + + FrameOrigin origin{0, 0}; + + bool blend = false; + BlendMode blendmode = BlendMode::kBlend; + + JxlBitDepth image_bit_depth = {}; + + // Animation-related information, corresponding to the timecode and duration + // fields of the jxl::AnimationFrame of the jxl::FrameHeader. + uint32_t duration = 0; + uint32_t timecode = 0; + + std::string name; + + // If non-empty, uses this blending info for the extra channels, otherwise + // automatically chooses it. The encoder API will fill this vector with the + // extra channel info and allows more options. The non-API cjxl leaves it + // empty and relies on the default behavior. + std::vector<BlendingInfo> extra_channel_blending_info; +}; + +// Checks and adjusts CompressParams when they are all initialized. +Status ParamsPostInit(CompressParams* p); + +// Encodes a single frame (including its header) into a byte stream. Groups may +// be processed in parallel by `pool`. metadata is the ImageMetadata encoded in +// the codestream, and must be used for the FrameHeaders, do not use +// ib.metadata. +Status EncodeFrame(const CompressParams& cparams_orig, + const FrameInfo& frame_info, const CodecMetadata* metadata, + JxlEncoderChunkedFrameAdapter& frame_data, + const JxlCmsInterface& cms, ThreadPool* pool, + JxlEncoderOutputProcessorWrapper* output_processor, + AuxOut* aux_out); + +Status EncodeFrame(const CompressParams& cparams_orig, + const FrameInfo& frame_info, const CodecMetadata* metadata, + const ImageBundle& ib, const JxlCmsInterface& cms, + ThreadPool* pool, BitWriter* writer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_FRAME_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_gaborish.cc b/third_party/jpeg-xl/lib/jxl/enc_gaborish.cc new file mode 100644 index 0000000000..3f2ee32afd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_gaborish.cc @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_gaborish.h" + +#include <stddef.h> + +#include <hwy/base.h> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +void GaborishInverse(Image3F* in_out, const Rect& rect, float mul[3], + ThreadPool* pool) { + WeightsSymmetric5 weights[3]; + // Only an approximation. One or even two 3x3, and rank-1 (separable) 5x5 + // are insufficient. The numbers here have been obtained by butteraugli + // based optimizing the whole system and the errors produced are likely + // more favorable for good rate-distortion compromises rather than + // just using mathematical optimization to find the inverse. + static const float kGaborish[5] = { + -0.090881924078487886f, -0.043663953593472138f, 0.01392497846646211f, + 0.0036189602184591141f, 0.0030557936884763499f}; + for (int i = 0; i < 3; ++i) { + double sum = 1.0 + mul[i] * 4 * + (kGaborish[0] + kGaborish[1] + kGaborish[2] + + kGaborish[4] + 2 * kGaborish[3]); + if (sum < 1e-5) { + sum = 1e-5; + } + const float normalize = static_cast<float>(1.0 / sum); + const float normalize_mul = mul[i] * normalize; + weights[i] = WeightsSymmetric5{{HWY_REP4(normalize)}, + {HWY_REP4(normalize_mul * kGaborish[0])}, + {HWY_REP4(normalize_mul * kGaborish[2])}, + {HWY_REP4(normalize_mul * kGaborish[1])}, + {HWY_REP4(normalize_mul * kGaborish[4])}, + {HWY_REP4(normalize_mul * kGaborish[3])}}; + } + // Reduce memory footprint by only allocating a single plane and swapping it + // into the output Image3F. Better still would be tiling. + // Note that we cannot *allocate* a plane, as doing so might cause Image3F to + // have planes of different stride. Instead, we copy one plane in a temporary + // image and reuse the existing planes of the in/out image. + ImageF temp(in_out->Plane(2).xsize(), in_out->Plane(2).ysize()); + CopyImageTo(in_out->Plane(2), &temp); + Rect xrect = rect.Extend(3, Rect(*in_out)); + Symmetric5(in_out->Plane(0), xrect, weights[0], pool, &in_out->Plane(2), + xrect); + Symmetric5(in_out->Plane(1), xrect, weights[1], pool, &in_out->Plane(0), + xrect); + Symmetric5(temp, xrect, weights[2], pool, &in_out->Plane(1), xrect); + // Now planes are 1, 2, 0. + in_out->Plane(0).Swap(in_out->Plane(1)); + // 2 1 0 + in_out->Plane(0).Swap(in_out->Plane(2)); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_gaborish.h b/third_party/jpeg-xl/lib/jxl/enc_gaborish.h new file mode 100644 index 0000000000..ece4959f36 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_gaborish.h @@ -0,0 +1,24 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_GABORISH_H_ +#define LIB_JXL_GABORISH_H_ + +// Linear smoothing (3x3 convolution) for deblocking without too much blur. + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/image.h" + +namespace jxl { + +// Used in encoder to reduce the impact of the decoder's smoothing. +// This is not exact. Works in-place to reduce memory use. +// The input is typically in XYB space. +void GaborishInverse(Image3F* in_out, const Rect& rect, float mul[3], + ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_GABORISH_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_gaborish_test.cc b/third_party/jpeg-xl/lib/jxl/enc_gaborish_test.cc new file mode 100644 index 0000000000..426f08ecb0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_gaborish_test.cc @@ -0,0 +1,81 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_gaborish.h" + +#include <hwy/base.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +// weight1,2 need not be normalized. +WeightsSymmetric3 GaborishKernel(float weight1, float weight2) { + constexpr float weight0 = 1.0f; + + // Normalize + const float mul = 1.0f / (weight0 + 4 * (weight1 + weight2)); + const float w0 = weight0 * mul; + const float w1 = weight1 * mul; + const float w2 = weight2 * mul; + + const WeightsSymmetric3 w = {{HWY_REP4(w0)}, {HWY_REP4(w1)}, {HWY_REP4(w2)}}; + return w; +} + +void ConvolveGaborish(const ImageF& in, float weight1, float weight2, + ThreadPool* pool, ImageF* JXL_RESTRICT out) { + JXL_CHECK(SameSize(in, *out)); + Symmetric3(in, Rect(in), GaborishKernel(weight1, weight2), pool, out); +} + +void TestRoundTrip(const Image3F& in, float max_l1) { + Image3F fwd(in.xsize(), in.ysize()); + ThreadPool* null_pool = nullptr; + ConvolveGaborish(in.Plane(0), 0, 0, null_pool, &fwd.Plane(0)); + ConvolveGaborish(in.Plane(1), 0, 0, null_pool, &fwd.Plane(1)); + ConvolveGaborish(in.Plane(2), 0, 0, null_pool, &fwd.Plane(2)); + float w = 0.92718927264540152f; + float weights[3] = { + w, + w, + w, + }; + GaborishInverse(&fwd, Rect(fwd), weights, null_pool); + JXL_ASSERT_OK(VerifyRelativeError(in, fwd, max_l1, 1E-4f, _)); +} + +TEST(GaborishTest, TestZero) { + Image3F in(20, 20); + ZeroFillImage(&in); + TestRoundTrip(in, 0.0f); +} + +// Disabled: large difference. +#if 0 +TEST(GaborishTest, TestDirac) { + Image3F in(20, 20); + ZeroFillImage(&in); + in.PlaneRow(1, 10)[10] = 10.0f; + TestRoundTrip(in, 0.26f); +} +#endif + +TEST(GaborishTest, TestFlat) { + Image3F in(20, 20); + FillImage(1.0f, &in); + TestRoundTrip(in, 1E-5f); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_gamma_correct.h b/third_party/jpeg-xl/lib/jxl/enc_gamma_correct.h new file mode 100644 index 0000000000..0d1b9123e0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_gamma_correct.h @@ -0,0 +1,35 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_GAMMA_CORRECT_H_ +#define LIB_JXL_ENC_GAMMA_CORRECT_H_ + +// Deprecated: sRGB transfer function. Use JxlCms instead. + +#include <cmath> + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// Values are in [0, 1]. +static JXL_INLINE double Srgb8ToLinearDirect(double srgb) { + if (srgb <= 0.0) return 0.0; + if (srgb <= 0.04045) return srgb / 12.92; + if (srgb >= 1.0) return 1.0; + return std::pow((srgb + 0.055) / 1.055, 2.4); +} + +// Values are in [0, 1]. +static JXL_INLINE double LinearToSrgb8Direct(double linear) { + if (linear <= 0.0) return 0.0; + if (linear >= 1.0) return 1.0; + if (linear <= 0.0031308) return linear * 12.92; + return std::pow(linear, 1.0 / 2.4) * 1.055 - 0.055; +} + +} // namespace jxl + +#endif // LIB_JXL_ENC_GAMMA_CORRECT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_group.cc b/third_party/jpeg-xl/lib/jxl/enc_group.cc new file mode 100644 index 0000000000..09bab534c9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_group.cc @@ -0,0 +1,540 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_group.h" + +#include <hwy/aligned_allocator.h> +#include <utility> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_group.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/common.h" // kMaxNumPasses +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_transforms-inl.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quantizer-inl.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/simd_util.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Ge; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::IfThenElseZero; +using hwy::HWY_NAMESPACE::MaskFromVec; +using hwy::HWY_NAMESPACE::Round; + +// NOTE: caller takes care of extracting quant from rect of RawQuantField. +void QuantizeBlockAC(const Quantizer& quantizer, const bool error_diffusion, + size_t c, float qm_multiplier, size_t quant_kind, + size_t xsize, size_t ysize, float* thresholds, + const float* JXL_RESTRICT block_in, int32_t* quant, + int32_t* JXL_RESTRICT block_out) { + const float* JXL_RESTRICT qm = quantizer.InvDequantMatrix(quant_kind, c); + float qac = quantizer.Scale() * (*quant); + // Not SIMD-ified for now. + if (c != 1 && xsize * ysize >= 4) { + for (int i = 0; i < 4; ++i) { + thresholds[i] -= 0.00744f * xsize * ysize; + if (thresholds[i] < 0.5) { + thresholds[i] = 0.5; + } + } + } + HWY_CAPPED(float, kBlockDim) df; + HWY_CAPPED(int32_t, kBlockDim) di; + HWY_CAPPED(uint32_t, kBlockDim) du; + const auto quantv = Set(df, qac * qm_multiplier); + for (size_t y = 0; y < ysize * kBlockDim; y++) { + size_t yfix = static_cast<size_t>(y >= ysize * kBlockDim / 2) * 2; + const size_t off = y * kBlockDim * xsize; + for (size_t x = 0; x < xsize * kBlockDim; x += Lanes(df)) { + auto thr = Zero(df); + if (xsize == 1) { + HWY_ALIGN uint32_t kMask[kBlockDim] = {0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u}; + const auto mask = MaskFromVec(BitCast(df, Load(du, kMask + x))); + thr = IfThenElse(mask, Set(df, thresholds[yfix + 1]), + Set(df, thresholds[yfix])); + } else { + // Same for all lanes in the vector. + thr = Set( + df, + thresholds[yfix + static_cast<size_t>(x >= xsize * kBlockDim / 2)]); + } + const auto q = Mul(Load(df, qm + off + x), quantv); + const auto in = Load(df, block_in + off + x); + const auto val = Mul(q, in); + const auto nzero_mask = Ge(Abs(val), thr); + const auto v = ConvertTo(di, IfThenElseZero(nzero_mask, Round(val))); + Store(v, di, block_out + off + x); + } + } +} + +void AdjustQuantBlockAC(const Quantizer& quantizer, size_t c, + float qm_multiplier, size_t quant_kind, size_t xsize, + size_t ysize, float* thresholds, + const float* JXL_RESTRICT block_in, int32_t* quant) { + // No quantization adjusting for these small blocks. + // Quantization adjusting attempts to fix some known issues + // with larger blocks and on the 8x8 dct's emerging 8x8 blockiness + // when there are not many non-zeros. + constexpr size_t kPartialBlockKinds = + (1 << AcStrategy::Type::IDENTITY) | (1 << AcStrategy::Type::DCT2X2) | + (1 << AcStrategy::Type::DCT4X4) | (1 << AcStrategy::Type::DCT4X8) | + (1 << AcStrategy::Type::DCT8X4) | (1 << AcStrategy::Type::AFV0) | + (1 << AcStrategy::Type::AFV1) | (1 << AcStrategy::Type::AFV2) | + (1 << AcStrategy::Type::AFV3); + if ((1 << quant_kind) & kPartialBlockKinds) { + return; + } + + const float* JXL_RESTRICT qm = quantizer.InvDequantMatrix(quant_kind, c); + float qac = quantizer.Scale() * (*quant); + if (xsize > 1 || ysize > 1) { + for (int i = 0; i < 4; ++i) { + thresholds[i] -= Clamp1(0.003f * xsize * ysize, 0.f, 0.08f); + if (thresholds[i] < 0.54) { + thresholds[i] = 0.54; + } + } + } + float sum_of_highest_freq_row_and_column = 0; + float sum_of_error = 0; + float sum_of_vals = 0; + float hfNonZeros[4] = {}; + float hfMaxError[4] = {}; + + for (size_t y = 0; y < ysize * kBlockDim; y++) { + for (size_t x = 0; x < xsize * kBlockDim; x++) { + const size_t pos = y * kBlockDim * xsize + x; + if (x < xsize && y < ysize) { + continue; + } + const size_t hfix = (static_cast<size_t>(y >= ysize * kBlockDim / 2) * 2 + + static_cast<size_t>(x >= xsize * kBlockDim / 2)); + const float val = block_in[pos] * (qm[pos] * qac * qm_multiplier); + const float v = (std::abs(val) < thresholds[hfix]) ? 0 : rintf(val); + const float error = std::abs(val - v); + sum_of_error += error; + sum_of_vals += std::abs(v); + if (c == 1 && v == 0) { + if (hfMaxError[hfix] < error) { + hfMaxError[hfix] = error; + } + } + if (v != 0.0f) { + hfNonZeros[hfix] += std::abs(v); + bool in_corner = y >= 7 * ysize && x >= 7 * xsize; + bool on_border = + y == ysize * kBlockDim - 1 || x == xsize * kBlockDim - 1; + bool in_larger_corner = x >= 4 * xsize && y >= 4 * ysize; + if (in_corner || (on_border && in_larger_corner)) { + sum_of_highest_freq_row_and_column += std::abs(val); + } + } + } + } + if (c == 1 && sum_of_vals * 8 < xsize * ysize) { + static const double kLimit[4] = { + 0.46, + 0.46, + 0.46, + 0.46, + }; + static const double kMul[4] = { + 0.9999, + 0.9999, + 0.9999, + 0.9999, + }; + const int32_t orig_quant = *quant; + int32_t new_quant = *quant; + for (int i = 1; i < 4; ++i) { + if (hfNonZeros[i] == 0.0 && hfMaxError[i] > kLimit[i]) { + new_quant = orig_quant + 1; + break; + } + } + *quant = new_quant; + if (hfNonZeros[3] == 0.0 && hfMaxError[3] > kLimit[3]) { + thresholds[3] = kMul[3] * hfMaxError[3] * new_quant / orig_quant; + } else if ((hfNonZeros[1] == 0.0 && hfMaxError[1] > kLimit[1]) || + (hfNonZeros[2] == 0.0 && hfMaxError[2] > kLimit[2])) { + thresholds[1] = kMul[1] * std::max(hfMaxError[1], hfMaxError[2]) * + new_quant / orig_quant; + thresholds[2] = thresholds[1]; + } else if (hfNonZeros[0] == 0.0 && hfMaxError[0] > kLimit[0]) { + thresholds[0] = kMul[0] * hfMaxError[0] * new_quant / orig_quant; + } + } + // Heuristic for improving accuracy of high-frequency patterns + // occurring in an environment with no medium-frequency masking + // patterns. + { + float all = + hfNonZeros[0] + hfNonZeros[1] + hfNonZeros[2] + hfNonZeros[3] + 1; + float mul[3] = {70, 30, 60}; + if (mul[c] * sum_of_highest_freq_row_and_column >= all) { + *quant += mul[c] * sum_of_highest_freq_row_and_column / all; + if (*quant >= Quantizer::kQuantMax) { + *quant = Quantizer::kQuantMax - 1; + } + } + } + if (quant_kind == AcStrategy::Type::DCT) { + // If this 8x8 block is too flat, increase the adaptive quantization level + // a bit to reduce visible block boundaries and requantize the block. + if (hfNonZeros[0] + hfNonZeros[1] + hfNonZeros[2] + hfNonZeros[3] < 11) { + *quant += 1; + if (*quant >= Quantizer::kQuantMax) { + *quant = Quantizer::kQuantMax - 1; + } + } + } + { + static const double kMul1[4][3] = { + { + 0.22080615753848404, + 0.45797479824262011, + 0.29859235095977965, + }, + { + 0.70109486510286834, + 0.16185281305512639, + 0.14387691730035473, + }, + { + 0.114985964456218638, + 0.44656840441027695, + 0.10587658215149048, + }, + { + 0.46849665264409396, + 0.41239077937781954, + 0.088667407767185444, + }, + }; + static const double kMul2[4][3] = { + { + 0.27450281941822197, + 1.1255766549984996, + 0.98950459134128388, + }, + { + 0.4652168675598285, + 0.40945807983455818, + 0.36581899811751367, + }, + { + 0.28034972424715715, + 0.9182653201929738, + 1.5581531543057416, + }, + { + 0.26873118114033728, + 0.68863712390392484, + 1.2082185408666786, + }, + }; + static const double kQuantNormalizer = 2.2942708343284721; + sum_of_error *= kQuantNormalizer; + sum_of_vals *= kQuantNormalizer; + if (quant_kind >= AcStrategy::Type::DCT16X16) { + int ix = 3; + if (quant_kind == AcStrategy::Type::DCT32X16 || + quant_kind == AcStrategy::Type::DCT16X32) { + ix = 1; + } else if (quant_kind == AcStrategy::Type::DCT16X16) { + ix = 0; + } else if (quant_kind == AcStrategy::Type::DCT32X32) { + ix = 2; + } + int step = + sum_of_error / (kMul1[ix][c] * xsize * ysize * kBlockDim * kBlockDim + + kMul2[ix][c] * sum_of_vals); + if (step >= 2) { + step = 2; + } + if (step < 0) { + step = 0; + } + if (sum_of_error > kMul1[ix][c] * xsize * ysize * kBlockDim * kBlockDim + + kMul2[ix][c] * sum_of_vals) { + *quant += step; + if (*quant >= Quantizer::kQuantMax) { + *quant = Quantizer::kQuantMax - 1; + } + } + } + } + { + // Reduce quant in highly active areas. + int32_t div = (xsize * ysize); + int32_t activity = (hfNonZeros[0] + div / 2) / div; + int32_t orig_qp_limit = std::max(4, *quant / 2); + for (int i = 1; i < 4; ++i) { + activity = std::min<int32_t>(activity, (hfNonZeros[i] + div / 2) / div); + } + if (activity >= 15) { + activity = 15; + } + int32_t qp = *quant - activity; + if (c == 1) { + for (int i = 1; i < 4; ++i) { + thresholds[i] += 0.01 * activity; + } + } + if (qp < orig_qp_limit) { + qp = orig_qp_limit; + } + *quant = qp; + } +} + +// NOTE: caller takes care of extracting quant from rect of RawQuantField. +void QuantizeRoundtripYBlockAC(PassesEncoderState* enc_state, const size_t size, + const Quantizer& quantizer, + const bool error_diffusion, size_t quant_kind, + size_t xsize, size_t ysize, + const float* JXL_RESTRICT biases, int32_t* quant, + float* JXL_RESTRICT inout, + int32_t* JXL_RESTRICT quantized) { + float thres_y[4] = {0.58f, 0.64f, 0.64f, 0.64f}; + { + int32_t max_quant = 0; + int quant_orig = *quant; + float val[3] = {enc_state->x_qm_multiplier, 1.0f, + enc_state->b_qm_multiplier}; + int clut[3] = {1, 0, 2}; + for (int ii = 0; ii < 3; ++ii) { + float thres[4] = {0.58f, 0.64f, 0.64f, 0.64f}; + int c = clut[ii]; + *quant = quant_orig; + AdjustQuantBlockAC(quantizer, c, val[c], quant_kind, xsize, ysize, + &thres[0], inout + c * size, quant); + // Dead zone adjustment + if (c == 1) { + for (int k = 0; k < 4; ++k) { + thres_y[k] = thres[k]; + } + } + max_quant = std::max(*quant, max_quant); + } + *quant = max_quant; + } + + QuantizeBlockAC(quantizer, error_diffusion, 1, 1.0f, quant_kind, xsize, ysize, + &thres_y[0], inout + size, quant, quantized + size); + + const float* JXL_RESTRICT dequant_matrix = + quantizer.DequantMatrix(quant_kind, 1); + + HWY_CAPPED(float, kDCTBlockSize) df; + HWY_CAPPED(int32_t, kDCTBlockSize) di; + const auto inv_qac = Set(df, quantizer.inv_quant_ac(*quant)); + for (size_t k = 0; k < kDCTBlockSize * xsize * ysize; k += Lanes(df)) { + const auto quant = Load(di, quantized + size + k); + const auto adj_quant = AdjustQuantBias(di, 1, quant, biases); + const auto dequantm = Load(df, dequant_matrix + k); + Store(Mul(Mul(adj_quant, dequantm), inv_qac), df, inout + size + k); + } +} + +void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state, + const Image3F& opsin, const Rect& rect, Image3F* dc) { + const Rect block_group_rect = + enc_state->shared.frame_dim.BlockGroupRect(group_idx); + const Rect cmap_rect( + block_group_rect.x0() / kColorTileDimInBlocks, + block_group_rect.y0() / kColorTileDimInBlocks, + DivCeil(block_group_rect.xsize(), kColorTileDimInBlocks), + DivCeil(block_group_rect.ysize(), kColorTileDimInBlocks)); + const Rect group_rect = + enc_state->shared.frame_dim.GroupRect(group_idx).Translate(rect.x0(), + rect.y0()); + + const size_t xsize_blocks = block_group_rect.xsize(); + const size_t ysize_blocks = block_group_rect.ysize(); + + const size_t dc_stride = static_cast<size_t>(dc->PixelsPerRow()); + const size_t opsin_stride = static_cast<size_t>(opsin.PixelsPerRow()); + + ImageI& full_quant_field = enc_state->shared.raw_quant_field; + const CompressParams& cparams = enc_state->cparams; + + const size_t dct_scratch_size = + 3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim; + + // TODO(veluca): consider strategies to reduce this memory. + auto mem = hwy::AllocateAligned<int32_t>(3 * AcStrategy::kMaxCoeffArea); + auto fmem = hwy::AllocateAligned<float>(5 * AcStrategy::kMaxCoeffArea + + dct_scratch_size); + float* JXL_RESTRICT scratch_space = + fmem.get() + 3 * AcStrategy::kMaxCoeffArea; + { + // Only use error diffusion in Squirrel mode or slower. + const bool error_diffusion = cparams.speed_tier <= SpeedTier::kSquirrel; + constexpr HWY_CAPPED(float, kDCTBlockSize) d; + + int32_t* JXL_RESTRICT coeffs[3][kMaxNumPasses] = {}; + size_t num_passes = enc_state->progressive_splitter.GetNumPasses(); + JXL_DASSERT(num_passes > 0); + for (size_t i = 0; i < num_passes; i++) { + // TODO(veluca): 16-bit quantized coeffs are not implemented yet. + JXL_ASSERT(enc_state->coeffs[i]->Type() == ACType::k32); + for (size_t c = 0; c < 3; c++) { + coeffs[c][i] = enc_state->coeffs[i]->PlaneRow(c, group_idx, 0).ptr32; + } + } + + HWY_ALIGN float* coeffs_in = fmem.get(); + HWY_ALIGN int32_t* quantized = mem.get(); + + for (size_t by = 0; by < ysize_blocks; ++by) { + int32_t* JXL_RESTRICT row_quant_ac = + block_group_rect.Row(&full_quant_field, by); + size_t ty = by / kColorTileDimInBlocks; + const int8_t* JXL_RESTRICT row_cmap[3] = { + cmap_rect.ConstRow(enc_state->shared.cmap.ytox_map, ty), + nullptr, + cmap_rect.ConstRow(enc_state->shared.cmap.ytob_map, ty), + }; + const float* JXL_RESTRICT opsin_rows[3] = { + group_rect.ConstPlaneRow(opsin, 0, by * kBlockDim), + group_rect.ConstPlaneRow(opsin, 1, by * kBlockDim), + group_rect.ConstPlaneRow(opsin, 2, by * kBlockDim), + }; + float* JXL_RESTRICT dc_rows[3] = { + block_group_rect.PlaneRow(dc, 0, by), + block_group_rect.PlaneRow(dc, 1, by), + block_group_rect.PlaneRow(dc, 2, by), + }; + AcStrategyRow ac_strategy_row = + enc_state->shared.ac_strategy.ConstRow(block_group_rect, by); + for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks); + tx++) { + const auto x_factor = + Set(d, enc_state->shared.cmap.YtoXRatio(row_cmap[0][tx])); + const auto b_factor = + Set(d, enc_state->shared.cmap.YtoBRatio(row_cmap[2][tx])); + for (size_t bx = tx * kColorTileDimInBlocks; + bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks; ++bx) { + const AcStrategy acs = ac_strategy_row[bx]; + if (!acs.IsFirstBlock()) continue; + + size_t xblocks = acs.covered_blocks_x(); + size_t yblocks = acs.covered_blocks_y(); + + CoefficientLayout(&yblocks, &xblocks); + + size_t size = kDCTBlockSize * xblocks * yblocks; + + // DCT Y channel, roundtrip-quantize it and set DC. + int32_t quant_ac = row_quant_ac[bx]; + for (size_t c : {0, 1, 2}) { + TransformFromPixels(acs.Strategy(), opsin_rows[c] + bx * kBlockDim, + opsin_stride, coeffs_in + c * size, + scratch_space); + } + DCFromLowestFrequencies(acs.Strategy(), coeffs_in + size, + dc_rows[1] + bx, dc_stride); + + QuantizeRoundtripYBlockAC( + enc_state, size, enc_state->shared.quantizer, error_diffusion, + acs.RawStrategy(), xblocks, yblocks, kDefaultQuantBias, &quant_ac, + coeffs_in, quantized); + + // Unapply color correlation + for (size_t k = 0; k < size; k += Lanes(d)) { + const auto in_x = Load(d, coeffs_in + k); + const auto in_y = Load(d, coeffs_in + size + k); + const auto in_b = Load(d, coeffs_in + 2 * size + k); + const auto out_x = NegMulAdd(x_factor, in_y, in_x); + const auto out_b = NegMulAdd(b_factor, in_y, in_b); + Store(out_x, d, coeffs_in + k); + Store(out_b, d, coeffs_in + 2 * size + k); + } + + // Quantize X and B channels and set DC. + for (size_t c : {0, 2}) { + float thres[4] = {0.58f, 0.62f, 0.62f, 0.62f}; + QuantizeBlockAC(enc_state->shared.quantizer, error_diffusion, c, + c == 0 ? enc_state->x_qm_multiplier + : enc_state->b_qm_multiplier, + acs.RawStrategy(), xblocks, yblocks, &thres[0], + coeffs_in + c * size, &quant_ac, + quantized + c * size); + DCFromLowestFrequencies(acs.Strategy(), coeffs_in + c * size, + dc_rows[c] + bx, dc_stride); + } + row_quant_ac[bx] = quant_ac; + for (size_t c = 0; c < 3; c++) { + enc_state->progressive_splitter.SplitACCoefficients( + quantized + c * size, acs, bx, by, coeffs[c]); + for (size_t p = 0; p < num_passes; p++) { + coeffs[c][p] += size; + } + } + } + } + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ComputeCoefficients); +void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state, + const Image3F& opsin, const Rect& rect, Image3F* dc) { + return HWY_DYNAMIC_DISPATCH(ComputeCoefficients)(group_idx, enc_state, opsin, + rect, dc); +} + +Status EncodeGroupTokenizedCoefficients(size_t group_idx, size_t pass_idx, + size_t histogram_idx, + const PassesEncoderState& enc_state, + BitWriter* writer, AuxOut* aux_out) { + // Select which histogram to use among those of the current pass. + const size_t num_histograms = enc_state.shared.num_histograms; + // num_histograms is 0 only for lossless. + JXL_ASSERT(num_histograms == 0 || histogram_idx < num_histograms); + size_t histo_selector_bits = CeilLog2Nonzero(num_histograms); + + if (histo_selector_bits != 0) { + BitWriter::Allotment allotment(writer, histo_selector_bits); + writer->Write(histo_selector_bits, histogram_idx); + allotment.ReclaimAndCharge(writer, kLayerAC, aux_out); + } + size_t context_offset = + histogram_idx * enc_state.shared.block_ctx_map.NumACContexts(); + WriteTokens(enc_state.passes[pass_idx].ac_tokens[group_idx], + enc_state.passes[pass_idx].codes, + enc_state.passes[pass_idx].context_map, context_offset, writer, + kLayerACTokens, aux_out); + + return true; +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_group.h b/third_party/jpeg-xl/lib/jxl/enc_group.h new file mode 100644 index 0000000000..78484c2e9b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_group.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_GROUP_H_ +#define LIB_JXL_ENC_GROUP_H_ + +#include <stddef.h> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" + +namespace jxl { + +struct AuxOut; +struct PassesEncoderState; + +// Fills DC +void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state, + const Image3F& opsin, const Rect& rect, Image3F* dc); + +Status EncodeGroupTokenizedCoefficients(size_t group_idx, size_t pass_idx, + size_t histogram_idx, + const PassesEncoderState& enc_state, + BitWriter* writer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_GROUP_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_heuristics.cc b/third_party/jpeg-xl/lib/jxl/enc_heuristics.cc new file mode 100644 index 0000000000..9d6bf11184 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_heuristics.cc @@ -0,0 +1,909 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_heuristics.h" + +#include <jxl/cms_interface.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <cstdlib> +#include <limits> +#include <memory> +#include <numeric> +#include <string> +#include <utility> +#include <vector> + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_ac_strategy.h" +#include "lib/jxl/enc_adaptive_quantization.h" +#include "lib/jxl/enc_ar_control_field.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_chroma_from_luma.h" +#include "lib/jxl/enc_gaborish.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_noise.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/enc_splines.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/passes_state.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +struct AuxOut; + +void FindBestBlockEntropyModel(const CompressParams& cparams, const ImageI& rqf, + const AcStrategyImage& ac_strategy, + BlockCtxMap* block_ctx_map) { + if (cparams.decoding_speed_tier >= 1) { + static constexpr uint8_t kSimpleCtxMap[] = { + // Cluster all blocks together + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // + }; + static_assert( + 3 * kNumOrders == sizeof(kSimpleCtxMap) / sizeof *kSimpleCtxMap, + "Update simple context map"); + + auto bcm = *block_ctx_map; + bcm.ctx_map.assign(std::begin(kSimpleCtxMap), std::end(kSimpleCtxMap)); + bcm.num_ctxs = 2; + bcm.num_dc_ctxs = 1; + return; + } + if (cparams.speed_tier >= SpeedTier::kFalcon) { + return; + } + // No need to change context modeling for small images. + size_t tot = rqf.xsize() * rqf.ysize(); + size_t size_for_ctx_model = (1 << 10) * cparams.butteraugli_distance; + if (tot < size_for_ctx_model) return; + + struct OccCounters { + // count the occurrences of each qf value and each strategy type. + OccCounters(const ImageI& rqf, const AcStrategyImage& ac_strategy) { + for (size_t y = 0; y < rqf.ysize(); y++) { + const int32_t* qf_row = rqf.Row(y); + AcStrategyRow acs_row = ac_strategy.ConstRow(y); + for (size_t x = 0; x < rqf.xsize(); x++) { + int ord = kStrategyOrder[acs_row[x].RawStrategy()]; + int qf = qf_row[x] - 1; + qf_counts[qf]++; + qf_ord_counts[ord][qf]++; + ord_counts[ord]++; + } + } + } + + size_t qf_counts[256] = {}; + size_t qf_ord_counts[kNumOrders][256] = {}; + size_t ord_counts[kNumOrders] = {}; + }; + // The OccCounters struct is too big to allocate on the stack. + std::unique_ptr<OccCounters> counters(new OccCounters(rqf, ac_strategy)); + + // Splitting the context model according to the quantization field seems to + // mostly benefit only large images. + size_t size_for_qf_split = (1 << 13) * cparams.butteraugli_distance; + size_t num_qf_segments = tot < size_for_qf_split ? 1 : 2; + std::vector<uint32_t>& qft = block_ctx_map->qf_thresholds; + qft.clear(); + // Divide the quant field in up to num_qf_segments segments. + size_t cumsum = 0; + size_t next = 1; + size_t last_cut = 256; + size_t cut = tot * next / num_qf_segments; + for (uint32_t j = 0; j < 256; j++) { + cumsum += counters->qf_counts[j]; + if (cumsum > cut) { + if (j != 0) { + qft.push_back(j); + } + last_cut = j; + while (cumsum > cut) { + next++; + cut = tot * next / num_qf_segments; + } + } else if (next > qft.size() + 1) { + if (j - 1 == last_cut && j != 0) { + qft.push_back(j); + } + } + } + + // Count the occurrences of each segment. + std::vector<size_t> counts(kNumOrders * (qft.size() + 1)); + size_t qft_pos = 0; + for (size_t j = 0; j < 256; j++) { + if (qft_pos < qft.size() && j == qft[qft_pos]) { + qft_pos++; + } + for (size_t i = 0; i < kNumOrders; i++) { + counts[qft_pos + i * (qft.size() + 1)] += counters->qf_ord_counts[i][j]; + } + } + + // Repeatedly merge the lowest-count pair. + std::vector<uint8_t> remap((qft.size() + 1) * kNumOrders); + std::iota(remap.begin(), remap.end(), 0); + std::vector<uint8_t> clusters(remap); + size_t nb_clusters = Clamp1((int)(tot / size_for_ctx_model / 2), 2, 9); + size_t nb_clusters_chroma = Clamp1((int)(tot / size_for_ctx_model / 3), 1, 5); + // This is O(n^2 log n), but n is small. + while (clusters.size() > nb_clusters) { + std::sort(clusters.begin(), clusters.end(), + [&](int a, int b) { return counts[a] > counts[b]; }); + counts[clusters[clusters.size() - 2]] += counts[clusters.back()]; + counts[clusters.back()] = 0; + remap[clusters.back()] = clusters[clusters.size() - 2]; + clusters.pop_back(); + } + for (size_t i = 0; i < remap.size(); i++) { + while (remap[remap[i]] != remap[i]) { + remap[i] = remap[remap[i]]; + } + } + // Relabel starting from 0. + std::vector<uint8_t> remap_remap(remap.size(), remap.size()); + size_t num = 0; + for (size_t i = 0; i < remap.size(); i++) { + if (remap_remap[remap[i]] == remap.size()) { + remap_remap[remap[i]] = num++; + } + remap[i] = remap_remap[remap[i]]; + } + // Write the block context map. + auto& ctx_map = block_ctx_map->ctx_map; + ctx_map = remap; + ctx_map.resize(remap.size() * 3); + // for chroma, only use up to nb_clusters_chroma separate block contexts + // (those for the biggest clusters) + for (size_t i = remap.size(); i < remap.size() * 3; i++) { + ctx_map[i] = num + Clamp1((int)remap[i % remap.size()], 0, + (int)nb_clusters_chroma - 1); + } + block_ctx_map->num_ctxs = + *std::max_element(ctx_map.begin(), ctx_map.end()) + 1; +} + +namespace { + +void FindBestDequantMatrices(const CompressParams& cparams, + ModularFrameEncoder* modular_frame_encoder, + DequantMatrices* dequant_matrices) { + // TODO(veluca): quant matrices for no-gaborish. + // TODO(veluca): heuristics for in-bitstream quant tables. + *dequant_matrices = DequantMatrices(); + if (cparams.max_error_mode) { + // Set numerators of all quantization matrices to constant values. + float weights[3][1] = {{1.0f / cparams.max_error[0]}, + {1.0f / cparams.max_error[1]}, + {1.0f / cparams.max_error[2]}}; + DctQuantWeightParams dct_params(weights); + std::vector<QuantEncoding> encodings(DequantMatrices::kNum, + QuantEncoding::DCT(dct_params)); + DequantMatricesSetCustom(dequant_matrices, encodings, + modular_frame_encoder); + float dc_weights[3] = {1.0f / cparams.max_error[0], + 1.0f / cparams.max_error[1], + 1.0f / cparams.max_error[2]}; + DequantMatricesSetCustomDC(dequant_matrices, dc_weights); + } +} + +void StoreMin2(const float v, float& min1, float& min2) { + if (v < min2) { + if (v < min1) { + min2 = min1; + min1 = v; + } else { + min2 = v; + } + } +} + +void CreateMask(const ImageF& image, ImageF& mask) { + for (size_t y = 0; y < image.ysize(); y++) { + auto* row_n = y > 0 ? image.Row(y - 1) : image.Row(y); + auto* row_in = image.Row(y); + auto* row_s = y + 1 < image.ysize() ? image.Row(y + 1) : image.Row(y); + auto* row_out = mask.Row(y); + for (size_t x = 0; x < image.xsize(); x++) { + // Center, west, east, north, south values and their absolute difference + float c = row_in[x]; + float w = x > 0 ? row_in[x - 1] : row_in[x]; + float e = x + 1 < image.xsize() ? row_in[x + 1] : row_in[x]; + float n = row_n[x]; + float s = row_s[x]; + float dw = std::abs(c - w); + float de = std::abs(c - e); + float dn = std::abs(c - n); + float ds = std::abs(c - s); + float min = std::numeric_limits<float>::max(); + float min2 = std::numeric_limits<float>::max(); + StoreMin2(dw, min, min2); + StoreMin2(de, min, min2); + StoreMin2(dn, min, min2); + StoreMin2(ds, min, min2); + row_out[x] = min2; + } + } +} + +// Downsamples the image by a factor of 2 with a kernel that's sharper than +// the standard 2x2 box kernel used by DownsampleImage. +// The kernel is optimized against the result of the 2x2 upsampling kernel used +// by the decoder. Ringing is slightly reduced by clamping the values of the +// resulting pixels within certain bounds of a small region in the original +// image. +void DownsampleImage2_Sharper(const ImageF& input, ImageF* output) { + const int64_t kernelx = 12; + const int64_t kernely = 12; + + static const float kernel[144] = { + -0.000314256996835, -0.000314256996835, -0.000897597057705, + -0.000562751488849, -0.000176807273646, 0.001864627368902, + 0.001864627368902, -0.000176807273646, -0.000562751488849, + -0.000897597057705, -0.000314256996835, -0.000314256996835, + -0.000314256996835, -0.001527942804748, -0.000121760530512, + 0.000191123989093, 0.010193185932466, 0.058637519197110, + 0.058637519197110, 0.010193185932466, 0.000191123989093, + -0.000121760530512, -0.001527942804748, -0.000314256996835, + -0.000897597057705, -0.000121760530512, 0.000946363683751, + 0.007113577630288, 0.000437956841058, -0.000372823835211, + -0.000372823835211, 0.000437956841058, 0.007113577630288, + 0.000946363683751, -0.000121760530512, -0.000897597057705, + -0.000562751488849, 0.000191123989093, 0.007113577630288, + 0.044592622228814, 0.000222278879007, -0.162864473015945, + -0.162864473015945, 0.000222278879007, 0.044592622228814, + 0.007113577630288, 0.000191123989093, -0.000562751488849, + -0.000176807273646, 0.010193185932466, 0.000437956841058, + 0.000222278879007, -0.000913092543974, -0.017071696107902, + -0.017071696107902, -0.000913092543974, 0.000222278879007, + 0.000437956841058, 0.010193185932466, -0.000176807273646, + 0.001864627368902, 0.058637519197110, -0.000372823835211, + -0.162864473015945, -0.017071696107902, 0.414660099370354, + 0.414660099370354, -0.017071696107902, -0.162864473015945, + -0.000372823835211, 0.058637519197110, 0.001864627368902, + 0.001864627368902, 0.058637519197110, -0.000372823835211, + -0.162864473015945, -0.017071696107902, 0.414660099370354, + 0.414660099370354, -0.017071696107902, -0.162864473015945, + -0.000372823835211, 0.058637519197110, 0.001864627368902, + -0.000176807273646, 0.010193185932466, 0.000437956841058, + 0.000222278879007, -0.000913092543974, -0.017071696107902, + -0.017071696107902, -0.000913092543974, 0.000222278879007, + 0.000437956841058, 0.010193185932466, -0.000176807273646, + -0.000562751488849, 0.000191123989093, 0.007113577630288, + 0.044592622228814, 0.000222278879007, -0.162864473015945, + -0.162864473015945, 0.000222278879007, 0.044592622228814, + 0.007113577630288, 0.000191123989093, -0.000562751488849, + -0.000897597057705, -0.000121760530512, 0.000946363683751, + 0.007113577630288, 0.000437956841058, -0.000372823835211, + -0.000372823835211, 0.000437956841058, 0.007113577630288, + 0.000946363683751, -0.000121760530512, -0.000897597057705, + -0.000314256996835, -0.001527942804748, -0.000121760530512, + 0.000191123989093, 0.010193185932466, 0.058637519197110, + 0.058637519197110, 0.010193185932466, 0.000191123989093, + -0.000121760530512, -0.001527942804748, -0.000314256996835, + -0.000314256996835, -0.000314256996835, -0.000897597057705, + -0.000562751488849, -0.000176807273646, 0.001864627368902, + 0.001864627368902, -0.000176807273646, -0.000562751488849, + -0.000897597057705, -0.000314256996835, -0.000314256996835}; + + int64_t xsize = input.xsize(); + int64_t ysize = input.ysize(); + + ImageF box_downsample(xsize, ysize); + CopyImageTo(input, &box_downsample); + DownsampleImage(&box_downsample, 2); + + ImageF mask(box_downsample.xsize(), box_downsample.ysize()); + CreateMask(box_downsample, mask); + + for (size_t y = 0; y < output->ysize(); y++) { + float* row_out = output->Row(y); + const float* row_in[kernely]; + const float* row_mask = mask.Row(y); + // get the rows in the support + for (size_t ky = 0; ky < kernely; ky++) { + int64_t iy = y * 2 + ky - (kernely - 1) / 2; + if (iy < 0) iy = 0; + if (iy >= ysize) iy = ysize - 1; + row_in[ky] = input.Row(iy); + } + + for (size_t x = 0; x < output->xsize(); x++) { + // get min and max values of the original image in the support + float min = std::numeric_limits<float>::max(); + float max = std::numeric_limits<float>::min(); + // kernelx - R and kernely - R are the radius of a rectangular region in + // which the values of a pixel are bounded to reduce ringing. + static constexpr int64_t R = 5; + for (int64_t ky = R; ky + R < kernely; ky++) { + for (int64_t kx = R; kx + R < kernelx; kx++) { + int64_t ix = x * 2 + kx - (kernelx - 1) / 2; + if (ix < 0) ix = 0; + if (ix >= xsize) ix = xsize - 1; + min = std::min<float>(min, row_in[ky][ix]); + max = std::max<float>(max, row_in[ky][ix]); + } + } + + float sum = 0; + for (int64_t ky = 0; ky < kernely; ky++) { + for (int64_t kx = 0; kx < kernelx; kx++) { + int64_t ix = x * 2 + kx - (kernelx - 1) / 2; + if (ix < 0) ix = 0; + if (ix >= xsize) ix = xsize - 1; + sum += row_in[ky][ix] * kernel[ky * kernelx + kx]; + } + } + + row_out[x] = sum; + + // Clamp the pixel within the value of a small area to prevent ringning. + // The mask determines how much to clamp, clamp more to reduce more + // ringing in smooth areas, clamp less in noisy areas to get more + // sharpness. Higher mask_multiplier gives less clamping, so less + // ringing reduction. + const constexpr float mask_multiplier = 1; + float a = row_mask[x] * mask_multiplier; + float clip_min = min - a; + float clip_max = max + a; + if (row_out[x] < clip_min) { + row_out[x] = clip_min; + } else if (row_out[x] > clip_max) { + row_out[x] = clip_max; + } + } + } +} + +} // namespace + +void DownsampleImage2_Sharper(Image3F* opsin) { + // Allocate extra space to avoid a reallocation when padding. + Image3F downsampled(DivCeil(opsin->xsize(), 2) + kBlockDim, + DivCeil(opsin->ysize(), 2) + kBlockDim); + downsampled.ShrinkTo(downsampled.xsize() - kBlockDim, + downsampled.ysize() - kBlockDim); + + for (size_t c = 0; c < 3; c++) { + DownsampleImage2_Sharper(opsin->Plane(c), &downsampled.Plane(c)); + } + *opsin = std::move(downsampled); +} + +namespace { + +// The default upsampling kernels used by Upsampler in the decoder. +static const constexpr int64_t kSize = 5; + +static const float kernel00[25] = { + -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f, + -0.03452303f, 0.14111091f, 0.28896755f, 0.00278718f, -0.01610267f, + -0.04022174f, 0.28896755f, 0.56661550f, 0.03777607f, -0.01986694f, + -0.02921014f, 0.00278718f, 0.03777607f, -0.03144731f, -0.01185068f, + -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f, +}; +static const float kernel01[25] = { + -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f, + -0.02921014f, 0.00278718f, 0.03777607f, -0.03144731f, -0.01185068f, + -0.04022174f, 0.28896755f, 0.56661550f, 0.03777607f, -0.01986694f, + -0.03452303f, 0.14111091f, 0.28896755f, 0.00278718f, -0.01610267f, + -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f, +}; +static const float kernel10[25] = { + -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f, + -0.01610267f, 0.00278718f, 0.28896755f, 0.14111091f, -0.03452303f, + -0.01986694f, 0.03777607f, 0.56661550f, 0.28896755f, -0.04022174f, + -0.01185068f, -0.03144731f, 0.03777607f, 0.00278718f, -0.02921014f, + -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f, +}; +static const float kernel11[25] = { + -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f, + -0.01185068f, -0.03144731f, 0.03777607f, 0.00278718f, -0.02921014f, + -0.01986694f, 0.03777607f, 0.56661550f, 0.28896755f, -0.04022174f, + -0.01610267f, 0.00278718f, 0.28896755f, 0.14111091f, -0.03452303f, + -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f, +}; + +// Does exactly the same as the Upsampler in dec_upsampler for 2x2 pixels, with +// default CustomTransformData. +// TODO(lode): use Upsampler instead. However, it requires pre-initialization +// and padding on the left side of the image which requires refactoring the +// other code using this. +static void UpsampleImage(const ImageF& input, ImageF* output) { + int64_t xsize = input.xsize(); + int64_t ysize = input.ysize(); + int64_t xsize2 = output->xsize(); + int64_t ysize2 = output->ysize(); + for (int64_t y = 0; y < ysize2; y++) { + for (int64_t x = 0; x < xsize2; x++) { + auto kernel = kernel00; + if ((x & 1) && (y & 1)) { + kernel = kernel11; + } else if (x & 1) { + kernel = kernel10; + } else if (y & 1) { + kernel = kernel01; + } + float sum = 0; + int64_t x2 = x / 2; + int64_t y2 = y / 2; + + // get min and max values of the original image in the support + float min = std::numeric_limits<float>::max(); + float max = std::numeric_limits<float>::min(); + + for (int64_t ky = 0; ky < kSize; ky++) { + for (int64_t kx = 0; kx < kSize; kx++) { + int64_t xi = x2 - kSize / 2 + kx; + int64_t yi = y2 - kSize / 2 + ky; + if (xi < 0) xi = 0; + if (xi >= xsize) xi = input.xsize() - 1; + if (yi < 0) yi = 0; + if (yi >= ysize) yi = input.ysize() - 1; + min = std::min<float>(min, input.Row(yi)[xi]); + max = std::max<float>(max, input.Row(yi)[xi]); + } + } + + for (int64_t ky = 0; ky < kSize; ky++) { + for (int64_t kx = 0; kx < kSize; kx++) { + int64_t xi = x2 - kSize / 2 + kx; + int64_t yi = y2 - kSize / 2 + ky; + if (xi < 0) xi = 0; + if (xi >= xsize) xi = input.xsize() - 1; + if (yi < 0) yi = 0; + if (yi >= ysize) yi = input.ysize() - 1; + sum += input.Row(yi)[xi] * kernel[ky * kSize + kx]; + } + } + output->Row(y)[x] = sum; + if (output->Row(y)[x] < min) output->Row(y)[x] = min; + if (output->Row(y)[x] > max) output->Row(y)[x] = max; + } + } +} + +// Returns the derivative of Upsampler, with respect to input pixel x2, y2, to +// output pixel x, y (ignoring the clamping). +float UpsamplerDeriv(int64_t x2, int64_t y2, int64_t x, int64_t y) { + auto kernel = kernel00; + if ((x & 1) && (y & 1)) { + kernel = kernel11; + } else if (x & 1) { + kernel = kernel10; + } else if (y & 1) { + kernel = kernel01; + } + + int64_t ix = x / 2; + int64_t iy = y / 2; + int64_t kx = x2 - ix + kSize / 2; + int64_t ky = y2 - iy + kSize / 2; + + // This should not happen. + if (kx < 0 || kx >= kSize || ky < 0 || ky >= kSize) return 0; + + return kernel[ky * kSize + kx]; +} + +// Apply the derivative of the Upsampler to the input, reversing the effect of +// its coefficients. The output image is 2x2 times smaller than the input. +void AntiUpsample(const ImageF& input, ImageF* d) { + int64_t xsize = input.xsize(); + int64_t ysize = input.ysize(); + int64_t xsize2 = d->xsize(); + int64_t ysize2 = d->ysize(); + int64_t k0 = kSize - 1; + int64_t k1 = kSize; + for (int64_t y2 = 0; y2 < ysize2; ++y2) { + auto* row = d->Row(y2); + for (int64_t x2 = 0; x2 < xsize2; ++x2) { + int64_t x0 = x2 * 2 - k0; + if (x0 < 0) x0 = 0; + int64_t x1 = x2 * 2 + k1 + 1; + if (x1 > xsize) x1 = xsize; + int64_t y0 = y2 * 2 - k0; + if (y0 < 0) y0 = 0; + int64_t y1 = y2 * 2 + k1 + 1; + if (y1 > ysize) y1 = ysize; + + float sum = 0; + for (int64_t y = y0; y < y1; ++y) { + const auto* row_in = input.Row(y); + for (int64_t x = x0; x < x1; ++x) { + double deriv = UpsamplerDeriv(x2, y2, x, y); + sum += deriv * row_in[x]; + } + } + row[x2] = sum; + } + } +} + +// Element-wise multiplies two images. +template <typename T> +void ElwiseMul(const Plane<T>& image1, const Plane<T>& image2, Plane<T>* out) { + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + JXL_CHECK(xsize == out->xsize()); + JXL_CHECK(ysize == out->ysize()); + for (size_t y = 0; y < ysize; ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + T* const JXL_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row1[x] * row2[x]; + } + } +} + +// Element-wise divides two images. +template <typename T> +void ElwiseDiv(const Plane<T>& image1, const Plane<T>& image2, Plane<T>* out) { + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + JXL_CHECK(xsize == out->xsize()); + JXL_CHECK(ysize == out->ysize()); + for (size_t y = 0; y < ysize; ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + T* const JXL_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row1[x] / row2[x]; + } + } +} + +void ReduceRinging(const ImageF& initial, const ImageF& mask, ImageF& down) { + int64_t xsize2 = down.xsize(); + int64_t ysize2 = down.ysize(); + + for (size_t y = 0; y < down.ysize(); y++) { + const float* row_mask = mask.Row(y); + float* row_out = down.Row(y); + for (size_t x = 0; x < down.xsize(); x++) { + float v = down.Row(y)[x]; + float min = initial.Row(y)[x]; + float max = initial.Row(y)[x]; + for (int64_t yi = -1; yi < 2; yi++) { + for (int64_t xi = -1; xi < 2; xi++) { + int64_t x2 = (int64_t)x + xi; + int64_t y2 = (int64_t)y + yi; + if (x2 < 0 || y2 < 0 || x2 >= (int64_t)xsize2 || + y2 >= (int64_t)ysize2) + continue; + min = std::min<float>(min, initial.Row(y2)[x2]); + max = std::max<float>(max, initial.Row(y2)[x2]); + } + } + + row_out[x] = v; + + // Clamp the pixel within the value of a small area to prevent ringning. + // The mask determines how much to clamp, clamp more to reduce more + // ringing in smooth areas, clamp less in noisy areas to get more + // sharpness. Higher mask_multiplier gives less clamping, so less + // ringing reduction. + const constexpr float mask_multiplier = 2; + float a = row_mask[x] * mask_multiplier; + float clip_min = min - a; + float clip_max = max + a; + if (row_out[x] < clip_min) row_out[x] = clip_min; + if (row_out[x] > clip_max) row_out[x] = clip_max; + } + } +} + +// TODO(lode): move this to a separate file enc_downsample.cc +void DownsampleImage2_Iterative(const ImageF& orig, ImageF* output) { + int64_t xsize = orig.xsize(); + int64_t ysize = orig.ysize(); + int64_t xsize2 = DivCeil(orig.xsize(), 2); + int64_t ysize2 = DivCeil(orig.ysize(), 2); + + ImageF box_downsample(xsize, ysize); + CopyImageTo(orig, &box_downsample); + DownsampleImage(&box_downsample, 2); + ImageF mask(box_downsample.xsize(), box_downsample.ysize()); + CreateMask(box_downsample, mask); + + output->ShrinkTo(xsize2, ysize2); + + // Initial result image using the sharper downsampling. + // Allocate extra space to avoid a reallocation when padding. + ImageF initial(DivCeil(orig.xsize(), 2) + kBlockDim, + DivCeil(orig.ysize(), 2) + kBlockDim); + initial.ShrinkTo(initial.xsize() - kBlockDim, initial.ysize() - kBlockDim); + DownsampleImage2_Sharper(orig, &initial); + + ImageF down(initial.xsize(), initial.ysize()); + CopyImageTo(initial, &down); + ImageF up(xsize, ysize); + ImageF corr(xsize, ysize); + ImageF corr2(xsize2, ysize2); + + // In the weights map, relatively higher values will allow less ringing but + // also less sharpness. With all constant values, it optimizes equally + // everywhere. Even in this case, the weights2 computed from + // this is still used and differs at the borders of the image. + // TODO(lode): Make use of the weights field for anti-ringing and clamping, + // the values are all set to 1 for now, but it is intended to be used for + // reducing ringing based on the mask, and taking clamping into account. + ImageF weights(xsize, ysize); + for (size_t y = 0; y < weights.ysize(); y++) { + auto* row = weights.Row(y); + for (size_t x = 0; x < weights.xsize(); x++) { + row[x] = 1; + } + } + ImageF weights2(xsize2, ysize2); + AntiUpsample(weights, &weights2); + + const size_t num_it = 3; + for (size_t it = 0; it < num_it; ++it) { + UpsampleImage(down, &up); + corr = LinComb<float>(1, orig, -1, up); + ElwiseMul(corr, weights, &corr); + AntiUpsample(corr, &corr2); + ElwiseDiv(corr2, weights2, &corr2); + + down = LinComb<float>(1, down, 1, corr2); + } + + ReduceRinging(initial, mask, down); + + // can't just use CopyImage, because the output image was prepared with + // padding. + for (size_t y = 0; y < down.ysize(); y++) { + for (size_t x = 0; x < down.xsize(); x++) { + float v = down.Row(y)[x]; + output->Row(y)[x] = v; + } + } +} + +} // namespace + +void DownsampleImage2_Iterative(Image3F* opsin) { + // Allocate extra space to avoid a reallocation when padding. + Image3F downsampled(DivCeil(opsin->xsize(), 2) + kBlockDim, + DivCeil(opsin->ysize(), 2) + kBlockDim); + downsampled.ShrinkTo(downsampled.xsize() - kBlockDim, + downsampled.ysize() - kBlockDim); + + Image3F rgb(opsin->xsize(), opsin->ysize()); + OpsinParams opsin_params; // TODO(user): use the ones that are actually used + opsin_params.Init(kDefaultIntensityTarget); + OpsinToLinear(*opsin, Rect(rgb), nullptr, &rgb, opsin_params); + + ImageF mask(opsin->xsize(), opsin->ysize()); + ButteraugliParams butter_params; + ButteraugliComparator butter(rgb, butter_params); + butter.Mask(&mask); + ImageF mask_fuzzy(opsin->xsize(), opsin->ysize()); + + for (size_t c = 0; c < 3; c++) { + DownsampleImage2_Iterative(opsin->Plane(c), &downsampled.Plane(c)); + } + *opsin = std::move(downsampled); +} + +Status LossyFrameHeuristics(const FrameHeader& frame_header, + PassesEncoderState* enc_state, + ModularFrameEncoder* modular_frame_encoder, + const Image3F* original_pixels, Image3F* opsin, + const Rect& rect, const JxlCmsInterface& cms, + ThreadPool* pool, AuxOut* aux_out) { + const CompressParams& cparams = enc_state->cparams; + const bool streaming_mode = enc_state->streaming_mode; + const bool initialize_global_state = enc_state->initialize_global_state; + PassesSharedState& shared = enc_state->shared; + const FrameDimensions& frame_dim = shared.frame_dim; + ImageFeatures& image_features = shared.image_features; + DequantMatrices& matrices = shared.matrices; + Quantizer& quantizer = shared.quantizer; + ImageI& raw_quant_field = shared.raw_quant_field; + ColorCorrelationMap& cmap = shared.cmap; + AcStrategyImage& ac_strategy = shared.ac_strategy; + ImageB& epf_sharpness = shared.epf_sharpness; + BlockCtxMap& block_ctx_map = shared.block_ctx_map; + + // Find and subtract splines. + if (!streaming_mode && cparams.speed_tier <= SpeedTier::kSquirrel) { + if (cparams.custom_splines.HasAny()) { + image_features.splines = cparams.custom_splines; + } else { + image_features.splines = FindSplines(*opsin); + } + JXL_RETURN_IF_ERROR(image_features.splines.InitializeDrawCache( + opsin->xsize(), opsin->ysize(), cmap)); + image_features.splines.SubtractFrom(opsin); + } + + // Find and subtract patches/dots. + if (!streaming_mode && + ApplyOverride(cparams.patches, + cparams.speed_tier <= SpeedTier::kSquirrel)) { + FindBestPatchDictionary(*opsin, enc_state, cms, pool, aux_out); + PatchDictionaryEncoder::SubtractFrom(image_features.patches, opsin); + } + + const float quant_dc = InitialQuantDC(cparams.butteraugli_distance); + + // TODO(veluca): we can now run all the code from here to FindBestQuantizer + // (excluded) one rect at a time. Do that. + + // Dependency graph: + // + // input: either XYB or input image + // + // input image -> XYB [optional] + // XYB -> initial quant field + // XYB -> Gaborished XYB + // Gaborished XYB -> CfL1 + // initial quant field, Gaborished XYB, CfL1 -> ACS + // initial quant field, ACS, Gaborished XYB -> EPF control field + // initial quant field -> adjusted initial quant field + // adjusted initial quant field, ACS -> raw quant field + // raw quant field, ACS, Gaborished XYB -> CfL2 + // + // output: Gaborished XYB, CfL, ACS, raw quant field, EPF control field. + + ArControlFieldHeuristics ar_heuristics; + AcStrategyHeuristics acs_heuristics(cparams); + CfLHeuristics cfl_heuristics; + ImageF initial_quant_field; + ImageF initial_quant_masking; + ImageF initial_quant_masking1x1; + + // Compute an initial estimate of the quantization field. + // Call InitialQuantField only in Hare mode or slower. Otherwise, rely + // on simple heuristics in FindBestAcStrategy, or set a constant for Falcon + // mode. + if (cparams.speed_tier > SpeedTier::kHare) { + initial_quant_field = + ImageF(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + initial_quant_masking = + ImageF(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + float q = 0.79 / cparams.butteraugli_distance; + FillImage(q, &initial_quant_field); + FillImage(1.0f / (q + 0.001f), &initial_quant_masking); + quantizer.ComputeGlobalScaleAndQuant(quant_dc, q, 0); + } else { + // Call this here, as it relies on pre-gaborish values. + float butteraugli_distance_for_iqf = cparams.butteraugli_distance; + if (!frame_header.loop_filter.gab) { + butteraugli_distance_for_iqf *= 0.73f; + } + initial_quant_field = InitialQuantField( + butteraugli_distance_for_iqf, *opsin, rect, pool, 1.0f, + &initial_quant_masking, &initial_quant_masking1x1); + float q = 0.39 / cparams.butteraugli_distance; + quantizer.ComputeGlobalScaleAndQuant(quant_dc, q, 0); + } + + // TODO(veluca): do something about animations. + + // Apply inverse-gaborish. + if (frame_header.loop_filter.gab) { + // Unsure why better to do some more gaborish on X and B than Y. + float weight[3] = { + 1.0036278514398933f, + 0.99406123118127299f, + 0.99719338015886894f, + }; + GaborishInverse(opsin, rect, weight, pool); + } + + if (initialize_global_state) { + FindBestDequantMatrices(cparams, modular_frame_encoder, &matrices); + } + + cfl_heuristics.Init(rect); + acs_heuristics.Init(*opsin, rect, initial_quant_field, initial_quant_masking, + initial_quant_masking1x1, &matrices); + + auto process_tile = [&](const uint32_t tid, const size_t thread) { + size_t n_enc_tiles = DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks); + size_t tx = tid % n_enc_tiles; + size_t ty = tid / n_enc_tiles; + size_t by0 = ty * kEncTileDimInBlocks; + size_t by1 = + std::min((ty + 1) * kEncTileDimInBlocks, frame_dim.ysize_blocks); + size_t bx0 = tx * kEncTileDimInBlocks; + size_t bx1 = + std::min((tx + 1) * kEncTileDimInBlocks, frame_dim.xsize_blocks); + Rect r(bx0, by0, bx1 - bx0, by1 - by0); + + // For speeds up to Wombat, we only compute the color correlation map + // once we know the transform type and the quantization map. + if (cparams.speed_tier <= SpeedTier::kSquirrel) { + cfl_heuristics.ComputeTile(r, *opsin, rect, matrices, + /*ac_strategy=*/nullptr, + /*raw_quant_field=*/nullptr, + /*quantizer=*/nullptr, /*fast=*/false, thread, + &cmap); + } + + // Choose block sizes. + acs_heuristics.ProcessRect(r, cmap, &ac_strategy); + + // Choose amount of post-processing smoothing. + // TODO(veluca): should this go *after* AdjustQuantField? + ar_heuristics.RunRect(cparams, frame_header, r, *opsin, rect, + initial_quant_field, ac_strategy, &epf_sharpness, + thread); + + // Always set the initial quant field, so we can compute the CfL map with + // more accuracy. The initial quant field might change in slower modes, but + // adjusting the quant field with butteraugli when all the other encoding + // parameters are fixed is likely a more reliable choice anyway. + AdjustQuantField(ac_strategy, r, cparams.butteraugli_distance, + &initial_quant_field); + quantizer.SetQuantFieldRect(initial_quant_field, r, &raw_quant_field); + + // Compute a non-default CfL map if we are at Hare speed, or slower. + if (cparams.speed_tier <= SpeedTier::kHare) { + cfl_heuristics.ComputeTile( + r, *opsin, rect, matrices, &ac_strategy, &raw_quant_field, &quantizer, + /*fast=*/cparams.speed_tier >= SpeedTier::kWombat, thread, &cmap); + } + }; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, + DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks) * + DivCeil(frame_dim.ysize_blocks, kEncTileDimInBlocks), + [&](const size_t num_threads) { + ar_heuristics.PrepareForThreads(num_threads); + cfl_heuristics.PrepareForThreads(num_threads); + return true; + }, + process_tile, "Enc Heuristics")); + + acs_heuristics.Finalize(frame_dim, ac_strategy, aux_out); + + // Refine quantization levels. + if (!streaming_mode) { + FindBestQuantizer(frame_header, original_pixels, *opsin, + initial_quant_field, enc_state, cms, pool, aux_out); + } + + // Choose a context model that depends on the amount of quantization for AC. + if (cparams.speed_tier < SpeedTier::kFalcon && initialize_global_state) { + FindBestBlockEntropyModel(cparams, raw_quant_field, ac_strategy, + &block_ctx_map); + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_heuristics.h b/third_party/jpeg-xl/lib/jxl/enc_heuristics.h new file mode 100644 index 0000000000..14cb596387 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_heuristics.h @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_HEURISTICS_H_ +#define LIB_JXL_ENC_HEURISTICS_H_ + +// Hook for custom encoder heuristics (VarDCT only for now). + +#include <jxl/cms_interface.h> +#include <stddef.h> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" + +namespace jxl { + +struct AuxOut; +struct PassesEncoderState; +class DequantMatrices; +class ImageBundle; +class ModularFrameEncoder; + +// Initializes encoder structures in `enc_state` using the original image data +// in `original_pixels`, and the XYB image data in `opsin`. Also modifies the +// `opsin` image by applying Gaborish, and doing other modifications if +// necessary. `pool` is used for running the computations on multiple threads. +// `aux_out` collects statistics and can be used to print debug images. +Status LossyFrameHeuristics(const FrameHeader& frame_header, + PassesEncoderState* enc_state, + ModularFrameEncoder* modular_frame_encoder, + const Image3F* original_pixels, Image3F* opsin, + const Rect& rect, const JxlCmsInterface& cms, + ThreadPool* pool, AuxOut* aux_out); + +void FindBestBlockEntropyModel(PassesEncoderState& enc_state); + +void DownsampleImage2_Iterative(Image3F* output); +void DownsampleImage2_Sharper(Image3F* opsin); + +} // namespace jxl + +#endif // LIB_JXL_ENC_HEURISTICS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_huffman.cc b/third_party/jpeg-xl/lib/jxl/enc_huffman.cc new file mode 100644 index 0000000000..3eab2c218a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_huffman.cc @@ -0,0 +1,214 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_huffman.h" + +#include <algorithm> +#include <memory> + +#include "lib/jxl/enc_huffman_tree.h" + +namespace jxl { + +namespace { + +constexpr int kCodeLengthCodes = 18; + +void StoreHuffmanTreeOfHuffmanTreeToBitMask(const int num_codes, + const uint8_t* code_length_bitdepth, + BitWriter* writer) { + static const uint8_t kStorageOrder[kCodeLengthCodes] = { + 1, 2, 3, 4, 0, 5, 17, 6, 16, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + // The bit lengths of the Huffman code over the code length alphabet + // are compressed with the following static Huffman code: + // Symbol Code + // ------ ---- + // 0 00 + // 1 1110 + // 2 110 + // 3 01 + // 4 10 + // 5 1111 + static const uint8_t kHuffmanBitLengthHuffmanCodeSymbols[6] = {0, 7, 3, + 2, 1, 15}; + static const uint8_t kHuffmanBitLengthHuffmanCodeBitLengths[6] = {2, 4, 3, + 2, 2, 4}; + + // Throw away trailing zeros: + size_t codes_to_store = kCodeLengthCodes; + if (num_codes > 1) { + for (; codes_to_store > 0; --codes_to_store) { + if (code_length_bitdepth[kStorageOrder[codes_to_store - 1]] != 0) { + break; + } + } + } + size_t skip_some = 0; // skips none. + if (code_length_bitdepth[kStorageOrder[0]] == 0 && + code_length_bitdepth[kStorageOrder[1]] == 0) { + skip_some = 2; // skips two. + if (code_length_bitdepth[kStorageOrder[2]] == 0) { + skip_some = 3; // skips three. + } + } + writer->Write(2, skip_some); + for (size_t i = skip_some; i < codes_to_store; ++i) { + size_t l = code_length_bitdepth[kStorageOrder[i]]; + writer->Write(kHuffmanBitLengthHuffmanCodeBitLengths[l], + kHuffmanBitLengthHuffmanCodeSymbols[l]); + } +} + +void StoreHuffmanTreeToBitMask(const size_t huffman_tree_size, + const uint8_t* huffman_tree, + const uint8_t* huffman_tree_extra_bits, + const uint8_t* code_length_bitdepth, + const uint16_t* code_length_bitdepth_symbols, + BitWriter* writer) { + for (size_t i = 0; i < huffman_tree_size; ++i) { + size_t ix = huffman_tree[i]; + writer->Write(code_length_bitdepth[ix], code_length_bitdepth_symbols[ix]); + // Extra bits + switch (ix) { + case 16: + writer->Write(2, huffman_tree_extra_bits[i]); + break; + case 17: + writer->Write(3, huffman_tree_extra_bits[i]); + break; + } + } +} + +void StoreSimpleHuffmanTree(const uint8_t* depths, size_t symbols[4], + size_t num_symbols, size_t max_bits, + BitWriter* writer) { + // value of 1 indicates a simple Huffman code + writer->Write(2, 1); + writer->Write(2, num_symbols - 1); // NSYM - 1 + + // Sort + for (size_t i = 0; i < num_symbols; i++) { + for (size_t j = i + 1; j < num_symbols; j++) { + if (depths[symbols[j]] < depths[symbols[i]]) { + std::swap(symbols[j], symbols[i]); + } + } + } + + if (num_symbols == 2) { + writer->Write(max_bits, symbols[0]); + writer->Write(max_bits, symbols[1]); + } else if (num_symbols == 3) { + writer->Write(max_bits, symbols[0]); + writer->Write(max_bits, symbols[1]); + writer->Write(max_bits, symbols[2]); + } else { + writer->Write(max_bits, symbols[0]); + writer->Write(max_bits, symbols[1]); + writer->Write(max_bits, symbols[2]); + writer->Write(max_bits, symbols[3]); + // tree-select + writer->Write(1, depths[symbols[0]] == 1 ? 1 : 0); + } +} + +// num = alphabet size +// depths = symbol depths +void StoreHuffmanTree(const uint8_t* depths, size_t num, BitWriter* writer) { + // Write the Huffman tree into the compact representation. + std::unique_ptr<uint8_t[]> arena(new uint8_t[2 * num]); + uint8_t* huffman_tree = arena.get(); + uint8_t* huffman_tree_extra_bits = arena.get() + num; + size_t huffman_tree_size = 0; + WriteHuffmanTree(depths, num, &huffman_tree_size, huffman_tree, + huffman_tree_extra_bits); + + // Calculate the statistics of the Huffman tree in the compact representation. + uint32_t huffman_tree_histogram[kCodeLengthCodes] = {0}; + for (size_t i = 0; i < huffman_tree_size; ++i) { + ++huffman_tree_histogram[huffman_tree[i]]; + } + + int num_codes = 0; + int code = 0; + for (int i = 0; i < kCodeLengthCodes; ++i) { + if (huffman_tree_histogram[i]) { + if (num_codes == 0) { + code = i; + num_codes = 1; + } else if (num_codes == 1) { + num_codes = 2; + break; + } + } + } + + // Calculate another Huffman tree to use for compressing both the + // earlier Huffman tree with. + uint8_t code_length_bitdepth[kCodeLengthCodes] = {0}; + uint16_t code_length_bitdepth_symbols[kCodeLengthCodes] = {0}; + CreateHuffmanTree(&huffman_tree_histogram[0], kCodeLengthCodes, 5, + &code_length_bitdepth[0]); + ConvertBitDepthsToSymbols(code_length_bitdepth, kCodeLengthCodes, + &code_length_bitdepth_symbols[0]); + + // Now, we have all the data, let's start storing it + StoreHuffmanTreeOfHuffmanTreeToBitMask(num_codes, code_length_bitdepth, + writer); + + if (num_codes == 1) { + code_length_bitdepth[code] = 0; + } + + // Store the real huffman tree now. + StoreHuffmanTreeToBitMask(huffman_tree_size, huffman_tree, + huffman_tree_extra_bits, &code_length_bitdepth[0], + code_length_bitdepth_symbols, writer); +} + +} // namespace + +void BuildAndStoreHuffmanTree(const uint32_t* histogram, const size_t length, + uint8_t* depth, uint16_t* bits, + BitWriter* writer) { + size_t count = 0; + size_t s4[4] = {0}; + for (size_t i = 0; i < length; i++) { + if (histogram[i]) { + if (count < 4) { + s4[count] = i; + } else if (count > 4) { + break; + } + count++; + } + } + + size_t max_bits_counter = length - 1; + size_t max_bits = 0; + while (max_bits_counter) { + max_bits_counter >>= 1; + ++max_bits; + } + + if (count <= 1) { + // Output symbol bits and depths are initialized with 0, nothing to do. + writer->Write(4, 1); + writer->Write(max_bits, s4[0]); + return; + } + + CreateHuffmanTree(histogram, length, 15, depth); + ConvertBitDepthsToSymbols(depth, length, bits); + + if (count <= 4) { + StoreSimpleHuffmanTree(depth, s4, count, max_bits, writer); + } else { + StoreHuffmanTree(depth, length, writer); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_huffman.h b/third_party/jpeg-xl/lib/jxl/enc_huffman.h new file mode 100644 index 0000000000..d7a66584e8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_huffman.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_HUFFMAN_H_ +#define LIB_JXL_ENC_HUFFMAN_H_ + +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +// Builds a Huffman tree for the given histogram, and encodes it into writer +// in a format that can be read by HuffmanDecodingData::ReadFromBitstream. +// An allotment for `writer` must already have been created by the caller. +void BuildAndStoreHuffmanTree(const uint32_t* histogram, size_t length, + uint8_t* depth, uint16_t* bits, + BitWriter* writer); + +} // namespace jxl + +#endif // LIB_JXL_ENC_HUFFMAN_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_huffman_tree.cc b/third_party/jpeg-xl/lib/jxl/enc_huffman_tree.cc new file mode 100644 index 0000000000..5c40dea770 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_huffman_tree.cc @@ -0,0 +1,328 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_huffman_tree.h" + +#include <algorithm> +#include <limits> +#include <vector> + +#include "lib/jxl/base/status.h" + +namespace jxl { + +void SetDepth(const HuffmanTree& p, HuffmanTree* pool, uint8_t* depth, + uint8_t level) { + if (p.index_left >= 0) { + ++level; + SetDepth(pool[p.index_left], pool, depth, level); + SetDepth(pool[p.index_right_or_value], pool, depth, level); + } else { + depth[p.index_right_or_value] = level; + } +} + +// Sort the root nodes, least popular first. +static JXL_INLINE bool Compare(const HuffmanTree& v0, const HuffmanTree& v1) { + return v0.total_count < v1.total_count; +} + +// This function will create a Huffman tree. +// +// The catch here is that the tree cannot be arbitrarily deep. +// Brotli specifies a maximum depth of 15 bits for "code trees" +// and 7 bits for "code length code trees." +// +// count_limit is the value that is to be faked as the minimum value +// and this minimum value is raised until the tree matches the +// maximum length requirement. +// +// This algorithm is not of excellent performance for very long data blocks, +// especially when population counts are longer than 2**tree_limit, but +// we are not planning to use this with extremely long blocks. +// +// See http://en.wikipedia.org/wiki/Huffman_coding +void CreateHuffmanTree(const uint32_t* data, const size_t length, + const int tree_limit, uint8_t* depth) { + // For block sizes below 64 kB, we never need to do a second iteration + // of this loop. Probably all of our block sizes will be smaller than + // that, so this loop is mostly of academic interest. If we actually + // would need this, we would be better off with the Katajainen algorithm. + for (uint32_t count_limit = 1;; count_limit *= 2) { + std::vector<HuffmanTree> tree; + tree.reserve(2 * length + 1); + + for (size_t i = length; i != 0;) { + --i; + if (data[i]) { + const uint32_t count = std::max(data[i], count_limit - 1); + tree.emplace_back(count, -1, static_cast<int16_t>(i)); + } + } + + const size_t n = tree.size(); + if (n == 1) { + // Fake value; will be fixed on upper level. + depth[tree[0].index_right_or_value] = 1; + break; + } + + std::stable_sort(tree.begin(), tree.end(), Compare); + + // The nodes are: + // [0, n): the sorted leaf nodes that we start with. + // [n]: we add a sentinel here. + // [n + 1, 2n): new parent nodes are added here, starting from + // (n+1). These are naturally in ascending order. + // [2n]: we add a sentinel at the end as well. + // There will be (2n+1) elements at the end. + const HuffmanTree sentinel(std::numeric_limits<uint32_t>::max(), -1, -1); + tree.push_back(sentinel); + tree.push_back(sentinel); + + size_t i = 0; // Points to the next leaf node. + size_t j = n + 1; // Points to the next non-leaf node. + for (size_t k = n - 1; k != 0; --k) { + size_t left, right; + if (tree[i].total_count <= tree[j].total_count) { + left = i; + ++i; + } else { + left = j; + ++j; + } + if (tree[i].total_count <= tree[j].total_count) { + right = i; + ++i; + } else { + right = j; + ++j; + } + + // The sentinel node becomes the parent node. + size_t j_end = tree.size() - 1; + tree[j_end].total_count = + tree[left].total_count + tree[right].total_count; + tree[j_end].index_left = static_cast<int16_t>(left); + tree[j_end].index_right_or_value = static_cast<int16_t>(right); + + // Add back the last sentinel node. + tree.push_back(sentinel); + } + JXL_DASSERT(tree.size() == 2 * n + 1); + SetDepth(tree[2 * n - 1], &tree[0], depth, 0); + + // We need to pack the Huffman tree in tree_limit bits. + // If this was not successful, add fake entities to the lowest values + // and retry. + if (*std::max_element(&depth[0], &depth[length]) <= tree_limit) { + break; + } + } +} + +void Reverse(uint8_t* v, size_t start, size_t end) { + --end; + while (start < end) { + uint8_t tmp = v[start]; + v[start] = v[end]; + v[end] = tmp; + ++start; + --end; + } +} + +void WriteHuffmanTreeRepetitions(const uint8_t previous_value, + const uint8_t value, size_t repetitions, + size_t* tree_size, uint8_t* tree, + uint8_t* extra_bits_data) { + JXL_DASSERT(repetitions > 0); + if (previous_value != value) { + tree[*tree_size] = value; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + --repetitions; + } + if (repetitions == 7) { + tree[*tree_size] = value; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + --repetitions; + } + if (repetitions < 3) { + for (size_t i = 0; i < repetitions; ++i) { + tree[*tree_size] = value; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + } + } else { + repetitions -= 3; + size_t start = *tree_size; + while (true) { + tree[*tree_size] = 16; + extra_bits_data[*tree_size] = repetitions & 0x3; + ++(*tree_size); + repetitions >>= 2; + if (repetitions == 0) { + break; + } + --repetitions; + } + Reverse(tree, start, *tree_size); + Reverse(extra_bits_data, start, *tree_size); + } +} + +void WriteHuffmanTreeRepetitionsZeros(size_t repetitions, size_t* tree_size, + uint8_t* tree, uint8_t* extra_bits_data) { + if (repetitions == 11) { + tree[*tree_size] = 0; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + --repetitions; + } + if (repetitions < 3) { + for (size_t i = 0; i < repetitions; ++i) { + tree[*tree_size] = 0; + extra_bits_data[*tree_size] = 0; + ++(*tree_size); + } + } else { + repetitions -= 3; + size_t start = *tree_size; + while (true) { + tree[*tree_size] = 17; + extra_bits_data[*tree_size] = repetitions & 0x7; + ++(*tree_size); + repetitions >>= 3; + if (repetitions == 0) { + break; + } + --repetitions; + } + Reverse(tree, start, *tree_size); + Reverse(extra_bits_data, start, *tree_size); + } +} + +static void DecideOverRleUse(const uint8_t* depth, const size_t length, + bool* use_rle_for_non_zero, + bool* use_rle_for_zero) { + size_t total_reps_zero = 0; + size_t total_reps_non_zero = 0; + size_t count_reps_zero = 1; + size_t count_reps_non_zero = 1; + for (size_t i = 0; i < length;) { + const uint8_t value = depth[i]; + size_t reps = 1; + for (size_t k = i + 1; k < length && depth[k] == value; ++k) { + ++reps; + } + if (reps >= 3 && value == 0) { + total_reps_zero += reps; + ++count_reps_zero; + } + if (reps >= 4 && value != 0) { + total_reps_non_zero += reps; + ++count_reps_non_zero; + } + i += reps; + } + *use_rle_for_non_zero = total_reps_non_zero > count_reps_non_zero * 2; + *use_rle_for_zero = total_reps_zero > count_reps_zero * 2; +} + +void WriteHuffmanTree(const uint8_t* depth, size_t length, size_t* tree_size, + uint8_t* tree, uint8_t* extra_bits_data) { + uint8_t previous_value = 8; + + // Throw away trailing zeros. + size_t new_length = length; + for (size_t i = 0; i < length; ++i) { + if (depth[length - i - 1] == 0) { + --new_length; + } else { + break; + } + } + + // First gather statistics on if it is a good idea to do rle. + bool use_rle_for_non_zero = false; + bool use_rle_for_zero = false; + if (length > 50) { + // Find rle coding for longer codes. + // Shorter codes seem not to benefit from rle. + DecideOverRleUse(depth, new_length, &use_rle_for_non_zero, + &use_rle_for_zero); + } + + // Actual rle coding. + for (size_t i = 0; i < new_length;) { + const uint8_t value = depth[i]; + size_t reps = 1; + if ((value != 0 && use_rle_for_non_zero) || + (value == 0 && use_rle_for_zero)) { + for (size_t k = i + 1; k < new_length && depth[k] == value; ++k) { + ++reps; + } + } + if (value == 0) { + WriteHuffmanTreeRepetitionsZeros(reps, tree_size, tree, extra_bits_data); + } else { + WriteHuffmanTreeRepetitions(previous_value, value, reps, tree_size, tree, + extra_bits_data); + previous_value = value; + } + i += reps; + } +} + +namespace { + +uint16_t ReverseBits(int num_bits, uint16_t bits) { + static const size_t kLut[16] = {// Pre-reversed 4-bit values. + 0x0, 0x8, 0x4, 0xc, 0x2, 0xa, 0x6, 0xe, + 0x1, 0x9, 0x5, 0xd, 0x3, 0xb, 0x7, 0xf}; + size_t retval = kLut[bits & 0xf]; + for (int i = 4; i < num_bits; i += 4) { + retval <<= 4; + bits = static_cast<uint16_t>(bits >> 4); + retval |= kLut[bits & 0xf]; + } + retval >>= (-num_bits & 0x3); + return static_cast<uint16_t>(retval); +} + +} // namespace + +void ConvertBitDepthsToSymbols(const uint8_t* depth, size_t len, + uint16_t* bits) { + // In Brotli, all bit depths are [1..15] + // 0 bit depth means that the symbol does not exist. + const int kMaxBits = 16; // 0..15 are values for bits + uint16_t bl_count[kMaxBits] = {0}; + { + for (size_t i = 0; i < len; ++i) { + ++bl_count[depth[i]]; + } + bl_count[0] = 0; + } + uint16_t next_code[kMaxBits]; + next_code[0] = 0; + { + int code = 0; + for (size_t i = 1; i < kMaxBits; ++i) { + code = (code + bl_count[i - 1]) << 1; + next_code[i] = static_cast<uint16_t>(code); + } + } + for (size_t i = 0; i < len; ++i) { + if (depth[i]) { + bits[i] = ReverseBits(depth[i], next_code[depth[i]]++); + } + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_huffman_tree.h b/third_party/jpeg-xl/lib/jxl/enc_huffman_tree.h new file mode 100644 index 0000000000..7d716cd3b5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_huffman_tree.h @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Library for creating Huffman codes from population counts. + +#ifndef LIB_JXL_HUFFMAN_TREE_H_ +#define LIB_JXL_HUFFMAN_TREE_H_ + +#include <stdint.h> +#include <stdlib.h> + +namespace jxl { + +// A node of a Huffman tree. +struct HuffmanTree { + HuffmanTree(uint32_t count, int16_t left, int16_t right) + : total_count(count), index_left(left), index_right_or_value(right) {} + uint32_t total_count; + int16_t index_left; + int16_t index_right_or_value; +}; + +void SetDepth(const HuffmanTree& p, HuffmanTree* pool, uint8_t* depth, + uint8_t level); + +// This function will create a Huffman tree. +// +// The (data,length) contains the population counts. +// The tree_limit is the maximum bit depth of the Huffman codes. +// +// The depth contains the tree, i.e., how many bits are used for +// the symbol. +// +// See http://en.wikipedia.org/wiki/Huffman_coding +void CreateHuffmanTree(const uint32_t* data, size_t length, int tree_limit, + uint8_t* depth); + +// Write a Huffman tree from bit depths into the bitstream representation +// of a Huffman tree. The generated Huffman tree is to be compressed once +// more using a Huffman tree +void WriteHuffmanTree(const uint8_t* depth, size_t length, size_t* tree_size, + uint8_t* tree, uint8_t* extra_bits_data); + +// Get the actual bit values for a tree of bit depths. +void ConvertBitDepthsToSymbols(const uint8_t* depth, size_t len, + uint16_t* bits); + +} // namespace jxl + +#endif // LIB_JXL_HUFFMAN_TREE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_icc_codec.cc b/third_party/jpeg-xl/lib/jxl/enc_icc_codec.cc new file mode 100644 index 0000000000..8e92fe3452 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_icc_codec.cc @@ -0,0 +1,447 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_icc_codec.h" + +#include <stdint.h> + +#include <limits> +#include <map> +#include <string> +#include <vector> + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/icc_codec_common.h" +#include "lib/jxl/padded_bytes.h" + +namespace jxl { +namespace { + +// Unshuffles or de-interleaves bytes, for example with width 2, turns +// "AaBbCcDc" into "ABCDabcd", this for example de-interleaves UTF-16 bytes into +// first all the high order bytes, then all the low order bytes. +// Transposes a matrix of width columns and ceil(size / width) rows. There are +// size elements, size may be < width * height, if so the +// last elements of the bottom row are missing, the missing spots are +// transposed along with the filled spots, and the result has the missing +// elements at the bottom of the rightmost column. The input is the input matrix +// in scanline order, the output is the result matrix in scanline order, with +// missing elements skipped over (this may occur at multiple positions). +void Unshuffle(uint8_t* data, size_t size, size_t width) { + size_t height = (size + width - 1) / width; // amount of rows of input + PaddedBytes result(size); + // i = input index, j output index + size_t s = 0, j = 0; + for (size_t i = 0; i < size; i++) { + result[j] = data[i]; + j += height; + if (j >= size) j = ++s; + } + + for (size_t i = 0; i < size; i++) { + data[i] = result[i]; + } +} + +// This is performed by the encoder, the encoder must be able to encode any +// random byte stream (not just byte streams that are a valid ICC profile), so +// an error returned by this function is an implementation error. +Status PredictAndShuffle(size_t stride, size_t width, int order, size_t num, + const uint8_t* data, size_t size, size_t* pos, + PaddedBytes* result) { + JXL_RETURN_IF_ERROR(CheckOutOfBounds(*pos, num, size)); + // Required by the specification, see decoder. stride * 4 must be < *pos. + if (!*pos || ((*pos - 1u) >> 2u) < stride) { + return JXL_FAILURE("Invalid stride"); + } + if (*pos < stride * 4) return JXL_FAILURE("Too large stride"); + size_t start = result->size(); + for (size_t i = 0; i < num; i++) { + uint8_t predicted = + LinearPredictICCValue(data, *pos, i, stride, width, order); + result->push_back(data[*pos + i] - predicted); + } + *pos += num; + if (width > 1) Unshuffle(result->data() + start, num, width); + return true; +} + +static inline void EncodeVarInt(uint64_t value, PaddedBytes* data) { + size_t pos = data->size(); + data->resize(data->size() + 9); + size_t output_size = data->size(); + uint8_t* output = data->data(); + + // While more than 7 bits of data are left, + // store 7 bits and set the next byte flag + while (value > 127) { + // TODO(eustas): should it be `<` ? + JXL_CHECK(pos <= output_size); + // |128: Set the next byte flag + output[pos++] = ((uint8_t)(value & 127)) | 128; + // Remove the seven bits we just wrote + value >>= 7; + } + // TODO(eustas): should it be `<` ? + JXL_CHECK(pos <= output_size); + output[pos++] = ((uint8_t)value) & 127; + + data->resize(pos); +} + +constexpr size_t kSizeLimit = std::numeric_limits<uint32_t>::max() >> 2; + +} // namespace + +// Outputs a transformed form of the given icc profile. The result itself is +// not particularly smaller than the input data in bytes, but it will be in a +// form that is easier to compress (more zeroes, ...) and will compress better +// with brotli. +Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) { + PaddedBytes commands; + PaddedBytes data; + + static_assert(sizeof(size_t) >= 4, "size_t is too short"); + // Fuzzer expects that PredictICC can accept any input, + // but 1GB should be enough for any purpose. + if (size > kSizeLimit) { + return JXL_FAILURE("ICC profile is too large"); + } + + EncodeVarInt(size, result); + + // Header + PaddedBytes header; + header.append(ICCInitialHeaderPrediction()); + EncodeUint32(0, size, &header); + for (size_t i = 0; i < kICCHeaderSize && i < size; i++) { + ICCPredictHeader(icc, size, header.data(), i); + data.push_back(icc[i] - header[i]); + } + if (size <= kICCHeaderSize) { + EncodeVarInt(0, result); // 0 commands + for (size_t i = 0; i < data.size(); i++) { + result->push_back(data[i]); + } + return true; + } + + std::vector<Tag> tags; + std::vector<size_t> tagstarts; + std::vector<size_t> tagsizes; + std::map<size_t, size_t> tagmap; + + // Tag list + size_t pos = kICCHeaderSize; + if (pos + 4 <= size) { + uint64_t numtags = DecodeUint32(icc, size, pos); + pos += 4; + EncodeVarInt(numtags + 1, &commands); + uint64_t prevtagstart = kICCHeaderSize + numtags * 12; + uint32_t prevtagsize = 0; + for (size_t i = 0; i < numtags; i++) { + if (pos + 12 > size) break; + + Tag tag = DecodeKeyword(icc, size, pos + 0); + uint32_t tagstart = DecodeUint32(icc, size, pos + 4); + uint32_t tagsize = DecodeUint32(icc, size, pos + 8); + pos += 12; + + tags.push_back(tag); + tagstarts.push_back(tagstart); + tagsizes.push_back(tagsize); + tagmap[tagstart] = tags.size() - 1; + + uint8_t tagcode = kCommandTagUnknown; + for (size_t j = 0; j < kNumTagStrings; j++) { + if (tag == *kTagStrings[j]) { + tagcode = j + kCommandTagStringFirst; + break; + } + } + + if (tag == kRtrcTag && pos + 24 < size) { + bool ok = true; + ok &= DecodeKeyword(icc, size, pos + 0) == kGtrcTag; + ok &= DecodeKeyword(icc, size, pos + 12) == kBtrcTag; + if (ok) { + for (size_t kk = 0; kk < 8; kk++) { + if (icc[pos - 8 + kk] != icc[pos + 4 + kk]) ok = false; + if (icc[pos - 8 + kk] != icc[pos + 16 + kk]) ok = false; + } + } + if (ok) { + tagcode = kCommandTagTRC; + pos += 24; + i += 2; + } + } + + if (tag == kRxyzTag && pos + 24 < size) { + bool ok = true; + ok &= DecodeKeyword(icc, size, pos + 0) == kGxyzTag; + ok &= DecodeKeyword(icc, size, pos + 12) == kBxyzTag; + uint32_t offsetr = tagstart; + uint32_t offsetg = DecodeUint32(icc, size, pos + 4); + uint32_t offsetb = DecodeUint32(icc, size, pos + 16); + uint32_t sizer = tagsize; + uint32_t sizeg = DecodeUint32(icc, size, pos + 8); + uint32_t sizeb = DecodeUint32(icc, size, pos + 20); + ok &= sizer == 20; + ok &= sizeg == 20; + ok &= sizeb == 20; + ok &= (offsetg == offsetr + 20); + ok &= (offsetb == offsetr + 40); + if (ok) { + tagcode = kCommandTagXYZ; + pos += 24; + i += 2; + } + } + + uint8_t command = tagcode; + uint64_t predicted_tagstart = prevtagstart + prevtagsize; + if (predicted_tagstart != tagstart) command |= kFlagBitOffset; + size_t predicted_tagsize = prevtagsize; + if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag || + tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag || + tag == kLumiTag) { + predicted_tagsize = 20; + } + if (predicted_tagsize != tagsize) command |= kFlagBitSize; + commands.push_back(command); + if (tagcode == 1) { + AppendKeyword(tag, &data); + } + if (command & kFlagBitOffset) EncodeVarInt(tagstart, &commands); + if (command & kFlagBitSize) EncodeVarInt(tagsize, &commands); + + prevtagstart = tagstart; + prevtagsize = tagsize; + } + } + // Indicate end of tag list or varint indicating there's none + commands.push_back(0); + + // Main content + // The main content in a valid ICC profile contains tagged elements, with the + // tag types (4 letter names) given by the tag list above, and the tag list + // pointing to the start and indicating the size of each tagged element. It is + // allowed for tagged elements to overlap, e.g. the curve for R, G and B could + // all point to the same one. + Tag tag; + size_t tagstart = 0, tagsize = 0, clutstart = 0; + + // Should always check tag_sane before doing math with tagsize. + const auto tag_sane = [&tagsize]() { + return (tagsize > 8) && (tagsize < kSizeLimit); + }; + + size_t last0 = pos; + // This loop appends commands to the output, processing some sub-section of a + // current tagged element each time. We need to keep track of the tagtype of + // the current element, and update it when we encounter the boundary of a + // next one. + // It is not required that the input data is a valid ICC profile, if the + // encoder does not recognize the data it will still be able to output bytes + // but will not predict as well. + while (pos <= size) { + size_t last1 = pos; + PaddedBytes commands_add; + PaddedBytes data_add; + + // This means the loop brought the position beyond the tag end. + // If tagsize is nonsensical, any pos looks "ok-ish". + if ((pos > tagstart + tagsize) && (tagsize < kSizeLimit)) { + tag = {{0, 0, 0, 0}}; // nonsensical value + } + + if (commands_add.empty() && data_add.empty() && tagmap.count(pos) && + pos + 4 <= size) { + size_t index = tagmap[pos]; + tag = DecodeKeyword(icc, size, pos); + tagstart = tagstarts[index]; + tagsize = tagsizes[index]; + + if (tag == kMlucTag && tag_sane() && pos + tagsize <= size && + icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 && + icc[pos + 7] == 0) { + size_t num = tagsize - 8; + commands_add.push_back(kCommandTypeStartFirst + 3); + pos += 8; + commands_add.push_back(kCommandShuffle2); + EncodeVarInt(num, &commands_add); + size_t start = data_add.size(); + for (size_t i = 0; i < num; i++) { + data_add.push_back(icc[pos]); + pos++; + } + Unshuffle(data_add.data() + start, num, 2); + } + + if (tag == kCurvTag && tag_sane() && pos + tagsize <= size && + icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 && + icc[pos + 7] == 0) { + size_t num = tagsize - 8; + if (num > 16 && num < (1 << 28) && pos + num <= size && pos > 0) { + commands_add.push_back(kCommandTypeStartFirst + 5); + pos += 8; + commands_add.push_back(kCommandPredict); + int order = 1, width = 2, stride = width; + commands_add.push_back((order << 2) | (width - 1)); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + } + } + + if (tag == kMab_Tag || tag == kMba_Tag) { + Tag subTag = DecodeKeyword(icc, size, pos); + if (pos + 12 < size && (subTag == kCurvTag || subTag == kVcgtTag) && + DecodeUint32(icc, size, pos + 4) == 0) { + uint32_t num = DecodeUint32(icc, size, pos + 8) * 2; + if (num > 16 && num < (1 << 28) && pos + 12 + num <= size) { + pos += 12; + last1 = pos; + commands_add.push_back(kCommandPredict); + int order = 1, width = 2, stride = width; + commands_add.push_back((order << 2) | (width - 1)); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + } + + if (pos == tagstart + 24 && pos + 4 < size) { + // Note that this value can be remembered for next iterations of the + // loop, so the "pos == clutstart" if below can trigger during a later + // iteration. + clutstart = tagstart + DecodeUint32(icc, size, pos); + } + + if (pos == clutstart && clutstart + 16 < size) { + size_t numi = icc[tagstart + 8]; + size_t numo = icc[tagstart + 9]; + size_t width = icc[clutstart + 16]; + size_t stride = width * numo; + size_t num = width * numo; + for (size_t i = 0; i < numi && clutstart + i < size; i++) { + num *= icc[clutstart + i]; + } + if ((width == 1 || width == 2) && num > 64 && num < (1 << 28) && + pos + num <= size && pos > stride * 4) { + commands_add.push_back(kCommandPredict); + int order = 1; + uint8_t flags = + (order << 2) | (width - 1) | (stride == width ? 0 : 16); + commands_add.push_back(flags); + if (flags & 16) EncodeVarInt(stride, &commands_add); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + } + } + + if (commands_add.empty() && data_add.empty() && tag == kGbd_Tag && + tag_sane() && pos == tagstart + 8 && pos + tagsize - 8 <= size && + pos > 16) { + size_t width = 4, order = 0, stride = width; + size_t num = tagsize - 8; + uint8_t flags = (order << 2) | (width - 1) | (stride == width ? 0 : 16); + commands_add.push_back(kCommandPredict); + commands_add.push_back(flags); + if (flags & 16) EncodeVarInt(stride, &commands_add); + EncodeVarInt(num, &commands_add); + JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, + size, &pos, &data_add)); + } + + if (commands_add.empty() && data_add.empty() && pos + 20 <= size) { + Tag subTag = DecodeKeyword(icc, size, pos); + if (subTag == kXyz_Tag && DecodeUint32(icc, size, pos + 4) == 0) { + commands_add.push_back(kCommandXYZ); + pos += 8; + for (size_t j = 0; j < 12; j++) data_add.push_back(icc[pos++]); + } + } + + if (commands_add.empty() && data_add.empty() && pos + 8 <= size) { + if (DecodeUint32(icc, size, pos + 4) == 0) { + Tag subTag = DecodeKeyword(icc, size, pos); + for (size_t i = 0; i < kNumTypeStrings; i++) { + if (subTag == *kTypeStrings[i]) { + commands_add.push_back(kCommandTypeStartFirst + i); + pos += 8; + break; + } + } + } + } + + if (!(commands_add.empty() && data_add.empty()) || pos == size) { + if (last0 < last1) { + commands.push_back(kCommandInsert); + EncodeVarInt(last1 - last0, &commands); + while (last0 < last1) { + data.push_back(icc[last0++]); + } + } + for (size_t i = 0; i < commands_add.size(); i++) { + commands.push_back(commands_add[i]); + } + for (size_t i = 0; i < data_add.size(); i++) { + data.push_back(data_add[i]); + } + last0 = pos; + } + if (commands_add.empty() && data_add.empty()) { + pos++; + } + } + + EncodeVarInt(commands.size(), result); + for (size_t i = 0; i < commands.size(); i++) { + result->push_back(commands[i]); + } + for (size_t i = 0; i < data.size(); i++) { + result->push_back(data[i]); + } + + return true; +} + +Status WriteICC(const IccBytes& icc, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* JXL_RESTRICT aux_out) { + if (icc.empty()) return JXL_FAILURE("ICC must be non-empty"); + PaddedBytes enc; + JXL_RETURN_IF_ERROR(PredictICC(icc.data(), icc.size(), &enc)); + std::vector<std::vector<Token>> tokens(1); + BitWriter::Allotment allotment(writer, 128); + JXL_RETURN_IF_ERROR(U64Coder::Write(enc.size(), writer)); + allotment.ReclaimAndCharge(writer, layer, aux_out); + + for (size_t i = 0; i < enc.size(); i++) { + tokens[0].emplace_back( + ICCANSContext(i, i > 0 ? enc[i - 1] : 0, i > 1 ? enc[i - 2] : 0), + enc[i]); + } + HistogramParams params; + params.lz77_method = enc.size() < 4096 ? HistogramParams::LZ77Method::kOptimal + : HistogramParams::LZ77Method::kLZ77; + EntropyEncodingData code; + std::vector<uint8_t> context_map; + params.force_huffman = true; + BuildAndEncodeHistograms(params, kNumICCContexts, tokens, &code, &context_map, + writer, layer, aux_out); + WriteTokens(tokens[0], code, context_map, 0, writer, layer, aux_out); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_icc_codec.h b/third_party/jpeg-xl/lib/jxl/enc_icc_codec.h new file mode 100644 index 0000000000..224c2e5316 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_icc_codec.h @@ -0,0 +1,35 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_ICC_CODEC_H_ +#define LIB_JXL_ENC_ICC_CODEC_H_ + +// Compressed representation of ICC profiles. + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +struct AuxOut; +class PaddedBytes; + +// Should still be called if `icc.empty()` - if so, writes only 1 bit. +Status WriteICC(const std::vector<uint8_t>& icc, BitWriter* JXL_RESTRICT writer, + size_t layer, AuxOut* JXL_RESTRICT aux_out); + +// Exposed only for testing +Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result); + +} // namespace jxl + +#endif // LIB_JXL_ENC_ICC_CODEC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_image_bundle.cc b/third_party/jpeg-xl/lib/jxl/enc_image_bundle.cc new file mode 100644 index 0000000000..1b41361320 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_image_bundle.cc @@ -0,0 +1,158 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_image_bundle.h" + +#include <jxl/cms_interface.h> + +#include <atomic> +#include <limits> +#include <utility> + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +Status ApplyColorTransform(const ColorEncoding& c_current, + float intensity_target, const Image3F& color, + const ImageF* black, const Rect& rect, + const ColorEncoding& c_desired, + const JxlCmsInterface& cms, ThreadPool* pool, + Image3F* out) { + ColorSpaceTransform c_transform(cms); + // Changing IsGray is probably a bug. + JXL_CHECK(c_current.IsGray() == c_desired.IsGray()); + bool is_gray = c_current.IsGray(); + if (out->xsize() < rect.xsize() || out->ysize() < rect.ysize()) { + *out = Image3F(rect.xsize(), rect.ysize()); + } else { + out->ShrinkTo(rect.xsize(), rect.ysize()); + } + std::atomic<bool> ok{true}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, rect.ysize(), + [&](const size_t num_threads) { + return c_transform.Init(c_current, c_desired, intensity_target, + rect.xsize(), num_threads); + }, + [&](const uint32_t y, const size_t thread) { + float* mutable_src_buf = c_transform.BufSrc(thread); + const float* src_buf = mutable_src_buf; + // Interleave input. + if (is_gray) { + src_buf = rect.ConstPlaneRow(color, 0, y); + } else if (c_current.IsCMYK()) { + if (!black) { + ok.store(false); + return; + } + const float* JXL_RESTRICT row_in0 = rect.ConstPlaneRow(color, 0, y); + const float* JXL_RESTRICT row_in1 = rect.ConstPlaneRow(color, 1, y); + const float* JXL_RESTRICT row_in2 = rect.ConstPlaneRow(color, 2, y); + const float* JXL_RESTRICT row_in3 = rect.ConstRow(*black, y); + for (size_t x = 0; x < rect.xsize(); x++) { + // CMYK convention in JXL: 0 = max ink, 1 = white + mutable_src_buf[4 * x + 0] = row_in0[x]; + mutable_src_buf[4 * x + 1] = row_in1[x]; + mutable_src_buf[4 * x + 2] = row_in2[x]; + mutable_src_buf[4 * x + 3] = row_in3[x]; + } + } else { + const float* JXL_RESTRICT row_in0 = rect.ConstPlaneRow(color, 0, y); + const float* JXL_RESTRICT row_in1 = rect.ConstPlaneRow(color, 1, y); + const float* JXL_RESTRICT row_in2 = rect.ConstPlaneRow(color, 2, y); + for (size_t x = 0; x < rect.xsize(); x++) { + mutable_src_buf[3 * x + 0] = row_in0[x]; + mutable_src_buf[3 * x + 1] = row_in1[x]; + mutable_src_buf[3 * x + 2] = row_in2[x]; + } + } + float* JXL_RESTRICT dst_buf = c_transform.BufDst(thread); + if (!c_transform.Run(thread, src_buf, dst_buf)) { + ok.store(false); + return; + } + float* JXL_RESTRICT row_out0 = out->PlaneRow(0, y); + float* JXL_RESTRICT row_out1 = out->PlaneRow(1, y); + float* JXL_RESTRICT row_out2 = out->PlaneRow(2, y); + // De-interleave output and convert type. + if (is_gray) { + for (size_t x = 0; x < rect.xsize(); x++) { + row_out0[x] = dst_buf[x]; + row_out1[x] = dst_buf[x]; + row_out2[x] = dst_buf[x]; + } + } else { + for (size_t x = 0; x < rect.xsize(); x++) { + row_out0[x] = dst_buf[3 * x + 0]; + row_out1[x] = dst_buf[3 * x + 1]; + row_out2[x] = dst_buf[3 * x + 2]; + } + } + }, + "Colorspace transform")); + return ok.load(); +} + +namespace { + +// Copies ib:rect, converts, and copies into out. +Status CopyToT(const ImageMetadata* metadata, const ImageBundle* ib, + const Rect& rect, const ColorEncoding& c_desired, + const JxlCmsInterface& cms, ThreadPool* pool, Image3F* out) { + return ApplyColorTransform( + ib->c_current(), metadata->IntensityTarget(), ib->color(), + ib->HasBlack() ? &ib->black() : nullptr, rect, c_desired, cms, pool, out); +} + +} // namespace + +Status ImageBundle::TransformTo(const ColorEncoding& c_desired, + const JxlCmsInterface& cms, ThreadPool* pool) { + JXL_RETURN_IF_ERROR(CopyTo(Rect(color_), c_desired, cms, &color_, pool)); + c_current_ = c_desired; + return true; +} +Status ImageBundle::CopyTo(const Rect& rect, const ColorEncoding& c_desired, + const JxlCmsInterface& cms, Image3F* out, + ThreadPool* pool) const { + return CopyToT(metadata_, this, rect, c_desired, cms, pool, out); +} +Status TransformIfNeeded(const ImageBundle& in, const ColorEncoding& c_desired, + const JxlCmsInterface& cms, ThreadPool* pool, + ImageBundle* store, const ImageBundle** out) { + if (in.c_current().SameColorEncoding(c_desired) && !in.HasBlack()) { + *out = ∈ + return true; + } + // TODO(janwas): avoid copying via createExternal+copyBackToIO + // instead of copy+createExternal+copyBackToIO + Image3F color(in.color().xsize(), in.color().ysize()); + CopyImageTo(in.color(), &color); + store->SetFromImage(std::move(color), in.c_current()); + + // Must at least copy the alpha channel for use by external_image. + if (in.HasExtraChannels()) { + std::vector<ImageF> extra_channels; + for (const ImageF& extra_channel : in.extra_channels()) { + ImageF ec(extra_channel.xsize(), extra_channel.ysize()); + CopyImageTo(extra_channel, &ec); + extra_channels.emplace_back(std::move(ec)); + } + store->SetExtraChannels(std::move(extra_channels)); + } + + if (!store->TransformTo(c_desired, cms, pool)) { + return false; + } + *out = store; + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_image_bundle.h b/third_party/jpeg-xl/lib/jxl/enc_image_bundle.h new file mode 100644 index 0000000000..38536c8c7a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_image_bundle.h @@ -0,0 +1,38 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_IMAGE_BUNDLE_H_ +#define LIB_JXL_ENC_IMAGE_BUNDLE_H_ + +#include <jxl/cms_interface.h> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +Status ApplyColorTransform(const ColorEncoding& c_current, + float intensity_target, const Image3F& color, + const ImageF* black, const Rect& rect, + const ColorEncoding& c_desired, + const JxlCmsInterface& cms, ThreadPool* pool, + Image3F* out); + +// Does color transformation from in.c_current() to c_desired if the color +// encodings are different, or nothing if they are already the same. +// If color transformation is done, stores the transformed values into store and +// sets the out pointer to store, else leaves store untouched and sets the out +// pointer to &in. +// Returns false if color transform fails. +Status TransformIfNeeded(const ImageBundle& in, const ColorEncoding& c_desired, + const JxlCmsInterface& cms, ThreadPool* pool, + ImageBundle* store, const ImageBundle** out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_IMAGE_BUNDLE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_linalg.cc b/third_party/jpeg-xl/lib/jxl/enc_linalg.cc new file mode 100644 index 0000000000..fe2090a909 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_linalg.cc @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_linalg.h" + +#include <cmath> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +void ConvertToDiagonal(const ImageD& A, ImageD* const JXL_RESTRICT diag, + ImageD* const JXL_RESTRICT U) { +#if JXL_ENABLE_ASSERT + JXL_ASSERT(A.xsize() == 2); + JXL_ASSERT(A.ysize() == 2); + JXL_ASSERT(std::abs(A.Row(0)[1] - A.Row(1)[0]) < 1e-15); +#endif + + if (std::abs(A.ConstRow(0)[1]) < 1e-15) { + // Already diagonal. + diag->Row(0)[0] = A.ConstRow(0)[0]; + diag->Row(0)[1] = A.ConstRow(1)[1]; + U->Row(0)[0] = U->Row(1)[1] = 1.0; + U->Row(0)[1] = U->Row(1)[0] = 0.0; + return; + } + double b = -(A.Row(0)[0] + A.Row(1)[1]); + double c = A.Row(0)[0] * A.Row(1)[1] - A.Row(0)[1] * A.Row(0)[1]; + double d = b * b - 4.0 * c; + double sqd = std::sqrt(d); + double l1 = (-b - sqd) * 0.5; + double l2 = (-b + sqd) * 0.5; + + double v1[2] = {A.Row(0)[0] - l1, A.Row(1)[0]}; + double v1n = 1.0 / std::hypot(v1[0], v1[1]); + v1[0] = v1[0] * v1n; + v1[1] = v1[1] * v1n; + + diag->Row(0)[0] = l1; + diag->Row(0)[1] = l2; + + U->Row(0)[0] = v1[1]; + U->Row(0)[1] = -v1[0]; + U->Row(1)[0] = v1[0]; + U->Row(1)[1] = v1[1]; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_linalg.h b/third_party/jpeg-xl/lib/jxl/enc_linalg.h new file mode 100644 index 0000000000..791770d5d4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_linalg.h @@ -0,0 +1,24 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_LINALG_H_ +#define LIB_JXL_LINALG_H_ + +// Linear algebra. + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/image.h" + +namespace jxl { + +using ImageD = Plane<double>; + +// A is symmetric, U is orthogonal, and A = U * Diagonal(diag) * Transpose(U). +void ConvertToDiagonal(const ImageD& A, ImageD* JXL_RESTRICT diag, + ImageD* JXL_RESTRICT U); + +} // namespace jxl + +#endif // LIB_JXL_LINALG_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_linalg_test.cc b/third_party/jpeg-xl/lib/jxl/enc_linalg_test.cc new file mode 100644 index 0000000000..967b9a3afb --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_linalg_test.cc @@ -0,0 +1,118 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_linalg.h" + +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +ImageD Identity(const size_t N) { + ImageD out(N, N); + for (size_t i = 0; i < N; ++i) { + double* JXL_RESTRICT row = out.Row(i); + std::fill(row, row + N, 0); + row[i] = 1.0; + } + return out; +} + +ImageD Diagonal(const ImageD& d) { + JXL_ASSERT(d.ysize() == 1); + ImageD out(d.xsize(), d.xsize()); + const double* JXL_RESTRICT row_diag = d.Row(0); + for (size_t k = 0; k < d.xsize(); ++k) { + double* JXL_RESTRICT row_out = out.Row(k); + std::fill(row_out, row_out + d.xsize(), 0.0); + row_out[k] = row_diag[k]; + } + return out; +} + +ImageD MatMul(const ImageD& A, const ImageD& B) { + JXL_ASSERT(A.ysize() == B.xsize()); + ImageD out(A.xsize(), B.ysize()); + for (size_t y = 0; y < B.ysize(); ++y) { + const double* const JXL_RESTRICT row_b = B.Row(y); + double* const JXL_RESTRICT row_out = out.Row(y); + for (size_t x = 0; x < A.xsize(); ++x) { + row_out[x] = 0.0; + for (size_t k = 0; k < B.xsize(); ++k) { + row_out[x] += A.Row(k)[x] * row_b[k]; + } + } + } + return out; +} + +ImageD Transpose(const ImageD& A) { + ImageD out(A.ysize(), A.xsize()); + for (size_t x = 0; x < A.xsize(); ++x) { + double* const JXL_RESTRICT row_out = out.Row(x); + for (size_t y = 0; y < A.ysize(); ++y) { + row_out[y] = A.Row(y)[x]; + } + } + return out; +} + +ImageD RandomSymmetricMatrix(const size_t N, Rng& rng, const double vmin, + const double vmax) { + ImageD A(N, N); + GenerateImage(rng, &A, vmin, vmax); + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < i; ++j) { + A.Row(j)[i] = A.Row(i)[j]; + } + } + return A; +} + +void VerifyMatrixEqual(const ImageD& A, const ImageD& B, const double eps) { + ASSERT_EQ(A.xsize(), B.xsize()); + ASSERT_EQ(A.ysize(), B.ysize()); + for (size_t y = 0; y < A.ysize(); ++y) { + for (size_t x = 0; x < A.xsize(); ++x) { + ASSERT_NEAR(A.Row(y)[x], B.Row(y)[x], eps); + } + } +} + +void VerifyOrthogonal(const ImageD& A, const double eps) { + VerifyMatrixEqual(Identity(A.xsize()), MatMul(Transpose(A), A), eps); +} + +TEST(LinAlgTest, ConvertToDiagonal) { + { + ImageD I = Identity(2); + ImageD U(2, 2), d(2, 1); + ConvertToDiagonal(I, &d, &U); + VerifyMatrixEqual(I, U, 1e-15); + for (size_t k = 0; k < 2; ++k) { + ASSERT_NEAR(d.Row(0)[k], 1.0, 1e-15); + } + } + { + ImageD A = Identity(2); + A.Row(0)[1] = A.Row(1)[0] = 2.0; + ImageD U(2, 2), d(2, 1); + ConvertToDiagonal(A, &d, &U); + VerifyOrthogonal(U, 1e-12); + VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12); + } + Rng rng(0); + for (size_t i = 0; i < 100; ++i) { + ImageD A = RandomSymmetricMatrix(2, rng, -1.0, 1.0); + ImageD U(2, 2), d(2, 1); + ConvertToDiagonal(A, &d, &U); + VerifyOrthogonal(U, 1e-12); + VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12); + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_modular.cc b/third_party/jpeg-xl/lib/jxl/enc_modular.cc new file mode 100644 index 0000000000..b8366953b7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_modular.cc @@ -0,0 +1,1646 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_modular.h" + +#include <stddef.h> +#include <stdint.h> + +#include <array> +#include <atomic> +#include <limits> +#include <queue> +#include <utility> +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/compressed_dc.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cluster.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/enc_gaborish.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_patch_dictionary.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_debug_tree.h" +#include "lib/jxl/modular/encoding/enc_encoding.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/enc_transform.h" +#include "lib/jxl/pack_signed.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +namespace { +// constexpr bool kPrintTree = false; + +// Squeeze default quantization factors +// these quantization factors are for -Q 50 (other qualities simply scale the +// factors; things are rounded down and obviously cannot get below 1) +static const float squeeze_quality_factor = + 0.35; // for easy tweaking of the quality range (decrease this number for + // higher quality) +static const float squeeze_luma_factor = + 1.1; // for easy tweaking of the balance between luma (or anything + // non-chroma) and chroma (decrease this number for higher quality + // luma) +static const float squeeze_quality_factor_xyb = 2.4f; +static const float squeeze_xyb_qtable[3][16] = { + {163.84, 81.92, 40.96, 20.48, 10.24, 5.12, 2.56, 1.28, 0.64, 0.32, 0.16, + 0.08, 0.04, 0.02, 0.01, 0.005}, // Y + {1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5, 0.5, + 0.5}, // X + {2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5, + 0.5}, // B-Y +}; + +static const float squeeze_luma_qtable[16] = { + 163.84, 81.92, 40.96, 20.48, 10.24, 5.12, 2.56, 1.28, + 0.64, 0.32, 0.16, 0.08, 0.04, 0.02, 0.01, 0.005}; +// for 8-bit input, the range of YCoCg chroma is -255..255 so basically this +// does 4:2:0 subsampling (two most fine grained layers get quantized away) +static const float squeeze_chroma_qtable[16] = { + 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5, 0.5, 0.5}; + +// Merges the trees in `trees` using nodes that decide on stream_id, as defined +// by `tree_splits`. +void MergeTrees(const std::vector<Tree>& trees, + const std::vector<size_t>& tree_splits, size_t begin, + size_t end, Tree* tree) { + JXL_ASSERT(trees.size() + 1 == tree_splits.size()); + JXL_ASSERT(end > begin); + JXL_ASSERT(end <= trees.size()); + if (end == begin + 1) { + // Insert the tree, adding the opportune offset to all child nodes. + // This will make the leaf IDs wrong, but subsequent roundtripping will fix + // them. + size_t sz = tree->size(); + tree->insert(tree->end(), trees[begin].begin(), trees[begin].end()); + for (size_t i = sz; i < tree->size(); i++) { + (*tree)[i].lchild += sz; + (*tree)[i].rchild += sz; + } + return; + } + size_t mid = (begin + end) / 2; + size_t splitval = tree_splits[mid] - 1; + size_t cur = tree->size(); + tree->emplace_back(1 /*stream_id*/, splitval, 0, 0, Predictor::Zero, 0, 1); + (*tree)[cur].lchild = tree->size(); + MergeTrees(trees, tree_splits, mid, end, tree); + (*tree)[cur].rchild = tree->size(); + MergeTrees(trees, tree_splits, begin, mid, tree); +} + +void QuantizeChannel(Channel& ch, const int q) { + if (q == 1) return; + for (size_t y = 0; y < ch.plane.ysize(); y++) { + pixel_type* row = ch.plane.Row(y); + for (size_t x = 0; x < ch.plane.xsize(); x++) { + if (row[x] < 0) { + row[x] = -((-row[x] + q / 2) / q) * q; + } else { + row[x] = ((row[x] + q / 2) / q) * q; + } + } + } +} + +// convert binary32 float that corresponds to custom [bits]-bit float (with +// [exp_bits] exponent bits) to a [bits]-bit integer representation that should +// fit in pixel_type +Status float_to_int(const float* const row_in, pixel_type* const row_out, + size_t xsize, unsigned int bits, unsigned int exp_bits, + bool fp, double dfactor) { + JXL_ASSERT(sizeof(pixel_type) * 8 >= bits); + if (!fp) { + if (bits > 22) { + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row_in[x] * dfactor + (row_in[x] < 0 ? -0.5 : 0.5); + } + } else { + float factor = dfactor; + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row_in[x] * factor + (row_in[x] < 0 ? -0.5f : 0.5f); + } + } + return true; + } + if (bits == 32 && fp) { + JXL_ASSERT(exp_bits == 8); + memcpy((void*)row_out, (const void*)row_in, 4 * xsize); + return true; + } + + int exp_bias = (1 << (exp_bits - 1)) - 1; + int max_exp = (1 << exp_bits) - 1; + uint32_t sign = (1u << (bits - 1)); + int mant_bits = bits - exp_bits - 1; + int mant_shift = 23 - mant_bits; + for (size_t x = 0; x < xsize; ++x) { + uint32_t f; + memcpy(&f, &row_in[x], 4); + int signbit = (f >> 31); + f &= 0x7fffffff; + if (f == 0) { + row_out[x] = (signbit ? sign : 0); + continue; + } + int exp = (f >> 23) - 127; + if (exp == 128) return JXL_FAILURE("Inf/NaN not allowed"); + int mantissa = (f & 0x007fffff); + // broke up the binary32 into its parts, now reassemble into + // arbitrary float + exp += exp_bias; + if (exp < 0) { // will become a subnormal number + // add implicit leading 1 to mantissa + mantissa |= 0x00800000; + if (exp < -mant_bits) { + return JXL_FAILURE( + "Invalid float number: %g cannot be represented with %i " + "exp_bits and %i mant_bits (exp %i)", + row_in[x], exp_bits, mant_bits, exp); + } + mantissa >>= 1 - exp; + exp = 0; + } + // exp should be representable in exp_bits, otherwise input was + // invalid + if (exp > max_exp) return JXL_FAILURE("Invalid float exponent"); + if (mantissa & ((1 << mant_shift) - 1)) { + return JXL_FAILURE("%g is losing precision (mant: %x)", row_in[x], + mantissa); + } + mantissa >>= mant_shift; + f = (signbit ? sign : 0); + f |= (exp << mant_bits); + f |= mantissa; + row_out[x] = (pixel_type)f; + } + return true; +} +} // namespace + +ModularFrameEncoder::ModularFrameEncoder(const FrameHeader& frame_header, + const CompressParams& cparams_orig) + : frame_dim_(frame_header.ToFrameDimensions()), cparams_(cparams_orig) { + size_t num_streams = + ModularStreamId::Num(frame_dim_, frame_header.passes.num_passes); + if (cparams_.ModularPartIsLossless()) { + switch (cparams_.decoding_speed_tier) { + case 0: + break; + case 1: + cparams_.options.wp_tree_mode = ModularOptions::TreeMode::kWPOnly; + break; + case 2: { + cparams_.options.wp_tree_mode = ModularOptions::TreeMode::kGradientOnly; + cparams_.options.predictor = Predictor::Gradient; + break; + } + case 3: { // LZ77, no Gradient. + cparams_.options.nb_repeats = 0; + cparams_.options.predictor = Predictor::Gradient; + break; + } + default: { // LZ77, no predictor. + cparams_.options.nb_repeats = 0; + cparams_.options.predictor = Predictor::Zero; + break; + } + } + } + if (cparams_.decoding_speed_tier >= 1 && cparams_.responsive && + cparams_.ModularPartIsLossless()) { + cparams_.options.tree_kind = + ModularOptions::TreeKind::kTrivialTreeNoPredictor; + cparams_.options.nb_repeats = 0; + } + stream_images_.resize(num_streams); + + // use a sensible default if nothing explicit is specified: + // Squeeze for lossy, no squeeze for lossless + if (cparams_.responsive < 0) { + if (cparams_.ModularPartIsLossless()) { + cparams_.responsive = 0; + } else { + cparams_.responsive = 1; + } + } + + cparams_.options.splitting_heuristics_node_threshold = + 82 + 14 * static_cast<int>(cparams_.speed_tier); + + { + // Set properties. + std::vector<uint32_t> prop_order; + if (cparams_.responsive) { + // Properties in order of their likelihood of being useful for Squeeze + // residuals. + prop_order = {0, 1, 4, 5, 6, 7, 8, 15, 9, 10, 11, 12, 13, 14, 2, 3}; + } else { + // Same, but for the non-Squeeze case. + prop_order = {0, 1, 15, 9, 10, 11, 12, 13, 14, 2, 3, 4, 5, 6, 7, 8}; + // if few groups, don't use group as a property + if (num_streams < 30 && cparams_.speed_tier > SpeedTier::kTortoise) { + prop_order.erase(prop_order.begin() + 1); + } + } + switch (cparams_.speed_tier) { + case SpeedTier::kHare: + cparams_.options.splitting_heuristics_properties.assign( + prop_order.begin(), prop_order.begin() + 4); + cparams_.options.max_property_values = 24; + break; + case SpeedTier::kWombat: + cparams_.options.splitting_heuristics_properties.assign( + prop_order.begin(), prop_order.begin() + 5); + cparams_.options.max_property_values = 32; + break; + case SpeedTier::kSquirrel: + cparams_.options.splitting_heuristics_properties.assign( + prop_order.begin(), prop_order.begin() + 7); + cparams_.options.max_property_values = 48; + break; + case SpeedTier::kKitten: + cparams_.options.splitting_heuristics_properties.assign( + prop_order.begin(), prop_order.begin() + 10); + cparams_.options.max_property_values = 96; + break; + case SpeedTier::kTortoise: + cparams_.options.splitting_heuristics_properties = prop_order; + cparams_.options.max_property_values = 256; + break; + default: + cparams_.options.splitting_heuristics_properties.assign( + prop_order.begin(), prop_order.begin() + 3); + cparams_.options.max_property_values = 16; + break; + } + if (cparams_.speed_tier > SpeedTier::kTortoise) { + // Gradient in previous channels. + for (int i = 0; i < cparams_.options.max_properties; i++) { + cparams_.options.splitting_heuristics_properties.push_back( + kNumNonrefProperties + i * 4 + 3); + } + } else { + // All the extra properties in Tortoise mode. + for (int i = 0; i < cparams_.options.max_properties * 4; i++) { + cparams_.options.splitting_heuristics_properties.push_back( + kNumNonrefProperties + i); + } + } + } + + if (cparams_.options.predictor == static_cast<Predictor>(-1)) { + // no explicit predictor(s) given, set a good default + if ((cparams_.speed_tier <= SpeedTier::kTortoise || + cparams_.modular_mode == false) && + cparams_.IsLossless() && cparams_.responsive == false) { + // TODO(veluca): allow all predictors that don't break residual + // multipliers in lossy mode. + cparams_.options.predictor = Predictor::Variable; + } else if (cparams_.responsive || cparams_.lossy_palette) { + // zero predictor for Squeeze residues and lossy palette + cparams_.options.predictor = Predictor::Zero; + } else if (!cparams_.IsLossless()) { + // If not responsive and lossy. TODO(veluca): use near_lossless instead? + cparams_.options.predictor = Predictor::Gradient; + } else if (cparams_.speed_tier < SpeedTier::kFalcon) { + // try median and weighted predictor for anything else + cparams_.options.predictor = Predictor::Best; + } else if (cparams_.speed_tier == SpeedTier::kFalcon) { + // just weighted predictor in falcon mode + cparams_.options.predictor = Predictor::Weighted; + } else if (cparams_.speed_tier > SpeedTier::kFalcon) { + // just gradient predictor in thunder mode + cparams_.options.predictor = Predictor::Gradient; + } + } else { + delta_pred_ = cparams_.options.predictor; + if (cparams_.lossy_palette) cparams_.options.predictor = Predictor::Zero; + } + if (!cparams_.ModularPartIsLossless()) { + if (cparams_.options.predictor == Predictor::Weighted || + cparams_.options.predictor == Predictor::Variable || + cparams_.options.predictor == Predictor::Best) + cparams_.options.predictor = Predictor::Zero; + } + tree_splits_.push_back(0); + if (cparams_.modular_mode == false) { + cparams_.options.fast_decode_multiplier = 1.0f; + tree_splits_.push_back(ModularStreamId::VarDCTDC(0).ID(frame_dim_)); + tree_splits_.push_back(ModularStreamId::ModularDC(0).ID(frame_dim_)); + tree_splits_.push_back(ModularStreamId::ACMetadata(0).ID(frame_dim_)); + tree_splits_.push_back(ModularStreamId::QuantTable(0).ID(frame_dim_)); + tree_splits_.push_back(ModularStreamId::ModularAC(0, 0).ID(frame_dim_)); + ac_metadata_size.resize(frame_dim_.num_dc_groups); + extra_dc_precision.resize(frame_dim_.num_dc_groups); + } + tree_splits_.push_back(num_streams); + cparams_.options.max_chan_size = frame_dim_.group_dim; + cparams_.options.group_dim = frame_dim_.group_dim; + + // TODO(veluca): figure out how to use different predictor sets per channel. + stream_options_.resize(num_streams, cparams_.options); +} + +bool do_transform(Image& image, const Transform& tr, + const weighted::Header& wp_header, + jxl::ThreadPool* pool = nullptr, bool force_jxlart = false) { + Transform t = tr; + bool did_it = true; + if (force_jxlart) { + if (!t.MetaApply(image)) return false; + } else { + did_it = TransformForward(t, image, wp_header, pool); + } + if (did_it) image.transform.push_back(t); + return did_it; +} + +Status ModularFrameEncoder::ComputeEncodingData( + const FrameHeader& frame_header, const ImageMetadata& metadata, + Image3F* JXL_RESTRICT color, const std::vector<ImageF>& extra_channels, + PassesEncoderState* JXL_RESTRICT enc_state, const JxlCmsInterface& cms, + ThreadPool* pool, AuxOut* aux_out, bool do_color) { + JXL_DEBUG_V(6, "Computing modular encoding data for frame %s", + frame_header.DebugString().c_str()); + + if (do_color && frame_header.loop_filter.gab) { + float w = 0.9908511000000001f; + float weights[3] = {w, w, w}; + GaborishInverse(color, Rect(*color), weights, pool); + } + + if (do_color && metadata.bit_depth.bits_per_sample <= 16 && + cparams_.speed_tier < SpeedTier::kCheetah && + cparams_.decoding_speed_tier < 2) { + FindBestPatchDictionary(*color, enc_state, cms, nullptr, aux_out, + cparams_.color_transform == ColorTransform::kXYB); + PatchDictionaryEncoder::SubtractFrom( + enc_state->shared.image_features.patches, color); + } + + // Convert ImageBundle to modular Image object + const size_t xsize = frame_dim_.xsize; + const size_t ysize = frame_dim_.ysize; + + int nb_chans = 3; + if (metadata.color_encoding.IsGray() && + cparams_.color_transform == ColorTransform::kNone) { + nb_chans = 1; + } + if (!do_color) nb_chans = 0; + + nb_chans += extra_channels.size(); + + bool fp = metadata.bit_depth.floating_point_sample && + cparams_.color_transform != ColorTransform::kXYB; + + // bits_per_sample is just metadata for XYB images. + if (metadata.bit_depth.bits_per_sample >= 32 && do_color && + cparams_.color_transform != ColorTransform::kXYB) { + if (metadata.bit_depth.bits_per_sample == 32 && fp == false) { + return JXL_FAILURE("uint32_t not supported in enc_modular"); + } else if (metadata.bit_depth.bits_per_sample > 32) { + return JXL_FAILURE("bits_per_sample > 32 not supported"); + } + } + + // in the non-float case, there is an implicit 0 sign bit + int max_bitdepth = + do_color ? metadata.bit_depth.bits_per_sample + (fp ? 0 : 1) : 0; + Image& gi = stream_images_[0]; + gi = Image(xsize, ysize, metadata.bit_depth.bits_per_sample, nb_chans); + int c = 0; + if (cparams_.color_transform == ColorTransform::kXYB && + cparams_.modular_mode == true) { + float enc_factors[3] = {32768.0f, 2048.0f, 2048.0f}; + if (cparams_.butteraugli_distance > 0 && !cparams_.responsive) { + // quantize XYB here and then treat it as a lossless image + enc_factors[0] *= 1.f / (1.f + 23.f * cparams_.butteraugli_distance); + enc_factors[1] *= 1.f / (1.f + 14.f * cparams_.butteraugli_distance); + enc_factors[2] *= 1.f / (1.f + 14.f * cparams_.butteraugli_distance); + cparams_.butteraugli_distance = 0; + } + if (cparams_.manual_xyb_factors.size() == 3) { + DequantMatricesSetCustomDC(&enc_state->shared.matrices, + cparams_.manual_xyb_factors.data()); + // TODO(jon): update max_bitdepth in this case + } else { + DequantMatricesSetCustomDC(&enc_state->shared.matrices, enc_factors); + max_bitdepth = 12; + } + } + pixel_type maxval = gi.bitdepth < 32 ? (1u << gi.bitdepth) - 1 : 0; + if (do_color) { + for (; c < 3; c++) { + if (metadata.color_encoding.IsGray() && + cparams_.color_transform == ColorTransform::kNone && + c != (cparams_.color_transform == ColorTransform::kXYB ? 1 : 0)) + continue; + int c_out = c; + // XYB is encoded as YX(B-Y) + if (cparams_.color_transform == ColorTransform::kXYB && c < 2) + c_out = 1 - c_out; + double factor = maxval; + if (cparams_.color_transform == ColorTransform::kXYB) + factor = enc_state->shared.matrices.InvDCQuant(c); + if (c == 2 && cparams_.color_transform == ColorTransform::kXYB) { + JXL_ASSERT(!fp); + for (size_t y = 0; y < ysize; ++y) { + const float* const JXL_RESTRICT row_in = color->PlaneRow(c, y); + pixel_type* const JXL_RESTRICT row_out = gi.channel[c_out].Row(y); + pixel_type* const JXL_RESTRICT row_Y = gi.channel[0].Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row_in[x] * factor + 0.5f; + row_out[x] -= row_Y[x]; + // zero the lsb of B + row_out[x] = row_out[x] / 2 * 2; + } + } + } else { + int bits = metadata.bit_depth.bits_per_sample; + int exp_bits = metadata.bit_depth.exponent_bits_per_sample; + gi.channel[c_out].hshift = frame_header.chroma_subsampling.HShift(c); + gi.channel[c_out].vshift = frame_header.chroma_subsampling.VShift(c); + size_t xsize_shifted = DivCeil(xsize, 1 << gi.channel[c_out].hshift); + size_t ysize_shifted = DivCeil(ysize, 1 << gi.channel[c_out].vshift); + gi.channel[c_out].shrink(xsize_shifted, ysize_shifted); + std::atomic<bool> has_error{false}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, ysize_shifted, ThreadPool::NoInit, + [&](const int task, const int thread) { + const size_t y = task; + const float* const JXL_RESTRICT row_in = color->PlaneRow(c, y); + pixel_type* const JXL_RESTRICT row_out = gi.channel[c_out].Row(y); + if (!float_to_int(row_in, row_out, xsize_shifted, bits, exp_bits, + fp, factor)) { + has_error = true; + }; + }, + "float2int")); + if (has_error) { + return JXL_FAILURE("Error in float to integer conversion"); + } + } + } + if (metadata.color_encoding.IsGray() && + cparams_.color_transform == ColorTransform::kNone) + c = 1; + } + + for (size_t ec = 0; ec < extra_channels.size(); ec++, c++) { + const ExtraChannelInfo& eci = metadata.extra_channel_info[ec]; + size_t ecups = frame_header.extra_channel_upsampling[ec]; + gi.channel[c].shrink(DivCeil(frame_dim_.xsize_upsampled, ecups), + DivCeil(frame_dim_.ysize_upsampled, ecups)); + gi.channel[c].hshift = gi.channel[c].vshift = + CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling); + + int bits = eci.bit_depth.bits_per_sample; + int exp_bits = eci.bit_depth.exponent_bits_per_sample; + bool fp = eci.bit_depth.floating_point_sample; + double factor = (fp ? 1 : ((1u << eci.bit_depth.bits_per_sample) - 1)); + if (bits + (fp ? 0 : 1) > max_bitdepth) max_bitdepth = bits + (fp ? 0 : 1); + std::atomic<bool> has_error{false}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, gi.channel[c].plane.ysize(), ThreadPool::NoInit, + [&](const int task, const int thread) { + const size_t y = task; + const float* const JXL_RESTRICT row_in = extra_channels[ec].Row(y); + pixel_type* const JXL_RESTRICT row_out = gi.channel[c].Row(y); + if (!float_to_int(row_in, row_out, gi.channel[c].plane.xsize(), bits, + exp_bits, fp, factor)) { + has_error = true; + }; + }, + "float2int")); + if (has_error) return JXL_FAILURE("Error in float to integer conversion"); + } + JXL_ASSERT(c == nb_chans); + + int level_max_bitdepth = (cparams_.level == 5 ? 16 : 32); + if (max_bitdepth > level_max_bitdepth) + return JXL_FAILURE( + "Bitdepth too high for level %i (need %i bits, have only %i in this " + "level)", + cparams_.level, max_bitdepth, level_max_bitdepth); + + // Set options and apply transformations + if (!cparams_.ModularPartIsLossless()) { + if (cparams_.palette_colors != 0) { + JXL_DEBUG_V(3, "Lossy encode, not doing palette transforms"); + } + if (cparams_.color_transform == ColorTransform::kXYB) { + cparams_.channel_colors_pre_transform_percent = 0; + } + cparams_.channel_colors_percent = 0; + cparams_.palette_colors = 0; + cparams_.lossy_palette = false; + } + + // Global palette + if (cparams_.palette_colors != 0 || cparams_.lossy_palette) { + // all-channel palette (e.g. RGBA) + if (gi.channel.size() - gi.nb_meta_channels > 1) { + Transform maybe_palette(TransformId::kPalette); + maybe_palette.begin_c = gi.nb_meta_channels; + maybe_palette.num_c = gi.channel.size() - gi.nb_meta_channels; + maybe_palette.nb_colors = + std::min((int)(xsize * ysize / 2), std::abs(cparams_.palette_colors)); + maybe_palette.ordered_palette = cparams_.palette_colors >= 0; + maybe_palette.lossy_palette = + (cparams_.lossy_palette && maybe_palette.num_c == 3); + if (maybe_palette.lossy_palette) { + maybe_palette.predictor = delta_pred_; + } + // TODO(veluca): use a custom weighted header if using the weighted + // predictor. + do_transform(gi, maybe_palette, weighted::Header(), pool, + cparams_.options.zero_tokens); + } + // all-minus-one-channel palette (RGB with separate alpha, or CMY with + // separate K) + if (gi.channel.size() - gi.nb_meta_channels > 3) { + Transform maybe_palette_3(TransformId::kPalette); + maybe_palette_3.begin_c = gi.nb_meta_channels; + maybe_palette_3.num_c = gi.channel.size() - gi.nb_meta_channels - 1; + maybe_palette_3.nb_colors = + std::min((int)(xsize * ysize / 3), std::abs(cparams_.palette_colors)); + maybe_palette_3.ordered_palette = cparams_.palette_colors >= 0; + maybe_palette_3.lossy_palette = cparams_.lossy_palette; + if (maybe_palette_3.lossy_palette) { + maybe_palette_3.predictor = delta_pred_; + } + do_transform(gi, maybe_palette_3, weighted::Header(), pool, + cparams_.options.zero_tokens); + } + } + + // Global channel palette + if (cparams_.channel_colors_pre_transform_percent > 0 && + !cparams_.lossy_palette && + (cparams_.speed_tier <= SpeedTier::kThunder || + (do_color && metadata.bit_depth.bits_per_sample > 8))) { + // single channel palette (like FLIF's ChannelCompact) + size_t nb_channels = gi.channel.size() - gi.nb_meta_channels; + int orig_bitdepth = max_bitdepth; + max_bitdepth = 0; + for (size_t i = 0; i < nb_channels; i++) { + int32_t min, max; + compute_minmax(gi.channel[gi.nb_meta_channels + i], &min, &max); + int64_t colors = (int64_t)max - min + 1; + JXL_DEBUG_V(10, "Channel %" PRIuS ": range=%i..%i", i, min, max); + Transform maybe_palette_1(TransformId::kPalette); + maybe_palette_1.begin_c = i + gi.nb_meta_channels; + maybe_palette_1.num_c = 1; + // simple heuristic: if less than X percent of the values in the range + // actually occur, it is probably worth it to do a compaction + // (but only if the channel palette is less than 6% the size of the + // image itself) + maybe_palette_1.nb_colors = std::min( + (int)(xsize * ysize / 16), + (int)(cparams_.channel_colors_pre_transform_percent / 100. * colors)); + if (do_transform(gi, maybe_palette_1, weighted::Header(), pool)) { + // effective bit depth is lower, adjust quantization accordingly + compute_minmax(gi.channel[gi.nb_meta_channels + i], &min, &max); + if (max < maxval) maxval = max; + int ch_bitdepth = + (max > 0 ? CeilLog2Nonzero(static_cast<uint32_t>(max)) : 0); + if (ch_bitdepth > max_bitdepth) max_bitdepth = ch_bitdepth; + } else + max_bitdepth = orig_bitdepth; + } + } + + // don't do an RCT if we're short on bits + if (cparams_.color_transform == ColorTransform::kNone && do_color && + gi.channel.size() - gi.nb_meta_channels >= 3 && + max_bitdepth + 1 < level_max_bitdepth) { + if (cparams_.colorspace < 0 && (!cparams_.ModularPartIsLossless() || + cparams_.speed_tier > SpeedTier::kHare)) { + Transform ycocg{TransformId::kRCT}; + ycocg.rct_type = 6; + ycocg.begin_c = gi.nb_meta_channels; + do_transform(gi, ycocg, weighted::Header(), pool); + max_bitdepth++; + } else if (cparams_.colorspace > 0) { + Transform sg(TransformId::kRCT); + sg.begin_c = gi.nb_meta_channels; + sg.rct_type = cparams_.colorspace; + do_transform(gi, sg, weighted::Header(), pool); + max_bitdepth++; + } + } + + // don't do squeeze if we don't have some spare bits + if (cparams_.responsive && !gi.channel.empty() && + max_bitdepth + 2 < level_max_bitdepth) { + Transform t(TransformId::kSqueeze); + do_transform(gi, t, weighted::Header(), pool); + max_bitdepth += 2; + } + + if (max_bitdepth + 1 > level_max_bitdepth) { + // force no group RCTs if we don't have a spare bit + cparams_.colorspace = 0; + } + JXL_ASSERT(max_bitdepth <= level_max_bitdepth); + + if (!cparams_.ModularPartIsLossless()) { + quants_.resize(gi.channel.size(), 1); + float quantizer = 0.25f; + if (!cparams_.responsive) { + JXL_DEBUG_V(1, + "Warning: lossy compression without Squeeze " + "transform is just color quantization."); + quantizer *= 0.1f; + } + float bitdepth_correction = 1.f; + if (cparams_.color_transform != ColorTransform::kXYB) { + bitdepth_correction = maxval / 255.f; + } + std::vector<float> quantizers; + float dist = cparams_.butteraugli_distance; + for (size_t i = 0; i < 3; i++) { + quantizers.push_back(quantizer * dist * bitdepth_correction); + } + for (size_t i = 0; i < extra_channels.size(); i++) { + int ec_bitdepth = + metadata.extra_channel_info[i].bit_depth.bits_per_sample; + pixel_type ec_maxval = ec_bitdepth < 32 ? (1u << ec_bitdepth) - 1 : 0; + bitdepth_correction = ec_maxval / 255.f; + if (i < cparams_.ec_distance.size()) dist = cparams_.ec_distance[i]; + if (dist < 0) dist = cparams_.butteraugli_distance; + quantizers.push_back(quantizer * dist * bitdepth_correction); + } + if (cparams_.options.nb_repeats == 0) { + return JXL_FAILURE("nb_repeats = 0 not supported with modular lossy!"); + } + for (uint32_t i = gi.nb_meta_channels; i < gi.channel.size(); i++) { + Channel& ch = gi.channel[i]; + int shift = ch.hshift + ch.vshift; // number of pixel halvings + if (shift > 16) shift = 16; + if (shift > 0) shift--; + int q; + // assuming default Squeeze here + int component = + (do_color ? 0 : 3) + ((i - gi.nb_meta_channels) % nb_chans); + // last 4 channels are final chroma residuals + if (nb_chans > 2 && i >= gi.channel.size() - 4 && cparams_.responsive) { + component = 1; + } + if (cparams_.color_transform == ColorTransform::kXYB && component < 3) { + q = quantizers[component] * squeeze_quality_factor_xyb * + squeeze_xyb_qtable[component][shift]; + } else { + if (cparams_.colorspace != 0 && component > 0 && component < 3) { + q = quantizers[component] * squeeze_quality_factor * + squeeze_chroma_qtable[shift]; + } else { + q = quantizers[component] * squeeze_quality_factor * + squeeze_luma_factor * squeeze_luma_qtable[shift]; + } + } + if (q < 1) q = 1; + QuantizeChannel(gi.channel[i], q); + quants_[i] = q; + } + } + + // Fill other groups. + struct GroupParams { + Rect rect; + int minShift; + int maxShift; + ModularStreamId id; + }; + std::vector<GroupParams> stream_params; + + stream_options_[0] = cparams_.options; + + // DC + for (size_t group_id = 0; group_id < frame_dim_.num_dc_groups; group_id++) { + const size_t gx = group_id % frame_dim_.xsize_dc_groups; + const size_t gy = group_id / frame_dim_.xsize_dc_groups; + const Rect rect(gx * frame_dim_.dc_group_dim, gy * frame_dim_.dc_group_dim, + frame_dim_.dc_group_dim, frame_dim_.dc_group_dim); + // minShift==3 because (frame_dim.dc_group_dim >> 3) == frame_dim.group_dim + // maxShift==1000 is infinity + stream_params.push_back( + GroupParams{rect, 3, 1000, ModularStreamId::ModularDC(group_id)}); + } + // AC global -> nothing. + // AC + for (size_t group_id = 0; group_id < frame_dim_.num_groups; group_id++) { + const size_t gx = group_id % frame_dim_.xsize_groups; + const size_t gy = group_id / frame_dim_.xsize_groups; + const Rect mrect(gx * frame_dim_.group_dim, gy * frame_dim_.group_dim, + frame_dim_.group_dim, frame_dim_.group_dim); + for (size_t i = 0; i < enc_state->progressive_splitter.GetNumPasses(); + i++) { + int maxShift, minShift; + frame_header.passes.GetDownsamplingBracket(i, minShift, maxShift); + stream_params.push_back(GroupParams{ + mrect, minShift, maxShift, ModularStreamId::ModularAC(group_id, i)}); + } + } + // if there's only one group, everything ends up in GlobalModular + // in that case, also try RCTs/WP params for the one group + if (stream_params.size() == 2) { + stream_params.push_back(GroupParams{Rect(0, 0, xsize, ysize), 0, 1000, + ModularStreamId::Global()}); + } + gi_channel_.resize(stream_images_.size()); + + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, stream_params.size(), ThreadPool::NoInit, + [&](const uint32_t i, size_t /* thread */) { + stream_options_[stream_params[i].id.ID(frame_dim_)] = cparams_.options; + JXL_CHECK(PrepareStreamParams( + stream_params[i].rect, cparams_, stream_params[i].minShift, + stream_params[i].maxShift, stream_params[i].id, do_color)); + }, + "ChooseParams")); + { + // Clear out channels that have been copied to groups. + Image& full_image = stream_images_[0]; + size_t c = full_image.nb_meta_channels; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + if (fc.w > frame_dim_.group_dim || fc.h > frame_dim_.group_dim) break; + } + for (; c < full_image.channel.size(); c++) { + full_image.channel[c].plane = ImageI(); + } + } + + JXL_RETURN_IF_ERROR(ValidateChannelDimensions(gi, stream_options_[0])); + return true; +} + +Status ModularFrameEncoder::ComputeTree(ThreadPool* pool) { + std::vector<ModularMultiplierInfo> multiplier_info; + if (!quants_.empty()) { + for (uint32_t stream_id = 0; stream_id < stream_images_.size(); + stream_id++) { + // skip non-modular stream_ids + if (stream_id > 0 && gi_channel_[stream_id].empty()) continue; + const Image& image = stream_images_[stream_id]; + const ModularOptions& options = stream_options_[stream_id]; + for (uint32_t i = image.nb_meta_channels; i < image.channel.size(); i++) { + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + continue; + } + if (stream_id > 0 && gi_channel_[stream_id].empty()) continue; + size_t ch_id = stream_id == 0 + ? i + : gi_channel_[stream_id][i - image.nb_meta_channels]; + uint32_t q = quants_[ch_id]; + // Inform the tree splitting heuristics that each channel in each group + // used this quantization factor. This will produce a tree with the + // given multipliers. + if (multiplier_info.empty() || + multiplier_info.back().range[1][0] != stream_id || + multiplier_info.back().multiplier != q) { + StaticPropRange range; + range[0] = {{i, i + 1}}; + range[1] = {{stream_id, stream_id + 1}}; + multiplier_info.push_back({range, (uint32_t)q}); + } else { + // Previous channel in the same group had the same quantization + // factor. Don't provide two different ranges, as that creates + // unnecessary nodes. + multiplier_info.back().range[0][1] = i + 1; + } + } + } + // Merge group+channel settings that have the same channels and quantization + // factors, to avoid unnecessary nodes. + std::sort(multiplier_info.begin(), multiplier_info.end(), + [](ModularMultiplierInfo a, ModularMultiplierInfo b) { + return std::make_tuple(a.range, a.multiplier) < + std::make_tuple(b.range, b.multiplier); + }); + size_t new_num = 1; + for (size_t i = 1; i < multiplier_info.size(); i++) { + ModularMultiplierInfo& prev = multiplier_info[new_num - 1]; + ModularMultiplierInfo& cur = multiplier_info[i]; + if (prev.range[0] == cur.range[0] && prev.multiplier == cur.multiplier && + prev.range[1][1] == cur.range[1][0]) { + prev.range[1][1] = cur.range[1][1]; + } else { + multiplier_info[new_num++] = multiplier_info[i]; + } + } + multiplier_info.resize(new_num); + } + + if (!cparams_.custom_fixed_tree.empty()) { + tree_ = cparams_.custom_fixed_tree; + } else if (cparams_.speed_tier < SpeedTier::kFalcon || + !cparams_.modular_mode) { + // Avoid creating a tree with leaves that don't correspond to any pixels. + std::vector<size_t> useful_splits; + useful_splits.reserve(tree_splits_.size()); + for (size_t chunk = 0; chunk < tree_splits_.size() - 1; chunk++) { + bool has_pixels = false; + size_t start = tree_splits_[chunk]; + size_t stop = tree_splits_[chunk + 1]; + for (size_t i = start; i < stop; i++) { + if (!stream_images_[i].empty()) has_pixels = true; + } + if (has_pixels) { + useful_splits.push_back(tree_splits_[chunk]); + } + } + // Don't do anything if modular mode does not have any pixels in this image + if (useful_splits.empty()) return true; + useful_splits.push_back(tree_splits_.back()); + + std::atomic_flag invalid_force_wp = ATOMIC_FLAG_INIT; + + std::vector<Tree> trees(useful_splits.size() - 1); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, useful_splits.size() - 1, ThreadPool::NoInit, + [&](const uint32_t chunk, size_t /* thread */) { + // TODO(veluca): parallelize more. + size_t total_pixels = 0; + uint32_t start = useful_splits[chunk]; + uint32_t stop = useful_splits[chunk + 1]; + while (start < stop && stream_images_[start].empty()) ++start; + while (start < stop && stream_images_[stop - 1].empty()) --stop; + if (stream_options_[start].tree_kind != + ModularOptions::TreeKind::kLearn) { + for (size_t i = start; i < stop; i++) { + for (const Channel& ch : stream_images_[i].channel) { + total_pixels += ch.w * ch.h; + } + } + trees[chunk] = + PredefinedTree(stream_options_[start].tree_kind, total_pixels); + return; + } + TreeSamples tree_samples; + if (!tree_samples.SetPredictor(stream_options_[start].predictor, + stream_options_[start].wp_tree_mode)) { + invalid_force_wp.test_and_set(std::memory_order_acq_rel); + return; + } + if (!tree_samples.SetProperties( + stream_options_[start].splitting_heuristics_properties, + stream_options_[start].wp_tree_mode)) { + invalid_force_wp.test_and_set(std::memory_order_acq_rel); + return; + } + uint32_t max_c = 0; + std::vector<pixel_type> pixel_samples; + std::vector<pixel_type> diff_samples; + std::vector<uint32_t> group_pixel_count; + std::vector<uint32_t> channel_pixel_count; + for (size_t i = start; i < stop; i++) { + max_c = std::max<uint32_t>(stream_images_[i].channel.size(), max_c); + CollectPixelSamples(stream_images_[i], stream_options_[i], i, + group_pixel_count, channel_pixel_count, + pixel_samples, diff_samples); + } + StaticPropRange range; + range[0] = {{0, max_c}}; + range[1] = {{start, stop}}; + auto local_multiplier_info = multiplier_info; + + tree_samples.PreQuantizeProperties( + range, local_multiplier_info, group_pixel_count, + channel_pixel_count, pixel_samples, diff_samples, + stream_options_[start].max_property_values); + for (size_t i = start; i < stop; i++) { + JXL_CHECK(ModularGenericCompress( + stream_images_[i], stream_options_[i], /*writer=*/nullptr, + /*aux_out=*/nullptr, 0, i, &tree_samples, &total_pixels)); + } + + // TODO(veluca): parallelize more. + trees[chunk] = + LearnTree(std::move(tree_samples), total_pixels, + stream_options_[start], local_multiplier_info, range); + }, + "LearnTrees")); + if (invalid_force_wp.test_and_set(std::memory_order_acq_rel)) { + return JXL_FAILURE("PrepareEncoding: force_no_wp with {Weighted}"); + } + tree_.clear(); + MergeTrees(trees, useful_splits, 0, useful_splits.size() - 1, &tree_); + } else { + // Fixed tree. + size_t total_pixels = 0; + for (const Image& img : stream_images_) { + for (const Channel& ch : img.channel) { + total_pixels += ch.w * ch.h; + } + } + if (cparams_.speed_tier <= SpeedTier::kFalcon) { + tree_ = + PredefinedTree(ModularOptions::TreeKind::kWPFixedDC, total_pixels); + } else if (cparams_.speed_tier <= SpeedTier::kThunder) { + tree_ = PredefinedTree(ModularOptions::TreeKind::kGradientFixedDC, + total_pixels); + } else { + tree_ = {PropertyDecisionNode::Leaf(Predictor::Gradient)}; + } + } + tree_tokens_.resize(1); + tree_tokens_[0].clear(); + Tree decoded_tree; + TokenizeTree(tree_, &tree_tokens_[0], &decoded_tree); + JXL_ASSERT(tree_.size() == decoded_tree.size()); + tree_ = std::move(decoded_tree); + + /* TODO(szabadka) Add text output callback to cparams + if (kPrintTree && WantDebugOutput(aux_out)) { + if (frame_header.dc_level > 0) { + PrintTree(tree_, aux_out->debug_prefix + "/dc_frame_level" + + std::to_string(frame_header.dc_level) + "_tree"); + } else { + PrintTree(tree_, aux_out->debug_prefix + "/global_tree"); + } + } */ + return true; +} + +Status ModularFrameEncoder::ComputeTokens(ThreadPool* pool) { + size_t num_streams = stream_images_.size(); + stream_headers_.resize(num_streams); + tokens_.resize(num_streams); + image_widths_.resize(num_streams); + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, num_streams, ThreadPool::NoInit, + [&](const uint32_t stream_id, size_t /* thread */) { + AuxOut my_aux_out; + tokens_[stream_id].clear(); + JXL_CHECK(ModularGenericCompress( + stream_images_[stream_id], stream_options_[stream_id], + /*writer=*/nullptr, &my_aux_out, 0, stream_id, + /*tree_samples=*/nullptr, + /*total_pixels=*/nullptr, + /*tree=*/&tree_, /*header=*/&stream_headers_[stream_id], + /*tokens=*/&tokens_[stream_id], + /*widths=*/&image_widths_[stream_id])); + }, + "ComputeTokens")); + return true; +} + +Status ModularFrameEncoder::EncodeGlobalInfo(bool streaming_mode, + BitWriter* writer, + AuxOut* aux_out) { + BitWriter::Allotment allotment(writer, 1); + // If we are using brotli, or not using modular mode. + if (tree_tokens_.empty() || tree_tokens_[0].empty()) { + writer->Write(1, 0); + allotment.ReclaimAndCharge(writer, kLayerModularTree, aux_out); + return true; + } + writer->Write(1, 1); + allotment.ReclaimAndCharge(writer, kLayerModularTree, aux_out); + + // Write tree + HistogramParams params; + if (cparams_.speed_tier > SpeedTier::kKitten) { + params.clustering = HistogramParams::ClusteringType::kFast; + params.ans_histogram_strategy = + cparams_.speed_tier > SpeedTier::kThunder + ? HistogramParams::ANSHistogramStrategy::kFast + : HistogramParams::ANSHistogramStrategy::kApproximate; + params.lz77_method = + cparams_.decoding_speed_tier >= 3 && cparams_.modular_mode + ? (cparams_.speed_tier >= SpeedTier::kFalcon + ? HistogramParams::LZ77Method::kRLE + : HistogramParams::LZ77Method::kLZ77) + : HistogramParams::LZ77Method::kNone; + // Near-lossless DC, as well as modular mode, require choosing hybrid uint + // more carefully. + if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) || + (cparams_.modular_mode && cparams_.speed_tier < SpeedTier::kCheetah)) { + params.uint_method = HistogramParams::HybridUintMethod::kFast; + } else { + params.uint_method = HistogramParams::HybridUintMethod::kNone; + } + } else if (cparams_.speed_tier <= SpeedTier::kTortoise) { + params.lz77_method = HistogramParams::LZ77Method::kOptimal; + } else { + params.lz77_method = HistogramParams::LZ77Method::kLZ77; + } + if (cparams_.decoding_speed_tier >= 1) { + params.max_histograms = 12; + } + if (cparams_.decoding_speed_tier >= 1 && cparams_.responsive) { + params.lz77_method = cparams_.speed_tier >= SpeedTier::kCheetah + ? HistogramParams::LZ77Method::kRLE + : cparams_.speed_tier >= SpeedTier::kKitten + ? HistogramParams::LZ77Method::kLZ77 + : HistogramParams::LZ77Method::kOptimal; + } + if (cparams_.decoding_speed_tier >= 2 && cparams_.responsive) { + params.uint_method = HistogramParams::HybridUintMethod::k000; + params.force_huffman = true; + } + { + EntropyEncodingData tree_code; + std::vector<uint8_t> tree_context_map; + BuildAndEncodeHistograms(params, kNumTreeContexts, tree_tokens_, &tree_code, + &tree_context_map, writer, kLayerModularTree, + aux_out); + WriteTokens(tree_tokens_[0], tree_code, tree_context_map, 0, writer, + kLayerModularTree, aux_out); + } + params.streaming_mode = streaming_mode; + params.add_missing_symbols = streaming_mode; + params.image_widths = image_widths_; + // Write histograms. + BuildAndEncodeHistograms(params, (tree_.size() + 1) / 2, tokens_, &code_, + &context_map_, writer, kLayerModularGlobal, aux_out); + return true; +} + +Status ModularFrameEncoder::EncodeStream(BitWriter* writer, AuxOut* aux_out, + size_t layer, + const ModularStreamId& stream) { + size_t stream_id = stream.ID(frame_dim_); + if (stream_images_[stream_id].channel.empty()) { + return true; // Image with no channels, header never gets decoded. + } + if (tokens_.empty()) { + JXL_RETURN_IF_ERROR(ModularGenericCompress( + stream_images_[stream_id], stream_options_[stream_id], writer, aux_out, + layer, stream_id)); + } else { + JXL_RETURN_IF_ERROR( + Bundle::Write(stream_headers_[stream_id], writer, layer, aux_out)); + WriteTokens(tokens_[stream_id], code_, context_map_, 0, writer, layer, + aux_out); + } + return true; +} + +void ModularFrameEncoder::ClearStreamData(const ModularStreamId& stream) { + size_t stream_id = stream.ID(frame_dim_); + Image empty_image; + std::swap(stream_images_[stream_id], empty_image); +} + +namespace { +float EstimateWPCost(const Image& img, size_t i) { + size_t extra_bits = 0; + float histo_cost = 0; + HybridUintConfig config; + int32_t cutoffs[] = {-500, -392, -255, -191, -127, -95, -63, -47, -31, + -23, -15, -11, -7, -4, -3, -1, 0, 1, + 3, 5, 7, 11, 15, 23, 31, 47, 63, + 95, 127, 191, 255, 392, 500}; + constexpr size_t nc = sizeof(cutoffs) / sizeof(*cutoffs) + 1; + Histogram histo[nc] = {}; + weighted::Header wp_header; + PredictorMode(i, &wp_header); + for (const Channel& ch : img.channel) { + const intptr_t onerow = ch.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, ch.w, ch.h); + Properties properties(1); + for (size_t y = 0; y < ch.h; y++) { + const pixel_type* JXL_RESTRICT r = ch.Row(y); + for (size_t x = 0; x < ch.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < ch.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + pixel_type guess = wp_state.Predict</*compute_properties=*/true>( + x, y, ch.w, top, left, topright, topleft, toptop, &properties, + offset); + size_t ctx = 0; + for (int c : cutoffs) { + ctx += c >= properties[0]; + } + pixel_type res = r[x] - guess; + uint32_t token, nbits, bits; + config.Encode(PackSigned(res), &token, &nbits, &bits); + histo[ctx].Add(token); + extra_bits += nbits; + wp_state.UpdateErrors(r[x], x, y, ch.w); + } + } + for (size_t h = 0; h < nc; h++) { + histo_cost += histo[h].ShannonEntropy(); + histo[h].Clear(); + } + } + return histo_cost + extra_bits; +} + +float EstimateCost(const Image& img) { + // TODO(veluca): consider SIMDfication of this code. + size_t extra_bits = 0; + float histo_cost = 0; + HybridUintConfig config; + uint32_t cutoffs[] = {0, 1, 3, 5, 7, 11, 15, 23, 31, + 47, 63, 95, 127, 191, 255, 392, 500}; + constexpr size_t nc = sizeof(cutoffs) / sizeof(*cutoffs) + 1; + Histogram histo[nc] = {}; + for (const Channel& ch : img.channel) { + const intptr_t onerow = ch.plane.PixelsPerRow(); + for (size_t y = 0; y < ch.h; y++) { + const pixel_type* JXL_RESTRICT r = ch.Row(y); + for (size_t x = 0; x < ch.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + size_t maxdiff = std::max(std::max(left, top), topleft) - + std::min(std::min(left, top), topleft); + size_t ctx = 0; + for (uint32_t c : cutoffs) { + ctx += c > maxdiff; + } + pixel_type res = r[x] - ClampedGradient(top, left, topleft); + uint32_t token, nbits, bits; + config.Encode(PackSigned(res), &token, &nbits, &bits); + histo[ctx].Add(token); + extra_bits += nbits; + } + } + for (size_t h = 0; h < nc; h++) { + histo_cost += histo[h].ShannonEntropy(); + histo[h].Clear(); + } + } + return histo_cost + extra_bits; +} + +} // namespace + +Status ModularFrameEncoder::PrepareStreamParams(const Rect& rect, + const CompressParams& cparams_, + int minShift, int maxShift, + const ModularStreamId& stream, + bool do_color) { + size_t stream_id = stream.ID(frame_dim_); + Image& full_image = stream_images_[0]; + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + Image& gi = stream_images_[stream_id]; + if (stream_id > 0) { + gi = Image(xsize, ysize, full_image.bitdepth, 0); + // start at the first bigger-than-frame_dim.group_dim non-metachannel + size_t c = full_image.nb_meta_channels; + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + if (fc.w > frame_dim_.group_dim || fc.h > frame_dim_.group_dim) break; + } + for (; c < full_image.channel.size(); c++) { + Channel& fc = full_image.channel[c]; + int shift = std::min(fc.hshift, fc.vshift); + if (shift > maxShift) continue; + if (shift < minShift) continue; + Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, + rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); + if (r.xsize() == 0 || r.ysize() == 0) continue; + gi_channel_[stream_id].push_back(c); + Channel gc(r.xsize(), r.ysize()); + gc.hshift = fc.hshift; + gc.vshift = fc.vshift; + for (size_t y = 0; y < r.ysize(); ++y) { + memcpy(gc.Row(y), r.ConstRow(fc.plane, y), + r.xsize() * sizeof(pixel_type)); + } + gi.channel.emplace_back(std::move(gc)); + } + + if (gi.channel.empty()) return true; + // Do some per-group transforms + + // Local palette + // TODO(veluca): make this work with quantize-after-prediction in lossy + // mode. + if (cparams_.butteraugli_distance == 0.f && cparams_.palette_colors != 0 && + cparams_.speed_tier < SpeedTier::kCheetah) { + // all-channel palette (e.g. RGBA) + if (gi.channel.size() - gi.nb_meta_channels > 1) { + Transform maybe_palette(TransformId::kPalette); + maybe_palette.begin_c = gi.nb_meta_channels; + maybe_palette.num_c = gi.channel.size() - gi.nb_meta_channels; + maybe_palette.nb_colors = std::abs(cparams_.palette_colors); + maybe_palette.ordered_palette = cparams_.palette_colors >= 0; + do_transform(gi, maybe_palette, weighted::Header()); + } + // all-minus-one-channel palette (RGB with separate alpha, or CMY with + // separate K) + if (gi.channel.size() - gi.nb_meta_channels > 3) { + Transform maybe_palette_3(TransformId::kPalette); + maybe_palette_3.begin_c = gi.nb_meta_channels; + maybe_palette_3.num_c = gi.channel.size() - gi.nb_meta_channels - 1; + maybe_palette_3.nb_colors = std::abs(cparams_.palette_colors); + maybe_palette_3.ordered_palette = cparams_.palette_colors >= 0; + maybe_palette_3.lossy_palette = cparams_.lossy_palette; + if (maybe_palette_3.lossy_palette) { + maybe_palette_3.predictor = Predictor::Weighted; + } + do_transform(gi, maybe_palette_3, weighted::Header()); + } + } + + // Local channel palette + if (cparams_.channel_colors_percent > 0 && + cparams_.butteraugli_distance == 0.f && !cparams_.lossy_palette && + cparams_.speed_tier < SpeedTier::kCheetah && + !(cparams_.responsive && cparams_.decoding_speed_tier >= 1)) { + // single channel palette (like FLIF's ChannelCompact) + size_t nb_channels = gi.channel.size() - gi.nb_meta_channels; + for (size_t i = 0; i < nb_channels; i++) { + int32_t min, max; + compute_minmax(gi.channel[gi.nb_meta_channels + i], &min, &max); + int64_t colors = (int64_t)max - min + 1; + JXL_DEBUG_V(10, "Channel %" PRIuS ": range=%i..%i", i, min, max); + Transform maybe_palette_1(TransformId::kPalette); + maybe_palette_1.begin_c = i + gi.nb_meta_channels; + maybe_palette_1.num_c = 1; + // simple heuristic: if less than X percent of the values in the range + // actually occur, it is probably worth it to do a compaction + // (but only if the channel palette is less than 80% the size of the + // image itself) + maybe_palette_1.nb_colors = + std::min((int)(xsize * ysize * 0.8), + (int)(cparams_.channel_colors_percent / 100. * colors)); + do_transform(gi, maybe_palette_1, weighted::Header()); + } + } + } + + // lossless and no specific color transform specified: try Nothing, YCoCg, + // and 17 RCTs + if (cparams_.color_transform == ColorTransform::kNone && + cparams_.IsLossless() && cparams_.colorspace < 0 && + gi.channel.size() - gi.nb_meta_channels >= 3 && + cparams_.responsive == false && do_color && + cparams_.speed_tier <= SpeedTier::kHare) { + Transform sg(TransformId::kRCT); + sg.begin_c = gi.nb_meta_channels; + size_t nb_rcts_to_try = 0; + switch (cparams_.speed_tier) { + case SpeedTier::kLightning: + case SpeedTier::kThunder: + case SpeedTier::kFalcon: + case SpeedTier::kCheetah: + nb_rcts_to_try = 0; // Just do global YCoCg + break; + case SpeedTier::kHare: + nb_rcts_to_try = 4; + break; + case SpeedTier::kWombat: + nb_rcts_to_try = 5; + break; + case SpeedTier::kSquirrel: + nb_rcts_to_try = 7; + break; + case SpeedTier::kKitten: + nb_rcts_to_try = 9; + break; + case SpeedTier::kGlacier: + case SpeedTier::kTortoise: + nb_rcts_to_try = 19; + break; + } + float best_cost = std::numeric_limits<float>::max(); + size_t best_rct = 0; + // These should be 19 actually different transforms; the remaining ones + // are equivalent to one of these (note that the first two are do-nothing + // and YCoCg) modulo channel reordering (which only matters in the case of + // MA-with-prev-channels-properties) and/or sign (e.g. RmG vs GmR) + for (int i : {0 * 7 + 0, 0 * 7 + 6, 0 * 7 + 5, 1 * 7 + 3, 3 * 7 + 5, + 5 * 7 + 5, 1 * 7 + 5, 2 * 7 + 5, 1 * 7 + 1, 0 * 7 + 4, + 1 * 7 + 2, 2 * 7 + 1, 2 * 7 + 2, 2 * 7 + 3, 4 * 7 + 4, + 4 * 7 + 5, 0 * 7 + 2, 0 * 7 + 1, 0 * 7 + 3}) { + if (nb_rcts_to_try == 0) break; + sg.rct_type = i; + nb_rcts_to_try--; + if (do_transform(gi, sg, weighted::Header())) { + float cost = EstimateCost(gi); + if (cost < best_cost) { + best_rct = i; + best_cost = cost; + } + Transform t = gi.transform.back(); + JXL_RETURN_IF_ERROR(t.Inverse(gi, weighted::Header(), nullptr)); + gi.transform.pop_back(); + } + } + // Apply the best RCT to the image for future encoding. + sg.rct_type = best_rct; + do_transform(gi, sg, weighted::Header()); + } else { + // No need to try anything, just use the default options. + } + size_t nb_wp_modes = 1; + if (cparams_.speed_tier <= SpeedTier::kTortoise) { + nb_wp_modes = 5; + } else if (cparams_.speed_tier <= SpeedTier::kKitten) { + nb_wp_modes = 2; + } + if (nb_wp_modes > 1 && + (stream_options_[stream_id].predictor == Predictor::Weighted || + stream_options_[stream_id].predictor == Predictor::Best || + stream_options_[stream_id].predictor == Predictor::Variable)) { + float best_cost = std::numeric_limits<float>::max(); + stream_options_[stream_id].wp_mode = 0; + for (size_t i = 0; i < nb_wp_modes; i++) { + float cost = EstimateWPCost(gi, i); + if (cost < best_cost) { + best_cost = cost; + stream_options_[stream_id].wp_mode = i; + } + } + } + return true; +} + +constexpr float q_deadzone = 0.62f; +int QuantizeWP(const int32_t* qrow, size_t onerow, size_t c, size_t x, size_t y, + size_t w, weighted::State* wp_state, float value, + float inv_factor) { + float svalue = value * inv_factor; + PredictionResult pred = + PredictNoTreeWP(w, qrow + x, onerow, x, y, Predictor::Weighted, wp_state); + svalue -= pred.guess; + if (svalue > -q_deadzone && svalue < q_deadzone) svalue = 0; + int residual = roundf(svalue); + if (residual > 2 || residual < -2) residual = roundf(svalue * 0.5) * 2; + return residual + pred.guess; +} + +int QuantizeGradient(const int32_t* qrow, size_t onerow, size_t c, size_t x, + size_t y, size_t w, float value, float inv_factor) { + float svalue = value * inv_factor; + PredictionResult pred = + PredictNoTreeNoWP(w, qrow + x, onerow, x, y, Predictor::Gradient); + svalue -= pred.guess; + if (svalue > -q_deadzone && svalue < q_deadzone) svalue = 0; + int residual = roundf(svalue); + if (residual > 2 || residual < -2) residual = roundf(svalue * 0.5) * 2; + return residual + pred.guess; +} + +void ModularFrameEncoder::AddVarDCTDC(const FrameHeader& frame_header, + const Image3F& dc, const Rect& r, + size_t group_index, bool nl_dc, + PassesEncoderState* enc_state, + bool jpeg_transcode) { + extra_dc_precision[group_index] = nl_dc ? 1 : 0; + float mul = 1 << extra_dc_precision[group_index]; + + size_t stream_id = ModularStreamId::VarDCTDC(group_index).ID(frame_dim_); + stream_options_[stream_id].max_chan_size = 0xFFFFFF; + stream_options_[stream_id].predictor = Predictor::Weighted; + stream_options_[stream_id].wp_tree_mode = ModularOptions::TreeMode::kWPOnly; + if (cparams_.speed_tier >= SpeedTier::kSquirrel) { + stream_options_[stream_id].tree_kind = ModularOptions::TreeKind::kWPFixedDC; + } + if (cparams_.speed_tier < SpeedTier::kSquirrel && !nl_dc) { + stream_options_[stream_id].predictor = + (cparams_.speed_tier < SpeedTier::kKitten ? Predictor::Variable + : Predictor::Best); + stream_options_[stream_id].wp_tree_mode = + ModularOptions::TreeMode::kDefault; + stream_options_[stream_id].tree_kind = ModularOptions::TreeKind::kLearn; + } + if (cparams_.decoding_speed_tier >= 1) { + stream_options_[stream_id].tree_kind = + ModularOptions::TreeKind::kGradientFixedDC; + } + + stream_images_[stream_id] = Image(r.xsize(), r.ysize(), 8, 3); + if (nl_dc && stream_options_[stream_id].tree_kind == + ModularOptions::TreeKind::kGradientFixedDC) { + JXL_ASSERT(frame_header.chroma_subsampling.Is444()); + for (size_t c : {1, 0, 2}) { + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul; + float cfl_factor = enc_state->shared.cmap.DCFactors()[c]; + for (size_t y = 0; y < r.ysize(); y++) { + int32_t* quant_row = + stream_images_[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y); + size_t stride = stream_images_[stream_id] + .channel[c < 2 ? c ^ 1 : c] + .plane.PixelsPerRow(); + const float* row = r.ConstPlaneRow(dc, c, y); + if (c == 1) { + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeGradient(quant_row, stride, c, x, y, + r.xsize(), row[x], inv_factor); + } + } else { + int32_t* quant_row_y = + stream_images_[stream_id].channel[0].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeGradient( + quant_row, stride, c, x, y, r.xsize(), + row[x] - quant_row_y[x] * (y_factor * cfl_factor), inv_factor); + } + } + } + } + } else if (nl_dc) { + JXL_ASSERT(frame_header.chroma_subsampling.Is444()); + for (size_t c : {1, 0, 2}) { + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul; + float cfl_factor = enc_state->shared.cmap.DCFactors()[c]; + weighted::Header header; + weighted::State wp_state(header, r.xsize(), r.ysize()); + for (size_t y = 0; y < r.ysize(); y++) { + int32_t* quant_row = + stream_images_[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y); + size_t stride = stream_images_[stream_id] + .channel[c < 2 ? c ^ 1 : c] + .plane.PixelsPerRow(); + const float* row = r.ConstPlaneRow(dc, c, y); + if (c == 1) { + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeWP(quant_row, stride, c, x, y, r.xsize(), + &wp_state, row[x], inv_factor); + wp_state.UpdateErrors(quant_row[x], x, y, r.xsize()); + } + } else { + int32_t* quant_row_y = + stream_images_[stream_id].channel[0].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = QuantizeWP( + quant_row, stride, c, x, y, r.xsize(), &wp_state, + row[x] - quant_row_y[x] * (y_factor * cfl_factor), inv_factor); + wp_state.UpdateErrors(quant_row[x], x, y, r.xsize()); + } + } + } + } + } else if (frame_header.chroma_subsampling.Is444()) { + for (size_t c : {1, 0, 2}) { + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul; + float cfl_factor = enc_state->shared.cmap.DCFactors()[c]; + for (size_t y = 0; y < r.ysize(); y++) { + int32_t* quant_row = + stream_images_[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y); + const float* row = r.ConstPlaneRow(dc, c, y); + if (c == 1) { + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = roundf(row[x] * inv_factor); + } + } else { + int32_t* quant_row_y = + stream_images_[stream_id].channel[0].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + quant_row[x] = + roundf((row[x] - quant_row_y[x] * (y_factor * cfl_factor)) * + inv_factor); + } + } + } + } + } else { + for (size_t c : {1, 0, 2}) { + Rect rect(r.x0() >> frame_header.chroma_subsampling.HShift(c), + r.y0() >> frame_header.chroma_subsampling.VShift(c), + r.xsize() >> frame_header.chroma_subsampling.HShift(c), + r.ysize() >> frame_header.chroma_subsampling.VShift(c)); + float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul; + size_t ys = rect.ysize(); + size_t xs = rect.xsize(); + Channel& ch = stream_images_[stream_id].channel[c < 2 ? c ^ 1 : c]; + ch.w = xs; + ch.h = ys; + ch.shrink(); + for (size_t y = 0; y < ys; y++) { + int32_t* quant_row = ch.plane.Row(y); + const float* row = rect.ConstPlaneRow(dc, c, y); + for (size_t x = 0; x < xs; x++) { + quant_row[x] = roundf(row[x] * inv_factor); + } + } + } + } + + DequantDC(r, &enc_state->shared.dc_storage, &enc_state->shared.quant_dc, + stream_images_[stream_id], enc_state->shared.quantizer.MulDC(), + 1.0 / mul, enc_state->shared.cmap.DCFactors(), + frame_header.chroma_subsampling, enc_state->shared.block_ctx_map); +} + +void ModularFrameEncoder::AddACMetadata(const Rect& r, size_t group_index, + bool jpeg_transcode, + PassesEncoderState* enc_state) { + size_t stream_id = ModularStreamId::ACMetadata(group_index).ID(frame_dim_); + stream_options_[stream_id].max_chan_size = 0xFFFFFF; + stream_options_[stream_id].wp_tree_mode = ModularOptions::TreeMode::kNoWP; + if (jpeg_transcode) { + stream_options_[stream_id].tree_kind = + ModularOptions::TreeKind::kJpegTranscodeACMeta; + } else if (cparams_.speed_tier >= SpeedTier::kFalcon) { + stream_options_[stream_id].tree_kind = + ModularOptions::TreeKind::kFalconACMeta; + } else if (cparams_.speed_tier > SpeedTier::kKitten) { + stream_options_[stream_id].tree_kind = ModularOptions::TreeKind::kACMeta; + } + // If we are using a non-constant CfL field, and are in a slow enough mode, + // re-enable tree computation for it. + if (cparams_.speed_tier < SpeedTier::kSquirrel && + cparams_.force_cfl_jpeg_recompression) { + stream_options_[stream_id].tree_kind = ModularOptions::TreeKind::kLearn; + } + // YToX, YToB, ACS + QF, EPF + Image& image = stream_images_[stream_id]; + image = Image(r.xsize(), r.ysize(), 8, 4); + static_assert(kColorTileDimInBlocks == 8, "Color tile size changed"); + Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3); + image.channel[0] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[1] = Channel(cr.xsize(), cr.ysize(), 3, 3); + image.channel[2] = Channel(r.xsize() * r.ysize(), 2, 0, 0); + ConvertPlaneAndClamp(cr, enc_state->shared.cmap.ytox_map, + Rect(image.channel[0].plane), &image.channel[0].plane); + ConvertPlaneAndClamp(cr, enc_state->shared.cmap.ytob_map, + Rect(image.channel[1].plane), &image.channel[1].plane); + size_t num = 0; + for (size_t y = 0; y < r.ysize(); y++) { + AcStrategyRow row_acs = enc_state->shared.ac_strategy.ConstRow(r, y); + const int32_t* row_qf = r.ConstRow(enc_state->shared.raw_quant_field, y); + const uint8_t* row_epf = r.ConstRow(enc_state->shared.epf_sharpness, y); + int32_t* out_acs = image.channel[2].plane.Row(0); + int32_t* out_qf = image.channel[2].plane.Row(1); + int32_t* row_out_epf = image.channel[3].plane.Row(y); + for (size_t x = 0; x < r.xsize(); x++) { + row_out_epf[x] = row_epf[x]; + if (!row_acs[x].IsFirstBlock()) continue; + out_acs[num] = row_acs[x].RawStrategy(); + out_qf[num] = row_qf[x] - 1; + num++; + } + } + image.channel[2].w = num; + ac_metadata_size[group_index] = num; +} + +void ModularFrameEncoder::EncodeQuantTable( + size_t size_x, size_t size_y, BitWriter* writer, + const QuantEncoding& encoding, size_t idx, + ModularFrameEncoder* modular_frame_encoder) { + JXL_ASSERT(encoding.qraw.qtable != nullptr); + JXL_ASSERT(size_x * size_y * 3 == encoding.qraw.qtable->size()); + JXL_CHECK(F16Coder::Write(encoding.qraw.qtable_den, writer)); + if (modular_frame_encoder) { + JXL_CHECK(modular_frame_encoder->EncodeStream( + writer, nullptr, 0, ModularStreamId::QuantTable(idx))); + return; + } + Image image(size_x, size_y, 8, 3); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < size_y; y++) { + int32_t* JXL_RESTRICT row = image.channel[c].Row(y); + for (size_t x = 0; x < size_x; x++) { + row[x] = (*encoding.qraw.qtable)[c * size_x * size_y + y * size_x + x]; + } + } + } + ModularOptions cfopts; + JXL_CHECK(ModularGenericCompress(image, cfopts, writer)); +} + +void ModularFrameEncoder::AddQuantTable(size_t size_x, size_t size_y, + const QuantEncoding& encoding, + size_t idx) { + size_t stream_id = ModularStreamId::QuantTable(idx).ID(frame_dim_); + JXL_ASSERT(encoding.qraw.qtable != nullptr); + JXL_ASSERT(size_x * size_y * 3 == encoding.qraw.qtable->size()); + Image& image = stream_images_[stream_id]; + image = Image(size_x, size_y, 8, 3); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < size_y; y++) { + int32_t* JXL_RESTRICT row = image.channel[c].Row(y); + for (size_t x = 0; x < size_x; x++) { + row[x] = (*encoding.qraw.qtable)[c * size_x * size_y + y * size_x + x]; + } + } + } +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_modular.h b/third_party/jpeg-xl/lib/jxl/enc_modular.h new file mode 100644 index 0000000000..2158a781af --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_modular.h @@ -0,0 +1,96 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_MODULAR_H_ +#define LIB_JXL_ENC_MODULAR_H_ + +#include <cstdint> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +struct AuxOut; + +class ModularFrameEncoder { + public: + ModularFrameEncoder(const FrameHeader& frame_header, + const CompressParams& cparams_orig); + Status ComputeEncodingData(const FrameHeader& frame_header, + const ImageMetadata& metadata, + Image3F* JXL_RESTRICT color, + const std::vector<ImageF>& extra_channels, + PassesEncoderState* JXL_RESTRICT enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out, bool do_color); + Status ComputeTree(ThreadPool* pool); + Status ComputeTokens(ThreadPool* pool); + // Encodes global info (tree + histograms) in the `writer`. + Status EncodeGlobalInfo(bool streaming_mode, BitWriter* writer, + AuxOut* aux_out); + // Encodes a specific modular image (identified by `stream`) in the `writer`, + // assigning bits to the provided `layer`. + Status EncodeStream(BitWriter* writer, AuxOut* aux_out, size_t layer, + const ModularStreamId& stream); + void ClearStreamData(const ModularStreamId& stream); + // Creates a modular image for a given DC group of VarDCT mode. `dc` is the + // input DC image, not quantized; the group is specified by `group_index`, and + // `nl_dc` decides whether to apply a near-lossless processing to the DC or + // not. + void AddVarDCTDC(const FrameHeader& frame_header, const Image3F& dc, + const Rect& r, size_t group_index, bool nl_dc, + PassesEncoderState* enc_state, bool jpeg_transcode); + // Creates a modular image for the AC metadata of the given group + // (`group_index`). + void AddACMetadata(const Rect& r, size_t group_index, bool jpeg_transcode, + PassesEncoderState* enc_state); + // Encodes a RAW quantization table in `writer`. If `modular_frame_encoder` is + // null, the quantization table in `encoding` is used, with dimensions `size_x + // x size_y`. Otherwise, the table with ID `idx` is encoded from the given + // `modular_frame_encoder`. + static void EncodeQuantTable(size_t size_x, size_t size_y, BitWriter* writer, + const QuantEncoding& encoding, size_t idx, + ModularFrameEncoder* modular_frame_encoder); + // Stores a quantization table for future usage with `EncodeQuantTable`. + void AddQuantTable(size_t size_x, size_t size_y, + const QuantEncoding& encoding, size_t idx); + + std::vector<size_t> ac_metadata_size; + std::vector<uint8_t> extra_dc_precision; + + private: + Status PrepareStreamParams(const Rect& rect, const CompressParams& cparams, + int minShift, int maxShift, + const ModularStreamId& stream, bool do_color); + std::vector<Image> stream_images_; + std::vector<ModularOptions> stream_options_; + std::vector<uint32_t> quants_; + + Tree tree_; + std::vector<std::vector<Token>> tree_tokens_; + std::vector<GroupHeader> stream_headers_; + std::vector<std::vector<Token>> tokens_; + EntropyEncodingData code_; + std::vector<uint8_t> context_map_; + FrameDimensions frame_dim_; + CompressParams cparams_; + std::vector<size_t> tree_splits_; + std::vector<std::vector<uint32_t>> gi_channel_; + std::vector<size_t> image_widths_; + Predictor delta_pred_ = Predictor::Average4; +}; + +} // namespace jxl + +#endif // LIB_JXL_ENC_MODULAR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_noise.cc b/third_party/jpeg-xl/lib/jxl/enc_noise.cc new file mode 100644 index 0000000000..a12a9e6dc4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_noise.cc @@ -0,0 +1,372 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_noise.h" + +#include <stdint.h> +#include <stdlib.h> + +#include <algorithm> +#include <numeric> +#include <utility> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_optimize.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { +namespace { + +using OptimizeArray = optimize::Array<double, NoiseParams::kNumNoisePoints>; + +float GetScoreSumsOfAbsoluteDifferences(const Image3F& opsin, const int x, + const int y, const int block_size) { + const int small_bl_size_x = 3; + const int small_bl_size_y = 4; + const int kNumSAD = + (block_size - small_bl_size_x) * (block_size - small_bl_size_y); + // block_size x block_size reference pixels + int counter = 0; + const int offset = 2; + + std::vector<float> sad(kNumSAD, 0); + for (int y_bl = 0; y_bl + small_bl_size_y < block_size; ++y_bl) { + for (int x_bl = 0; x_bl + small_bl_size_x < block_size; ++x_bl) { + float sad_sum = 0; + // size of the center patch, we compare all the patches inside window with + // the center one + for (int cy = 0; cy < small_bl_size_y; ++cy) { + for (int cx = 0; cx < small_bl_size_x; ++cx) { + float wnd = 0.5f * (opsin.PlaneRow(1, y + y_bl + cy)[x + x_bl + cx] + + opsin.PlaneRow(0, y + y_bl + cy)[x + x_bl + cx]); + float center = + 0.5f * (opsin.PlaneRow(1, y + offset + cy)[x + offset + cx] + + opsin.PlaneRow(0, y + offset + cy)[x + offset + cx]); + sad_sum += std::abs(center - wnd); + } + } + sad[counter++] = sad_sum; + } + } + const int kSamples = (kNumSAD) / 2; + // As with ROAD (rank order absolute distance), we keep the smallest half of + // the values in SAD (we use here the more robust patch SAD instead of + // absolute single-pixel differences). + std::sort(sad.begin(), sad.end()); + const float total_sad_sum = + std::accumulate(sad.begin(), sad.begin() + kSamples, 0.0f); + return total_sad_sum / kSamples; +} + +class NoiseHistogram { + public: + static constexpr int kBins = 256; + + NoiseHistogram() { std::fill(bins, bins + kBins, 0); } + + void Increment(const float x) { bins[Index(x)] += 1; } + int Get(const float x) const { return bins[Index(x)]; } + int Bin(const size_t bin) const { return bins[bin]; } + + int Mode() const { + size_t max_idx = 0; + for (size_t i = 0; i < kBins; i++) { + if (bins[i] > bins[max_idx]) max_idx = i; + } + return max_idx; + } + + double Quantile(double q01) const { + const int64_t total = std::accumulate(bins, bins + kBins, int64_t{1}); + const int64_t target = static_cast<int64_t>(q01 * total); + // Until sum >= target: + int64_t sum = 0; + size_t i = 0; + for (; i < kBins; ++i) { + sum += bins[i]; + // Exact match: assume middle of bin i + if (sum == target) { + return i + 0.5; + } + if (sum > target) break; + } + + // Next non-empty bin (in case histogram is sparsely filled) + size_t next = i + 1; + while (next < kBins && bins[next] == 0) { + ++next; + } + + // Linear interpolation according to how far into next we went + const double excess = target - sum; + const double weight_next = bins[Index(next)] / excess; + return ClampX(next * weight_next + i * (1.0 - weight_next)); + } + + // Inter-quartile range + double IQR() const { return Quantile(0.75) - Quantile(0.25); } + + private: + template <typename T> + T ClampX(const T x) const { + return std::min(std::max(T(0), x), T(kBins - 1)); + } + size_t Index(const float x) const { return ClampX(static_cast<int>(x)); } + + uint32_t bins[kBins]; +}; + +std::vector<float> GetSADScoresForPatches(const Image3F& opsin, + const size_t block_s, + const size_t num_bin, + NoiseHistogram* sad_histogram) { + std::vector<float> sad_scores( + (opsin.ysize() / block_s) * (opsin.xsize() / block_s), 0.0f); + + int block_index = 0; + + for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) { + for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) { + float sad_sc = GetScoreSumsOfAbsoluteDifferences(opsin, x, y, block_s); + sad_scores[block_index++] = sad_sc; + sad_histogram->Increment(sad_sc * num_bin); + } + } + return sad_scores; +} + +float GetSADThreshold(const NoiseHistogram& histogram, const int num_bin) { + // Here we assume that the most patches with similar SAD value is a "flat" + // patches. However, some images might contain regular texture part and + // generate second strong peak at the histogram + // TODO(user) handle bimodal and heavy-tailed case + const int mode = histogram.Mode(); + return static_cast<float>(mode) / NoiseHistogram::kBins; +} + +// loss = sum asym * (F(x) - nl)^2 + kReg * num_points * sum (w[i] - w[i+1])^2 +// where asym = 1 if F(x) < nl, kAsym if F(x) > nl. +struct LossFunction { + explicit LossFunction(std::vector<NoiseLevel> nl0) : nl(std::move(nl0)) {} + + double Compute(const OptimizeArray& w, OptimizeArray* df, + bool skip_regularization = false) const { + constexpr double kReg = 0.005; + constexpr double kAsym = 1.1; + double loss_function = 0; + for (size_t i = 0; i < w.size(); i++) { + (*df)[i] = 0; + } + for (auto ind : nl) { + std::pair<int, float> pos = IndexAndFrac(ind.intensity); + JXL_DASSERT(pos.first >= 0 && static_cast<size_t>(pos.first) < + NoiseParams::kNumNoisePoints - 1); + double low = w[pos.first]; + double hi = w[pos.first + 1]; + double val = low * (1.0f - pos.second) + hi * pos.second; + double dist = val - ind.noise_level; + if (dist > 0) { + loss_function += kAsym * dist * dist; + (*df)[pos.first] -= kAsym * (1.0f - pos.second) * dist; + (*df)[pos.first + 1] -= kAsym * pos.second * dist; + } else { + loss_function += dist * dist; + (*df)[pos.first] -= (1.0f - pos.second) * dist; + (*df)[pos.first + 1] -= pos.second * dist; + } + } + if (skip_regularization) return loss_function; + for (size_t i = 0; i + 1 < w.size(); i++) { + double diff = w[i] - w[i + 1]; + loss_function += kReg * nl.size() * diff * diff; + (*df)[i] -= kReg * diff * nl.size(); + (*df)[i + 1] += kReg * diff * nl.size(); + } + return loss_function; + } + + std::vector<NoiseLevel> nl; +}; + +void OptimizeNoiseParameters(const std::vector<NoiseLevel>& noise_level, + NoiseParams* noise_params) { + constexpr double kMaxError = 1e-3; + static const double kPrecision = 1e-8; + static const int kMaxIter = 40; + + float avg = 0; + for (const NoiseLevel& nl : noise_level) { + avg += nl.noise_level; + } + avg /= noise_level.size(); + + LossFunction loss_function(noise_level); + OptimizeArray parameter_vector; + for (size_t i = 0; i < parameter_vector.size(); i++) { + parameter_vector[i] = avg; + } + + parameter_vector = optimize::OptimizeWithScaledConjugateGradientMethod( + loss_function, parameter_vector, kPrecision, kMaxIter); + + OptimizeArray df = parameter_vector; + float loss = loss_function.Compute(parameter_vector, &df, + /*skip_regularization=*/true) / + noise_level.size(); + + // Approximation went too badly: escape with no noise at all. + if (loss > kMaxError) { + noise_params->Clear(); + return; + } + + for (size_t i = 0; i < parameter_vector.size(); i++) { + noise_params->lut[i] = std::max(parameter_vector[i], 0.0); + } +} + +std::vector<NoiseLevel> GetNoiseLevel( + const Image3F& opsin, const std::vector<float>& texture_strength, + const float threshold, const size_t block_s) { + std::vector<NoiseLevel> noise_level_per_intensity; + + const int filt_size = 1; + static const float kLaplFilter[filt_size * 2 + 1][filt_size * 2 + 1] = { + {-0.25f, -1.0f, -0.25f}, + {-1.0f, 5.0f, -1.0f}, + {-0.25f, -1.0f, -0.25f}, + }; + + // The noise model is built based on channel 0.5 * (X+Y) as we notice that it + // is similar to the model 0.5 * (Y-X) + size_t patch_index = 0; + + for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) { + for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) { + if (texture_strength[patch_index] <= threshold) { + // Calculate mean value + float mean_int = 0; + for (size_t y_bl = 0; y_bl < block_s; ++y_bl) { + for (size_t x_bl = 0; x_bl < block_s; ++x_bl) { + mean_int += 0.5f * (opsin.PlaneRow(1, y + y_bl)[x + x_bl] + + opsin.PlaneRow(0, y + y_bl)[x + x_bl]); + } + } + mean_int /= block_s * block_s; + + // Calculate Noise level + float noise_level = 0; + size_t count = 0; + for (size_t y_bl = 0; y_bl < block_s; ++y_bl) { + for (size_t x_bl = 0; x_bl < block_s; ++x_bl) { + float filtered_value = 0; + for (int y_f = -1 * filt_size; y_f <= filt_size; ++y_f) { + if ((static_cast<ssize_t>(y_bl) + y_f) >= 0 && + (y_bl + y_f) < block_s) { + for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) { + if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 && + (x_bl + x_f) < block_s) { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl + x_f] + + opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl + x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } else { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl - x_f] + + opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl - x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } + } + } else { + for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) { + if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 && + (x_bl + x_f) < block_s) { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl + x_f] + + opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl + x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } else { + filtered_value += + 0.5f * + (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl - x_f] + + opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl - x_f]) * + kLaplFilter[y_f + filt_size][x_f + filt_size]; + } + } + } + } + noise_level += std::abs(filtered_value); + ++count; + } + } + noise_level /= count; + NoiseLevel nl; + nl.intensity = mean_int; + nl.noise_level = noise_level; + noise_level_per_intensity.push_back(nl); + } + ++patch_index; + } + } + return noise_level_per_intensity; +} + +void EncodeFloatParam(float val, float precision, BitWriter* writer) { + JXL_ASSERT(val >= 0); + const int absval_quant = static_cast<int>(val * precision + 0.5f); + JXL_ASSERT(absval_quant < (1 << 10)); + writer->Write(10, absval_quant); +} + +} // namespace + +Status GetNoiseParameter(const Image3F& opsin, NoiseParams* noise_params, + float quality_coef) { + // The size of a patch in decoder might be different from encoder's patch + // size. + // For encoder: the patch size should be big enough to estimate + // noise level, but, at the same time, it should be not too big + // to be able to estimate intensity value of the patch + const size_t block_s = 8; + const size_t kNumBin = 256; + NoiseHistogram sad_histogram; + std::vector<float> sad_scores = + GetSADScoresForPatches(opsin, block_s, kNumBin, &sad_histogram); + float sad_threshold = GetSADThreshold(sad_histogram, kNumBin); + // If threshold is too large, the image has a strong pattern. This pattern + // fools our model and it will add too much noise. Therefore, we do not add + // noise for such images + if (sad_threshold > 0.15f || sad_threshold <= 0.0f) { + noise_params->Clear(); + return false; + } + std::vector<NoiseLevel> nl = + GetNoiseLevel(opsin, sad_scores, sad_threshold, block_s); + + OptimizeNoiseParameters(nl, noise_params); + for (float& i : noise_params->lut) { + i *= quality_coef * 1.4; + } + return noise_params->HasAny(); +} + +void EncodeNoise(const NoiseParams& noise_params, BitWriter* writer, + size_t layer, AuxOut* aux_out) { + JXL_ASSERT(noise_params.HasAny()); + + BitWriter::Allotment allotment(writer, NoiseParams::kNumNoisePoints * 16); + for (float i : noise_params.lut) { + EncodeFloatParam(i, kNoisePrecision, writer); + } + allotment.ReclaimAndCharge(writer, layer, aux_out); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_noise.h b/third_party/jpeg-xl/lib/jxl/enc_noise.h new file mode 100644 index 0000000000..851fdd12db --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_noise.h @@ -0,0 +1,34 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_NOISE_H_ +#define LIB_JXL_ENC_NOISE_H_ + +// Noise parameter estimation. + +#include <stddef.h> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" +#include "lib/jxl/noise.h" + +namespace jxl { + +struct AuxOut; + +// Get parameters of the noise for NoiseParams model +// Returns whether a valid noise model (with HasAny()) is set. +Status GetNoiseParameter(const Image3F& opsin, NoiseParams* noise_params, + float quality_coef); + +// Does not write anything if `noise_params` are empty. Otherwise, caller must +// set FrameHeader.flags.kNoise. +void EncodeNoise(const NoiseParams& noise_params, BitWriter* writer, + size_t layer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_NOISE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_optimize.cc b/third_party/jpeg-xl/lib/jxl/enc_optimize.cc new file mode 100644 index 0000000000..6865ff67df --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_optimize.cc @@ -0,0 +1,163 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_optimize.h" + +#include <algorithm> + +#include "lib/jxl/base/status.h" + +namespace jxl { + +namespace optimize { + +namespace { + +// simplex vector must be sorted by first element of its elements +std::vector<double> Midpoint(const std::vector<std::vector<double>>& simplex) { + JXL_CHECK(!simplex.empty()); + JXL_CHECK(simplex.size() == simplex[0].size()); + int dim = simplex.size() - 1; + std::vector<double> result(dim + 1, 0); + for (int i = 0; i < dim; i++) { + for (int k = 0; k < dim; k++) { + result[i + 1] += simplex[k][i + 1]; + } + result[i + 1] /= dim; + } + return result; +} + +// first element ignored +std::vector<double> Subtract(const std::vector<double>& a, + const std::vector<double>& b) { + JXL_CHECK(a.size() == b.size()); + std::vector<double> result(a.size()); + result[0] = 0; + for (size_t i = 1; i < result.size(); i++) { + result[i] = a[i] - b[i]; + } + return result; +} + +// first element ignored +std::vector<double> Add(const std::vector<double>& a, + const std::vector<double>& b) { + JXL_CHECK(a.size() == b.size()); + std::vector<double> result(a.size()); + result[0] = 0; + for (size_t i = 1; i < result.size(); i++) { + result[i] = a[i] + b[i]; + } + return result; +} + +// first element ignored +std::vector<double> Average(const std::vector<double>& a, + const std::vector<double>& b) { + JXL_CHECK(a.size() == b.size()); + std::vector<double> result(a.size()); + result[0] = 0; + for (size_t i = 1; i < result.size(); i++) { + result[i] = 0.5 * (a[i] + b[i]); + } + return result; +} + +// vec: [0] will contain the objective function, [1:] will +// contain the vector position for the objective function. +// fun: the function evaluates the value. +void Eval(std::vector<double>* vec, + const std::function<double(const std::vector<double>&)>& fun) { + std::vector<double> args(vec->begin() + 1, vec->end()); + (*vec)[0] = fun(args); +} + +void Sort(std::vector<std::vector<double>>* simplex) { + std::sort(simplex->begin(), simplex->end()); +} + +// Main iteration step of Nelder-Mead like optimization. +void Reflect(std::vector<std::vector<double>>* simplex, + const std::function<double(const std::vector<double>&)>& fun) { + Sort(simplex); + const std::vector<double>& last = simplex->back(); + std::vector<double> mid = Midpoint(*simplex); + std::vector<double> diff = Subtract(mid, last); + std::vector<double> mirrored = Add(mid, diff); + Eval(&mirrored, fun); + if (mirrored[0] > (*simplex)[simplex->size() - 2][0]) { + // Still the worst, shrink towards the best. + std::vector<double> shrinking = Average(simplex->back(), (*simplex)[0]); + Eval(&shrinking, fun); + simplex->back() = shrinking; + } else if (mirrored[0] < (*simplex)[0][0]) { + // new best + std::vector<double> even_further = Add(mirrored, diff); + Eval(&even_further, fun); + if (even_further[0] < mirrored[0]) { + mirrored = even_further; + } + simplex->back() = mirrored; + } else { + // not a best, not a worst point + simplex->back() = mirrored; + } +} + +// Initialize the simplex at origin. +std::vector<std::vector<double>> InitialSimplex( + int dim, double amount, const std::vector<double>& init, + const std::function<double(const std::vector<double>&)>& fun) { + std::vector<double> best(1 + dim, 0); + std::copy(init.begin(), init.end(), best.begin() + 1); + Eval(&best, fun); + std::vector<std::vector<double>> result{best}; + for (int i = 0; i < dim; i++) { + best = result[0]; + best[i + 1] += amount; + Eval(&best, fun); + result.push_back(best); + Sort(&result); + } + return result; +} + +// For comparing the same with the python tool +/*void RunSimplexExternal( + int dim, double amount, int max_iterations, + const std::function<double((const vector<double>&))>& fun) { + vector<double> vars; + for (int i = 0; i < dim; i++) { + vars.push_back(atof(getenv(StrCat("VAR", i).c_str()))); + } + double result = fun(vars); + std::cout << "Result=" << result; +}*/ + +} // namespace + +std::vector<double> RunSimplex( + int dim, double amount, int max_iterations, const std::vector<double>& init, + const std::function<double(const std::vector<double>&)>& fun) { + std::vector<std::vector<double>> simplex = + InitialSimplex(dim, amount, init, fun); + for (int i = 0; i < max_iterations; i++) { + Sort(&simplex); + Reflect(&simplex, fun); + } + return simplex[0]; +} + +std::vector<double> RunSimplex( + int dim, double amount, int max_iterations, + const std::function<double(const std::vector<double>&)>& fun) { + std::vector<double> init(dim, 0.0); + return RunSimplex(dim, amount, max_iterations, init, fun); +} + +} // namespace optimize + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_optimize.h b/third_party/jpeg-xl/lib/jxl/enc_optimize.h new file mode 100644 index 0000000000..9da523f8ef --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_optimize.h @@ -0,0 +1,216 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Utility functions for optimizing multi-dimensional nonlinear functions. + +#ifndef LIB_JXL_OPTIMIZE_H_ +#define LIB_JXL_OPTIMIZE_H_ + +#include <cmath> +#include <cstdio> +#include <functional> +#include <vector> + +#include "lib/jxl/base/status.h" + +namespace jxl { +namespace optimize { + +// An array type of numeric values that supports math operations with operator-, +// operator+, etc. +template <typename T, size_t N> +class Array { + public: + Array() = default; + explicit Array(T v) { + for (size_t i = 0; i < N; i++) v_[i] = v; + } + + size_t size() const { return N; } + + T& operator[](size_t index) { + JXL_DASSERT(index < N); + return v_[index]; + } + T operator[](size_t index) const { + JXL_DASSERT(index < N); + return v_[index]; + } + + private: + // The values used by this Array. + T v_[N]; +}; + +template <typename T, size_t N> +Array<T, N> operator+(const Array<T, N>& x, const Array<T, N>& y) { + Array<T, N> z; + for (size_t i = 0; i < N; ++i) { + z[i] = x[i] + y[i]; + } + return z; +} + +template <typename T, size_t N> +Array<T, N> operator-(const Array<T, N>& x, const Array<T, N>& y) { + Array<T, N> z; + for (size_t i = 0; i < N; ++i) { + z[i] = x[i] - y[i]; + } + return z; +} + +template <typename T, size_t N> +Array<T, N> operator*(T v, const Array<T, N>& x) { + Array<T, N> y; + for (size_t i = 0; i < N; ++i) { + y[i] = v * x[i]; + } + return y; +} + +template <typename T, size_t N> +T operator*(const Array<T, N>& x, const Array<T, N>& y) { + T r = 0.0; + for (size_t i = 0; i < N; ++i) { + r += x[i] * y[i]; + } + return r; +} + +// Runs Nelder-Mead like optimization. Runs for max_iterations times, +// fun gets called with a vector of size dim as argument, and returns the score +// based on those parameters (lower is better). Returns a vector of dim+1 +// dimensions, where the first value is the optimal value of the function and +// the rest is the argmin value. Use init to pass an initial guess or where +// the optimal value is. +// +// Usage example: +// +// RunSimplex(2, 0.1, 100, [](const vector<float>& v) { +// return (v[0] - 5) * (v[0] - 5) + (v[1] - 7) * (v[1] - 7); +// }); +// +// Returns (0.0, 5, 7) +std::vector<double> RunSimplex( + int dim, double amount, int max_iterations, + const std::function<double(const std::vector<double>&)>& fun); +std::vector<double> RunSimplex( + int dim, double amount, int max_iterations, const std::vector<double>& init, + const std::function<double(const std::vector<double>&)>& fun); + +// Implementation of the Scaled Conjugate Gradient method described in the +// following paper: +// Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised +// Learning", Neural Networks, Vol. 6. pp. 525-533, 1993 +// http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf +// +// The Function template parameter is a class that has the following method: +// +// // Returns the value of the function at point w and sets *df to be the +// // negative gradient vector of the function at point w. +// double Compute(const optimize::Array<T, N>& w, +// optimize::Array<T, N>* df) const; +// +// Returns a vector w, such that |df(w)| < grad_norm_threshold. +template <typename T, size_t N, typename Function> +Array<T, N> OptimizeWithScaledConjugateGradientMethod( + const Function& f, const Array<T, N>& w0, const T grad_norm_threshold, + size_t max_iters) { + const size_t n = w0.size(); + const T rsq_threshold = grad_norm_threshold * grad_norm_threshold; + const T sigma0 = static_cast<T>(0.0001); + const T l_min = static_cast<T>(1.0e-15); + const T l_max = static_cast<T>(1.0e15); + + Array<T, N> w = w0; + Array<T, N> wp; + Array<T, N> r; + Array<T, N> rt; + Array<T, N> e; + Array<T, N> p; + T psq; + T fp; + T D; + T d; + T m; + T a; + T b; + T s; + T t; + + T fw = f.Compute(w, &r); + T rsq = r * r; + e = r; + p = r; + T l = static_cast<T>(1.0); + bool success = true; + size_t n_success = 0; + size_t k = 0; + + while (k++ < max_iters) { + if (success) { + m = -(p * r); + if (m >= 0) { + p = r; + m = -(p * r); + } + psq = p * p; + s = sigma0 / std::sqrt(psq); + f.Compute(w + (s * p), &rt); + t = (p * (r - rt)) / s; + } + + d = t + l * psq; + if (d <= 0) { + d = l * psq; + l = l - t / psq; + } + + a = -m / d; + wp = w + a * p; + fp = f.Compute(wp, &rt); + + D = 2.0 * (fp - fw) / (a * m); + if (D >= 0.0) { + success = true; + n_success++; + w = wp; + } else { + success = false; + } + + if (success) { + e = r; + r = rt; + rsq = r * r; + fw = fp; + if (rsq <= rsq_threshold) { + break; + } + } + + if (D < 0.25) { + l = std::min(4.0 * l, l_max); + } else if (D > 0.75) { + l = std::max(0.25 * l, l_min); + } + + if ((n_success % n) == 0) { + p = r; + l = 1.0; + } else if (success) { + b = ((e - r) * r) / m; + p = b * p + r; + } + } + + return w; +} + +} // namespace optimize +} // namespace jxl + +#endif // LIB_JXL_OPTIMIZE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_optimize_test.cc b/third_party/jpeg-xl/lib/jxl/enc_optimize_test.cc new file mode 100644 index 0000000000..cc65bf1a0c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_optimize_test.cc @@ -0,0 +1,107 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_optimize.h" + +#include "lib/jxl/testing.h" + +namespace jxl { +namespace optimize { +namespace { + +// The maximum number of iterations for the test. +static const size_t kMaxTestIter = 100000; + +// F(w) = (w - w_min)^2. +struct SimpleQuadraticFunction { + typedef Array<double, 2> ArrayType; + explicit SimpleQuadraticFunction(const ArrayType& w0) : w_min(w0) {} + + double Compute(const ArrayType& w, ArrayType* df) const { + ArrayType dw = w - w_min; + *df = -2.0 * dw; + return dw * dw; + } + + ArrayType w_min; +}; + +// F(alpha, beta, gamma| x,y) = \sum_i(y_i - (alpha x_i ^ gamma + beta))^2. +struct PowerFunction { + explicit PowerFunction(const std::vector<double>& x0, + const std::vector<double>& y0) + : x(x0), y(y0) {} + + typedef Array<double, 3> ArrayType; + double Compute(const ArrayType& w, ArrayType* df) const { + double loss_function = 0; + (*df)[0] = 0; + (*df)[1] = 0; + (*df)[2] = 0; + for (size_t ind = 0; ind < y.size(); ++ind) { + if (x[ind] != 0) { + double l_f = y[ind] - (w[0] * pow(x[ind], w[1]) + w[2]); + (*df)[0] += 2.0 * l_f * pow(x[ind], w[1]); + (*df)[1] += 2.0 * l_f * w[0] * pow(x[ind], w[1]) * log(x[ind]); + (*df)[2] += 2.0 * l_f * 1; + loss_function += l_f * l_f; + } + } + return loss_function; + } + + std::vector<double> x; + std::vector<double> y; +}; + +TEST(OptimizeTest, SimpleQuadraticFunction) { + SimpleQuadraticFunction::ArrayType w_min; + w_min[0] = 1.0; + w_min[1] = 2.0; + SimpleQuadraticFunction f(w_min); + SimpleQuadraticFunction::ArrayType w(0.); + static const double kPrecision = 1e-8; + w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision, + kMaxTestIter); + EXPECT_NEAR(w[0], 1.0, kPrecision); + EXPECT_NEAR(w[1], 2.0, kPrecision); +} + +TEST(OptimizeTest, PowerFunction) { + std::vector<double> x(10); + std::vector<double> y(10); + for (int ind = 0; ind < 10; ++ind) { + x[ind] = 1. * ind; + y[ind] = 2. * pow(x[ind], 3) + 5.; + } + PowerFunction f(x, y); + PowerFunction::ArrayType w(0.); + + static const double kPrecision = 0.01; + w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision, + kMaxTestIter); + EXPECT_NEAR(w[0], 2.0, kPrecision); + EXPECT_NEAR(w[1], 3.0, kPrecision); + EXPECT_NEAR(w[2], 5.0, kPrecision); +} + +TEST(OptimizeTest, SimplexOptTest) { + auto f = [](const std::vector<double>& x) -> double { + double t1 = x[0] - 1.0; + double t2 = x[1] + 1.5; + return 2.0 + t1 * t1 + t2 * t2; + }; + auto opt = RunSimplex(2, 0.01, 100, f); + EXPECT_EQ(opt.size(), 3u); + + static const double kPrecision = 0.01; + EXPECT_NEAR(opt[0], 2.0, kPrecision); + EXPECT_NEAR(opt[1], 1.0, kPrecision); + EXPECT_NEAR(opt[2], -1.5, kPrecision); +} + +} // namespace +} // namespace optimize +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_params.h b/third_party/jpeg-xl/lib/jxl/enc_params.h new file mode 100644 index 0000000000..89fd2c924f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_params.h @@ -0,0 +1,234 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_PARAMS_H_ +#define LIB_JXL_ENC_PARAMS_H_ + +// Parameters and flags that govern JXL compression. + +#include <jxl/cms_interface.h> +#include <jxl/encode.h> +#include <stddef.h> + +#include <vector> + +#include "lib/jxl/base/override.h" +#include "lib/jxl/enc_progressive_split.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/splines.h" + +namespace jxl { + +enum class SpeedTier { + // Try multiple combinations of Tortoise flags for modular mode. Otherwise + // like kTortoise. + kGlacier = 0, + // Turns on FindBestQuantizationHQ loop. Equivalent to "guetzli" mode. + kTortoise = 1, + // Turns on FindBestQuantization butteraugli loop. + kKitten = 2, + // Turns on dots, patches, and spline detection by default, as well as full + // context clustering. Default. + kSquirrel = 3, + // Turns on error diffusion and full AC strategy heuristics. Equivalent to + // "fast" mode. + kWombat = 4, + // Turns on gaborish by default, non-default cmap, initial quant field. + kHare = 5, + // Turns on simple heuristics for AC strategy, quant field, and clustering; + // also enables coefficient reordering. + kCheetah = 6, + // Turns off most encoder features. Does context clustering. + // Modular: uses fixed tree with Weighted predictor. + kFalcon = 7, + // Currently fastest possible setting for VarDCT. + // Modular: uses fixed tree with Gradient predictor. + kThunder = 8, + // VarDCT: same as kThunder. + // Modular: no tree, Gradient predictor, fast histograms + kLightning = 9 +}; + +// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) +struct CompressParams { + float butteraugli_distance = 1.0f; + + // explicit distances for extra channels (defaults to butteraugli_distance + // when not set; value of -1 can be used to represent 'default') + std::vector<float> ec_distance; + + // Try to achieve a maximum pixel-by-pixel error on each channel. + bool max_error_mode = false; + float max_error[3] = {0.0, 0.0, 0.0}; + + SpeedTier speed_tier = SpeedTier::kSquirrel; + int brotli_effort = -1; + + // 0 = default. + // 1 = slightly worse quality. + // 4 = fastest speed, lowest quality + size_t decoding_speed_tier = 0; + + ColorTransform color_transform = ColorTransform::kXYB; + + // If true, the "modular mode options" members below are used. + bool modular_mode = false; + + // Change group size in modular mode (0=128, 1=256, 2=512, 3=1024, -1=encoder + // chooses). + int modular_group_size_shift = -1; + + Override preview = Override::kDefault; + Override noise = Override::kDefault; + Override dots = Override::kDefault; + Override patches = Override::kDefault; + Override gaborish = Override::kDefault; + int epf = -1; + + // Progressive mode. + Override progressive_mode = Override::kDefault; + + // Quantized-progressive mode. + Override qprogressive_mode = Override::kDefault; + + // Put center groups first in the bitstream. + bool centerfirst = false; + + // Pixel coordinates of the center. First group will contain that center. + size_t center_x = static_cast<size_t>(-1); + size_t center_y = static_cast<size_t>(-1); + + int progressive_dc = -1; + + // If on: preserve color of invisible pixels (if off: don't care) + // Default: on for lossless, off for lossy + Override keep_invisible = Override::kDefault; + + JxlCmsInterface cms; + bool cms_set = false; + void SetCms(const JxlCmsInterface& cms) { + this->cms = cms; + cms_set = true; + } + + // Force usage of CfL when doing JPEG recompression. This can have unexpected + // effects on the decoded pixels, while still being JPEG-compliant and + // allowing reconstruction of the original JPEG. + bool force_cfl_jpeg_recompression = true; + + // Use brotli compression for any boxes derived from a JPEG frame. + bool jpeg_compress_boxes = true; + + // Preserve this metadata when doing JPEG recompression. + bool jpeg_keep_exif = true; + bool jpeg_keep_xmp = true; + bool jpeg_keep_jumbf = true; + + // Set the noise to what it would approximately be if shooting at the nominal + // exposure for a given ISO setting on a 35mm camera. + float photon_noise_iso = 0; + + // modular mode options below + ModularOptions options; + int responsive = -1; + int colorspace = -1; + // Use Global channel palette if #colors < this percentage of range + float channel_colors_pre_transform_percent = 95.f; + // Use Local channel palette if #colors < this percentage of range + float channel_colors_percent = 80.f; + int palette_colors = 1 << 10; // up to 10-bit palette is probably worthwhile + bool lossy_palette = false; + + // Returns whether these params are lossless as defined by SetLossless(); + bool IsLossless() const { return modular_mode && ModularPartIsLossless(); } + + bool ModularPartIsLossless() const { + if (modular_mode) { + // YCbCr is also considered lossless here since it's intended for + // source material that is already YCbCr (we don't do the fwd transform) + if (butteraugli_distance != 0 || + color_transform == jxl::ColorTransform::kXYB) + return false; + } + for (float f : ec_distance) { + if (f > 0) return false; + if (f < 0 && butteraugli_distance != 0) return false; + } + // if no explicit ec_distance given, and using vardct, then the modular part + // is empty or not lossless + if (!modular_mode && ec_distance.empty()) return false; + // all modular channels are encoded at distance 0 + return true; + } + + // Sets the parameters required to make the codec lossless. + void SetLossless() { + modular_mode = true; + butteraugli_distance = 0.0f; + for (float& f : ec_distance) f = 0.0f; + color_transform = jxl::ColorTransform::kNone; + } + + // Down/upsample the image before encoding / after decoding by this factor. + // The resampling value can also be set to <= 0 to automatically choose based + // on distance, however EncodeFrame doesn't support this, so it is + // required to call PostInit() to set a valid positive resampling + // value and altered butteraugli score if this is used. + int resampling = -1; + int ec_resampling = -1; + // Skip the downsampling before encoding if this is true. + bool already_downsampled = false; + // Butteraugli target distance on the original full size image, this can be + // different from butteraugli_distance if resampling was used. + float original_butteraugli_distance = -1.0f; + + float quant_ac_rescale = 1.0; + + // Codestream level to conform to. + // -1: don't care + int level = -1; + + // See JXL_ENC_FRAME_SETTING_BUFFERING option value. + int buffering = 0; + // See JXL_ENC_FRAME_SETTING_USE_FULL_IMAGE_HEURISTICS option value. + bool use_full_image_heuristics = true; + + std::vector<float> manual_noise; + std::vector<float> manual_xyb_factors; + + // If not empty, this tree will be used for dc global section. + // Used in jxl_from_tree tool. + Tree custom_fixed_tree; + // If not empty, these custom splines will be used instead of the computed + // ones. Used in jxl_from_tee tool. + Splines custom_splines; + // If not null, overrides progressive mode settings. Used in decode_test. + const ProgressiveMode* custom_progressive_mode = nullptr; + + JxlDebugImageCallback debug_image = nullptr; + void* debug_image_opaque; +}; + +static constexpr float kMinButteraugliForDynamicAR = 0.5f; +static constexpr float kMinButteraugliForDots = 3.0f; +static constexpr float kMinButteraugliToSubtractOriginalPatches = 3.0f; + +// Always off +static constexpr float kMinButteraugliForNoise = 99.0f; + +// Minimum butteraugli distance the encoder accepts. +static constexpr float kMinButteraugliDistance = 0.001f; + +// Tile size for encoder-side processing. Must be equal to color tile dim in the +// current implementation. +static constexpr size_t kEncTileDim = 64; +static constexpr size_t kEncTileDimInBlocks = kEncTileDim / kBlockDim; + +} // namespace jxl + +#endif // LIB_JXL_ENC_PARAMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.cc b/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.cc new file mode 100644 index 0000000000..0abd177809 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.cc @@ -0,0 +1,819 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_patch_dictionary.h" + +#include <stdint.h> +#include <stdlib.h> +#include <sys/types.h> + +#include <algorithm> +#include <atomic> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_debug_image.h" +#include "lib/jxl/enc_dot_dictionary.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/pack_signed.h" +#include "lib/jxl/patch_dictionary_internal.h" + +namespace jxl { + +static constexpr size_t kPatchFrameReferenceId = 3; + +// static +void PatchDictionaryEncoder::Encode(const PatchDictionary& pdic, + BitWriter* writer, size_t layer, + AuxOut* aux_out) { + JXL_ASSERT(pdic.HasAny()); + std::vector<std::vector<Token>> tokens(1); + size_t num_ec = pdic.shared_->metadata->m.num_extra_channels; + + auto add_num = [&](int context, size_t num) { + tokens[0].emplace_back(context, num); + }; + size_t num_ref_patch = 0; + for (size_t i = 0; i < pdic.positions_.size();) { + size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx; + while (i < pdic.positions_.size() && + pdic.positions_[i].ref_pos_idx == ref_pos_idx) { + i++; + } + num_ref_patch++; + } + add_num(kNumRefPatchContext, num_ref_patch); + size_t blend_pos = 0; + for (size_t i = 0; i < pdic.positions_.size();) { + size_t i_start = i; + size_t ref_pos_idx = pdic.positions_[i].ref_pos_idx; + const auto& ref_pos = pdic.ref_positions_[ref_pos_idx]; + while (i < pdic.positions_.size() && + pdic.positions_[i].ref_pos_idx == ref_pos_idx) { + i++; + } + size_t num = i - i_start; + JXL_ASSERT(num > 0); + add_num(kReferenceFrameContext, ref_pos.ref); + add_num(kPatchReferencePositionContext, ref_pos.x0); + add_num(kPatchReferencePositionContext, ref_pos.y0); + add_num(kPatchSizeContext, ref_pos.xsize - 1); + add_num(kPatchSizeContext, ref_pos.ysize - 1); + add_num(kPatchCountContext, num - 1); + for (size_t j = i_start; j < i; j++) { + const PatchPosition& pos = pdic.positions_[j]; + if (j == i_start) { + add_num(kPatchPositionContext, pos.x); + add_num(kPatchPositionContext, pos.y); + } else { + add_num(kPatchOffsetContext, + PackSigned(pos.x - pdic.positions_[j - 1].x)); + add_num(kPatchOffsetContext, + PackSigned(pos.y - pdic.positions_[j - 1].y)); + } + for (size_t j = 0; j < num_ec + 1; ++j, ++blend_pos) { + const PatchBlending& info = pdic.blendings_[blend_pos]; + add_num(kPatchBlendModeContext, static_cast<uint32_t>(info.mode)); + if (UsesAlpha(info.mode) && + pdic.shared_->metadata->m.extra_channel_info.size() > 1) { + add_num(kPatchAlphaChannelContext, info.alpha_channel); + } + if (UsesClamp(info.mode)) { + add_num(kPatchClampContext, info.clamp); + } + } + } + } + + EntropyEncodingData codes; + std::vector<uint8_t> context_map; + BuildAndEncodeHistograms(HistogramParams(), kNumPatchDictionaryContexts, + tokens, &codes, &context_map, writer, layer, + aux_out); + WriteTokens(tokens[0], codes, context_map, 0, writer, layer, aux_out); +} + +// static +void PatchDictionaryEncoder::SubtractFrom(const PatchDictionary& pdic, + Image3F* opsin) { + size_t num_ec = pdic.shared_->metadata->m.num_extra_channels; + // TODO(veluca): this can likely be optimized knowing it runs on full images. + for (size_t y = 0; y < opsin->ysize(); y++) { + float* JXL_RESTRICT rows[3] = { + opsin->PlaneRow(0, y), + opsin->PlaneRow(1, y), + opsin->PlaneRow(2, y), + }; + for (size_t pos_idx : pdic.GetPatchesForRow(y)) { + const size_t blending_idx = pos_idx * (num_ec + 1); + const PatchPosition& pos = pdic.positions_[pos_idx]; + const PatchReferencePosition& ref_pos = + pdic.ref_positions_[pos.ref_pos_idx]; + const PatchBlendMode mode = pdic.blendings_[blending_idx].mode; + size_t by = pos.y; + size_t bx = pos.x; + size_t xsize = ref_pos.xsize; + JXL_DASSERT(y >= by); + JXL_DASSERT(y < by + ref_pos.ysize); + size_t iy = y - by; + size_t ref = ref_pos.ref; + const float* JXL_RESTRICT ref_rows[3] = { + pdic.shared_->reference_frames[ref].frame.color().ConstPlaneRow( + 0, ref_pos.y0 + iy) + + ref_pos.x0, + pdic.shared_->reference_frames[ref].frame.color().ConstPlaneRow( + 1, ref_pos.y0 + iy) + + ref_pos.x0, + pdic.shared_->reference_frames[ref].frame.color().ConstPlaneRow( + 2, ref_pos.y0 + iy) + + ref_pos.x0, + }; + for (size_t ix = 0; ix < xsize; ix++) { + for (size_t c = 0; c < 3; c++) { + if (mode == PatchBlendMode::kAdd) { + rows[c][bx + ix] -= ref_rows[c][ix]; + } else if (mode == PatchBlendMode::kReplace) { + rows[c][bx + ix] = 0; + } else if (mode == PatchBlendMode::kNone) { + // Nothing to do. + } else { + JXL_UNREACHABLE("Blending mode %u not yet implemented", + (uint32_t)mode); + } + } + } + } + } +} + +namespace { + +struct PatchColorspaceInfo { + float kChannelDequant[3]; + float kChannelWeights[3]; + + explicit PatchColorspaceInfo(bool is_xyb) { + if (is_xyb) { + kChannelDequant[0] = 0.01615; + kChannelDequant[1] = 0.08875; + kChannelDequant[2] = 0.1922; + kChannelWeights[0] = 30.0; + kChannelWeights[1] = 3.0; + kChannelWeights[2] = 1.0; + } else { + kChannelDequant[0] = 20.0f / 255; + kChannelDequant[1] = 22.0f / 255; + kChannelDequant[2] = 20.0f / 255; + kChannelWeights[0] = 0.017 * 255; + kChannelWeights[1] = 0.02 * 255; + kChannelWeights[2] = 0.017 * 255; + } + } + + float ScaleForQuantization(float val, size_t c) { + return val / kChannelDequant[c]; + } + + int Quantize(float val, size_t c) { + return truncf(ScaleForQuantization(val, c)); + } + + bool is_similar_v(const float v1[3], const float v2[3], float threshold) { + float distance = 0; + for (size_t c = 0; c < 3; c++) { + distance += std::fabs(v1[c] - v2[c]) * kChannelWeights[c]; + } + return distance <= threshold; + } +}; + +std::vector<PatchInfo> FindTextLikePatches( + const CompressParams& cparams, const Image3F& opsin, + const PassesEncoderState* JXL_RESTRICT state, ThreadPool* pool, + AuxOut* aux_out, bool is_xyb) { + if (state->cparams.patches == Override::kOff) return {}; + const auto& frame_dim = state->shared.frame_dim; + + PatchColorspaceInfo pci(is_xyb); + float kSimilarThreshold = 0.8f; + + auto is_similar_impl = [&pci](std::pair<uint32_t, uint32_t> p1, + std::pair<uint32_t, uint32_t> p2, + const float* JXL_RESTRICT rows[3], + size_t stride, float threshold) { + float v1[3], v2[3]; + for (size_t c = 0; c < 3; c++) { + v1[c] = rows[c][p1.second * stride + p1.first]; + v2[c] = rows[c][p2.second * stride + p2.first]; + } + return pci.is_similar_v(v1, v2, threshold); + }; + + std::atomic<bool> has_screenshot_areas{false}; + const size_t opsin_stride = opsin.PixelsPerRow(); + const float* JXL_RESTRICT opsin_rows[3] = {opsin.ConstPlaneRow(0, 0), + opsin.ConstPlaneRow(1, 0), + opsin.ConstPlaneRow(2, 0)}; + + auto is_same = [&opsin_rows, opsin_stride](std::pair<uint32_t, uint32_t> p1, + std::pair<uint32_t, uint32_t> p2) { + for (size_t c = 0; c < 3; c++) { + float v1 = opsin_rows[c][p1.second * opsin_stride + p1.first]; + float v2 = opsin_rows[c][p2.second * opsin_stride + p2.first]; + if (std::fabs(v1 - v2) > 1e-4) { + return false; + } + } + return true; + }; + + auto is_similar = [&](std::pair<uint32_t, uint32_t> p1, + std::pair<uint32_t, uint32_t> p2) { + return is_similar_impl(p1, p2, opsin_rows, opsin_stride, kSimilarThreshold); + }; + + constexpr int64_t kPatchSide = 4; + constexpr int64_t kExtraSide = 4; + + // Look for kPatchSide size squares, naturally aligned, that all have the same + // pixel values. + ImageB is_screenshot_like(DivCeil(frame_dim.xsize, kPatchSide), + DivCeil(frame_dim.ysize, kPatchSide)); + ZeroFillImage(&is_screenshot_like); + uint8_t* JXL_RESTRICT screenshot_row = is_screenshot_like.Row(0); + const size_t screenshot_stride = is_screenshot_like.PixelsPerRow(); + const auto process_row = [&](const uint32_t y, size_t /* thread */) { + for (uint64_t x = 0; x < frame_dim.xsize / kPatchSide; x++) { + bool all_same = true; + for (size_t iy = 0; iy < static_cast<size_t>(kPatchSide); iy++) { + for (size_t ix = 0; ix < static_cast<size_t>(kPatchSide); ix++) { + size_t cx = x * kPatchSide + ix; + size_t cy = y * kPatchSide + iy; + if (!is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) { + all_same = false; + break; + } + } + } + if (!all_same) continue; + size_t num = 0; + size_t num_same = 0; + for (int64_t iy = -kExtraSide; iy < kExtraSide + kPatchSide; iy++) { + for (int64_t ix = -kExtraSide; ix < kExtraSide + kPatchSide; ix++) { + int64_t cx = x * kPatchSide + ix; + int64_t cy = y * kPatchSide + iy; + if (cx < 0 || static_cast<uint64_t>(cx) >= frame_dim.xsize || // + cy < 0 || static_cast<uint64_t>(cy) >= frame_dim.ysize) { + continue; + } + num++; + if (is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) num_same++; + } + } + // Too few equal pixels nearby. + if (num_same * 8 < num * 7) continue; + screenshot_row[y * screenshot_stride + x] = 1; + has_screenshot_areas = true; + } + }; + JXL_CHECK(RunOnPool(pool, 0, frame_dim.ysize / kPatchSide, ThreadPool::NoInit, + process_row, "IsScreenshotLike")); + + // TODO(veluca): also parallelize the rest of this function. + if (WantDebugOutput(cparams)) { + DumpPlaneNormalized(cparams, "screenshot_like", is_screenshot_like); + } + + constexpr int kSearchRadius = 1; + + if (!ApplyOverride(state->cparams.patches, has_screenshot_areas)) { + return {}; + } + + // Search for "similar enough" pixels near the screenshot-like areas. + ImageB is_background(frame_dim.xsize, frame_dim.ysize); + ZeroFillImage(&is_background); + Image3F background(frame_dim.xsize, frame_dim.ysize); + ZeroFillImage(&background); + constexpr size_t kDistanceLimit = 50; + float* JXL_RESTRICT background_rows[3] = { + background.PlaneRow(0, 0), + background.PlaneRow(1, 0), + background.PlaneRow(2, 0), + }; + const size_t background_stride = background.PixelsPerRow(); + uint8_t* JXL_RESTRICT is_background_row = is_background.Row(0); + const size_t is_background_stride = is_background.PixelsPerRow(); + std::vector< + std::pair<std::pair<uint32_t, uint32_t>, std::pair<uint32_t, uint32_t>>> + queue; + size_t queue_front = 0; + for (size_t y = 0; y < frame_dim.ysize; y++) { + for (size_t x = 0; x < frame_dim.xsize; x++) { + if (!screenshot_row[screenshot_stride * (y / kPatchSide) + + (x / kPatchSide)]) + continue; + queue.push_back({{x, y}, {x, y}}); + } + } + while (queue.size() != queue_front) { + std::pair<uint32_t, uint32_t> cur = queue[queue_front].first; + std::pair<uint32_t, uint32_t> src = queue[queue_front].second; + queue_front++; + if (is_background_row[cur.second * is_background_stride + cur.first]) + continue; + is_background_row[cur.second * is_background_stride + cur.first] = 1; + for (size_t c = 0; c < 3; c++) { + background_rows[c][cur.second * background_stride + cur.first] = + opsin_rows[c][src.second * opsin_stride + src.first]; + } + for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) { + for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) { + if (dx == 0 && dy == 0) continue; + int next_first = cur.first + dx; + int next_second = cur.second + dy; + if (next_first < 0 || next_second < 0 || + static_cast<uint32_t>(next_first) >= frame_dim.xsize || + static_cast<uint32_t>(next_second) >= frame_dim.ysize) { + continue; + } + if (static_cast<uint32_t>( + std::abs(next_first - static_cast<int>(src.first)) + + std::abs(next_second - static_cast<int>(src.second))) > + kDistanceLimit) { + continue; + } + std::pair<uint32_t, uint32_t> next{next_first, next_second}; + if (is_similar(src, next)) { + if (!screenshot_row[next.second / kPatchSide * screenshot_stride + + next.first / kPatchSide] || + is_same(src, next)) { + if (!is_background_row[next.second * is_background_stride + + next.first]) + queue.emplace_back(next, src); + } + } + } + } + } + queue.clear(); + + ImageF ccs; + Rng rng(0); + bool paint_ccs = false; + if (WantDebugOutput(cparams)) { + DumpPlaneNormalized(cparams, "is_background", is_background); + if (is_xyb) { + DumpXybImage(cparams, "background", background); + } else { + DumpImage(cparams, "background", background); + } + ccs = ImageF(frame_dim.xsize, frame_dim.ysize); + ZeroFillImage(&ccs); + paint_ccs = true; + } + + constexpr float kVerySimilarThreshold = 0.03f; + constexpr float kHasSimilarThreshold = 0.03f; + + const float* JXL_RESTRICT const_background_rows[3] = { + background_rows[0], background_rows[1], background_rows[2]}; + auto is_similar_b = [&](std::pair<int, int> p1, std::pair<int, int> p2) { + return is_similar_impl(p1, p2, const_background_rows, background_stride, + kVerySimilarThreshold); + }; + + constexpr int kMinPeak = 2; + constexpr int kHasSimilarRadius = 2; + + std::vector<PatchInfo> info; + + // Find small CC outside the "similar enough" areas, compute bounding boxes, + // and run heuristics to exclude some patches. + ImageB visited(frame_dim.xsize, frame_dim.ysize); + ZeroFillImage(&visited); + uint8_t* JXL_RESTRICT visited_row = visited.Row(0); + const size_t visited_stride = visited.PixelsPerRow(); + std::vector<std::pair<uint32_t, uint32_t>> cc; + std::vector<std::pair<uint32_t, uint32_t>> stack; + for (size_t y = 0; y < frame_dim.ysize; y++) { + for (size_t x = 0; x < frame_dim.xsize; x++) { + if (is_background_row[y * is_background_stride + x]) continue; + cc.clear(); + stack.clear(); + stack.emplace_back(x, y); + size_t min_x = x; + size_t max_x = x; + size_t min_y = y; + size_t max_y = y; + std::pair<uint32_t, uint32_t> reference; + bool found_border = false; + bool all_similar = true; + while (!stack.empty()) { + std::pair<uint32_t, uint32_t> cur = stack.back(); + stack.pop_back(); + if (visited_row[cur.second * visited_stride + cur.first]) continue; + visited_row[cur.second * visited_stride + cur.first] = 1; + if (cur.first < min_x) min_x = cur.first; + if (cur.first > max_x) max_x = cur.first; + if (cur.second < min_y) min_y = cur.second; + if (cur.second > max_y) max_y = cur.second; + if (paint_ccs) { + cc.push_back(cur); + } + for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) { + for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) { + if (dx == 0 && dy == 0) continue; + int next_first = static_cast<int32_t>(cur.first) + dx; + int next_second = static_cast<int32_t>(cur.second) + dy; + if (next_first < 0 || next_second < 0 || + static_cast<uint32_t>(next_first) >= frame_dim.xsize || + static_cast<uint32_t>(next_second) >= frame_dim.ysize) { + continue; + } + std::pair<uint32_t, uint32_t> next{next_first, next_second}; + if (!is_background_row[next.second * is_background_stride + + next.first]) { + stack.push_back(next); + } else { + if (!found_border) { + reference = next; + found_border = true; + } else { + if (!is_similar_b(next, reference)) all_similar = false; + } + } + } + } + } + if (!found_border || !all_similar || max_x - min_x >= kMaxPatchSize || + max_y - min_y >= kMaxPatchSize) { + continue; + } + size_t bpos = background_stride * reference.second + reference.first; + float ref[3] = {background_rows[0][bpos], background_rows[1][bpos], + background_rows[2][bpos]}; + bool has_similar = false; + for (size_t iy = std::max<int>( + static_cast<int32_t>(min_y) - kHasSimilarRadius, 0); + iy < std::min(max_y + kHasSimilarRadius + 1, frame_dim.ysize); + iy++) { + for (size_t ix = std::max<int>( + static_cast<int32_t>(min_x) - kHasSimilarRadius, 0); + ix < std::min(max_x + kHasSimilarRadius + 1, frame_dim.xsize); + ix++) { + size_t opos = opsin_stride * iy + ix; + float px[3] = {opsin_rows[0][opos], opsin_rows[1][opos], + opsin_rows[2][opos]}; + if (pci.is_similar_v(ref, px, kHasSimilarThreshold)) { + has_similar = true; + } + } + } + if (!has_similar) continue; + info.emplace_back(); + info.back().second.emplace_back(min_x, min_y); + QuantizedPatch& patch = info.back().first; + patch.xsize = max_x - min_x + 1; + patch.ysize = max_y - min_y + 1; + int max_value = 0; + for (size_t c : {1, 0, 2}) { + for (size_t iy = min_y; iy <= max_y; iy++) { + for (size_t ix = min_x; ix <= max_x; ix++) { + size_t offset = (iy - min_y) * patch.xsize + ix - min_x; + patch.fpixels[c][offset] = + opsin_rows[c][iy * opsin_stride + ix] - ref[c]; + int val = pci.Quantize(patch.fpixels[c][offset], c); + patch.pixels[c][offset] = val; + if (std::abs(val) > max_value) max_value = std::abs(val); + } + } + } + if (max_value < kMinPeak) { + info.pop_back(); + continue; + } + if (paint_ccs) { + float cc_color = rng.UniformF(0.5, 1.0); + for (std::pair<uint32_t, uint32_t> p : cc) { + ccs.Row(p.second)[p.first] = cc_color; + } + } + } + } + + if (paint_ccs) { + JXL_ASSERT(WantDebugOutput(cparams)); + DumpPlaneNormalized(cparams, "ccs", ccs); + } + if (info.empty()) { + return {}; + } + + // Remove duplicates. + constexpr size_t kMinPatchOccurrences = 2; + std::sort(info.begin(), info.end()); + size_t unique = 0; + for (size_t i = 1; i < info.size(); i++) { + if (info[i].first == info[unique].first) { + info[unique].second.insert(info[unique].second.end(), + info[i].second.begin(), info[i].second.end()); + } else { + if (info[unique].second.size() >= kMinPatchOccurrences) { + unique++; + } + info[unique] = info[i]; + } + } + if (info[unique].second.size() >= kMinPatchOccurrences) { + unique++; + } + info.resize(unique); + + size_t max_patch_size = 0; + + for (size_t i = 0; i < info.size(); i++) { + size_t pixels = info[i].first.xsize * info[i].first.ysize; + if (pixels > max_patch_size) max_patch_size = pixels; + } + + // don't use patches if all patches are smaller than this + constexpr size_t kMinMaxPatchSize = 20; + if (max_patch_size < kMinMaxPatchSize) return {}; + + return info; +} + +} // namespace + +void FindBestPatchDictionary(const Image3F& opsin, + PassesEncoderState* JXL_RESTRICT state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out, bool is_xyb) { + std::vector<PatchInfo> info = + FindTextLikePatches(state->cparams, opsin, state, pool, aux_out, is_xyb); + + // TODO(veluca): this doesn't work if both dots and patches are enabled. + // For now, since dots and patches are not likely to occur in the same kind of + // images, disable dots if some patches were found. + if (info.empty() && + ApplyOverride( + state->cparams.dots, + state->cparams.speed_tier <= SpeedTier::kSquirrel && + state->cparams.butteraugli_distance >= kMinButteraugliForDots)) { + info = FindDotDictionary(state->cparams, opsin, state->shared.cmap, pool); + } + + if (info.empty()) return; + + std::sort( + info.begin(), info.end(), [&](const PatchInfo& a, const PatchInfo& b) { + return a.first.xsize * a.first.ysize > b.first.xsize * b.first.ysize; + }); + + size_t max_x_size = 0; + size_t max_y_size = 0; + size_t total_pixels = 0; + + for (size_t i = 0; i < info.size(); i++) { + size_t pixels = info[i].first.xsize * info[i].first.ysize; + if (max_x_size < info[i].first.xsize) max_x_size = info[i].first.xsize; + if (max_y_size < info[i].first.ysize) max_y_size = info[i].first.ysize; + total_pixels += pixels; + } + + // Bin-packing & conversion of patches. + constexpr float kBinPackingSlackness = 1.05f; + size_t ref_xsize = std::max<float>(max_x_size, std::sqrt(total_pixels)); + size_t ref_ysize = std::max<float>(max_y_size, std::sqrt(total_pixels)); + std::vector<std::pair<size_t, size_t>> ref_positions(info.size()); + // TODO(veluca): allow partial overlaps of patches that have the same pixels. + size_t max_y = 0; + do { + max_y = 0; + // Increase packed image size. + ref_xsize = ref_xsize * kBinPackingSlackness + 1; + ref_ysize = ref_ysize * kBinPackingSlackness + 1; + + ImageB occupied(ref_xsize, ref_ysize); + ZeroFillImage(&occupied); + uint8_t* JXL_RESTRICT occupied_rows = occupied.Row(0); + size_t occupied_stride = occupied.PixelsPerRow(); + + bool success = true; + // For every patch... + for (size_t patch = 0; patch < info.size(); patch++) { + size_t x0 = 0; + size_t y0 = 0; + size_t xsize = info[patch].first.xsize; + size_t ysize = info[patch].first.ysize; + bool found = false; + // For every possible start position ... + for (; y0 + ysize <= ref_ysize; y0++) { + x0 = 0; + for (; x0 + xsize <= ref_xsize; x0++) { + bool has_occupied_pixel = false; + size_t x = x0; + // Check if it is possible to place the patch in this position in the + // reference frame. + for (size_t y = y0; y < y0 + ysize; y++) { + x = x0; + for (; x < x0 + xsize; x++) { + if (occupied_rows[y * occupied_stride + x]) { + has_occupied_pixel = true; + break; + } + } + } // end of positioning check + if (!has_occupied_pixel) { + found = true; + break; + } + x0 = x; // Jump to next pixel after the occupied one. + } + if (found) break; + } // end of start position checking + + // We didn't find a possible position: repeat from the beginning with a + // larger reference frame size. + if (!found) { + success = false; + break; + } + + // We found a position: mark the corresponding positions in the reference + // image as used. + ref_positions[patch] = {x0, y0}; + for (size_t y = y0; y < y0 + ysize; y++) { + for (size_t x = x0; x < x0 + xsize; x++) { + occupied_rows[y * occupied_stride + x] = true; + } + } + max_y = std::max(max_y, y0 + ysize); + } + + if (success) break; + } while (true); + + JXL_ASSERT(ref_ysize >= max_y); + + ref_ysize = max_y; + + Image3F reference_frame(ref_xsize, ref_ysize); + // TODO(veluca): figure out a better way to fill the image. + ZeroFillImage(&reference_frame); + std::vector<PatchPosition> positions; + std::vector<PatchReferencePosition> pref_positions; + std::vector<PatchBlending> blendings; + float* JXL_RESTRICT ref_rows[3] = { + reference_frame.PlaneRow(0, 0), + reference_frame.PlaneRow(1, 0), + reference_frame.PlaneRow(2, 0), + }; + size_t ref_stride = reference_frame.PixelsPerRow(); + size_t num_ec = state->shared.metadata->m.num_extra_channels; + + for (size_t i = 0; i < info.size(); i++) { + PatchReferencePosition ref_pos; + ref_pos.xsize = info[i].first.xsize; + ref_pos.ysize = info[i].first.ysize; + ref_pos.x0 = ref_positions[i].first; + ref_pos.y0 = ref_positions[i].second; + ref_pos.ref = kPatchFrameReferenceId; + for (size_t y = 0; y < ref_pos.ysize; y++) { + for (size_t x = 0; x < ref_pos.xsize; x++) { + for (size_t c = 0; c < 3; c++) { + ref_rows[c][(y + ref_pos.y0) * ref_stride + x + ref_pos.x0] = + info[i].first.fpixels[c][y * ref_pos.xsize + x]; + } + } + } + for (const auto& pos : info[i].second) { + positions.emplace_back( + PatchPosition{pos.first, pos.second, pref_positions.size()}); + // Add blending for color channels, ignore other channels. + blendings.push_back({PatchBlendMode::kAdd, 0, false}); + for (size_t j = 0; j < num_ec; ++j) { + blendings.push_back({PatchBlendMode::kNone, 0, false}); + } + } + pref_positions.emplace_back(std::move(ref_pos)); + } + + CompressParams cparams = state->cparams; + // Recursive application of patches could create very weird issues. + cparams.patches = Override::kOff; + + RoundtripPatchFrame(&reference_frame, state, kPatchFrameReferenceId, cparams, + cms, pool, aux_out, /*subtract=*/true); + + // TODO(veluca): this assumes that applying patches is commutative, which is + // not true for all blending modes. This code only produces kAdd patches, so + // this works out. + PatchDictionaryEncoder::SetPositions( + &state->shared.image_features.patches, std::move(positions), + std::move(pref_positions), std::move(blendings)); +} + +void RoundtripPatchFrame(Image3F* reference_frame, + PassesEncoderState* JXL_RESTRICT state, int idx, + CompressParams& cparams, const JxlCmsInterface& cms, + ThreadPool* pool, AuxOut* aux_out, bool subtract) { + FrameInfo patch_frame_info; + cparams.resampling = 1; + cparams.ec_resampling = 1; + cparams.dots = Override::kOff; + cparams.noise = Override::kOff; + cparams.modular_mode = true; + cparams.responsive = 0; + cparams.progressive_dc = 0; + cparams.progressive_mode = Override::kOff; + cparams.qprogressive_mode = Override::kOff; + // Use gradient predictor and not Predictor::Best. + cparams.options.predictor = Predictor::Gradient; + patch_frame_info.save_as_reference = idx; // always saved. + patch_frame_info.frame_type = FrameType::kReferenceOnly; + patch_frame_info.save_before_color_transform = true; + ImageBundle ib(&state->shared.metadata->m); + // TODO(veluca): metadata.color_encoding is a lie: ib is in XYB, but there is + // no simple way to express that yet. + patch_frame_info.ib_needs_color_transform = false; + ib.SetFromImage(std::move(*reference_frame), + state->shared.metadata->m.color_encoding); + if (!ib.metadata()->extra_channel_info.empty()) { + // Add placeholder extra channels to the patch image: patch encoding does + // not yet support extra channels, but the codec expects that the amount of + // extra channels in frames matches that in the metadata of the codestream. + std::vector<ImageF> extra_channels; + extra_channels.reserve(ib.metadata()->extra_channel_info.size()); + for (size_t i = 0; i < ib.metadata()->extra_channel_info.size(); i++) { + extra_channels.emplace_back(ib.xsize(), ib.ysize()); + // Must initialize the image with data to not affect blending with + // uninitialized memory. + // TODO(lode): patches must copy and use the real extra channels instead. + ZeroFillImage(&extra_channels.back()); + } + ib.SetExtraChannels(std::move(extra_channels)); + } + auto special_frame = std::unique_ptr<BitWriter>(new BitWriter()); + AuxOut patch_aux_out; + JXL_CHECK(EncodeFrame(cparams, patch_frame_info, state->shared.metadata, ib, + cms, pool, special_frame.get(), + aux_out ? &patch_aux_out : nullptr)); + if (aux_out) { + for (const auto& l : patch_aux_out.layers) { + aux_out->layers[kLayerDictionary].Assimilate(l); + } + } + const Span<const uint8_t> encoded = special_frame->GetSpan(); + state->special_frames.emplace_back(std::move(special_frame)); + if (subtract) { + ImageBundle decoded(&state->shared.metadata->m); + PassesDecoderState dec_state; + JXL_CHECK(dec_state.output_encoding_info.SetFromMetadata( + *state->shared.metadata)); + const uint8_t* frame_start = encoded.data(); + size_t encoded_size = encoded.size(); + JXL_CHECK(DecodeFrame(&dec_state, pool, frame_start, encoded_size, + /*frame_header=*/nullptr, &decoded, + *state->shared.metadata)); + frame_start += decoded.decoded_bytes(); + encoded_size -= decoded.decoded_bytes(); + size_t ref_xsize = + dec_state.shared_storage.reference_frames[idx].frame.color()->xsize(); + // if the frame itself uses patches, we need to decode another frame + if (!ref_xsize) { + JXL_CHECK(DecodeFrame(&dec_state, pool, frame_start, encoded_size, + /*frame_header=*/nullptr, &decoded, + *state->shared.metadata)); + } + JXL_CHECK(encoded_size == 0); + state->shared.reference_frames[idx] = + std::move(dec_state.shared_storage.reference_frames[idx]); + } else { + state->shared.reference_frames[idx].frame = std::move(ib); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.h b/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.h new file mode 100644 index 0000000000..e17bfe4f04 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_patch_dictionary.h @@ -0,0 +1,107 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_PATCH_DICTIONARY_H_ +#define LIB_JXL_ENC_PATCH_DICTIONARY_H_ + +// Chooses reference patches, and avoids encoding them once per occurrence. + +#include <stddef.h> +#include <string.h> +#include <sys/types.h> + +#include <tuple> +#include <vector> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" + +namespace jxl { + +struct AuxOut; + +constexpr size_t kMaxPatchSize = 32; + +struct QuantizedPatch { + size_t xsize; + size_t ysize; + QuantizedPatch() { + for (size_t i = 0; i < 3; i++) { + pixels[i].resize(kMaxPatchSize * kMaxPatchSize); + fpixels[i].resize(kMaxPatchSize * kMaxPatchSize); + } + } + std::vector<int8_t> pixels[3] = {}; + // Not compared. Used only to retrieve original pixels to construct the + // reference image. + std::vector<float> fpixels[3] = {}; + bool operator==(const QuantizedPatch& other) const { + if (xsize != other.xsize) return false; + if (ysize != other.ysize) return false; + for (size_t c = 0; c < 3; c++) { + if (memcmp(pixels[c].data(), other.pixels[c].data(), + sizeof(int8_t) * xsize * ysize) != 0) + return false; + } + return true; + } + + bool operator<(const QuantizedPatch& other) const { + if (xsize != other.xsize) return xsize < other.xsize; + if (ysize != other.ysize) return ysize < other.ysize; + for (size_t c = 0; c < 3; c++) { + int cmp = memcmp(pixels[c].data(), other.pixels[c].data(), + sizeof(int8_t) * xsize * ysize); + if (cmp > 0) return false; + if (cmp < 0) return true; + } + return false; + } +}; + +// Pair (patch, vector of occurrences). +using PatchInfo = + std::pair<QuantizedPatch, std::vector<std::pair<uint32_t, uint32_t>>>; + +// Friend class of PatchDictionary. +class PatchDictionaryEncoder { + public: + // Only call if HasAny(). + static void Encode(const PatchDictionary& pdic, BitWriter* writer, + size_t layer, AuxOut* aux_out); + + static void SetPositions(PatchDictionary* pdic, + std::vector<PatchPosition> positions, + std::vector<PatchReferencePosition> ref_positions, + std::vector<PatchBlending> blendings) { + pdic->positions_ = std::move(positions); + pdic->ref_positions_ = std::move(ref_positions); + pdic->blendings_ = std::move(blendings); + pdic->ComputePatchTree(); + } + + static void SubtractFrom(const PatchDictionary& pdic, Image3F* opsin); +}; + +void FindBestPatchDictionary(const Image3F& opsin, + PassesEncoderState* JXL_RESTRICT state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out, bool is_xyb = true); + +void RoundtripPatchFrame(Image3F* reference_frame, + PassesEncoderState* JXL_RESTRICT state, int idx, + CompressParams& cparams, const JxlCmsInterface& cms, + ThreadPool* pool, AuxOut* aux_out, bool subtract); + +} // namespace jxl + +#endif // LIB_JXL_ENC_PATCH_DICTIONARY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_photon_noise.cc b/third_party/jpeg-xl/lib/jxl/enc_photon_noise.cc new file mode 100644 index 0000000000..1933435753 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_photon_noise.cc @@ -0,0 +1,95 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_photon_noise.h" + +#include <algorithm> + +#include "lib/jxl/cms/opsin_params.h" + +namespace jxl { + +namespace { + +// Assumes a daylight-like spectrum. +// https://www.strollswithmydog.com/effective-quantum-efficiency-of-sensor/#:~:text=11%2C260%20photons/um%5E2/lx-s +constexpr float kPhotonsPerLxSPerUm2 = 11260; + +// Order of magnitude for cameras in the 2010-2020 decade, taking the CFA into +// account. +constexpr float kEffectiveQuantumEfficiency = 0.20; + +// TODO(sboukortt): reevaluate whether these are good defaults, notably whether +// it would be worth making read noise higher at lower ISO settings. +constexpr float kPhotoResponseNonUniformity = 0.005; +constexpr float kInputReferredReadNoise = 3; + +// Assumes a 35mm sensor. +constexpr float kSensorAreaUm2 = 36000.f * 24000; + +template <typename T> +inline constexpr T Square(const T x) { + return x * x; +} +template <typename T> +inline constexpr T Cube(const T x) { + return x * x * x; +} + +} // namespace + +NoiseParams SimulatePhotonNoise(const size_t xsize, const size_t ysize, + const float iso) { + const float kOpsinAbsorbanceBiasCbrt = + std::cbrt(jxl::cms::kOpsinAbsorbanceBias[1]); + + // Focal plane exposure for 18% of kDefaultIntensityTarget, in lx·s. + // (ISO = 10 lx·s ÷ H) + const float h_18 = 10 / iso; + + const float pixel_area_um2 = kSensorAreaUm2 / (xsize * ysize); + + const float electrons_per_pixel_18 = kEffectiveQuantumEfficiency * + kPhotonsPerLxSPerUm2 * h_18 * + pixel_area_um2; + + NoiseParams params; + + for (size_t i = 0; i < NoiseParams::kNumNoisePoints; ++i) { + const float scaled_index = i / (NoiseParams::kNumNoisePoints - 2.f); + // scaled_index is used for XYB = (0, 2·scaled_index, 2·scaled_index) + const float y = 2 * scaled_index; + // 1 = default intensity target + const float linear = std::max(0.f, Cube(y - kOpsinAbsorbanceBiasCbrt) + + jxl::cms::kOpsinAbsorbanceBias[1]); + const float electrons_per_pixel = electrons_per_pixel_18 * (linear / 0.18f); + // Quadrature sum of read noise, photon shot noise (sqrt(S) so simply not + // squared here) and photo response non-uniformity. + // https://doi.org/10.1117/3.725073 + // Units are electrons rms. + const float noise = + std::sqrt(Square(kInputReferredReadNoise) + electrons_per_pixel + + Square(kPhotoResponseNonUniformity * electrons_per_pixel)); + const float linear_noise = noise * (0.18f / electrons_per_pixel_18); + const float opsin_derivative = + (1.f / 3) / + Square(std::cbrt(linear - jxl::cms::kOpsinAbsorbanceBias[1])); + const float opsin_noise = linear_noise * opsin_derivative; + + // TODO(sboukortt): verify more thoroughly whether the denominator is + // correct. + params.lut[i] = + Clamp1(opsin_noise / + (0.22f // norm_const + * std::sqrt(2.f) // red_noise + green_noise + * 1.13f // standard deviation of a plane of generated noise + ), + 0.f, 1.f); + } + + return params; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_photon_noise.h b/third_party/jpeg-xl/lib/jxl/enc_photon_noise.h new file mode 100644 index 0000000000..f43e14d560 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_photon_noise.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_PHOTON_NOISE_H_ +#define LIB_JXL_ENC_PHOTON_NOISE_H_ + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/noise.h" + +namespace jxl { + +// Constructs a NoiseParams representing the noise that would be seen at the +// selected nominal exposure on a last-decade (as of 2021) color camera with a +// 36×24mm sensor (“35mm format”). +NoiseParams SimulatePhotonNoise(size_t xsize, size_t ysize, float iso); + +} // namespace jxl + +#endif // LIB_JXL_ENC_PHOTON_NOISE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_photon_noise_test.cc b/third_party/jpeg-xl/lib/jxl/enc_photon_noise_test.cc new file mode 100644 index 0000000000..be11b465ad --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_photon_noise_test.cc @@ -0,0 +1,51 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_photon_noise.h" + +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +using ::testing::FloatNear; +using ::testing::Pointwise; + +MATCHER(AreApproximatelyEqual, "") { + constexpr float kTolerance = 1e-6; + const float actual = std::get<0>(arg); + const float expected = std::get<1>(arg); + return testing::ExplainMatchResult(FloatNear(expected, kTolerance), actual, + result_listener); +} + +TEST(EncPhotonNoiseTest, LUTs) { + EXPECT_THAT( + SimulatePhotonNoise(/*xsize=*/6000, /*ysize=*/4000, /*iso=*/100).lut, + Pointwise(AreApproximatelyEqual(), + {0.00259652, 0.0139648, 0.00681551, 0.00632582, 0.00694917, + 0.00803922, 0.00934574, 0.0107607})); + EXPECT_THAT( + SimulatePhotonNoise(/*xsize=*/6000, /*ysize=*/4000, /*iso=*/800).lut, + Pointwise(AreApproximatelyEqual(), + {0.02077220, 0.0420923, 0.01820690, 0.01439020, 0.01293670, + 0.01254030, 0.01277390, 0.0134161})); + EXPECT_THAT( + SimulatePhotonNoise(/*xsize=*/6000, /*ysize=*/4000, /*iso=*/6400).lut, + Pointwise(AreApproximatelyEqual(), + {0.1661770, 0.1691120, 0.05309080, 0.03963960, 0.03357410, + 0.03001650, 0.02776740, 0.0263478})); + + // Lower when measured on a per-pixel basis as there are fewer of them. + EXPECT_THAT( + SimulatePhotonNoise(/*xsize=*/4000, /*ysize=*/3000, /*iso=*/6400).lut, + Pointwise(AreApproximatelyEqual(), + {0.0830886, 0.1008720, 0.0367748, 0.0280305, 0.0240236, + 0.0218040, 0.0205771, 0.0200058})); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_progressive_split.cc b/third_party/jpeg-xl/lib/jxl/enc_progressive_split.cc new file mode 100644 index 0000000000..811c9455c2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_progressive_split.cc @@ -0,0 +1,82 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_progressive_split.h" + +#include <string.h> + +#include <algorithm> +#include <memory> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/common.h" // kMaxNumPasses +#include "lib/jxl/image.h" + +namespace jxl { + +template <typename T> +void ProgressiveSplitter::SplitACCoefficients( + const T* JXL_RESTRICT block, const AcStrategy& acs, size_t bx, size_t by, + T* JXL_RESTRICT output[kMaxNumPasses]) { + size_t size = acs.covered_blocks_x() * acs.covered_blocks_y() * kDCTBlockSize; + auto shift_right_round0 = [&](T v, int shift) { + T one_if_negative = static_cast<uint32_t>(v) >> 31; + T add = (one_if_negative << shift) - one_if_negative; + return (v + add) >> shift; + }; + // Early quit for the simple case of only one pass. + if (mode_.num_passes == 1) { + memcpy(output[0], block, sizeof(T) * size); + return; + } + size_t ncoeffs_all_done_from_earlier_passes = 1; + + int previous_pass_shift = 0; + for (size_t num_pass = 0; num_pass < mode_.num_passes; num_pass++) { // pass + // Zero out output block. + memset(output[num_pass], 0, size * sizeof(T)); + const int pass_shift = mode_.passes[num_pass].shift; + size_t frame_ncoeffs = mode_.passes[num_pass].num_coefficients; + size_t xsize = acs.covered_blocks_x(); + size_t ysize = acs.covered_blocks_y(); + CoefficientLayout(&ysize, &xsize); + for (size_t y = 0; y < ysize * frame_ncoeffs; y++) { // superblk-y + for (size_t x = 0; x < xsize * frame_ncoeffs; x++) { // superblk-x + size_t pos = y * xsize * kBlockDim + x; + if (x < xsize * ncoeffs_all_done_from_earlier_passes && + y < ysize * ncoeffs_all_done_from_earlier_passes) { + // This coefficient was already included in an earlier pass, + // which included a genuinely smaller set of coefficients. + continue; + } + T v = block[pos]; + // Previous pass discarded some bits: do not encode them again. + if (previous_pass_shift != 0) { + T previous_v = shift_right_round0(v, previous_pass_shift) * + (1 << previous_pass_shift); + v -= previous_v; + } + output[num_pass][pos] = shift_right_round0(v, pass_shift); + } // superblk-x + } // superblk-y + // We just finished a pass. + // Hence, we are now guaranteed to have included all coeffs up to + // frame_ncoeffs in every block, unless the current pass is shifted. + if (mode_.passes[num_pass].shift == 0) { + ncoeffs_all_done_from_earlier_passes = frame_ncoeffs; + } + previous_pass_shift = mode_.passes[num_pass].shift; + } // num_pass +} + +template void ProgressiveSplitter::SplitACCoefficients<int32_t>( + const int32_t* JXL_RESTRICT, const AcStrategy&, size_t, size_t, + int32_t* JXL_RESTRICT[kMaxNumPasses]); + +template void ProgressiveSplitter::SplitACCoefficients<int16_t>( + const int16_t* JXL_RESTRICT, const AcStrategy&, size_t, size_t, + int16_t* JXL_RESTRICT[kMaxNumPasses]); + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_progressive_split.h b/third_party/jpeg-xl/lib/jxl/enc_progressive_split.h new file mode 100644 index 0000000000..06584fa916 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_progressive_split.h @@ -0,0 +1,131 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_PROGRESSIVE_SPLIT_H_ +#define LIB_JXL_PROGRESSIVE_SPLIT_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <limits> +#include <memory> +#include <vector> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" // kMaxNumPasses +#include "lib/jxl/dct_util.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/splines.h" + +// Functions to split DCT coefficients in multiple passes. All the passes of a +// single frame are added together. + +namespace jxl { + +constexpr size_t kNoDownsamplingFactor = std::numeric_limits<size_t>::max(); + +struct PassDefinition { + // Side of the square of the coefficients that should be kept in each 8x8 + // block. Must be greater than 1, and at most 8. Should be in non-decreasing + // order. + size_t num_coefficients; + + // How much to shift the encoded values by, with rounding. + size_t shift; + + // If specified, this indicates that if the requested downsampling factor is + // sufficiently high, then it is fine to stop decoding after this pass. + // By default, passes are not marked as being suitable for any downsampling. + size_t suitable_for_downsampling_of_at_least; +}; + +struct ProgressiveMode { + size_t num_passes = 1; + PassDefinition passes[kMaxNumPasses] = { + PassDefinition{/*num_coefficients=*/8, /*shift=*/0, + /*suitable_for_downsampling_of_at_least=*/1}}; + + ProgressiveMode() = default; + + template <size_t nump> + explicit ProgressiveMode(const PassDefinition (&p)[nump]) { + JXL_ASSERT(nump <= kMaxNumPasses); + num_passes = nump; + PassDefinition previous_pass{ + /*num_coefficients=*/1, /*shift=*/0, + /*suitable_for_downsampling_of_at_least=*/kNoDownsamplingFactor}; + size_t last_downsampling_factor = kNoDownsamplingFactor; + for (size_t i = 0; i < nump; i++) { + JXL_ASSERT(p[i].num_coefficients > previous_pass.num_coefficients || + (p[i].num_coefficients == previous_pass.num_coefficients && + p[i].shift < previous_pass.shift)); + JXL_ASSERT(p[i].suitable_for_downsampling_of_at_least == + kNoDownsamplingFactor || + p[i].suitable_for_downsampling_of_at_least <= + last_downsampling_factor); + // Only used inside assert. + (void)last_downsampling_factor; + if (p[i].suitable_for_downsampling_of_at_least != kNoDownsamplingFactor) { + last_downsampling_factor = p[i].suitable_for_downsampling_of_at_least; + } + previous_pass = passes[i] = p[i]; + } + } +}; + +class ProgressiveSplitter { + public: + void SetProgressiveMode(ProgressiveMode mode) { mode_ = mode; } + + size_t GetNumPasses() const { return mode_.num_passes; } + + void InitPasses(Passes* JXL_RESTRICT passes) const { + passes->num_passes = static_cast<uint32_t>(GetNumPasses()); + passes->num_downsample = 0; + JXL_ASSERT(passes->num_passes != 0); + passes->shift[passes->num_passes - 1] = 0; + if (passes->num_passes == 1) return; // Done, arrays are empty + + for (uint32_t i = 0; i < mode_.num_passes - 1; ++i) { + const size_t min_downsampling_factor = + mode_.passes[i].suitable_for_downsampling_of_at_least; + passes->shift[i] = mode_.passes[i].shift; + if (1 < min_downsampling_factor && + min_downsampling_factor != kNoDownsamplingFactor) { + passes->downsample[passes->num_downsample] = min_downsampling_factor; + passes->last_pass[passes->num_downsample] = i; + if (mode_.passes[i + 1].suitable_for_downsampling_of_at_least < + min_downsampling_factor) { + passes->num_downsample += 1; + } + } + } + } + + template <typename T> + void SplitACCoefficients(const T* JXL_RESTRICT block, const AcStrategy& acs, + size_t bx, size_t by, + T* JXL_RESTRICT output[kMaxNumPasses]); + + private: + ProgressiveMode mode_; +}; + +extern template void ProgressiveSplitter::SplitACCoefficients<int32_t>( + const int32_t* JXL_RESTRICT, const AcStrategy&, size_t, size_t, + int32_t* JXL_RESTRICT[kMaxNumPasses]); + +extern template void ProgressiveSplitter::SplitACCoefficients<int16_t>( + const int16_t* JXL_RESTRICT, const AcStrategy&, size_t, size_t, + int16_t* JXL_RESTRICT[kMaxNumPasses]); + +} // namespace jxl + +#endif // LIB_JXL_PROGRESSIVE_SPLIT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_quant_weights.cc b/third_party/jpeg-xl/lib/jxl/enc_quant_weights.cc new file mode 100644 index 0000000000..236ddaacfd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_quant_weights.cc @@ -0,0 +1,213 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_quant_weights.h" + +#include <stdlib.h> + +#include <algorithm> +#include <cmath> +#include <limits> +#include <utility> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +struct AuxOut; + +namespace { + +Status EncodeDctParams(const DctQuantWeightParams& params, BitWriter* writer) { + JXL_ASSERT(params.num_distance_bands >= 1); + writer->Write(DctQuantWeightParams::kLog2MaxDistanceBands, + params.num_distance_bands - 1); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < params.num_distance_bands; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Write( + params.distance_bands[c][i] * (i == 0 ? (1 / 64.0f) : 1.0f), writer)); + } + } + return true; +} + +Status EncodeQuant(const QuantEncoding& encoding, size_t idx, size_t size_x, + size_t size_y, BitWriter* writer, + ModularFrameEncoder* modular_frame_encoder) { + writer->Write(kLog2NumQuantModes, encoding.mode); + size_x *= kBlockDim; + size_y *= kBlockDim; + switch (encoding.mode) { + case QuantEncoding::kQuantModeLibrary: { + writer->Write(kCeilLog2NumPredefinedTables, encoding.predefined); + break; + } + case QuantEncoding::kQuantModeID: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 3; i++) { + JXL_RETURN_IF_ERROR( + F16Coder::Write(encoding.idweights[c][i] * (1.0f / 64), writer)); + } + } + break; + } + case QuantEncoding::kQuantModeDCT2: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 6; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Write( + encoding.dct2weights[c][i] * (1.0f / 64), writer)); + } + } + break; + } + case QuantEncoding::kQuantModeDCT4X8: { + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR( + F16Coder::Write(encoding.dct4x8multipliers[c], writer)); + } + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + break; + } + case QuantEncoding::kQuantModeDCT4: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 2; i++) { + JXL_RETURN_IF_ERROR( + F16Coder::Write(encoding.dct4multipliers[c][i], writer)); + } + } + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + break; + } + case QuantEncoding::kQuantModeDCT: { + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + break; + } + case QuantEncoding::kQuantModeRAW: { + ModularFrameEncoder::EncodeQuantTable(size_x, size_y, writer, encoding, + idx, modular_frame_encoder); + break; + } + case QuantEncoding::kQuantModeAFV: { + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 9; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Write( + encoding.afv_weights[c][i] * (i < 6 ? 1.0f / 64 : 1.0f), writer)); + } + } + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); + JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params_afv_4x4, writer)); + break; + } + } + return true; +} + +} // namespace + +Status DequantMatricesEncode(const DequantMatrices& matrices, BitWriter* writer, + size_t layer, AuxOut* aux_out, + ModularFrameEncoder* modular_frame_encoder) { + bool all_default = true; + const std::vector<QuantEncoding>& encodings = matrices.encodings(); + + for (size_t i = 0; i < encodings.size(); i++) { + if (encodings[i].mode != QuantEncoding::kQuantModeLibrary || + encodings[i].predefined != 0) { + all_default = false; + } + } + // TODO(janwas): better bound + BitWriter::Allotment allotment(writer, 512 * 1024); + writer->Write(1, all_default); + if (!all_default) { + for (size_t i = 0; i < encodings.size(); i++) { + JXL_RETURN_IF_ERROR(EncodeQuant( + encodings[i], i, DequantMatrices::required_size_x[i], + DequantMatrices::required_size_y[i], writer, modular_frame_encoder)); + } + } + allotment.ReclaimAndCharge(writer, layer, aux_out); + return true; +} + +Status DequantMatricesEncodeDC(const DequantMatrices& matrices, + BitWriter* writer, size_t layer, + AuxOut* aux_out) { + bool all_default = true; + const float* dc_quant = matrices.DCQuants(); + for (size_t c = 0; c < 3; c++) { + if (dc_quant[c] != kDCQuant[c]) { + all_default = false; + } + } + BitWriter::Allotment allotment(writer, 1 + sizeof(float) * kBitsPerByte * 3); + writer->Write(1, all_default); + if (!all_default) { + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR(F16Coder::Write(dc_quant[c] * 128.0f, writer)); + } + } + allotment.ReclaimAndCharge(writer, layer, aux_out); + return true; +} + +void DequantMatricesSetCustomDC(DequantMatrices* matrices, const float* dc) { + matrices->SetDCQuant(dc); + // Roundtrip encode/decode DC to ensure same values as decoder. + BitWriter writer; + JXL_CHECK(DequantMatricesEncodeDC(*matrices, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + BitReader br(writer.GetSpan()); + // Called only in the encoder: should fail only for programmer errors. + JXL_CHECK(matrices->DecodeDC(&br)); + JXL_CHECK(br.Close()); +} + +void DequantMatricesScaleDC(DequantMatrices* matrices, const float scale) { + float dc[3]; + for (size_t c = 0; c < 3; ++c) { + dc[c] = matrices->InvDCQuant(c) * (1.0f / scale); + } + DequantMatricesSetCustomDC(matrices, dc); +} + +void DequantMatricesRoundtrip(DequantMatrices* matrices) { + // Do not pass modular en/decoder, as they only change entropy and not + // values. + BitWriter writer; + JXL_CHECK(DequantMatricesEncode(*matrices, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + BitReader br(writer.GetSpan()); + // Called only in the encoder: should fail only for programmer errors. + JXL_CHECK(matrices->Decode(&br)); + JXL_CHECK(br.Close()); +} + +void DequantMatricesSetCustom(DequantMatrices* matrices, + const std::vector<QuantEncoding>& encodings, + ModularFrameEncoder* encoder) { + JXL_ASSERT(encodings.size() == DequantMatrices::kNum); + matrices->SetEncodings(encodings); + for (size_t i = 0; i < encodings.size(); i++) { + if (encodings[i].mode == QuantEncodingInternal::kQuantModeRAW) { + encoder->AddQuantTable(DequantMatrices::required_size_x[i] * kBlockDim, + DequantMatrices::required_size_y[i] * kBlockDim, + encodings[i], i); + } + } + DequantMatricesRoundtrip(matrices); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_quant_weights.h b/third_party/jpeg-xl/lib/jxl/enc_quant_weights.h new file mode 100644 index 0000000000..a47dfd4988 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_quant_weights.h @@ -0,0 +1,39 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_QUANT_WEIGHTS_H_ +#define LIB_JXL_ENC_QUANT_WEIGHTS_H_ + +#include <cstddef> + +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +struct AuxOut; +struct BitWriter; + +Status DequantMatricesEncode( + const DequantMatrices& matrices, BitWriter* writer, size_t layer, + AuxOut* aux_out, ModularFrameEncoder* modular_frame_encoder = nullptr); +Status DequantMatricesEncodeDC(const DequantMatrices& matrices, + BitWriter* writer, size_t layer, + AuxOut* aux_out); +// For consistency with QuantEncoding, higher values correspond to more +// precision. +void DequantMatricesSetCustomDC(DequantMatrices* matrices, const float* dc); + +void DequantMatricesScaleDC(DequantMatrices* matrices, float scale); + +void DequantMatricesSetCustom(DequantMatrices* matrices, + const std::vector<QuantEncoding>& encodings, + ModularFrameEncoder* encoder); + +// Roundtrip encode/decode the matrices to ensure same values as decoder. +void DequantMatricesRoundtrip(DequantMatrices* matrices); + +} // namespace jxl + +#endif // LIB_JXL_ENC_QUANT_WEIGHTS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_splines.cc b/third_party/jpeg-xl/lib/jxl/enc_splines.cc new file mode 100644 index 0000000000..de6c9670ea --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_splines.cc @@ -0,0 +1,97 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <algorithm> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/pack_signed.h" +#include "lib/jxl/splines.h" + +namespace jxl { + +struct AuxOut; + +class QuantizedSplineEncoder { + public: + // Only call if HasAny(). + static void Tokenize(const QuantizedSpline& spline, + std::vector<Token>* const tokens) { + tokens->emplace_back(kNumControlPointsContext, + spline.control_points_.size()); + for (const auto& point : spline.control_points_) { + tokens->emplace_back(kControlPointsContext, PackSigned(point.first)); + tokens->emplace_back(kControlPointsContext, PackSigned(point.second)); + } + const auto encode_dct = [tokens](const int dct[32]) { + for (int i = 0; i < 32; ++i) { + tokens->emplace_back(kDCTContext, PackSigned(dct[i])); + } + }; + for (int c = 0; c < 3; ++c) { + encode_dct(spline.color_dct_[c]); + } + encode_dct(spline.sigma_dct_); + } +}; + +namespace { + +void EncodeAllStartingPoints(const std::vector<Spline::Point>& points, + std::vector<Token>* tokens) { + int64_t last_x = 0; + int64_t last_y = 0; + for (size_t i = 0; i < points.size(); i++) { + const int64_t x = lroundf(points[i].x); + const int64_t y = lroundf(points[i].y); + if (i == 0) { + tokens->emplace_back(kStartingPositionContext, x); + tokens->emplace_back(kStartingPositionContext, y); + } else { + tokens->emplace_back(kStartingPositionContext, PackSigned(x - last_x)); + tokens->emplace_back(kStartingPositionContext, PackSigned(y - last_y)); + } + last_x = x; + last_y = y; + } +} + +} // namespace + +void EncodeSplines(const Splines& splines, BitWriter* writer, + const size_t layer, const HistogramParams& histogram_params, + AuxOut* aux_out) { + JXL_ASSERT(splines.HasAny()); + + const std::vector<QuantizedSpline>& quantized_splines = + splines.QuantizedSplines(); + std::vector<std::vector<Token>> tokens(1); + tokens[0].emplace_back(kNumSplinesContext, quantized_splines.size() - 1); + EncodeAllStartingPoints(splines.StartingPoints(), &tokens[0]); + + tokens[0].emplace_back(kQuantizationAdjustmentContext, + PackSigned(splines.GetQuantizationAdjustment())); + + for (const QuantizedSpline& spline : quantized_splines) { + QuantizedSplineEncoder::Tokenize(spline, &tokens[0]); + } + + EntropyEncodingData codes; + std::vector<uint8_t> context_map; + BuildAndEncodeHistograms(histogram_params, kNumSplineContexts, tokens, &codes, + &context_map, writer, layer, aux_out); + WriteTokens(tokens[0], codes, context_map, 0, writer, layer, aux_out); +} + +Splines FindSplines(const Image3F& opsin) { + // TODO(user): implement spline detection. + return {}; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_splines.h b/third_party/jpeg-xl/lib/jxl/enc_splines.h new file mode 100644 index 0000000000..3f6ecc7c4f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_splines.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_SPLINES_H_ +#define LIB_JXL_ENC_SPLINES_H_ + +#include <cstddef> + +#include "lib/jxl/enc_ans_params.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/image.h" +#include "lib/jxl/splines.h" + +namespace jxl { + +struct AuxOut; + +// Only call if splines.HasAny(). +void EncodeSplines(const Splines& splines, BitWriter* writer, size_t layer, + const HistogramParams& histogram_params, AuxOut* aux_out); + +Splines FindSplines(const Image3F& opsin); + +} // namespace jxl + +#endif // LIB_JXL_ENC_SPLINES_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_toc.cc b/third_party/jpeg-xl/lib/jxl/enc_toc.cc new file mode 100644 index 0000000000..e79298ef31 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_toc.cc @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_toc.h" + +#include <stdint.h> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_coeff_order.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/toc.h" + +namespace jxl { +Status WriteGroupOffsets(const std::vector<BitWriter>& group_codes, + const std::vector<coeff_order_t>& permutation, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out) { + BitWriter::Allotment allotment(writer, MaxBits(group_codes.size())); + if (!permutation.empty() && !group_codes.empty()) { + // Don't write a permutation at all for an empty group_codes. + writer->Write(1, 1); // permutation + JXL_DASSERT(permutation.size() == group_codes.size()); + EncodePermutation(permutation.data(), /*skip=*/0, permutation.size(), + writer, /* layer= */ 0, aux_out); + + } else { + writer->Write(1, 0); // no permutation + } + writer->ZeroPadToByte(); // before TOC entries + + for (size_t i = 0; i < group_codes.size(); i++) { + JXL_ASSERT(group_codes[i].BitsWritten() % kBitsPerByte == 0); + const size_t group_size = group_codes[i].BitsWritten() / kBitsPerByte; + JXL_RETURN_IF_ERROR(U32Coder::Write(kTocDist, group_size, writer)); + } + writer->ZeroPadToByte(); // before first group + allotment.ReclaimAndCharge(writer, kLayerTOC, aux_out); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_toc.h b/third_party/jpeg-xl/lib/jxl/enc_toc.h new file mode 100644 index 0000000000..aa222141be --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_toc.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_TOC_H_ +#define LIB_JXL_ENC_TOC_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/enc_bit_writer.h" + +namespace jxl { + +struct AuxOut; + +// Writes the group offsets. If the permutation vector is empty, the identity +// permutation will be used. +Status WriteGroupOffsets(const std::vector<BitWriter>& group_codes, + const std::vector<coeff_order_t>& permutation, + BitWriter* JXL_RESTRICT writer, AuxOut* aux_out); + +} // namespace jxl + +#endif // LIB_JXL_ENC_TOC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_transforms-inl.h b/third_party/jpeg-xl/lib/jxl/enc_transforms-inl.h new file mode 100644 index 0000000000..469072eebd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_transforms-inl.h @@ -0,0 +1,800 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_ENC_TRANSFORMS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_ENC_TRANSFORMS_INL_H_ +#undef LIB_JXL_ENC_TRANSFORMS_INL_H_ +#else +#define LIB_JXL_ENC_TRANSFORMS_INL_H_ +#endif + +#include <stddef.h> + +#include <hwy/highway.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/dct_scales.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// Inverse of ReinterpretingDCT. +template <size_t DCT_ROWS, size_t DCT_COLS, size_t LF_ROWS, size_t LF_COLS, + size_t ROWS, size_t COLS> +HWY_INLINE void ReinterpretingIDCT(const float* input, + const size_t input_stride, float* output, + const size_t output_stride) { + HWY_ALIGN float block[ROWS * COLS] = {}; + if (ROWS < COLS) { + for (size_t y = 0; y < LF_ROWS; y++) { + for (size_t x = 0; x < LF_COLS; x++) { + block[y * COLS + x] = input[y * input_stride + x] * + DCTTotalResampleScale<DCT_ROWS, ROWS>(y) * + DCTTotalResampleScale<DCT_COLS, COLS>(x); + } + } + } else { + for (size_t y = 0; y < LF_COLS; y++) { + for (size_t x = 0; x < LF_ROWS; x++) { + block[y * ROWS + x] = input[y * input_stride + x] * + DCTTotalResampleScale<DCT_COLS, COLS>(y) * + DCTTotalResampleScale<DCT_ROWS, ROWS>(x); + } + } + } + + // ROWS, COLS <= 8, so we can put scratch space on the stack. + HWY_ALIGN float scratch_space[ROWS * COLS * 3]; + ComputeScaledIDCT<ROWS, COLS>()(block, DCTTo(output, output_stride), + scratch_space); +} + +template <size_t S> +void DCT2TopBlock(const float* block, size_t stride, float* out) { + static_assert(kBlockDim % S == 0, "S should be a divisor of kBlockDim"); + static_assert(S % 2 == 0, "S should be even"); + float temp[kDCTBlockSize]; + constexpr size_t num_2x2 = S / 2; + for (size_t y = 0; y < num_2x2; y++) { + for (size_t x = 0; x < num_2x2; x++) { + float c00 = block[y * 2 * stride + x * 2]; + float c01 = block[y * 2 * stride + x * 2 + 1]; + float c10 = block[(y * 2 + 1) * stride + x * 2]; + float c11 = block[(y * 2 + 1) * stride + x * 2 + 1]; + float r00 = c00 + c01 + c10 + c11; + float r01 = c00 + c01 - c10 - c11; + float r10 = c00 - c01 + c10 - c11; + float r11 = c00 - c01 - c10 + c11; + r00 *= 0.25f; + r01 *= 0.25f; + r10 *= 0.25f; + r11 *= 0.25f; + temp[y * kBlockDim + x] = r00; + temp[y * kBlockDim + num_2x2 + x] = r01; + temp[(y + num_2x2) * kBlockDim + x] = r10; + temp[(y + num_2x2) * kBlockDim + num_2x2 + x] = r11; + } + } + for (size_t y = 0; y < S; y++) { + for (size_t x = 0; x < S; x++) { + out[y * kBlockDim + x] = temp[y * kBlockDim + x]; + } + } +} + +void AFVDCT4x4(const float* JXL_RESTRICT pixels, float* JXL_RESTRICT coeffs) { + HWY_ALIGN static constexpr float k4x4AFVBasisTranspose[16][16] = { + { + 0.2500000000000000, + 0.8769029297991420f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + -0.4105377591765233f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + }, + { + 0.2500000000000000, + 0.2206518106944235f, + 0.0000000000000000, + 0.0000000000000000, + -0.7071067811865474f, + 0.6235485373547691f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + 0.4067007583026075f, + -0.2125574805828875f, + 0.0000000000000000, + -0.0643507165794627f, + -0.4517556589999482f, + -0.3046847507248690f, + 0.3017929516615495f, + 0.4082482904638627f, + 0.1747866975480809f, + -0.2110560104933578f, + -0.1426608480880726f, + -0.1381354035075859f, + -0.1743760259965107f, + 0.1135498731499434f, + }, + { + 0.2500000000000000, + -0.1014005039375375f, + 0.4444481661973445f, + 0.3085497062849767f, + 0.0000000000000000f, + -0.0643507165794627f, + 0.1585450355184006f, + 0.5112616136591823f, + 0.2579236279634118f, + 0.0000000000000000, + 0.0812611176717539f, + 0.1856718091610980f, + -0.3416446842253372f, + 0.3302282550303788f, + 0.0702790691196284f, + -0.0741750459581035f, + }, + { + 0.2500000000000000, + 0.2206518106944236f, + 0.0000000000000000, + 0.0000000000000000, + 0.7071067811865476f, + 0.6235485373547694f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + }, + { + 0.2500000000000000, + -0.1014005039375378f, + 0.0000000000000000, + 0.4706702258572536f, + 0.0000000000000000, + -0.0643507165794628f, + -0.0403851516082220f, + 0.0000000000000000, + 0.1627234014286620f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.7367497537172237f, + 0.0875511500058708f, + -0.2921026642334881f, + 0.1940289303259434f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + 0.1957439937204294f, + -0.1621205195722993f, + 0.0000000000000000, + -0.0643507165794628f, + 0.0074182263792424f, + -0.2904801297289980f, + 0.0952002265347504f, + 0.0000000000000000, + -0.3675398009862027f, + 0.4921585901373873f, + 0.2462710772207515f, + -0.0794670660590957f, + 0.3623817333531167f, + -0.4351904965232280f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + 0.2929100136981264f, + 0.0000000000000000, + 0.0000000000000000, + -0.0643507165794627f, + 0.3935103426921017f, + -0.0657870154914280f, + 0.0000000000000000, + -0.4082482904638628f, + -0.3078822139579090f, + -0.3852501370925192f, + -0.0857401903551931f, + -0.4613374887461511f, + 0.0000000000000000, + 0.2191868483885747f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + -0.4067007583026072f, + -0.2125574805828705f, + 0.0000000000000000, + -0.0643507165794627f, + -0.4517556589999464f, + 0.3046847507248840f, + 0.3017929516615503f, + -0.4082482904638635f, + -0.1747866975480813f, + 0.2110560104933581f, + -0.1426608480880734f, + -0.1381354035075829f, + -0.1743760259965108f, + 0.1135498731499426f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + -0.1957439937204287f, + -0.1621205195722833f, + 0.0000000000000000, + -0.0643507165794628f, + 0.0074182263792444f, + 0.2904801297290076f, + 0.0952002265347505f, + 0.0000000000000000, + 0.3675398009862011f, + -0.4921585901373891f, + 0.2462710772207514f, + -0.0794670660591026f, + 0.3623817333531165f, + -0.4351904965232251f, + }, + { + 0.2500000000000000, + -0.1014005039375375f, + 0.0000000000000000, + -0.4706702258572528f, + 0.0000000000000000, + -0.0643507165794627f, + 0.1107416575309343f, + 0.0000000000000000, + -0.1627234014286617f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + 0.1488339922711357f, + 0.4972464710953509f, + 0.2921026642334879f, + 0.5550443808910661f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + 0.1137907446044809f, + -0.1464291867126764f, + 0.0000000000000000, + -0.0643507165794628f, + 0.0829816309488205f, + -0.2388977352334460f, + -0.3531238544981630f, + -0.4082482904638630f, + 0.4826689115059883f, + 0.1741941265991622f, + -0.0476868035022925f, + 0.1253805944856366f, + -0.4326608024727445f, + -0.2546827712406646f, + }, + { + 0.2500000000000000, + -0.1014005039375377f, + -0.4444481661973438f, + 0.3085497062849487f, + 0.0000000000000000, + -0.0643507165794628f, + 0.1585450355183970f, + -0.5112616136592012f, + 0.2579236279634129f, + 0.0000000000000000, + -0.0812611176717504f, + -0.1856718091610990f, + -0.3416446842253373f, + 0.3302282550303805f, + 0.0702790691196282f, + -0.0741750459581023f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + -0.2929100136981264f, + 0.0000000000000000, + 0.0000000000000000, + -0.0643507165794627f, + 0.3935103426921022f, + 0.0657870154914254f, + 0.0000000000000000, + 0.4082482904638634f, + 0.3078822139579031f, + 0.3852501370925211f, + -0.0857401903551927f, + -0.4613374887461554f, + 0.0000000000000000, + 0.2191868483885728f, + }, + { + 0.2500000000000000, + -0.1014005039375376f, + -0.1137907446044814f, + -0.1464291867126654f, + 0.0000000000000000, + -0.0643507165794627f, + 0.0829816309488214f, + 0.2388977352334547f, + -0.3531238544981624f, + 0.4082482904638630f, + -0.4826689115059858f, + -0.1741941265991621f, + -0.0476868035022928f, + 0.1253805944856431f, + -0.4326608024727457f, + -0.2546827712406641f, + }, + { + 0.2500000000000000, + -0.1014005039375374f, + 0.0000000000000000, + 0.4251149611657548f, + 0.0000000000000000, + -0.0643507165794626f, + -0.4517556589999480f, + 0.0000000000000000, + -0.6035859033230976f, + 0.0000000000000000, + 0.0000000000000000, + 0.0000000000000000, + -0.1426608480880724f, + -0.1381354035075845f, + 0.3487520519930227f, + 0.1135498731499429f, + }, + }; + + const HWY_CAPPED(float, 16) d; + for (size_t i = 0; i < 16; i += Lanes(d)) { + auto scalar = Zero(d); + for (size_t j = 0; j < 16; j++) { + auto px = Set(d, pixels[j]); + auto basis = Load(d, k4x4AFVBasisTranspose[j] + i); + scalar = MulAdd(px, basis, scalar); + } + Store(scalar, d, coeffs + i); + } +} + +// Coefficient layout: +// - (even, even) positions hold AFV coefficients +// - (odd, even) positions hold DCT4x4 coefficients +// - (any, odd) positions hold DCT4x8 coefficients +template <size_t afv_kind> +void AFVTransformFromPixels(const float* JXL_RESTRICT pixels, + size_t pixels_stride, + float* JXL_RESTRICT coefficients) { + HWY_ALIGN float scratch_space[4 * 8 * 5]; + size_t afv_x = afv_kind & 1; + size_t afv_y = afv_kind / 2; + HWY_ALIGN float block[4 * 8] = {}; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + block[(afv_y == 1 ? 3 - iy : iy) * 4 + (afv_x == 1 ? 3 - ix : ix)] = + pixels[(iy + 4 * afv_y) * pixels_stride + ix + 4 * afv_x]; + } + } + // AFV coefficients in (even, even) positions. + HWY_ALIGN float coeff[4 * 4]; + AFVDCT4x4(block, coeff); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + coefficients[iy * 2 * 8 + ix * 2] = coeff[iy * 4 + ix]; + } + } + // 4x4 DCT of the block with same y and different x. + ComputeScaledDCT<4, 4>()( + DCTFrom(pixels + afv_y * 4 * pixels_stride + (afv_x == 1 ? 0 : 4), + pixels_stride), + block, scratch_space); + // ... in (odd, even) positions. + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + coefficients[iy * 2 * 8 + ix * 2 + 1] = block[iy * 4 + ix]; + } + } + // 4x8 DCT of the other half of the block. + ComputeScaledDCT<4, 8>()( + DCTFrom(pixels + (afv_y == 1 ? 0 : 4) * pixels_stride, pixels_stride), + block, scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + coefficients[(1 + iy * 2) * 8 + ix] = block[iy * 8 + ix]; + } + } + float block00 = coefficients[0] * 0.25f; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + coefficients[0] = (block00 + block01 + 2 * block10) * 0.25f; + coefficients[1] = (block00 - block01) * 0.5f; + coefficients[8] = (block00 + block01 - 2 * block10) * 0.25f; +} + +HWY_MAYBE_UNUSED void TransformFromPixels(const AcStrategy::Type strategy, + const float* JXL_RESTRICT pixels, + size_t pixels_stride, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT scratch_space) { + using Type = AcStrategy::Type; + switch (strategy) { + case Type::IDENTITY: { + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + float block_dc = 0; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + block_dc += pixels[(y * 4 + iy) * pixels_stride + x * 4 + ix]; + } + } + block_dc *= 1.0f / 16; + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + if (ix == 1 && iy == 1) continue; + coefficients[(y + iy * 2) * 8 + x + ix * 2] = + pixels[(y * 4 + iy) * pixels_stride + x * 4 + ix] - + pixels[(y * 4 + 1) * pixels_stride + x * 4 + 1]; + } + } + coefficients[(y + 2) * 8 + x + 2] = coefficients[y * 8 + x]; + coefficients[y * 8 + x] = block_dc; + } + } + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + coefficients[0] = (block00 + block01 + block10 + block11) * 0.25f; + coefficients[1] = (block00 + block01 - block10 - block11) * 0.25f; + coefficients[8] = (block00 - block01 + block10 - block11) * 0.25f; + coefficients[9] = (block00 - block01 - block10 + block11) * 0.25f; + break; + } + case Type::DCT8X4: { + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 8]; + ComputeScaledDCT<8, 4>()(DCTFrom(pixels + x * 4, pixels_stride), block, + scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + // Store transposed. + coefficients[(x + iy * 2) * 8 + ix] = block[iy * 8 + ix]; + } + } + } + float block0 = coefficients[0]; + float block1 = coefficients[8]; + coefficients[0] = (block0 + block1) * 0.5f; + coefficients[8] = (block0 - block1) * 0.5f; + break; + } + case Type::DCT4X8: { + for (size_t y = 0; y < 2; y++) { + HWY_ALIGN float block[4 * 8]; + ComputeScaledDCT<4, 8>()( + DCTFrom(pixels + y * 4 * pixels_stride, pixels_stride), block, + scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 8; ix++) { + coefficients[(y + iy * 2) * 8 + ix] = block[iy * 8 + ix]; + } + } + } + float block0 = coefficients[0]; + float block1 = coefficients[8]; + coefficients[0] = (block0 + block1) * 0.5f; + coefficients[8] = (block0 - block1) * 0.5f; + break; + } + case Type::DCT4X4: { + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + HWY_ALIGN float block[4 * 4]; + ComputeScaledDCT<4, 4>()( + DCTFrom(pixels + y * 4 * pixels_stride + x * 4, pixels_stride), + block, scratch_space); + for (size_t iy = 0; iy < 4; iy++) { + for (size_t ix = 0; ix < 4; ix++) { + coefficients[(y + iy * 2) * 8 + x + ix * 2] = block[iy * 4 + ix]; + } + } + } + } + float block00 = coefficients[0]; + float block01 = coefficients[1]; + float block10 = coefficients[8]; + float block11 = coefficients[9]; + coefficients[0] = (block00 + block01 + block10 + block11) * 0.25f; + coefficients[1] = (block00 + block01 - block10 - block11) * 0.25f; + coefficients[8] = (block00 - block01 + block10 - block11) * 0.25f; + coefficients[9] = (block00 - block01 - block10 + block11) * 0.25f; + break; + } + case Type::DCT2X2: { + DCT2TopBlock<8>(pixels, pixels_stride, coefficients); + DCT2TopBlock<4>(coefficients, kBlockDim, coefficients); + DCT2TopBlock<2>(coefficients, kBlockDim, coefficients); + break; + } + case Type::DCT16X16: { + ComputeScaledDCT<16, 16>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT16X8: { + ComputeScaledDCT<16, 8>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT8X16: { + ComputeScaledDCT<8, 16>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X8: { + ComputeScaledDCT<32, 8>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT8X32: { + ComputeScaledDCT<8, 32>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X16: { + ComputeScaledDCT<32, 16>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT16X32: { + ComputeScaledDCT<16, 32>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X32: { + ComputeScaledDCT<32, 32>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT: { + ComputeScaledDCT<8, 8>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::AFV0: { + AFVTransformFromPixels<0>(pixels, pixels_stride, coefficients); + break; + } + case Type::AFV1: { + AFVTransformFromPixels<1>(pixels, pixels_stride, coefficients); + break; + } + case Type::AFV2: { + AFVTransformFromPixels<2>(pixels, pixels_stride, coefficients); + break; + } + case Type::AFV3: { + AFVTransformFromPixels<3>(pixels, pixels_stride, coefficients); + break; + } + case Type::DCT64X64: { + ComputeScaledDCT<64, 64>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT64X32: { + ComputeScaledDCT<64, 32>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT32X64: { + ComputeScaledDCT<32, 64>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT128X128: { + ComputeScaledDCT<128, 128>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT128X64: { + ComputeScaledDCT<128, 64>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT64X128: { + ComputeScaledDCT<64, 128>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT256X256: { + ComputeScaledDCT<256, 256>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT256X128: { + ComputeScaledDCT<256, 128>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::DCT128X256: { + ComputeScaledDCT<128, 256>()(DCTFrom(pixels, pixels_stride), coefficients, + scratch_space); + break; + } + case Type::kNumValidStrategies: + JXL_UNREACHABLE("Invalid strategy"); + } +} + +HWY_MAYBE_UNUSED void DCFromLowestFrequencies(const AcStrategy::Type strategy, + const float* block, float* dc, + size_t dc_stride) { + using Type = AcStrategy::Type; + switch (strategy) { + case Type::DCT16X8: { + ReinterpretingIDCT</*DCT_ROWS=*/2 * kBlockDim, /*DCT_COLS=*/kBlockDim, + /*LF_ROWS=*/2, /*LF_COLS=*/1, /*ROWS=*/2, /*COLS=*/1>( + block, 2 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT8X16: { + ReinterpretingIDCT</*DCT_ROWS=*/kBlockDim, /*DCT_COLS=*/2 * kBlockDim, + /*LF_ROWS=*/1, /*LF_COLS=*/2, /*ROWS=*/1, /*COLS=*/2>( + block, 2 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT16X16: { + ReinterpretingIDCT</*DCT_ROWS=*/2 * kBlockDim, /*DCT_COLS=*/2 * kBlockDim, + /*LF_ROWS=*/2, /*LF_COLS=*/2, /*ROWS=*/2, /*COLS=*/2>( + block, 2 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X8: { + ReinterpretingIDCT</*DCT_ROWS=*/4 * kBlockDim, /*DCT_COLS=*/kBlockDim, + /*LF_ROWS=*/4, /*LF_COLS=*/1, /*ROWS=*/4, /*COLS=*/1>( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT8X32: { + ReinterpretingIDCT</*DCT_ROWS=*/kBlockDim, /*DCT_COLS=*/4 * kBlockDim, + /*LF_ROWS=*/1, /*LF_COLS=*/4, /*ROWS=*/1, /*COLS=*/4>( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X16: { + ReinterpretingIDCT</*DCT_ROWS=*/4 * kBlockDim, /*DCT_COLS=*/2 * kBlockDim, + /*LF_ROWS=*/4, /*LF_COLS=*/2, /*ROWS=*/4, /*COLS=*/2>( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT16X32: { + ReinterpretingIDCT</*DCT_ROWS=*/2 * kBlockDim, /*DCT_COLS=*/4 * kBlockDim, + /*LF_ROWS=*/2, /*LF_COLS=*/4, /*ROWS=*/2, /*COLS=*/4>( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X32: { + ReinterpretingIDCT</*DCT_ROWS=*/4 * kBlockDim, /*DCT_COLS=*/4 * kBlockDim, + /*LF_ROWS=*/4, /*LF_COLS=*/4, /*ROWS=*/4, /*COLS=*/4>( + block, 4 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT64X32: { + ReinterpretingIDCT</*DCT_ROWS=*/8 * kBlockDim, /*DCT_COLS=*/4 * kBlockDim, + /*LF_ROWS=*/8, /*LF_COLS=*/4, /*ROWS=*/8, /*COLS=*/4>( + block, 8 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT32X64: { + ReinterpretingIDCT</*DCT_ROWS=*/4 * kBlockDim, /*DCT_COLS=*/8 * kBlockDim, + /*LF_ROWS=*/4, /*LF_COLS=*/8, /*ROWS=*/4, /*COLS=*/8>( + block, 8 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT64X64: { + ReinterpretingIDCT</*DCT_ROWS=*/8 * kBlockDim, /*DCT_COLS=*/8 * kBlockDim, + /*LF_ROWS=*/8, /*LF_COLS=*/8, /*ROWS=*/8, /*COLS=*/8>( + block, 8 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT128X64: { + ReinterpretingIDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/8 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/8, /*ROWS=*/16, /*COLS=*/8>( + block, 16 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT64X128: { + ReinterpretingIDCT< + /*DCT_ROWS=*/8 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/8, /*LF_COLS=*/16, /*ROWS=*/8, /*COLS=*/16>( + block, 16 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT128X128: { + ReinterpretingIDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/16, /*ROWS=*/16, /*COLS=*/16>( + block, 16 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT256X128: { + ReinterpretingIDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/16 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/16, /*ROWS=*/32, /*COLS=*/16>( + block, 32 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT128X256: { + ReinterpretingIDCT< + /*DCT_ROWS=*/16 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/16, /*LF_COLS=*/32, /*ROWS=*/16, /*COLS=*/32>( + block, 32 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT256X256: { + ReinterpretingIDCT< + /*DCT_ROWS=*/32 * kBlockDim, /*DCT_COLS=*/32 * kBlockDim, + /*LF_ROWS=*/32, /*LF_COLS=*/32, /*ROWS=*/32, /*COLS=*/32>( + block, 32 * kBlockDim, dc, dc_stride); + break; + } + case Type::DCT: + case Type::DCT2X2: + case Type::DCT4X4: + case Type::DCT4X8: + case Type::DCT8X4: + case Type::AFV0: + case Type::AFV1: + case Type::AFV2: + case Type::AFV3: + case Type::IDENTITY: + dc[0] = block[0]; + break; + case Type::kNumValidStrategies: + JXL_UNREACHABLE("Invalid strategy"); + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_ENC_TRANSFORMS_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_transforms.cc b/third_party/jpeg-xl/lib/jxl/enc_transforms.cc new file mode 100644 index 0000000000..8978ba1dcb --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_transforms.cc @@ -0,0 +1,41 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_transforms.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_transforms.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/enc_transforms-inl.h" + +namespace jxl { + +#if HWY_ONCE +HWY_EXPORT(TransformFromPixels); +void TransformFromPixels(const AcStrategy::Type strategy, + const float* JXL_RESTRICT pixels, size_t pixels_stride, + float* JXL_RESTRICT coefficients, + float* scratch_space) { + return HWY_DYNAMIC_DISPATCH(TransformFromPixels)( + strategy, pixels, pixels_stride, coefficients, scratch_space); +} + +HWY_EXPORT(DCFromLowestFrequencies); +void DCFromLowestFrequencies(AcStrategy::Type strategy, const float* block, + float* dc, size_t dc_stride) { + return HWY_DYNAMIC_DISPATCH(DCFromLowestFrequencies)(strategy, block, dc, + dc_stride); +} + +HWY_EXPORT(AFVDCT4x4); +void AFVDCT4x4(const float* JXL_RESTRICT pixels, float* JXL_RESTRICT coeffs) { + return HWY_DYNAMIC_DISPATCH(AFVDCT4x4)(pixels, coeffs); +} +#endif // HWY_ONCE + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/enc_transforms.h b/third_party/jpeg-xl/lib/jxl/enc_transforms.h new file mode 100644 index 0000000000..039ccc3893 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_transforms.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_TRANSFORMS_H_ +#define LIB_JXL_ENC_TRANSFORMS_H_ + +// Facade for (non-inlined) integral transforms. + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +void TransformFromPixels(const AcStrategy::Type strategy, + const float* JXL_RESTRICT pixels, size_t pixels_stride, + float* JXL_RESTRICT coefficients, + float* JXL_RESTRICT scratch_space); + +// Equivalent of the above for DC image. +void DCFromLowestFrequencies(AcStrategy::Type strategy, const float* block, + float* dc, size_t dc_stride); + +void AFVDCT4x4(const float* JXL_RESTRICT pixels, float* JXL_RESTRICT coeffs); + +} // namespace jxl + +#endif // LIB_JXL_ENC_TRANSFORMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/enc_xyb.cc b/third_party/jpeg-xl/lib/jxl/enc_xyb.cc new file mode 100644 index 0000000000..e538e8c91d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_xyb.cc @@ -0,0 +1,448 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_xyb.h" + +#include <algorithm> +#include <atomic> +#include <cstdlib> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_xyb.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/fast_math-inl.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/cms/opsin_params.h" +#include "lib/jxl/cms/transfer_functions-inl.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_image_bundle.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +// 4x3 matrix * 3x1 SIMD vectors +template <class V> +JXL_INLINE void OpsinAbsorbance(const V r, const V g, const V b, + const float* JXL_RESTRICT premul_absorb, + V* JXL_RESTRICT mixed0, V* JXL_RESTRICT mixed1, + V* JXL_RESTRICT mixed2) { + const float* bias = &jxl::cms::kOpsinAbsorbanceBias[0]; + const HWY_FULL(float) d; + const size_t N = Lanes(d); + const auto m0 = Load(d, premul_absorb + 0 * N); + const auto m1 = Load(d, premul_absorb + 1 * N); + const auto m2 = Load(d, premul_absorb + 2 * N); + const auto m3 = Load(d, premul_absorb + 3 * N); + const auto m4 = Load(d, premul_absorb + 4 * N); + const auto m5 = Load(d, premul_absorb + 5 * N); + const auto m6 = Load(d, premul_absorb + 6 * N); + const auto m7 = Load(d, premul_absorb + 7 * N); + const auto m8 = Load(d, premul_absorb + 8 * N); + *mixed0 = MulAdd(m0, r, MulAdd(m1, g, MulAdd(m2, b, Set(d, bias[0])))); + *mixed1 = MulAdd(m3, r, MulAdd(m4, g, MulAdd(m5, b, Set(d, bias[1])))); + *mixed2 = MulAdd(m6, r, MulAdd(m7, g, MulAdd(m8, b, Set(d, bias[2])))); +} + +template <class V> +void StoreXYB(const V r, V g, const V b, float* JXL_RESTRICT valx, + float* JXL_RESTRICT valy, float* JXL_RESTRICT valz) { + const HWY_FULL(float) d; + const V half = Set(d, 0.5f); + Store(Mul(half, Sub(r, g)), d, valx); + Store(Mul(half, Add(r, g)), d, valy); + Store(b, d, valz); +} + +// Converts one RGB vector to XYB. +template <class V> +void LinearRGBToXYB(const V r, const V g, const V b, + const float* JXL_RESTRICT premul_absorb, + float* JXL_RESTRICT valx, float* JXL_RESTRICT valy, + float* JXL_RESTRICT valz) { + V mixed0, mixed1, mixed2; + OpsinAbsorbance(r, g, b, premul_absorb, &mixed0, &mixed1, &mixed2); + + // mixed* should be non-negative even for wide-gamut, so clamp to zero. + mixed0 = ZeroIfNegative(mixed0); + mixed1 = ZeroIfNegative(mixed1); + mixed2 = ZeroIfNegative(mixed2); + + const HWY_FULL(float) d; + const size_t N = Lanes(d); + mixed0 = CubeRootAndAdd(mixed0, Load(d, premul_absorb + 9 * N)); + mixed1 = CubeRootAndAdd(mixed1, Load(d, premul_absorb + 10 * N)); + mixed2 = CubeRootAndAdd(mixed2, Load(d, premul_absorb + 11 * N)); + StoreXYB(mixed0, mixed1, mixed2, valx, valy, valz); + + // For wide-gamut inputs, r/g/b and valx (but not y/z) are often negative. +} + +void LinearRGBRowToXYB(float* JXL_RESTRICT row0, float* JXL_RESTRICT row1, + float* JXL_RESTRICT row2, + const float* JXL_RESTRICT premul_absorb, size_t xsize) { + const HWY_FULL(float) d; + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto r = Load(d, row0 + x); + const auto g = Load(d, row1 + x); + const auto b = Load(d, row2 + x); + LinearRGBToXYB(r, g, b, premul_absorb, row0 + x, row1 + x, row2 + x); + } +} + +// Input/output uses the codec.h scaling: nominally 0-1 if in-gamut. +template <class V> +V LinearFromSRGB(V encoded) { + return TF_SRGB().DisplayFromEncoded(encoded); +} + +Status LinearSRGBToXYB(const float* JXL_RESTRICT premul_absorb, + ThreadPool* pool, Image3F* JXL_RESTRICT image) { + const size_t xsize = image->xsize(); + + const HWY_FULL(float) d; + return RunOnPool( + pool, 0, static_cast<uint32_t>(image->ysize()), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = static_cast<size_t>(task); + float* JXL_RESTRICT row0 = image->PlaneRow(0, y); + float* JXL_RESTRICT row1 = image->PlaneRow(1, y); + float* JXL_RESTRICT row2 = image->PlaneRow(2, y); + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_r = Load(d, row0 + x); + const auto in_g = Load(d, row1 + x); + const auto in_b = Load(d, row2 + x); + LinearRGBToXYB(in_r, in_g, in_b, premul_absorb, row0 + x, row1 + x, + row2 + x); + } + }, + "LinearToXYB"); +} + +Status SRGBToXYB(const float* JXL_RESTRICT premul_absorb, ThreadPool* pool, + Image3F* JXL_RESTRICT image) { + const size_t xsize = image->xsize(); + + const HWY_FULL(float) d; + return RunOnPool( + pool, 0, static_cast<uint32_t>(image->ysize()), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = static_cast<size_t>(task); + float* JXL_RESTRICT row0 = image->PlaneRow(0, y); + float* JXL_RESTRICT row1 = image->PlaneRow(1, y); + float* JXL_RESTRICT row2 = image->PlaneRow(2, y); + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_r = LinearFromSRGB(Load(d, row0 + x)); + const auto in_g = LinearFromSRGB(Load(d, row1 + x)); + const auto in_b = LinearFromSRGB(Load(d, row2 + x)); + LinearRGBToXYB(in_r, in_g, in_b, premul_absorb, row0 + x, row1 + x, + row2 + x); + } + }, + "SRGBToXYB"); +} + +Status SRGBToXYBAndLinear(const float* JXL_RESTRICT premul_absorb, + ThreadPool* pool, Image3F* JXL_RESTRICT image, + Image3F* JXL_RESTRICT linear) { + const size_t xsize = image->xsize(); + + const HWY_FULL(float) d; + return RunOnPool( + pool, 0, static_cast<uint32_t>(image->ysize()), ThreadPool::NoInit, + [&](const uint32_t task, size_t /*thread*/) { + const size_t y = static_cast<size_t>(task); + float* JXL_RESTRICT row_image0 = image->PlaneRow(0, y); + float* JXL_RESTRICT row_image1 = image->PlaneRow(1, y); + float* JXL_RESTRICT row_image2 = image->PlaneRow(2, y); + float* JXL_RESTRICT row_linear0 = linear->PlaneRow(0, y); + float* JXL_RESTRICT row_linear1 = linear->PlaneRow(1, y); + float* JXL_RESTRICT row_linear2 = linear->PlaneRow(2, y); + + for (size_t x = 0; x < xsize; x += Lanes(d)) { + const auto in_r = LinearFromSRGB(Load(d, row_image0 + x)); + const auto in_g = LinearFromSRGB(Load(d, row_image1 + x)); + const auto in_b = LinearFromSRGB(Load(d, row_image2 + x)); + + Store(in_r, d, row_linear0 + x); + Store(in_g, d, row_linear1 + x); + Store(in_b, d, row_linear2 + x); + + LinearRGBToXYB(in_r, in_g, in_b, premul_absorb, row_image0 + x, + row_image1 + x, row_image2 + x); + } + }, + "SRGBToXYBAndLinear"); +} + +void ComputePremulAbsorb(float intensity_target, float* premul_absorb) { + const HWY_FULL(float) d; + const size_t N = Lanes(d); + const float mul = intensity_target / 255.0f; + for (size_t i = 0; i < 9; ++i) { + const auto absorb = Set(d, jxl::cms::kOpsinAbsorbanceMatrix[i] * mul); + Store(absorb, d, premul_absorb + i * N); + } + for (size_t i = 0; i < 3; ++i) { + const auto neg_bias_cbrt = + Set(d, -cbrtf(jxl::cms::kOpsinAbsorbanceBias[i])); + Store(neg_bias_cbrt, d, premul_absorb + (9 + i) * N); + } +} + +Image3F TransformToLinearRGB(const Image3F& in, + const ColorEncoding& color_encoding, + float intensity_target, const JxlCmsInterface& cms, + ThreadPool* pool) { + ColorSpaceTransform c_transform(cms); + bool is_gray = color_encoding.IsGray(); + const ColorEncoding& c_desired = ColorEncoding::LinearSRGB(is_gray); + Image3F out(in.xsize(), in.ysize()); + std::atomic<bool> ok{true}; + JXL_CHECK(RunOnPool( + pool, 0, in.ysize(), + [&](const size_t num_threads) { + return c_transform.Init(color_encoding, c_desired, intensity_target, + in.xsize(), num_threads); + }, + [&](const uint32_t y, const size_t thread) { + float* mutable_src_buf = c_transform.BufSrc(thread); + const float* src_buf = mutable_src_buf; + // Interleave input. + if (is_gray) { + src_buf = in.ConstPlaneRow(0, y); + } else { + const float* JXL_RESTRICT row_in0 = in.ConstPlaneRow(0, y); + const float* JXL_RESTRICT row_in1 = in.ConstPlaneRow(1, y); + const float* JXL_RESTRICT row_in2 = in.ConstPlaneRow(2, y); + for (size_t x = 0; x < in.xsize(); x++) { + mutable_src_buf[3 * x + 0] = row_in0[x]; + mutable_src_buf[3 * x + 1] = row_in1[x]; + mutable_src_buf[3 * x + 2] = row_in2[x]; + } + } + float* JXL_RESTRICT dst_buf = c_transform.BufDst(thread); + if (!c_transform.Run(thread, src_buf, dst_buf)) { + ok.store(false); + return; + } + float* JXL_RESTRICT row_out0 = out.PlaneRow(0, y); + float* JXL_RESTRICT row_out1 = out.PlaneRow(1, y); + float* JXL_RESTRICT row_out2 = out.PlaneRow(2, y); + // De-interleave output and convert type. + if (is_gray) { + for (size_t x = 0; x < in.xsize(); x++) { + row_out0[x] = dst_buf[x]; + row_out1[x] = dst_buf[x]; + row_out2[x] = dst_buf[x]; + } + } else { + for (size_t x = 0; x < in.xsize(); x++) { + row_out0[x] = dst_buf[3 * x + 0]; + row_out1[x] = dst_buf[3 * x + 1]; + row_out2[x] = dst_buf[3 * x + 2]; + } + } + }, + "Colorspace transform")); + JXL_CHECK(ok.load()); + return out; +} + +// This is different from Butteraugli's OpsinDynamicsImage() in the sense that +// it does not contain a sensitivity multiplier based on the blurred image. +void ToXYB(const ColorEncoding& c_current, float intensity_target, + const ImageF* black, ThreadPool* pool, Image3F* JXL_RESTRICT image, + const JxlCmsInterface& cms, Image3F* const JXL_RESTRICT linear) { + if (black) JXL_ASSERT(SameSize(*image, *black)); + if (linear) JXL_ASSERT(SameSize(*image, *linear)); + + const HWY_FULL(float) d; + // Pre-broadcasted constants + HWY_ALIGN float premul_absorb[MaxLanes(d) * 12]; + ComputePremulAbsorb(intensity_target, premul_absorb); + + const bool want_linear = linear != nullptr; + + const ColorEncoding& c_linear_srgb = + ColorEncoding::LinearSRGB(c_current.IsGray()); + // Linear sRGB inputs are rare but can be useful for the fastest encoders, for + // which undoing the sRGB transfer function would be a large part of the cost. + if (c_linear_srgb.SameColorEncoding(c_current)) { + // This only happens if kitten or slower, moving ImageBundle might be + // possible but the encoder is much slower than this copy. + if (want_linear) { + CopyImageTo(*image, linear); + } + JXL_CHECK(LinearSRGBToXYB(premul_absorb, pool, image)); + return; + } + + // Common case: already sRGB, can avoid the color transform + if (c_current.IsSRGB()) { + // Common case: can avoid allocating/copying + if (want_linear) { + // Slow encoder also wants linear sRGB. + JXL_CHECK(SRGBToXYBAndLinear(premul_absorb, pool, image, linear)); + } else { + JXL_CHECK(SRGBToXYB(premul_absorb, pool, image)); + } + return; + } + + JXL_CHECK(ApplyColorTransform(c_current, intensity_target, *image, black, + Rect(*image), c_linear_srgb, cms, pool, + want_linear ? linear : image)); + if (want_linear) { + CopyImageTo(*linear, image); + } + JXL_CHECK(LinearSRGBToXYB(premul_absorb, pool, image)); +} + +// Transform RGB to YCbCr. +// Could be performed in-place (i.e. Y, Cb and Cr could alias R, B and B). +Status RgbToYcbcr(const ImageF& r_plane, const ImageF& g_plane, + const ImageF& b_plane, ImageF* y_plane, ImageF* cb_plane, + ImageF* cr_plane, ThreadPool* pool) { + const HWY_FULL(float) df; + const size_t S = Lanes(df); // Step. + + const size_t xsize = r_plane.xsize(); + const size_t ysize = r_plane.ysize(); + if ((xsize == 0) || (ysize == 0)) return true; + + // Full-range BT.601 as defined by JFIF Clause 7: + // https://www.itu.int/rec/T-REC-T.871-201105-I/en + const auto k128 = Set(df, 128.0f / 255); + const auto kR = Set(df, 0.299f); // NTSC luma + const auto kG = Set(df, 0.587f); + const auto kB = Set(df, 0.114f); + const auto kAmpR = Set(df, 0.701f); + const auto kAmpB = Set(df, 0.886f); + const auto kDiffR = Add(kAmpR, kR); + const auto kDiffB = Add(kAmpB, kB); + const auto kNormR = Div(Set(df, 1.0f), (Add(kAmpR, Add(kG, kB)))); + const auto kNormB = Div(Set(df, 1.0f), (Add(kR, Add(kG, kAmpB)))); + + constexpr size_t kGroupArea = kGroupDim * kGroupDim; + const size_t lines_per_group = DivCeil(kGroupArea, xsize); + const size_t num_stripes = DivCeil(ysize, lines_per_group); + const auto transform = [&](int idx, int /* thread*/) { + const size_t y0 = idx * lines_per_group; + const size_t y1 = std::min<size_t>(y0 + lines_per_group, ysize); + for (size_t y = y0; y < y1; ++y) { + const float* r_row = r_plane.ConstRow(y); + const float* g_row = g_plane.ConstRow(y); + const float* b_row = b_plane.ConstRow(y); + float* y_row = y_plane->Row(y); + float* cb_row = cb_plane->Row(y); + float* cr_row = cr_plane->Row(y); + for (size_t x = 0; x < xsize; x += S) { + const auto r = Load(df, r_row + x); + const auto g = Load(df, g_row + x); + const auto b = Load(df, b_row + x); + const auto r_base = Mul(r, kR); + const auto r_diff = Mul(r, kDiffR); + const auto g_base = Mul(g, kG); + const auto b_base = Mul(b, kB); + const auto b_diff = Mul(b, kDiffB); + const auto y_base = Add(r_base, Add(g_base, b_base)); + const auto y_vec = Sub(y_base, k128); + const auto cb_vec = Mul(Sub(b_diff, y_base), kNormB); + const auto cr_vec = Mul(Sub(r_diff, y_base), kNormR); + Store(y_vec, df, y_row + x); + Store(cb_vec, df, cb_row + x); + Store(cr_vec, df, cr_row + x); + } + } + }; + return RunOnPool(pool, 0, static_cast<int>(num_stripes), ThreadPool::NoInit, + transform, "RgbToYcbCr"); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(ToXYB); +void ToXYB(const ColorEncoding& c_current, float intensity_target, + const ImageF* black, ThreadPool* pool, Image3F* JXL_RESTRICT image, + const JxlCmsInterface& cms, Image3F* const JXL_RESTRICT linear) { + HWY_DYNAMIC_DISPATCH(ToXYB) + (c_current, intensity_target, black, pool, image, cms, linear); +} + +void ToXYB(const ImageBundle& in, ThreadPool* pool, Image3F* JXL_RESTRICT xyb, + const JxlCmsInterface& cms, Image3F* JXL_RESTRICT linear) { + *xyb = Image3F(in.xsize(), in.ysize()); + CopyImageTo(in.color(), xyb); + ToXYB(in.c_current(), in.metadata()->IntensityTarget(), + in.HasBlack() ? &in.black() : nullptr, pool, xyb, cms, linear); +} + +HWY_EXPORT(LinearRGBRowToXYB); +void LinearRGBRowToXYB(float* JXL_RESTRICT row0, float* JXL_RESTRICT row1, + float* JXL_RESTRICT row2, + const float* JXL_RESTRICT premul_absorb, size_t xsize) { + HWY_DYNAMIC_DISPATCH(LinearRGBRowToXYB) + (row0, row1, row2, premul_absorb, xsize); +} + +HWY_EXPORT(ComputePremulAbsorb); +void ComputePremulAbsorb(float intensity_target, float* premul_absorb) { + HWY_DYNAMIC_DISPATCH(ComputePremulAbsorb)(intensity_target, premul_absorb); +} + +void ScaleXYBRow(float* JXL_RESTRICT row0, float* JXL_RESTRICT row1, + float* JXL_RESTRICT row2, size_t xsize) { + for (size_t x = 0; x < xsize; x++) { + row2[x] = (row2[x] - row1[x] + jxl::cms::kScaledXYBOffset[2]) * + jxl::cms::kScaledXYBScale[2]; + row0[x] = (row0[x] + jxl::cms::kScaledXYBOffset[0]) * + jxl::cms::kScaledXYBScale[0]; + row1[x] = (row1[x] + jxl::cms::kScaledXYBOffset[1]) * + jxl::cms::kScaledXYBScale[1]; + } +} + +void ScaleXYB(Image3F* opsin) { + for (size_t y = 0; y < opsin->ysize(); y++) { + float* row0 = opsin->PlaneRow(0, y); + float* row1 = opsin->PlaneRow(1, y); + float* row2 = opsin->PlaneRow(2, y); + ScaleXYBRow(row0, row1, row2, opsin->xsize()); + } +} + +HWY_EXPORT(RgbToYcbcr); +Status RgbToYcbcr(const ImageF& r_plane, const ImageF& g_plane, + const ImageF& b_plane, ImageF* y_plane, ImageF* cb_plane, + ImageF* cr_plane, ThreadPool* pool) { + return HWY_DYNAMIC_DISPATCH(RgbToYcbcr)(r_plane, g_plane, b_plane, y_plane, + cb_plane, cr_plane, pool); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/enc_xyb.h b/third_party/jpeg-xl/lib/jxl/enc_xyb.h new file mode 100644 index 0000000000..6a2e7c4123 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_xyb.h @@ -0,0 +1,53 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENC_XYB_H_ +#define LIB_JXL_ENC_XYB_H_ + +// Converts to XYB color space. + +#include <jxl/cms_interface.h> + +#include <cstddef> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" + +namespace jxl { + +// Converts any color space to XYB in-place. If `linear` is not null, fills it +// with a linear sRGB copy of `image`. +void ToXYB(const ColorEncoding& c_current, float intensity_target, + const ImageF* black, ThreadPool* pool, Image3F* JXL_RESTRICT image, + const JxlCmsInterface& cms, Image3F* JXL_RESTRICT linear); + +void ToXYB(const ImageBundle& in, ThreadPool* pool, Image3F* JXL_RESTRICT xyb, + const JxlCmsInterface& cms, Image3F* JXL_RESTRICT linear = nullptr); + +void LinearRGBRowToXYB(float* JXL_RESTRICT row0, float* JXL_RESTRICT row1, + float* JXL_RESTRICT row2, + const float* JXL_RESTRICT premul_absorb, size_t xsize); + +void ComputePremulAbsorb(float intensity_target, float* premul_absorb); + +// Transforms each color component of the given XYB image into the [0.0, 1.0] +// interval with an affine transform. +void ScaleXYB(Image3F* opsin); +void ScaleXYBRow(float* row0, float* row1, float* row2, size_t xsize); + +// Bt.601 to match JPEG/JFIF. Outputs _signed_ YCbCr values suitable for DCT, +// see F.1.1.3 of T.81 (because our data type is float, there is no need to add +// a bias to make the values unsigned). +Status RgbToYcbcr(const ImageF& r_plane, const ImageF& g_plane, + const ImageF& b_plane, ImageF* y_plane, ImageF* cb_plane, + ImageF* cr_plane, ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_ENC_XYB_H_ diff --git a/third_party/jpeg-xl/lib/jxl/encode.cc b/third_party/jpeg-xl/lib/jxl/encode.cc new file mode 100644 index 0000000000..76f2148d62 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/encode.cc @@ -0,0 +1,2690 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <brotli/encode.h> +#include <jxl/cms.h> +#include <jxl/codestream_header.h> +#include <jxl/encode.h> +#include <jxl/types.h> +#include <jxl/version.h> + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <cstring> + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/exif.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_fast_lossless.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/enc_icc_codec.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/luminance.h" +#include "lib/jxl/memory_manager_internal.h" +#include "lib/jxl/padded_bytes.h" +#include "lib/jxl/sanitizers.h" + +struct JxlErrorOrStatus { + // NOLINTNEXTLINE(google-explicit-constructor) + operator jxl::Status() const { + switch (error_) { + case JXL_ENC_SUCCESS: + return jxl::OkStatus(); + case JXL_ENC_NEED_MORE_OUTPUT: + return jxl::StatusCode::kNotEnoughBytes; + default: + return jxl::StatusCode::kGenericError; + } + } + // NOLINTNEXTLINE(google-explicit-constructor) + operator JxlEncoderStatus() const { return error_; } + + static JxlErrorOrStatus Success() { + return JxlErrorOrStatus(JXL_ENC_SUCCESS); + } + + static JxlErrorOrStatus MoreOutput() { + return JxlErrorOrStatus(JXL_ENC_NEED_MORE_OUTPUT); + } + + static JxlErrorOrStatus Error() { return JxlErrorOrStatus(JXL_ENC_ERROR); } + + private: + explicit JxlErrorOrStatus(JxlEncoderStatus error) : error_(error) {} + JxlEncoderStatus error_; +}; + +// Debug-printing failure macro similar to JXL_FAILURE, but for the status code +// JXL_ENC_ERROR +#ifdef JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(enc, error_code, format, ...) \ + (enc->error = error_code, \ + ::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__), \ + ::jxl::Abort(), JxlErrorOrStatus::Error()) +#define JXL_API_ERROR_NOSET(format, ...) \ + (::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__), \ + ::jxl::Abort(), JxlErrorOrStatus::Error()) +#else // JXL_CRASH_ON_ERROR +#define JXL_API_ERROR(enc, error_code, format, ...) \ + (enc->error = error_code, \ + ((JXL_DEBUG_ON_ERROR) && \ + ::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__)), \ + JxlErrorOrStatus::Error()) +#define JXL_API_ERROR_NOSET(format, ...) \ + (::jxl::Debug(("%s:%d: " format "\n"), __FILE__, __LINE__, ##__VA_ARGS__), \ + JxlErrorOrStatus::Error()) +#endif // JXL_CRASH_ON_ERROR + +jxl::StatusOr<JxlOutputProcessorBuffer> +JxlEncoderOutputProcessorWrapper::GetBuffer(size_t min_size, + size_t requested_size) { + JXL_ASSERT(min_size > 0); + JXL_ASSERT(!has_buffer_); + if (stop_requested_) return jxl::StatusCode::kNotEnoughBytes; + requested_size = std::max(min_size, requested_size); + + // If we support seeking, output_position_ == position_. + if (external_output_processor_ && external_output_processor_->seek) { + JXL_ASSERT(output_position_ == position_); + } + // Otherwise, output_position_ <= position_. + JXL_ASSERT(output_position_ <= position_); + size_t additional_size = position_ - output_position_; + + if (external_output_processor_) { + // TODO(veluca): here, we cannot just ask for a larger buffer, as it will be + // released with a prefix of the buffer that has not been written yet. + // Figure out if there is a good way to do this more efficiently. + if (additional_size == 0) { + size_t size = requested_size; + uint8_t* user_buffer = + static_cast<uint8_t*>(external_output_processor_->get_buffer( + external_output_processor_->opaque, &size)); + if (size == 0 || user_buffer == nullptr) { + stop_requested_ = true; + return jxl::StatusCode::kNotEnoughBytes; + } + if (size < min_size) { + external_output_processor_->release_buffer( + external_output_processor_->opaque, 0); + } else { + internal_buffers_.emplace(position_, InternalBuffer()); + has_buffer_ = true; + return JxlOutputProcessorBuffer(user_buffer, size, 0, this); + } + } + } else { + if (min_size + additional_size < *avail_out_) { + internal_buffers_.emplace(position_, InternalBuffer()); + has_buffer_ = true; + return JxlOutputProcessorBuffer(*next_out_ + additional_size, + *avail_out_ - additional_size, 0, this); + } + } + + // Otherwise, we need to allocate our own buffer. + auto it = internal_buffers_.emplace(position_, InternalBuffer()).first; + InternalBuffer& buffer = it->second; + size_t alloc_size = requested_size; + it++; + if (it != internal_buffers_.end()) { + alloc_size = std::min(alloc_size, it->first - position_); + JXL_ASSERT(alloc_size >= min_size); + } + buffer.owned_data.resize(alloc_size); + has_buffer_ = true; + return JxlOutputProcessorBuffer(buffer.owned_data.data(), alloc_size, 0, + this); +} + +void JxlEncoderOutputProcessorWrapper::Seek(size_t pos) { + JXL_ASSERT(!has_buffer_); + if (external_output_processor_ && external_output_processor_->seek) { + external_output_processor_->seek(external_output_processor_->opaque, pos); + output_position_ = pos; + } + JXL_ASSERT(pos >= finalized_position_); + position_ = pos; +} + +void JxlEncoderOutputProcessorWrapper::SetFinalizedPosition() { + JXL_ASSERT(!has_buffer_); + if (external_output_processor_ && external_output_processor_->seek) { + external_output_processor_->set_finalized_position( + external_output_processor_->opaque, position_); + } + finalized_position_ = position_; + FlushOutput(); +} + +bool JxlEncoderOutputProcessorWrapper::SetAvailOut(uint8_t** next_out, + size_t* avail_out) { + if (external_output_processor_) return false; + avail_out_ = avail_out; + next_out_ = next_out; + FlushOutput(); + return true; +} + +void JxlEncoderOutputProcessorWrapper::CopyOutput(std::vector<uint8_t>& output, + uint8_t* next_out, + size_t& avail_out) { + while (HasOutputToWrite()) { + SetAvailOut(&next_out, &avail_out); + if (avail_out == 0) { + size_t offset = next_out - output.data(); + output.resize(output.size() * 2); + next_out = output.data() + offset; + avail_out = output.size() - offset; + } + } + output.resize(output.size() - avail_out); +} + +void JxlEncoderOutputProcessorWrapper::ReleaseBuffer(size_t bytes_used) { + JXL_ASSERT(has_buffer_); + has_buffer_ = false; + auto it = internal_buffers_.find(position_); + JXL_ASSERT(it != internal_buffers_.end()); + if (bytes_used == 0) { + if (external_output_processor_) { + external_output_processor_->release_buffer( + external_output_processor_->opaque, bytes_used); + } + internal_buffers_.erase(it); + return; + } + it->second.written_bytes = bytes_used; + position_ += bytes_used; + + auto it_to_next = it; + it_to_next++; + if (it_to_next != internal_buffers_.end()) { + JXL_ASSERT(it_to_next->first >= position_); + } + + if (external_output_processor_) { + // If the buffer was given by the user, tell the user it is not needed + // anymore. + if (it->second.owned_data.empty()) { + external_output_processor_->release_buffer( + external_output_processor_->opaque, bytes_used); + // If we don't support seeking, this implies we will never modify again + // the bytes that were written so far. Advance the finalized position and + // flush the output to clean up the internal buffers. + if (!external_output_processor_->seek) { + SetFinalizedPosition(); + JXL_ASSERT(output_position_ == finalized_position_); + JXL_ASSERT(output_position_ == position_); + } else { + // Otherwise, advance the output position accordingly. + output_position_ += bytes_used; + JXL_ASSERT(output_position_ >= finalized_position_); + JXL_ASSERT(output_position_ == position_); + } + } else if (external_output_processor_->seek) { + // If we had buffered the data internally, flush it out to the external + // processor if we can. + external_output_processor_->seek(external_output_processor_->opaque, + position_ - bytes_used); + output_position_ = position_ - bytes_used; + while (output_position_ < position_) { + size_t num_to_write = position_ - output_position_; + if (!AppendBufferToExternalProcessor(it->second.owned_data.data() + + output_position_ - position_ + + bytes_used, + num_to_write)) { + return; + } + } + it->second.owned_data.clear(); + } + } +} + +// Tries to write all the bytes up to the finalized position. +void JxlEncoderOutputProcessorWrapper::FlushOutput() { + JXL_ASSERT(!has_buffer_); + while (output_position_ < finalized_position_ && + (avail_out_ == nullptr || *avail_out_ > 0)) { + JXL_ASSERT(!internal_buffers_.empty()); + auto it = internal_buffers_.begin(); + // If this fails, we are trying to move the finalized position past data + // that was not written yet. This is a library programming error. + JXL_ASSERT(output_position_ >= it->first); + JXL_ASSERT(it->second.written_bytes != 0); + size_t buffer_last_byte = it->first + it->second.written_bytes; + if (!it->second.owned_data.empty()) { + size_t start_in_buffer = output_position_ - it->first; + // Guaranteed by the invariant on `internal_buffers_`. + JXL_ASSERT(buffer_last_byte > output_position_); + size_t num_to_write = + std::min(buffer_last_byte, finalized_position_) - output_position_; + if (avail_out_ != nullptr) { + size_t n = std::min(num_to_write, *avail_out_); + memcpy(*next_out_, it->second.owned_data.data() + start_in_buffer, n); + *avail_out_ -= n; + *next_out_ += n; + output_position_ += n; + } else { + if (!AppendBufferToExternalProcessor( + it->second.owned_data.data() + start_in_buffer, num_to_write)) { + return; + } + } + } else { + size_t advance = + std::min(buffer_last_byte, finalized_position_) - output_position_; + output_position_ += advance; + if (avail_out_ != nullptr) { + *next_out_ += advance; + *avail_out_ -= advance; + } + } + if (buffer_last_byte == output_position_) { + internal_buffers_.erase(it); + } + if (external_output_processor_ && !external_output_processor_->seek) { + external_output_processor_->set_finalized_position( + external_output_processor_->opaque, output_position_); + } + } +} + +bool JxlEncoderOutputProcessorWrapper::AppendBufferToExternalProcessor( + void* data, size_t count) { + JXL_ASSERT(external_output_processor_); + size_t n = count; + void* user_buffer = external_output_processor_->get_buffer( + external_output_processor_->opaque, &n); + if (user_buffer == nullptr || n == 0) { + stop_requested_ = true; + return false; + } + n = std::min(n, count); + memcpy(user_buffer, data, n); + external_output_processor_->release_buffer(external_output_processor_->opaque, + n); + output_position_ += n; + return true; +} + +namespace jxl { + +size_t WriteBoxHeader(const jxl::BoxType& type, size_t size, bool unbounded, + bool force_large_box, uint8_t* output) { + uint64_t box_size = 0; + bool large_size = false; + if (!unbounded) { + if (box_size >= kLargeBoxContentSizeThreshold || force_large_box) { + large_size = true; + // TODO(firsching): send a separate CL for this (+ test), + // quick fix in the old code: box_size += 8 + box_size = size + kLargeBoxHeaderSize; + } else { + box_size = size + kSmallBoxHeaderSize; + } + } + + size_t idx = 0; + { + const uint64_t store = large_size ? 1 : box_size; + for (size_t i = 0; i < 4; i++) { + output[idx++] = store >> (8 * (3 - i)) & 0xff; + } + } + for (size_t i = 0; i < 4; i++) { + output[idx++] = type[i]; + } + + if (large_size) { + for (size_t i = 0; i < 8; i++) { + output[idx++] = box_size >> (8 * (7 - i)) & 0xff; + } + } + return idx; +} +} // namespace jxl + +template <typename WriteBox> +jxl::Status JxlEncoderStruct::AppendBox(const jxl::BoxType& type, + bool unbounded, size_t box_max_size, + const WriteBox& write_box) { + size_t current_position = output_processor.CurrentPosition(); + bool large_box = false; + size_t box_header_size = 0; + if (box_max_size >= jxl::kLargeBoxContentSizeThreshold && !unbounded) { + box_header_size = jxl::kLargeBoxHeaderSize; + large_box = true; + } else { + box_header_size = jxl::kSmallBoxHeaderSize; + } + output_processor.Seek(current_position + box_header_size); + size_t box_contents_start = output_processor.CurrentPosition(); + JXL_RETURN_IF_ERROR(write_box()); + size_t box_contents_end = output_processor.CurrentPosition(); + output_processor.Seek(current_position); + JXL_ASSERT(box_contents_end >= box_contents_start); + if (box_contents_end - box_contents_start > box_max_size) { + return JXL_API_ERROR(this, JXL_ENC_ERR_GENERIC, + "Internal error: upper bound on box size was " + "violated, upper bound: %" PRIuS ", actual: %" PRIuS, + box_max_size, box_contents_end - box_contents_start); + } + // We need to release the buffer before Seek. + { + JXL_ASSIGN_OR_RETURN( + auto buffer, + output_processor.GetBuffer(box_contents_start - current_position)); + const size_t n = + jxl::WriteBoxHeader(type, box_contents_end - box_contents_start, + unbounded, large_box, buffer.data()); + JXL_ASSERT(n == box_header_size); + buffer.advance(n); + } + output_processor.Seek(box_contents_end); + output_processor.SetFinalizedPosition(); + return jxl::OkStatus(); +} + +template <typename BoxContents> +jxl::Status JxlEncoderStruct::AppendBoxWithContents( + const jxl::BoxType& type, const BoxContents& contents) { + size_t size = std::end(contents) - std::begin(contents); + return AppendBox(type, /*unbounded=*/false, size, + [&]() { return AppendData(output_processor, contents); }); +} + +uint32_t JxlEncoderVersion(void) { + return JPEGXL_MAJOR_VERSION * 1000000 + JPEGXL_MINOR_VERSION * 1000 + + JPEGXL_PATCH_VERSION; +} + +namespace { + +void WriteJxlpBoxCounter(uint32_t counter, bool last, uint8_t* buffer) { + if (last) counter |= 0x80000000; + for (size_t i = 0; i < 4; i++) { + buffer[i] = counter >> (8 * (3 - i)) & 0xff; + } +} + +void WriteJxlpBoxCounter(uint32_t counter, bool last, + JxlOutputProcessorBuffer& buffer) { + uint8_t buf[4]; + WriteJxlpBoxCounter(counter, last, buf); + buffer.append(buf, 4); +} + +void QueueFrame( + const JxlEncoderFrameSettings* frame_settings, + jxl::MemoryManagerUniquePtr<jxl::JxlEncoderQueuedFrame>& frame) { + if (frame_settings->values.lossless) { + frame->option_values.cparams.SetLossless(); + } + + jxl::JxlEncoderQueuedInput queued_input(frame_settings->enc->memory_manager); + queued_input.frame = std::move(frame); + frame_settings->enc->input_queue.emplace_back(std::move(queued_input)); + frame_settings->enc->num_queued_frames++; +} + +void QueueFastLosslessFrame(const JxlEncoderFrameSettings* frame_settings, + JxlFastLosslessFrameState* fast_lossless_frame) { + jxl::JxlEncoderQueuedInput queued_input(frame_settings->enc->memory_manager); + queued_input.fast_lossless_frame.reset(fast_lossless_frame); + frame_settings->enc->input_queue.emplace_back(std::move(queued_input)); + frame_settings->enc->num_queued_frames++; +} + +void QueueBox(JxlEncoder* enc, + jxl::MemoryManagerUniquePtr<jxl::JxlEncoderQueuedBox>& box) { + jxl::JxlEncoderQueuedInput queued_input(enc->memory_manager); + queued_input.box = std::move(box); + enc->input_queue.emplace_back(std::move(queued_input)); + enc->num_queued_boxes++; +} + +// TODO(lode): share this code and the Brotli compression code in enc_jpeg_data +JxlEncoderStatus BrotliCompress(int quality, const uint8_t* in, size_t in_size, + jxl::PaddedBytes* out) { + std::unique_ptr<BrotliEncoderState, decltype(BrotliEncoderDestroyInstance)*> + enc(BrotliEncoderCreateInstance(nullptr, nullptr, nullptr), + BrotliEncoderDestroyInstance); + if (!enc) return JXL_API_ERROR_NOSET("BrotliEncoderCreateInstance failed"); + + BrotliEncoderSetParameter(enc.get(), BROTLI_PARAM_QUALITY, quality); + BrotliEncoderSetParameter(enc.get(), BROTLI_PARAM_SIZE_HINT, in_size); + + constexpr size_t kBufferSize = 128 * 1024; + jxl::PaddedBytes temp_buffer(kBufferSize); + + size_t avail_in = in_size; + const uint8_t* next_in = in; + + size_t total_out = 0; + + for (;;) { + size_t avail_out = kBufferSize; + uint8_t* next_out = temp_buffer.data(); + jxl::msan::MemoryIsInitialized(next_in, avail_in); + if (!BrotliEncoderCompressStream(enc.get(), BROTLI_OPERATION_FINISH, + &avail_in, &next_in, &avail_out, &next_out, + &total_out)) { + return JXL_API_ERROR_NOSET("Brotli compression failed"); + } + size_t out_size = next_out - temp_buffer.data(); + jxl::msan::UnpoisonMemory(next_out - out_size, out_size); + out->resize(out->size() + out_size); + memcpy(out->data() + out->size() - out_size, temp_buffer.data(), out_size); + if (BrotliEncoderIsFinished(enc.get())) break; + } + + return JxlErrorOrStatus::Success(); +} + +// The JXL codestream can have level 5 or level 10. Levels have certain +// restrictions such as max allowed image dimensions. This function checks the +// level required to support the current encoder settings. The debug_string is +// intended to be used for developer API error messages, and may be set to +// nullptr. +int VerifyLevelSettings(const JxlEncoder* enc, std::string* debug_string) { + const auto& m = enc->metadata.m; + + uint64_t xsize = enc->metadata.size.xsize(); + uint64_t ysize = enc->metadata.size.ysize(); + // The uncompressed ICC size, if it is used. + size_t icc_size = 0; + if (m.color_encoding.WantICC()) { + icc_size = m.color_encoding.ICC().size(); + } + + // Level 10 checks + + if (xsize > (1ull << 30ull) || ysize > (1ull << 30ull) || + xsize * ysize > (1ull << 40ull)) { + if (debug_string) *debug_string = "Too large image dimensions"; + return -1; + } + if (icc_size > (1ull << 28)) { + if (debug_string) *debug_string = "Too large ICC profile size"; + return -1; + } + if (m.num_extra_channels > 256) { + if (debug_string) *debug_string = "Too many extra channels"; + return -1; + } + + // Level 5 checks + + if (!m.modular_16_bit_buffer_sufficient) { + if (debug_string) *debug_string = "Too high modular bit depth"; + return 10; + } + if (xsize > (1ull << 18ull) || ysize > (1ull << 18ull) || + xsize * ysize > (1ull << 28ull)) { + if (debug_string) *debug_string = "Too large image dimensions"; + return 10; + } + if (icc_size > (1ull << 22)) { + if (debug_string) *debug_string = "Too large ICC profile"; + return 10; + } + if (m.num_extra_channels > 4) { + if (debug_string) *debug_string = "Too many extra channels"; + return 10; + } + for (size_t i = 0; i < m.extra_channel_info.size(); ++i) { + if (m.extra_channel_info[i].type == jxl::ExtraChannel::kBlack) { + if (debug_string) *debug_string = "CMYK channel not allowed"; + return 10; + } + } + + // TODO(lode): also need to check if consecutive composite-still frames total + // pixel amount doesn't exceed 2**28 in the case of level 5. This should be + // done when adding frame and requires ability to add composite still frames + // to be added first. + + // TODO(lode): also need to check animation duration of a frame. This should + // be done when adding frame, but first requires implementing setting the + // JxlFrameHeader for a frame. + + // TODO(lode): also need to check properties such as num_splines, num_patches, + // modular_16bit_buffers and multiple properties of modular trees. However + // these are not user-set properties so cannot be checked here, but decisions + // the C++ encoder should be able to make based on the level. + + // All level 5 checks passes, so can return the more compatible level 5 + return 5; +} + +JxlEncoderStatus CheckValidBitdepth(uint32_t bits_per_sample, + uint32_t exponent_bits_per_sample) { + if (!exponent_bits_per_sample) { + // The spec allows up to 31 for bits_per_sample here, but + // the code does not (yet) support it. + if (!(bits_per_sample > 0 && bits_per_sample <= 24)) { + return JXL_API_ERROR_NOSET("Invalid value for bits_per_sample"); + } + } else if ((exponent_bits_per_sample > 8) || + (bits_per_sample > 24 + exponent_bits_per_sample) || + (bits_per_sample < 3 + exponent_bits_per_sample)) { + return JXL_API_ERROR_NOSET("Invalid float description"); + } + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus VerifyInputBitDepth(JxlBitDepth bit_depth, + JxlPixelFormat format) { + return JxlErrorOrStatus::Success(); +} + +static inline bool EncodeVarInt(uint64_t value, size_t output_size, + size_t* output_pos, uint8_t* output) { + // While more than 7 bits of data are left, + // store 7 bits and set the next byte flag + while (value > 127) { + // TODO(eustas): should it be `>=` ? + if (*output_pos > output_size) return false; + // |128: Set the next byte flag + output[(*output_pos)++] = ((uint8_t)(value & 127)) | 128; + // Remove the seven bits we just wrote + value >>= 7; + } + // TODO(eustas): should it be `>=` ? + if (*output_pos > output_size) return false; + output[(*output_pos)++] = ((uint8_t)value) & 127; + return true; +} + +bool EncodeFrameIndexBox(const jxl::JxlEncoderFrameIndexBox& frame_index_box, + std::vector<uint8_t>& buffer_vec) { + bool ok = true; + int NF = 0; + for (size_t i = 0; i < frame_index_box.entries.size(); ++i) { + if (i == 0 || frame_index_box.entries[i].to_be_indexed) { + ++NF; + } + } + // Frame index box contents varint + 8 bytes + // continue with NF * 3 * varint + // varint max length is 10 for 64 bit numbers, and these numbers + // are limited to 63 bits. + static const int kVarintMaxLength = 10; + static const int kFrameIndexBoxHeaderLength = kVarintMaxLength + 8; + static const int kFrameIndexBoxElementLength = 3 * kVarintMaxLength; + const int buffer_size = + kFrameIndexBoxHeaderLength + NF * kFrameIndexBoxElementLength; + buffer_vec.resize(buffer_size); + uint8_t* buffer = buffer_vec.data(); + size_t output_pos = 0; + ok &= EncodeVarInt(NF, buffer_vec.size(), &output_pos, buffer); + StoreBE32(frame_index_box.TNUM, &buffer[output_pos]); + output_pos += 4; + StoreBE32(frame_index_box.TDEN, &buffer[output_pos]); + output_pos += 4; + // When we record a frame in the index, the record needs to know + // how many frames until the next indexed frame. That is why + // we store the 'prev' record. That 'prev' record needs to store + // the offset byte position to previously recorded indexed frame, + // that's why we also trace previous to the previous frame. + int prev_prev_ix = -1; // For position offset (OFFi) delta coding. + int prev_ix = 0; + int T_prev = 0; + int T = 0; + for (size_t i = 1; i < frame_index_box.entries.size(); ++i) { + if (frame_index_box.entries[i].to_be_indexed) { + // Now we can record the previous entry, since we need to store + // there how many frames until the next one. + int64_t OFFi = frame_index_box.entries[prev_ix].OFFi; + if (prev_prev_ix != -1) { + // Offi needs to be offset of start byte of this frame compared to start + // byte of previous frame from this index in the JPEG XL codestream. For + // the first frame, this is the offset from the first byte of the JPEG + // XL codestream. + OFFi -= frame_index_box.entries[prev_prev_ix].OFFi; + } + int32_t Ti = T_prev; + int32_t Fi = i - prev_ix; + ok &= EncodeVarInt(OFFi, buffer_vec.size(), &output_pos, buffer); + ok &= EncodeVarInt(Ti, buffer_vec.size(), &output_pos, buffer); + ok &= EncodeVarInt(Fi, buffer_vec.size(), &output_pos, buffer); + prev_prev_ix = prev_ix; + prev_ix = i; + T_prev = T; + T += frame_index_box.entries[i].duration; + } + } + { + // Last frame. + size_t i = frame_index_box.entries.size(); + int64_t OFFi = frame_index_box.entries[prev_ix].OFFi; + if (prev_prev_ix != -1) { + OFFi -= frame_index_box.entries[prev_prev_ix].OFFi; + } + int32_t Ti = T_prev; + int32_t Fi = i - prev_ix; + ok &= EncodeVarInt(OFFi, buffer_vec.size(), &output_pos, buffer); + ok &= EncodeVarInt(Ti, buffer_vec.size(), &output_pos, buffer); + ok &= EncodeVarInt(Fi, buffer_vec.size(), &output_pos, buffer); + } + // Enough buffer has been allocated, this function should never fail in + // writing. + JXL_ASSERT(ok); + buffer_vec.resize(output_pos); + return ok; +} + +} // namespace + +jxl::Status JxlEncoderStruct::ProcessOneEnqueuedInput() { + jxl::PaddedBytes header_bytes; + + jxl::JxlEncoderQueuedInput& input = input_queue[0]; + + // TODO(lode): split this into 3 functions: for adding the signature and other + // initial headers (jbrd, ...), one for adding frame, and one for adding user + // box. + + if (!wrote_bytes) { + // First time encoding any data, verify the level 5 vs level 10 settings + std::string level_message; + int required_level = VerifyLevelSettings(this, &level_message); + // Only level 5 and 10 are defined, and the function can return -1 to + // indicate full incompatibility. + JXL_ASSERT(required_level == -1 || required_level == 5 || + required_level == 10); + // codestream_level == -1 means auto-set to the required level + if (codestream_level == -1) codestream_level = required_level; + if (codestream_level == 5 && required_level != 5) { + // If the required level is 10, return error rather than automatically + // setting the level to 10, to avoid inadvertently creating a level 10 + // JXL file while intending to target a level 5 decoder. + return JXL_API_ERROR( + this, JXL_ENC_ERR_API_USAGE, "%s", + ("Codestream level verification for level 5 failed: " + level_message) + .c_str()); + } + if (required_level == -1) { + return JXL_API_ERROR( + this, JXL_ENC_ERR_API_USAGE, "%s", + ("Codestream level verification for level 10 failed: " + + level_message) + .c_str()); + } + jxl::AuxOut* aux_out = + input.frame ? input.frame->option_values.aux_out : nullptr; + jxl::BitWriter writer; + if (!WriteCodestreamHeaders(&metadata, &writer, aux_out)) { + return JXL_API_ERROR(this, JXL_ENC_ERR_GENERIC, + "Failed to write codestream header"); + } + // Only send ICC (at least several hundred bytes) if fields aren't enough. + if (metadata.m.color_encoding.WantICC()) { + if (!jxl::WriteICC(metadata.m.color_encoding.ICC(), &writer, + jxl::kLayerHeader, aux_out)) { + return JXL_API_ERROR(this, JXL_ENC_ERR_GENERIC, + "Failed to write ICC profile"); + } + } + // TODO(lode): preview should be added here if a preview image is added + + jxl::BitWriter::Allotment allotment(&writer, 8); + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, jxl::kLayerHeader, aux_out); + + header_bytes = std::move(writer).TakeBytes(); + + // Not actually the end of frame, but the end of metadata/ICC, but helps + // the next frame to start here for indexing purposes. + codestream_bytes_written_end_of_frame += header_bytes.size(); + + if (MustUseContainer()) { + // Add "JXL " and ftyp box. + { + JXL_ASSIGN_OR_RETURN(auto buffer, output_processor.GetBuffer( + jxl::kContainerHeader.size())); + buffer.append(jxl::kContainerHeader); + } + if (codestream_level != 5) { + // Add jxll box directly after the ftyp box to indicate the codestream + // level. + JXL_ASSIGN_OR_RETURN(auto buffer, output_processor.GetBuffer( + jxl::kLevelBoxHeader.size() + 1)); + buffer.append(jxl::kLevelBoxHeader); + uint8_t cl = codestream_level; + buffer.append(&cl, 1); + } + + // Whether to write the basic info and color profile header of the + // codestream into an early separate jxlp box, so that it comes before + // metadata or jpeg reconstruction boxes. In theory this could simply + // always be done, but there's no reason to add an extra box with box + // header overhead if the codestream will already come immediately after + // the signature and level boxes. + bool partial_header = + store_jpeg_metadata || + (use_boxes && (!input.frame && !input.fast_lossless_frame)); + + if (partial_header) { + JXL_RETURN_IF_ERROR(AppendBox( + jxl::MakeBoxType("jxlp"), /*unbounded=*/false, + header_bytes.size() + 4, [&]() { + JXL_ASSIGN_OR_RETURN(auto buffer, output_processor.GetBuffer( + header_bytes.size() + 4)); + WriteJxlpBoxCounter(jxlp_counter++, /*last=*/false, buffer); + buffer.append(header_bytes); + return jxl::OkStatus(); + })); + header_bytes.clear(); + } + + if (store_jpeg_metadata && !jpeg_metadata.empty()) { + JXL_RETURN_IF_ERROR( + AppendBoxWithContents(jxl::MakeBoxType("jbrd"), jpeg_metadata)); + } + } + wrote_bytes = true; + } + + output_processor.SetFinalizedPosition(); + + // Choose frame or box processing: exactly one of the two unique pointers (box + // or frame) in the input queue item is non-null. + if (input.frame || input.fast_lossless_frame) { + jxl::MemoryManagerUniquePtr<jxl::JxlEncoderQueuedFrame> input_frame = + std::move(input.frame); + jxl::FJXLFrameUniquePtr fast_lossless_frame = + std::move(input.fast_lossless_frame); + input_queue.erase(input_queue.begin()); + num_queued_frames--; + if (input_frame) { + for (unsigned idx = 0; idx < input_frame->ec_initialized.size(); idx++) { + if (!input_frame->ec_initialized[idx]) { + return JXL_API_ERROR(this, JXL_ENC_ERR_API_USAGE, + "Extra channel %u is not initialized", idx); + } + } + + // TODO(zond): If the input queue is empty and the frames_closed is true, + // then mark this frame as the last. + + // TODO(zond): Handle progressive mode like EncodeFile does it. + // TODO(zond): Handle animation like EncodeFile does it, by checking if + // JxlEncoderCloseFrames has been called and if the frame + // queue is empty (to see if it's the last animation frame). + + if (metadata.m.xyb_encoded) { + input_frame->option_values.cparams.color_transform = + jxl::ColorTransform::kXYB; + } else { + // TODO(zond): Figure out when to use kYCbCr instead. + input_frame->option_values.cparams.color_transform = + jxl::ColorTransform::kNone; + } + } + + uint32_t duration; + uint32_t timecode; + if (input_frame && metadata.m.have_animation) { + duration = input_frame->option_values.header.duration; + timecode = input_frame->option_values.header.timecode; + } else { + // If have_animation is false, the encoder should ignore the duration and + // timecode values. However, assigning them to ib will cause the encoder + // to write an invalid frame header that can't be decoded so ensure + // they're the default value of 0 here. + duration = 0; + timecode = 0; + } + + const bool last_frame = frames_closed && !num_queued_frames; + + uint32_t max_bits_per_sample = metadata.m.bit_depth.bits_per_sample; + for (const auto& info : metadata.m.extra_channel_info) { + max_bits_per_sample = + std::max(max_bits_per_sample, info.bit_depth.bits_per_sample); + } + // Heuristic upper bound on how many bits a single pixel in a single channel + // can use. + uint32_t bits_per_channels_estimate = + std::max(24u, max_bits_per_sample + 3); + size_t upper_bound_on_compressed_size_bits = + metadata.xsize() * metadata.ysize() * + (metadata.m.color_encoding.Channels() + metadata.m.num_extra_channels) * + bits_per_channels_estimate; + // Add a 1MB = 0x100000 for an heuristic upper bound on small sizes. + size_t upper_bound_on_compressed_size_bytes = + 0x100000 + (upper_bound_on_compressed_size_bits >> 3); + bool use_large_box = upper_bound_on_compressed_size_bytes >= + jxl::kLargeBoxContentSizeThreshold; + size_t box_header_size = + use_large_box ? jxl::kLargeBoxHeaderSize : jxl::kSmallBoxHeaderSize; + + const size_t frame_start_pos = output_processor.CurrentPosition(); + if (MustUseContainer()) { + if (!last_frame || jxlp_counter > 0) { + // If this is the last frame and no jxlp boxes were used yet, it's + // slightly more efficient to write a jxlc box since it has 4 bytes + // less overhead. + box_header_size += 4; // jxlp_counter field + } + output_processor.Seek(frame_start_pos + box_header_size); + } + const size_t frame_codestream_start = output_processor.CurrentPosition(); + + JXL_RETURN_IF_ERROR(AppendData(output_processor, header_bytes)); + + if (input_frame) { + frame_index_box.AddFrame(codestream_bytes_written_end_of_frame, duration, + input_frame->option_values.frame_index_box); + + size_t save_as_reference = + input_frame->option_values.header.layer_info.save_as_reference; + if (save_as_reference >= 3) { + return JXL_API_ERROR( + this, JXL_ENC_ERR_API_USAGE, + "Cannot use save_as_reference values >=3 (found: %d)", + (int)save_as_reference); + } + + jxl::FrameInfo frame_info; + frame_info.is_last = last_frame; + frame_info.save_as_reference = save_as_reference; + frame_info.source = + input_frame->option_values.header.layer_info.blend_info.source; + frame_info.clamp = + input_frame->option_values.header.layer_info.blend_info.clamp; + frame_info.alpha_channel = + input_frame->option_values.header.layer_info.blend_info.alpha; + frame_info.extra_channel_blending_info.resize( + metadata.m.num_extra_channels); + // If extra channel blend info has not been set, use the blend mode from + // the layer_info. + JxlBlendInfo default_blend_info = + input_frame->option_values.header.layer_info.blend_info; + for (size_t i = 0; i < metadata.m.num_extra_channels; ++i) { + auto& to = frame_info.extra_channel_blending_info[i]; + const auto& from = + i < input_frame->option_values.extra_channel_blend_info.size() + ? input_frame->option_values.extra_channel_blend_info[i] + : default_blend_info; + to.mode = static_cast<jxl::BlendMode>(from.blendmode); + to.source = from.source; + to.alpha_channel = from.alpha; + to.clamp = (from.clamp != 0); + } + frame_info.origin.x0 = + input_frame->option_values.header.layer_info.crop_x0; + frame_info.origin.y0 = + input_frame->option_values.header.layer_info.crop_y0; + frame_info.blendmode = static_cast<jxl::BlendMode>( + input_frame->option_values.header.layer_info.blend_info.blendmode); + frame_info.blend = + input_frame->option_values.header.layer_info.blend_info.blendmode != + JXL_BLEND_REPLACE; + frame_info.image_bit_depth = input_frame->option_values.image_bit_depth; + frame_info.duration = duration; + frame_info.timecode = timecode; + frame_info.name = input_frame->option_values.frame_name; + + if (!jxl::EncodeFrame(input_frame->option_values.cparams, frame_info, + &metadata, input_frame->frame_data, cms, + thread_pool.get(), &output_processor, + input_frame->option_values.aux_out)) { + return JXL_API_ERROR(this, JXL_ENC_ERR_GENERIC, + "Failed to encode frame"); + } + } else { + JXL_CHECK(fast_lossless_frame); + auto runner = +[](void* void_pool, void* opaque, void fun(void*, size_t), + size_t count) { + auto* pool = reinterpret_cast<jxl::ThreadPool*>(void_pool); + JXL_CHECK(jxl::RunOnPool( + pool, 0, count, jxl::ThreadPool::NoInit, + [&](size_t i, size_t) { fun(opaque, i); }, "Encode fast lossless")); + }; + JxlFastLosslessProcessFrame(fast_lossless_frame.get(), last_frame, + thread_pool.get(), runner, &output_processor); + } + + const size_t frame_codestream_end = output_processor.CurrentPosition(); + const size_t frame_codestream_size = + frame_codestream_end - frame_codestream_start; + + codestream_bytes_written_end_of_frame += + frame_codestream_size - header_bytes.size(); + + if (MustUseContainer()) { + output_processor.Seek(frame_start_pos); + std::vector<uint8_t> box_header(box_header_size); + if (!use_large_box && + frame_codestream_size >= jxl::kLargeBoxContentSizeThreshold) { + // Assuming our upper bound estimate is correct, this should never + // happen. + return JXL_API_ERROR( + this, JXL_ENC_ERR_GENERIC, + "Box size was estimated to be small, but turned out to be large. " + "Please file this error in size estimation as a bug."); + } + if (last_frame && jxlp_counter == 0) { +#if JXL_ENABLE_ASSERT + const size_t n = +#endif + jxl::WriteBoxHeader(jxl::MakeBoxType("jxlc"), frame_codestream_size, + /*unbounded=*/false, use_large_box, + &box_header[0]); + JXL_ASSERT(n == box_header_size); + } else { +#if JXL_ENABLE_ASSERT + const size_t n = +#endif + jxl::WriteBoxHeader( + jxl::MakeBoxType("jxlp"), frame_codestream_size + 4, + /*unbounded=*/false, use_large_box, &box_header[0]); + JXL_ASSERT(n == box_header_size - 4); + WriteJxlpBoxCounter(jxlp_counter++, last_frame, + &box_header[box_header_size - 4]); + } + JXL_RETURN_IF_ERROR(AppendData(output_processor, box_header)); + JXL_ASSERT(output_processor.CurrentPosition() == frame_codestream_start); + output_processor.Seek(frame_codestream_end); + } + output_processor.SetFinalizedPosition(); + if (input_frame) { + last_used_cparams = input_frame->option_values.cparams; + } + if (last_frame && frame_index_box.StoreFrameIndexBox()) { + std::vector<uint8_t> index_box_content; + EncodeFrameIndexBox(frame_index_box, index_box_content); + JXL_RETURN_IF_ERROR(AppendBoxWithContents(jxl::MakeBoxType("jxli"), + jxl::Bytes(index_box_content))); + } + } else { + // Not a frame, so is a box instead + jxl::MemoryManagerUniquePtr<jxl::JxlEncoderQueuedBox> box = + std::move(input.box); + input_queue.erase(input_queue.begin()); + num_queued_boxes--; + + if (box->compress_box) { + jxl::PaddedBytes compressed(4); + // Prepend the original box type in the brob box contents + for (size_t i = 0; i < 4; i++) { + compressed[i] = static_cast<uint8_t>(box->type[i]); + } + if (JXL_ENC_SUCCESS != + BrotliCompress((brotli_effort >= 0 ? brotli_effort : 4), + box->contents.data(), box->contents.size(), + &compressed)) { + return JXL_API_ERROR(this, JXL_ENC_ERR_GENERIC, + "Brotli compression for brob box failed"); + } + + JXL_RETURN_IF_ERROR( + AppendBoxWithContents(jxl::MakeBoxType("brob"), compressed)); + } else { + JXL_RETURN_IF_ERROR(AppendBoxWithContents(box->type, box->contents)); + } + } + + return jxl::OkStatus(); +} + +JxlEncoderStatus JxlEncoderSetColorEncoding(JxlEncoder* enc, + const JxlColorEncoding* color) { + if (!enc->basic_info_set) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Basic info not yet set"); + } + if (enc->color_encoding_set) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Color encoding is already set"); + } + if (!enc->metadata.m.color_encoding.FromExternal(*color)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_GENERIC, "Error in color conversion"); + } + if (enc->metadata.m.color_encoding.GetColorSpace() == + jxl::ColorSpace::kGray) { + if (enc->basic_info.num_color_channels != 1) + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "Cannot use grayscale color encoding with num_color_channels != 1"); + } else { + if (enc->basic_info.num_color_channels != 3) + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "Cannot use RGB color encoding with num_color_channels != 3"); + } + enc->color_encoding_set = true; + if (!enc->intensity_target_set) { + jxl::SetIntensityTarget(&enc->metadata.m); + } + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderSetICCProfile(JxlEncoder* enc, + const uint8_t* icc_profile, + size_t size) { + if (!enc->basic_info_set) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Basic info not yet set"); + } + if (enc->color_encoding_set) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "ICC profile is already set"); + } + if (size == 0) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_BAD_INPUT, "Empty ICC profile"); + } + jxl::IccBytes icc; + icc.assign(icc_profile, icc_profile + size); + if (enc->cms_set) { + if (!enc->metadata.m.color_encoding.SetICC(std::move(icc), &enc->cms)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_BAD_INPUT, + "ICC profile could not be set"); + } + } else { + enc->metadata.m.color_encoding.SetICCRaw(std::move(icc)); + } + if (enc->metadata.m.color_encoding.GetColorSpace() == + jxl::ColorSpace::kGray) { + if (enc->basic_info.num_color_channels != 1) + return JXL_API_ERROR( + enc, JXL_ENC_ERR_BAD_INPUT, + "Cannot use grayscale ICC profile with num_color_channels != 1"); + } else { + if (enc->basic_info.num_color_channels != 3) + return JXL_API_ERROR( + enc, JXL_ENC_ERR_BAD_INPUT, + "Cannot use RGB ICC profile with num_color_channels != 3"); + // TODO(jon): also check that a kBlack extra channel is provided in the CMYK + // case + } + enc->color_encoding_set = true; + if (!enc->intensity_target_set) { + jxl::SetIntensityTarget(&enc->metadata.m); + } + + if (!enc->basic_info.uses_original_profile && enc->cms_set) { + enc->metadata.m.color_encoding.DecideIfWantICC(enc->cms); + } + + return JxlErrorOrStatus::Success(); +} + +void JxlEncoderInitBasicInfo(JxlBasicInfo* info) { + info->have_container = JXL_FALSE; + info->xsize = 0; + info->ysize = 0; + info->bits_per_sample = 8; + info->exponent_bits_per_sample = 0; + info->intensity_target = 0.f; + info->min_nits = 0.f; + info->relative_to_max_display = JXL_FALSE; + info->linear_below = 0.f; + info->uses_original_profile = JXL_FALSE; + info->have_preview = JXL_FALSE; + info->have_animation = JXL_FALSE; + info->orientation = JXL_ORIENT_IDENTITY; + info->num_color_channels = 3; + info->num_extra_channels = 0; + info->alpha_bits = 0; + info->alpha_exponent_bits = 0; + info->alpha_premultiplied = JXL_FALSE; + info->preview.xsize = 0; + info->preview.ysize = 0; + info->intrinsic_xsize = 0; + info->intrinsic_ysize = 0; + info->animation.tps_numerator = 10; + info->animation.tps_denominator = 1; + info->animation.num_loops = 0; + info->animation.have_timecodes = JXL_FALSE; +} + +void JxlEncoderInitFrameHeader(JxlFrameHeader* frame_header) { + // For each field, the default value of the specification is used. Depending + // on whether an animation frame, or a composite still blending frame, + // is used, different fields have to be set up by the user after initing + // the frame header. + frame_header->duration = 0; + frame_header->timecode = 0; + frame_header->name_length = 0; + // In the specification, the default value of is_last is !frame_type, and the + // default frame_type is kRegularFrame which has value 0, so is_last is true + // by default. However, the encoder does not use this value (the field exists + // for the decoder to set) since last frame is determined by usage of + // JxlEncoderCloseFrames instead. + frame_header->is_last = JXL_TRUE; + frame_header->layer_info.have_crop = JXL_FALSE; + frame_header->layer_info.crop_x0 = 0; + frame_header->layer_info.crop_y0 = 0; + // These must be set if have_crop is enabled, but the default value has + // have_crop false, and these dimensions 0. The user must set these to the + // desired size after enabling have_crop (which is not yet implemented). + frame_header->layer_info.xsize = 0; + frame_header->layer_info.ysize = 0; + JxlEncoderInitBlendInfo(&frame_header->layer_info.blend_info); + frame_header->layer_info.save_as_reference = 0; +} + +void JxlEncoderInitBlendInfo(JxlBlendInfo* blend_info) { + // Default blend mode in the specification is 0. Note that combining + // blend mode of replace with a duration is not useful, but the user has to + // manually set duration in case of animation, or manually change the blend + // mode in case of composite stills, so initing to a combination that is not + // useful on its own is not an issue. + blend_info->blendmode = JXL_BLEND_REPLACE; + blend_info->source = 0; + blend_info->alpha = 0; + blend_info->clamp = 0; +} + +JxlEncoderStatus JxlEncoderSetBasicInfo(JxlEncoder* enc, + const JxlBasicInfo* info) { + if (!enc->metadata.size.Set(info->xsize, info->ysize)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Invalid dimensions"); + } + if (JXL_ENC_SUCCESS != CheckValidBitdepth(info->bits_per_sample, + info->exponent_bits_per_sample)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Invalid bit depth"); + } + + enc->metadata.m.bit_depth.bits_per_sample = info->bits_per_sample; + enc->metadata.m.bit_depth.exponent_bits_per_sample = + info->exponent_bits_per_sample; + enc->metadata.m.bit_depth.floating_point_sample = + (info->exponent_bits_per_sample != 0u); + enc->metadata.m.modular_16_bit_buffer_sufficient = + (!info->uses_original_profile || info->bits_per_sample <= 12) && + info->alpha_bits <= 12; + if ((info->intrinsic_xsize > 0 || info->intrinsic_ysize > 0) && + (info->intrinsic_xsize != info->xsize || + info->intrinsic_ysize != info->ysize)) { + if (info->intrinsic_xsize > (1ull << 30ull) || + info->intrinsic_ysize > (1ull << 30ull) || + !enc->metadata.m.intrinsic_size.Set(info->intrinsic_xsize, + info->intrinsic_ysize)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Invalid intrinsic dimensions"); + } + enc->metadata.m.have_intrinsic_size = true; + } + + // The number of extra channels includes the alpha channel, so for example and + // RGBA with no other extra channels, has exactly num_extra_channels == 1 + enc->metadata.m.num_extra_channels = info->num_extra_channels; + enc->metadata.m.extra_channel_info.resize(enc->metadata.m.num_extra_channels); + if (info->num_extra_channels == 0 && info->alpha_bits) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "when alpha_bits is non-zero, the number of channels must be at least " + "1"); + } + // If the user provides non-zero alpha_bits, we make the channel info at index + // zero the appropriate alpha channel. + if (info->alpha_bits) { + JxlExtraChannelInfo channel_info; + JxlEncoderInitExtraChannelInfo(JXL_CHANNEL_ALPHA, &channel_info); + channel_info.bits_per_sample = info->alpha_bits; + channel_info.exponent_bits_per_sample = info->alpha_exponent_bits; + if (JxlEncoderSetExtraChannelInfo(enc, 0, &channel_info)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Problem setting extra channel info for alpha"); + } + } + + enc->metadata.m.xyb_encoded = !info->uses_original_profile; + if (info->orientation > 0 && info->orientation <= 8) { + enc->metadata.m.orientation = info->orientation; + } else { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Invalid value for orientation field"); + } + if (info->num_color_channels != 1 && info->num_color_channels != 3) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Invalid number of color channels"); + } + if (info->intensity_target != 0) { + enc->metadata.m.SetIntensityTarget(info->intensity_target); + enc->intensity_target_set = true; + } else if (enc->color_encoding_set) { + // If this is false, JxlEncoderSetColorEncoding will be called later and we + // will get one more chance to call jxl::SetIntensityTarget, after the color + // encoding is indeed set. + jxl::SetIntensityTarget(&enc->metadata.m); + enc->intensity_target_set = true; + } + enc->metadata.m.tone_mapping.min_nits = info->min_nits; + enc->metadata.m.tone_mapping.relative_to_max_display = + info->relative_to_max_display; + enc->metadata.m.tone_mapping.linear_below = info->linear_below; + enc->basic_info = *info; + enc->basic_info_set = true; + + enc->metadata.m.have_animation = info->have_animation; + if (info->have_animation) { + if (info->animation.tps_denominator < 1) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "If animation is used, tps_denominator must be >= 1"); + } + if (info->animation.tps_numerator < 1) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "If animation is used, tps_numerator must be >= 1"); + } + enc->metadata.m.animation.tps_numerator = info->animation.tps_numerator; + enc->metadata.m.animation.tps_denominator = info->animation.tps_denominator; + enc->metadata.m.animation.num_loops = info->animation.num_loops; + enc->metadata.m.animation.have_timecodes = info->animation.have_timecodes; + } + std::string level_message; + int required_level = VerifyLevelSettings(enc, &level_message); + if (required_level == -1 || + (static_cast<int>(enc->codestream_level) < required_level && + enc->codestream_level != -1)) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, "%s", + ("Codestream level verification for level " + + std::to_string(enc->codestream_level) + " failed: " + level_message) + .c_str()); + } + return JxlErrorOrStatus::Success(); +} + +void JxlEncoderInitExtraChannelInfo(JxlExtraChannelType type, + JxlExtraChannelInfo* info) { + info->type = type; + info->bits_per_sample = 8; + info->exponent_bits_per_sample = 0; + info->dim_shift = 0; + info->name_length = 0; + info->alpha_premultiplied = JXL_FALSE; + info->spot_color[0] = 0; + info->spot_color[1] = 0; + info->spot_color[2] = 0; + info->spot_color[3] = 0; + info->cfa_channel = 0; +} + +JXL_EXPORT JxlEncoderStatus JxlEncoderSetUpsamplingMode(JxlEncoder* enc, + const int64_t factor, + const int64_t mode) { + // for convenience, allow calling this with factor 1 and just make it a no-op + if (factor == 1) return JxlErrorOrStatus::Success(); + if (factor != 2 && factor != 4 && factor != 8) + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Invalid upsampling factor"); + if (mode < -1) + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Invalid upsampling mode"); + if (mode > 1) + return JXL_API_ERROR(enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Unsupported upsampling mode"); + + const size_t count = (factor == 2 ? 15 : (factor == 4 ? 55 : 210)); + auto& td = enc->metadata.transform_data; + float* weights = (factor == 2 ? td.upsampling2_weights + : (factor == 4 ? td.upsampling4_weights + : td.upsampling8_weights)); + if (mode == -1) { + // Default fancy upsampling: don't signal custom weights + enc->metadata.transform_data.custom_weights_mask &= ~(factor >> 1); + } else if (mode == 0) { + // Nearest neighbor upsampling + enc->metadata.transform_data.custom_weights_mask |= (factor >> 1); + memset(weights, 0, sizeof(float) * count); + if (factor == 2) { + weights[9] = 1.f; + } else if (factor == 4) { + for (int i : {19, 24, 49}) weights[i] = 1.f; + } else if (factor == 8) { + for (int i : {39, 44, 49, 54, 119, 124, 129, 174, 179, 204}) { + weights[i] = 1.f; + } + } + } else if (mode == 1) { + // 'Pixel dots' upsampling (nearest-neighbor with cut corners) + JxlEncoderSetUpsamplingMode(enc, factor, 0); + if (factor == 4) { + weights[19] = 0.f; + weights[24] = 0.5f; + } else if (factor == 8) { + for (int i : {39, 44, 49, 119}) weights[i] = 0.f; + for (int i : {54, 124}) weights[i] = 0.5f; + } + } + return JxlErrorOrStatus::Success(); +} + +JXL_EXPORT JxlEncoderStatus JxlEncoderSetExtraChannelInfo( + JxlEncoder* enc, size_t index, const JxlExtraChannelInfo* info) { + if (index >= enc->metadata.m.num_extra_channels) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Invalid value for the index of extra channel"); + } + if (JXL_ENC_SUCCESS != CheckValidBitdepth(info->bits_per_sample, + info->exponent_bits_per_sample)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, "Invalid bit depth"); + } + + jxl::ExtraChannelInfo& channel = enc->metadata.m.extra_channel_info[index]; + channel.type = static_cast<jxl::ExtraChannel>(info->type); + channel.bit_depth.bits_per_sample = info->bits_per_sample; + enc->metadata.m.modular_16_bit_buffer_sufficient &= + info->bits_per_sample <= 12; + channel.bit_depth.exponent_bits_per_sample = info->exponent_bits_per_sample; + channel.bit_depth.floating_point_sample = info->exponent_bits_per_sample != 0; + channel.dim_shift = info->dim_shift; + channel.name = ""; + channel.alpha_associated = (info->alpha_premultiplied != 0); + channel.cfa_channel = info->cfa_channel; + channel.spot_color[0] = info->spot_color[0]; + channel.spot_color[1] = info->spot_color[1]; + channel.spot_color[2] = info->spot_color[2]; + channel.spot_color[3] = info->spot_color[3]; + std::string level_message; + int required_level = VerifyLevelSettings(enc, &level_message); + if (required_level == -1 || + (static_cast<int>(enc->codestream_level) < required_level && + enc->codestream_level != -1)) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, "%s", + ("Codestream level verification for level " + + std::to_string(enc->codestream_level) + " failed: " + level_message) + .c_str()); + } + return JxlErrorOrStatus::Success(); +} + +JXL_EXPORT JxlEncoderStatus JxlEncoderSetExtraChannelName(JxlEncoder* enc, + size_t index, + const char* name, + size_t size) { + if (index >= enc->metadata.m.num_extra_channels) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Invalid value for the index of extra channel"); + } + enc->metadata.m.extra_channel_info[index].name = + std::string(name, name + size); + return JxlErrorOrStatus::Success(); +} + +JxlEncoderFrameSettings* JxlEncoderFrameSettingsCreate( + JxlEncoder* enc, const JxlEncoderFrameSettings* source) { + auto opts = jxl::MemoryManagerMakeUnique<JxlEncoderFrameSettings>( + &enc->memory_manager); + if (!opts) return nullptr; + opts->enc = enc; + if (source != nullptr) { + opts->values = source->values; + } else { + opts->values.lossless = false; + } + opts->values.cparams.level = enc->codestream_level; + opts->values.cparams.ec_distance.resize(enc->metadata.m.num_extra_channels, + -1); + + JxlEncoderFrameSettings* ret = opts.get(); + enc->encoder_options.emplace_back(std::move(opts)); + return ret; +} + +JxlEncoderStatus JxlEncoderSetFrameLossless( + JxlEncoderFrameSettings* frame_settings, const JXL_BOOL lossless) { + if (lossless && frame_settings->enc->basic_info_set && + frame_settings->enc->metadata.m.xyb_encoded) { + return JXL_API_ERROR( + frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Set uses_original_profile=true for lossless encoding"); + } + frame_settings->values.lossless = lossless; + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderSetFrameDistance( + JxlEncoderFrameSettings* frame_settings, float distance) { + if (distance < 0.f || distance > 25.f) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Distance has to be in [0.0..25.0] (corresponding to " + "quality in [0.0..100.0])"); + } + if (distance > 0.f && distance < 0.01f) { + distance = 0.01f; + } + frame_settings->values.cparams.butteraugli_distance = distance; + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderSetExtraChannelDistance( + JxlEncoderFrameSettings* frame_settings, size_t index, float distance) { + if (index >= frame_settings->enc->metadata.m.num_extra_channels) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Invalid value for the index of extra channel"); + } + if (distance != -1.f && (distance < 0.f || distance > 25.f)) { + return JXL_API_ERROR( + frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Distance has to be -1 or in [0.0..25.0] (corresponding to " + "quality in [0.0..100.0])"); + } + if (distance > 0.f && distance < 0.01f) { + distance = 0.01f; + } + + if (index >= frame_settings->values.cparams.ec_distance.size()) { + // This can only happen if JxlEncoderFrameSettingsCreate() was called before + // JxlEncoderSetBasicInfo(). + frame_settings->values.cparams.ec_distance.resize( + frame_settings->enc->metadata.m.num_extra_channels, -1); + } + + frame_settings->values.cparams.ec_distance[index] = distance; + return JxlErrorOrStatus::Success(); +} + +float JxlEncoderDistanceFromQuality(float quality) { + return quality >= 100.0 ? 0.0 + : quality >= 30 + ? 0.1 + (100 - quality) * 0.09 + : 53.0 / 3000.0 * quality * quality - 23.0 / 20.0 * quality + 25.0; +} + +JxlEncoderStatus JxlEncoderFrameSettingsSetOption( + JxlEncoderFrameSettings* frame_settings, JxlEncoderFrameSettingId option, + int64_t value) { + // check if value is -1, 0 or 1 for Override-type options + switch (option) { + case JXL_ENC_FRAME_SETTING_NOISE: + case JXL_ENC_FRAME_SETTING_DOTS: + case JXL_ENC_FRAME_SETTING_PATCHES: + case JXL_ENC_FRAME_SETTING_GABORISH: + case JXL_ENC_FRAME_SETTING_MODULAR: + case JXL_ENC_FRAME_SETTING_KEEP_INVISIBLE: + case JXL_ENC_FRAME_SETTING_GROUP_ORDER: + case JXL_ENC_FRAME_SETTING_RESPONSIVE: + case JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC: + case JXL_ENC_FRAME_SETTING_QPROGRESSIVE_AC: + case JXL_ENC_FRAME_SETTING_LOSSY_PALETTE: + case JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL: + case JXL_ENC_FRAME_SETTING_JPEG_COMPRESS_BOXES: + case JXL_ENC_FRAME_SETTING_JPEG_KEEP_EXIF: + case JXL_ENC_FRAME_SETTING_JPEG_KEEP_XMP: + case JXL_ENC_FRAME_SETTING_JPEG_KEEP_JUMBF: + if (value < -1 || value > 1) { + return JXL_API_ERROR( + frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be -1 (default), 0 (off) or 1 (on)"); + } + break; + default: + break; + } + + switch (option) { + case JXL_ENC_FRAME_SETTING_EFFORT: + if (frame_settings->enc->allow_expert_options) { + if (value < 1 || value > 10) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Encode effort has to be in [1..10]"); + } + } else { + if (value < 1 || value > 9) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Encode effort has to be in [1..9]"); + } + } + frame_settings->values.cparams.speed_tier = + static_cast<jxl::SpeedTier>(10 - value); + break; + case JXL_ENC_FRAME_SETTING_BROTLI_EFFORT: + if (value < -1 || value > 11) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Brotli effort has to be in [-1..11]"); + } + // set cparams for brotli use in JPEG frames + frame_settings->values.cparams.brotli_effort = value; + // set enc option for brotli use in brob boxes + frame_settings->enc->brotli_effort = value; + break; + case JXL_ENC_FRAME_SETTING_DECODING_SPEED: + if (value < 0 || value > 4) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Decoding speed has to be in [0..4]"); + } + frame_settings->values.cparams.decoding_speed_tier = value; + break; + case JXL_ENC_FRAME_SETTING_RESAMPLING: + if (value != -1 && value != 1 && value != 2 && value != 4 && value != 8) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Resampling factor has to be 1, 2, 4 or 8"); + } + frame_settings->values.cparams.resampling = value; + break; + case JXL_ENC_FRAME_SETTING_EXTRA_CHANNEL_RESAMPLING: + // TODO(lode): the jxl codestream allows choosing a different resampling + // factor for each extra channel, independently per frame. Move this + // option to a JxlEncoderFrameSettings-option that can be set per extra + // channel, so needs its own function rather than + // JxlEncoderFrameSettingsSetOption due to the extra channel index + // argument required. + if (value != -1 && value != 1 && value != 2 && value != 4 && value != 8) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Resampling factor has to be 1, 2, 4 or 8"); + } + frame_settings->values.cparams.ec_resampling = value; + break; + case JXL_ENC_FRAME_SETTING_ALREADY_DOWNSAMPLED: + if (value < 0 || value > 1) { + return JxlErrorOrStatus::Error(); + } + frame_settings->values.cparams.already_downsampled = (value == 1); + break; + case JXL_ENC_FRAME_SETTING_NOISE: + frame_settings->values.cparams.noise = static_cast<jxl::Override>(value); + break; + case JXL_ENC_FRAME_SETTING_DOTS: + frame_settings->values.cparams.dots = static_cast<jxl::Override>(value); + break; + case JXL_ENC_FRAME_SETTING_PATCHES: + frame_settings->values.cparams.patches = + static_cast<jxl::Override>(value); + break; + case JXL_ENC_FRAME_SETTING_EPF: + if (value < -1 || value > 3) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "EPF value has to be in [-1..3]"); + } + frame_settings->values.cparams.epf = static_cast<int>(value); + break; + case JXL_ENC_FRAME_SETTING_GABORISH: + frame_settings->values.cparams.gaborish = + static_cast<jxl::Override>(value); + break; + case JXL_ENC_FRAME_SETTING_MODULAR: + frame_settings->values.cparams.modular_mode = (value == 1); + break; + case JXL_ENC_FRAME_SETTING_KEEP_INVISIBLE: + frame_settings->values.cparams.keep_invisible = + static_cast<jxl::Override>(value); + break; + case JXL_ENC_FRAME_SETTING_GROUP_ORDER: + frame_settings->values.cparams.centerfirst = (value == 1); + break; + case JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_X: + if (value < -1) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Center x coordinate has to be -1 or positive"); + } + frame_settings->values.cparams.center_x = static_cast<size_t>(value); + break; + case JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_Y: + if (value < -1) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Center y coordinate has to be -1 or positive"); + } + frame_settings->values.cparams.center_y = static_cast<size_t>(value); + break; + case JXL_ENC_FRAME_SETTING_RESPONSIVE: + frame_settings->values.cparams.responsive = value; + break; + case JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC: + frame_settings->values.cparams.progressive_mode = + static_cast<jxl::Override>(value); + break; + case JXL_ENC_FRAME_SETTING_QPROGRESSIVE_AC: + frame_settings->values.cparams.qprogressive_mode = + static_cast<jxl::Override>(value); + break; + case JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC: + if (value < -1 || value > 2) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Progressive DC has to be in [-1..2]"); + } + frame_settings->values.cparams.progressive_dc = value; + break; + case JXL_ENC_FRAME_SETTING_PALETTE_COLORS: + if (value < -1 || value > 70913) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..70913]"); + } + if (value == -1) { + frame_settings->values.cparams.palette_colors = 1 << 10; + } else { + frame_settings->values.cparams.palette_colors = value; + } + break; + case JXL_ENC_FRAME_SETTING_LOSSY_PALETTE: + // TODO(lode): the defaults of some palette settings depend on others. + // See the logic in cjxl. Similar for other settings. This should be + // handled in the encoder during JxlEncoderProcessOutput (or, + // alternatively, in the cjxl binary like now) + frame_settings->values.cparams.lossy_palette = (value == 1); + break; + case JXL_ENC_FRAME_SETTING_COLOR_TRANSFORM: + if (value < -1 || value > 2) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..2]"); + } + if (value == -1) { + frame_settings->values.cparams.color_transform = + jxl::ColorTransform::kXYB; + } else { + frame_settings->values.cparams.color_transform = + static_cast<jxl::ColorTransform>(value); + } + break; + case JXL_ENC_FRAME_SETTING_MODULAR_COLOR_SPACE: + if (value < -1 || value > 41) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..41]"); + } + frame_settings->values.cparams.colorspace = value; + break; + case JXL_ENC_FRAME_SETTING_MODULAR_GROUP_SIZE: + if (value < -1 || value > 3) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..3]"); + } + frame_settings->values.cparams.modular_group_size_shift = value; + break; + case JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR: + if (value < -1 || value > 15) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..15]"); + } + frame_settings->values.cparams.options.predictor = + static_cast<jxl::Predictor>(value); + break; + case JXL_ENC_FRAME_SETTING_MODULAR_NB_PREV_CHANNELS: + // The max allowed value can in theory be higher. However, it depends on + // the effort setting. 11 is the highest safe value that doesn't cause + // tree_samples to be >= 64 in the encoder. The specification may allow + // more than this. With more fine tuning higher values could be allowed. + // For N-channel images, the largest useful value is N-1. + if (value < -1 || value > 11) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..11]"); + } + if (value == -1) { + frame_settings->values.cparams.options.max_properties = 0; + } else { + frame_settings->values.cparams.options.max_properties = value; + } + break; + case JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL: + if (value == -1) { + frame_settings->values.cparams.force_cfl_jpeg_recompression = true; + } else { + frame_settings->values.cparams.force_cfl_jpeg_recompression = value; + } + break; + case JXL_ENC_FRAME_INDEX_BOX: + if (value < 0 || value > 1) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Option value has to be 0 or 1"); + } + frame_settings->values.frame_index_box = true; + break; + case JXL_ENC_FRAME_SETTING_PHOTON_NOISE: + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Float option, try setting it with " + "JxlEncoderFrameSettingsSetFloatOption"); + case JXL_ENC_FRAME_SETTING_JPEG_COMPRESS_BOXES: + frame_settings->values.cparams.jpeg_compress_boxes = value; + break; + case JXL_ENC_FRAME_SETTING_BUFFERING: + if (value < 0 || value > 3) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Buffering has to be in [0..3]"); + } + frame_settings->values.cparams.buffering = value; + break; + case JXL_ENC_FRAME_SETTING_JPEG_KEEP_EXIF: + frame_settings->values.cparams.jpeg_keep_exif = value; + break; + case JXL_ENC_FRAME_SETTING_JPEG_KEEP_XMP: + frame_settings->values.cparams.jpeg_keep_xmp = value; + break; + case JXL_ENC_FRAME_SETTING_JPEG_KEEP_JUMBF: + frame_settings->values.cparams.jpeg_keep_jumbf = value; + break; + case JXL_ENC_FRAME_SETTING_USE_FULL_IMAGE_HEURISTICS: + if (value < 0 || value > 1) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Option value has to be 0 or 1"); + } + frame_settings->values.cparams.use_full_image_heuristics = value; + break; + + default: + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Unknown option"); + } + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderFrameSettingsSetFloatOption( + JxlEncoderFrameSettings* frame_settings, JxlEncoderFrameSettingId option, + float value) { + switch (option) { + case JXL_ENC_FRAME_SETTING_PHOTON_NOISE: + if (value < 0) return JXL_ENC_ERROR; + // TODO(lode): add encoder setting to set the 8 floating point values of + // the noise synthesis parameters per frame for more fine grained control. + frame_settings->values.cparams.photon_noise_iso = value; + return JxlErrorOrStatus::Success(); + case JXL_ENC_FRAME_SETTING_MODULAR_MA_TREE_LEARNING_PERCENT: + if (value < -1.f || value > 100.f) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be smaller than 100"); + } + // This value is called "iterations" or "nb_repeats" in cjxl, but is in + // fact a fraction in range 0.0-1.0, with the default value 0.5. + // Convert from floating point percentage to floating point fraction here. + if (value < -.5f) { + // TODO(lode): for this and many other settings (also in + // JxlEncoderFrameSettingsSetOption), avoid duplicating the default + // values here and in enc_params.h and options.h, have one location + // where the defaults are specified. + frame_settings->values.cparams.options.nb_repeats = 0.5f; + } else { + frame_settings->values.cparams.options.nb_repeats = value * 0.01f; + } + return JxlErrorOrStatus::Success(); + case JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GLOBAL_PERCENT: + if (value < -1.f || value > 100.f) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..100]"); + } + if (value < -.5f) { + frame_settings->values.cparams.channel_colors_pre_transform_percent = + 95.0f; + } else { + frame_settings->values.cparams.channel_colors_pre_transform_percent = + value; + } + return JxlErrorOrStatus::Success(); + case JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GROUP_PERCENT: + if (value < -1.f || value > 100.f) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Option value has to be in [-1..100]"); + } + if (value < -.5f) { + frame_settings->values.cparams.channel_colors_percent = 80.0f; + } else { + frame_settings->values.cparams.channel_colors_percent = value; + } + return JxlErrorOrStatus::Success(); + case JXL_ENC_FRAME_SETTING_EFFORT: + case JXL_ENC_FRAME_SETTING_DECODING_SPEED: + case JXL_ENC_FRAME_SETTING_RESAMPLING: + case JXL_ENC_FRAME_SETTING_EXTRA_CHANNEL_RESAMPLING: + case JXL_ENC_FRAME_SETTING_ALREADY_DOWNSAMPLED: + case JXL_ENC_FRAME_SETTING_NOISE: + case JXL_ENC_FRAME_SETTING_DOTS: + case JXL_ENC_FRAME_SETTING_PATCHES: + case JXL_ENC_FRAME_SETTING_EPF: + case JXL_ENC_FRAME_SETTING_GABORISH: + case JXL_ENC_FRAME_SETTING_MODULAR: + case JXL_ENC_FRAME_SETTING_KEEP_INVISIBLE: + case JXL_ENC_FRAME_SETTING_GROUP_ORDER: + case JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_X: + case JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_Y: + case JXL_ENC_FRAME_SETTING_RESPONSIVE: + case JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC: + case JXL_ENC_FRAME_SETTING_QPROGRESSIVE_AC: + case JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC: + case JXL_ENC_FRAME_SETTING_PALETTE_COLORS: + case JXL_ENC_FRAME_SETTING_LOSSY_PALETTE: + case JXL_ENC_FRAME_SETTING_COLOR_TRANSFORM: + case JXL_ENC_FRAME_SETTING_MODULAR_COLOR_SPACE: + case JXL_ENC_FRAME_SETTING_MODULAR_GROUP_SIZE: + case JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR: + case JXL_ENC_FRAME_SETTING_MODULAR_NB_PREV_CHANNELS: + case JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL: + case JXL_ENC_FRAME_INDEX_BOX: + case JXL_ENC_FRAME_SETTING_BROTLI_EFFORT: + case JXL_ENC_FRAME_SETTING_FILL_ENUM: + case JXL_ENC_FRAME_SETTING_JPEG_COMPRESS_BOXES: + case JXL_ENC_FRAME_SETTING_BUFFERING: + case JXL_ENC_FRAME_SETTING_JPEG_KEEP_EXIF: + case JXL_ENC_FRAME_SETTING_JPEG_KEEP_XMP: + case JXL_ENC_FRAME_SETTING_JPEG_KEEP_JUMBF: + case JXL_ENC_FRAME_SETTING_USE_FULL_IMAGE_HEURISTICS: + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Int option, try setting it with " + "JxlEncoderFrameSettingsSetOption"); + default: + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_NOT_SUPPORTED, + "Unknown option"); + } +} +JxlEncoder* JxlEncoderCreate(const JxlMemoryManager* memory_manager) { + JxlMemoryManager local_memory_manager; + if (!jxl::MemoryManagerInit(&local_memory_manager, memory_manager)) { + return nullptr; + } + + void* alloc = + jxl::MemoryManagerAlloc(&local_memory_manager, sizeof(JxlEncoder)); + if (!alloc) return nullptr; + JxlEncoder* enc = new (alloc) JxlEncoder(); + enc->memory_manager = local_memory_manager; + // TODO(sboukortt): add an API function to set this. + enc->cms = *JxlGetDefaultCms(); + enc->cms_set = true; + + // Initialize all the field values. + JxlEncoderReset(enc); + + return enc; +} + +void JxlEncoderReset(JxlEncoder* enc) { + enc->thread_pool.reset(); + enc->input_queue.clear(); + enc->num_queued_frames = 0; + enc->num_queued_boxes = 0; + enc->encoder_options.clear(); + enc->codestream_bytes_written_end_of_frame = 0; + enc->wrote_bytes = false; + enc->jxlp_counter = 0; + enc->metadata = jxl::CodecMetadata(); + enc->last_used_cparams = jxl::CompressParams(); + enc->frames_closed = false; + enc->boxes_closed = false; + enc->basic_info_set = false; + enc->color_encoding_set = false; + enc->intensity_target_set = false; + enc->use_container = false; + enc->use_boxes = false; + enc->codestream_level = -1; + enc->output_processor = JxlEncoderOutputProcessorWrapper(); + JxlEncoderInitBasicInfo(&enc->basic_info); +} + +void JxlEncoderDestroy(JxlEncoder* enc) { + if (enc) { + JxlMemoryManager local_memory_manager = enc->memory_manager; + // Call destructor directly since custom free function is used. + enc->~JxlEncoder(); + jxl::MemoryManagerFree(&local_memory_manager, enc); + } +} + +JxlEncoderError JxlEncoderGetError(JxlEncoder* enc) { return enc->error; } + +JxlEncoderStatus JxlEncoderUseContainer(JxlEncoder* enc, + JXL_BOOL use_container) { + if (enc->wrote_bytes) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "this setting can only be set at the beginning"); + } + enc->use_container = static_cast<bool>(use_container); + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderStoreJPEGMetadata(JxlEncoder* enc, + JXL_BOOL store_jpeg_metadata) { + if (enc->wrote_bytes) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "this setting can only be set at the beginning"); + } + enc->store_jpeg_metadata = static_cast<bool>(store_jpeg_metadata); + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderSetCodestreamLevel(JxlEncoder* enc, int level) { + if (level != -1 && level != 5 && level != 10) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_NOT_SUPPORTED, "invalid level"); + } + if (enc->wrote_bytes) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "this setting can only be set at the beginning"); + } + enc->codestream_level = level; + return JxlErrorOrStatus::Success(); +} + +int JxlEncoderGetRequiredCodestreamLevel(const JxlEncoder* enc) { + return VerifyLevelSettings(enc, nullptr); +} + +void JxlEncoderSetCms(JxlEncoder* enc, JxlCmsInterface cms) { + jxl::msan::MemoryIsInitialized(&cms, sizeof(cms)); + enc->cms = cms; + enc->cms_set = true; +} + +JxlEncoderStatus JxlEncoderSetParallelRunner(JxlEncoder* enc, + JxlParallelRunner parallel_runner, + void* parallel_runner_opaque) { + if (enc->thread_pool) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "parallel runner already set"); + } + enc->thread_pool = jxl::MemoryManagerMakeUnique<jxl::ThreadPool>( + &enc->memory_manager, parallel_runner, parallel_runner_opaque); + if (!enc->thread_pool) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_GENERIC, + "error setting parallel runner"); + } + return JxlErrorOrStatus::Success(); +} + +namespace { +JxlEncoderStatus GetCurrentDimensions( + const JxlEncoderFrameSettings* frame_settings, size_t& xsize, + size_t& ysize) { + xsize = frame_settings->enc->metadata.xsize(); + ysize = frame_settings->enc->metadata.ysize(); + if (frame_settings->values.header.layer_info.have_crop) { + xsize = frame_settings->values.header.layer_info.xsize; + ysize = frame_settings->values.header.layer_info.ysize; + } + if (frame_settings->values.cparams.already_downsampled) { + size_t factor = frame_settings->values.cparams.resampling; + xsize = jxl::DivCeil(xsize, factor); + ysize = jxl::DivCeil(ysize, factor); + } + if (xsize == 0 || ysize == 0) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "zero-sized frame is not allowed"); + } + return JxlErrorOrStatus::Success(); +} +} // namespace + +JxlEncoderStatus JxlEncoderAddJPEGFrame( + const JxlEncoderFrameSettings* frame_settings, const uint8_t* buffer, + size_t size) { + if (frame_settings->enc->frames_closed) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Frame input is already closed"); + } + + jxl::CodecInOut io; + if (!jxl::jpeg::DecodeImageJPG(jxl::Bytes(buffer, size), &io)) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_BAD_INPUT, + "Error during decode of input JPEG"); + } + + if (!frame_settings->enc->color_encoding_set) { + SetColorEncodingFromJpegData( + *io.Main().jpeg_data, &frame_settings->enc->metadata.m.color_encoding); + frame_settings->enc->color_encoding_set = true; + } + + if (!frame_settings->enc->basic_info_set) { + JxlBasicInfo basic_info; + JxlEncoderInitBasicInfo(&basic_info); + basic_info.xsize = io.Main().jpeg_data->width; + basic_info.ysize = io.Main().jpeg_data->height; + basic_info.uses_original_profile = true; + if (JxlEncoderSetBasicInfo(frame_settings->enc, &basic_info) != + JXL_ENC_SUCCESS) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "Error setting basic info"); + } + } + + size_t xsize, ysize; + if (GetCurrentDimensions(frame_settings, xsize, ysize) != JXL_ENC_SUCCESS) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "bad dimensions"); + } + if (xsize != static_cast<size_t>(io.Main().jpeg_data->width) || + ysize != static_cast<size_t>(io.Main().jpeg_data->height)) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "JPEG dimensions don't match frame dimensions"); + } + + if (frame_settings->enc->metadata.m.xyb_encoded) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Can't XYB encode a lossless JPEG"); + } + if (!io.blobs.exif.empty()) { + JxlOrientation orientation = static_cast<JxlOrientation>( + frame_settings->enc->metadata.m.orientation); + jxl::InterpretExif(io.blobs.exif, &orientation); + frame_settings->enc->metadata.m.orientation = orientation; + } + if (!io.blobs.exif.empty() && frame_settings->values.cparams.jpeg_keep_exif) { + size_t exif_size = io.blobs.exif.size(); + // Exif data in JPEG is limited to 64k + if (exif_size > 0xFFFF) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "Exif larger than possible in JPEG?"); + } + exif_size += 4; // prefix 4 zero bytes for tiff offset + std::vector<uint8_t> exif(exif_size); + memcpy(exif.data() + 4, io.blobs.exif.data(), io.blobs.exif.size()); + JxlEncoderUseBoxes(frame_settings->enc); + JxlEncoderAddBox(frame_settings->enc, "Exif", exif.data(), exif_size, + frame_settings->values.cparams.jpeg_compress_boxes); + } + if (!io.blobs.xmp.empty() && frame_settings->values.cparams.jpeg_keep_xmp) { + JxlEncoderUseBoxes(frame_settings->enc); + JxlEncoderAddBox(frame_settings->enc, "xml ", io.blobs.xmp.data(), + io.blobs.xmp.size(), + frame_settings->values.cparams.jpeg_compress_boxes); + } + if (!io.blobs.jumbf.empty() && + frame_settings->values.cparams.jpeg_keep_jumbf) { + JxlEncoderUseBoxes(frame_settings->enc); + JxlEncoderAddBox(frame_settings->enc, "jumb", io.blobs.jumbf.data(), + io.blobs.jumbf.size(), + frame_settings->values.cparams.jpeg_compress_boxes); + } + if (frame_settings->enc->store_jpeg_metadata) { + if (!frame_settings->values.cparams.jpeg_keep_exif || + !frame_settings->values.cparams.jpeg_keep_xmp) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Need to preserve EXIF and XMP to allow JPEG " + "bitstream reconstruction"); + } + jxl::jpeg::JPEGData data_in = *io.Main().jpeg_data; + std::vector<uint8_t> jpeg_data; + if (!jxl::jpeg::EncodeJPEGData(data_in, &jpeg_data, + frame_settings->values.cparams)) { + return JXL_API_ERROR( + frame_settings->enc, JXL_ENC_ERR_JBRD, + "JPEG bitstream reconstruction data cannot be encoded"); + } + frame_settings->enc->jpeg_metadata = jpeg_data; + } + + jxl::JxlEncoderChunkedFrameAdapter frame_data( + xsize, ysize, frame_settings->enc->metadata.m.num_extra_channels); + frame_data.SetJPEGData(*io.Main().jpeg_data); + + auto queued_frame = jxl::MemoryManagerMakeUnique<jxl::JxlEncoderQueuedFrame>( + &frame_settings->enc->memory_manager, + // JxlEncoderQueuedFrame is a struct with no constructors, so we use the + // default move constructor there. + jxl::JxlEncoderQueuedFrame{ + frame_settings->values, std::move(frame_data), {}}); + if (!queued_frame) { + // TODO(jon): when can this happen? is this an API usage error? + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "No frame queued?"); + } + queued_frame->ec_initialized.resize( + frame_settings->enc->metadata.m.num_extra_channels); + + QueueFrame(frame_settings, queued_frame); + return JxlErrorOrStatus::Success(); +} + +static bool CanDoFastLossless(const JxlEncoderFrameSettings* frame_settings, + const JxlPixelFormat* pixel_format, + bool has_alpha) { + if (!frame_settings->values.lossless) { + return false; + } + // TODO(veluca): many of the following options could be made to work, but are + // just not implemented in FJXL's frame header handling yet. + if (frame_settings->values.frame_index_box) { + return false; + } + if (frame_settings->values.header.layer_info.have_crop) { + return false; + } + if (frame_settings->enc->metadata.m.have_animation) { + return false; + } + if (frame_settings->values.cparams.speed_tier != jxl::SpeedTier::kLightning) { + return false; + } + if (frame_settings->values.image_bit_depth.type == + JxlBitDepthType::JXL_BIT_DEPTH_CUSTOM && + frame_settings->values.image_bit_depth.bits_per_sample != + frame_settings->enc->metadata.m.bit_depth.bits_per_sample) { + return false; + } + // TODO(veluca): implement support for LSB-padded input in fast_lossless. + if (frame_settings->values.image_bit_depth.type == + JxlBitDepthType::JXL_BIT_DEPTH_FROM_PIXEL_FORMAT && + frame_settings->values.image_bit_depth.bits_per_sample % 8 != 0) { + return false; + } + if (!frame_settings->values.frame_name.empty()) { + return false; + } + // No extra channels other than alpha. + if (!(has_alpha && frame_settings->enc->metadata.m.num_extra_channels == 1) && + frame_settings->enc->metadata.m.num_extra_channels != 0) { + return false; + } + if (frame_settings->enc->metadata.m.bit_depth.bits_per_sample > 16) { + return false; + } + if (pixel_format->data_type != JxlDataType::JXL_TYPE_FLOAT16 && + pixel_format->data_type != JxlDataType::JXL_TYPE_UINT16 && + pixel_format->data_type != JxlDataType::JXL_TYPE_UINT8) { + return false; + } + if ((frame_settings->enc->metadata.m.bit_depth.bits_per_sample > 8) != + (pixel_format->data_type == JxlDataType::JXL_TYPE_UINT16 || + pixel_format->data_type == JxlDataType::JXL_TYPE_FLOAT16)) { + return false; + } + if (!((pixel_format->num_channels == 1 || pixel_format->num_channels == 3) && + !has_alpha) && + !((pixel_format->num_channels == 2 || pixel_format->num_channels == 4) && + has_alpha)) { + return false; + } + + return true; +} + +namespace { +JxlEncoderStatus JxlEncoderAddImageFrameInternal( + const JxlEncoderFrameSettings* frame_settings, size_t xsize, size_t ysize, + bool streaming, jxl::JxlEncoderChunkedFrameAdapter frame_data) { + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + { + JxlChunkedFrameInputSource input = frame_data.GetInputSource(); + input.get_color_channels_pixel_format(input.opaque, &pixel_format); + } + uint32_t num_channels = pixel_format.num_channels; + size_t has_interleaved_alpha = + static_cast<size_t>(num_channels == 2 || num_channels == 4); + + if (!frame_settings->enc->basic_info_set) { + // Basic Info must be set. Otherwise, this is an API misuse. + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Basic info or color encoding not set yet"); + } + if (frame_settings->enc->frames_closed) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Frame input already closed"); + } + if (num_channels < 3) { + if (frame_settings->enc->basic_info.num_color_channels != 1) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Grayscale pixel format input for an RGB image"); + } + } else { + if (frame_settings->enc->basic_info.num_color_channels != 3) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "RGB pixel format input for a grayscale image"); + } + } + if (frame_settings->values.lossless && + frame_settings->enc->metadata.m.xyb_encoded) { + return JXL_API_ERROR( + frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Set uses_original_profile=true for lossless encoding"); + } + if (JXL_ENC_SUCCESS != + VerifyInputBitDepth(frame_settings->values.image_bit_depth, + pixel_format)) { + return JXL_API_ERROR_NOSET("Invalid input bit depth"); + } + if (has_interleaved_alpha > + frame_settings->enc->metadata.m.num_extra_channels) { + return JXL_API_ERROR( + frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "number of extra channels mismatch (need 1 extra channel for alpha)"); + } + + bool has_alpha = frame_settings->enc->metadata.m.HasAlpha(); + + // All required conditions to do fast-lossless. + if (CanDoFastLossless(frame_settings, &pixel_format, has_alpha)) { + const bool big_endian = + pixel_format.endianness == JXL_BIG_ENDIAN || + (pixel_format.endianness == JXL_NATIVE_ENDIAN && !IsLittleEndian()); + + auto runner = +[](void* void_pool, void* opaque, void fun(void*, size_t), + size_t count) { + auto* pool = reinterpret_cast<jxl::ThreadPool*>(void_pool); + JXL_CHECK(jxl::RunOnPool( + pool, 0, count, jxl::ThreadPool::NoInit, + [&](size_t i, size_t) { fun(opaque, i); }, "Encode fast lossless")); + }; + auto frame_state = JxlFastLosslessPrepareFrame( + frame_data.GetInputSource(), xsize, ysize, num_channels, + frame_settings->enc->metadata.m.bit_depth.bits_per_sample, big_endian, + /*effort=*/2, /*oneshot=*/!frame_data.StreamingInput()); + if (!streaming) { + JxlFastLosslessProcessFrame(frame_state, /*is_last=*/false, + frame_settings->enc->thread_pool.get(), + runner, nullptr); + } + QueueFastLosslessFrame(frame_settings, frame_state); + return JxlErrorOrStatus::Success(); + } + + if (!streaming) { + // The input callbacks are only guaranteed to be available during frame + // encoding when both the input and the output is streaming. In all other + // cases we need to create an internal copy of the frame data. + if (!frame_data.CopyBuffers()) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Invalid chunked frame input source"); + } + } + + if (!frame_settings->enc->color_encoding_set) { + jxl::ColorEncoding c_current; + if ((pixel_format.data_type == JXL_TYPE_FLOAT) || + (pixel_format.data_type == JXL_TYPE_FLOAT16)) { + c_current = jxl::ColorEncoding::LinearSRGB(num_channels < 3); + } else { + c_current = jxl::ColorEncoding::SRGB(num_channels < 3); + } + frame_settings->enc->metadata.m.color_encoding = c_current; + } + + auto queued_frame = jxl::MemoryManagerMakeUnique<jxl::JxlEncoderQueuedFrame>( + &frame_settings->enc->memory_manager, + // JxlEncoderQueuedFrame is a struct with no constructors, so we use the + // default move constructor there. + jxl::JxlEncoderQueuedFrame{ + frame_settings->values, std::move(frame_data), {}}); + + if (!queued_frame) { + // TODO(jon): when can this happen? is this an API usage error? + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "No frame queued?"); + } + + for (auto& ec_info : frame_settings->enc->metadata.m.extra_channel_info) { + if (has_interleaved_alpha && ec_info.type == jxl::ExtraChannel::kAlpha) { + queued_frame->ec_initialized.push_back(1); + has_interleaved_alpha = 0; // only first Alpha is initialized + } else { + queued_frame->ec_initialized.push_back(0); + } + } + queued_frame->option_values.cparams.level = + frame_settings->enc->codestream_level; + + QueueFrame(frame_settings, queued_frame); + return JxlErrorOrStatus::Success(); +} +} // namespace + +JxlEncoderStatus JxlEncoderAddImageFrame( + const JxlEncoderFrameSettings* frame_settings, + const JxlPixelFormat* pixel_format, const void* buffer, size_t size) { + size_t xsize, ysize; + if (GetCurrentDimensions(frame_settings, xsize, ysize) != JXL_ENC_SUCCESS) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "bad dimensions"); + } + jxl::JxlEncoderChunkedFrameAdapter frame_data( + xsize, ysize, frame_settings->enc->metadata.m.num_extra_channels); + if (!frame_data.SetFromBuffer(0, reinterpret_cast<const uint8_t*>(buffer), + size, *pixel_format)) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "provided image buffer too small"); + } + return JxlEncoderAddImageFrameInternal(frame_settings, xsize, ysize, + /*streaming=*/false, + std::move(frame_data)); +} + +JxlEncoderStatus JxlEncoderAddChunkedFrame( + const JxlEncoderFrameSettings* frame_settings, JXL_BOOL is_last_frame, + JxlChunkedFrameInputSource chunked_frame_input) { + size_t xsize, ysize; + if (GetCurrentDimensions(frame_settings, xsize, ysize) != JXL_ENC_SUCCESS) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_GENERIC, + "bad dimensions"); + } + bool streaming = frame_settings->enc->output_processor.OutputProcessorSet(); + jxl::JxlEncoderChunkedFrameAdapter frame_data( + xsize, ysize, frame_settings->enc->metadata.m.num_extra_channels); + frame_data.SetInputSource(chunked_frame_input); + auto status = JxlEncoderAddImageFrameInternal(frame_settings, xsize, ysize, + streaming, frame_data); + if (status != JXL_ENC_SUCCESS) return status; + + auto& queued_frame = frame_settings->enc->input_queue.back(); + if (queued_frame.frame) { + for (auto& val : queued_frame.frame->ec_initialized) val = 1; + } + + if (is_last_frame) { + JxlEncoderCloseInput(frame_settings->enc); + } + if (streaming) { + return JxlEncoderFlushInput(frame_settings->enc); + } + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderUseBoxes(JxlEncoder* enc) { + if (enc->wrote_bytes) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "this setting can only be set at the beginning"); + } + enc->use_boxes = true; + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderAddBox(JxlEncoder* enc, const JxlBoxType type, + const uint8_t* contents, size_t size, + JXL_BOOL compress_box) { + if (!enc->use_boxes) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "must set JxlEncoderUseBoxes at the beginning to add boxes"); + } + if (enc->boxes_closed) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Box input already closed"); + } + if (compress_box) { + if (memcmp("jxl", type, 3) == 0) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "brob box may not contain a type starting with \"jxl\""); + } + if (memcmp("jbrd", type, 4) == 0) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "jbrd box may not be brob compressed"); + } + if (memcmp("brob", type, 4) == 0) { + // The compress_box will compress an existing non-brob box into a brob + // box. If already giving a valid brotli-compressed brob box, set + // compress_box to false since it is already compressed. + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "a brob box cannot contain another brob box"); + } + } + + auto box = jxl::MemoryManagerMakeUnique<jxl::JxlEncoderQueuedBox>( + &enc->memory_manager); + + box->type = jxl::MakeBoxType(type); + box->contents.assign(contents, contents + size); + box->compress_box = !!compress_box; + QueueBox(enc, box); + return JxlErrorOrStatus::Success(); +} + +JXL_EXPORT JxlEncoderStatus JxlEncoderSetExtraChannelBuffer( + const JxlEncoderFrameSettings* frame_settings, + const JxlPixelFormat* pixel_format, const void* buffer, size_t size, + uint32_t index) { + if (index >= frame_settings->enc->metadata.m.num_extra_channels) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Invalid value for the index of extra channel"); + } + if (!frame_settings->enc->basic_info_set || + !frame_settings->enc->color_encoding_set) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Basic info has to be set first"); + } + if (frame_settings->enc->input_queue.empty()) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "First add image frame, then extra channels"); + } + if (frame_settings->enc->frames_closed) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Frame input already closed"); + } + JxlPixelFormat ec_format = *pixel_format; + ec_format.num_channels = 1; + if (JXL_ENC_SUCCESS != + VerifyInputBitDepth(frame_settings->values.image_bit_depth, ec_format)) { + return JXL_API_ERROR_NOSET("Invalid input bit depth"); + } + const uint8_t* uint8_buffer = reinterpret_cast<const uint8_t*>(buffer); + auto queued_frame = frame_settings->enc->input_queue.back().frame.get(); + if (!queued_frame->frame_data.SetFromBuffer(1 + index, uint8_buffer, size, + ec_format)) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "provided image buffer too small"); + } + queued_frame->ec_initialized[index] = 1; + + return JxlErrorOrStatus::Success(); +} + +void JxlEncoderCloseFrames(JxlEncoder* enc) { enc->frames_closed = true; } + +void JxlEncoderCloseBoxes(JxlEncoder* enc) { enc->boxes_closed = true; } + +void JxlEncoderCloseInput(JxlEncoder* enc) { + JxlEncoderCloseFrames(enc); + JxlEncoderCloseBoxes(enc); +} + +JXL_EXPORT JxlEncoderStatus JxlEncoderFlushInput(JxlEncoder* enc) { + if (!enc->output_processor.OutputProcessorSet()) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Cannot flush input without setting output " + "processor with JxlEncoderSetOutputProcessor"); + } + while (!enc->input_queue.empty()) { + if (!enc->ProcessOneEnqueuedInput()) { + return JxlErrorOrStatus::Error(); + } + } + return JxlErrorOrStatus::Success(); +} + +JXL_EXPORT JxlEncoderStatus JxlEncoderSetOutputProcessor( + JxlEncoder* enc, JxlEncoderOutputProcessor output_processor) { + if (enc->output_processor.HasAvailOut()) { + return JXL_API_ERROR( + enc, JXL_ENC_ERR_API_USAGE, + "Cannot set an output processor when some output was already produced"); + } + if (!output_processor.set_finalized_position || + !output_processor.get_buffer || !output_processor.release_buffer) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Missing output processor functions"); + } + enc->output_processor = JxlEncoderOutputProcessorWrapper(output_processor); + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderProcessOutput(JxlEncoder* enc, uint8_t** next_out, + size_t* avail_out) { + if (!enc->output_processor.SetAvailOut(next_out, avail_out)) { + return JXL_API_ERROR(enc, JXL_ENC_ERR_API_USAGE, + "Cannot call JxlEncoderProcessOutput after calling " + "JxlEncoderSetOutputProcessor"); + } + while (*avail_out != 0 && !enc->input_queue.empty()) { + if (!enc->ProcessOneEnqueuedInput()) { + return JxlErrorOrStatus::Error(); + } + } + + if (!enc->input_queue.empty() || enc->output_processor.HasOutputToWrite()) { + return JxlErrorOrStatus::MoreOutput(); + } + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderSetFrameHeader( + JxlEncoderFrameSettings* frame_settings, + const JxlFrameHeader* frame_header) { + if (frame_header->layer_info.blend_info.source > 3) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "invalid blending source index"); + } + // If there are no extra channels, it's ok for the value to be 0. + if (frame_header->layer_info.blend_info.alpha != 0 && + frame_header->layer_info.blend_info.alpha >= + frame_settings->enc->metadata.m.extra_channel_info.size()) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "alpha blend channel index out of bounds"); + } + + frame_settings->values.header = *frame_header; + // Setting the frame header resets the frame name, it must be set again with + // JxlEncoderSetFrameName if desired. + frame_settings->values.frame_name = ""; + + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderSetExtraChannelBlendInfo( + JxlEncoderFrameSettings* frame_settings, size_t index, + const JxlBlendInfo* blend_info) { + if (index >= frame_settings->enc->metadata.m.num_extra_channels) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "Invalid value for the index of extra channel"); + } + + if (frame_settings->values.extra_channel_blend_info.size() != + frame_settings->enc->metadata.m.num_extra_channels) { + JxlBlendInfo default_blend_info; + JxlEncoderInitBlendInfo(&default_blend_info); + frame_settings->values.extra_channel_blend_info.resize( + frame_settings->enc->metadata.m.num_extra_channels, default_blend_info); + } + frame_settings->values.extra_channel_blend_info[index] = *blend_info; + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderSetFrameName(JxlEncoderFrameSettings* frame_settings, + const char* frame_name) { + std::string str = frame_name ? frame_name : ""; + if (str.size() > 1071) { + return JXL_API_ERROR(frame_settings->enc, JXL_ENC_ERR_API_USAGE, + "frame name can be max 1071 bytes long"); + } + frame_settings->values.frame_name = str; + frame_settings->values.header.name_length = str.size(); + return JxlErrorOrStatus::Success(); +} + +JxlEncoderStatus JxlEncoderSetFrameBitDepth( + JxlEncoderFrameSettings* frame_settings, const JxlBitDepth* bit_depth) { + if (bit_depth->type != JXL_BIT_DEPTH_FROM_PIXEL_FORMAT && + bit_depth->type != JXL_BIT_DEPTH_FROM_CODESTREAM) { + return JXL_API_ERROR_NOSET( + "Only JXL_BIT_DEPTH_FROM_PIXEL_FORMAT and " + "JXL_BIT_DEPTH_FROM_CODESTREAM is implemented " + "for input buffers."); + } + frame_settings->values.image_bit_depth = *bit_depth; + return JxlErrorOrStatus::Success(); +} + +void JxlColorEncodingSetToSRGB(JxlColorEncoding* color_encoding, + JXL_BOOL is_gray) { + *color_encoding = jxl::ColorEncoding::SRGB(is_gray).ToExternal(); +} + +void JxlColorEncodingSetToLinearSRGB(JxlColorEncoding* color_encoding, + JXL_BOOL is_gray) { + *color_encoding = jxl::ColorEncoding::LinearSRGB(is_gray).ToExternal(); +} + +void JxlEncoderAllowExpertOptions(JxlEncoder* enc) { + enc->allow_expert_options = true; +} + +JXL_EXPORT void JxlEncoderSetDebugImageCallback( + JxlEncoderFrameSettings* frame_settings, JxlDebugImageCallback callback, + void* opaque) { + frame_settings->values.cparams.debug_image = callback; + frame_settings->values.cparams.debug_image_opaque = opaque; +} + +JXL_EXPORT JxlEncoderStats* JxlEncoderStatsCreate() { + return new JxlEncoderStats(); +} + +JXL_EXPORT void JxlEncoderStatsDestroy(JxlEncoderStats* stats) { + if (stats) delete stats; +} + +JXL_EXPORT void JxlEncoderCollectStats(JxlEncoderFrameSettings* frame_settings, + JxlEncoderStats* stats) { + if (!stats) return; + frame_settings->values.aux_out = &stats->aux_out; +} + +JXL_EXPORT size_t JxlEncoderStatsGet(const JxlEncoderStats* stats, + JxlEncoderStatsKey key) { + if (!stats) return 0; + const jxl::AuxOut& aux_out = stats->aux_out; + switch (key) { + case JXL_ENC_STAT_HEADER_BITS: + return aux_out.layers[jxl::kLayerHeader].total_bits; + case JXL_ENC_STAT_TOC_BITS: + return aux_out.layers[jxl::kLayerTOC].total_bits; + case JXL_ENC_STAT_DICTIONARY_BITS: + return aux_out.layers[jxl::kLayerDictionary].total_bits; + case JXL_ENC_STAT_SPLINES_BITS: + return aux_out.layers[jxl::kLayerSplines].total_bits; + case JXL_ENC_STAT_NOISE_BITS: + return aux_out.layers[jxl::kLayerNoise].total_bits; + case JXL_ENC_STAT_QUANT_BITS: + return aux_out.layers[jxl::kLayerQuant].total_bits; + case JXL_ENC_STAT_MODULAR_TREE_BITS: + return aux_out.layers[jxl::kLayerModularTree].total_bits; + case JXL_ENC_STAT_MODULAR_GLOBAL_BITS: + return aux_out.layers[jxl::kLayerModularGlobal].total_bits; + case JXL_ENC_STAT_DC_BITS: + return aux_out.layers[jxl::kLayerDC].total_bits; + case JXL_ENC_STAT_MODULAR_DC_GROUP_BITS: + return aux_out.layers[jxl::kLayerModularDcGroup].total_bits; + case JXL_ENC_STAT_CONTROL_FIELDS_BITS: + return aux_out.layers[jxl::kLayerControlFields].total_bits; + case JXL_ENC_STAT_COEF_ORDER_BITS: + return aux_out.layers[jxl::kLayerOrder].total_bits; + case JXL_ENC_STAT_AC_HISTOGRAM_BITS: + return aux_out.layers[jxl::kLayerAC].total_bits; + case JXL_ENC_STAT_AC_BITS: + return aux_out.layers[jxl::kLayerACTokens].total_bits; + case JXL_ENC_STAT_MODULAR_AC_GROUP_BITS: + return aux_out.layers[jxl::kLayerModularAcGroup].total_bits; + case JXL_ENC_STAT_NUM_SMALL_BLOCKS: + return aux_out.num_small_blocks; + case JXL_ENC_STAT_NUM_DCT4X8_BLOCKS: + return aux_out.num_dct4x8_blocks; + case JXL_ENC_STAT_NUM_AFV_BLOCKS: + return aux_out.num_afv_blocks; + case JXL_ENC_STAT_NUM_DCT8_BLOCKS: + return aux_out.num_dct8_blocks; + case JXL_ENC_STAT_NUM_DCT8X32_BLOCKS: + return aux_out.num_dct16_blocks; + case JXL_ENC_STAT_NUM_DCT16_BLOCKS: + return aux_out.num_dct16x32_blocks; + case JXL_ENC_STAT_NUM_DCT16X32_BLOCKS: + return aux_out.num_dct32_blocks; + case JXL_ENC_STAT_NUM_DCT32_BLOCKS: + return aux_out.num_dct32x64_blocks; + case JXL_ENC_STAT_NUM_DCT32X64_BLOCKS: + return aux_out.num_dct32x64_blocks; + case JXL_ENC_STAT_NUM_DCT64_BLOCKS: + return aux_out.num_dct64_blocks; + case JXL_ENC_STAT_NUM_BUTTERAUGLI_ITERS: + return aux_out.num_butteraugli_iters; + default: + return 0; + } +} + +JXL_EXPORT void JxlEncoderStatsMerge(JxlEncoderStats* stats, + const JxlEncoderStats* other) { + if (!stats || !other) return; + stats->aux_out.Assimilate(other->aux_out); +} diff --git a/third_party/jpeg-xl/lib/jxl/encode_internal.h b/third_party/jpeg-xl/lib/jxl/encode_internal.h new file mode 100644 index 0000000000..e89993f253 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/encode_internal.h @@ -0,0 +1,669 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +#ifndef LIB_JXL_ENCODE_INTERNAL_H_ +#define LIB_JXL_ENCODE_INTERNAL_H_ + +#include <jxl/cms_interface.h> +#include <jxl/codestream_header.h> +#include <jxl/encode.h> +#include <jxl/memory_manager.h> +#include <jxl/types.h> + +#include <algorithm> +#include <array> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <functional> +#include <map> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "lib/jxl/base/c_callback_support.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_fast_lossless.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/jpeg/jpeg_data.h" +#include "lib/jxl/memory_manager_internal.h" +#include "lib/jxl/padded_bytes.h" + +namespace jxl { + +/* Frame index box 'jxli' will start with Varint() for +NF: has type Varint(): number of frames listed in the index. +TNUM: has type u32: numerator of tick unit. +TDEN: has type u32: denominator of tick unit. Value 0 means the file is +ill-formed. per frame i listed: OFFi: has type Varint(): offset of start byte of +this frame compared to start byte of previous frame from this index in the JPEG +XL codestream. For the first frame, this is the offset from the first byte of +the JPEG XL codestream. Ti: has type Varint(): duration in ticks between the +start of this frame and the start of the next frame in the index. If this is the +last frame in the index, this is the duration in ticks between the start of this +frame and the end of the stream. A tick lasts TNUM / TDEN seconds. Fi: has type +Varint(): amount of frames the next frame in the index occurs after this frame. +If this is the last frame in the index, this is the amount of frames after this +frame in the remainder of the stream. Only frames that are presented by the +decoder are counted for this purpose, this excludes frames that are not intended +for display but for compositing with other frames, such as frames that aren't +the last frame with a duration of 0 ticks. + +All the frames listed in jxli are keyframes and the first frame is +present in the list. +There shall be either zero or one Frame Index boxes in a JPEG XL file. +The offsets OFFi per frame are given as bytes in the codestream, not as +bytes in the file format using the box structure. This means if JPEG XL Partial +Codestream boxes are used, the offset is counted within the concatenated +codestream, bytes from box headers or non-codestream boxes are not counted. +*/ + +typedef struct JxlEncoderFrameIndexBoxEntryStruct { + bool to_be_indexed; + uint32_t duration; + uint64_t OFFi; +} JxlEncoderFrameIndexBoxEntry; + +typedef struct JxlEncoderFrameIndexBoxStruct { + // We always need to record the first frame entry, so presence of the + // first entry alone is not an indication if it was requested to be + // stored. + bool index_box_requested_through_api = false; + + int64_t NF() const { return entries.size(); } + bool StoreFrameIndexBox() { + for (auto e : entries) { + if (e.to_be_indexed) { + return true; + } + } + return false; + } + int32_t TNUM = 1; + int32_t TDEN = 1000; + + std::vector<JxlEncoderFrameIndexBoxEntry> entries; + + // That way we can ensure that every index box will have the first frame. + // If the API user decides to mark it as an indexed frame, we call + // the AddFrame again, this time with requested. + void AddFrame(uint64_t OFFi, uint32_t duration, bool to_be_indexed) { + // We call AddFrame to every frame. + // Recording the first frame is required by the standard. + // Knowing the last frame is required, since the last indexed frame + // needs to know how many frames until the end. + // To be able to tell how many frames there are between each index + // entry we just record every frame here. + if (entries.size() == 1) { + if (OFFi == entries[0].OFFi) { + // API use for the first frame, let's clear the already recorded first + // frame. + entries.clear(); + } + } + JxlEncoderFrameIndexBoxEntry e; + e.to_be_indexed = to_be_indexed; + e.OFFi = OFFi; + e.duration = duration; + entries.push_back(e); + } +} JxlEncoderFrameIndexBox; + +// The encoder options (such as quality, compression speed, ...) for a single +// frame, but not encoder-wide options such as box-related options. +typedef struct JxlEncoderFrameSettingsValuesStruct { + // lossless is a separate setting from cparams because it is a combination + // setting that overrides multiple settings inside of cparams. + bool lossless; + CompressParams cparams; + JxlFrameHeader header; + std::vector<JxlBlendInfo> extra_channel_blend_info; + std::string frame_name; + JxlBitDepth image_bit_depth; + bool frame_index_box = false; + jxl::AuxOut* aux_out = nullptr; +} JxlEncoderFrameSettingsValues; + +typedef std::array<uint8_t, 4> BoxType; + +// Utility function that makes a BoxType from a string literal. The string must +// have 4 characters, a 5th null termination character is optional. +constexpr BoxType MakeBoxType(const char* type) { + return BoxType( + {{static_cast<uint8_t>(type[0]), static_cast<uint8_t>(type[1]), + static_cast<uint8_t>(type[2]), static_cast<uint8_t>(type[3])}}); +} + +constexpr std::array<unsigned char, 32> kContainerHeader = { + 0, 0, 0, 0xc, 'J', 'X', 'L', ' ', 0xd, 0xa, 0x87, + 0xa, 0, 0, 0, 0x14, 'f', 't', 'y', 'p', 'j', 'x', + 'l', ' ', 0, 0, 0, 0, 'j', 'x', 'l', ' '}; + +constexpr std::array<unsigned char, 8> kLevelBoxHeader = {0, 0, 0, 0x9, + 'j', 'x', 'l', 'l'}; + +static JXL_INLINE size_t BitsPerChannel(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_UINT8: + return 8; + case JXL_TYPE_UINT16: + return 16; + case JXL_TYPE_FLOAT: + return 32; + case JXL_TYPE_FLOAT16: + return 16; + default: + return 0; // signals unhandled JxlDataType + } +} + +static JXL_INLINE size_t BytesPerPixel(JxlPixelFormat format) { + return format.num_channels * BitsPerChannel(format.data_type) / 8; +} + +using ScopedInputBuffer = + std::unique_ptr<const void, std::function<void(const void*)>>; + +static JXL_INLINE ScopedInputBuffer +GetColorBuffer(JxlChunkedFrameInputSource& input, size_t xpos, size_t ypos, + size_t xsize, size_t ysize, size_t* row_offset) { + return ScopedInputBuffer( + input.get_color_channel_data_at(input.opaque, xpos, ypos, xsize, ysize, + row_offset), + [&input](const void* p) { input.release_buffer(input.opaque, p); }); +} + +static JXL_INLINE ScopedInputBuffer GetExtraChannelBuffer( + JxlChunkedFrameInputSource& input, size_t ec_index, size_t xpos, + size_t ypos, size_t xsize, size_t ysize, size_t* row_offset) { + return ScopedInputBuffer( + input.get_extra_channel_data_at(input.opaque, ec_index, xpos, ypos, xsize, + ysize, row_offset), + [&input](const void* p) { input.release_buffer(input.opaque, p); }); +} + +// Common adapter for an existing frame input source or a whole-image input +// buffer or a parsed JPEG file. +class JxlEncoderChunkedFrameAdapter { + public: + JxlEncoderChunkedFrameAdapter(size_t xs, size_t ys, size_t num_extra_channels) + : xsize(xs), ysize(ys), channels_(1 + num_extra_channels) {} + + void SetInputSource(JxlChunkedFrameInputSource input_source) { + input_source_ = input_source; + has_input_source_ = true; + } + + bool SetFromBuffer(size_t channel, const uint8_t* buffer, size_t size, + JxlPixelFormat format) { + if (channel >= channels_.size()) return false; + if (!channels_[channel].SetFromBuffer(buffer, size, format, xsize, ysize)) { + return false; + } + if (channel > 0) channels_[channel].CopyBuffer(); + return true; + } + + // TODO(szabadka) Move instead of copy. + void SetJPEGData(const jpeg::JPEGData jpeg_data) { + jpeg_data_ = jpeg_data; + has_jpeg_data_ = true; + } + bool IsJPEG() const { return has_jpeg_data_; } + + jpeg::JPEGData&& TakeJPEGData() { + JXL_ASSERT(has_jpeg_data_); + return std::move(jpeg_data_); + } + + JxlChunkedFrameInputSource GetInputSource() { + if (has_input_source_) { + return input_source_; + } + return JxlChunkedFrameInputSource{ + this, + METHOD_TO_C_CALLBACK( + &JxlEncoderChunkedFrameAdapter::GetColorChannelsPixelFormat), + METHOD_TO_C_CALLBACK( + &JxlEncoderChunkedFrameAdapter::GetColorChannelDataAt), + METHOD_TO_C_CALLBACK( + &JxlEncoderChunkedFrameAdapter::GetExtraChannelPixelFormat), + METHOD_TO_C_CALLBACK( + &JxlEncoderChunkedFrameAdapter::GetExtraChannelDataAt), + METHOD_TO_C_CALLBACK( + &JxlEncoderChunkedFrameAdapter::ReleaseCurrentData)}; + } + + bool CopyBuffers() { + if (has_input_source_) { + JxlPixelFormat format{4, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + input_source_.get_color_channels_pixel_format(input_source_.opaque, + &format); + size_t row_offset; + { + auto buffer = + GetColorBuffer(input_source_, 0, 0, xsize, ysize, &row_offset); + if (!buffer) return false; + channels_[0].CopyFromBuffer(buffer.get(), format, xsize, ysize, + row_offset); + } + for (size_t ec = 0; ec + 1 < channels_.size(); ++ec) { + input_source_.get_extra_channel_pixel_format(input_source_.opaque, ec, + &format); + auto buffer = GetExtraChannelBuffer(input_source_, ec, 0, 0, xsize, + ysize, &row_offset); + if (!buffer) continue; + channels_[1 + ec].CopyFromBuffer(buffer.get(), format, xsize, ysize, + row_offset); + } + has_input_source_ = false; + } else { + channels_[0].CopyBuffer(); + } + return true; + } + + bool StreamingInput() const { return has_input_source_; } + + const size_t xsize; + const size_t ysize; + + private: + void GetColorChannelsPixelFormat(JxlPixelFormat* pixel_format) { + *pixel_format = channels_[0].format_; + } + + const void* GetColorChannelDataAt(size_t xpos, size_t ypos, size_t xsize, + size_t ysize, size_t* row_offset) { + return channels_[0].GetDataAt(xpos, ypos, xsize, ysize, row_offset); + } + + void GetExtraChannelPixelFormat(size_t ec_index, + JxlPixelFormat* pixel_format) { + JXL_ASSERT(1 + ec_index < channels_.size()); + *pixel_format = channels_[1 + ec_index].format_; + } + + const void* GetExtraChannelDataAt(size_t ec_index, size_t xpos, size_t ypos, + size_t xsize, size_t ysize, + size_t* row_offset) { + JXL_ASSERT(1 + ec_index < channels_.size()); + return channels_[1 + ec_index].GetDataAt(xpos, ypos, xsize, ysize, + row_offset); + } + + void ReleaseCurrentData(const void* buffer) { + // No dynamic memory is allocated in GetColorChannelDataAt or + // GetExtraChannelDataAt. Therefore, no cleanup is required here. + } + + JxlChunkedFrameInputSource input_source_ = {}; + bool has_input_source_ = false; + jpeg::JPEGData jpeg_data_; + bool has_jpeg_data_ = false; + struct Channel { + const uint8_t* buffer_ = nullptr; + size_t buffer_size_; + JxlPixelFormat format_; + size_t xsize_; + size_t ysize_; + size_t bytes_per_pixel_; + size_t stride_; + std::vector<uint8_t> copy_; + + void SetFormatAndDimensions(JxlPixelFormat format, size_t xsize, + size_t ysize) { + format_ = format; + xsize_ = xsize; + ysize_ = ysize; + bytes_per_pixel_ = BytesPerPixel(format_); + const size_t last_row_size = xsize_ * bytes_per_pixel_; + const size_t align = format_.align; + stride_ = (align > 1 ? jxl::DivCeil(last_row_size, align) * align + : last_row_size); + } + + bool SetFromBuffer(const uint8_t* buffer, size_t size, + JxlPixelFormat format, size_t xsize, size_t ysize) { + SetFormatAndDimensions(format, xsize, ysize); + buffer_ = buffer; + buffer_size_ = size; + const size_t min_buffer_size = + stride_ * (ysize_ - 1) + xsize_ * bytes_per_pixel_; + return min_buffer_size <= size; + } + + void CopyFromBuffer(const void* buffer, JxlPixelFormat format, size_t xsize, + size_t ysize, size_t row_offset) { + SetFormatAndDimensions(format, xsize, ysize); + buffer_ = nullptr; + copy_.resize(ysize * stride_); + for (size_t y = 0; y < ysize; ++y) { + memcpy(copy_.data() + y * stride_, + reinterpret_cast<const uint8_t*>(buffer) + y * row_offset, + stride_); + } + } + + void CopyBuffer() { + if (buffer_) { + copy_ = std::vector<uint8_t>(buffer_, buffer_ + buffer_size_); + buffer_ = nullptr; + } + } + + const void* GetDataAt(size_t xpos, size_t ypos, size_t xsize, size_t ysize, + size_t* row_offset) const { + const uint8_t* buffer = copy_.empty() ? buffer_ : copy_.data(); + JXL_ASSERT(ypos + ysize <= ysize_); + JXL_ASSERT(xpos + xsize <= xsize_); + JXL_ASSERT(buffer); + *row_offset = stride_; + return buffer + ypos * stride_ + xpos * bytes_per_pixel_; + } + }; + std::vector<Channel> channels_; +}; + +struct JxlEncoderQueuedFrame { + JxlEncoderFrameSettingsValues option_values; + JxlEncoderChunkedFrameAdapter frame_data; + std::vector<uint8_t> ec_initialized; +}; + +struct JxlEncoderQueuedBox { + BoxType type; + std::vector<uint8_t> contents; + bool compress_box; +}; + +using FJXLFrameUniquePtr = + std::unique_ptr<JxlFastLosslessFrameState, + decltype(&JxlFastLosslessFreeFrameState)>; + +// Either a frame, or a box, not both. +// Can also be a FJXL frame. +struct JxlEncoderQueuedInput { + explicit JxlEncoderQueuedInput(const JxlMemoryManager& memory_manager) + : frame(nullptr, jxl::MemoryManagerDeleteHelper(&memory_manager)), + box(nullptr, jxl::MemoryManagerDeleteHelper(&memory_manager)) {} + MemoryManagerUniquePtr<JxlEncoderQueuedFrame> frame; + MemoryManagerUniquePtr<JxlEncoderQueuedBox> box; + FJXLFrameUniquePtr fast_lossless_frame = {nullptr, + JxlFastLosslessFreeFrameState}; +}; + +static constexpr size_t kSmallBoxHeaderSize = 8; +static constexpr size_t kLargeBoxHeaderSize = 16; +static constexpr size_t kLargeBoxContentSizeThreshold = + 0x100000000ull - kSmallBoxHeaderSize; + +size_t WriteBoxHeader(const jxl::BoxType& type, size_t size, bool unbounded, + bool force_large_box, uint8_t* output); + +// Appends a JXL container box header with given type, size, and unbounded +// properties to output. +template <typename T> +void AppendBoxHeader(const jxl::BoxType& type, size_t size, bool unbounded, + T* output) { + size_t current_size = output->size(); + output->resize(current_size + kLargeBoxHeaderSize); + size_t header_size = + WriteBoxHeader(type, size, unbounded, /*force_large_box=*/false, + output->data() + current_size); + output->resize(current_size + header_size); +} + +} // namespace jxl + +class JxlOutputProcessorBuffer; + +class JxlEncoderOutputProcessorWrapper { + friend class JxlOutputProcessorBuffer; + + public: + JxlEncoderOutputProcessorWrapper() = default; + explicit JxlEncoderOutputProcessorWrapper(JxlEncoderOutputProcessor processor) + : external_output_processor_( + jxl::make_unique<JxlEncoderOutputProcessor>(processor)) {} + + bool HasAvailOut() const { return avail_out_ != nullptr; } + + // Caller can never overwrite a previously-written buffer. Asking for a buffer + // with `min_size` such that `position + min_size` overlaps with a + // previously-written buffer is invalid. + jxl::StatusOr<JxlOutputProcessorBuffer> GetBuffer(size_t min_size, + size_t requested_size = 0); + + void Seek(size_t pos); + + void SetFinalizedPosition(); + + size_t CurrentPosition() const { return position_; } + + bool SetAvailOut(uint8_t** next_out, size_t* avail_out); + + bool WasStopRequested() const { return stop_requested_; } + bool OutputProcessorSet() const { + return external_output_processor_ != nullptr; + } + bool HasOutputToWrite() const { + return output_position_ < finalized_position_; + } + + void CopyOutput(std::vector<uint8_t>& output, uint8_t* next_out, + size_t& avail_out); + + private: + void ReleaseBuffer(size_t bytes_used); + + // Tries to write all the bytes up to the finalized position. + void FlushOutput(); + + bool AppendBufferToExternalProcessor(void* data, size_t count); + + struct InternalBuffer { + // Bytes in the range `[output_position_ - start_of_the_buffer, + // written_bytes)` need to be flushed out. + size_t written_bytes = 0; + // If data has been buffered, it is stored in `owned_data`. + jxl::PaddedBytes owned_data; + }; + + // Invariant: `internal_buffers_` does not contain chunks that are entirely + // below the output position. + std::map<size_t, InternalBuffer> internal_buffers_; + + uint8_t** next_out_ = nullptr; + size_t* avail_out_ = nullptr; + // Where the next GetBuffer call will write bytes to. + size_t position_ = 0; + // The position of the last SetFinalizedPosition call. + size_t finalized_position_ = 0; + // Either the position of the `external_output_processor_` or the position + // `next_out_` points to. + size_t output_position_ = 0; + + bool stop_requested_ = false; + bool has_buffer_ = false; + + std::unique_ptr<JxlEncoderOutputProcessor> external_output_processor_; +}; + +class JxlOutputProcessorBuffer { + public: + size_t size() const { return size_; }; + uint8_t* data() { return data_; } + + JxlOutputProcessorBuffer(uint8_t* buffer, size_t size, size_t bytes_used, + JxlEncoderOutputProcessorWrapper* wrapper) + : data_(buffer), + size_(size), + bytes_used_(bytes_used), + wrapper_(wrapper) {} + ~JxlOutputProcessorBuffer() { release(); } + + JxlOutputProcessorBuffer(const JxlOutputProcessorBuffer&) = delete; + JxlOutputProcessorBuffer(JxlOutputProcessorBuffer&& other) noexcept + : JxlOutputProcessorBuffer(other.data_, other.size_, other.bytes_used_, + other.wrapper_) { + other.data_ = nullptr; + other.size_ = 0; + } + + void advance(size_t count) { + JXL_ASSERT(count <= size_); + data_ += count; + size_ -= count; + bytes_used_ += count; + } + + void release() { + if (this->data_) { + wrapper_->ReleaseBuffer(bytes_used_); + } + data_ = nullptr; + size_ = 0; + } + + void append(const void* data, size_t count) { + memcpy(data_, data, count); + advance(count); + } + + template <typename T> + void append(const T& data) { + static_assert(sizeof(*std::begin(data)) == 1, "Cannot append non-bytes"); + append(&*std::begin(data), std::end(data) - std::begin(data)); + } + + JxlOutputProcessorBuffer& operator=(const JxlOutputProcessorBuffer&) = delete; + JxlOutputProcessorBuffer& operator=( + JxlOutputProcessorBuffer&& other) noexcept { + data_ = other.data_; + size_ = other.size_; + wrapper_ = other.wrapper_; + return *this; + } + + private: + uint8_t* data_; + size_t size_; + size_t bytes_used_; + JxlEncoderOutputProcessorWrapper* wrapper_; +}; + +template <typename T> +jxl::Status AppendData(JxlEncoderOutputProcessorWrapper& output_processor, + const T& data) { + size_t size = std::end(data) - std::begin(data); + size_t written = 0; + while (written < size) { + JXL_ASSIGN_OR_RETURN(auto buffer, + output_processor.GetBuffer(1, size - written)); + size_t n = std::min(size - written, buffer.size()); + buffer.append(data.data() + written, n); + written += n; + } + return jxl::OkStatus(); +} + +// Internal use only struct, can only be initialized correctly by +// JxlEncoderCreate. +struct JxlEncoderStruct { + JxlEncoderError error = JxlEncoderError::JXL_ENC_ERR_OK; + JxlMemoryManager memory_manager; + jxl::MemoryManagerUniquePtr<jxl::ThreadPool> thread_pool{ + nullptr, jxl::MemoryManagerDeleteHelper(&memory_manager)}; + JxlCmsInterface cms; + bool cms_set; + std::vector<jxl::MemoryManagerUniquePtr<JxlEncoderFrameSettings>> + encoder_options; + + size_t num_queued_frames; + size_t num_queued_boxes; + std::vector<jxl::JxlEncoderQueuedInput> input_queue; + JxlEncoderOutputProcessorWrapper output_processor; + + // How many codestream bytes have been written, i.e., + // content of jxlc and jxlp boxes. Frame index box jxli + // requires position indices to point to codestream bytes, + // so we need to keep track of the total of flushed or queue + // codestream bytes. These bytes may be in a single jxlc box + // or across multiple jxlp boxes. + size_t codestream_bytes_written_end_of_frame; + jxl::JxlEncoderFrameIndexBox frame_index_box; + + // Force using the container even if not needed + bool use_container; + // User declared they will add metadata boxes + bool use_boxes; + + // TODO(lode): move level into jxl::CompressParams since some C++ + // implementation decisions should be based on it: level 10 allows more + // features to be used. + int32_t codestream_level; + bool store_jpeg_metadata; + jxl::CodecMetadata metadata; + std::vector<uint8_t> jpeg_metadata; + + // Wrote any output at all, so wrote the data before the first user added + // frame or box, such as signature, basic info, ICC profile or jpeg + // reconstruction box. + bool wrote_bytes; + jxl::CompressParams last_used_cparams; + JxlBasicInfo basic_info; + + // Encoder wrote a jxlp (partial codestream) box, so any next codestream + // parts must also be written in jxlp boxes, a single jxlc box cannot be + // used. The counter is used for the 4-byte jxlp box index header. + size_t jxlp_counter; + + bool frames_closed; + bool boxes_closed; + bool basic_info_set; + bool color_encoding_set; + bool intensity_target_set; + bool allow_expert_options = false; + int brotli_effort = -1; + + // Takes the first frame in the input_queue, encodes it, and appends + // the bytes to the output_byte_queue. + jxl::Status ProcessOneEnqueuedInput(); + + bool MustUseContainer() const { + return use_container || (codestream_level != 5 && codestream_level != -1) || + store_jpeg_metadata || use_boxes; + } + + // `write_box` must never seek before the position the output wrapper was at + // the moment of the call, and must leave the output wrapper such that its + // position is one byte past the end of the written box. + template <typename WriteBox> + jxl::Status AppendBox(const jxl::BoxType& type, bool unbounded, + size_t box_max_size, const WriteBox& write_box); + + template <typename BoxContents> + jxl::Status AppendBoxWithContents(const jxl::BoxType& type, + const BoxContents& contents); +}; + +struct JxlEncoderFrameSettingsStruct { + JxlEncoder* enc; + jxl::JxlEncoderFrameSettingsValues values; +}; + +struct JxlEncoderStatsStruct { + jxl::AuxOut aux_out; +}; + +#endif // LIB_JXL_ENCODE_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/encode_test.cc b/third_party/jpeg-xl/lib/jxl/encode_test.cc new file mode 100644 index 0000000000..2c17fcab21 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/encode_test.cc @@ -0,0 +1,2083 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/cms.h> +#include <jxl/cms_interface.h> +#include <jxl/codestream_header.h> +#include <jxl/color_encoding.h> +#include <jxl/decode.h> +#include <jxl/decode_cxx.h> +#include <jxl/encode.h> +#include <jxl/encode_cxx.h> +#include <jxl/memory_manager.h> +#include <jxl/types.h> + +#include <cstddef> +#include <cstdint> +#include <cstdio> +#include <cstdlib> +#include <cstring> +#include <mutex> +#include <ostream> +#include <set> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#include "lib/extras/codec.h" +#include "lib/extras/dec/jxl.h" +#include "lib/extras/metrics.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/c_callback_support.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" // JXL_HIGH_PRECISION +#include "lib/jxl/enc_params.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/test_image.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace { +bool SameDecodedPixels(const std::vector<uint8_t>& compressed0, + const std::vector<uint8_t>& compressed1) { + jxl::extras::JXLDecompressParams dparams; + dparams.accepted_formats = { + {3, JXL_TYPE_UINT16, JXL_LITTLE_ENDIAN, 0}, + {4, JXL_TYPE_UINT16, JXL_LITTLE_ENDIAN, 0}, + }; + jxl::extras::PackedPixelFile ppf0; + EXPECT_TRUE(DecodeImageJXL(compressed0.data(), compressed0.size(), dparams, + nullptr, &ppf0, nullptr)); + jxl::extras::PackedPixelFile ppf1; + EXPECT_TRUE(DecodeImageJXL(compressed1.data(), compressed1.size(), dparams, + nullptr, &ppf1, nullptr)); + return jxl::test::SamePixels(ppf0, ppf1); +} +} // namespace + +TEST(EncodeTest, AddFrameAfterCloseInputTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderCloseInput(enc.get()); + + size_t xsize = 64; + size_t ysize = 64; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); +} + +TEST(EncodeTest, AddJPEGAfterCloseTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderCloseInput(enc.get()); + + const std::string jpeg_path = "jxl/flower/flower.png.im_q85_420.jpg"; + const std::vector<uint8_t> orig = jxl::test::ReadTestData(jpeg_path); + + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddJPEGFrame(frame_settings, orig.data(), orig.size())); +} + +TEST(EncodeTest, AddFrameBeforeBasicInfoTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + size_t xsize = 64; + size_t ysize = 64; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); +} + +TEST(EncodeTest, DefaultAllocTest) { + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + JxlEncoderDestroy(enc); +} + +TEST(EncodeTest, CustomAllocTest) { + struct CalledCounters { + int allocs = 0; + int frees = 0; + } counters; + + JxlMemoryManager mm; + mm.opaque = &counters; + mm.alloc = [](void* opaque, size_t size) { + reinterpret_cast<CalledCounters*>(opaque)->allocs++; + return malloc(size); + }; + mm.free = [](void* opaque, void* address) { + reinterpret_cast<CalledCounters*>(opaque)->frees++; + free(address); + }; + + { + JxlEncoderPtr enc = JxlEncoderMake(&mm); + EXPECT_NE(nullptr, enc.get()); + EXPECT_LE(1, counters.allocs); + EXPECT_EQ(0, counters.frees); + } + EXPECT_LE(1, counters.frees); +} + +TEST(EncodeTest, DefaultParallelRunnerTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetParallelRunner(enc.get(), nullptr, nullptr)); +} + +void VerifyFrameEncoding(size_t xsize, size_t ysize, JxlEncoder* enc, + const JxlEncoderFrameSettings* frame_settings, + size_t max_compressed_size, + bool lossy_use_original_profile) { + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + if (frame_settings->values.lossless || lossy_use_original_profile) { + basic_info.uses_original_profile = true; + } else { + basic_info.uses_original_profile = false; + } + // 16-bit alpha means this requires level 10 + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc, 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, true); + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderSetColorEncoding(enc, &color_encoding)); + JxlColorEncodingSetToSRGB(&color_encoding, false); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetColorEncoding(enc, &color_encoding)); + pixel_format.num_channels = 1; + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + pixel_format.num_channels = 4; + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseInput(enc); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc, &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_LE(compressed.size(), max_compressed_size); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + jxl::CodecInOut decoded_io; + EXPECT_TRUE(jxl::test::DecodeFile( + {}, jxl::Bytes(compressed.data(), compressed.size()), &decoded_io)); + + EXPECT_LE( + ComputeDistance2(input_io.Main(), decoded_io.Main(), *JxlGetDefaultCms()), +#if JXL_HIGH_PRECISION + 1.84); +#else + 8.7); +#endif +} + +void VerifyFrameEncoding(JxlEncoder* enc, + const JxlEncoderFrameSettings* frame_settings) { + VerifyFrameEncoding(63, 129, enc, frame_settings, 2700, + /*lossy_use_original_profile=*/false); +} + +TEST(EncodeTest, FrameEncodingTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + VerifyFrameEncoding(enc.get(), + JxlEncoderFrameSettingsCreate(enc.get(), nullptr)); +} + +TEST(EncodeTest, EncoderResetTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + VerifyFrameEncoding(50, 200, enc.get(), + JxlEncoderFrameSettingsCreate(enc.get(), nullptr), 4300, + false); + // Encoder should become reusable for a new image from scratch after using + // reset. + JxlEncoderReset(enc.get()); + VerifyFrameEncoding(157, 77, enc.get(), + JxlEncoderFrameSettingsCreate(enc.get(), nullptr), 2300, + false); +} + +TEST(EncodeTest, CmsTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + bool cms_called = false; + JxlCmsInterface cms = *JxlGetDefaultCms(); + struct InitData { + void* original_init_data; + jpegxl_cms_init_func original_init; + bool* cms_called; + }; + InitData init_data = {/*original_init_data=*/cms.init_data, + /*original_init=*/cms.init, + /*cms_called=*/&cms_called}; + cms.init_data = &init_data; + cms.init = +[](void* raw_init_data, size_t num_threads, + size_t pixels_per_thread, const JxlColorProfile* input_profile, + const JxlColorProfile* output_profile, + float intensity_target) { + const InitData* init_data = static_cast<const InitData*>(raw_init_data); + *init_data->cms_called = true; + return init_data->original_init(init_data->original_init_data, num_threads, + pixels_per_thread, input_profile, + output_profile, intensity_target); + }; + JxlEncoderSetCms(enc.get(), cms); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), nullptr); + JxlEncoderSetFrameLossless(frame_settings, false); + ASSERT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption(frame_settings, + JXL_ENC_FRAME_SETTING_EFFORT, 8)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_TRUE(cms_called); +} + +TEST(EncodeTest, frame_settingsTest) { + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, 5)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(jxl::SpeedTier::kHare, enc->last_used_cparams.speed_tier); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + const size_t nb_options = 23; + const JxlEncoderFrameSettingId options[nb_options] = { + JXL_ENC_FRAME_SETTING_EFFORT, + JXL_ENC_FRAME_SETTING_BROTLI_EFFORT, + JXL_ENC_FRAME_SETTING_DECODING_SPEED, + JXL_ENC_FRAME_SETTING_RESAMPLING, + JXL_ENC_FRAME_SETTING_EXTRA_CHANNEL_RESAMPLING, + JXL_ENC_FRAME_SETTING_ALREADY_DOWNSAMPLED, + JXL_ENC_FRAME_SETTING_EPF, + JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_X, + JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_Y, + JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC, + JXL_ENC_FRAME_SETTING_PALETTE_COLORS, + JXL_ENC_FRAME_SETTING_COLOR_TRANSFORM, + JXL_ENC_FRAME_SETTING_MODULAR_COLOR_SPACE, + JXL_ENC_FRAME_SETTING_MODULAR_GROUP_SIZE, + JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR, + JXL_ENC_FRAME_SETTING_MODULAR_NB_PREV_CHANNELS, + JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL, + JXL_ENC_FRAME_INDEX_BOX, + JXL_ENC_FRAME_SETTING_JPEG_COMPRESS_BOXES, + JXL_ENC_FRAME_SETTING_BUFFERING, + JXL_ENC_FRAME_SETTING_JPEG_KEEP_EXIF, + JXL_ENC_FRAME_SETTING_JPEG_KEEP_XMP, + JXL_ENC_FRAME_SETTING_JPEG_KEEP_JUMBF}; + const int too_low[nb_options] = {0, -2, -2, 3, -2, -2, -2, -2, + -2, -2, -2, -2, -2, -2, -2, -2, + -2, -1, -2, -1, -2, -2, -2}; + const int too_high[nb_options] = {11, 12, 5, 16, 6, 2, 4, -3, + -3, 3, 70914, 3, 42, 4, 16, 12, + 2, 2, 2, 4, 2, 2, 2}; + const int in_range[nb_options] = {5, 5, 3, 1, 1, 1, 3, -1, + 0, 1, -1, -1, 3, 2, 15, -1, + -1, 1, 0, 0, -1, -1, -1}; + for (size_t i = 0; i < nb_options; i++) { + // Lower than currently supported values + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderFrameSettingsSetOption( + frame_settings, options[i], too_low[i])); + // Higher than currently supported values + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderFrameSettingsSetOption( + frame_settings, options[i], too_high[i])); + // Using SetFloatOption on integer options + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderFrameSettingsSetFloatOption( + frame_settings, options[i], 1.0f)); + // Within range of the currently supported values + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderFrameSettingsSetOption( + frame_settings, options[i], in_range[i])); + } + // Effort 10 should only work when expert options are allowed + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, 10)); + JxlEncoderAllowExpertOptions(enc.get()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, 10)); + + // Non-existing option + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_FILL_ENUM, 0)); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, JXL_ENC_FRAME_SETTING_FILL_ENUM, 0.f)); + + // Float options + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, JXL_ENC_FRAME_SETTING_PHOTON_NOISE, -1.0f)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, JXL_ENC_FRAME_SETTING_PHOTON_NOISE, 100.0f)); + EXPECT_EQ( + JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_MODULAR_MA_TREE_LEARNING_PERCENT, 101.0f)); + EXPECT_EQ( + JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_MODULAR_MA_TREE_LEARNING_PERCENT, -2.0f)); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_MODULAR_MA_TREE_LEARNING_PERCENT, -1.0f)); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GLOBAL_PERCENT, 101.0f)); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GLOBAL_PERCENT, -2.0f)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GLOBAL_PERCENT, -1.0f)); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GROUP_PERCENT, 101.0f)); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GROUP_PERCENT, -2.0f)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GROUP_PERCENT, -1.0f)); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetOption( + frame_settings, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GROUP_PERCENT, 50.0f)); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_PHOTON_NOISE, 50.0f)); + + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 2500, false); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetFrameLossless(frame_settings, JXL_TRUE)); + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 3000, false); + EXPECT_EQ(true, enc->last_used_cparams.IsLossless()); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetFrameDistance(frame_settings, 0.5)); + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 3030, false); + EXPECT_EQ(0.5, enc->last_used_cparams.butteraugli_distance); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + // Disallowed negative distance + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderSetFrameDistance(frame_settings, -1)); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_DECODING_SPEED, 2)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(2u, enc->last_used_cparams.decoding_speed_tier); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_ERROR, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_GROUP_ORDER, 100)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_GROUP_ORDER, 1)); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_GROUP_ORDER_CENTER_X, 5)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(true, enc->last_used_cparams.centerfirst); + EXPECT_EQ(5, enc->last_used_cparams.center_x); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_RESPONSIVE, 0)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC, 1)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_QPROGRESSIVE_AC, -1)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC, 2)); + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 2830, + /*lossy_use_original_profile=*/false); + EXPECT_EQ(false, enc->last_used_cparams.responsive); + EXPECT_EQ(jxl::Override::kOn, enc->last_used_cparams.progressive_mode); + EXPECT_EQ(2, enc->last_used_cparams.progressive_dc); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, JXL_ENC_FRAME_SETTING_PHOTON_NOISE, 1777.777)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_NEAR(1777.777f, enc->last_used_cparams.photon_noise_iso, 1E-4); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GLOBAL_PERCENT, 55.0f)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_CHANNEL_COLORS_GROUP_PERCENT, 25.0f)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_PALETTE_COLORS, 70000)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_LOSSY_PALETTE, 1)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_NEAR(55.0f, + enc->last_used_cparams.channel_colors_pre_transform_percent, + 1E-6); + EXPECT_NEAR(25.0f, enc->last_used_cparams.channel_colors_percent, 1E-6); + EXPECT_EQ(70000, enc->last_used_cparams.palette_colors); + EXPECT_EQ(true, enc->last_used_cparams.lossy_palette); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_MODULAR_COLOR_SPACE, 30)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_MODULAR_GROUP_SIZE, 2)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR, 14)); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetFloatOption( + frame_settings, + JXL_ENC_FRAME_SETTING_MODULAR_MA_TREE_LEARNING_PERCENT, 77.0f)); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_MODULAR_NB_PREV_CHANNELS, 7)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(30, enc->last_used_cparams.colorspace); + EXPECT_EQ(2, enc->last_used_cparams.modular_group_size_shift); + EXPECT_EQ(jxl::Predictor::Best, enc->last_used_cparams.options.predictor); + EXPECT_NEAR(0.77f, enc->last_used_cparams.options.nb_repeats, 1E-6); + EXPECT_EQ(7, enc->last_used_cparams.options.max_properties); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL, 0)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(false, enc->last_used_cparams.force_cfl_jpeg_recompression); + } + + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_JPEG_RECON_CFL, 1)); + VerifyFrameEncoding(enc.get(), frame_settings); + EXPECT_EQ(true, enc->last_used_cparams.force_cfl_jpeg_recompression); + } +} + +TEST(EncodeTest, LossyEncoderUseOriginalProfileTest) { + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 7897, true); + } + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC, 2)); + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 8310, true); + } + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + ASSERT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, 8)); + VerifyFrameEncoding(63, 129, enc.get(), frame_settings, 7228, true); + } +} + +namespace { +// Returns a copy of buf from offset to offset+size, or a new zeroed vector if +// the result would have been out of bounds taking integer overflow into +// account. +std::vector<uint8_t> SliceSpan(const jxl::Span<const uint8_t>& buf, + size_t offset, size_t size) { + if (offset + size >= buf.size()) { + return std::vector<uint8_t>(size, 0); + } + if (offset + size < offset) { + return std::vector<uint8_t>(size, 0); + } + return std::vector<uint8_t>(buf.data() + offset, buf.data() + offset + size); +} + +struct Box { + // The type of the box. + // If "uuid", use extended_type instead + char type[4] = {0, 0, 0, 0}; + + // The extended_type is only used when type == "uuid". + // Extended types are not used in JXL. However, the box format itself + // supports this so they are handled correctly. + char extended_type[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + // Box data. + jxl::Span<const uint8_t> data = jxl::Bytes(nullptr, 0); + + // If the size is not given, the datasize extends to the end of the file. + // If this field is false, the size field is not encoded when the box is + // serialized. + bool data_size_given = true; + + // If successful, returns true and sets `in` to be the rest data (if any). + // If `in` contains a box with a size larger than `in.size()`, will not + // modify `in`, and will return true but the data `Span<uint8_t>` will + // remain set to nullptr. + // If unsuccessful, returns error and doesn't modify `in`. + jxl::Status Decode(jxl::Span<const uint8_t>* in) { + // Total box_size including this header itself. + uint64_t box_size = LoadBE32(SliceSpan(*in, 0, 4).data()); + size_t pos = 4; + + memcpy(type, SliceSpan(*in, pos, 4).data(), 4); + pos += 4; + + if (box_size == 1) { + // If the size is 1, it indicates extended size read from 64-bit integer. + box_size = LoadBE64(SliceSpan(*in, pos, 8).data()); + pos += 8; + } + + if (!memcmp("uuid", type, 4)) { + memcpy(extended_type, SliceSpan(*in, pos, 16).data(), 16); + pos += 16; + } + + // This is the end of the box header, the box data begins here. Handle + // the data size now. + const size_t header_size = pos; + + if (box_size != 0) { + if (box_size < header_size) { + return JXL_FAILURE("Invalid box size"); + } + if (box_size > in->size()) { + // The box is fine, but the input is too short. + return true; + } + data_size_given = true; + data = jxl::Bytes(in->data() + header_size, box_size - header_size); + } else { + data_size_given = false; + data = jxl::Bytes(in->data() + header_size, in->size() - header_size); + } + + *in = jxl::Bytes(in->data() + header_size + data.size(), + in->size() - header_size - data.size()); + return true; + } +}; + +struct Container { + std::vector<Box> boxes; + + // If successful, returns true and sets `in` to be the rest data (if any). + // If unsuccessful, returns error and doesn't modify `in`. + jxl::Status Decode(jxl::Span<const uint8_t>* in) { + boxes.clear(); + + Box signature_box; + JXL_RETURN_IF_ERROR(signature_box.Decode(in)); + if (memcmp("JXL ", signature_box.type, 4) != 0) { + return JXL_FAILURE("Invalid magic signature"); + } + if (signature_box.data.size() != 4) + return JXL_FAILURE("Invalid magic signature"); + if (signature_box.data[0] != 0xd || signature_box.data[1] != 0xa || + signature_box.data[2] != 0x87 || signature_box.data[3] != 0xa) { + return JXL_FAILURE("Invalid magic signature"); + } + + Box ftyp_box; + JXL_RETURN_IF_ERROR(ftyp_box.Decode(in)); + if (memcmp("ftyp", ftyp_box.type, 4) != 0) { + return JXL_FAILURE("Invalid ftyp"); + } + if (ftyp_box.data.size() != 12) return JXL_FAILURE("Invalid ftyp"); + const char* expected = "jxl \0\0\0\0jxl "; + if (memcmp(expected, ftyp_box.data.data(), 12) != 0) + return JXL_FAILURE("Invalid ftyp"); + + while (!in->empty()) { + Box box = {}; + JXL_RETURN_IF_ERROR(box.Decode(in)); + if (box.data.data() == nullptr) { + // The decoding encountered a box, but not enough data yet. + return true; + } + boxes.emplace_back(box); + } + + return true; + } +}; + +} // namespace + +TEST(EncodeTest, SingleFrameBoundedJXLCTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc.get(), true)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + + size_t xsize = 71; + size_t ysize = 23; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + Container container = {}; + jxl::Span<const uint8_t> encoded_span = + jxl::Bytes(compressed.data(), compressed.size()); + EXPECT_TRUE(container.Decode(&encoded_span)); + EXPECT_EQ(0u, encoded_span.size()); + bool found_jxlc = false; + bool found_jxlp = false; + // The encoder is allowed to either emit a jxlc or one or more jxlp. + for (size_t i = 0; i < container.boxes.size(); ++i) { + if (memcmp("jxlc", container.boxes[i].type, 4) == 0) { + EXPECT_EQ(false, found_jxlc); // Max 1 jxlc + EXPECT_EQ(false, found_jxlp); // Can't mix jxlc and jxlp + found_jxlc = true; + } + if (memcmp("jxlp", container.boxes[i].type, 4) == 0) { + EXPECT_EQ(false, found_jxlc); // Can't mix jxlc and jxlp + found_jxlp = true; + } + // The encoder shouldn't create an unbounded box in this case, with the + // single frame it knows the full size in time, so can help make decoding + // more efficient by giving the full box size of the final box. + EXPECT_EQ(true, container.boxes[i].data_size_given); + } + EXPECT_EQ(true, found_jxlc || found_jxlp); +} + +TEST(EncodeTest, CodestreamLevelTest) { + size_t xsize = 64; + size_t ysize = 64; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + + jxl::CodecInOut input_io = + jxl::test::SomeTestImageToCodecInOut(pixels, 4, xsize, ysize); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + Container container = {}; + jxl::Span<const uint8_t> encoded_span = + jxl::Bytes(compressed.data(), compressed.size()); + EXPECT_TRUE(container.Decode(&encoded_span)); + EXPECT_EQ(0u, encoded_span.size()); + EXPECT_EQ(0, memcmp("jxll", container.boxes[0].type, 4)); +} + +TEST(EncodeTest, CodestreamLevelVerificationTest) { + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT8, JXL_BIG_ENDIAN, 0}; + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = 64; + basic_info.ysize = 64; + basic_info.uses_original_profile = false; + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + + EXPECT_EQ(5, JxlEncoderGetRequiredCodestreamLevel(enc.get())); + + // Set an image dimension that is too large for level 5, but fits in level 10 + + basic_info.xsize = 1ull << 30ull; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 5)); + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + EXPECT_EQ(10, JxlEncoderGetRequiredCodestreamLevel(enc.get())); + + // Set an image dimension that is too large even for level 10 + + basic_info.xsize = 1ull << 31ull; + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); +} + +TEST(EncodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGReconstructionTest)) { + const std::string jpeg_path = "jxl/flower/flower.png.im_q85_420.jpg"; + const std::vector<uint8_t> orig = jxl::test::ReadTestData(jpeg_path); + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderStoreJPEGMetadata(enc.get(), JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddJPEGFrame(frame_settings, orig.data(), orig.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + jxl::extras::JXLDecompressParams dparams; + jxl::test::DefaultAcceptedFormats(dparams); + std::vector<uint8_t> decoded_jpeg_bytes; + jxl::extras::PackedPixelFile ppf; + EXPECT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), dparams, + nullptr, &ppf, &decoded_jpeg_bytes)); + + EXPECT_EQ(decoded_jpeg_bytes.size(), orig.size()); + EXPECT_EQ(0, memcmp(decoded_jpeg_bytes.data(), orig.data(), orig.size())); +} + +TEST(EncodeTest, JXL_TRANSCODE_JPEG_TEST(ProgressiveJPEGReconstructionTest)) { + const std::string jpeg_path = "jxl/flower/flower.png.im_q85_420.jpg"; + const std::vector<uint8_t> orig = jxl::test::ReadTestData(jpeg_path); + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + + frame_settings->values.cparams.progressive_mode = jxl::Override::kOn; + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderStoreJPEGMetadata(enc.get(), JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddJPEGFrame(frame_settings, orig.data(), orig.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + jxl::extras::JXLDecompressParams dparams; + jxl::test::DefaultAcceptedFormats(dparams); + std::vector<uint8_t> decoded_jpeg_bytes; + jxl::extras::PackedPixelFile ppf; + EXPECT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), dparams, + nullptr, &ppf, &decoded_jpeg_bytes)); + + EXPECT_EQ(decoded_jpeg_bytes.size(), orig.size()); + EXPECT_EQ(0, memcmp(decoded_jpeg_bytes.data(), orig.data(), orig.size())); +} + +static void ProcessEncoder(JxlEncoder* enc, std::vector<uint8_t>& compressed, + uint8_t*& next_out, size_t& avail_out) { + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc, &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + size_t offset = next_out - compressed.data(); + compressed.resize(next_out - compressed.data()); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); +} + +TEST(EncodeTest, BasicInfoTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + size_t xsize = 1; + size_t ysize = 1; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + basic_info.have_animation = true; + basic_info.intensity_target = 123.4; + basic_info.min_nits = 5.0; + basic_info.linear_below = 12.7; + basic_info.orientation = JXL_ORIENT_ROTATE_90_CW; + basic_info.intrinsic_xsize = 88; + basic_info.intrinsic_ysize = 99; + basic_info.animation.tps_numerator = 55; + basic_info.animation.tps_denominator = 77; + basic_info.animation.num_loops = 10; + basic_info.animation.have_timecodes = JXL_TRUE; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseFrames(enc.get()); + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + // Decode to verify the boxes, we don't decode to pixels, only the boxes. + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_NE(nullptr, dec.get()); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_BASIC_INFO)); + // Allow testing the orientation field, without this setting it will be + // overridden to identity. + JxlDecoderSetKeepOrientation(dec.get(), JXL_TRUE); + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + JxlDecoderCloseInput(dec.get()); + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + if (status == JXL_DEC_ERROR) { + FAIL(); + } else if (status == JXL_DEC_SUCCESS) { + break; + } else if (status == JXL_DEC_BASIC_INFO) { + JxlBasicInfo basic_info2; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetBasicInfo(dec.get(), &basic_info2)); + EXPECT_EQ(basic_info.xsize, basic_info2.xsize); + EXPECT_EQ(basic_info.ysize, basic_info2.ysize); + EXPECT_EQ(basic_info.bits_per_sample, basic_info2.bits_per_sample); + EXPECT_EQ(basic_info.exponent_bits_per_sample, + basic_info2.exponent_bits_per_sample); + EXPECT_NEAR(basic_info.intensity_target, basic_info2.intensity_target, + 0.5); + EXPECT_NEAR(basic_info.min_nits, basic_info2.min_nits, 0.5); + EXPECT_NEAR(basic_info.linear_below, basic_info2.linear_below, 0.5); + EXPECT_EQ(basic_info.relative_to_max_display, + basic_info2.relative_to_max_display); + EXPECT_EQ(basic_info.uses_original_profile, + basic_info2.uses_original_profile); + EXPECT_EQ(basic_info.orientation, basic_info2.orientation); + EXPECT_EQ(basic_info.intrinsic_xsize, basic_info2.intrinsic_xsize); + EXPECT_EQ(basic_info.intrinsic_ysize, basic_info2.intrinsic_ysize); + EXPECT_EQ(basic_info.num_color_channels, basic_info2.num_color_channels); + // TODO(lode): also test num_extra_channels, but currently there may be a + // mismatch between 0 and 1 if there is alpha, until encoder support for + // extra channels is fully implemented. + EXPECT_EQ(basic_info.alpha_bits, basic_info2.alpha_bits); + EXPECT_EQ(basic_info.alpha_exponent_bits, + basic_info2.alpha_exponent_bits); + EXPECT_EQ(basic_info.alpha_premultiplied, + basic_info2.alpha_premultiplied); + + EXPECT_EQ(basic_info.have_preview, basic_info2.have_preview); + if (basic_info.have_preview) { + EXPECT_EQ(basic_info.preview.xsize, basic_info2.preview.xsize); + EXPECT_EQ(basic_info.preview.ysize, basic_info2.preview.ysize); + } + + EXPECT_EQ(basic_info.have_animation, basic_info2.have_animation); + if (basic_info.have_animation) { + EXPECT_EQ(basic_info.animation.tps_numerator, + basic_info2.animation.tps_numerator); + EXPECT_EQ(basic_info.animation.tps_denominator, + basic_info2.animation.tps_denominator); + EXPECT_EQ(basic_info.animation.num_loops, + basic_info2.animation.num_loops); + EXPECT_EQ(basic_info.animation.have_timecodes, + basic_info2.animation.have_timecodes); + } + } else { + FAIL(); // unexpected status + } + } +} + +TEST(EncodeTest, AnimationHeaderTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + size_t xsize = 1; + size_t ysize = 1; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.have_animation = true; + basic_info.animation.tps_numerator = 1000; + basic_info.animation.tps_denominator = 1; + basic_info.animation.have_timecodes = JXL_TRUE; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + + std::string frame_name = "test frame"; + JxlFrameHeader header; + JxlEncoderInitFrameHeader(&header); + header.duration = 50; + header.timecode = 800; + header.layer_info.blend_info.blendmode = JXL_BLEND_BLEND; + header.layer_info.blend_info.source = 2; + header.layer_info.blend_info.clamp = 1; + JxlBlendInfo extra_channel_blend_info; + JxlEncoderInitBlendInfo(&extra_channel_blend_info); + extra_channel_blend_info.blendmode = JXL_BLEND_MULADD; + JxlEncoderSetFrameHeader(frame_settings, &header); + JxlEncoderSetExtraChannelBlendInfo(frame_settings, 0, + &extra_channel_blend_info); + JxlEncoderSetFrameName(frame_settings, frame_name.c_str()); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseFrames(enc.get()); + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + // Decode to verify the boxes, we don't decode to pixels, only the boxes. + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_NE(nullptr, dec.get()); + + // To test the blend_info fields, coalescing must be set to false in the + // decoder. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetCoalescing(dec.get(), JXL_FALSE)); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_FRAME)); + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + JxlDecoderCloseInput(dec.get()); + + bool seen_frame = false; + + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + if (status == JXL_DEC_ERROR) { + FAIL(); + } else if (status == JXL_DEC_SUCCESS) { + break; + } else if (status == JXL_DEC_FRAME) { + seen_frame = true; + JxlFrameHeader header2; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec.get(), &header2)); + EXPECT_EQ(header.duration, header2.duration); + EXPECT_EQ(header.timecode, header2.timecode); + EXPECT_EQ(header.layer_info.blend_info.blendmode, + header2.layer_info.blend_info.blendmode); + EXPECT_EQ(header.layer_info.blend_info.clamp, + header2.layer_info.blend_info.clamp); + EXPECT_EQ(header.layer_info.blend_info.source, + header2.layer_info.blend_info.source); + EXPECT_EQ(frame_name.size(), header2.name_length); + JxlBlendInfo extra_channel_blend_info2; + JxlDecoderGetExtraChannelBlendInfo(dec.get(), 0, + &extra_channel_blend_info2); + EXPECT_EQ(extra_channel_blend_info.blendmode, + extra_channel_blend_info2.blendmode); + if (header2.name_length > 0) { + std::string frame_name2(header2.name_length + 1, '\0'); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetFrameName(dec.get(), &frame_name2.front(), + frame_name2.size())); + frame_name2.resize(header2.name_length); + EXPECT_EQ(frame_name, frame_name2); + } + } else { + FAIL(); // unexpected status + } + } + + EXPECT_EQ(true, seen_frame); +} +TEST(EncodeTest, CroppedFrameTest) { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + size_t xsize = 300; + size_t ysize = 300; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + std::vector<uint8_t> pixels2(pixels.size()); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + // Encoding a 300x300 frame in an image that is only 100x100 + basic_info.xsize = 100; + basic_info.ysize = 100; + basic_info.uses_original_profile = JXL_TRUE; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + + JxlFrameHeader header; + JxlEncoderInitFrameHeader(&header); + header.layer_info.have_crop = JXL_TRUE; + header.layer_info.xsize = xsize; + header.layer_info.ysize = ysize; + header.layer_info.crop_x0 = -50; + header.layer_info.crop_y0 = -250; + JxlEncoderSetFrameLossless(frame_settings, JXL_TRUE); + JxlEncoderSetFrameHeader(frame_settings, &header); + JxlEncoderFrameSettingsSetOption(frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, + 1); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(100); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + JxlEncoderCloseFrames(enc.get()); + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_NE(nullptr, dec.get()); + // Non-coalesced decoding so we can get the full uncropped frame + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetCoalescing(dec.get(), JXL_FALSE)); + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_FRAME | JXL_DEC_FULL_IMAGE)); + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + JxlDecoderCloseInput(dec.get()); + + bool seen_frame = false; + bool checked_frame = false; + for (;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + if (status == JXL_DEC_ERROR) { + FAIL(); + } else if (status == JXL_DEC_SUCCESS) { + break; + } else if (status == JXL_DEC_FRAME) { + seen_frame = true; + JxlFrameHeader header2; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetFrameHeader(dec.get(), &header2)); + EXPECT_EQ(header.layer_info.xsize, header2.layer_info.xsize); + EXPECT_EQ(header.layer_info.ysize, header2.layer_info.ysize); + EXPECT_EQ(header.layer_info.crop_x0, header2.layer_info.crop_x0); + EXPECT_EQ(header.layer_info.crop_y0, header2.layer_info.crop_y0); + } else if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec.get(), &pixel_format, + pixels2.data(), pixels2.size())); + } else if (status == JXL_DEC_FULL_IMAGE) { + EXPECT_EQ(0, memcmp(pixels.data(), pixels2.data(), pixels.size())); + checked_frame = true; + } else { + FAIL(); // unexpected status + } + } + EXPECT_EQ(true, checked_frame); + EXPECT_EQ(true, seen_frame); +} + +struct EncodeBoxTest : public testing::TestWithParam<std::tuple<bool, size_t>> { +}; + +TEST_P(EncodeBoxTest, JXL_BOXES_TEST(BoxTest)) { + // Test with uncompressed boxes and with brob boxes + bool compress_box = std::get<0>(GetParam()); + size_t xml_box_size = std::get<1>(GetParam()); + // TODO(firsching): use xml_box_size + (void)xml_box_size; + // Tests adding two metadata boxes with the encoder: an exif box before the + // image frame, and an xml box after the image frame. Then verifies the + // decoder can decode them, they are in the expected place, and have the + // correct content after decoding. + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + EXPECT_NE(nullptr, enc.get()); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseBoxes(enc.get())); + + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + size_t xsize = 50; + size_t ysize = 17; + JxlPixelFormat pixel_format = {4, JXL_TYPE_UINT16, JXL_BIG_ENDIAN, 0}; + std::vector<uint8_t> pixels = jxl::test::GetSomeTestImage(xsize, ysize, 4, 0); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc.get(), 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + + // Add an early metadata box. Also add a valid 4-byte TIFF offset header + // before the fake exif data of these box contents. + constexpr const char* exif_test_string = "\0\0\0\0exif test data"; + const uint8_t* exif_data = reinterpret_cast<const uint8_t*>(exif_test_string); + // Skip the 4 zeroes for strlen + const size_t exif_size = 4 + strlen(exif_test_string + 4); + JxlEncoderAddBox(enc.get(), "Exif", exif_data, exif_size, compress_box); + + // Write to output + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + // Add image frame + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + pixels.data(), pixels.size())); + // Indicate this is the last frame + JxlEncoderCloseFrames(enc.get()); + + // Write to output + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + // Add a late metadata box + constexpr const char* xml_test_string = "<some random xml data>"; + const uint8_t* xml_data = reinterpret_cast<const uint8_t*>(xml_test_string); + size_t xml_size = strlen(xml_test_string); + JxlEncoderAddBox(enc.get(), "XML ", xml_data, xml_size, compress_box); + + // Indicate this is the last box + JxlEncoderCloseBoxes(enc.get()); + + // Write to output + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + + // Decode to verify the boxes, we don't decode to pixels, only the boxes. + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_NE(nullptr, dec.get()); + + if (compress_box) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetDecompressBoxes(dec.get(), JXL_TRUE)); + } + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec.get(), JXL_DEC_FRAME | JXL_DEC_BOX)); + + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + JxlDecoderCloseInput(dec.get()); + + std::vector<uint8_t> dec_exif_box(exif_size); + std::vector<uint8_t> dec_xml_box(xml_size); + + for (bool post_frame = false;;) { + JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); + if (status == JXL_DEC_ERROR) { + FAIL(); + } else if (status == JXL_DEC_SUCCESS) { + EXPECT_EQ(0, JxlDecoderReleaseBoxBuffer(dec.get())); + break; + } else if (status == JXL_DEC_FRAME) { + post_frame = true; + } else if (status == JXL_DEC_BOX) { + // Since we gave the exif/xml box output buffer of the exact known + // correct size, 0 bytes should be released. Same when no buffer was + // set. + EXPECT_EQ(0, JxlDecoderReleaseBoxBuffer(dec.get())); + JxlBoxType type; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBoxType(dec.get(), type, true)); + if (!memcmp(type, "Exif", 4)) { + // This box should have been encoded before the image frame + EXPECT_EQ(false, post_frame); + JxlDecoderSetBoxBuffer(dec.get(), dec_exif_box.data(), + dec_exif_box.size()); + } else if (!memcmp(type, "XML ", 4)) { + // This box should have been encoded after the image frame + EXPECT_EQ(true, post_frame); + JxlDecoderSetBoxBuffer(dec.get(), dec_xml_box.data(), + dec_xml_box.size()); + } + } else { + FAIL(); // unexpected status + } + } + + EXPECT_EQ(0, memcmp(exif_data, dec_exif_box.data(), exif_size)); + EXPECT_EQ(0, memcmp(xml_data, dec_xml_box.data(), xml_size)); +} + +std::string nameBoxTest( + const ::testing::TestParamInfo<std::tuple<bool, size_t>>& info) { + return (std::get<0>(info.param) ? "C" : "Unc") + std::string("ompressed") + + "_BoxSize_" + std::to_string((std::get<1>(info.param))); +} + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P( + EncodeBoxParamsTest, EncodeBoxTest, + testing::Combine(testing::Values(false, true), + testing::Values(256, + jxl::kLargeBoxContentSizeThreshold + 77)), + nameBoxTest); + +TEST(EncodeTest, JXL_TRANSCODE_JPEG_TEST(JPEGFrameTest)) { + TEST_LIBJPEG_SUPPORT(); + for (int skip_basic_info = 0; skip_basic_info < 2; skip_basic_info++) { + for (int skip_color_encoding = 0; skip_color_encoding < 2; + skip_color_encoding++) { + // cannot set color encoding if basic info is not set + if (skip_basic_info && !skip_color_encoding) continue; + const std::string jpeg_path = "jxl/flower/flower_cropped.jpg"; + const std::vector<uint8_t> orig = jxl::test::ReadTestData(jpeg_path); + jxl::CodecInOut orig_io; + ASSERT_TRUE(SetFromBytes(jxl::Bytes(orig), &orig_io, + /*pool=*/nullptr)); + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + JxlEncoderFrameSettingsSetOption(frame_settings, + JXL_ENC_FRAME_SETTING_EFFORT, 1); + if (!skip_basic_info) { + JxlBasicInfo basic_info; + JxlEncoderInitBasicInfo(&basic_info); + basic_info.xsize = orig_io.xsize(); + basic_info.ysize = orig_io.ysize(); + basic_info.uses_original_profile = true; + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetBasicInfo(enc.get(), &basic_info)); + } + if (!skip_color_encoding) { + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc.get(), &color_encoding)); + } + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderAddJPEGFrame( + frame_settings, orig.data(), orig.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size() - (next_out - compressed.data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = + JxlEncoderProcessOutput(enc.get(), &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed.data(); + compressed.resize(compressed.size() * 2); + next_out = compressed.data() + offset; + avail_out = compressed.size() - offset; + } + } + compressed.resize(next_out - compressed.data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); + + jxl::CodecInOut decoded_io; + EXPECT_TRUE(jxl::test::DecodeFile( + {}, jxl::Bytes(compressed.data(), compressed.size()), &decoded_io)); + + EXPECT_LE(ComputeDistance2(orig_io.Main(), decoded_io.Main(), + *JxlGetDefaultCms()), + 3.5); + } + } +} + +namespace { +class JxlStreamingAdapter { + public: + JxlStreamingAdapter(JxlEncoder* encoder, bool return_large_buffers, + bool can_seek) + : return_large_buffers_(return_large_buffers) { + struct JxlEncoderOutputProcessor output_processor; + output_processor.opaque = this; + output_processor.get_buffer = + METHOD_TO_C_CALLBACK(&JxlStreamingAdapter::GetBuffer); + if (can_seek) { + output_processor.seek = METHOD_TO_C_CALLBACK(&JxlStreamingAdapter::Seek); + } else { + output_processor.seek = nullptr; + } + output_processor.set_finalized_position = + METHOD_TO_C_CALLBACK(&JxlStreamingAdapter::SetFinalizedPosition); + output_processor.release_buffer = + METHOD_TO_C_CALLBACK(&JxlStreamingAdapter::ReleaseBuffer); + EXPECT_EQ(JxlEncoderSetOutputProcessor(encoder, output_processor), + JXL_ENC_SUCCESS); + } + + std::vector<uint8_t> output() && { + output_.resize(position_); + return std::move(output_); + } + + void* GetBuffer(size_t* size) { + if (!return_large_buffers_) { + *size = 1; + } + if (position_ + *size > output_.size()) { + output_.resize(position_ + *size, 0xDA); + } + if (return_large_buffers_) { + *size = output_.size() - position_; + } + return output_.data() + position_; + } + + void ReleaseBuffer(size_t written_bytes) { + // TODO(veluca): check no more bytes were written. + Seek(position_ + written_bytes); + } + + void Seek(uint64_t position) { + EXPECT_GE(position, finalized_position_); + position_ = position; + } + + void SetFinalizedPosition(uint64_t finalized_position) { + EXPECT_GE(finalized_position, finalized_position_); + finalized_position_ = finalized_position; + EXPECT_GE(position_, finalized_position_); + } + + void CheckFinalWatermarkPosition() const { + EXPECT_EQ(finalized_position_, position_); + } + + private: + std::vector<uint8_t> output_; + size_t position_ = 0; + size_t finalized_position_ = 0; + bool return_large_buffers_; +}; + +class JxlChunkedFrameInputSourceAdapter { + private: + static const void* GetDataAt(const jxl::extras::PackedImage& image, + size_t xpos, size_t ypos, size_t* row_offset) { + JxlDataType data_type = image.format.data_type; + size_t num_channels = image.format.num_channels; + size_t bytes_per_pixel = + num_channels * jxl::extras::PackedImage::BitsPerChannel(data_type) / 8; + *row_offset = image.stride; + return static_cast<uint8_t*>(image.pixels()) + bytes_per_pixel * xpos + + ypos * image.stride; + } + + public: + // Constructor to wrap the image data or any other state + explicit JxlChunkedFrameInputSourceAdapter( + jxl::extras::PackedImage color_channel, + jxl::extras::PackedImage extra_channel) + : colorchannel_(std::move(color_channel)), + extra_channel_(std::move(extra_channel)) {} + ~JxlChunkedFrameInputSourceAdapter() { EXPECT_TRUE(active_buffers_.empty()); } + + void GetColorChannelsPixelFormat(JxlPixelFormat* pixel_format) { + *pixel_format = colorchannel_.format; + } + + const void* GetColorChannelDataAt(size_t xpos, size_t ypos, size_t xsize, + size_t ysize, size_t* row_offset) { + const void* p = GetDataAt(colorchannel_, xpos, ypos, row_offset); + std::lock_guard<std::mutex> lock(mtx_); + active_buffers_.insert(p); + return p; + } + + void GetExtraChannelPixelFormat(size_t ec_index, + JxlPixelFormat* pixel_format) { + // In this test, we we the same color channel data, so `ec_index` is never + // used + *pixel_format = extra_channel_.format; + } + + const void* GetExtraChannelDataAt(size_t ec_index, size_t xpos, size_t ypos, + size_t xsize, size_t ysize, + size_t* row_offset) { + // In this test, we we the same color channel data, so `ec_index` is never + // used + const void* p = GetDataAt(extra_channel_, xpos, ypos, row_offset); + std::lock_guard<std::mutex> lock(mtx_); + active_buffers_.insert(p); + return p; + } + void ReleaseCurrentData(const void* buffer) { + std::lock_guard<std::mutex> lock(mtx_); + auto iter = active_buffers_.find(buffer); + if (iter != active_buffers_.end()) { + active_buffers_.erase(iter); + } + } + + JxlChunkedFrameInputSource GetInputSource() { + return JxlChunkedFrameInputSource{ + this, + METHOD_TO_C_CALLBACK( + &JxlChunkedFrameInputSourceAdapter::GetColorChannelsPixelFormat), + METHOD_TO_C_CALLBACK( + &JxlChunkedFrameInputSourceAdapter::GetColorChannelDataAt), + METHOD_TO_C_CALLBACK( + &JxlChunkedFrameInputSourceAdapter::GetExtraChannelPixelFormat), + METHOD_TO_C_CALLBACK( + &JxlChunkedFrameInputSourceAdapter::GetExtraChannelDataAt), + METHOD_TO_C_CALLBACK( + &JxlChunkedFrameInputSourceAdapter::ReleaseCurrentData)}; + } + + private: + const jxl::extras::PackedImage colorchannel_; + const jxl::extras::PackedImage extra_channel_; + std::mutex mtx_; + std::set<const void*> active_buffers_; +}; + +struct StreamingTestParam { + size_t bitmask; + bool use_container() const { return bitmask & 0x1; } + bool return_large_buffers() const { return bitmask & 0x2; } + bool multiple_frames() const { return bitmask & 0x4; } + bool fast_lossless() const { return bitmask & 0x8; } + bool can_seek() const { return bitmask & 0x10; } + bool with_extra_channels() const { return bitmask & 0x20; } + bool color_includes_alpha() const { return bitmask & 0x40; } + bool onegroup() const { return bitmask & 0x80; } + + bool is_lossless() const { return fast_lossless(); } + + static std::vector<StreamingTestParam> All() { + std::vector<StreamingTestParam> params; + for (size_t bitmask = 0; bitmask < 256; bitmask++) { + params.push_back(StreamingTestParam{bitmask}); + } + return params; + } +}; + +std::ostream& operator<<(std::ostream& out, StreamingTestParam p) { + if (p.use_container()) { + out << "WithContainer_"; + } else { + out << "WithoutContainer_"; + } + if (p.return_large_buffers()) { + out << "WithLargeBuffers_"; + } else { + out << "WithSmallBuffers_"; + } + if (p.multiple_frames()) out << "WithMultipleFrames_"; + if (p.fast_lossless()) out << "FastLossless_"; + if (!p.can_seek()) { + out << "CannotSeek_"; + } else { + out << "CanSeek_"; + } + if (p.with_extra_channels()) { + out << "WithExtraChannels_"; + } else { + out << "WithoutExtraChannels_"; + } + if (p.color_includes_alpha()) { + out << "ColorIncludesAlpha_"; + } else { + out << "ColorWithoutAlpha_"; + } + if (p.onegroup()) { + out << "OneGroup_"; + } else { + out << "MultiGroup_"; + } + return out; +} + +} // namespace + +class EncoderStreamingTest : public testing::TestWithParam<StreamingTestParam> { + public: + static void SetupImage(const StreamingTestParam& p, size_t xsize, + size_t ysize, size_t num_channels, + size_t bits_per_sample, jxl::test::TestImage& image) { + image.SetDimensions(xsize, ysize) + .SetDataType(JXL_TYPE_UINT8) + .SetChannels(num_channels) + .SetAllBitDepths(bits_per_sample); + if (p.onegroup()) { + image.SetRowAlignment(128); + } + image.AddFrame().RandomFill(); + } + static void SetUpBasicInfo(JxlBasicInfo& basic_info, size_t xsize, + size_t ysize, size_t number_extra_channels, + bool include_alpha, bool is_lossless) { + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.num_extra_channels = number_extra_channels + include_alpha; + basic_info.uses_original_profile = is_lossless; + } + + static void SetupEncoder(JxlEncoderFrameSettings* frame_settings, + const StreamingTestParam& p, + const JxlBasicInfo& basic_info, + size_t number_extra_channels, bool streaming) { + JxlEncoderStruct* enc = frame_settings->enc; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + if (p.fast_lossless()) { + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetFrameLossless(frame_settings, JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_EFFORT, 1)); + } + JxlColorEncoding color_encoding; + JxlColorEncodingSetToSRGB(&color_encoding, /*is_gray=*/false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc, &color_encoding)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption(frame_settings, + JXL_ENC_FRAME_SETTING_BUFFERING, + streaming ? 3 : 0)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, + JXL_ENC_FRAME_SETTING_USE_FULL_IMAGE_HEURISTICS, 0)); + if (p.use_container()) { + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc, 10)); + } + for (size_t i = 0; i < number_extra_channels; i++) { + JxlExtraChannelInfo channel_info; + JxlExtraChannelType channel_type = JXL_CHANNEL_THERMAL; + JxlEncoderInitExtraChannelInfo(channel_type, &channel_info); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetExtraChannelInfo(enc, i, &channel_info)); + } + } + + static void SetupInputNonStreaming(JxlEncoderFrameSettings* frame_settings, + const StreamingTestParam& p, + size_t number_extra_channels, + const jxl::extras::PackedImage& frame, + const jxl::extras::PackedImage& ec_frame) { + size_t frame_count = static_cast<int>(p.multiple_frames()) + 1; + for (size_t i = 0; i < frame_count; i++) { + { + // Copy pixel data here because it is only guaranteed to be available + // during the call to JxlEncoderAddImageFrame(). + std::vector<uint8_t> pixels(frame.pixels_size); + memcpy(&pixels[0], frame.pixels(), pixels.size()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &frame.format, + pixels.data(), pixels.size())); + } + for (size_t i = 0; i < number_extra_channels; i++) { + // Copy pixel data here because it is only guaranteed to be available + // during the call to JxlEncoderSetExtraChannelBuffer(). + std::vector<uint8_t> ec_pixels(ec_frame.pixels_size); + memcpy(&ec_pixels[0], ec_frame.pixels(), ec_pixels.size()); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetExtraChannelBuffer( + frame_settings, &ec_frame.format, + ec_pixels.data(), ec_pixels.size(), i)); + } + } + JxlEncoderCloseInput(frame_settings->enc); + } + + static void SetupInputStreaming(JxlEncoderFrameSettings* frame_settings, + const StreamingTestParam& p, + size_t number_extra_channels, + const jxl::extras::PackedImage& frame, + const jxl::extras::PackedImage& ec_frame) { + size_t frame_count = static_cast<int>(p.multiple_frames()) + 1; + for (size_t i = 0; i < frame_count; i++) { + // Create local copy of pixels and adapter because they are only + // guarantted to be available during the JxlEncoderAddChunkedFrame() call. + JxlChunkedFrameInputSourceAdapter chunked_frame_adapter(frame.Copy(), + ec_frame.Copy()); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddChunkedFrame( + // should only set `JXL_TRUE` in the lass pass of the loop + frame_settings, i + 1 == frame_count ? JXL_TRUE : JXL_FALSE, + chunked_frame_adapter.GetInputSource())); + } + } +}; + +TEST_P(EncoderStreamingTest, OutputCallback) { + const StreamingTestParam p = GetParam(); + size_t xsize = p.onegroup() ? 17 : 257; + size_t ysize = p.onegroup() ? 19 : 259; + size_t number_extra_channels = p.with_extra_channels() ? 5 : 0; + jxl::test::TestImage image; + SetupImage(p, xsize, ysize, p.color_includes_alpha() ? 4 : 3, + p.use_container() ? 16 : 8, image); + jxl::test::TestImage ec_image; + SetupImage(p, xsize, ysize, 1, 8, ec_image); + const auto& frame = image.ppf().frames[0].color; + const auto& ec_frame = ec_image.ppf().frames[0].color; + JxlBasicInfo basic_info = image.ppf().info; + SetUpBasicInfo(basic_info, xsize, ysize, number_extra_channels, + p.color_includes_alpha(), p.is_lossless()); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + // without sreaming + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + SetupEncoder(frame_settings, p, basic_info, number_extra_channels, false); + SetupInputNonStreaming(frame_settings, p, number_extra_channels, frame, + ec_frame); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size(); + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + } + + std::vector<uint8_t> streaming_compressed; + // with streaming + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + SetupEncoder(frame_settings, p, basic_info, number_extra_channels, true); + SetupInputNonStreaming(frame_settings, p, number_extra_channels, frame, + ec_frame); + JxlStreamingAdapter streaming_adapter(enc.get(), p.return_large_buffers(), + p.can_seek()); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderFlushInput(enc.get())); + streaming_adapter.CheckFinalWatermarkPosition(); + streaming_compressed = std::move(streaming_adapter).output(); + } + + EXPECT_TRUE(SameDecodedPixels(compressed, streaming_compressed)); + EXPECT_LE(streaming_compressed.size(), compressed.size() + 1024); +} + +TEST_P(EncoderStreamingTest, ChunkedFrame) { + const StreamingTestParam p = GetParam(); + size_t xsize = p.onegroup() ? 17 : 257; + size_t ysize = p.onegroup() ? 19 : 259; + size_t number_extra_channels = p.with_extra_channels() ? 5 : 0; + jxl::test::TestImage image; + SetupImage(p, xsize, ysize, p.color_includes_alpha() ? 4 : 3, + p.use_container() ? 16 : 8, image); + jxl::test::TestImage ec_image; + SetupImage(p, xsize, ysize, 1, 8, ec_image); + const auto& frame = image.ppf().frames[0].color; + const auto& ec_frame = ec_image.ppf().frames[0].color; + JxlBasicInfo basic_info = image.ppf().info; + SetUpBasicInfo(basic_info, xsize, ysize, number_extra_channels, + p.color_includes_alpha(), p.is_lossless()); + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + std::vector<uint8_t> streaming_compressed = std::vector<uint8_t>(64); + + // without streaming + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + SetupEncoder(frame_settings, p, basic_info, number_extra_channels, false); + SetupInputNonStreaming(frame_settings, p, number_extra_channels, frame, + ec_frame); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size(); + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + } + + // with streaming + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + SetupEncoder(frame_settings, p, basic_info, number_extra_channels, true); + SetupInputStreaming(frame_settings, p, number_extra_channels, frame, + ec_frame); + uint8_t* next_out = streaming_compressed.data(); + size_t avail_out = streaming_compressed.size(); + ProcessEncoder(enc.get(), streaming_compressed, next_out, avail_out); + } + + EXPECT_TRUE(SameDecodedPixels(compressed, streaming_compressed)); + EXPECT_LE(streaming_compressed.size(), compressed.size() + 1024); +} + +TEST_P(EncoderStreamingTest, ChunkedAndOutputCallback) { + const StreamingTestParam p = GetParam(); + size_t xsize = p.onegroup() ? 17 : 257; + size_t ysize = p.onegroup() ? 19 : 259; + size_t number_extra_channels = p.with_extra_channels() ? 5 : 0; + jxl::test::TestImage image; + SetupImage(p, xsize, ysize, p.color_includes_alpha() ? 4 : 3, + p.use_container() ? 16 : 8, image); + jxl::test::TestImage ec_image; + SetupImage(p, xsize, ysize, 1, 8, ec_image); + const auto& frame = image.ppf().frames[0].color; + const auto& ec_frame = ec_image.ppf().frames[0].color; + JxlBasicInfo basic_info = image.ppf().info; + SetUpBasicInfo(basic_info, xsize, ysize, number_extra_channels, + p.color_includes_alpha(), p.is_lossless()); + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + + // without streaming + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + SetupEncoder(frame_settings, p, basic_info, number_extra_channels, false); + SetupInputNonStreaming(frame_settings, p, number_extra_channels, frame, + ec_frame); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size(); + ProcessEncoder(enc.get(), compressed, next_out, avail_out); + } + + std::vector<uint8_t> streaming_compressed; + // with streaming + { + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + ASSERT_NE(nullptr, enc.get()); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + SetupEncoder(frame_settings, p, basic_info, number_extra_channels, true); + JxlStreamingAdapter streaming_adapter = + JxlStreamingAdapter(enc.get(), p.return_large_buffers(), p.can_seek()); + SetupInputStreaming(frame_settings, p, number_extra_channels, frame, + ec_frame); + streaming_adapter.CheckFinalWatermarkPosition(); + streaming_compressed = std::move(streaming_adapter).output(); + } + + EXPECT_TRUE(SameDecodedPixels(compressed, streaming_compressed)); + EXPECT_LE(streaming_compressed.size(), compressed.size() + 1024); +} + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P( + EncoderStreamingTest, EncoderStreamingTest, + testing::ValuesIn(StreamingTestParam::All())); + +TEST(EncoderTest, CMYK) { + size_t xsize = 257; + size_t ysize = 259; + jxl::test::TestImage image; + image.SetDimensions(xsize, ysize) + .SetDataType(JXL_TYPE_UINT8) + .SetChannels(3) + .SetAllBitDepths(8); + image.AddFrame().RandomFill(); + jxl::test::TestImage ec_image; + ec_image.SetDataType(JXL_TYPE_UINT8) + .SetDimensions(xsize, ysize) + .SetChannels(1) + .SetAllBitDepths(8); + ec_image.AddFrame().RandomFill(); + const auto& frame = image.ppf().frames[0].color; + const auto& ec_frame = ec_image.ppf().frames[0].color; + JxlBasicInfo basic_info = image.ppf().info; + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.num_extra_channels = 1; + basic_info.uses_original_profile = JXL_TRUE; + + std::vector<uint8_t> compressed = std::vector<uint8_t>(64); + JxlEncoderPtr enc_ptr = JxlEncoderMake(nullptr); + JxlEncoderStruct* enc = enc_ptr.get(); + ASSERT_NE(nullptr, enc); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc, NULL); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + JxlExtraChannelInfo channel_info; + JxlExtraChannelType channel_type = JXL_CHANNEL_BLACK; + JxlEncoderInitExtraChannelInfo(channel_type, &channel_info); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetExtraChannelInfo(enc, 0, &channel_info)); + const std::vector<uint8_t> icc = jxl::test::ReadTestData( + "external/Compact-ICC-Profiles/profiles/" + "CGATS001Compat-v2-micro.icc"); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetICCProfile(enc, icc.data(), icc.size())); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &frame.format, + frame.pixels(), frame.pixels_size)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetExtraChannelBuffer( + frame_settings, &ec_frame.format, + ec_frame.pixels(), ec_frame.pixels_size, 0)); + JxlEncoderCloseInput(frame_settings->enc); + uint8_t* next_out = compressed.data(); + size_t avail_out = compressed.size(); + ProcessEncoder(enc, compressed, next_out, avail_out); + + jxl::extras::JXLDecompressParams dparams; + dparams.accepted_formats = { + {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}, + }; + jxl::extras::PackedPixelFile ppf; + EXPECT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), dparams, + nullptr, &ppf, nullptr)); +} diff --git a/third_party/jpeg-xl/lib/jxl/entropy_coder.cc b/third_party/jpeg-xl/lib/jxl/entropy_coder.cc new file mode 100644 index 0000000000..a90ed0257a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/entropy_coder.cc @@ -0,0 +1,69 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/entropy_coder.h" + +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <utility> +#include <vector> + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_context_map.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/pack_signed.h" + +namespace jxl { + +Status DecodeBlockCtxMap(BitReader* br, BlockCtxMap* block_ctx_map) { + auto& dct = block_ctx_map->dc_thresholds; + auto& qft = block_ctx_map->qf_thresholds; + auto& ctx_map = block_ctx_map->ctx_map; + bool is_default = br->ReadFixedBits<1>(); + if (is_default) { + *block_ctx_map = BlockCtxMap(); + return true; + } + block_ctx_map->num_dc_ctxs = 1; + for (int j : {0, 1, 2}) { + dct[j].resize(br->ReadFixedBits<4>()); + block_ctx_map->num_dc_ctxs *= dct[j].size() + 1; + for (int& i : dct[j]) { + i = UnpackSigned(U32Coder::Read(kDCThresholdDist, br)); + } + } + qft.resize(br->ReadFixedBits<4>()); + for (uint32_t& i : qft) { + i = U32Coder::Read(kQFThresholdDist, br) + 1; + } + + if (block_ctx_map->num_dc_ctxs * (qft.size() + 1) > 64) { + return JXL_FAILURE("Invalid block context map: too big"); + } + + ctx_map.resize(3 * kNumOrders * block_ctx_map->num_dc_ctxs * + (qft.size() + 1)); + JXL_RETURN_IF_ERROR(DecodeContextMap(&ctx_map, &block_ctx_map->num_ctxs, br)); + if (block_ctx_map->num_ctxs > 16) { + return JXL_FAILURE("Invalid block context map: too many distinct contexts"); + } + return true; +} + +constexpr uint8_t BlockCtxMap::kDefaultCtxMap[]; // from ac_context.h + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/entropy_coder.h b/third_party/jpeg-xl/lib/jxl/entropy_coder.h new file mode 100644 index 0000000000..e4afa7a631 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/entropy_coder.h @@ -0,0 +1,45 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ENTROPY_CODER_H_ +#define LIB_JXL_ENTROPY_CODER_H_ + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/field_encodings.h" + +// Entropy coding and context modeling of DC and AC coefficients, as well as AC +// strategy and quantization field. + +namespace jxl { + +static JXL_INLINE int32_t PredictFromTopAndLeft( + const int32_t* const JXL_RESTRICT row_top, + const int32_t* const JXL_RESTRICT row, size_t x, int32_t default_val) { + if (x == 0) { + return row_top == nullptr ? default_val : row_top[x]; + } + if (row_top == nullptr) { + return row[x - 1]; + } + return (row_top[x] + row[x - 1] + 1) / 2; +} + +static constexpr U32Enc kDCThresholdDist(Bits(4), BitsOffset(8, 16), + BitsOffset(16, 272), + BitsOffset(32, 65808)); + +static constexpr U32Enc kQFThresholdDist(Bits(2), BitsOffset(3, 4), + BitsOffset(5, 12), BitsOffset(8, 44)); + +Status DecodeBlockCtxMap(BitReader* br, BlockCtxMap* block_ctx_map); + +} // namespace jxl + +#endif // LIB_JXL_ENTROPY_CODER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/entropy_coder_test.cc b/third_party/jpeg-xl/lib/jxl/entropy_coder_test.cc new file mode 100644 index 0000000000..d32fe1b26b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/entropy_coder_test.cc @@ -0,0 +1,68 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TODO(deymo): Move these tests to dec_ans.h and common.h + +#include <stdint.h> + +#include "lib/jxl/base/random.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/pack_signed.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +TEST(EntropyCoderTest, PackUnpack) { + for (int32_t i = -31; i < 32; ++i) { + uint32_t packed = PackSigned(i); + EXPECT_LT(packed, 63u); + int32_t unpacked = UnpackSigned(packed); + EXPECT_EQ(i, unpacked); + } +} + +struct MockBitReader { + uint32_t nbits, bits; + void Consume(uint32_t nbits) {} + uint32_t PeekBits(uint32_t n) { + EXPECT_EQ(n, nbits); + return bits; + } +}; + +void HybridUintRoundtrip(HybridUintConfig config, size_t limit = 1 << 24) { + Rng rng(0); + constexpr size_t kNumIntegers = 1 << 20; + std::vector<uint32_t> integers(kNumIntegers); + std::vector<uint32_t> token(kNumIntegers); + std::vector<uint32_t> nbits(kNumIntegers); + std::vector<uint32_t> bits(kNumIntegers); + for (size_t i = 0; i < kNumIntegers; i++) { + integers[i] = rng.UniformU(0, limit + 1); + config.Encode(integers[i], &token[i], &nbits[i], &bits[i]); + } + for (size_t i = 0; i < kNumIntegers; i++) { + MockBitReader br{nbits[i], bits[i]}; + EXPECT_EQ(integers[i], + ANSSymbolReader::ReadHybridUintConfig(config, token[i], &br)); + } +} + +TEST(HybridUintTest, Test000) { + HybridUintRoundtrip(HybridUintConfig{0, 0, 0}); +} +TEST(HybridUintTest, Test411) { + HybridUintRoundtrip(HybridUintConfig{4, 1, 1}); +} +TEST(HybridUintTest, Test420) { + HybridUintRoundtrip(HybridUintConfig{4, 2, 0}); +} +TEST(HybridUintTest, Test421) { + HybridUintRoundtrip(HybridUintConfig{4, 2, 1}, 256); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/epf.cc b/third_party/jpeg-xl/lib/jxl/epf.cc new file mode 100644 index 0000000000..78ef38bfd5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/epf.cc @@ -0,0 +1,144 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Edge-preserving smoothing: weighted average based on L1 patch similarity. + +#include "lib/jxl/epf.h" + +#include <math.h> +#include <stdint.h> +#include <stdlib.h> +#include <string.h> + +#include <algorithm> +#include <atomic> +#include <numeric> // std::accumulate +#include <vector> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { + +// Mirror n floats starting at *p and store them before p. +JXL_INLINE void LeftMirror(float* p, size_t n) { + for (size_t i = 0; i < n; i++) { + *(p - 1 - i) = p[i]; + } +} + +// Mirror n floats starting at *(p - n) and store them at *p. +JXL_INLINE void RightMirror(float* p, size_t n) { + for (size_t i = 0; i < n; i++) { + p[i] = *(p - 1 - i); + } +} + +void ComputeSigma(const LoopFilter& lf, const Rect& block_rect, + PassesDecoderState* state) { + JXL_CHECK(lf.epf_iters > 0); + const AcStrategyImage& ac_strategy = state->shared->ac_strategy; + const float quant_scale = state->shared->quantizer.Scale(); + + const size_t sigma_stride = state->sigma.PixelsPerRow(); + const size_t sharpness_stride = state->shared->epf_sharpness.PixelsPerRow(); + + for (size_t by = 0; by < block_rect.ysize(); ++by) { + float* JXL_RESTRICT sigma_row = block_rect.Row(&state->sigma, by); + const uint8_t* JXL_RESTRICT sharpness_row = + block_rect.ConstRow(state->shared->epf_sharpness, by); + AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by); + const int32_t* const JXL_RESTRICT row_quant = + block_rect.ConstRow(state->shared->raw_quant_field, by); + + for (size_t bx = 0; bx < block_rect.xsize(); bx++) { + AcStrategy acs = acs_row[bx]; + size_t llf_x = acs.covered_blocks_x(); + if (!acs.IsFirstBlock()) continue; + // quant_scale is smaller for low quality. + // quant_scale is roughly 0.08 / butteraugli score. + // + // row_quant is smaller for low quality. + // row_quant is a quantization multiplier of form 1.0 / + // row_quant[bx] + // + // lf.epf_quant_mul is a parameter in the format + // kInvSigmaNum is a constant + float sigma_quant = + lf.epf_quant_mul / (quant_scale * row_quant[bx] * kInvSigmaNum); + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + float sigma = + sigma_quant * + lf.epf_sharp_lut[sharpness_row[bx + ix + iy * sharpness_stride]]; + // Avoid infinities. + sigma = std::min(-1e-4f, sigma); // TODO(veluca): remove this. + sigma_row[bx + ix + kSigmaPadding + + (iy + kSigmaPadding) * sigma_stride] = 1.0f / sigma; + } + } + // TODO(veluca): remove this padding. + // Left padding with mirroring. + if (bx + block_rect.x0() == 0) { + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + LeftMirror( + sigma_row + kSigmaPadding + (iy + kSigmaPadding) * sigma_stride, + kSigmaBorder); + } + } + // Right padding with mirroring. + if (bx + block_rect.x0() + llf_x == + state->shared->frame_dim.xsize_blocks) { + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + RightMirror(sigma_row + kSigmaPadding + bx + llf_x + + (iy + kSigmaPadding) * sigma_stride, + kSigmaBorder); + } + } + // Offsets for row copying, in blocks. + size_t offset_before = bx + block_rect.x0() == 0 ? 1 : bx + kSigmaPadding; + size_t offset_after = + bx + block_rect.x0() + llf_x == state->shared->frame_dim.xsize_blocks + ? kSigmaPadding + llf_x + bx + kSigmaBorder + : kSigmaPadding + llf_x + bx; + size_t num = offset_after - offset_before; + // Above + if (by + block_rect.y0() == 0) { + for (size_t iy = 0; iy < kSigmaBorder; iy++) { + memcpy( + sigma_row + offset_before + + (kSigmaPadding - 1 - iy) * sigma_stride, + sigma_row + offset_before + (kSigmaPadding + iy) * sigma_stride, + num * sizeof(*sigma_row)); + } + } + // Below + if (by + block_rect.y0() + acs.covered_blocks_y() == + state->shared->frame_dim.ysize_blocks) { + for (size_t iy = 0; iy < kSigmaBorder; iy++) { + memcpy( + sigma_row + offset_before + + sigma_stride * (acs.covered_blocks_y() + kSigmaPadding + iy), + sigma_row + offset_before + + sigma_stride * + (acs.covered_blocks_y() + kSigmaPadding - 1 - iy), + num * sizeof(*sigma_row)); + } + } + } + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/epf.h b/third_party/jpeg-xl/lib/jxl/epf.h new file mode 100644 index 0000000000..808dde10dc --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/epf.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_EPF_H_ +#define LIB_JXL_EPF_H_ + +// Fast SIMD "in-loop" edge preserving filter (adaptive, nonlinear). + +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/image.h" +#include "lib/jxl/loop_filter.h" + +namespace jxl { + +// 4 * (sqrt(0.5)-1), so that Weight(sigma) = 0.5. +static constexpr float kInvSigmaNum = -1.1715728752538099024f; + +// kInvSigmaNum / 0.3 +constexpr float kMinSigma = -3.90524291751269967465540850526868f; + +// Fills the `state->filter_weights.sigma` image with the precomputed sigma +// values in the area inside `block_rect`. Accesses the AC strategy, quant field +// and epf_sharpness fields in the corresponding positions. +void ComputeSigma(const LoopFilter& lf, const Rect& block_rect, + PassesDecoderState* state); + +} // namespace jxl + +#endif // LIB_JXL_EPF_H_ diff --git a/third_party/jpeg-xl/lib/jxl/fake_parallel_runner_testonly.h b/third_party/jpeg-xl/lib/jxl/fake_parallel_runner_testonly.h new file mode 100644 index 0000000000..508d808cc5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fake_parallel_runner_testonly.h @@ -0,0 +1,79 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_FAKE_PARALLEL_RUNNER_TESTONLY_H_ +#define LIB_JXL_FAKE_PARALLEL_RUNNER_TESTONLY_H_ + +#include <jxl/parallel_runner.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/random.h" + +namespace jxl { + +// A parallel runner implementation that runs all the jobs in a single thread +// (the caller thread) but runs them pretending to use multiple threads and +// potentially out of order. This is useful for testing conditions that only +// occur under heavy load where the order of operations is different. +class FakeParallelRunner { + public: + FakeParallelRunner(uint32_t order_seed, uint32_t num_threads) + : order_seed_(order_seed), rng_(order_seed), num_threads_(num_threads) { + if (num_threads_ < 1) num_threads_ = 1; + } + + JxlParallelRetCode Run(void* jxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start, + uint32_t end) { + JxlParallelRetCode ret = init(jxl_opaque, num_threads_); + if (ret != 0) return ret; + + if (order_seed_ == 0) { + for (uint32_t i = start; i < end; i++) { + func(jxl_opaque, i, i % num_threads_); + } + } else { + std::vector<uint32_t> order(end - start); + for (uint32_t i = start; i < end; i++) { + order[i - start] = i; + } + rng_.Shuffle(order.data(), order.size()); + for (uint32_t i = start; i < end; i++) { + func(jxl_opaque, order[i - start], i % num_threads_); + } + } + return ret; + } + + private: + // Seed for the RNG for defining the execution order. A value of 0 means + // sequential order from start to end. + uint32_t order_seed_; + + // The PRNG object, initialized with the order_seed_. Only used if the seed is + // not 0. + Rng rng_; + + // Number of fake threads. All the tasks are run on the same thread, but using + // different thread_id values based on this num_threads. + uint32_t num_threads_; +}; + +} // namespace jxl + +extern "C" { +// Function to pass as the parallel runner. +JXL_INLINE JxlParallelRetCode JxlFakeParallelRunner( + void* runner_opaque, void* jpegxl_opaque, JxlParallelRunInit init, + JxlParallelRunFunction func, uint32_t start_range, uint32_t end_range) { + return static_cast<jxl::FakeParallelRunner*>(runner_opaque) + ->Run(jpegxl_opaque, init, func, start_range, end_range); +} +} + +#endif // LIB_JXL_FAKE_PARALLEL_RUNNER_TESTONLY_H_ diff --git a/third_party/jpeg-xl/lib/jxl/fast_dct-inl.h b/third_party/jpeg-xl/lib/jxl/fast_dct-inl.h new file mode 100644 index 0000000000..de1f845901 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_dct-inl.h @@ -0,0 +1,239 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_FAST_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_FAST_DCT_INL_H_ +#undef LIB_JXL_FAST_DCT_INL_H_ +#else +#define LIB_JXL_FAST_DCT_INL_H_ +#endif + +#include <cmath> +#include <hwy/aligned_allocator.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/status.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +#if HWY_TARGET == HWY_NEON +HWY_NOINLINE void FastTransposeBlock(const int16_t* JXL_RESTRICT data_in, + size_t stride_in, size_t N, size_t M, + int16_t* JXL_RESTRICT data_out, + size_t stride_out) { + JXL_DASSERT(N % 8 == 0); + JXL_DASSERT(M % 8 == 0); + for (size_t i = 0; i < N; i += 8) { + for (size_t j = 0; j < M; j += 8) { + // TODO(veluca): one could optimize the M==8, stride_in==8 case further + // with vld4. + // This code is about 40% faster for N == M == stride_in == + // stride_out == 8 + // Using loads + stores to reshuffle things to be able to + // use vld4 doesn't help. + /* + auto a0 = vld4q_s16(data_in); auto a1 = vld4q_s16(data_in + 32); + int16x8x4_t out0; + int16x8x4_t out1; + out0.val[0] = vuzp1q_s16(a0.val[0], a1.val[0]); + out0.val[1] = vuzp1q_s16(a0.val[1], a1.val[1]); + out0.val[2] = vuzp1q_s16(a0.val[2], a1.val[2]); + out0.val[3] = vuzp1q_s16(a0.val[3], a1.val[3]); + out1.val[0] = vuzp2q_s16(a0.val[0], a1.val[0]); + out1.val[1] = vuzp2q_s16(a0.val[1], a1.val[1]); + out1.val[2] = vuzp2q_s16(a0.val[2], a1.val[2]); + out1.val[3] = vuzp2q_s16(a0.val[3], a1.val[3]); + vst1q_s16_x4(data_out, out0); + vst1q_s16_x4(data_out + 32, out1); + */ + auto a0 = vld1q_s16(data_in + i * stride_in + j); + auto a1 = vld1q_s16(data_in + (i + 1) * stride_in + j); + auto a2 = vld1q_s16(data_in + (i + 2) * stride_in + j); + auto a3 = vld1q_s16(data_in + (i + 3) * stride_in + j); + + auto a01 = vtrnq_s16(a0, a1); + auto a23 = vtrnq_s16(a2, a3); + + auto four0 = vtrnq_s32(vreinterpretq_s32_s16(a01.val[0]), + vreinterpretq_s32_s16(a23.val[0])); + auto four1 = vtrnq_s32(vreinterpretq_s32_s16(a01.val[1]), + vreinterpretq_s32_s16(a23.val[1])); + + auto a4 = vld1q_s16(data_in + (i + 4) * stride_in + j); + auto a5 = vld1q_s16(data_in + (i + 5) * stride_in + j); + auto a6 = vld1q_s16(data_in + (i + 6) * stride_in + j); + auto a7 = vld1q_s16(data_in + (i + 7) * stride_in + j); + + auto a45 = vtrnq_s16(a4, a5); + auto a67 = vtrnq_s16(a6, a7); + + auto four2 = vtrnq_s32(vreinterpretq_s32_s16(a45.val[0]), + vreinterpretq_s32_s16(a67.val[0])); + auto four3 = vtrnq_s32(vreinterpretq_s32_s16(a45.val[1]), + vreinterpretq_s32_s16(a67.val[1])); + + auto out0 = + vcombine_s32(vget_low_s32(four0.val[0]), vget_low_s32(four2.val[0])); + auto out1 = + vcombine_s32(vget_low_s32(four1.val[0]), vget_low_s32(four3.val[0])); + auto out2 = + vcombine_s32(vget_low_s32(four0.val[1]), vget_low_s32(four2.val[1])); + auto out3 = + vcombine_s32(vget_low_s32(four1.val[1]), vget_low_s32(four3.val[1])); + auto out4 = vcombine_s32(vget_high_s32(four0.val[0]), + vget_high_s32(four2.val[0])); + auto out5 = vcombine_s32(vget_high_s32(four1.val[0]), + vget_high_s32(four3.val[0])); + auto out6 = vcombine_s32(vget_high_s32(four0.val[1]), + vget_high_s32(four2.val[1])); + auto out7 = vcombine_s32(vget_high_s32(four1.val[1]), + vget_high_s32(four3.val[1])); + vst1q_s16(data_out + j * stride_out + i, vreinterpretq_s16_s32(out0)); + vst1q_s16(data_out + (j + 1) * stride_out + i, + vreinterpretq_s16_s32(out1)); + vst1q_s16(data_out + (j + 2) * stride_out + i, + vreinterpretq_s16_s32(out2)); + vst1q_s16(data_out + (j + 3) * stride_out + i, + vreinterpretq_s16_s32(out3)); + vst1q_s16(data_out + (j + 4) * stride_out + i, + vreinterpretq_s16_s32(out4)); + vst1q_s16(data_out + (j + 5) * stride_out + i, + vreinterpretq_s16_s32(out5)); + vst1q_s16(data_out + (j + 6) * stride_out + i, + vreinterpretq_s16_s32(out6)); + vst1q_s16(data_out + (j + 7) * stride_out + i, + vreinterpretq_s16_s32(out7)); + } + } +} + +template <size_t N> +struct FastDCTTag {}; + +#include "lib/jxl/fast_dct128-inl.h" +#include "lib/jxl/fast_dct16-inl.h" +#include "lib/jxl/fast_dct256-inl.h" +#include "lib/jxl/fast_dct32-inl.h" +#include "lib/jxl/fast_dct64-inl.h" +#include "lib/jxl/fast_dct8-inl.h" + +template <size_t ROWS, size_t COLS> +struct ComputeFastScaledIDCT { + // scratch_space must be aligned, and should have space for ROWS*COLS + // int16_ts. + HWY_MAYBE_UNUSED void operator()(int16_t* JXL_RESTRICT from, int16_t* to, + size_t to_stride, + int16_t* JXL_RESTRICT scratch_space) { + // Reverse the steps done in ComputeScaledDCT. + if (ROWS < COLS) { + FastTransposeBlock(from, COLS, ROWS, COLS, scratch_space, ROWS); + FastIDCT(FastDCTTag<COLS>(), scratch_space, ROWS, from, ROWS, ROWS); + FastTransposeBlock(from, ROWS, COLS, ROWS, scratch_space, COLS); + FastIDCT(FastDCTTag<ROWS>(), scratch_space, COLS, to, to_stride, COLS); + } else { + FastIDCT(FastDCTTag<COLS>(), from, ROWS, scratch_space, ROWS, ROWS); + FastTransposeBlock(scratch_space, ROWS, COLS, ROWS, from, COLS); + FastIDCT(FastDCTTag<ROWS>(), from, COLS, to, to_stride, COLS); + } + } +}; +#endif + +template <size_t N, size_t M> +HWY_NOINLINE void TestFastIDCT() { +#if HWY_TARGET == HWY_NEON + auto pixels_mem = hwy::AllocateAligned<float>(N * M); + float* pixels = pixels_mem.get(); + auto dct_mem = hwy::AllocateAligned<float>(N * M); + float* dct = dct_mem.get(); + auto dct_i_mem = hwy::AllocateAligned<int16_t>(N * M); + int16_t* dct_i = dct_i_mem.get(); + auto dct_in_mem = hwy::AllocateAligned<int16_t>(N * M); + int16_t* dct_in = dct_in_mem.get(); + auto idct_mem = hwy::AllocateAligned<int16_t>(N * M); + int16_t* idct = idct_mem.get(); + + const HWY_FULL(float) df; + auto scratch_space_mem = hwy::AllocateAligned<float>( + N * M * 2 + 3 * std::max(N, M) * MaxLanes(df)); + float* scratch_space = scratch_space_mem.get(); + auto scratch_space_i_mem = hwy::AllocateAligned<int16_t>(N * M * 2); + int16_t* scratch_space_i = scratch_space_i_mem.get(); + + Rng rng(0); + for (size_t i = 0; i < N * M; i++) { + pixels[i] = rng.UniformF(-1, 1); + } + ComputeScaledDCT<M, N>()(DCTFrom(pixels, N), dct, scratch_space); + size_t integer_bits = std::max(FastIDCTIntegerBits(FastDCTTag<N>()), + FastIDCTIntegerBits(FastDCTTag<M>())); + // Enough range for [-2, 2] output values. + JXL_ASSERT(integer_bits <= 14); + float scale = (1 << (14 - integer_bits)); + for (size_t i = 0; i < N * M; i++) { + dct_i[i] = std::round(dct[i] * scale); + } + + for (size_t j = 0; j < 40000000 / (M * N); j++) { + memcpy(dct_in, dct_i, sizeof(*dct_i) * N * M); + ComputeFastScaledIDCT<M, N>()(dct_in, idct, N, scratch_space_i); + } + float max_error = 0; + for (size_t i = 0; i < M * N; i++) { + float err = std::abs(idct[i] * (1.0f / scale) - pixels[i]); + if (std::abs(err) > max_error) { + max_error = std::abs(err); + } + } + printf("max error: %f mantissa bits: %d\n", max_error, + 14 - (int)integer_bits); +#endif +} + +template <size_t N, size_t M> +HWY_NOINLINE void TestFloatIDCT() { + auto pixels_mem = hwy::AllocateAligned<float>(N * M); + float* pixels = pixels_mem.get(); + auto dct_mem = hwy::AllocateAligned<float>(N * M); + float* dct = dct_mem.get(); + auto idct_mem = hwy::AllocateAligned<float>(N * M); + float* idct = idct_mem.get(); + + auto dct_in_mem = hwy::AllocateAligned<float>(N * M); + float* dct_in = dct_mem.get(); + + auto scratch_space_mem = hwy::AllocateAligned<float>(N * M * 5); + float* scratch_space = scratch_space_mem.get(); + + Rng rng(0); + for (size_t i = 0; i < N * M; i++) { + pixels[i] = rng.UniformF(-1, 1); + } + ComputeScaledDCT<M, N>()(DCTFrom(pixels, N), dct, scratch_space); + + for (size_t j = 0; j < 40000000 / (M * N); j++) { + memcpy(dct_in, dct, sizeof(*dct) * N * M); + ComputeScaledIDCT<M, N>()(dct_in, DCTTo(idct, N), scratch_space); + } + float max_error = 0; + for (size_t i = 0; i < M * N; i++) { + float err = std::abs(idct[i] - pixels[i]); + if (std::abs(err) > max_error) { + max_error = std::abs(err); + } + } + printf("max error: %e\n", max_error); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_FAST_DCT_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/fast_dct.cc b/third_party/jpeg-xl/lib/jxl/fast_dct.cc new file mode 100644 index 0000000000..d796018fd0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_dct.cc @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/fast_dct.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/random.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/fast_dct-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { +void BenchmarkFloatIDCT32x32() { TestFloatIDCT<32, 32>(); } +void BenchmarkFastIDCT32x32() { TestFastIDCT<32, 32>(); } +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(BenchmarkFloatIDCT32x32); +HWY_EXPORT(BenchmarkFastIDCT32x32); +void BenchmarkFloatIDCT32x32() { + HWY_DYNAMIC_DISPATCH(BenchmarkFloatIDCT32x32)(); +} +void BenchmarkFastIDCT32x32() { + HWY_DYNAMIC_DISPATCH(BenchmarkFastIDCT32x32)(); +} +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/fast_dct.h b/third_party/jpeg-xl/lib/jxl/fast_dct.h new file mode 100644 index 0000000000..641933d8a0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_dct.h @@ -0,0 +1,9 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +namespace jxl { +void BenchmarkFloatIDCT32x32(); +void BenchmarkFastIDCT32x32(); +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/fast_dct128-inl.h b/third_party/jpeg-xl/lib/jxl/fast_dct128-inl.h new file mode 100644 index 0000000000..1a94d3ee92 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_dct128-inl.h @@ -0,0 +1,2137 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<128>) { return 2; } + +void FastIDCT(FastDCTTag<128>, const int16_t* in, size_t in_stride, + int16_t* out, size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 64 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 32 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 96 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 16 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 80 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 48 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vaddq_s16(v13, v10); + int16x8_t v17_tmp = vqrdmulhq_n_s16(v16, 13573); + int16x8_t v17 = vaddq_s16(v17_tmp, v16); + int16x8_t v18 = vld1q_s16(in + in_stride * 112 + i); + int16x8_t v19 = vaddq_s16(v18, v12); + int16x8_t v20 = vaddq_s16(v19, v16); + int16x8_t v21 = vaddq_s16(v17, v20); + int16x8_t v22 = vqrdmulhq_n_s16(v21, 17734); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vld1q_s16(in + in_stride * 8 + i); + int16x8_t v27_tmp = vqrdmulhq_n_s16(v26, 13573); + int16x8_t v27 = vaddq_s16(v27_tmp, v26); + int16x8_t v28 = vld1q_s16(in + in_stride * 72 + i); + int16x8_t v29 = vld1q_s16(in + in_stride * 56 + i); + int16x8_t v30 = vaddq_s16(v28, v29); + int16x8_t v31 = vaddq_s16(v27, v30); + int16x8_t v32 = vld1q_s16(in + in_stride * 40 + i); + int16x8_t v33 = vld1q_s16(in + in_stride * 24 + i); + int16x8_t v34 = vaddq_s16(v32, v33); + int16x8_t v35_tmp = vqrdmulhq_n_s16(v34, 13573); + int16x8_t v35 = vaddq_s16(v35_tmp, v34); + int16x8_t v36 = vld1q_s16(in + in_stride * 104 + i); + int16x8_t v37 = vld1q_s16(in + in_stride * 88 + i); + int16x8_t v38 = vaddq_s16(v36, v37); + int16x8_t v39 = vaddq_s16(v38, v34); + int16x8_t v40 = vaddq_s16(v35, v39); + int16x8_t v41 = vqrdmulhq_n_s16(v40, 17734); + int16x8_t v42 = vaddq_s16(v31, v41); + int16x8_t v43 = vaddq_s16(v33, v26); + int16x8_t v44_tmp = vqrdmulhq_n_s16(v43, 13573); + int16x8_t v44 = vaddq_s16(v44_tmp, v43); + int16x8_t v45 = vaddq_s16(v37, v28); + int16x8_t v46 = vaddq_s16(v29, v32); + int16x8_t v47 = vaddq_s16(v45, v46); + int16x8_t v48 = vaddq_s16(v44, v47); + int16x8_t v49 = vaddq_s16(v46, v43); + int16x8_t v50_tmp = vqrdmulhq_n_s16(v49, 13573); + int16x8_t v50 = vaddq_s16(v50_tmp, v49); + int16x8_t v51 = vld1q_s16(in + in_stride * 120 + i); + int16x8_t v52 = vaddq_s16(v51, v36); + int16x8_t v53 = vaddq_s16(v52, v45); + int16x8_t v54 = vaddq_s16(v53, v49); + int16x8_t v55 = vaddq_s16(v50, v54); + int16x8_t v56 = vqrdmulhq_n_s16(v55, 17734); + int16x8_t v57 = vaddq_s16(v48, v56); + int16x8_t v58 = vqrdmulhq_n_s16(v57, 16705); + int16x8_t v59 = vaddq_s16(v42, v58); + int16x8_t v60 = vqrdmulhq_n_s16(v59, 16463); + int16x8_t v61 = vaddq_s16(v25, v60); + int16x8_t v62 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v63_tmp = vqrdmulhq_n_s16(v62, 13573); + int16x8_t v63 = vaddq_s16(v63_tmp, v62); + int16x8_t v64 = vld1q_s16(in + in_stride * 68 + i); + int16x8_t v65 = vld1q_s16(in + in_stride * 60 + i); + int16x8_t v66 = vaddq_s16(v64, v65); + int16x8_t v67 = vaddq_s16(v63, v66); + int16x8_t v68 = vld1q_s16(in + in_stride * 36 + i); + int16x8_t v69 = vld1q_s16(in + in_stride * 28 + i); + int16x8_t v70 = vaddq_s16(v68, v69); + int16x8_t v71_tmp = vqrdmulhq_n_s16(v70, 13573); + int16x8_t v71 = vaddq_s16(v71_tmp, v70); + int16x8_t v72 = vld1q_s16(in + in_stride * 100 + i); + int16x8_t v73 = vld1q_s16(in + in_stride * 92 + i); + int16x8_t v74 = vaddq_s16(v72, v73); + int16x8_t v75 = vaddq_s16(v74, v70); + int16x8_t v76 = vaddq_s16(v71, v75); + int16x8_t v77 = vqrdmulhq_n_s16(v76, 17734); + int16x8_t v78 = vaddq_s16(v67, v77); + int16x8_t v79 = vld1q_s16(in + in_stride * 20 + i); + int16x8_t v80 = vld1q_s16(in + in_stride * 12 + i); + int16x8_t v81 = vaddq_s16(v79, v80); + int16x8_t v82_tmp = vqrdmulhq_n_s16(v81, 13573); + int16x8_t v82 = vaddq_s16(v82_tmp, v81); + int16x8_t v83 = vld1q_s16(in + in_stride * 84 + i); + int16x8_t v84 = vld1q_s16(in + in_stride * 76 + i); + int16x8_t v85 = vaddq_s16(v83, v84); + int16x8_t v86 = vld1q_s16(in + in_stride * 52 + i); + int16x8_t v87 = vld1q_s16(in + in_stride * 44 + i); + int16x8_t v88 = vaddq_s16(v86, v87); + int16x8_t v89 = vaddq_s16(v85, v88); + int16x8_t v90 = vaddq_s16(v82, v89); + int16x8_t v91 = vaddq_s16(v88, v81); + int16x8_t v92_tmp = vqrdmulhq_n_s16(v91, 13573); + int16x8_t v92 = vaddq_s16(v92_tmp, v91); + int16x8_t v93 = vld1q_s16(in + in_stride * 116 + i); + int16x8_t v94 = vld1q_s16(in + in_stride * 108 + i); + int16x8_t v95 = vaddq_s16(v93, v94); + int16x8_t v96 = vaddq_s16(v95, v85); + int16x8_t v97 = vaddq_s16(v96, v91); + int16x8_t v98 = vaddq_s16(v92, v97); + int16x8_t v99 = vqrdmulhq_n_s16(v98, 17734); + int16x8_t v100 = vaddq_s16(v90, v99); + int16x8_t v101 = vqrdmulhq_n_s16(v100, 16705); + int16x8_t v102 = vaddq_s16(v78, v101); + int16x8_t v103 = vaddq_s16(v80, v62); + int16x8_t v104_tmp = vqrdmulhq_n_s16(v103, 13573); + int16x8_t v104 = vaddq_s16(v104_tmp, v103); + int16x8_t v105 = vaddq_s16(v84, v64); + int16x8_t v106 = vaddq_s16(v65, v86); + int16x8_t v107 = vaddq_s16(v105, v106); + int16x8_t v108 = vaddq_s16(v104, v107); + int16x8_t v109 = vaddq_s16(v87, v68); + int16x8_t v110 = vaddq_s16(v69, v79); + int16x8_t v111 = vaddq_s16(v109, v110); + int16x8_t v112_tmp = vqrdmulhq_n_s16(v111, 13573); + int16x8_t v112 = vaddq_s16(v112_tmp, v111); + int16x8_t v113 = vaddq_s16(v94, v72); + int16x8_t v114 = vaddq_s16(v73, v83); + int16x8_t v115 = vaddq_s16(v113, v114); + int16x8_t v116 = vaddq_s16(v115, v111); + int16x8_t v117 = vaddq_s16(v112, v116); + int16x8_t v118 = vqrdmulhq_n_s16(v117, 17734); + int16x8_t v119 = vaddq_s16(v108, v118); + int16x8_t v120 = vaddq_s16(v110, v103); + int16x8_t v121_tmp = vqrdmulhq_n_s16(v120, 13573); + int16x8_t v121 = vaddq_s16(v121_tmp, v120); + int16x8_t v122 = vaddq_s16(v114, v105); + int16x8_t v123 = vaddq_s16(v106, v109); + int16x8_t v124 = vaddq_s16(v122, v123); + int16x8_t v125 = vaddq_s16(v121, v124); + int16x8_t v126 = vaddq_s16(v123, v120); + int16x8_t v127_tmp = vqrdmulhq_n_s16(v126, 13573); + int16x8_t v127 = vaddq_s16(v127_tmp, v126); + int16x8_t v128 = vld1q_s16(in + in_stride * 124 + i); + int16x8_t v129 = vaddq_s16(v128, v93); + int16x8_t v130 = vaddq_s16(v129, v113); + int16x8_t v131 = vaddq_s16(v130, v122); + int16x8_t v132 = vaddq_s16(v131, v126); + int16x8_t v133 = vaddq_s16(v127, v132); + int16x8_t v134 = vqrdmulhq_n_s16(v133, 17734); + int16x8_t v135 = vaddq_s16(v125, v134); + int16x8_t v136 = vqrdmulhq_n_s16(v135, 16705); + int16x8_t v137 = vaddq_s16(v119, v136); + int16x8_t v138 = vqrdmulhq_n_s16(v137, 16463); + int16x8_t v139 = vaddq_s16(v102, v138); + int16x8_t v140 = vqrdmulhq_n_s16(v139, 16404); + int16x8_t v141 = vaddq_s16(v61, v140); + int16x8_t v142 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v143_tmp = vqrdmulhq_n_s16(v142, 13573); + int16x8_t v143 = vaddq_s16(v143_tmp, v142); + int16x8_t v144 = vld1q_s16(in + in_stride * 66 + i); + int16x8_t v145 = vld1q_s16(in + in_stride * 62 + i); + int16x8_t v146 = vaddq_s16(v144, v145); + int16x8_t v147 = vaddq_s16(v143, v146); + int16x8_t v148 = vld1q_s16(in + in_stride * 34 + i); + int16x8_t v149 = vld1q_s16(in + in_stride * 30 + i); + int16x8_t v150 = vaddq_s16(v148, v149); + int16x8_t v151_tmp = vqrdmulhq_n_s16(v150, 13573); + int16x8_t v151 = vaddq_s16(v151_tmp, v150); + int16x8_t v152 = vld1q_s16(in + in_stride * 98 + i); + int16x8_t v153 = vld1q_s16(in + in_stride * 94 + i); + int16x8_t v154 = vaddq_s16(v152, v153); + int16x8_t v155 = vaddq_s16(v154, v150); + int16x8_t v156 = vaddq_s16(v151, v155); + int16x8_t v157 = vqrdmulhq_n_s16(v156, 17734); + int16x8_t v158 = vaddq_s16(v147, v157); + int16x8_t v159 = vld1q_s16(in + in_stride * 18 + i); + int16x8_t v160 = vld1q_s16(in + in_stride * 14 + i); + int16x8_t v161 = vaddq_s16(v159, v160); + int16x8_t v162_tmp = vqrdmulhq_n_s16(v161, 13573); + int16x8_t v162 = vaddq_s16(v162_tmp, v161); + int16x8_t v163 = vld1q_s16(in + in_stride * 82 + i); + int16x8_t v164 = vld1q_s16(in + in_stride * 78 + i); + int16x8_t v165 = vaddq_s16(v163, v164); + int16x8_t v166 = vld1q_s16(in + in_stride * 50 + i); + int16x8_t v167 = vld1q_s16(in + in_stride * 46 + i); + int16x8_t v168 = vaddq_s16(v166, v167); + int16x8_t v169 = vaddq_s16(v165, v168); + int16x8_t v170 = vaddq_s16(v162, v169); + int16x8_t v171 = vaddq_s16(v168, v161); + int16x8_t v172_tmp = vqrdmulhq_n_s16(v171, 13573); + int16x8_t v172 = vaddq_s16(v172_tmp, v171); + int16x8_t v173 = vld1q_s16(in + in_stride * 114 + i); + int16x8_t v174 = vld1q_s16(in + in_stride * 110 + i); + int16x8_t v175 = vaddq_s16(v173, v174); + int16x8_t v176 = vaddq_s16(v175, v165); + int16x8_t v177 = vaddq_s16(v176, v171); + int16x8_t v178 = vaddq_s16(v172, v177); + int16x8_t v179 = vqrdmulhq_n_s16(v178, 17734); + int16x8_t v180 = vaddq_s16(v170, v179); + int16x8_t v181 = vqrdmulhq_n_s16(v180, 16705); + int16x8_t v182 = vaddq_s16(v158, v181); + int16x8_t v183 = vld1q_s16(in + in_stride * 10 + i); + int16x8_t v184 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v185 = vaddq_s16(v183, v184); + int16x8_t v186_tmp = vqrdmulhq_n_s16(v185, 13573); + int16x8_t v186 = vaddq_s16(v186_tmp, v185); + int16x8_t v187 = vld1q_s16(in + in_stride * 74 + i); + int16x8_t v188 = vld1q_s16(in + in_stride * 70 + i); + int16x8_t v189 = vaddq_s16(v187, v188); + int16x8_t v190 = vld1q_s16(in + in_stride * 58 + i); + int16x8_t v191 = vld1q_s16(in + in_stride * 54 + i); + int16x8_t v192 = vaddq_s16(v190, v191); + int16x8_t v193 = vaddq_s16(v189, v192); + int16x8_t v194 = vaddq_s16(v186, v193); + int16x8_t v195 = vld1q_s16(in + in_stride * 42 + i); + int16x8_t v196 = vld1q_s16(in + in_stride * 38 + i); + int16x8_t v197 = vaddq_s16(v195, v196); + int16x8_t v198 = vld1q_s16(in + in_stride * 26 + i); + int16x8_t v199 = vld1q_s16(in + in_stride * 22 + i); + int16x8_t v200 = vaddq_s16(v198, v199); + int16x8_t v201 = vaddq_s16(v197, v200); + int16x8_t v202_tmp = vqrdmulhq_n_s16(v201, 13573); + int16x8_t v202 = vaddq_s16(v202_tmp, v201); + int16x8_t v203 = vld1q_s16(in + in_stride * 106 + i); + int16x8_t v204 = vld1q_s16(in + in_stride * 102 + i); + int16x8_t v205 = vaddq_s16(v203, v204); + int16x8_t v206 = vld1q_s16(in + in_stride * 90 + i); + int16x8_t v207 = vld1q_s16(in + in_stride * 86 + i); + int16x8_t v208 = vaddq_s16(v206, v207); + int16x8_t v209 = vaddq_s16(v205, v208); + int16x8_t v210 = vaddq_s16(v209, v201); + int16x8_t v211 = vaddq_s16(v202, v210); + int16x8_t v212 = vqrdmulhq_n_s16(v211, 17734); + int16x8_t v213 = vaddq_s16(v194, v212); + int16x8_t v214 = vaddq_s16(v200, v185); + int16x8_t v215_tmp = vqrdmulhq_n_s16(v214, 13573); + int16x8_t v215 = vaddq_s16(v215_tmp, v214); + int16x8_t v216 = vaddq_s16(v208, v189); + int16x8_t v217 = vaddq_s16(v192, v197); + int16x8_t v218 = vaddq_s16(v216, v217); + int16x8_t v219 = vaddq_s16(v215, v218); + int16x8_t v220 = vaddq_s16(v217, v214); + int16x8_t v221_tmp = vqrdmulhq_n_s16(v220, 13573); + int16x8_t v221 = vaddq_s16(v221_tmp, v220); + int16x8_t v222 = vld1q_s16(in + in_stride * 122 + i); + int16x8_t v223 = vld1q_s16(in + in_stride * 118 + i); + int16x8_t v224 = vaddq_s16(v222, v223); + int16x8_t v225 = vaddq_s16(v224, v205); + int16x8_t v226 = vaddq_s16(v225, v216); + int16x8_t v227 = vaddq_s16(v226, v220); + int16x8_t v228 = vaddq_s16(v221, v227); + int16x8_t v229 = vqrdmulhq_n_s16(v228, 17734); + int16x8_t v230 = vaddq_s16(v219, v229); + int16x8_t v231 = vqrdmulhq_n_s16(v230, 16705); + int16x8_t v232 = vaddq_s16(v213, v231); + int16x8_t v233 = vqrdmulhq_n_s16(v232, 16463); + int16x8_t v234 = vaddq_s16(v182, v233); + int16x8_t v235 = vaddq_s16(v184, v142); + int16x8_t v236_tmp = vqrdmulhq_n_s16(v235, 13573); + int16x8_t v236 = vaddq_s16(v236_tmp, v235); + int16x8_t v237 = vaddq_s16(v188, v144); + int16x8_t v238 = vaddq_s16(v145, v190); + int16x8_t v239 = vaddq_s16(v237, v238); + int16x8_t v240 = vaddq_s16(v236, v239); + int16x8_t v241 = vaddq_s16(v196, v148); + int16x8_t v242 = vaddq_s16(v149, v198); + int16x8_t v243 = vaddq_s16(v241, v242); + int16x8_t v244_tmp = vqrdmulhq_n_s16(v243, 13573); + int16x8_t v244 = vaddq_s16(v244_tmp, v243); + int16x8_t v245 = vaddq_s16(v204, v152); + int16x8_t v246 = vaddq_s16(v153, v206); + int16x8_t v247 = vaddq_s16(v245, v246); + int16x8_t v248 = vaddq_s16(v247, v243); + int16x8_t v249 = vaddq_s16(v244, v248); + int16x8_t v250 = vqrdmulhq_n_s16(v249, 17734); + int16x8_t v251 = vaddq_s16(v240, v250); + int16x8_t v252 = vaddq_s16(v199, v159); + int16x8_t v253 = vaddq_s16(v160, v183); + int16x8_t v254 = vaddq_s16(v252, v253); + int16x8_t v255_tmp = vqrdmulhq_n_s16(v254, 13573); + int16x8_t v255 = vaddq_s16(v255_tmp, v254); + int16x8_t v256 = vaddq_s16(v207, v163); + int16x8_t v257 = vaddq_s16(v164, v187); + int16x8_t v258 = vaddq_s16(v256, v257); + int16x8_t v259 = vaddq_s16(v191, v166); + int16x8_t v260 = vaddq_s16(v167, v195); + int16x8_t v261 = vaddq_s16(v259, v260); + int16x8_t v262 = vaddq_s16(v258, v261); + int16x8_t v263 = vaddq_s16(v255, v262); + int16x8_t v264 = vaddq_s16(v261, v254); + int16x8_t v265_tmp = vqrdmulhq_n_s16(v264, 13573); + int16x8_t v265 = vaddq_s16(v265_tmp, v264); + int16x8_t v266 = vaddq_s16(v223, v173); + int16x8_t v267 = vaddq_s16(v174, v203); + int16x8_t v268 = vaddq_s16(v266, v267); + int16x8_t v269 = vaddq_s16(v268, v258); + int16x8_t v270 = vaddq_s16(v269, v264); + int16x8_t v271 = vaddq_s16(v265, v270); + int16x8_t v272 = vqrdmulhq_n_s16(v271, 17734); + int16x8_t v273 = vaddq_s16(v263, v272); + int16x8_t v274 = vqrdmulhq_n_s16(v273, 16705); + int16x8_t v275 = vaddq_s16(v251, v274); + int16x8_t v276 = vaddq_s16(v253, v235); + int16x8_t v277_tmp = vqrdmulhq_n_s16(v276, 13573); + int16x8_t v277 = vaddq_s16(v277_tmp, v276); + int16x8_t v278 = vaddq_s16(v257, v237); + int16x8_t v279 = vaddq_s16(v238, v259); + int16x8_t v280 = vaddq_s16(v278, v279); + int16x8_t v281 = vaddq_s16(v277, v280); + int16x8_t v282 = vaddq_s16(v260, v241); + int16x8_t v283 = vaddq_s16(v242, v252); + int16x8_t v284 = vaddq_s16(v282, v283); + int16x8_t v285_tmp = vqrdmulhq_n_s16(v284, 13573); + int16x8_t v285 = vaddq_s16(v285_tmp, v284); + int16x8_t v286 = vaddq_s16(v267, v245); + int16x8_t v287 = vaddq_s16(v246, v256); + int16x8_t v288 = vaddq_s16(v286, v287); + int16x8_t v289 = vaddq_s16(v288, v284); + int16x8_t v290 = vaddq_s16(v285, v289); + int16x8_t v291 = vqrdmulhq_n_s16(v290, 17734); + int16x8_t v292 = vaddq_s16(v281, v291); + int16x8_t v293 = vaddq_s16(v283, v276); + int16x8_t v294_tmp = vqrdmulhq_n_s16(v293, 13573); + int16x8_t v294 = vaddq_s16(v294_tmp, v293); + int16x8_t v295 = vaddq_s16(v287, v278); + int16x8_t v296 = vaddq_s16(v279, v282); + int16x8_t v297 = vaddq_s16(v295, v296); + int16x8_t v298 = vaddq_s16(v294, v297); + int16x8_t v299 = vaddq_s16(v296, v293); + int16x8_t v300_tmp = vqrdmulhq_n_s16(v299, 13573); + int16x8_t v300 = vaddq_s16(v300_tmp, v299); + int16x8_t v301 = vld1q_s16(in + in_stride * 126 + i); + int16x8_t v302 = vaddq_s16(v301, v222); + int16x8_t v303 = vaddq_s16(v302, v266); + int16x8_t v304 = vaddq_s16(v303, v286); + int16x8_t v305 = vaddq_s16(v304, v295); + int16x8_t v306 = vaddq_s16(v305, v299); + int16x8_t v307 = vaddq_s16(v300, v306); + int16x8_t v308 = vqrdmulhq_n_s16(v307, 17734); + int16x8_t v309 = vaddq_s16(v298, v308); + int16x8_t v310 = vqrdmulhq_n_s16(v309, 16705); + int16x8_t v311 = vaddq_s16(v292, v310); + int16x8_t v312 = vqrdmulhq_n_s16(v311, 16463); + int16x8_t v313 = vaddq_s16(v275, v312); + int16x8_t v314 = vqrdmulhq_n_s16(v313, 16404); + int16x8_t v315 = vaddq_s16(v234, v314); + int16x8_t v316 = vqrdmulhq_n_s16(v315, 16389); + int16x8_t v317 = vaddq_s16(v141, v316); + int16x8_t v318 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v319_tmp = vqrdmulhq_n_s16(v318, 13573); + int16x8_t v319 = vaddq_s16(v319_tmp, v318); + int16x8_t v320 = vld1q_s16(in + in_stride * 65 + i); + int16x8_t v321 = vld1q_s16(in + in_stride * 63 + i); + int16x8_t v322 = vaddq_s16(v320, v321); + int16x8_t v323 = vaddq_s16(v319, v322); + int16x8_t v324 = vld1q_s16(in + in_stride * 33 + i); + int16x8_t v325 = vld1q_s16(in + in_stride * 31 + i); + int16x8_t v326 = vaddq_s16(v324, v325); + int16x8_t v327_tmp = vqrdmulhq_n_s16(v326, 13573); + int16x8_t v327 = vaddq_s16(v327_tmp, v326); + int16x8_t v328 = vld1q_s16(in + in_stride * 97 + i); + int16x8_t v329 = vld1q_s16(in + in_stride * 95 + i); + int16x8_t v330 = vaddq_s16(v328, v329); + int16x8_t v331 = vaddq_s16(v330, v326); + int16x8_t v332 = vaddq_s16(v327, v331); + int16x8_t v333 = vqrdmulhq_n_s16(v332, 17734); + int16x8_t v334 = vaddq_s16(v323, v333); + int16x8_t v335 = vld1q_s16(in + in_stride * 17 + i); + int16x8_t v336 = vld1q_s16(in + in_stride * 15 + i); + int16x8_t v337 = vaddq_s16(v335, v336); + int16x8_t v338_tmp = vqrdmulhq_n_s16(v337, 13573); + int16x8_t v338 = vaddq_s16(v338_tmp, v337); + int16x8_t v339 = vld1q_s16(in + in_stride * 81 + i); + int16x8_t v340 = vld1q_s16(in + in_stride * 79 + i); + int16x8_t v341 = vaddq_s16(v339, v340); + int16x8_t v342 = vld1q_s16(in + in_stride * 49 + i); + int16x8_t v343 = vld1q_s16(in + in_stride * 47 + i); + int16x8_t v344 = vaddq_s16(v342, v343); + int16x8_t v345 = vaddq_s16(v341, v344); + int16x8_t v346 = vaddq_s16(v338, v345); + int16x8_t v347 = vaddq_s16(v344, v337); + int16x8_t v348_tmp = vqrdmulhq_n_s16(v347, 13573); + int16x8_t v348 = vaddq_s16(v348_tmp, v347); + int16x8_t v349 = vld1q_s16(in + in_stride * 113 + i); + int16x8_t v350 = vld1q_s16(in + in_stride * 111 + i); + int16x8_t v351 = vaddq_s16(v349, v350); + int16x8_t v352 = vaddq_s16(v351, v341); + int16x8_t v353 = vaddq_s16(v352, v347); + int16x8_t v354 = vaddq_s16(v348, v353); + int16x8_t v355 = vqrdmulhq_n_s16(v354, 17734); + int16x8_t v356 = vaddq_s16(v346, v355); + int16x8_t v357 = vqrdmulhq_n_s16(v356, 16705); + int16x8_t v358 = vaddq_s16(v334, v357); + int16x8_t v359 = vld1q_s16(in + in_stride * 9 + i); + int16x8_t v360 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v361 = vaddq_s16(v359, v360); + int16x8_t v362_tmp = vqrdmulhq_n_s16(v361, 13573); + int16x8_t v362 = vaddq_s16(v362_tmp, v361); + int16x8_t v363 = vld1q_s16(in + in_stride * 73 + i); + int16x8_t v364 = vld1q_s16(in + in_stride * 71 + i); + int16x8_t v365 = vaddq_s16(v363, v364); + int16x8_t v366 = vld1q_s16(in + in_stride * 57 + i); + int16x8_t v367 = vld1q_s16(in + in_stride * 55 + i); + int16x8_t v368 = vaddq_s16(v366, v367); + int16x8_t v369 = vaddq_s16(v365, v368); + int16x8_t v370 = vaddq_s16(v362, v369); + int16x8_t v371 = vld1q_s16(in + in_stride * 41 + i); + int16x8_t v372 = vld1q_s16(in + in_stride * 39 + i); + int16x8_t v373 = vaddq_s16(v371, v372); + int16x8_t v374 = vld1q_s16(in + in_stride * 25 + i); + int16x8_t v375 = vld1q_s16(in + in_stride * 23 + i); + int16x8_t v376 = vaddq_s16(v374, v375); + int16x8_t v377 = vaddq_s16(v373, v376); + int16x8_t v378_tmp = vqrdmulhq_n_s16(v377, 13573); + int16x8_t v378 = vaddq_s16(v378_tmp, v377); + int16x8_t v379 = vld1q_s16(in + in_stride * 105 + i); + int16x8_t v380 = vld1q_s16(in + in_stride * 103 + i); + int16x8_t v381 = vaddq_s16(v379, v380); + int16x8_t v382 = vld1q_s16(in + in_stride * 89 + i); + int16x8_t v383 = vld1q_s16(in + in_stride * 87 + i); + int16x8_t v384 = vaddq_s16(v382, v383); + int16x8_t v385 = vaddq_s16(v381, v384); + int16x8_t v386 = vaddq_s16(v385, v377); + int16x8_t v387 = vaddq_s16(v378, v386); + int16x8_t v388 = vqrdmulhq_n_s16(v387, 17734); + int16x8_t v389 = vaddq_s16(v370, v388); + int16x8_t v390 = vaddq_s16(v376, v361); + int16x8_t v391_tmp = vqrdmulhq_n_s16(v390, 13573); + int16x8_t v391 = vaddq_s16(v391_tmp, v390); + int16x8_t v392 = vaddq_s16(v384, v365); + int16x8_t v393 = vaddq_s16(v368, v373); + int16x8_t v394 = vaddq_s16(v392, v393); + int16x8_t v395 = vaddq_s16(v391, v394); + int16x8_t v396 = vaddq_s16(v393, v390); + int16x8_t v397_tmp = vqrdmulhq_n_s16(v396, 13573); + int16x8_t v397 = vaddq_s16(v397_tmp, v396); + int16x8_t v398 = vld1q_s16(in + in_stride * 121 + i); + int16x8_t v399 = vld1q_s16(in + in_stride * 119 + i); + int16x8_t v400 = vaddq_s16(v398, v399); + int16x8_t v401 = vaddq_s16(v400, v381); + int16x8_t v402 = vaddq_s16(v401, v392); + int16x8_t v403 = vaddq_s16(v402, v396); + int16x8_t v404 = vaddq_s16(v397, v403); + int16x8_t v405 = vqrdmulhq_n_s16(v404, 17734); + int16x8_t v406 = vaddq_s16(v395, v405); + int16x8_t v407 = vqrdmulhq_n_s16(v406, 16705); + int16x8_t v408 = vaddq_s16(v389, v407); + int16x8_t v409 = vqrdmulhq_n_s16(v408, 16463); + int16x8_t v410 = vaddq_s16(v358, v409); + int16x8_t v411 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v412 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v413 = vaddq_s16(v411, v412); + int16x8_t v414_tmp = vqrdmulhq_n_s16(v413, 13573); + int16x8_t v414 = vaddq_s16(v414_tmp, v413); + int16x8_t v415 = vld1q_s16(in + in_stride * 69 + i); + int16x8_t v416 = vld1q_s16(in + in_stride * 67 + i); + int16x8_t v417 = vaddq_s16(v415, v416); + int16x8_t v418 = vld1q_s16(in + in_stride * 61 + i); + int16x8_t v419 = vld1q_s16(in + in_stride * 59 + i); + int16x8_t v420 = vaddq_s16(v418, v419); + int16x8_t v421 = vaddq_s16(v417, v420); + int16x8_t v422 = vaddq_s16(v414, v421); + int16x8_t v423 = vld1q_s16(in + in_stride * 37 + i); + int16x8_t v424 = vld1q_s16(in + in_stride * 35 + i); + int16x8_t v425 = vaddq_s16(v423, v424); + int16x8_t v426 = vld1q_s16(in + in_stride * 29 + i); + int16x8_t v427 = vld1q_s16(in + in_stride * 27 + i); + int16x8_t v428 = vaddq_s16(v426, v427); + int16x8_t v429 = vaddq_s16(v425, v428); + int16x8_t v430_tmp = vqrdmulhq_n_s16(v429, 13573); + int16x8_t v430 = vaddq_s16(v430_tmp, v429); + int16x8_t v431 = vld1q_s16(in + in_stride * 101 + i); + int16x8_t v432 = vld1q_s16(in + in_stride * 99 + i); + int16x8_t v433 = vaddq_s16(v431, v432); + int16x8_t v434 = vld1q_s16(in + in_stride * 93 + i); + int16x8_t v435 = vld1q_s16(in + in_stride * 91 + i); + int16x8_t v436 = vaddq_s16(v434, v435); + int16x8_t v437 = vaddq_s16(v433, v436); + int16x8_t v438 = vaddq_s16(v437, v429); + int16x8_t v439 = vaddq_s16(v430, v438); + int16x8_t v440 = vqrdmulhq_n_s16(v439, 17734); + int16x8_t v441 = vaddq_s16(v422, v440); + int16x8_t v442 = vld1q_s16(in + in_stride * 21 + i); + int16x8_t v443 = vld1q_s16(in + in_stride * 19 + i); + int16x8_t v444 = vaddq_s16(v442, v443); + int16x8_t v445 = vld1q_s16(in + in_stride * 13 + i); + int16x8_t v446 = vld1q_s16(in + in_stride * 11 + i); + int16x8_t v447 = vaddq_s16(v445, v446); + int16x8_t v448 = vaddq_s16(v444, v447); + int16x8_t v449_tmp = vqrdmulhq_n_s16(v448, 13573); + int16x8_t v449 = vaddq_s16(v449_tmp, v448); + int16x8_t v450 = vld1q_s16(in + in_stride * 85 + i); + int16x8_t v451 = vld1q_s16(in + in_stride * 83 + i); + int16x8_t v452 = vaddq_s16(v450, v451); + int16x8_t v453 = vld1q_s16(in + in_stride * 77 + i); + int16x8_t v454 = vld1q_s16(in + in_stride * 75 + i); + int16x8_t v455 = vaddq_s16(v453, v454); + int16x8_t v456 = vaddq_s16(v452, v455); + int16x8_t v457 = vld1q_s16(in + in_stride * 53 + i); + int16x8_t v458 = vld1q_s16(in + in_stride * 51 + i); + int16x8_t v459 = vaddq_s16(v457, v458); + int16x8_t v460 = vld1q_s16(in + in_stride * 45 + i); + int16x8_t v461 = vld1q_s16(in + in_stride * 43 + i); + int16x8_t v462 = vaddq_s16(v460, v461); + int16x8_t v463 = vaddq_s16(v459, v462); + int16x8_t v464 = vaddq_s16(v456, v463); + int16x8_t v465 = vaddq_s16(v449, v464); + int16x8_t v466 = vaddq_s16(v463, v448); + int16x8_t v467_tmp = vqrdmulhq_n_s16(v466, 13573); + int16x8_t v467 = vaddq_s16(v467_tmp, v466); + int16x8_t v468 = vld1q_s16(in + in_stride * 117 + i); + int16x8_t v469 = vld1q_s16(in + in_stride * 115 + i); + int16x8_t v470 = vaddq_s16(v468, v469); + int16x8_t v471 = vld1q_s16(in + in_stride * 109 + i); + int16x8_t v472 = vld1q_s16(in + in_stride * 107 + i); + int16x8_t v473 = vaddq_s16(v471, v472); + int16x8_t v474 = vaddq_s16(v470, v473); + int16x8_t v475 = vaddq_s16(v474, v456); + int16x8_t v476 = vaddq_s16(v475, v466); + int16x8_t v477 = vaddq_s16(v467, v476); + int16x8_t v478 = vqrdmulhq_n_s16(v477, 17734); + int16x8_t v479 = vaddq_s16(v465, v478); + int16x8_t v480 = vqrdmulhq_n_s16(v479, 16705); + int16x8_t v481 = vaddq_s16(v441, v480); + int16x8_t v482 = vaddq_s16(v447, v413); + int16x8_t v483_tmp = vqrdmulhq_n_s16(v482, 13573); + int16x8_t v483 = vaddq_s16(v483_tmp, v482); + int16x8_t v484 = vaddq_s16(v455, v417); + int16x8_t v485 = vaddq_s16(v420, v459); + int16x8_t v486 = vaddq_s16(v484, v485); + int16x8_t v487 = vaddq_s16(v483, v486); + int16x8_t v488 = vaddq_s16(v462, v425); + int16x8_t v489 = vaddq_s16(v428, v444); + int16x8_t v490 = vaddq_s16(v488, v489); + int16x8_t v491_tmp = vqrdmulhq_n_s16(v490, 13573); + int16x8_t v491 = vaddq_s16(v491_tmp, v490); + int16x8_t v492 = vaddq_s16(v473, v433); + int16x8_t v493 = vaddq_s16(v436, v452); + int16x8_t v494 = vaddq_s16(v492, v493); + int16x8_t v495 = vaddq_s16(v494, v490); + int16x8_t v496 = vaddq_s16(v491, v495); + int16x8_t v497 = vqrdmulhq_n_s16(v496, 17734); + int16x8_t v498 = vaddq_s16(v487, v497); + int16x8_t v499 = vaddq_s16(v489, v482); + int16x8_t v500_tmp = vqrdmulhq_n_s16(v499, 13573); + int16x8_t v500 = vaddq_s16(v500_tmp, v499); + int16x8_t v501 = vaddq_s16(v493, v484); + int16x8_t v502 = vaddq_s16(v485, v488); + int16x8_t v503 = vaddq_s16(v501, v502); + int16x8_t v504 = vaddq_s16(v500, v503); + int16x8_t v505 = vaddq_s16(v502, v499); + int16x8_t v506_tmp = vqrdmulhq_n_s16(v505, 13573); + int16x8_t v506 = vaddq_s16(v506_tmp, v505); + int16x8_t v507 = vld1q_s16(in + in_stride * 125 + i); + int16x8_t v508 = vld1q_s16(in + in_stride * 123 + i); + int16x8_t v509 = vaddq_s16(v507, v508); + int16x8_t v510 = vaddq_s16(v509, v470); + int16x8_t v511 = vaddq_s16(v510, v492); + int16x8_t v512 = vaddq_s16(v511, v501); + int16x8_t v513 = vaddq_s16(v512, v505); + int16x8_t v514 = vaddq_s16(v506, v513); + int16x8_t v515 = vqrdmulhq_n_s16(v514, 17734); + int16x8_t v516 = vaddq_s16(v504, v515); + int16x8_t v517 = vqrdmulhq_n_s16(v516, 16705); + int16x8_t v518 = vaddq_s16(v498, v517); + int16x8_t v519 = vqrdmulhq_n_s16(v518, 16463); + int16x8_t v520 = vaddq_s16(v481, v519); + int16x8_t v521 = vqrdmulhq_n_s16(v520, 16404); + int16x8_t v522 = vaddq_s16(v410, v521); + int16x8_t v523 = vaddq_s16(v412, v318); + int16x8_t v524_tmp = vqrdmulhq_n_s16(v523, 13573); + int16x8_t v524 = vaddq_s16(v524_tmp, v523); + int16x8_t v525 = vaddq_s16(v416, v320); + int16x8_t v526 = vaddq_s16(v321, v418); + int16x8_t v527 = vaddq_s16(v525, v526); + int16x8_t v528 = vaddq_s16(v524, v527); + int16x8_t v529 = vaddq_s16(v424, v324); + int16x8_t v530 = vaddq_s16(v325, v426); + int16x8_t v531 = vaddq_s16(v529, v530); + int16x8_t v532_tmp = vqrdmulhq_n_s16(v531, 13573); + int16x8_t v532 = vaddq_s16(v532_tmp, v531); + int16x8_t v533 = vaddq_s16(v432, v328); + int16x8_t v534 = vaddq_s16(v329, v434); + int16x8_t v535 = vaddq_s16(v533, v534); + int16x8_t v536 = vaddq_s16(v535, v531); + int16x8_t v537 = vaddq_s16(v532, v536); + int16x8_t v538 = vqrdmulhq_n_s16(v537, 17734); + int16x8_t v539 = vaddq_s16(v528, v538); + int16x8_t v540 = vaddq_s16(v443, v335); + int16x8_t v541 = vaddq_s16(v336, v445); + int16x8_t v542 = vaddq_s16(v540, v541); + int16x8_t v543_tmp = vqrdmulhq_n_s16(v542, 13573); + int16x8_t v543 = vaddq_s16(v543_tmp, v542); + int16x8_t v544 = vaddq_s16(v451, v339); + int16x8_t v545 = vaddq_s16(v340, v453); + int16x8_t v546 = vaddq_s16(v544, v545); + int16x8_t v547 = vaddq_s16(v458, v342); + int16x8_t v548 = vaddq_s16(v343, v460); + int16x8_t v549 = vaddq_s16(v547, v548); + int16x8_t v550 = vaddq_s16(v546, v549); + int16x8_t v551 = vaddq_s16(v543, v550); + int16x8_t v552 = vaddq_s16(v549, v542); + int16x8_t v553_tmp = vqrdmulhq_n_s16(v552, 13573); + int16x8_t v553 = vaddq_s16(v553_tmp, v552); + int16x8_t v554 = vaddq_s16(v469, v349); + int16x8_t v555 = vaddq_s16(v350, v471); + int16x8_t v556 = vaddq_s16(v554, v555); + int16x8_t v557 = vaddq_s16(v556, v546); + int16x8_t v558 = vaddq_s16(v557, v552); + int16x8_t v559 = vaddq_s16(v553, v558); + int16x8_t v560 = vqrdmulhq_n_s16(v559, 17734); + int16x8_t v561 = vaddq_s16(v551, v560); + int16x8_t v562 = vqrdmulhq_n_s16(v561, 16705); + int16x8_t v563 = vaddq_s16(v539, v562); + int16x8_t v564 = vaddq_s16(v446, v359); + int16x8_t v565 = vaddq_s16(v360, v411); + int16x8_t v566 = vaddq_s16(v564, v565); + int16x8_t v567_tmp = vqrdmulhq_n_s16(v566, 13573); + int16x8_t v567 = vaddq_s16(v567_tmp, v566); + int16x8_t v568 = vaddq_s16(v454, v363); + int16x8_t v569 = vaddq_s16(v364, v415); + int16x8_t v570 = vaddq_s16(v568, v569); + int16x8_t v571 = vaddq_s16(v419, v366); + int16x8_t v572 = vaddq_s16(v367, v457); + int16x8_t v573 = vaddq_s16(v571, v572); + int16x8_t v574 = vaddq_s16(v570, v573); + int16x8_t v575 = vaddq_s16(v567, v574); + int16x8_t v576 = vaddq_s16(v461, v371); + int16x8_t v577 = vaddq_s16(v372, v423); + int16x8_t v578 = vaddq_s16(v576, v577); + int16x8_t v579 = vaddq_s16(v427, v374); + int16x8_t v580 = vaddq_s16(v375, v442); + int16x8_t v581 = vaddq_s16(v579, v580); + int16x8_t v582 = vaddq_s16(v578, v581); + int16x8_t v583_tmp = vqrdmulhq_n_s16(v582, 13573); + int16x8_t v583 = vaddq_s16(v583_tmp, v582); + int16x8_t v584 = vaddq_s16(v472, v379); + int16x8_t v585 = vaddq_s16(v380, v431); + int16x8_t v586 = vaddq_s16(v584, v585); + int16x8_t v587 = vaddq_s16(v435, v382); + int16x8_t v588 = vaddq_s16(v383, v450); + int16x8_t v589 = vaddq_s16(v587, v588); + int16x8_t v590 = vaddq_s16(v586, v589); + int16x8_t v591 = vaddq_s16(v590, v582); + int16x8_t v592 = vaddq_s16(v583, v591); + int16x8_t v593 = vqrdmulhq_n_s16(v592, 17734); + int16x8_t v594 = vaddq_s16(v575, v593); + int16x8_t v595 = vaddq_s16(v581, v566); + int16x8_t v596_tmp = vqrdmulhq_n_s16(v595, 13573); + int16x8_t v596 = vaddq_s16(v596_tmp, v595); + int16x8_t v597 = vaddq_s16(v589, v570); + int16x8_t v598 = vaddq_s16(v573, v578); + int16x8_t v599 = vaddq_s16(v597, v598); + int16x8_t v600 = vaddq_s16(v596, v599); + int16x8_t v601 = vaddq_s16(v598, v595); + int16x8_t v602_tmp = vqrdmulhq_n_s16(v601, 13573); + int16x8_t v602 = vaddq_s16(v602_tmp, v601); + int16x8_t v603 = vaddq_s16(v508, v398); + int16x8_t v604 = vaddq_s16(v399, v468); + int16x8_t v605 = vaddq_s16(v603, v604); + int16x8_t v606 = vaddq_s16(v605, v586); + int16x8_t v607 = vaddq_s16(v606, v597); + int16x8_t v608 = vaddq_s16(v607, v601); + int16x8_t v609 = vaddq_s16(v602, v608); + int16x8_t v610 = vqrdmulhq_n_s16(v609, 17734); + int16x8_t v611 = vaddq_s16(v600, v610); + int16x8_t v612 = vqrdmulhq_n_s16(v611, 16705); + int16x8_t v613 = vaddq_s16(v594, v612); + int16x8_t v614 = vqrdmulhq_n_s16(v613, 16463); + int16x8_t v615 = vaddq_s16(v563, v614); + int16x8_t v616 = vaddq_s16(v565, v523); + int16x8_t v617_tmp = vqrdmulhq_n_s16(v616, 13573); + int16x8_t v617 = vaddq_s16(v617_tmp, v616); + int16x8_t v618 = vaddq_s16(v569, v525); + int16x8_t v619 = vaddq_s16(v526, v571); + int16x8_t v620 = vaddq_s16(v618, v619); + int16x8_t v621 = vaddq_s16(v617, v620); + int16x8_t v622 = vaddq_s16(v577, v529); + int16x8_t v623 = vaddq_s16(v530, v579); + int16x8_t v624 = vaddq_s16(v622, v623); + int16x8_t v625_tmp = vqrdmulhq_n_s16(v624, 13573); + int16x8_t v625 = vaddq_s16(v625_tmp, v624); + int16x8_t v626 = vaddq_s16(v585, v533); + int16x8_t v627 = vaddq_s16(v534, v587); + int16x8_t v628 = vaddq_s16(v626, v627); + int16x8_t v629 = vaddq_s16(v628, v624); + int16x8_t v630 = vaddq_s16(v625, v629); + int16x8_t v631 = vqrdmulhq_n_s16(v630, 17734); + int16x8_t v632 = vaddq_s16(v621, v631); + int16x8_t v633 = vaddq_s16(v580, v540); + int16x8_t v634 = vaddq_s16(v541, v564); + int16x8_t v635 = vaddq_s16(v633, v634); + int16x8_t v636_tmp = vqrdmulhq_n_s16(v635, 13573); + int16x8_t v636 = vaddq_s16(v636_tmp, v635); + int16x8_t v637 = vaddq_s16(v588, v544); + int16x8_t v638 = vaddq_s16(v545, v568); + int16x8_t v639 = vaddq_s16(v637, v638); + int16x8_t v640 = vaddq_s16(v572, v547); + int16x8_t v641 = vaddq_s16(v548, v576); + int16x8_t v642 = vaddq_s16(v640, v641); + int16x8_t v643 = vaddq_s16(v639, v642); + int16x8_t v644 = vaddq_s16(v636, v643); + int16x8_t v645 = vaddq_s16(v642, v635); + int16x8_t v646_tmp = vqrdmulhq_n_s16(v645, 13573); + int16x8_t v646 = vaddq_s16(v646_tmp, v645); + int16x8_t v647 = vaddq_s16(v604, v554); + int16x8_t v648 = vaddq_s16(v555, v584); + int16x8_t v649 = vaddq_s16(v647, v648); + int16x8_t v650 = vaddq_s16(v649, v639); + int16x8_t v651 = vaddq_s16(v650, v645); + int16x8_t v652 = vaddq_s16(v646, v651); + int16x8_t v653 = vqrdmulhq_n_s16(v652, 17734); + int16x8_t v654 = vaddq_s16(v644, v653); + int16x8_t v655 = vqrdmulhq_n_s16(v654, 16705); + int16x8_t v656 = vaddq_s16(v632, v655); + int16x8_t v657 = vaddq_s16(v634, v616); + int16x8_t v658_tmp = vqrdmulhq_n_s16(v657, 13573); + int16x8_t v658 = vaddq_s16(v658_tmp, v657); + int16x8_t v659 = vaddq_s16(v638, v618); + int16x8_t v660 = vaddq_s16(v619, v640); + int16x8_t v661 = vaddq_s16(v659, v660); + int16x8_t v662 = vaddq_s16(v658, v661); + int16x8_t v663 = vaddq_s16(v641, v622); + int16x8_t v664 = vaddq_s16(v623, v633); + int16x8_t v665 = vaddq_s16(v663, v664); + int16x8_t v666_tmp = vqrdmulhq_n_s16(v665, 13573); + int16x8_t v666 = vaddq_s16(v666_tmp, v665); + int16x8_t v667 = vaddq_s16(v648, v626); + int16x8_t v668 = vaddq_s16(v627, v637); + int16x8_t v669 = vaddq_s16(v667, v668); + int16x8_t v670 = vaddq_s16(v669, v665); + int16x8_t v671 = vaddq_s16(v666, v670); + int16x8_t v672 = vqrdmulhq_n_s16(v671, 17734); + int16x8_t v673 = vaddq_s16(v662, v672); + int16x8_t v674 = vaddq_s16(v664, v657); + int16x8_t v675_tmp = vqrdmulhq_n_s16(v674, 13573); + int16x8_t v675 = vaddq_s16(v675_tmp, v674); + int16x8_t v676 = vaddq_s16(v668, v659); + int16x8_t v677 = vaddq_s16(v660, v663); + int16x8_t v678 = vaddq_s16(v676, v677); + int16x8_t v679 = vaddq_s16(v675, v678); + int16x8_t v680 = vaddq_s16(v677, v674); + int16x8_t v681_tmp = vqrdmulhq_n_s16(v680, 13573); + int16x8_t v681 = vaddq_s16(v681_tmp, v680); + int16x8_t v682 = vld1q_s16(in + in_stride * 127 + i); + int16x8_t v683 = vaddq_s16(v682, v507); + int16x8_t v684 = vaddq_s16(v683, v603); + int16x8_t v685 = vaddq_s16(v684, v647); + int16x8_t v686 = vaddq_s16(v685, v667); + int16x8_t v687 = vaddq_s16(v686, v676); + int16x8_t v688 = vaddq_s16(v687, v680); + int16x8_t v689 = vaddq_s16(v681, v688); + int16x8_t v690 = vqrdmulhq_n_s16(v689, 17734); + int16x8_t v691 = vaddq_s16(v679, v690); + int16x8_t v692 = vqrdmulhq_n_s16(v691, 16705); + int16x8_t v693 = vaddq_s16(v673, v692); + int16x8_t v694 = vqrdmulhq_n_s16(v693, 16463); + int16x8_t v695 = vaddq_s16(v656, v694); + int16x8_t v696 = vqrdmulhq_n_s16(v695, 16404); + int16x8_t v697 = vaddq_s16(v615, v696); + int16x8_t v698 = vqrdmulhq_n_s16(v697, 16389); + int16x8_t v699 = vaddq_s16(v522, v698); + int16x8_t v700 = vqrdmulhq_n_s16(v699, 16385); + int16x8_t v701 = vaddq_s16(v317, v700); + int16x8_t v702 = vsubq_s16(v0, v1); + int16x8_t v703 = vsubq_s16(v4, v6); + int16x8_t v704_tmp = vqrdmulhq_n_s16(v703, 10045); + int16x8_t v704 = vaddq_s16(v704_tmp, v703); + int16x8_t v705 = vaddq_s16(v702, v704); + int16x8_t v706 = vsubq_s16(v11, v14); + int16x8_t v707 = vsubq_s16(v17, v20); + int16x8_t v708_tmp = vqrdmulhq_n_s16(v707, 10045); + int16x8_t v708 = vaddq_s16(v708_tmp, v707); + int16x8_t v709 = vaddq_s16(v706, v708); + int16x8_t v710 = vqrdmulhq_n_s16(v709, 19705); + int16x8_t v711 = vaddq_s16(v705, v710); + int16x8_t v712 = vsubq_s16(v27, v30); + int16x8_t v713 = vsubq_s16(v35, v39); + int16x8_t v714_tmp = vqrdmulhq_n_s16(v713, 10045); + int16x8_t v714 = vaddq_s16(v714_tmp, v713); + int16x8_t v715 = vaddq_s16(v712, v714); + int16x8_t v716 = vsubq_s16(v44, v47); + int16x8_t v717 = vsubq_s16(v50, v54); + int16x8_t v718_tmp = vqrdmulhq_n_s16(v717, 10045); + int16x8_t v718 = vaddq_s16(v718_tmp, v717); + int16x8_t v719 = vaddq_s16(v716, v718); + int16x8_t v720 = vqrdmulhq_n_s16(v719, 19705); + int16x8_t v721 = vaddq_s16(v715, v720); + int16x8_t v722 = vqrdmulhq_n_s16(v721, 17121); + int16x8_t v723 = vaddq_s16(v711, v722); + int16x8_t v724 = vsubq_s16(v63, v66); + int16x8_t v725 = vsubq_s16(v71, v75); + int16x8_t v726_tmp = vqrdmulhq_n_s16(v725, 10045); + int16x8_t v726 = vaddq_s16(v726_tmp, v725); + int16x8_t v727 = vaddq_s16(v724, v726); + int16x8_t v728 = vsubq_s16(v82, v89); + int16x8_t v729 = vsubq_s16(v92, v97); + int16x8_t v730_tmp = vqrdmulhq_n_s16(v729, 10045); + int16x8_t v730 = vaddq_s16(v730_tmp, v729); + int16x8_t v731 = vaddq_s16(v728, v730); + int16x8_t v732 = vqrdmulhq_n_s16(v731, 19705); + int16x8_t v733 = vaddq_s16(v727, v732); + int16x8_t v734 = vsubq_s16(v104, v107); + int16x8_t v735 = vsubq_s16(v112, v116); + int16x8_t v736_tmp = vqrdmulhq_n_s16(v735, 10045); + int16x8_t v736 = vaddq_s16(v736_tmp, v735); + int16x8_t v737 = vaddq_s16(v734, v736); + int16x8_t v738 = vsubq_s16(v121, v124); + int16x8_t v739 = vsubq_s16(v127, v132); + int16x8_t v740_tmp = vqrdmulhq_n_s16(v739, 10045); + int16x8_t v740 = vaddq_s16(v740_tmp, v739); + int16x8_t v741 = vaddq_s16(v738, v740); + int16x8_t v742 = vqrdmulhq_n_s16(v741, 19705); + int16x8_t v743 = vaddq_s16(v737, v742); + int16x8_t v744 = vqrdmulhq_n_s16(v743, 17121); + int16x8_t v745 = vaddq_s16(v733, v744); + int16x8_t v746 = vqrdmulhq_n_s16(v745, 16563); + int16x8_t v747 = vaddq_s16(v723, v746); + int16x8_t v748 = vsubq_s16(v143, v146); + int16x8_t v749 = vsubq_s16(v151, v155); + int16x8_t v750_tmp = vqrdmulhq_n_s16(v749, 10045); + int16x8_t v750 = vaddq_s16(v750_tmp, v749); + int16x8_t v751 = vaddq_s16(v748, v750); + int16x8_t v752 = vsubq_s16(v162, v169); + int16x8_t v753 = vqrdmulhq_n_s16(v752, 19705); + int16x8_t v754 = vsubq_s16(v172, v177); + int16x8_t v755 = vqrdmulhq_n_s16(v754, 25746); + int16x8_t v756 = vaddq_s16(v753, v755); + int16x8_t v757 = vaddq_s16(v751, v756); + int16x8_t v758 = vsubq_s16(v186, v193); + int16x8_t v759 = vsubq_s16(v202, v210); + int16x8_t v760_tmp = vqrdmulhq_n_s16(v759, 10045); + int16x8_t v760 = vaddq_s16(v760_tmp, v759); + int16x8_t v761 = vaddq_s16(v758, v760); + int16x8_t v762 = vsubq_s16(v215, v218); + int16x8_t v763 = vsubq_s16(v221, v227); + int16x8_t v764_tmp = vqrdmulhq_n_s16(v763, 10045); + int16x8_t v764 = vaddq_s16(v764_tmp, v763); + int16x8_t v765 = vaddq_s16(v762, v764); + int16x8_t v766 = vqrdmulhq_n_s16(v765, 19705); + int16x8_t v767 = vaddq_s16(v761, v766); + int16x8_t v768 = vqrdmulhq_n_s16(v767, 17121); + int16x8_t v769 = vaddq_s16(v757, v768); + int16x8_t v770 = vsubq_s16(v236, v239); + int16x8_t v771 = vsubq_s16(v244, v248); + int16x8_t v772_tmp = vqrdmulhq_n_s16(v771, 10045); + int16x8_t v772 = vaddq_s16(v772_tmp, v771); + int16x8_t v773 = vaddq_s16(v770, v772); + int16x8_t v774 = vsubq_s16(v255, v262); + int16x8_t v775 = vsubq_s16(v265, v270); + int16x8_t v776_tmp = vqrdmulhq_n_s16(v775, 10045); + int16x8_t v776 = vaddq_s16(v776_tmp, v775); + int16x8_t v777 = vaddq_s16(v774, v776); + int16x8_t v778 = vqrdmulhq_n_s16(v777, 19705); + int16x8_t v779 = vaddq_s16(v773, v778); + int16x8_t v780 = vsubq_s16(v277, v280); + int16x8_t v781 = vsubq_s16(v285, v289); + int16x8_t v782_tmp = vqrdmulhq_n_s16(v781, 10045); + int16x8_t v782 = vaddq_s16(v782_tmp, v781); + int16x8_t v783 = vaddq_s16(v780, v782); + int16x8_t v784 = vsubq_s16(v294, v297); + int16x8_t v785 = vsubq_s16(v300, v306); + int16x8_t v786_tmp = vqrdmulhq_n_s16(v785, 10045); + int16x8_t v786 = vaddq_s16(v786_tmp, v785); + int16x8_t v787 = vaddq_s16(v784, v786); + int16x8_t v788 = vqrdmulhq_n_s16(v787, 19705); + int16x8_t v789 = vaddq_s16(v783, v788); + int16x8_t v790 = vqrdmulhq_n_s16(v789, 17121); + int16x8_t v791 = vaddq_s16(v779, v790); + int16x8_t v792 = vqrdmulhq_n_s16(v791, 16563); + int16x8_t v793 = vaddq_s16(v769, v792); + int16x8_t v794 = vqrdmulhq_n_s16(v793, 16429); + int16x8_t v795 = vaddq_s16(v747, v794); + int16x8_t v796 = vsubq_s16(v319, v322); + int16x8_t v797 = vsubq_s16(v327, v331); + int16x8_t v798_tmp = vqrdmulhq_n_s16(v797, 10045); + int16x8_t v798 = vaddq_s16(v798_tmp, v797); + int16x8_t v799 = vaddq_s16(v796, v798); + int16x8_t v800 = vsubq_s16(v338, v345); + int16x8_t v801 = vsubq_s16(v348, v353); + int16x8_t v802_tmp = vqrdmulhq_n_s16(v801, 10045); + int16x8_t v802 = vaddq_s16(v802_tmp, v801); + int16x8_t v803 = vaddq_s16(v800, v802); + int16x8_t v804 = vqrdmulhq_n_s16(v803, 19705); + int16x8_t v805 = vaddq_s16(v799, v804); + int16x8_t v806 = vsubq_s16(v362, v369); + int16x8_t v807 = vsubq_s16(v378, v386); + int16x8_t v808_tmp = vqrdmulhq_n_s16(v807, 10045); + int16x8_t v808 = vaddq_s16(v808_tmp, v807); + int16x8_t v809 = vaddq_s16(v806, v808); + int16x8_t v810 = vsubq_s16(v391, v394); + int16x8_t v811 = vsubq_s16(v397, v403); + int16x8_t v812_tmp = vqrdmulhq_n_s16(v811, 10045); + int16x8_t v812 = vaddq_s16(v812_tmp, v811); + int16x8_t v813 = vaddq_s16(v810, v812); + int16x8_t v814 = vqrdmulhq_n_s16(v813, 19705); + int16x8_t v815 = vaddq_s16(v809, v814); + int16x8_t v816 = vqrdmulhq_n_s16(v815, 17121); + int16x8_t v817 = vaddq_s16(v805, v816); + int16x8_t v818 = vsubq_s16(v414, v421); + int16x8_t v819 = vsubq_s16(v430, v438); + int16x8_t v820_tmp = vqrdmulhq_n_s16(v819, 10045); + int16x8_t v820 = vaddq_s16(v820_tmp, v819); + int16x8_t v821 = vaddq_s16(v818, v820); + int16x8_t v822 = vsubq_s16(v449, v464); + int16x8_t v823 = vsubq_s16(v467, v476); + int16x8_t v824_tmp = vqrdmulhq_n_s16(v823, 10045); + int16x8_t v824 = vaddq_s16(v824_tmp, v823); + int16x8_t v825 = vaddq_s16(v822, v824); + int16x8_t v826 = vqrdmulhq_n_s16(v825, 19705); + int16x8_t v827 = vaddq_s16(v821, v826); + int16x8_t v828 = vsubq_s16(v483, v486); + int16x8_t v829 = vsubq_s16(v491, v495); + int16x8_t v830_tmp = vqrdmulhq_n_s16(v829, 10045); + int16x8_t v830 = vaddq_s16(v830_tmp, v829); + int16x8_t v831 = vaddq_s16(v828, v830); + int16x8_t v832 = vsubq_s16(v500, v503); + int16x8_t v833 = vsubq_s16(v506, v513); + int16x8_t v834_tmp = vqrdmulhq_n_s16(v833, 10045); + int16x8_t v834 = vaddq_s16(v834_tmp, v833); + int16x8_t v835 = vaddq_s16(v832, v834); + int16x8_t v836 = vqrdmulhq_n_s16(v835, 19705); + int16x8_t v837 = vaddq_s16(v831, v836); + int16x8_t v838 = vqrdmulhq_n_s16(v837, 17121); + int16x8_t v839 = vaddq_s16(v827, v838); + int16x8_t v840 = vqrdmulhq_n_s16(v839, 16563); + int16x8_t v841 = vaddq_s16(v817, v840); + int16x8_t v842 = vsubq_s16(v524, v527); + int16x8_t v843 = vsubq_s16(v532, v536); + int16x8_t v844_tmp = vqrdmulhq_n_s16(v843, 10045); + int16x8_t v844 = vaddq_s16(v844_tmp, v843); + int16x8_t v845 = vaddq_s16(v842, v844); + int16x8_t v846 = vsubq_s16(v543, v550); + int16x8_t v847 = vsubq_s16(v553, v558); + int16x8_t v848_tmp = vqrdmulhq_n_s16(v847, 10045); + int16x8_t v848 = vaddq_s16(v848_tmp, v847); + int16x8_t v849 = vaddq_s16(v846, v848); + int16x8_t v850 = vqrdmulhq_n_s16(v849, 19705); + int16x8_t v851 = vaddq_s16(v845, v850); + int16x8_t v852 = vsubq_s16(v567, v574); + int16x8_t v853 = vsubq_s16(v583, v591); + int16x8_t v854_tmp = vqrdmulhq_n_s16(v853, 10045); + int16x8_t v854 = vaddq_s16(v854_tmp, v853); + int16x8_t v855 = vaddq_s16(v852, v854); + int16x8_t v856 = vsubq_s16(v596, v599); + int16x8_t v857 = vsubq_s16(v602, v608); + int16x8_t v858_tmp = vqrdmulhq_n_s16(v857, 10045); + int16x8_t v858 = vaddq_s16(v858_tmp, v857); + int16x8_t v859 = vaddq_s16(v856, v858); + int16x8_t v860 = vqrdmulhq_n_s16(v859, 19705); + int16x8_t v861 = vaddq_s16(v855, v860); + int16x8_t v862 = vqrdmulhq_n_s16(v861, 17121); + int16x8_t v863 = vaddq_s16(v851, v862); + int16x8_t v864 = vsubq_s16(v617, v620); + int16x8_t v865 = vsubq_s16(v625, v629); + int16x8_t v866_tmp = vqrdmulhq_n_s16(v865, 10045); + int16x8_t v866 = vaddq_s16(v866_tmp, v865); + int16x8_t v867 = vaddq_s16(v864, v866); + int16x8_t v868 = vsubq_s16(v636, v643); + int16x8_t v869 = vsubq_s16(v646, v651); + int16x8_t v870_tmp = vqrdmulhq_n_s16(v869, 10045); + int16x8_t v870 = vaddq_s16(v870_tmp, v869); + int16x8_t v871 = vaddq_s16(v868, v870); + int16x8_t v872 = vqrdmulhq_n_s16(v871, 19705); + int16x8_t v873 = vaddq_s16(v867, v872); + int16x8_t v874 = vsubq_s16(v658, v661); + int16x8_t v875 = vsubq_s16(v666, v670); + int16x8_t v876_tmp = vqrdmulhq_n_s16(v875, 10045); + int16x8_t v876 = vaddq_s16(v876_tmp, v875); + int16x8_t v877 = vaddq_s16(v874, v876); + int16x8_t v878 = vsubq_s16(v675, v678); + int16x8_t v879 = vsubq_s16(v681, v688); + int16x8_t v880_tmp = vqrdmulhq_n_s16(v879, 10045); + int16x8_t v880 = vaddq_s16(v880_tmp, v879); + int16x8_t v881 = vaddq_s16(v878, v880); + int16x8_t v882 = vqrdmulhq_n_s16(v881, 19705); + int16x8_t v883 = vaddq_s16(v877, v882); + int16x8_t v884 = vqrdmulhq_n_s16(v883, 17121); + int16x8_t v885 = vaddq_s16(v873, v884); + int16x8_t v886 = vqrdmulhq_n_s16(v885, 16563); + int16x8_t v887 = vaddq_s16(v863, v886); + int16x8_t v888 = vqrdmulhq_n_s16(v887, 16429); + int16x8_t v889 = vaddq_s16(v841, v888); + int16x8_t v890 = vqrdmulhq_n_s16(v889, 16395); + int16x8_t v891 = vaddq_s16(v795, v890); + int16x8_t v892 = vsubq_s16(v702, v704); + int16x8_t v893 = vsubq_s16(v706, v708); + int16x8_t v894 = vqrdmulhq_n_s16(v893, 29490); + int16x8_t v895 = vaddq_s16(v892, v894); + int16x8_t v896 = vsubq_s16(v712, v714); + int16x8_t v897 = vsubq_s16(v716, v718); + int16x8_t v898 = vqrdmulhq_n_s16(v897, 29490); + int16x8_t v899 = vaddq_s16(v896, v898); + int16x8_t v900 = vqrdmulhq_n_s16(v899, 18578); + int16x8_t v901 = vaddq_s16(v895, v900); + int16x8_t v902 = vsubq_s16(v724, v726); + int16x8_t v903 = vsubq_s16(v728, v730); + int16x8_t v904 = vqrdmulhq_n_s16(v903, 29490); + int16x8_t v905 = vaddq_s16(v902, v904); + int16x8_t v906 = vsubq_s16(v734, v736); + int16x8_t v907 = vsubq_s16(v738, v740); + int16x8_t v908 = vqrdmulhq_n_s16(v907, 29490); + int16x8_t v909 = vaddq_s16(v906, v908); + int16x8_t v910 = vqrdmulhq_n_s16(v909, 18578); + int16x8_t v911 = vaddq_s16(v905, v910); + int16x8_t v912 = vqrdmulhq_n_s16(v911, 16890); + int16x8_t v913 = vaddq_s16(v901, v912); + int16x8_t v914 = vsubq_s16(v748, v750); + int16x8_t v915_tmp = vqrdmulhq_n_s16(v754, 10045); + int16x8_t v915 = vaddq_s16(v915_tmp, v754); + int16x8_t v916 = vsubq_s16(v752, v915); + int16x8_t v917 = vqrdmulhq_n_s16(v916, 29490); + int16x8_t v918 = vaddq_s16(v914, v917); + int16x8_t v919 = vsubq_s16(v758, v760); + int16x8_t v920 = vsubq_s16(v762, v764); + int16x8_t v921 = vqrdmulhq_n_s16(v920, 29490); + int16x8_t v922 = vaddq_s16(v919, v921); + int16x8_t v923 = vqrdmulhq_n_s16(v922, 18578); + int16x8_t v924 = vaddq_s16(v918, v923); + int16x8_t v925 = vsubq_s16(v770, v772); + int16x8_t v926 = vsubq_s16(v774, v776); + int16x8_t v927 = vqrdmulhq_n_s16(v926, 29490); + int16x8_t v928 = vaddq_s16(v925, v927); + int16x8_t v929 = vsubq_s16(v780, v782); + int16x8_t v930 = vsubq_s16(v784, v786); + int16x8_t v931 = vqrdmulhq_n_s16(v930, 29490); + int16x8_t v932 = vaddq_s16(v929, v931); + int16x8_t v933 = vqrdmulhq_n_s16(v932, 18578); + int16x8_t v934 = vaddq_s16(v928, v933); + int16x8_t v935 = vqrdmulhq_n_s16(v934, 16890); + int16x8_t v936 = vaddq_s16(v924, v935); + int16x8_t v937 = vqrdmulhq_n_s16(v936, 16508); + int16x8_t v938 = vaddq_s16(v913, v937); + int16x8_t v939 = vsubq_s16(v796, v798); + int16x8_t v940 = vsubq_s16(v800, v802); + int16x8_t v941 = vqrdmulhq_n_s16(v940, 29490); + int16x8_t v942 = vaddq_s16(v939, v941); + int16x8_t v943 = vsubq_s16(v806, v808); + int16x8_t v944 = vsubq_s16(v810, v812); + int16x8_t v945 = vqrdmulhq_n_s16(v944, 29490); + int16x8_t v946 = vaddq_s16(v943, v945); + int16x8_t v947 = vqrdmulhq_n_s16(v946, 18578); + int16x8_t v948 = vaddq_s16(v942, v947); + int16x8_t v949 = vsubq_s16(v818, v820); + int16x8_t v950 = vsubq_s16(v822, v824); + int16x8_t v951 = vqrdmulhq_n_s16(v950, 29490); + int16x8_t v952 = vaddq_s16(v949, v951); + int16x8_t v953 = vsubq_s16(v828, v830); + int16x8_t v954 = vsubq_s16(v832, v834); + int16x8_t v955 = vqrdmulhq_n_s16(v954, 29490); + int16x8_t v956 = vaddq_s16(v953, v955); + int16x8_t v957 = vqrdmulhq_n_s16(v956, 18578); + int16x8_t v958 = vaddq_s16(v952, v957); + int16x8_t v959 = vqrdmulhq_n_s16(v958, 16890); + int16x8_t v960 = vaddq_s16(v948, v959); + int16x8_t v961 = vsubq_s16(v842, v844); + int16x8_t v962 = vsubq_s16(v846, v848); + int16x8_t v963 = vqrdmulhq_n_s16(v962, 29490); + int16x8_t v964 = vaddq_s16(v961, v963); + int16x8_t v965 = vsubq_s16(v852, v854); + int16x8_t v966 = vsubq_s16(v856, v858); + int16x8_t v967 = vqrdmulhq_n_s16(v966, 29490); + int16x8_t v968 = vaddq_s16(v965, v967); + int16x8_t v969 = vqrdmulhq_n_s16(v968, 18578); + int16x8_t v970 = vaddq_s16(v964, v969); + int16x8_t v971 = vsubq_s16(v864, v866); + int16x8_t v972 = vsubq_s16(v868, v870); + int16x8_t v973 = vqrdmulhq_n_s16(v972, 29490); + int16x8_t v974 = vaddq_s16(v971, v973); + int16x8_t v975 = vsubq_s16(v874, v876); + int16x8_t v976 = vsubq_s16(v878, v880); + int16x8_t v977 = vqrdmulhq_n_s16(v976, 29490); + int16x8_t v978 = vaddq_s16(v975, v977); + int16x8_t v979 = vqrdmulhq_n_s16(v978, 18578); + int16x8_t v980 = vaddq_s16(v974, v979); + int16x8_t v981 = vqrdmulhq_n_s16(v980, 16890); + int16x8_t v982 = vaddq_s16(v970, v981); + int16x8_t v983 = vqrdmulhq_n_s16(v982, 16508); + int16x8_t v984 = vaddq_s16(v960, v983); + int16x8_t v985 = vqrdmulhq_n_s16(v984, 16415); + int16x8_t v986 = vaddq_s16(v938, v985); + int16x8_t v987 = vsubq_s16(v2, v8); + int16x8_t v988 = vsubq_s16(v15, v22); + int16x8_t v989_tmp = vqrdmulhq_n_s16(v988, 18446); + int16x8_t v989 = vmlaq_n_s16(v989_tmp, v988, 2); + int16x8_t v990 = vaddq_s16(v987, v989); + int16x8_t v991 = vsubq_s16(v31, v41); + int16x8_t v992 = vsubq_s16(v48, v56); + int16x8_t v993_tmp = vqrdmulhq_n_s16(v992, 18446); + int16x8_t v993 = vmlaq_n_s16(v993_tmp, v992, 2); + int16x8_t v994 = vaddq_s16(v991, v993); + int16x8_t v995 = vqrdmulhq_n_s16(v994, 21195); + int16x8_t v996 = vaddq_s16(v990, v995); + int16x8_t v997 = vsubq_s16(v67, v77); + int16x8_t v998 = vsubq_s16(v90, v99); + int16x8_t v999_tmp = vqrdmulhq_n_s16(v998, 18446); + int16x8_t v999 = vmlaq_n_s16(v999_tmp, v998, 2); + int16x8_t v1000 = vaddq_s16(v997, v999); + int16x8_t v1001 = vsubq_s16(v108, v118); + int16x8_t v1002 = vsubq_s16(v125, v134); + int16x8_t v1003_tmp = vqrdmulhq_n_s16(v1002, 18446); + int16x8_t v1003 = vmlaq_n_s16(v1003_tmp, v1002, 2); + int16x8_t v1004 = vaddq_s16(v1001, v1003); + int16x8_t v1005 = vqrdmulhq_n_s16(v1004, 21195); + int16x8_t v1006 = vaddq_s16(v1000, v1005); + int16x8_t v1007 = vqrdmulhq_n_s16(v1006, 17401); + int16x8_t v1008 = vaddq_s16(v996, v1007); + int16x8_t v1009 = vsubq_s16(v147, v157); + int16x8_t v1010 = vsubq_s16(v170, v179); + int16x8_t v1011_tmp = vqrdmulhq_n_s16(v1010, 18446); + int16x8_t v1011 = vmlaq_n_s16(v1011_tmp, v1010, 2); + int16x8_t v1012 = vaddq_s16(v1009, v1011); + int16x8_t v1013 = vsubq_s16(v194, v212); + int16x8_t v1014 = vsubq_s16(v219, v229); + int16x8_t v1015_tmp = vqrdmulhq_n_s16(v1014, 18446); + int16x8_t v1015 = vmlaq_n_s16(v1015_tmp, v1014, 2); + int16x8_t v1016 = vaddq_s16(v1013, v1015); + int16x8_t v1017 = vqrdmulhq_n_s16(v1016, 21195); + int16x8_t v1018 = vaddq_s16(v1012, v1017); + int16x8_t v1019 = vsubq_s16(v240, v250); + int16x8_t v1020 = vsubq_s16(v263, v272); + int16x8_t v1021_tmp = vqrdmulhq_n_s16(v1020, 18446); + int16x8_t v1021 = vmlaq_n_s16(v1021_tmp, v1020, 2); + int16x8_t v1022 = vaddq_s16(v1019, v1021); + int16x8_t v1023 = vsubq_s16(v281, v291); + int16x8_t v1024 = vsubq_s16(v298, v308); + int16x8_t v1025_tmp = vqrdmulhq_n_s16(v1024, 18446); + int16x8_t v1025 = vmlaq_n_s16(v1025_tmp, v1024, 2); + int16x8_t v1026 = vaddq_s16(v1023, v1025); + int16x8_t v1027 = vqrdmulhq_n_s16(v1026, 21195); + int16x8_t v1028 = vaddq_s16(v1022, v1027); + int16x8_t v1029 = vqrdmulhq_n_s16(v1028, 17401); + int16x8_t v1030 = vaddq_s16(v1018, v1029); + int16x8_t v1031 = vqrdmulhq_n_s16(v1030, 16629); + int16x8_t v1032 = vaddq_s16(v1008, v1031); + int16x8_t v1033 = vsubq_s16(v323, v333); + int16x8_t v1034 = vsubq_s16(v346, v355); + int16x8_t v1035_tmp = vqrdmulhq_n_s16(v1034, 18446); + int16x8_t v1035 = vmlaq_n_s16(v1035_tmp, v1034, 2); + int16x8_t v1036 = vaddq_s16(v1033, v1035); + int16x8_t v1037 = vsubq_s16(v370, v388); + int16x8_t v1038 = vsubq_s16(v395, v405); + int16x8_t v1039_tmp = vqrdmulhq_n_s16(v1038, 18446); + int16x8_t v1039 = vmlaq_n_s16(v1039_tmp, v1038, 2); + int16x8_t v1040 = vaddq_s16(v1037, v1039); + int16x8_t v1041 = vqrdmulhq_n_s16(v1040, 21195); + int16x8_t v1042 = vaddq_s16(v1036, v1041); + int16x8_t v1043 = vsubq_s16(v422, v440); + int16x8_t v1044 = vsubq_s16(v465, v478); + int16x8_t v1045_tmp = vqrdmulhq_n_s16(v1044, 18446); + int16x8_t v1045 = vmlaq_n_s16(v1045_tmp, v1044, 2); + int16x8_t v1046 = vaddq_s16(v1043, v1045); + int16x8_t v1047 = vsubq_s16(v487, v497); + int16x8_t v1048 = vsubq_s16(v504, v515); + int16x8_t v1049_tmp = vqrdmulhq_n_s16(v1048, 18446); + int16x8_t v1049 = vmlaq_n_s16(v1049_tmp, v1048, 2); + int16x8_t v1050 = vaddq_s16(v1047, v1049); + int16x8_t v1051 = vqrdmulhq_n_s16(v1050, 21195); + int16x8_t v1052 = vaddq_s16(v1046, v1051); + int16x8_t v1053 = vqrdmulhq_n_s16(v1052, 17401); + int16x8_t v1054 = vaddq_s16(v1042, v1053); + int16x8_t v1055 = vsubq_s16(v528, v538); + int16x8_t v1056 = vsubq_s16(v551, v560); + int16x8_t v1057_tmp = vqrdmulhq_n_s16(v1056, 18446); + int16x8_t v1057 = vmlaq_n_s16(v1057_tmp, v1056, 2); + int16x8_t v1058 = vaddq_s16(v1055, v1057); + int16x8_t v1059 = vsubq_s16(v575, v593); + int16x8_t v1060 = vsubq_s16(v600, v610); + int16x8_t v1061_tmp = vqrdmulhq_n_s16(v1060, 18446); + int16x8_t v1061 = vmlaq_n_s16(v1061_tmp, v1060, 2); + int16x8_t v1062 = vaddq_s16(v1059, v1061); + int16x8_t v1063 = vqrdmulhq_n_s16(v1062, 21195); + int16x8_t v1064 = vaddq_s16(v1058, v1063); + int16x8_t v1065 = vsubq_s16(v621, v631); + int16x8_t v1066 = vsubq_s16(v644, v653); + int16x8_t v1067_tmp = vqrdmulhq_n_s16(v1066, 18446); + int16x8_t v1067 = vmlaq_n_s16(v1067_tmp, v1066, 2); + int16x8_t v1068 = vaddq_s16(v1065, v1067); + int16x8_t v1069 = vsubq_s16(v662, v672); + int16x8_t v1070 = vsubq_s16(v679, v690); + int16x8_t v1071_tmp = vqrdmulhq_n_s16(v1070, 18446); + int16x8_t v1071 = vmlaq_n_s16(v1071_tmp, v1070, 2); + int16x8_t v1072 = vaddq_s16(v1069, v1071); + int16x8_t v1073 = vqrdmulhq_n_s16(v1072, 21195); + int16x8_t v1074 = vaddq_s16(v1068, v1073); + int16x8_t v1075 = vqrdmulhq_n_s16(v1074, 17401); + int16x8_t v1076 = vaddq_s16(v1064, v1075); + int16x8_t v1077 = vqrdmulhq_n_s16(v1076, 16629); + int16x8_t v1078 = vaddq_s16(v1054, v1077); + int16x8_t v1079 = vqrdmulhq_n_s16(v1078, 16445); + int16x8_t v1080 = vaddq_s16(v1032, v1079); + int16x8_t v1081 = vsubq_s16(v987, v989); + int16x8_t v1082 = vsubq_s16(v991, v993); + int16x8_t v1083 = vqrdmulhq_n_s16(v1082, 25826); + int16x8_t v1084 = vaddq_s16(v1081, v1083); + int16x8_t v1085 = vsubq_s16(v997, v999); + int16x8_t v1086 = vsubq_s16(v1001, v1003); + int16x8_t v1087 = vqrdmulhq_n_s16(v1086, 25826); + int16x8_t v1088 = vaddq_s16(v1085, v1087); + int16x8_t v1089 = vqrdmulhq_n_s16(v1088, 18124); + int16x8_t v1090 = vaddq_s16(v1084, v1089); + int16x8_t v1091 = vsubq_s16(v1009, v1011); + int16x8_t v1092 = vsubq_s16(v1013, v1015); + int16x8_t v1093 = vqrdmulhq_n_s16(v1092, 25826); + int16x8_t v1094 = vaddq_s16(v1091, v1093); + int16x8_t v1095 = vsubq_s16(v1019, v1021); + int16x8_t v1096 = vsubq_s16(v1023, v1025); + int16x8_t v1097 = vqrdmulhq_n_s16(v1096, 25826); + int16x8_t v1098 = vaddq_s16(v1095, v1097); + int16x8_t v1099 = vqrdmulhq_n_s16(v1098, 18124); + int16x8_t v1100 = vaddq_s16(v1094, v1099); + int16x8_t v1101 = vqrdmulhq_n_s16(v1100, 16792); + int16x8_t v1102 = vaddq_s16(v1090, v1101); + int16x8_t v1103 = vsubq_s16(v1033, v1035); + int16x8_t v1104 = vsubq_s16(v1037, v1039); + int16x8_t v1105 = vqrdmulhq_n_s16(v1104, 25826); + int16x8_t v1106 = vaddq_s16(v1103, v1105); + int16x8_t v1107 = vsubq_s16(v1043, v1045); + int16x8_t v1108 = vsubq_s16(v1047, v1049); + int16x8_t v1109 = vqrdmulhq_n_s16(v1108, 25826); + int16x8_t v1110 = vaddq_s16(v1107, v1109); + int16x8_t v1111 = vqrdmulhq_n_s16(v1110, 18124); + int16x8_t v1112 = vaddq_s16(v1106, v1111); + int16x8_t v1113 = vsubq_s16(v1055, v1057); + int16x8_t v1114 = vsubq_s16(v1059, v1061); + int16x8_t v1115 = vqrdmulhq_n_s16(v1114, 25826); + int16x8_t v1116 = vaddq_s16(v1113, v1115); + int16x8_t v1117 = vsubq_s16(v1065, v1067); + int16x8_t v1118 = vsubq_s16(v1069, v1071); + int16x8_t v1119 = vqrdmulhq_n_s16(v1118, 25826); + int16x8_t v1120 = vaddq_s16(v1117, v1119); + int16x8_t v1121 = vqrdmulhq_n_s16(v1120, 18124); + int16x8_t v1122 = vaddq_s16(v1116, v1121); + int16x8_t v1123 = vqrdmulhq_n_s16(v1122, 16792); + int16x8_t v1124 = vaddq_s16(v1112, v1123); + int16x8_t v1125 = vqrdmulhq_n_s16(v1124, 16484); + int16x8_t v1126 = vaddq_s16(v1102, v1125); + int16x8_t v1127 = vsubq_s16(v892, v894); + int16x8_t v1128 = vsubq_s16(v896, v898); + int16x8_t v1129_tmp = vqrdmulhq_n_s16(v1128, 1988); + int16x8_t v1129 = vaddq_s16(v1129_tmp, v1128); + int16x8_t v1130 = vaddq_s16(v1127, v1129); + int16x8_t v1131 = vsubq_s16(v902, v904); + int16x8_t v1132 = vsubq_s16(v906, v908); + int16x8_t v1133_tmp = vqrdmulhq_n_s16(v1132, 1988); + int16x8_t v1133 = vaddq_s16(v1133_tmp, v1132); + int16x8_t v1134 = vaddq_s16(v1131, v1133); + int16x8_t v1135 = vqrdmulhq_n_s16(v1134, 19102); + int16x8_t v1136 = vaddq_s16(v1130, v1135); + int16x8_t v1137 = vsubq_s16(v914, v917); + int16x8_t v1138 = vsubq_s16(v919, v921); + int16x8_t v1139_tmp = vqrdmulhq_n_s16(v1138, 1988); + int16x8_t v1139 = vaddq_s16(v1139_tmp, v1138); + int16x8_t v1140 = vaddq_s16(v1137, v1139); + int16x8_t v1141 = vsubq_s16(v925, v927); + int16x8_t v1142 = vsubq_s16(v929, v931); + int16x8_t v1143_tmp = vqrdmulhq_n_s16(v1142, 1988); + int16x8_t v1143 = vaddq_s16(v1143_tmp, v1142); + int16x8_t v1144 = vaddq_s16(v1141, v1143); + int16x8_t v1145 = vqrdmulhq_n_s16(v1144, 19102); + int16x8_t v1146 = vaddq_s16(v1140, v1145); + int16x8_t v1147 = vqrdmulhq_n_s16(v1146, 17000); + int16x8_t v1148 = vaddq_s16(v1136, v1147); + int16x8_t v1149 = vsubq_s16(v939, v941); + int16x8_t v1150 = vsubq_s16(v943, v945); + int16x8_t v1151_tmp = vqrdmulhq_n_s16(v1150, 1988); + int16x8_t v1151 = vaddq_s16(v1151_tmp, v1150); + int16x8_t v1152 = vaddq_s16(v1149, v1151); + int16x8_t v1153 = vsubq_s16(v949, v951); + int16x8_t v1154 = vsubq_s16(v953, v955); + int16x8_t v1155_tmp = vqrdmulhq_n_s16(v1154, 1988); + int16x8_t v1155 = vaddq_s16(v1155_tmp, v1154); + int16x8_t v1156 = vaddq_s16(v1153, v1155); + int16x8_t v1157 = vqrdmulhq_n_s16(v1156, 19102); + int16x8_t v1158 = vaddq_s16(v1152, v1157); + int16x8_t v1159 = vsubq_s16(v961, v963); + int16x8_t v1160 = vsubq_s16(v965, v967); + int16x8_t v1161_tmp = vqrdmulhq_n_s16(v1160, 1988); + int16x8_t v1161 = vaddq_s16(v1161_tmp, v1160); + int16x8_t v1162 = vaddq_s16(v1159, v1161); + int16x8_t v1163 = vsubq_s16(v971, v973); + int16x8_t v1164 = vsubq_s16(v975, v977); + int16x8_t v1165_tmp = vqrdmulhq_n_s16(v1164, 1988); + int16x8_t v1165 = vaddq_s16(v1165_tmp, v1164); + int16x8_t v1166 = vaddq_s16(v1163, v1165); + int16x8_t v1167 = vqrdmulhq_n_s16(v1166, 19102); + int16x8_t v1168 = vaddq_s16(v1162, v1167); + int16x8_t v1169 = vqrdmulhq_n_s16(v1168, 17000); + int16x8_t v1170 = vaddq_s16(v1158, v1169); + int16x8_t v1171 = vqrdmulhq_n_s16(v1170, 16534); + int16x8_t v1172 = vaddq_s16(v1148, v1171); + int16x8_t v1173 = vsubq_s16(v705, v710); + int16x8_t v1174 = vsubq_s16(v715, v720); + int16x8_t v1175_tmp = vqrdmulhq_n_s16(v1174, 23673); + int16x8_t v1175 = vaddq_s16(v1175_tmp, v1174); + int16x8_t v1176 = vaddq_s16(v1173, v1175); + int16x8_t v1177 = vsubq_s16(v727, v732); + int16x8_t v1178 = vsubq_s16(v737, v742); + int16x8_t v1179_tmp = vqrdmulhq_n_s16(v1178, 23673); + int16x8_t v1179 = vaddq_s16(v1179_tmp, v1178); + int16x8_t v1180 = vaddq_s16(v1177, v1179); + int16x8_t v1181 = vqrdmulhq_n_s16(v1180, 20398); + int16x8_t v1182 = vaddq_s16(v1176, v1181); + int16x8_t v1183 = vsubq_s16(v751, v756); + int16x8_t v1184 = vsubq_s16(v761, v766); + int16x8_t v1185_tmp = vqrdmulhq_n_s16(v1184, 23673); + int16x8_t v1185 = vaddq_s16(v1185_tmp, v1184); + int16x8_t v1186 = vaddq_s16(v1183, v1185); + int16x8_t v1187 = vsubq_s16(v773, v778); + int16x8_t v1188 = vsubq_s16(v783, v788); + int16x8_t v1189_tmp = vqrdmulhq_n_s16(v1188, 23673); + int16x8_t v1189 = vaddq_s16(v1189_tmp, v1188); + int16x8_t v1190 = vaddq_s16(v1187, v1189); + int16x8_t v1191 = vqrdmulhq_n_s16(v1190, 20398); + int16x8_t v1192 = vaddq_s16(v1186, v1191); + int16x8_t v1193 = vqrdmulhq_n_s16(v1192, 17255); + int16x8_t v1194 = vaddq_s16(v1182, v1193); + int16x8_t v1195 = vsubq_s16(v799, v804); + int16x8_t v1196 = vsubq_s16(v809, v814); + int16x8_t v1197_tmp = vqrdmulhq_n_s16(v1196, 23673); + int16x8_t v1197 = vaddq_s16(v1197_tmp, v1196); + int16x8_t v1198 = vaddq_s16(v1195, v1197); + int16x8_t v1199 = vsubq_s16(v821, v826); + int16x8_t v1200 = vsubq_s16(v831, v836); + int16x8_t v1201_tmp = vqrdmulhq_n_s16(v1200, 23673); + int16x8_t v1201 = vaddq_s16(v1201_tmp, v1200); + int16x8_t v1202 = vaddq_s16(v1199, v1201); + int16x8_t v1203 = vqrdmulhq_n_s16(v1202, 20398); + int16x8_t v1204 = vaddq_s16(v1198, v1203); + int16x8_t v1205 = vsubq_s16(v845, v850); + int16x8_t v1206 = vsubq_s16(v855, v860); + int16x8_t v1207_tmp = vqrdmulhq_n_s16(v1206, 23673); + int16x8_t v1207 = vaddq_s16(v1207_tmp, v1206); + int16x8_t v1208 = vaddq_s16(v1205, v1207); + int16x8_t v1209 = vsubq_s16(v867, v872); + int16x8_t v1210 = vsubq_s16(v877, v882); + int16x8_t v1211_tmp = vqrdmulhq_n_s16(v1210, 23673); + int16x8_t v1211 = vaddq_s16(v1211_tmp, v1210); + int16x8_t v1212 = vaddq_s16(v1209, v1211); + int16x8_t v1213 = vqrdmulhq_n_s16(v1212, 20398); + int16x8_t v1214 = vaddq_s16(v1208, v1213); + int16x8_t v1215 = vqrdmulhq_n_s16(v1214, 17255); + int16x8_t v1216 = vaddq_s16(v1204, v1215); + int16x8_t v1217 = vqrdmulhq_n_s16(v1216, 16595); + int16x8_t v1218 = vaddq_s16(v1194, v1217); + int16x8_t v1219 = vsubq_s16(v9, v24); + int16x8_t v1220 = vsubq_s16(v42, v58); + int16x8_t v1221_tmp = vqrdmulhq_n_s16(v1220, 3314); + int16x8_t v1221 = vmlaq_n_s16(v1221_tmp, v1220, 5); + int16x8_t v1222 = vaddq_s16(v1219, v1221); + int16x8_t v1223 = vsubq_s16(v78, v101); + int16x8_t v1224 = vsubq_s16(v119, v136); + int16x8_t v1225_tmp = vqrdmulhq_n_s16(v1224, 3314); + int16x8_t v1225 = vmlaq_n_s16(v1225_tmp, v1224, 5); + int16x8_t v1226 = vaddq_s16(v1223, v1225); + int16x8_t v1227 = vqrdmulhq_n_s16(v1226, 22112); + int16x8_t v1228 = vaddq_s16(v1222, v1227); + int16x8_t v1229 = vsubq_s16(v158, v181); + int16x8_t v1230 = vsubq_s16(v213, v231); + int16x8_t v1231_tmp = vqrdmulhq_n_s16(v1230, 3314); + int16x8_t v1231 = vmlaq_n_s16(v1231_tmp, v1230, 5); + int16x8_t v1232 = vaddq_s16(v1229, v1231); + int16x8_t v1233 = vsubq_s16(v251, v274); + int16x8_t v1234 = vsubq_s16(v292, v310); + int16x8_t v1235_tmp = vqrdmulhq_n_s16(v1234, 3314); + int16x8_t v1235 = vmlaq_n_s16(v1235_tmp, v1234, 5); + int16x8_t v1236 = vaddq_s16(v1233, v1235); + int16x8_t v1237 = vqrdmulhq_n_s16(v1236, 22112); + int16x8_t v1238 = vaddq_s16(v1232, v1237); + int16x8_t v1239 = vqrdmulhq_n_s16(v1238, 17561); + int16x8_t v1240 = vaddq_s16(v1228, v1239); + int16x8_t v1241 = vsubq_s16(v334, v357); + int16x8_t v1242 = vsubq_s16(v389, v407); + int16x8_t v1243_tmp = vqrdmulhq_n_s16(v1242, 3314); + int16x8_t v1243 = vmlaq_n_s16(v1243_tmp, v1242, 5); + int16x8_t v1244 = vaddq_s16(v1241, v1243); + int16x8_t v1245 = vsubq_s16(v441, v480); + int16x8_t v1246 = vsubq_s16(v498, v517); + int16x8_t v1247_tmp = vqrdmulhq_n_s16(v1246, 3314); + int16x8_t v1247 = vmlaq_n_s16(v1247_tmp, v1246, 5); + int16x8_t v1248 = vaddq_s16(v1245, v1247); + int16x8_t v1249 = vqrdmulhq_n_s16(v1248, 22112); + int16x8_t v1250 = vaddq_s16(v1244, v1249); + int16x8_t v1251 = vsubq_s16(v539, v562); + int16x8_t v1252 = vsubq_s16(v594, v612); + int16x8_t v1253_tmp = vqrdmulhq_n_s16(v1252, 3314); + int16x8_t v1253 = vmlaq_n_s16(v1253_tmp, v1252, 5); + int16x8_t v1254 = vaddq_s16(v1251, v1253); + int16x8_t v1255 = vsubq_s16(v632, v655); + int16x8_t v1256 = vsubq_s16(v673, v692); + int16x8_t v1257_tmp = vqrdmulhq_n_s16(v1256, 3314); + int16x8_t v1257 = vmlaq_n_s16(v1257_tmp, v1256, 5); + int16x8_t v1258 = vaddq_s16(v1255, v1257); + int16x8_t v1259 = vqrdmulhq_n_s16(v1258, 22112); + int16x8_t v1260 = vaddq_s16(v1254, v1259); + int16x8_t v1261 = vqrdmulhq_n_s16(v1260, 17561); + int16x8_t v1262 = vaddq_s16(v1250, v1261); + int16x8_t v1263 = vqrdmulhq_n_s16(v1262, 16666); + int16x8_t v1264 = vaddq_s16(v1240, v1263); + int16x8_t v1265 = vsubq_s16(v1219, v1221); + int16x8_t v1266 = vsubq_s16(v1223, v1225); + int16x8_t v1267 = vqrdmulhq_n_s16(v1266, 24397); + int16x8_t v1268 = vaddq_s16(v1265, v1267); + int16x8_t v1269 = vsubq_s16(v1229, v1231); + int16x8_t v1270 = vsubq_s16(v1233, v1235); + int16x8_t v1271 = vqrdmulhq_n_s16(v1270, 24397); + int16x8_t v1272 = vaddq_s16(v1269, v1271); + int16x8_t v1273 = vqrdmulhq_n_s16(v1272, 17921); + int16x8_t v1274 = vaddq_s16(v1268, v1273); + int16x8_t v1275 = vsubq_s16(v1241, v1243); + int16x8_t v1276 = vsubq_s16(v1245, v1247); + int16x8_t v1277 = vqrdmulhq_n_s16(v1276, 24397); + int16x8_t v1278 = vaddq_s16(v1275, v1277); + int16x8_t v1279 = vsubq_s16(v1251, v1253); + int16x8_t v1280 = vsubq_s16(v1255, v1257); + int16x8_t v1281 = vqrdmulhq_n_s16(v1280, 24397); + int16x8_t v1282 = vaddq_s16(v1279, v1281); + int16x8_t v1283 = vqrdmulhq_n_s16(v1282, 17921); + int16x8_t v1284 = vaddq_s16(v1278, v1283); + int16x8_t v1285 = vqrdmulhq_n_s16(v1284, 16747); + int16x8_t v1286 = vaddq_s16(v1274, v1285); + int16x8_t v1287 = vsubq_s16(v1173, v1175); + int16x8_t v1288 = vsubq_s16(v1177, v1179); + int16x8_t v1289 = vqrdmulhq_n_s16(v1288, 27504); + int16x8_t v1290 = vaddq_s16(v1287, v1289); + int16x8_t v1291 = vsubq_s16(v1183, v1185); + int16x8_t v1292 = vsubq_s16(v1187, v1189); + int16x8_t v1293 = vqrdmulhq_n_s16(v1292, 27504); + int16x8_t v1294 = vaddq_s16(v1291, v1293); + int16x8_t v1295 = vqrdmulhq_n_s16(v1294, 18343); + int16x8_t v1296 = vaddq_s16(v1290, v1295); + int16x8_t v1297 = vsubq_s16(v1195, v1197); + int16x8_t v1298 = vsubq_s16(v1199, v1201); + int16x8_t v1299 = vqrdmulhq_n_s16(v1298, 27504); + int16x8_t v1300 = vaddq_s16(v1297, v1299); + int16x8_t v1301 = vsubq_s16(v1205, v1207); + int16x8_t v1302 = vsubq_s16(v1209, v1211); + int16x8_t v1303 = vqrdmulhq_n_s16(v1302, 27504); + int16x8_t v1304 = vaddq_s16(v1301, v1303); + int16x8_t v1305 = vqrdmulhq_n_s16(v1304, 18343); + int16x8_t v1306 = vaddq_s16(v1300, v1305); + int16x8_t v1307 = vqrdmulhq_n_s16(v1306, 16840); + int16x8_t v1308 = vaddq_s16(v1296, v1307); + int16x8_t v1309 = vsubq_s16(v1127, v1129); + int16x8_t v1310 = vsubq_s16(v1131, v1133); + int16x8_t v1311 = vqrdmulhq_n_s16(v1310, 31869); + int16x8_t v1312 = vaddq_s16(v1309, v1311); + int16x8_t v1313 = vsubq_s16(v1137, v1139); + int16x8_t v1314 = vsubq_s16(v1141, v1143); + int16x8_t v1315 = vqrdmulhq_n_s16(v1314, 31869); + int16x8_t v1316 = vaddq_s16(v1313, v1315); + int16x8_t v1317 = vqrdmulhq_n_s16(v1316, 18830); + int16x8_t v1318 = vaddq_s16(v1312, v1317); + int16x8_t v1319 = vsubq_s16(v1149, v1151); + int16x8_t v1320 = vsubq_s16(v1153, v1155); + int16x8_t v1321 = vqrdmulhq_n_s16(v1320, 31869); + int16x8_t v1322 = vaddq_s16(v1319, v1321); + int16x8_t v1323 = vsubq_s16(v1159, v1161); + int16x8_t v1324 = vsubq_s16(v1163, v1165); + int16x8_t v1325 = vqrdmulhq_n_s16(v1324, 31869); + int16x8_t v1326 = vaddq_s16(v1323, v1325); + int16x8_t v1327 = vqrdmulhq_n_s16(v1326, 18830); + int16x8_t v1328 = vaddq_s16(v1322, v1327); + int16x8_t v1329 = vqrdmulhq_n_s16(v1328, 16944); + int16x8_t v1330 = vaddq_s16(v1318, v1329); + int16x8_t v1331 = vsubq_s16(v1081, v1083); + int16x8_t v1332 = vsubq_s16(v1085, v1087); + int16x8_t v1333_tmp = vqrdmulhq_n_s16(v1332, 5552); + int16x8_t v1333 = vaddq_s16(v1333_tmp, v1332); + int16x8_t v1334 = vaddq_s16(v1331, v1333); + int16x8_t v1335 = vsubq_s16(v1091, v1093); + int16x8_t v1336 = vsubq_s16(v1095, v1097); + int16x8_t v1337_tmp = vqrdmulhq_n_s16(v1336, 5552); + int16x8_t v1337 = vaddq_s16(v1337_tmp, v1336); + int16x8_t v1338 = vaddq_s16(v1335, v1337); + int16x8_t v1339 = vqrdmulhq_n_s16(v1338, 19393); + int16x8_t v1340 = vaddq_s16(v1334, v1339); + int16x8_t v1341 = vsubq_s16(v1103, v1105); + int16x8_t v1342 = vsubq_s16(v1107, v1109); + int16x8_t v1343_tmp = vqrdmulhq_n_s16(v1342, 5552); + int16x8_t v1343 = vaddq_s16(v1343_tmp, v1342); + int16x8_t v1344 = vaddq_s16(v1341, v1343); + int16x8_t v1345 = vsubq_s16(v1113, v1115); + int16x8_t v1346 = vsubq_s16(v1117, v1119); + int16x8_t v1347_tmp = vqrdmulhq_n_s16(v1346, 5552); + int16x8_t v1347 = vaddq_s16(v1347_tmp, v1346); + int16x8_t v1348 = vaddq_s16(v1345, v1347); + int16x8_t v1349 = vqrdmulhq_n_s16(v1348, 19393); + int16x8_t v1350 = vaddq_s16(v1344, v1349); + int16x8_t v1351 = vqrdmulhq_n_s16(v1350, 17059); + int16x8_t v1352 = vaddq_s16(v1340, v1351); + int16x8_t v1353 = vsubq_s16(v990, v995); + int16x8_t v1354 = vsubq_s16(v1000, v1005); + int16x8_t v1355_tmp = vqrdmulhq_n_s16(v1354, 15865); + int16x8_t v1355 = vaddq_s16(v1355_tmp, v1354); + int16x8_t v1356 = vaddq_s16(v1353, v1355); + int16x8_t v1357 = vsubq_s16(v1012, v1017); + int16x8_t v1358 = vsubq_s16(v1022, v1027); + int16x8_t v1359_tmp = vqrdmulhq_n_s16(v1358, 15865); + int16x8_t v1359 = vaddq_s16(v1359_tmp, v1358); + int16x8_t v1360 = vaddq_s16(v1357, v1359); + int16x8_t v1361 = vqrdmulhq_n_s16(v1360, 20040); + int16x8_t v1362 = vaddq_s16(v1356, v1361); + int16x8_t v1363 = vsubq_s16(v1036, v1041); + int16x8_t v1364 = vsubq_s16(v1046, v1051); + int16x8_t v1365_tmp = vqrdmulhq_n_s16(v1364, 15865); + int16x8_t v1365 = vaddq_s16(v1365_tmp, v1364); + int16x8_t v1366 = vaddq_s16(v1363, v1365); + int16x8_t v1367 = vsubq_s16(v1058, v1063); + int16x8_t v1368 = vsubq_s16(v1068, v1073); + int16x8_t v1369_tmp = vqrdmulhq_n_s16(v1368, 15865); + int16x8_t v1369 = vaddq_s16(v1369_tmp, v1368); + int16x8_t v1370 = vaddq_s16(v1367, v1369); + int16x8_t v1371 = vqrdmulhq_n_s16(v1370, 20040); + int16x8_t v1372 = vaddq_s16(v1366, v1371); + int16x8_t v1373 = vqrdmulhq_n_s16(v1372, 17187); + int16x8_t v1374 = vaddq_s16(v1362, v1373); + int16x8_t v1375 = vsubq_s16(v895, v900); + int16x8_t v1376 = vsubq_s16(v905, v910); + int16x8_t v1377_tmp = vqrdmulhq_n_s16(v1376, 1893); + int16x8_t v1377 = vmlaq_n_s16(v1377_tmp, v1376, 2); + int16x8_t v1378 = vaddq_s16(v1375, v1377); + int16x8_t v1379 = vsubq_s16(v918, v923); + int16x8_t v1380 = vsubq_s16(v928, v933); + int16x8_t v1381_tmp = vqrdmulhq_n_s16(v1380, 1893); + int16x8_t v1381 = vmlaq_n_s16(v1381_tmp, v1380, 2); + int16x8_t v1382 = vaddq_s16(v1379, v1381); + int16x8_t v1383 = vqrdmulhq_n_s16(v1382, 20783); + int16x8_t v1384 = vaddq_s16(v1378, v1383); + int16x8_t v1385 = vsubq_s16(v942, v947); + int16x8_t v1386 = vsubq_s16(v952, v957); + int16x8_t v1387_tmp = vqrdmulhq_n_s16(v1386, 1893); + int16x8_t v1387 = vmlaq_n_s16(v1387_tmp, v1386, 2); + int16x8_t v1388 = vaddq_s16(v1385, v1387); + int16x8_t v1389 = vsubq_s16(v964, v969); + int16x8_t v1390 = vsubq_s16(v974, v979); + int16x8_t v1391_tmp = vqrdmulhq_n_s16(v1390, 1893); + int16x8_t v1391 = vmlaq_n_s16(v1391_tmp, v1390, 2); + int16x8_t v1392 = vaddq_s16(v1389, v1391); + int16x8_t v1393 = vqrdmulhq_n_s16(v1392, 20783); + int16x8_t v1394 = vaddq_s16(v1388, v1393); + int16x8_t v1395 = vqrdmulhq_n_s16(v1394, 17326); + int16x8_t v1396 = vaddq_s16(v1384, v1395); + int16x8_t v1397 = vsubq_s16(v711, v722); + int16x8_t v1398 = vsubq_s16(v733, v744); + int16x8_t v1399_tmp = vqrdmulhq_n_s16(v1398, 13357); + int16x8_t v1399 = vmlaq_n_s16(v1399_tmp, v1398, 3); + int16x8_t v1400 = vaddq_s16(v1397, v1399); + int16x8_t v1401 = vsubq_s16(v757, v768); + int16x8_t v1402 = vsubq_s16(v779, v790); + int16x8_t v1403_tmp = vqrdmulhq_n_s16(v1402, 13357); + int16x8_t v1403 = vmlaq_n_s16(v1403_tmp, v1402, 3); + int16x8_t v1404 = vaddq_s16(v1401, v1403); + int16x8_t v1405 = vqrdmulhq_n_s16(v1404, 21637); + int16x8_t v1406 = vaddq_s16(v1400, v1405); + int16x8_t v1407 = vsubq_s16(v805, v816); + int16x8_t v1408 = vsubq_s16(v827, v838); + int16x8_t v1409_tmp = vqrdmulhq_n_s16(v1408, 13357); + int16x8_t v1409 = vmlaq_n_s16(v1409_tmp, v1408, 3); + int16x8_t v1410 = vaddq_s16(v1407, v1409); + int16x8_t v1411 = vsubq_s16(v851, v862); + int16x8_t v1412 = vsubq_s16(v873, v884); + int16x8_t v1413_tmp = vqrdmulhq_n_s16(v1412, 13357); + int16x8_t v1413 = vmlaq_n_s16(v1413_tmp, v1412, 3); + int16x8_t v1414 = vaddq_s16(v1411, v1413); + int16x8_t v1415 = vqrdmulhq_n_s16(v1414, 21637); + int16x8_t v1416 = vaddq_s16(v1410, v1415); + int16x8_t v1417 = vqrdmulhq_n_s16(v1416, 17479); + int16x8_t v1418 = vaddq_s16(v1406, v1417); + int16x8_t v1419 = vsubq_s16(v25, v60); + int16x8_t v1420 = vsubq_s16(v102, v138); + int16x8_t v1421_tmp = vqrdmulhq_n_s16(v1420, 6226); + int16x8_t v1421 = vmlaq_n_s16(v1421_tmp, v1420, 10); + int16x8_t v1422 = vaddq_s16(v1419, v1421); + int16x8_t v1423 = vsubq_s16(v182, v233); + int16x8_t v1424 = vsubq_s16(v275, v312); + int16x8_t v1425_tmp = vqrdmulhq_n_s16(v1424, 6226); + int16x8_t v1425 = vmlaq_n_s16(v1425_tmp, v1424, 10); + int16x8_t v1426 = vaddq_s16(v1423, v1425); + int16x8_t v1427 = vqrdmulhq_n_s16(v1426, 22622); + int16x8_t v1428 = vaddq_s16(v1422, v1427); + int16x8_t v1429 = vsubq_s16(v358, v409); + int16x8_t v1430 = vsubq_s16(v481, v519); + int16x8_t v1431_tmp = vqrdmulhq_n_s16(v1430, 6226); + int16x8_t v1431 = vmlaq_n_s16(v1431_tmp, v1430, 10); + int16x8_t v1432 = vaddq_s16(v1429, v1431); + int16x8_t v1433 = vsubq_s16(v563, v614); + int16x8_t v1434 = vsubq_s16(v656, v694); + int16x8_t v1435_tmp = vqrdmulhq_n_s16(v1434, 6226); + int16x8_t v1435 = vmlaq_n_s16(v1435_tmp, v1434, 10); + int16x8_t v1436 = vaddq_s16(v1433, v1435); + int16x8_t v1437 = vqrdmulhq_n_s16(v1436, 22622); + int16x8_t v1438 = vaddq_s16(v1432, v1437); + int16x8_t v1439 = vqrdmulhq_n_s16(v1438, 17646); + int16x8_t v1440 = vaddq_s16(v1428, v1439); + int16x8_t v1441 = vsubq_s16(v1419, v1421); + int16x8_t v1442 = vsubq_s16(v1423, v1425); + int16x8_t v1443 = vqrdmulhq_n_s16(v1442, 23761); + int16x8_t v1444 = vaddq_s16(v1441, v1443); + int16x8_t v1445 = vsubq_s16(v1429, v1431); + int16x8_t v1446 = vsubq_s16(v1433, v1435); + int16x8_t v1447 = vqrdmulhq_n_s16(v1446, 23761); + int16x8_t v1448 = vaddq_s16(v1445, v1447); + int16x8_t v1449 = vqrdmulhq_n_s16(v1448, 17826); + int16x8_t v1450 = vaddq_s16(v1444, v1449); + int16x8_t v1451 = vsubq_s16(v1397, v1399); + int16x8_t v1452 = vsubq_s16(v1401, v1403); + int16x8_t v1453 = vqrdmulhq_n_s16(v1452, 25084); + int16x8_t v1454 = vaddq_s16(v1451, v1453); + int16x8_t v1455 = vsubq_s16(v1407, v1409); + int16x8_t v1456 = vsubq_s16(v1411, v1413); + int16x8_t v1457 = vqrdmulhq_n_s16(v1456, 25084); + int16x8_t v1458 = vaddq_s16(v1455, v1457); + int16x8_t v1459 = vqrdmulhq_n_s16(v1458, 18021); + int16x8_t v1460 = vaddq_s16(v1454, v1459); + int16x8_t v1461 = vsubq_s16(v1375, v1377); + int16x8_t v1462 = vsubq_s16(v1379, v1381); + int16x8_t v1463 = vqrdmulhq_n_s16(v1462, 26631); + int16x8_t v1464 = vaddq_s16(v1461, v1463); + int16x8_t v1465 = vsubq_s16(v1385, v1387); + int16x8_t v1466 = vsubq_s16(v1389, v1391); + int16x8_t v1467 = vqrdmulhq_n_s16(v1466, 26631); + int16x8_t v1468 = vaddq_s16(v1465, v1467); + int16x8_t v1469 = vqrdmulhq_n_s16(v1468, 18231); + int16x8_t v1470 = vaddq_s16(v1464, v1469); + int16x8_t v1471 = vsubq_s16(v1353, v1355); + int16x8_t v1472 = vsubq_s16(v1357, v1359); + int16x8_t v1473 = vqrdmulhq_n_s16(v1472, 28454); + int16x8_t v1474 = vaddq_s16(v1471, v1473); + int16x8_t v1475 = vsubq_s16(v1363, v1365); + int16x8_t v1476 = vsubq_s16(v1367, v1369); + int16x8_t v1477 = vqrdmulhq_n_s16(v1476, 28454); + int16x8_t v1478 = vaddq_s16(v1475, v1477); + int16x8_t v1479 = vqrdmulhq_n_s16(v1478, 18458); + int16x8_t v1480 = vaddq_s16(v1474, v1479); + int16x8_t v1481 = vsubq_s16(v1331, v1333); + int16x8_t v1482 = vsubq_s16(v1335, v1337); + int16x8_t v1483 = vqrdmulhq_n_s16(v1482, 30624); + int16x8_t v1484 = vaddq_s16(v1481, v1483); + int16x8_t v1485 = vsubq_s16(v1341, v1343); + int16x8_t v1486 = vsubq_s16(v1345, v1347); + int16x8_t v1487 = vqrdmulhq_n_s16(v1486, 30624); + int16x8_t v1488 = vaddq_s16(v1485, v1487); + int16x8_t v1489 = vqrdmulhq_n_s16(v1488, 18702); + int16x8_t v1490 = vaddq_s16(v1484, v1489); + int16x8_t v1491 = vsubq_s16(v1309, v1311); + int16x8_t v1492 = vsubq_s16(v1313, v1315); + int16x8_t v1493_tmp = vqrdmulhq_n_s16(v1492, 472); + int16x8_t v1493 = vaddq_s16(v1493_tmp, v1492); + int16x8_t v1494 = vaddq_s16(v1491, v1493); + int16x8_t v1495 = vsubq_s16(v1319, v1321); + int16x8_t v1496 = vsubq_s16(v1323, v1325); + int16x8_t v1497_tmp = vqrdmulhq_n_s16(v1496, 472); + int16x8_t v1497 = vaddq_s16(v1497_tmp, v1496); + int16x8_t v1498 = vaddq_s16(v1495, v1497); + int16x8_t v1499 = vqrdmulhq_n_s16(v1498, 18964); + int16x8_t v1500 = vaddq_s16(v1494, v1499); + int16x8_t v1501 = vsubq_s16(v1287, v1289); + int16x8_t v1502 = vsubq_s16(v1291, v1293); + int16x8_t v1503_tmp = vqrdmulhq_n_s16(v1502, 3672); + int16x8_t v1503 = vaddq_s16(v1503_tmp, v1502); + int16x8_t v1504 = vaddq_s16(v1501, v1503); + int16x8_t v1505 = vsubq_s16(v1297, v1299); + int16x8_t v1506 = vsubq_s16(v1301, v1303); + int16x8_t v1507_tmp = vqrdmulhq_n_s16(v1506, 3672); + int16x8_t v1507 = vaddq_s16(v1507_tmp, v1506); + int16x8_t v1508 = vaddq_s16(v1505, v1507); + int16x8_t v1509 = vqrdmulhq_n_s16(v1508, 19245); + int16x8_t v1510 = vaddq_s16(v1504, v1509); + int16x8_t v1511 = vsubq_s16(v1265, v1267); + int16x8_t v1512 = vsubq_s16(v1269, v1271); + int16x8_t v1513_tmp = vqrdmulhq_n_s16(v1512, 7662); + int16x8_t v1513 = vaddq_s16(v1513_tmp, v1512); + int16x8_t v1514 = vaddq_s16(v1511, v1513); + int16x8_t v1515 = vsubq_s16(v1275, v1277); + int16x8_t v1516 = vsubq_s16(v1279, v1281); + int16x8_t v1517_tmp = vqrdmulhq_n_s16(v1516, 7662); + int16x8_t v1517 = vaddq_s16(v1517_tmp, v1516); + int16x8_t v1518 = vaddq_s16(v1515, v1517); + int16x8_t v1519 = vqrdmulhq_n_s16(v1518, 19546); + int16x8_t v1520 = vaddq_s16(v1514, v1519); + int16x8_t v1521 = vsubq_s16(v1222, v1227); + int16x8_t v1522 = vsubq_s16(v1232, v1237); + int16x8_t v1523_tmp = vqrdmulhq_n_s16(v1522, 12756); + int16x8_t v1523 = vaddq_s16(v1523_tmp, v1522); + int16x8_t v1524 = vaddq_s16(v1521, v1523); + int16x8_t v1525 = vsubq_s16(v1244, v1249); + int16x8_t v1526 = vsubq_s16(v1254, v1259); + int16x8_t v1527_tmp = vqrdmulhq_n_s16(v1526, 12756); + int16x8_t v1527 = vaddq_s16(v1527_tmp, v1526); + int16x8_t v1528 = vaddq_s16(v1525, v1527); + int16x8_t v1529 = vqrdmulhq_n_s16(v1528, 19869); + int16x8_t v1530 = vaddq_s16(v1524, v1529); + int16x8_t v1531 = vsubq_s16(v1176, v1181); + int16x8_t v1532 = vsubq_s16(v1186, v1191); + int16x8_t v1533_tmp = vqrdmulhq_n_s16(v1532, 19463); + int16x8_t v1533 = vaddq_s16(v1533_tmp, v1532); + int16x8_t v1534 = vaddq_s16(v1531, v1533); + int16x8_t v1535 = vsubq_s16(v1198, v1203); + int16x8_t v1536 = vsubq_s16(v1208, v1213); + int16x8_t v1537_tmp = vqrdmulhq_n_s16(v1536, 19463); + int16x8_t v1537 = vaddq_s16(v1537_tmp, v1536); + int16x8_t v1538 = vaddq_s16(v1535, v1537); + int16x8_t v1539 = vqrdmulhq_n_s16(v1538, 20216); + int16x8_t v1540 = vaddq_s16(v1534, v1539); + int16x8_t v1541 = vsubq_s16(v1130, v1135); + int16x8_t v1542 = vsubq_s16(v1140, v1145); + int16x8_t v1543_tmp = vqrdmulhq_n_s16(v1542, 28661); + int16x8_t v1543 = vaddq_s16(v1543_tmp, v1542); + int16x8_t v1544 = vaddq_s16(v1541, v1543); + int16x8_t v1545 = vsubq_s16(v1152, v1157); + int16x8_t v1546 = vsubq_s16(v1162, v1167); + int16x8_t v1547_tmp = vqrdmulhq_n_s16(v1546, 28661); + int16x8_t v1547 = vaddq_s16(v1547_tmp, v1546); + int16x8_t v1548 = vaddq_s16(v1545, v1547); + int16x8_t v1549 = vqrdmulhq_n_s16(v1548, 20587); + int16x8_t v1550 = vaddq_s16(v1544, v1549); + int16x8_t v1551 = vsubq_s16(v1084, v1089); + int16x8_t v1552 = vsubq_s16(v1094, v1099); + int16x8_t v1553_tmp = vqrdmulhq_n_s16(v1552, 9242); + int16x8_t v1553 = vmlaq_n_s16(v1553_tmp, v1552, 2); + int16x8_t v1554 = vaddq_s16(v1551, v1553); + int16x8_t v1555 = vsubq_s16(v1106, v1111); + int16x8_t v1556 = vsubq_s16(v1116, v1121); + int16x8_t v1557_tmp = vqrdmulhq_n_s16(v1556, 9242); + int16x8_t v1557 = vmlaq_n_s16(v1557_tmp, v1556, 2); + int16x8_t v1558 = vaddq_s16(v1555, v1557); + int16x8_t v1559 = vqrdmulhq_n_s16(v1558, 20985); + int16x8_t v1560 = vaddq_s16(v1554, v1559); + int16x8_t v1561 = vsubq_s16(v996, v1007); + int16x8_t v1562 = vsubq_s16(v1018, v1029); + int16x8_t v1563_tmp = vqrdmulhq_n_s16(v1562, 30298); + int16x8_t v1563 = vmlaq_n_s16(v1563_tmp, v1562, 2); + int16x8_t v1564 = vaddq_s16(v1561, v1563); + int16x8_t v1565 = vsubq_s16(v1042, v1053); + int16x8_t v1566 = vsubq_s16(v1064, v1075); + int16x8_t v1567_tmp = vqrdmulhq_n_s16(v1566, 30298); + int16x8_t v1567 = vmlaq_n_s16(v1567_tmp, v1566, 2); + int16x8_t v1568 = vaddq_s16(v1565, v1567); + int16x8_t v1569 = vqrdmulhq_n_s16(v1568, 21412); + int16x8_t v1570 = vaddq_s16(v1564, v1569); + int16x8_t v1571 = vsubq_s16(v901, v912); + int16x8_t v1572 = vsubq_s16(v924, v935); + int16x8_t v1573_tmp = vqrdmulhq_n_s16(v1572, 2773); + int16x8_t v1573 = vmlaq_n_s16(v1573_tmp, v1572, 4); + int16x8_t v1574 = vaddq_s16(v1571, v1573); + int16x8_t v1575 = vsubq_s16(v948, v959); + int16x8_t v1576 = vsubq_s16(v970, v981); + int16x8_t v1577_tmp = vqrdmulhq_n_s16(v1576, 2773); + int16x8_t v1577 = vmlaq_n_s16(v1577_tmp, v1576, 4); + int16x8_t v1578 = vaddq_s16(v1575, v1577); + int16x8_t v1579 = vqrdmulhq_n_s16(v1578, 21871); + int16x8_t v1580 = vaddq_s16(v1574, v1579); + int16x8_t v1581 = vsubq_s16(v723, v746); + int16x8_t v1582 = vsubq_s16(v769, v792); + int16x8_t v1583_tmp = vqrdmulhq_n_s16(v1582, 26108); + int16x8_t v1583 = vmlaq_n_s16(v1583_tmp, v1582, 6); + int16x8_t v1584 = vaddq_s16(v1581, v1583); + int16x8_t v1585 = vsubq_s16(v817, v840); + int16x8_t v1586 = vsubq_s16(v863, v886); + int16x8_t v1587_tmp = vqrdmulhq_n_s16(v1586, 26108); + int16x8_t v1587 = vmlaq_n_s16(v1587_tmp, v1586, 6); + int16x8_t v1588 = vaddq_s16(v1585, v1587); + int16x8_t v1589 = vqrdmulhq_n_s16(v1588, 22363); + int16x8_t v1590 = vaddq_s16(v1584, v1589); + int16x8_t v1591 = vsubq_s16(v61, v140); + int16x8_t v1592 = vsubq_s16(v234, v314); + int16x8_t v1593_tmp = vqrdmulhq_n_s16(v1592, 12251); + int16x8_t v1593 = vmlaq_n_s16(v1593_tmp, v1592, 20); + int16x8_t v1594 = vaddq_s16(v1591, v1593); + int16x8_t v1595 = vsubq_s16(v410, v521); + int16x8_t v1596 = vsubq_s16(v615, v696); + int16x8_t v1597_tmp = vqrdmulhq_n_s16(v1596, 12251); + int16x8_t v1597 = vmlaq_n_s16(v1597_tmp, v1596, 20); + int16x8_t v1598 = vaddq_s16(v1595, v1597); + int16x8_t v1599 = vqrdmulhq_n_s16(v1598, 22891); + int16x8_t v1600 = vaddq_s16(v1594, v1599); + int16x8_t v1601 = vsubq_s16(v1591, v1593); + int16x8_t v1602 = vsubq_s16(v1595, v1597); + int16x8_t v1603 = vqrdmulhq_n_s16(v1602, 23460); + int16x8_t v1604 = vaddq_s16(v1601, v1603); + int16x8_t v1605 = vsubq_s16(v1581, v1583); + int16x8_t v1606 = vsubq_s16(v1585, v1587); + int16x8_t v1607 = vqrdmulhq_n_s16(v1606, 24073); + int16x8_t v1608 = vaddq_s16(v1605, v1607); + int16x8_t v1609 = vsubq_s16(v1571, v1573); + int16x8_t v1610 = vsubq_s16(v1575, v1577); + int16x8_t v1611 = vqrdmulhq_n_s16(v1610, 24734); + int16x8_t v1612 = vaddq_s16(v1609, v1611); + int16x8_t v1613 = vsubq_s16(v1561, v1563); + int16x8_t v1614 = vsubq_s16(v1565, v1567); + int16x8_t v1615 = vqrdmulhq_n_s16(v1614, 25448); + int16x8_t v1616 = vaddq_s16(v1613, v1615); + int16x8_t v1617 = vsubq_s16(v1551, v1553); + int16x8_t v1618 = vsubq_s16(v1555, v1557); + int16x8_t v1619 = vqrdmulhq_n_s16(v1618, 26220); + int16x8_t v1620 = vaddq_s16(v1617, v1619); + int16x8_t v1621 = vsubq_s16(v1541, v1543); + int16x8_t v1622 = vsubq_s16(v1545, v1547); + int16x8_t v1623 = vqrdmulhq_n_s16(v1622, 27058); + int16x8_t v1624 = vaddq_s16(v1621, v1623); + int16x8_t v1625 = vsubq_s16(v1531, v1533); + int16x8_t v1626 = vsubq_s16(v1535, v1537); + int16x8_t v1627 = vqrdmulhq_n_s16(v1626, 27969); + int16x8_t v1628 = vaddq_s16(v1625, v1627); + int16x8_t v1629 = vsubq_s16(v1521, v1523); + int16x8_t v1630 = vsubq_s16(v1525, v1527); + int16x8_t v1631 = vqrdmulhq_n_s16(v1630, 28961); + int16x8_t v1632 = vaddq_s16(v1629, v1631); + int16x8_t v1633 = vsubq_s16(v1511, v1513); + int16x8_t v1634 = vsubq_s16(v1515, v1517); + int16x8_t v1635 = vqrdmulhq_n_s16(v1634, 30044); + int16x8_t v1636 = vaddq_s16(v1633, v1635); + int16x8_t v1637 = vsubq_s16(v1501, v1503); + int16x8_t v1638 = vsubq_s16(v1505, v1507); + int16x8_t v1639 = vqrdmulhq_n_s16(v1638, 31232); + int16x8_t v1640 = vaddq_s16(v1637, v1639); + int16x8_t v1641 = vsubq_s16(v1491, v1493); + int16x8_t v1642 = vsubq_s16(v1495, v1497); + int16x8_t v1643 = vqrdmulhq_n_s16(v1642, 32538); + int16x8_t v1644 = vaddq_s16(v1641, v1643); + int16x8_t v1645 = vsubq_s16(v1481, v1483); + int16x8_t v1646 = vsubq_s16(v1485, v1487); + int16x8_t v1647_tmp = vqrdmulhq_n_s16(v1646, 1211); + int16x8_t v1647 = vaddq_s16(v1647_tmp, v1646); + int16x8_t v1648 = vaddq_s16(v1645, v1647); + int16x8_t v1649 = vsubq_s16(v1471, v1473); + int16x8_t v1650 = vsubq_s16(v1475, v1477); + int16x8_t v1651_tmp = vqrdmulhq_n_s16(v1650, 2808); + int16x8_t v1651 = vaddq_s16(v1651_tmp, v1650); + int16x8_t v1652 = vaddq_s16(v1649, v1651); + int16x8_t v1653 = vsubq_s16(v1461, v1463); + int16x8_t v1654 = vsubq_s16(v1465, v1467); + int16x8_t v1655_tmp = vqrdmulhq_n_s16(v1654, 4586); + int16x8_t v1655 = vaddq_s16(v1655_tmp, v1654); + int16x8_t v1656 = vaddq_s16(v1653, v1655); + int16x8_t v1657 = vsubq_s16(v1451, v1453); + int16x8_t v1658 = vsubq_s16(v1455, v1457); + int16x8_t v1659_tmp = vqrdmulhq_n_s16(v1658, 6576); + int16x8_t v1659 = vaddq_s16(v1659_tmp, v1658); + int16x8_t v1660 = vaddq_s16(v1657, v1659); + int16x8_t v1661 = vsubq_s16(v1441, v1443); + int16x8_t v1662 = vsubq_s16(v1445, v1447); + int16x8_t v1663_tmp = vqrdmulhq_n_s16(v1662, 8817); + int16x8_t v1663 = vaddq_s16(v1663_tmp, v1662); + int16x8_t v1664 = vaddq_s16(v1661, v1663); + int16x8_t v1665 = vsubq_s16(v1422, v1427); + int16x8_t v1666 = vsubq_s16(v1432, v1437); + int16x8_t v1667_tmp = vqrdmulhq_n_s16(v1666, 11356); + int16x8_t v1667 = vaddq_s16(v1667_tmp, v1666); + int16x8_t v1668 = vaddq_s16(v1665, v1667); + int16x8_t v1669 = vsubq_s16(v1400, v1405); + int16x8_t v1670 = vsubq_s16(v1410, v1415); + int16x8_t v1671_tmp = vqrdmulhq_n_s16(v1670, 14256); + int16x8_t v1671 = vaddq_s16(v1671_tmp, v1670); + int16x8_t v1672 = vaddq_s16(v1669, v1671); + int16x8_t v1673 = vsubq_s16(v1378, v1383); + int16x8_t v1674 = vsubq_s16(v1388, v1393); + int16x8_t v1675_tmp = vqrdmulhq_n_s16(v1674, 17596); + int16x8_t v1675 = vaddq_s16(v1675_tmp, v1674); + int16x8_t v1676 = vaddq_s16(v1673, v1675); + int16x8_t v1677 = vsubq_s16(v1356, v1361); + int16x8_t v1678 = vsubq_s16(v1366, v1371); + int16x8_t v1679_tmp = vqrdmulhq_n_s16(v1678, 21483); + int16x8_t v1679 = vaddq_s16(v1679_tmp, v1678); + int16x8_t v1680 = vaddq_s16(v1677, v1679); + int16x8_t v1681 = vsubq_s16(v1334, v1339); + int16x8_t v1682 = vsubq_s16(v1344, v1349); + int16x8_t v1683_tmp = vqrdmulhq_n_s16(v1682, 26057); + int16x8_t v1683 = vaddq_s16(v1683_tmp, v1682); + int16x8_t v1684 = vaddq_s16(v1681, v1683); + int16x8_t v1685 = vsubq_s16(v1312, v1317); + int16x8_t v1686 = vsubq_s16(v1322, v1327); + int16x8_t v1687_tmp = vqrdmulhq_n_s16(v1686, 31517); + int16x8_t v1687 = vaddq_s16(v1687_tmp, v1686); + int16x8_t v1688 = vaddq_s16(v1685, v1687); + int16x8_t v1689 = vsubq_s16(v1290, v1295); + int16x8_t v1690 = vsubq_s16(v1300, v1305); + int16x8_t v1691_tmp = vqrdmulhq_n_s16(v1690, 5373); + int16x8_t v1691 = vmlaq_n_s16(v1691_tmp, v1690, 2); + int16x8_t v1692 = vaddq_s16(v1689, v1691); + int16x8_t v1693 = vsubq_s16(v1268, v1273); + int16x8_t v1694 = vsubq_s16(v1278, v1283); + int16x8_t v1695_tmp = vqrdmulhq_n_s16(v1694, 13571); + int16x8_t v1695 = vmlaq_n_s16(v1695_tmp, v1694, 2); + int16x8_t v1696 = vaddq_s16(v1693, v1695); + int16x8_t v1697 = vsubq_s16(v1228, v1239); + int16x8_t v1698 = vsubq_s16(v1250, v1261); + int16x8_t v1699_tmp = vqrdmulhq_n_s16(v1698, 23975); + int16x8_t v1699 = vmlaq_n_s16(v1699_tmp, v1698, 2); + int16x8_t v1700 = vaddq_s16(v1697, v1699); + int16x8_t v1701 = vsubq_s16(v1182, v1193); + int16x8_t v1702 = vsubq_s16(v1204, v1215); + int16x8_t v1703_tmp = vqrdmulhq_n_s16(v1702, 4832); + int16x8_t v1703 = vmlaq_n_s16(v1703_tmp, v1702, 3); + int16x8_t v1704 = vaddq_s16(v1701, v1703); + int16x8_t v1705 = vsubq_s16(v1136, v1147); + int16x8_t v1706 = vsubq_s16(v1158, v1169); + int16x8_t v1707_tmp = vqrdmulhq_n_s16(v1706, 23437); + int16x8_t v1707 = vmlaq_n_s16(v1707_tmp, v1706, 3); + int16x8_t v1708 = vaddq_s16(v1705, v1707); + int16x8_t v1709 = vsubq_s16(v1090, v1101); + int16x8_t v1710 = vsubq_s16(v1112, v1123); + int16x8_t v1711_tmp = vqrdmulhq_n_s16(v1710, 17573); + int16x8_t v1711 = vmlaq_n_s16(v1711_tmp, v1710, 4); + int16x8_t v1712 = vaddq_s16(v1709, v1711); + int16x8_t v1713 = vsubq_s16(v1008, v1031); + int16x8_t v1714 = vsubq_s16(v1054, v1077); + int16x8_t v1715_tmp = vqrdmulhq_n_s16(v1714, 27122); + int16x8_t v1715 = vmlaq_n_s16(v1715_tmp, v1714, 5); + int16x8_t v1716 = vaddq_s16(v1713, v1715); + int16x8_t v1717 = vsubq_s16(v913, v937); + int16x8_t v1718 = vsubq_s16(v960, v983); + int16x8_t v1719_tmp = vqrdmulhq_n_s16(v1718, 5041); + int16x8_t v1719 = vmlaq_n_s16(v1719_tmp, v1718, 8); + int16x8_t v1720 = vaddq_s16(v1717, v1719); + int16x8_t v1721 = vsubq_s16(v747, v794); + int16x8_t v1722 = vsubq_s16(v841, v888); + int16x8_t v1723_tmp = vqrdmulhq_n_s16(v1722, 19146); + int16x8_t v1723 = vmlaq_n_s16(v1723_tmp, v1722, 13); + int16x8_t v1724 = vaddq_s16(v1721, v1723); + int16x8_t v1725 = vsubq_s16(v141, v316); + int16x8_t v1726 = vsubq_s16(v522, v698); + int16x8_t v1727_tmp = vqrdmulhq_n_s16(v1726, 24402); + int16x8_t v1727 = vmlaq_n_s16(v1727_tmp, v1726, 40); + int16x8_t v1728 = vaddq_s16(v1725, v1727); + int16x8_t v1729 = vsubq_s16(v1725, v1727); + int16x8_t v1730 = vsubq_s16(v1721, v1723); + int16x8_t v1731 = vsubq_s16(v1717, v1719); + int16x8_t v1732 = vsubq_s16(v1713, v1715); + int16x8_t v1733 = vsubq_s16(v1709, v1711); + int16x8_t v1734 = vsubq_s16(v1705, v1707); + int16x8_t v1735 = vsubq_s16(v1701, v1703); + int16x8_t v1736 = vsubq_s16(v1697, v1699); + int16x8_t v1737 = vsubq_s16(v1693, v1695); + int16x8_t v1738 = vsubq_s16(v1689, v1691); + int16x8_t v1739 = vsubq_s16(v1685, v1687); + int16x8_t v1740 = vsubq_s16(v1681, v1683); + int16x8_t v1741 = vsubq_s16(v1677, v1679); + int16x8_t v1742 = vsubq_s16(v1673, v1675); + int16x8_t v1743 = vsubq_s16(v1669, v1671); + int16x8_t v1744 = vsubq_s16(v1665, v1667); + int16x8_t v1745 = vsubq_s16(v1661, v1663); + int16x8_t v1746 = vsubq_s16(v1657, v1659); + int16x8_t v1747 = vsubq_s16(v1653, v1655); + int16x8_t v1748 = vsubq_s16(v1649, v1651); + int16x8_t v1749 = vsubq_s16(v1645, v1647); + int16x8_t v1750 = vsubq_s16(v1641, v1643); + int16x8_t v1751 = vsubq_s16(v1637, v1639); + int16x8_t v1752 = vsubq_s16(v1633, v1635); + int16x8_t v1753 = vsubq_s16(v1629, v1631); + int16x8_t v1754 = vsubq_s16(v1625, v1627); + int16x8_t v1755 = vsubq_s16(v1621, v1623); + int16x8_t v1756 = vsubq_s16(v1617, v1619); + int16x8_t v1757 = vsubq_s16(v1613, v1615); + int16x8_t v1758 = vsubq_s16(v1609, v1611); + int16x8_t v1759 = vsubq_s16(v1605, v1607); + int16x8_t v1760 = vsubq_s16(v1601, v1603); + int16x8_t v1761 = vsubq_s16(v1594, v1599); + int16x8_t v1762 = vsubq_s16(v1584, v1589); + int16x8_t v1763 = vsubq_s16(v1574, v1579); + int16x8_t v1764 = vsubq_s16(v1564, v1569); + int16x8_t v1765 = vsubq_s16(v1554, v1559); + int16x8_t v1766 = vsubq_s16(v1544, v1549); + int16x8_t v1767 = vsubq_s16(v1534, v1539); + int16x8_t v1768 = vsubq_s16(v1524, v1529); + int16x8_t v1769 = vsubq_s16(v1514, v1519); + int16x8_t v1770 = vsubq_s16(v1504, v1509); + int16x8_t v1771 = vsubq_s16(v1494, v1499); + int16x8_t v1772 = vsubq_s16(v1484, v1489); + int16x8_t v1773 = vsubq_s16(v1474, v1479); + int16x8_t v1774 = vsubq_s16(v1464, v1469); + int16x8_t v1775 = vsubq_s16(v1454, v1459); + int16x8_t v1776 = vsubq_s16(v1444, v1449); + int16x8_t v1777 = vsubq_s16(v1428, v1439); + int16x8_t v1778 = vsubq_s16(v1406, v1417); + int16x8_t v1779 = vsubq_s16(v1384, v1395); + int16x8_t v1780 = vsubq_s16(v1362, v1373); + int16x8_t v1781 = vsubq_s16(v1340, v1351); + int16x8_t v1782 = vsubq_s16(v1318, v1329); + int16x8_t v1783 = vsubq_s16(v1296, v1307); + int16x8_t v1784 = vsubq_s16(v1274, v1285); + int16x8_t v1785 = vsubq_s16(v1240, v1263); + int16x8_t v1786 = vsubq_s16(v1194, v1217); + int16x8_t v1787 = vsubq_s16(v1148, v1171); + int16x8_t v1788 = vsubq_s16(v1102, v1125); + int16x8_t v1789 = vsubq_s16(v1032, v1079); + int16x8_t v1790 = vsubq_s16(v938, v985); + int16x8_t v1791 = vsubq_s16(v795, v890); + int16x8_t v1792 = vsubq_s16(v317, v700); + vst1q_s16(out + out_stride * 0 + i, v701); + vst1q_s16(out + out_stride * 1 + i, v891); + vst1q_s16(out + out_stride * 2 + i, v986); + vst1q_s16(out + out_stride * 3 + i, v1080); + vst1q_s16(out + out_stride * 4 + i, v1126); + vst1q_s16(out + out_stride * 5 + i, v1172); + vst1q_s16(out + out_stride * 6 + i, v1218); + vst1q_s16(out + out_stride * 7 + i, v1264); + vst1q_s16(out + out_stride * 8 + i, v1286); + vst1q_s16(out + out_stride * 9 + i, v1308); + vst1q_s16(out + out_stride * 10 + i, v1330); + vst1q_s16(out + out_stride * 11 + i, v1352); + vst1q_s16(out + out_stride * 12 + i, v1374); + vst1q_s16(out + out_stride * 13 + i, v1396); + vst1q_s16(out + out_stride * 14 + i, v1418); + vst1q_s16(out + out_stride * 15 + i, v1440); + vst1q_s16(out + out_stride * 16 + i, v1450); + vst1q_s16(out + out_stride * 17 + i, v1460); + vst1q_s16(out + out_stride * 18 + i, v1470); + vst1q_s16(out + out_stride * 19 + i, v1480); + vst1q_s16(out + out_stride * 20 + i, v1490); + vst1q_s16(out + out_stride * 21 + i, v1500); + vst1q_s16(out + out_stride * 22 + i, v1510); + vst1q_s16(out + out_stride * 23 + i, v1520); + vst1q_s16(out + out_stride * 24 + i, v1530); + vst1q_s16(out + out_stride * 25 + i, v1540); + vst1q_s16(out + out_stride * 26 + i, v1550); + vst1q_s16(out + out_stride * 27 + i, v1560); + vst1q_s16(out + out_stride * 28 + i, v1570); + vst1q_s16(out + out_stride * 29 + i, v1580); + vst1q_s16(out + out_stride * 30 + i, v1590); + vst1q_s16(out + out_stride * 31 + i, v1600); + vst1q_s16(out + out_stride * 32 + i, v1604); + vst1q_s16(out + out_stride * 33 + i, v1608); + vst1q_s16(out + out_stride * 34 + i, v1612); + vst1q_s16(out + out_stride * 35 + i, v1616); + vst1q_s16(out + out_stride * 36 + i, v1620); + vst1q_s16(out + out_stride * 37 + i, v1624); + vst1q_s16(out + out_stride * 38 + i, v1628); + vst1q_s16(out + out_stride * 39 + i, v1632); + vst1q_s16(out + out_stride * 40 + i, v1636); + vst1q_s16(out + out_stride * 41 + i, v1640); + vst1q_s16(out + out_stride * 42 + i, v1644); + vst1q_s16(out + out_stride * 43 + i, v1648); + vst1q_s16(out + out_stride * 44 + i, v1652); + vst1q_s16(out + out_stride * 45 + i, v1656); + vst1q_s16(out + out_stride * 46 + i, v1660); + vst1q_s16(out + out_stride * 47 + i, v1664); + vst1q_s16(out + out_stride * 48 + i, v1668); + vst1q_s16(out + out_stride * 49 + i, v1672); + vst1q_s16(out + out_stride * 50 + i, v1676); + vst1q_s16(out + out_stride * 51 + i, v1680); + vst1q_s16(out + out_stride * 52 + i, v1684); + vst1q_s16(out + out_stride * 53 + i, v1688); + vst1q_s16(out + out_stride * 54 + i, v1692); + vst1q_s16(out + out_stride * 55 + i, v1696); + vst1q_s16(out + out_stride * 56 + i, v1700); + vst1q_s16(out + out_stride * 57 + i, v1704); + vst1q_s16(out + out_stride * 58 + i, v1708); + vst1q_s16(out + out_stride * 59 + i, v1712); + vst1q_s16(out + out_stride * 60 + i, v1716); + vst1q_s16(out + out_stride * 61 + i, v1720); + vst1q_s16(out + out_stride * 62 + i, v1724); + vst1q_s16(out + out_stride * 63 + i, v1728); + vst1q_s16(out + out_stride * 64 + i, v1729); + vst1q_s16(out + out_stride * 65 + i, v1730); + vst1q_s16(out + out_stride * 66 + i, v1731); + vst1q_s16(out + out_stride * 67 + i, v1732); + vst1q_s16(out + out_stride * 68 + i, v1733); + vst1q_s16(out + out_stride * 69 + i, v1734); + vst1q_s16(out + out_stride * 70 + i, v1735); + vst1q_s16(out + out_stride * 71 + i, v1736); + vst1q_s16(out + out_stride * 72 + i, v1737); + vst1q_s16(out + out_stride * 73 + i, v1738); + vst1q_s16(out + out_stride * 74 + i, v1739); + vst1q_s16(out + out_stride * 75 + i, v1740); + vst1q_s16(out + out_stride * 76 + i, v1741); + vst1q_s16(out + out_stride * 77 + i, v1742); + vst1q_s16(out + out_stride * 78 + i, v1743); + vst1q_s16(out + out_stride * 79 + i, v1744); + vst1q_s16(out + out_stride * 80 + i, v1745); + vst1q_s16(out + out_stride * 81 + i, v1746); + vst1q_s16(out + out_stride * 82 + i, v1747); + vst1q_s16(out + out_stride * 83 + i, v1748); + vst1q_s16(out + out_stride * 84 + i, v1749); + vst1q_s16(out + out_stride * 85 + i, v1750); + vst1q_s16(out + out_stride * 86 + i, v1751); + vst1q_s16(out + out_stride * 87 + i, v1752); + vst1q_s16(out + out_stride * 88 + i, v1753); + vst1q_s16(out + out_stride * 89 + i, v1754); + vst1q_s16(out + out_stride * 90 + i, v1755); + vst1q_s16(out + out_stride * 91 + i, v1756); + vst1q_s16(out + out_stride * 92 + i, v1757); + vst1q_s16(out + out_stride * 93 + i, v1758); + vst1q_s16(out + out_stride * 94 + i, v1759); + vst1q_s16(out + out_stride * 95 + i, v1760); + vst1q_s16(out + out_stride * 96 + i, v1761); + vst1q_s16(out + out_stride * 97 + i, v1762); + vst1q_s16(out + out_stride * 98 + i, v1763); + vst1q_s16(out + out_stride * 99 + i, v1764); + vst1q_s16(out + out_stride * 100 + i, v1765); + vst1q_s16(out + out_stride * 101 + i, v1766); + vst1q_s16(out + out_stride * 102 + i, v1767); + vst1q_s16(out + out_stride * 103 + i, v1768); + vst1q_s16(out + out_stride * 104 + i, v1769); + vst1q_s16(out + out_stride * 105 + i, v1770); + vst1q_s16(out + out_stride * 106 + i, v1771); + vst1q_s16(out + out_stride * 107 + i, v1772); + vst1q_s16(out + out_stride * 108 + i, v1773); + vst1q_s16(out + out_stride * 109 + i, v1774); + vst1q_s16(out + out_stride * 110 + i, v1775); + vst1q_s16(out + out_stride * 111 + i, v1776); + vst1q_s16(out + out_stride * 112 + i, v1777); + vst1q_s16(out + out_stride * 113 + i, v1778); + vst1q_s16(out + out_stride * 114 + i, v1779); + vst1q_s16(out + out_stride * 115 + i, v1780); + vst1q_s16(out + out_stride * 116 + i, v1781); + vst1q_s16(out + out_stride * 117 + i, v1782); + vst1q_s16(out + out_stride * 118 + i, v1783); + vst1q_s16(out + out_stride * 119 + i, v1784); + vst1q_s16(out + out_stride * 120 + i, v1785); + vst1q_s16(out + out_stride * 121 + i, v1786); + vst1q_s16(out + out_stride * 122 + i, v1787); + vst1q_s16(out + out_stride * 123 + i, v1788); + vst1q_s16(out + out_stride * 124 + i, v1789); + vst1q_s16(out + out_stride * 125 + i, v1790); + vst1q_s16(out + out_stride * 126 + i, v1791); + vst1q_s16(out + out_stride * 127 + i, v1792); + } +} diff --git a/third_party/jpeg-xl/lib/jxl/fast_dct16-inl.h b/third_party/jpeg-xl/lib/jxl/fast_dct16-inl.h new file mode 100644 index 0000000000..472ec20d42 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_dct16-inl.h @@ -0,0 +1,180 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<16>) { return 1; } + +void FastIDCT(FastDCTTag<16>, const int16_t* in, size_t in_stride, int16_t* out, + size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 8 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 12 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 10 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vaddq_s16(v13, v10); + int16x8_t v17 = vqrdmulhq_n_s16(v16, 25080); + int16x8_t v18 = vld1q_s16(in + in_stride * 14 + i); + int16x8_t v19 = vaddq_s16(v18, v12); + int16x8_t v20 = vaddq_s16(v16, v19); + int16x8_t v21 = vqrdmulhq_n_s16(v20, 17734); + int16x8_t v22 = vaddq_s16(v17, v21); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vld1q_s16(in + in_stride * 15 + i); + int16x8_t v27 = vld1q_s16(in + in_stride * 13 + i); + int16x8_t v28 = vaddq_s16(v26, v27); + int16x8_t v29 = vld1q_s16(in + in_stride * 11 + i); + int16x8_t v30 = vld1q_s16(in + in_stride * 9 + i); + int16x8_t v31 = vaddq_s16(v29, v30); + int16x8_t v32 = vaddq_s16(v28, v31); + int16x8_t v33 = vqrdmulhq_n_s16(v32, 17734); + int16x8_t v34 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v35 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v36 = vaddq_s16(v34, v35); + int16x8_t v37 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v38 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v39 = vaddq_s16(v37, v38); + int16x8_t v40 = vaddq_s16(v36, v39); + int16x8_t v41_tmp = vqrdmulhq_n_s16(v40, 10045); + int16x8_t v41 = vaddq_s16(v41_tmp, v40); + int16x8_t v42 = vaddq_s16(v33, v41); + int16x8_t v43 = vqrdmulhq_n_s16(v42, 16705); + int16x8_t v44_tmp = vqrdmulhq_n_s16(v36, 13573); + int16x8_t v44 = vaddq_s16(v44_tmp, v36); + int16x8_t v45 = vaddq_s16(v39, v31); + int16x8_t v46 = vaddq_s16(v44, v45); + int16x8_t v47 = vqrdmulhq_n_s16(v46, 16705); + int16x8_t v48 = vaddq_s16(v43, v47); + int16x8_t v49_tmp = vqrdmulhq_n_s16(v35, 13573); + int16x8_t v49 = vaddq_s16(v49_tmp, v35); + int16x8_t v50 = vaddq_s16(v30, v37); + int16x8_t v51 = vaddq_s16(v49, v50); + int16x8_t v52 = vaddq_s16(v38, v34); + int16x8_t v53 = vaddq_s16(v27, v29); + int16x8_t v54 = vaddq_s16(v52, v53); + int16x8_t v55 = vqrdmulhq_n_s16(v54, 17734); + int16x8_t v56 = vqrdmulhq_n_s16(v52, 25080); + int16x8_t v57 = vaddq_s16(v55, v56); + int16x8_t v58 = vaddq_s16(v51, v57); + int16x8_t v59 = vaddq_s16(v48, v58); + int16x8_t v60 = vqrdmulhq_n_s16(v59, 16463); + int16x8_t v61 = vaddq_s16(v25, v60); + int16x8_t v62 = vsubq_s16(v0, v1); + int16x8_t v63 = vsubq_s16(v4, v6); + int16x8_t v64_tmp = vqrdmulhq_n_s16(v63, 10045); + int16x8_t v64 = vaddq_s16(v64_tmp, v63); + int16x8_t v65 = vaddq_s16(v62, v64); + int16x8_t v66 = vsubq_s16(v11, v14); + int16x8_t v67 = vqrdmulhq_n_s16(v16, 17734); + int16x8_t v68_tmp = vqrdmulhq_n_s16(v19, 10045); + int16x8_t v68 = vaddq_s16(v68_tmp, v19); + int16x8_t v69 = vsubq_s16(v67, v68); + int16x8_t v70 = vaddq_s16(v66, v69); + int16x8_t v71 = vqrdmulhq_n_s16(v70, 19705); + int16x8_t v72 = vaddq_s16(v65, v71); + int16x8_t v73 = vsubq_s16(v49, v50); + int16x8_t v74 = vqrdmulhq_n_s16(v52, 17734); + int16x8_t v75_tmp = vqrdmulhq_n_s16(v53, 10045); + int16x8_t v75 = vaddq_s16(v75_tmp, v53); + int16x8_t v76 = vsubq_s16(v74, v75); + int16x8_t v77 = vaddq_s16(v73, v76); + int16x8_t v78 = vsubq_s16(v44, v45); + int16x8_t v79 = vqrdmulhq_n_s16(v78, 19705); + int16x8_t v80 = vqrdmulhq_n_s16(v40, 13573); + int16x8_t v81 = vsubq_s16(v80, v32); + int16x8_t v82 = vqrdmulhq_n_s16(v81, 25746); + int16x8_t v83 = vaddq_s16(v79, v82); + int16x8_t v84 = vaddq_s16(v77, v83); + int16x8_t v85 = vqrdmulhq_n_s16(v84, 17121); + int16x8_t v86 = vaddq_s16(v72, v85); + int16x8_t v87 = vsubq_s16(v62, v64); + int16x8_t v88 = vsubq_s16(v66, v69); + int16x8_t v89 = vqrdmulhq_n_s16(v88, 29490); + int16x8_t v90 = vaddq_s16(v87, v89); + int16x8_t v91 = vsubq_s16(v73, v76); + int16x8_t v92 = vqrdmulhq_n_s16(v78, 29490); + int16x8_t v93_tmp = vqrdmulhq_n_s16(v81, 5763); + int16x8_t v93 = vaddq_s16(v93_tmp, v81); + int16x8_t v94 = vsubq_s16(v92, v93); + int16x8_t v95 = vaddq_s16(v91, v94); + int16x8_t v96 = vqrdmulhq_n_s16(v95, 18578); + int16x8_t v97 = vaddq_s16(v90, v96); + int16x8_t v98 = vsubq_s16(v46, v42); + int16x8_t v99_tmp = vqrdmulhq_n_s16(v98, 18446); + int16x8_t v99 = vmlaq_n_s16(v99_tmp, v98, 2); + int16x8_t v100 = vsubq_s16(v51, v57); + int16x8_t v101 = vaddq_s16(v99, v100); + int16x8_t v102 = vqrdmulhq_n_s16(v101, 21195); + int16x8_t v103 = vsubq_s16(v2, v8); + int16x8_t v104 = vsubq_s16(v15, v22); + int16x8_t v105_tmp = vqrdmulhq_n_s16(v104, 18446); + int16x8_t v105 = vmlaq_n_s16(v105_tmp, v104, 2); + int16x8_t v106 = vaddq_s16(v103, v105); + int16x8_t v107 = vaddq_s16(v102, v106); + int16x8_t v108 = vsubq_s16(v103, v105); + int16x8_t v109 = vsubq_s16(v100, v99); + int16x8_t v110 = vqrdmulhq_n_s16(v109, 25826); + int16x8_t v111 = vaddq_s16(v108, v110); + int16x8_t v112 = vsubq_s16(v87, v89); + int16x8_t v113 = vsubq_s16(v91, v94); + int16x8_t v114_tmp = vqrdmulhq_n_s16(v113, 1988); + int16x8_t v114 = vaddq_s16(v114_tmp, v113); + int16x8_t v115 = vaddq_s16(v112, v114); + int16x8_t v116 = vsubq_s16(v65, v71); + int16x8_t v117 = vsubq_s16(v77, v83); + int16x8_t v118_tmp = vqrdmulhq_n_s16(v117, 23673); + int16x8_t v118 = vaddq_s16(v118_tmp, v117); + int16x8_t v119 = vaddq_s16(v116, v118); + int16x8_t v120 = vsubq_s16(v58, v48); + int16x8_t v121_tmp = vqrdmulhq_n_s16(v120, 3314); + int16x8_t v121 = vmlaq_n_s16(v121_tmp, v120, 5); + int16x8_t v122 = vsubq_s16(v9, v24); + int16x8_t v123 = vaddq_s16(v121, v122); + int16x8_t v124 = vsubq_s16(v122, v121); + int16x8_t v125 = vsubq_s16(v116, v118); + int16x8_t v126 = vsubq_s16(v112, v114); + int16x8_t v127 = vsubq_s16(v108, v110); + int16x8_t v128 = vsubq_s16(v106, v102); + int16x8_t v129 = vsubq_s16(v90, v96); + int16x8_t v130 = vsubq_s16(v72, v85); + int16x8_t v131 = vsubq_s16(v25, v60); + vst1q_s16(out + out_stride * 0 + i, v61); + vst1q_s16(out + out_stride * 1 + i, v86); + vst1q_s16(out + out_stride * 2 + i, v97); + vst1q_s16(out + out_stride * 3 + i, v107); + vst1q_s16(out + out_stride * 4 + i, v111); + vst1q_s16(out + out_stride * 5 + i, v115); + vst1q_s16(out + out_stride * 6 + i, v119); + vst1q_s16(out + out_stride * 7 + i, v123); + vst1q_s16(out + out_stride * 8 + i, v124); + vst1q_s16(out + out_stride * 9 + i, v125); + vst1q_s16(out + out_stride * 10 + i, v126); + vst1q_s16(out + out_stride * 11 + i, v127); + vst1q_s16(out + out_stride * 12 + i, v128); + vst1q_s16(out + out_stride * 13 + i, v129); + vst1q_s16(out + out_stride * 14 + i, v130); + vst1q_s16(out + out_stride * 15 + i, v131); + } +} diff --git a/third_party/jpeg-xl/lib/jxl/fast_dct256-inl.h b/third_party/jpeg-xl/lib/jxl/fast_dct256-inl.h new file mode 100644 index 0000000000..a823440af2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_dct256-inl.h @@ -0,0 +1,4811 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<256>) { return 3; } + +void FastIDCT(FastDCTTag<256>, const int16_t* in, size_t in_stride, + int16_t* out, size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 128 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 64 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 192 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 32 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 160 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 96 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vaddq_s16(v13, v10); + int16x8_t v17_tmp = vqrdmulhq_n_s16(v16, 13573); + int16x8_t v17 = vaddq_s16(v17_tmp, v16); + int16x8_t v18 = vld1q_s16(in + in_stride * 224 + i); + int16x8_t v19 = vaddq_s16(v18, v12); + int16x8_t v20 = vaddq_s16(v19, v16); + int16x8_t v21 = vaddq_s16(v17, v20); + int16x8_t v22 = vqrdmulhq_n_s16(v21, 17734); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vld1q_s16(in + in_stride * 16 + i); + int16x8_t v27_tmp = vqrdmulhq_n_s16(v26, 13573); + int16x8_t v27 = vaddq_s16(v27_tmp, v26); + int16x8_t v28 = vld1q_s16(in + in_stride * 144 + i); + int16x8_t v29 = vld1q_s16(in + in_stride * 112 + i); + int16x8_t v30 = vaddq_s16(v28, v29); + int16x8_t v31 = vaddq_s16(v27, v30); + int16x8_t v32 = vld1q_s16(in + in_stride * 80 + i); + int16x8_t v33 = vld1q_s16(in + in_stride * 48 + i); + int16x8_t v34 = vaddq_s16(v32, v33); + int16x8_t v35_tmp = vqrdmulhq_n_s16(v34, 13573); + int16x8_t v35 = vaddq_s16(v35_tmp, v34); + int16x8_t v36 = vld1q_s16(in + in_stride * 208 + i); + int16x8_t v37 = vld1q_s16(in + in_stride * 176 + i); + int16x8_t v38 = vaddq_s16(v36, v37); + int16x8_t v39 = vaddq_s16(v38, v34); + int16x8_t v40 = vaddq_s16(v35, v39); + int16x8_t v41 = vqrdmulhq_n_s16(v40, 17734); + int16x8_t v42 = vaddq_s16(v31, v41); + int16x8_t v43 = vaddq_s16(v33, v26); + int16x8_t v44_tmp = vqrdmulhq_n_s16(v43, 13573); + int16x8_t v44 = vaddq_s16(v44_tmp, v43); + int16x8_t v45 = vaddq_s16(v37, v28); + int16x8_t v46 = vaddq_s16(v29, v32); + int16x8_t v47 = vaddq_s16(v45, v46); + int16x8_t v48 = vaddq_s16(v44, v47); + int16x8_t v49 = vaddq_s16(v46, v43); + int16x8_t v50_tmp = vqrdmulhq_n_s16(v49, 13573); + int16x8_t v50 = vaddq_s16(v50_tmp, v49); + int16x8_t v51 = vld1q_s16(in + in_stride * 240 + i); + int16x8_t v52 = vaddq_s16(v51, v36); + int16x8_t v53 = vaddq_s16(v52, v45); + int16x8_t v54 = vaddq_s16(v53, v49); + int16x8_t v55 = vaddq_s16(v50, v54); + int16x8_t v56 = vqrdmulhq_n_s16(v55, 17734); + int16x8_t v57 = vaddq_s16(v48, v56); + int16x8_t v58 = vqrdmulhq_n_s16(v57, 16705); + int16x8_t v59 = vaddq_s16(v42, v58); + int16x8_t v60 = vqrdmulhq_n_s16(v59, 16463); + int16x8_t v61 = vaddq_s16(v25, v60); + int16x8_t v62 = vld1q_s16(in + in_stride * 8 + i); + int16x8_t v63_tmp = vqrdmulhq_n_s16(v62, 13573); + int16x8_t v63 = vaddq_s16(v63_tmp, v62); + int16x8_t v64 = vld1q_s16(in + in_stride * 136 + i); + int16x8_t v65 = vld1q_s16(in + in_stride * 120 + i); + int16x8_t v66 = vaddq_s16(v64, v65); + int16x8_t v67 = vaddq_s16(v63, v66); + int16x8_t v68 = vld1q_s16(in + in_stride * 72 + i); + int16x8_t v69 = vld1q_s16(in + in_stride * 56 + i); + int16x8_t v70 = vaddq_s16(v68, v69); + int16x8_t v71_tmp = vqrdmulhq_n_s16(v70, 13573); + int16x8_t v71 = vaddq_s16(v71_tmp, v70); + int16x8_t v72 = vld1q_s16(in + in_stride * 200 + i); + int16x8_t v73 = vld1q_s16(in + in_stride * 184 + i); + int16x8_t v74 = vaddq_s16(v72, v73); + int16x8_t v75 = vaddq_s16(v74, v70); + int16x8_t v76 = vaddq_s16(v71, v75); + int16x8_t v77 = vqrdmulhq_n_s16(v76, 17734); + int16x8_t v78 = vaddq_s16(v67, v77); + int16x8_t v79 = vld1q_s16(in + in_stride * 40 + i); + int16x8_t v80 = vld1q_s16(in + in_stride * 24 + i); + int16x8_t v81 = vaddq_s16(v79, v80); + int16x8_t v82_tmp = vqrdmulhq_n_s16(v81, 13573); + int16x8_t v82 = vaddq_s16(v82_tmp, v81); + int16x8_t v83 = vld1q_s16(in + in_stride * 168 + i); + int16x8_t v84 = vld1q_s16(in + in_stride * 152 + i); + int16x8_t v85 = vaddq_s16(v83, v84); + int16x8_t v86 = vld1q_s16(in + in_stride * 104 + i); + int16x8_t v87 = vld1q_s16(in + in_stride * 88 + i); + int16x8_t v88 = vaddq_s16(v86, v87); + int16x8_t v89 = vaddq_s16(v85, v88); + int16x8_t v90 = vaddq_s16(v82, v89); + int16x8_t v91 = vaddq_s16(v88, v81); + int16x8_t v92_tmp = vqrdmulhq_n_s16(v91, 13573); + int16x8_t v92 = vaddq_s16(v92_tmp, v91); + int16x8_t v93 = vld1q_s16(in + in_stride * 232 + i); + int16x8_t v94 = vld1q_s16(in + in_stride * 216 + i); + int16x8_t v95 = vaddq_s16(v93, v94); + int16x8_t v96 = vaddq_s16(v95, v85); + int16x8_t v97 = vaddq_s16(v96, v91); + int16x8_t v98 = vaddq_s16(v92, v97); + int16x8_t v99 = vqrdmulhq_n_s16(v98, 17734); + int16x8_t v100 = vaddq_s16(v90, v99); + int16x8_t v101 = vqrdmulhq_n_s16(v100, 16705); + int16x8_t v102 = vaddq_s16(v78, v101); + int16x8_t v103 = vaddq_s16(v80, v62); + int16x8_t v104_tmp = vqrdmulhq_n_s16(v103, 13573); + int16x8_t v104 = vaddq_s16(v104_tmp, v103); + int16x8_t v105 = vaddq_s16(v84, v64); + int16x8_t v106 = vaddq_s16(v65, v86); + int16x8_t v107 = vaddq_s16(v105, v106); + int16x8_t v108 = vaddq_s16(v104, v107); + int16x8_t v109 = vaddq_s16(v87, v68); + int16x8_t v110 = vaddq_s16(v69, v79); + int16x8_t v111 = vaddq_s16(v109, v110); + int16x8_t v112_tmp = vqrdmulhq_n_s16(v111, 13573); + int16x8_t v112 = vaddq_s16(v112_tmp, v111); + int16x8_t v113 = vaddq_s16(v94, v72); + int16x8_t v114 = vaddq_s16(v73, v83); + int16x8_t v115 = vaddq_s16(v113, v114); + int16x8_t v116 = vaddq_s16(v115, v111); + int16x8_t v117 = vaddq_s16(v112, v116); + int16x8_t v118 = vqrdmulhq_n_s16(v117, 17734); + int16x8_t v119 = vaddq_s16(v108, v118); + int16x8_t v120 = vaddq_s16(v110, v103); + int16x8_t v121_tmp = vqrdmulhq_n_s16(v120, 13573); + int16x8_t v121 = vaddq_s16(v121_tmp, v120); + int16x8_t v122 = vaddq_s16(v114, v105); + int16x8_t v123 = vaddq_s16(v106, v109); + int16x8_t v124 = vaddq_s16(v122, v123); + int16x8_t v125 = vaddq_s16(v121, v124); + int16x8_t v126 = vaddq_s16(v123, v120); + int16x8_t v127_tmp = vqrdmulhq_n_s16(v126, 13573); + int16x8_t v127 = vaddq_s16(v127_tmp, v126); + int16x8_t v128 = vld1q_s16(in + in_stride * 248 + i); + int16x8_t v129 = vaddq_s16(v128, v93); + int16x8_t v130 = vaddq_s16(v129, v113); + int16x8_t v131 = vaddq_s16(v130, v122); + int16x8_t v132 = vaddq_s16(v131, v126); + int16x8_t v133 = vaddq_s16(v127, v132); + int16x8_t v134 = vqrdmulhq_n_s16(v133, 17734); + int16x8_t v135 = vaddq_s16(v125, v134); + int16x8_t v136 = vqrdmulhq_n_s16(v135, 16705); + int16x8_t v137 = vaddq_s16(v119, v136); + int16x8_t v138 = vqrdmulhq_n_s16(v137, 16463); + int16x8_t v139 = vaddq_s16(v102, v138); + int16x8_t v140 = vqrdmulhq_n_s16(v139, 16404); + int16x8_t v141 = vaddq_s16(v61, v140); + int16x8_t v142 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v143_tmp = vqrdmulhq_n_s16(v142, 13573); + int16x8_t v143 = vaddq_s16(v143_tmp, v142); + int16x8_t v144 = vld1q_s16(in + in_stride * 132 + i); + int16x8_t v145 = vld1q_s16(in + in_stride * 124 + i); + int16x8_t v146 = vaddq_s16(v144, v145); + int16x8_t v147 = vaddq_s16(v143, v146); + int16x8_t v148 = vld1q_s16(in + in_stride * 68 + i); + int16x8_t v149 = vld1q_s16(in + in_stride * 60 + i); + int16x8_t v150 = vaddq_s16(v148, v149); + int16x8_t v151_tmp = vqrdmulhq_n_s16(v150, 13573); + int16x8_t v151 = vaddq_s16(v151_tmp, v150); + int16x8_t v152 = vld1q_s16(in + in_stride * 196 + i); + int16x8_t v153 = vld1q_s16(in + in_stride * 188 + i); + int16x8_t v154 = vaddq_s16(v152, v153); + int16x8_t v155 = vaddq_s16(v154, v150); + int16x8_t v156 = vaddq_s16(v151, v155); + int16x8_t v157 = vqrdmulhq_n_s16(v156, 17734); + int16x8_t v158 = vaddq_s16(v147, v157); + int16x8_t v159 = vld1q_s16(in + in_stride * 36 + i); + int16x8_t v160 = vld1q_s16(in + in_stride * 28 + i); + int16x8_t v161 = vaddq_s16(v159, v160); + int16x8_t v162_tmp = vqrdmulhq_n_s16(v161, 13573); + int16x8_t v162 = vaddq_s16(v162_tmp, v161); + int16x8_t v163 = vld1q_s16(in + in_stride * 164 + i); + int16x8_t v164 = vld1q_s16(in + in_stride * 156 + i); + int16x8_t v165 = vaddq_s16(v163, v164); + int16x8_t v166 = vld1q_s16(in + in_stride * 100 + i); + int16x8_t v167 = vld1q_s16(in + in_stride * 92 + i); + int16x8_t v168 = vaddq_s16(v166, v167); + int16x8_t v169 = vaddq_s16(v165, v168); + int16x8_t v170 = vaddq_s16(v162, v169); + int16x8_t v171 = vaddq_s16(v168, v161); + int16x8_t v172_tmp = vqrdmulhq_n_s16(v171, 13573); + int16x8_t v172 = vaddq_s16(v172_tmp, v171); + int16x8_t v173 = vld1q_s16(in + in_stride * 228 + i); + int16x8_t v174 = vld1q_s16(in + in_stride * 220 + i); + int16x8_t v175 = vaddq_s16(v173, v174); + int16x8_t v176 = vaddq_s16(v175, v165); + int16x8_t v177 = vaddq_s16(v176, v171); + int16x8_t v178 = vaddq_s16(v172, v177); + int16x8_t v179 = vqrdmulhq_n_s16(v178, 17734); + int16x8_t v180 = vaddq_s16(v170, v179); + int16x8_t v181 = vqrdmulhq_n_s16(v180, 16705); + int16x8_t v182 = vaddq_s16(v158, v181); + int16x8_t v183 = vld1q_s16(in + in_stride * 20 + i); + int16x8_t v184 = vld1q_s16(in + in_stride * 12 + i); + int16x8_t v185 = vaddq_s16(v183, v184); + int16x8_t v186_tmp = vqrdmulhq_n_s16(v185, 13573); + int16x8_t v186 = vaddq_s16(v186_tmp, v185); + int16x8_t v187 = vld1q_s16(in + in_stride * 148 + i); + int16x8_t v188 = vld1q_s16(in + in_stride * 140 + i); + int16x8_t v189 = vaddq_s16(v187, v188); + int16x8_t v190 = vld1q_s16(in + in_stride * 116 + i); + int16x8_t v191 = vld1q_s16(in + in_stride * 108 + i); + int16x8_t v192 = vaddq_s16(v190, v191); + int16x8_t v193 = vaddq_s16(v189, v192); + int16x8_t v194 = vaddq_s16(v186, v193); + int16x8_t v195 = vld1q_s16(in + in_stride * 84 + i); + int16x8_t v196 = vld1q_s16(in + in_stride * 76 + i); + int16x8_t v197 = vaddq_s16(v195, v196); + int16x8_t v198 = vld1q_s16(in + in_stride * 52 + i); + int16x8_t v199 = vld1q_s16(in + in_stride * 44 + i); + int16x8_t v200 = vaddq_s16(v198, v199); + int16x8_t v201 = vaddq_s16(v197, v200); + int16x8_t v202_tmp = vqrdmulhq_n_s16(v201, 13573); + int16x8_t v202 = vaddq_s16(v202_tmp, v201); + int16x8_t v203 = vld1q_s16(in + in_stride * 212 + i); + int16x8_t v204 = vld1q_s16(in + in_stride * 204 + i); + int16x8_t v205 = vaddq_s16(v203, v204); + int16x8_t v206 = vld1q_s16(in + in_stride * 180 + i); + int16x8_t v207 = vld1q_s16(in + in_stride * 172 + i); + int16x8_t v208 = vaddq_s16(v206, v207); + int16x8_t v209 = vaddq_s16(v205, v208); + int16x8_t v210 = vaddq_s16(v209, v201); + int16x8_t v211 = vaddq_s16(v202, v210); + int16x8_t v212 = vqrdmulhq_n_s16(v211, 17734); + int16x8_t v213 = vaddq_s16(v194, v212); + int16x8_t v214 = vaddq_s16(v200, v185); + int16x8_t v215_tmp = vqrdmulhq_n_s16(v214, 13573); + int16x8_t v215 = vaddq_s16(v215_tmp, v214); + int16x8_t v216 = vaddq_s16(v208, v189); + int16x8_t v217 = vaddq_s16(v192, v197); + int16x8_t v218 = vaddq_s16(v216, v217); + int16x8_t v219 = vaddq_s16(v215, v218); + int16x8_t v220 = vaddq_s16(v217, v214); + int16x8_t v221_tmp = vqrdmulhq_n_s16(v220, 13573); + int16x8_t v221 = vaddq_s16(v221_tmp, v220); + int16x8_t v222 = vld1q_s16(in + in_stride * 244 + i); + int16x8_t v223 = vld1q_s16(in + in_stride * 236 + i); + int16x8_t v224 = vaddq_s16(v222, v223); + int16x8_t v225 = vaddq_s16(v224, v205); + int16x8_t v226 = vaddq_s16(v225, v216); + int16x8_t v227 = vaddq_s16(v226, v220); + int16x8_t v228 = vaddq_s16(v221, v227); + int16x8_t v229 = vqrdmulhq_n_s16(v228, 17734); + int16x8_t v230 = vaddq_s16(v219, v229); + int16x8_t v231 = vqrdmulhq_n_s16(v230, 16705); + int16x8_t v232 = vaddq_s16(v213, v231); + int16x8_t v233 = vqrdmulhq_n_s16(v232, 16463); + int16x8_t v234 = vaddq_s16(v182, v233); + int16x8_t v235 = vaddq_s16(v184, v142); + int16x8_t v236_tmp = vqrdmulhq_n_s16(v235, 13573); + int16x8_t v236 = vaddq_s16(v236_tmp, v235); + int16x8_t v237 = vaddq_s16(v188, v144); + int16x8_t v238 = vaddq_s16(v145, v190); + int16x8_t v239 = vaddq_s16(v237, v238); + int16x8_t v240 = vaddq_s16(v236, v239); + int16x8_t v241 = vaddq_s16(v196, v148); + int16x8_t v242 = vaddq_s16(v149, v198); + int16x8_t v243 = vaddq_s16(v241, v242); + int16x8_t v244_tmp = vqrdmulhq_n_s16(v243, 13573); + int16x8_t v244 = vaddq_s16(v244_tmp, v243); + int16x8_t v245 = vaddq_s16(v204, v152); + int16x8_t v246 = vaddq_s16(v153, v206); + int16x8_t v247 = vaddq_s16(v245, v246); + int16x8_t v248 = vaddq_s16(v247, v243); + int16x8_t v249 = vaddq_s16(v244, v248); + int16x8_t v250 = vqrdmulhq_n_s16(v249, 17734); + int16x8_t v251 = vaddq_s16(v240, v250); + int16x8_t v252 = vaddq_s16(v199, v159); + int16x8_t v253 = vaddq_s16(v160, v183); + int16x8_t v254 = vaddq_s16(v252, v253); + int16x8_t v255_tmp = vqrdmulhq_n_s16(v254, 13573); + int16x8_t v255 = vaddq_s16(v255_tmp, v254); + int16x8_t v256 = vaddq_s16(v207, v163); + int16x8_t v257 = vaddq_s16(v164, v187); + int16x8_t v258 = vaddq_s16(v256, v257); + int16x8_t v259 = vaddq_s16(v191, v166); + int16x8_t v260 = vaddq_s16(v167, v195); + int16x8_t v261 = vaddq_s16(v259, v260); + int16x8_t v262 = vaddq_s16(v258, v261); + int16x8_t v263 = vaddq_s16(v255, v262); + int16x8_t v264 = vaddq_s16(v261, v254); + int16x8_t v265_tmp = vqrdmulhq_n_s16(v264, 13573); + int16x8_t v265 = vaddq_s16(v265_tmp, v264); + int16x8_t v266 = vaddq_s16(v223, v173); + int16x8_t v267 = vaddq_s16(v174, v203); + int16x8_t v268 = vaddq_s16(v266, v267); + int16x8_t v269 = vaddq_s16(v268, v258); + int16x8_t v270 = vaddq_s16(v269, v264); + int16x8_t v271 = vaddq_s16(v265, v270); + int16x8_t v272 = vqrdmulhq_n_s16(v271, 17734); + int16x8_t v273 = vaddq_s16(v263, v272); + int16x8_t v274 = vqrdmulhq_n_s16(v273, 16705); + int16x8_t v275 = vaddq_s16(v251, v274); + int16x8_t v276 = vaddq_s16(v253, v235); + int16x8_t v277_tmp = vqrdmulhq_n_s16(v276, 13573); + int16x8_t v277 = vaddq_s16(v277_tmp, v276); + int16x8_t v278 = vaddq_s16(v257, v237); + int16x8_t v279 = vaddq_s16(v238, v259); + int16x8_t v280 = vaddq_s16(v278, v279); + int16x8_t v281 = vaddq_s16(v277, v280); + int16x8_t v282 = vaddq_s16(v260, v241); + int16x8_t v283 = vaddq_s16(v242, v252); + int16x8_t v284 = vaddq_s16(v282, v283); + int16x8_t v285_tmp = vqrdmulhq_n_s16(v284, 13573); + int16x8_t v285 = vaddq_s16(v285_tmp, v284); + int16x8_t v286 = vaddq_s16(v267, v245); + int16x8_t v287 = vaddq_s16(v246, v256); + int16x8_t v288 = vaddq_s16(v286, v287); + int16x8_t v289 = vaddq_s16(v288, v284); + int16x8_t v290 = vaddq_s16(v285, v289); + int16x8_t v291 = vqrdmulhq_n_s16(v290, 17734); + int16x8_t v292 = vaddq_s16(v281, v291); + int16x8_t v293 = vaddq_s16(v283, v276); + int16x8_t v294_tmp = vqrdmulhq_n_s16(v293, 13573); + int16x8_t v294 = vaddq_s16(v294_tmp, v293); + int16x8_t v295 = vaddq_s16(v287, v278); + int16x8_t v296 = vaddq_s16(v279, v282); + int16x8_t v297 = vaddq_s16(v295, v296); + int16x8_t v298 = vaddq_s16(v294, v297); + int16x8_t v299 = vaddq_s16(v296, v293); + int16x8_t v300_tmp = vqrdmulhq_n_s16(v299, 13573); + int16x8_t v300 = vaddq_s16(v300_tmp, v299); + int16x8_t v301 = vld1q_s16(in + in_stride * 252 + i); + int16x8_t v302 = vaddq_s16(v301, v222); + int16x8_t v303 = vaddq_s16(v302, v266); + int16x8_t v304 = vaddq_s16(v303, v286); + int16x8_t v305 = vaddq_s16(v304, v295); + int16x8_t v306 = vaddq_s16(v305, v299); + int16x8_t v307 = vaddq_s16(v300, v306); + int16x8_t v308 = vqrdmulhq_n_s16(v307, 17734); + int16x8_t v309 = vaddq_s16(v298, v308); + int16x8_t v310 = vqrdmulhq_n_s16(v309, 16705); + int16x8_t v311 = vaddq_s16(v292, v310); + int16x8_t v312 = vqrdmulhq_n_s16(v311, 16463); + int16x8_t v313 = vaddq_s16(v275, v312); + int16x8_t v314 = vqrdmulhq_n_s16(v313, 16404); + int16x8_t v315 = vaddq_s16(v234, v314); + int16x8_t v316 = vqrdmulhq_n_s16(v315, 16389); + int16x8_t v317 = vaddq_s16(v141, v316); + int16x8_t v318 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v319_tmp = vqrdmulhq_n_s16(v318, 13573); + int16x8_t v319 = vaddq_s16(v319_tmp, v318); + int16x8_t v320 = vld1q_s16(in + in_stride * 130 + i); + int16x8_t v321 = vld1q_s16(in + in_stride * 126 + i); + int16x8_t v322 = vaddq_s16(v320, v321); + int16x8_t v323 = vaddq_s16(v319, v322); + int16x8_t v324 = vld1q_s16(in + in_stride * 66 + i); + int16x8_t v325 = vld1q_s16(in + in_stride * 62 + i); + int16x8_t v326 = vaddq_s16(v324, v325); + int16x8_t v327_tmp = vqrdmulhq_n_s16(v326, 13573); + int16x8_t v327 = vaddq_s16(v327_tmp, v326); + int16x8_t v328 = vld1q_s16(in + in_stride * 194 + i); + int16x8_t v329 = vld1q_s16(in + in_stride * 190 + i); + int16x8_t v330 = vaddq_s16(v328, v329); + int16x8_t v331 = vaddq_s16(v330, v326); + int16x8_t v332 = vaddq_s16(v327, v331); + int16x8_t v333 = vqrdmulhq_n_s16(v332, 17734); + int16x8_t v334 = vaddq_s16(v323, v333); + int16x8_t v335 = vld1q_s16(in + in_stride * 34 + i); + int16x8_t v336 = vld1q_s16(in + in_stride * 30 + i); + int16x8_t v337 = vaddq_s16(v335, v336); + int16x8_t v338_tmp = vqrdmulhq_n_s16(v337, 13573); + int16x8_t v338 = vaddq_s16(v338_tmp, v337); + int16x8_t v339 = vld1q_s16(in + in_stride * 162 + i); + int16x8_t v340 = vld1q_s16(in + in_stride * 158 + i); + int16x8_t v341 = vaddq_s16(v339, v340); + int16x8_t v342 = vld1q_s16(in + in_stride * 98 + i); + int16x8_t v343 = vld1q_s16(in + in_stride * 94 + i); + int16x8_t v344 = vaddq_s16(v342, v343); + int16x8_t v345 = vaddq_s16(v341, v344); + int16x8_t v346 = vaddq_s16(v338, v345); + int16x8_t v347 = vaddq_s16(v344, v337); + int16x8_t v348_tmp = vqrdmulhq_n_s16(v347, 13573); + int16x8_t v348 = vaddq_s16(v348_tmp, v347); + int16x8_t v349 = vld1q_s16(in + in_stride * 226 + i); + int16x8_t v350 = vld1q_s16(in + in_stride * 222 + i); + int16x8_t v351 = vaddq_s16(v349, v350); + int16x8_t v352 = vaddq_s16(v351, v341); + int16x8_t v353 = vaddq_s16(v352, v347); + int16x8_t v354 = vaddq_s16(v348, v353); + int16x8_t v355 = vqrdmulhq_n_s16(v354, 17734); + int16x8_t v356 = vaddq_s16(v346, v355); + int16x8_t v357 = vqrdmulhq_n_s16(v356, 16705); + int16x8_t v358 = vaddq_s16(v334, v357); + int16x8_t v359 = vld1q_s16(in + in_stride * 18 + i); + int16x8_t v360 = vld1q_s16(in + in_stride * 14 + i); + int16x8_t v361 = vaddq_s16(v359, v360); + int16x8_t v362_tmp = vqrdmulhq_n_s16(v361, 13573); + int16x8_t v362 = vaddq_s16(v362_tmp, v361); + int16x8_t v363 = vld1q_s16(in + in_stride * 146 + i); + int16x8_t v364 = vld1q_s16(in + in_stride * 142 + i); + int16x8_t v365 = vaddq_s16(v363, v364); + int16x8_t v366 = vld1q_s16(in + in_stride * 114 + i); + int16x8_t v367 = vld1q_s16(in + in_stride * 110 + i); + int16x8_t v368 = vaddq_s16(v366, v367); + int16x8_t v369 = vaddq_s16(v365, v368); + int16x8_t v370 = vaddq_s16(v362, v369); + int16x8_t v371 = vld1q_s16(in + in_stride * 82 + i); + int16x8_t v372 = vld1q_s16(in + in_stride * 78 + i); + int16x8_t v373 = vaddq_s16(v371, v372); + int16x8_t v374 = vld1q_s16(in + in_stride * 50 + i); + int16x8_t v375 = vld1q_s16(in + in_stride * 46 + i); + int16x8_t v376 = vaddq_s16(v374, v375); + int16x8_t v377 = vaddq_s16(v373, v376); + int16x8_t v378_tmp = vqrdmulhq_n_s16(v377, 13573); + int16x8_t v378 = vaddq_s16(v378_tmp, v377); + int16x8_t v379 = vld1q_s16(in + in_stride * 210 + i); + int16x8_t v380 = vld1q_s16(in + in_stride * 206 + i); + int16x8_t v381 = vaddq_s16(v379, v380); + int16x8_t v382 = vld1q_s16(in + in_stride * 178 + i); + int16x8_t v383 = vld1q_s16(in + in_stride * 174 + i); + int16x8_t v384 = vaddq_s16(v382, v383); + int16x8_t v385 = vaddq_s16(v381, v384); + int16x8_t v386 = vaddq_s16(v385, v377); + int16x8_t v387 = vaddq_s16(v378, v386); + int16x8_t v388 = vqrdmulhq_n_s16(v387, 17734); + int16x8_t v389 = vaddq_s16(v370, v388); + int16x8_t v390 = vaddq_s16(v376, v361); + int16x8_t v391_tmp = vqrdmulhq_n_s16(v390, 13573); + int16x8_t v391 = vaddq_s16(v391_tmp, v390); + int16x8_t v392 = vaddq_s16(v384, v365); + int16x8_t v393 = vaddq_s16(v368, v373); + int16x8_t v394 = vaddq_s16(v392, v393); + int16x8_t v395 = vaddq_s16(v391, v394); + int16x8_t v396 = vaddq_s16(v393, v390); + int16x8_t v397_tmp = vqrdmulhq_n_s16(v396, 13573); + int16x8_t v397 = vaddq_s16(v397_tmp, v396); + int16x8_t v398 = vld1q_s16(in + in_stride * 242 + i); + int16x8_t v399 = vld1q_s16(in + in_stride * 238 + i); + int16x8_t v400 = vaddq_s16(v398, v399); + int16x8_t v401 = vaddq_s16(v400, v381); + int16x8_t v402 = vaddq_s16(v401, v392); + int16x8_t v403 = vaddq_s16(v402, v396); + int16x8_t v404 = vaddq_s16(v397, v403); + int16x8_t v405 = vqrdmulhq_n_s16(v404, 17734); + int16x8_t v406 = vaddq_s16(v395, v405); + int16x8_t v407 = vqrdmulhq_n_s16(v406, 16705); + int16x8_t v408 = vaddq_s16(v389, v407); + int16x8_t v409 = vqrdmulhq_n_s16(v408, 16463); + int16x8_t v410 = vaddq_s16(v358, v409); + int16x8_t v411 = vld1q_s16(in + in_stride * 10 + i); + int16x8_t v412 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v413 = vaddq_s16(v411, v412); + int16x8_t v414_tmp = vqrdmulhq_n_s16(v413, 13573); + int16x8_t v414 = vaddq_s16(v414_tmp, v413); + int16x8_t v415 = vld1q_s16(in + in_stride * 138 + i); + int16x8_t v416 = vld1q_s16(in + in_stride * 134 + i); + int16x8_t v417 = vaddq_s16(v415, v416); + int16x8_t v418 = vld1q_s16(in + in_stride * 122 + i); + int16x8_t v419 = vld1q_s16(in + in_stride * 118 + i); + int16x8_t v420 = vaddq_s16(v418, v419); + int16x8_t v421 = vaddq_s16(v417, v420); + int16x8_t v422 = vaddq_s16(v414, v421); + int16x8_t v423 = vld1q_s16(in + in_stride * 74 + i); + int16x8_t v424 = vld1q_s16(in + in_stride * 70 + i); + int16x8_t v425 = vaddq_s16(v423, v424); + int16x8_t v426 = vld1q_s16(in + in_stride * 58 + i); + int16x8_t v427 = vld1q_s16(in + in_stride * 54 + i); + int16x8_t v428 = vaddq_s16(v426, v427); + int16x8_t v429 = vaddq_s16(v425, v428); + int16x8_t v430_tmp = vqrdmulhq_n_s16(v429, 13573); + int16x8_t v430 = vaddq_s16(v430_tmp, v429); + int16x8_t v431 = vld1q_s16(in + in_stride * 202 + i); + int16x8_t v432 = vld1q_s16(in + in_stride * 198 + i); + int16x8_t v433 = vaddq_s16(v431, v432); + int16x8_t v434 = vld1q_s16(in + in_stride * 186 + i); + int16x8_t v435 = vld1q_s16(in + in_stride * 182 + i); + int16x8_t v436 = vaddq_s16(v434, v435); + int16x8_t v437 = vaddq_s16(v433, v436); + int16x8_t v438 = vaddq_s16(v437, v429); + int16x8_t v439 = vaddq_s16(v430, v438); + int16x8_t v440 = vqrdmulhq_n_s16(v439, 17734); + int16x8_t v441 = vaddq_s16(v422, v440); + int16x8_t v442 = vld1q_s16(in + in_stride * 42 + i); + int16x8_t v443 = vld1q_s16(in + in_stride * 38 + i); + int16x8_t v444 = vaddq_s16(v442, v443); + int16x8_t v445 = vld1q_s16(in + in_stride * 26 + i); + int16x8_t v446 = vld1q_s16(in + in_stride * 22 + i); + int16x8_t v447 = vaddq_s16(v445, v446); + int16x8_t v448 = vaddq_s16(v444, v447); + int16x8_t v449_tmp = vqrdmulhq_n_s16(v448, 13573); + int16x8_t v449 = vaddq_s16(v449_tmp, v448); + int16x8_t v450 = vld1q_s16(in + in_stride * 170 + i); + int16x8_t v451 = vld1q_s16(in + in_stride * 166 + i); + int16x8_t v452 = vaddq_s16(v450, v451); + int16x8_t v453 = vld1q_s16(in + in_stride * 154 + i); + int16x8_t v454 = vld1q_s16(in + in_stride * 150 + i); + int16x8_t v455 = vaddq_s16(v453, v454); + int16x8_t v456 = vaddq_s16(v452, v455); + int16x8_t v457 = vld1q_s16(in + in_stride * 106 + i); + int16x8_t v458 = vld1q_s16(in + in_stride * 102 + i); + int16x8_t v459 = vaddq_s16(v457, v458); + int16x8_t v460 = vld1q_s16(in + in_stride * 90 + i); + int16x8_t v461 = vld1q_s16(in + in_stride * 86 + i); + int16x8_t v462 = vaddq_s16(v460, v461); + int16x8_t v463 = vaddq_s16(v459, v462); + int16x8_t v464 = vaddq_s16(v456, v463); + int16x8_t v465 = vaddq_s16(v449, v464); + int16x8_t v466 = vaddq_s16(v463, v448); + int16x8_t v467_tmp = vqrdmulhq_n_s16(v466, 13573); + int16x8_t v467 = vaddq_s16(v467_tmp, v466); + int16x8_t v468 = vld1q_s16(in + in_stride * 234 + i); + int16x8_t v469 = vld1q_s16(in + in_stride * 230 + i); + int16x8_t v470 = vaddq_s16(v468, v469); + int16x8_t v471 = vld1q_s16(in + in_stride * 218 + i); + int16x8_t v472 = vld1q_s16(in + in_stride * 214 + i); + int16x8_t v473 = vaddq_s16(v471, v472); + int16x8_t v474 = vaddq_s16(v470, v473); + int16x8_t v475 = vaddq_s16(v474, v456); + int16x8_t v476 = vaddq_s16(v475, v466); + int16x8_t v477 = vaddq_s16(v467, v476); + int16x8_t v478 = vqrdmulhq_n_s16(v477, 17734); + int16x8_t v479 = vaddq_s16(v465, v478); + int16x8_t v480 = vqrdmulhq_n_s16(v479, 16705); + int16x8_t v481 = vaddq_s16(v441, v480); + int16x8_t v482 = vaddq_s16(v447, v413); + int16x8_t v483_tmp = vqrdmulhq_n_s16(v482, 13573); + int16x8_t v483 = vaddq_s16(v483_tmp, v482); + int16x8_t v484 = vaddq_s16(v455, v417); + int16x8_t v485 = vaddq_s16(v420, v459); + int16x8_t v486 = vaddq_s16(v484, v485); + int16x8_t v487 = vaddq_s16(v483, v486); + int16x8_t v488 = vaddq_s16(v462, v425); + int16x8_t v489 = vaddq_s16(v428, v444); + int16x8_t v490 = vaddq_s16(v488, v489); + int16x8_t v491_tmp = vqrdmulhq_n_s16(v490, 13573); + int16x8_t v491 = vaddq_s16(v491_tmp, v490); + int16x8_t v492 = vaddq_s16(v473, v433); + int16x8_t v493 = vaddq_s16(v436, v452); + int16x8_t v494 = vaddq_s16(v492, v493); + int16x8_t v495 = vaddq_s16(v494, v490); + int16x8_t v496 = vaddq_s16(v491, v495); + int16x8_t v497 = vqrdmulhq_n_s16(v496, 17734); + int16x8_t v498 = vaddq_s16(v487, v497); + int16x8_t v499 = vaddq_s16(v489, v482); + int16x8_t v500_tmp = vqrdmulhq_n_s16(v499, 13573); + int16x8_t v500 = vaddq_s16(v500_tmp, v499); + int16x8_t v501 = vaddq_s16(v493, v484); + int16x8_t v502 = vaddq_s16(v485, v488); + int16x8_t v503 = vaddq_s16(v501, v502); + int16x8_t v504 = vaddq_s16(v500, v503); + int16x8_t v505 = vaddq_s16(v502, v499); + int16x8_t v506_tmp = vqrdmulhq_n_s16(v505, 13573); + int16x8_t v506 = vaddq_s16(v506_tmp, v505); + int16x8_t v507 = vld1q_s16(in + in_stride * 250 + i); + int16x8_t v508 = vld1q_s16(in + in_stride * 246 + i); + int16x8_t v509 = vaddq_s16(v507, v508); + int16x8_t v510 = vaddq_s16(v509, v470); + int16x8_t v511 = vaddq_s16(v510, v492); + int16x8_t v512 = vaddq_s16(v511, v501); + int16x8_t v513 = vaddq_s16(v512, v505); + int16x8_t v514 = vaddq_s16(v506, v513); + int16x8_t v515 = vqrdmulhq_n_s16(v514, 17734); + int16x8_t v516 = vaddq_s16(v504, v515); + int16x8_t v517 = vqrdmulhq_n_s16(v516, 16705); + int16x8_t v518 = vaddq_s16(v498, v517); + int16x8_t v519 = vqrdmulhq_n_s16(v518, 16463); + int16x8_t v520 = vaddq_s16(v481, v519); + int16x8_t v521 = vqrdmulhq_n_s16(v520, 16404); + int16x8_t v522 = vaddq_s16(v410, v521); + int16x8_t v523 = vaddq_s16(v412, v318); + int16x8_t v524_tmp = vqrdmulhq_n_s16(v523, 13573); + int16x8_t v524 = vaddq_s16(v524_tmp, v523); + int16x8_t v525 = vaddq_s16(v416, v320); + int16x8_t v526 = vaddq_s16(v321, v418); + int16x8_t v527 = vaddq_s16(v525, v526); + int16x8_t v528 = vaddq_s16(v524, v527); + int16x8_t v529 = vaddq_s16(v424, v324); + int16x8_t v530 = vaddq_s16(v325, v426); + int16x8_t v531 = vaddq_s16(v529, v530); + int16x8_t v532_tmp = vqrdmulhq_n_s16(v531, 13573); + int16x8_t v532 = vaddq_s16(v532_tmp, v531); + int16x8_t v533 = vaddq_s16(v432, v328); + int16x8_t v534 = vaddq_s16(v329, v434); + int16x8_t v535 = vaddq_s16(v533, v534); + int16x8_t v536 = vaddq_s16(v535, v531); + int16x8_t v537 = vaddq_s16(v532, v536); + int16x8_t v538 = vqrdmulhq_n_s16(v537, 17734); + int16x8_t v539 = vaddq_s16(v528, v538); + int16x8_t v540 = vaddq_s16(v443, v335); + int16x8_t v541 = vaddq_s16(v336, v445); + int16x8_t v542 = vaddq_s16(v540, v541); + int16x8_t v543_tmp = vqrdmulhq_n_s16(v542, 13573); + int16x8_t v543 = vaddq_s16(v543_tmp, v542); + int16x8_t v544 = vaddq_s16(v451, v339); + int16x8_t v545 = vaddq_s16(v340, v453); + int16x8_t v546 = vaddq_s16(v544, v545); + int16x8_t v547 = vaddq_s16(v458, v342); + int16x8_t v548 = vaddq_s16(v343, v460); + int16x8_t v549 = vaddq_s16(v547, v548); + int16x8_t v550 = vaddq_s16(v546, v549); + int16x8_t v551 = vaddq_s16(v543, v550); + int16x8_t v552 = vaddq_s16(v549, v542); + int16x8_t v553_tmp = vqrdmulhq_n_s16(v552, 13573); + int16x8_t v553 = vaddq_s16(v553_tmp, v552); + int16x8_t v554 = vaddq_s16(v469, v349); + int16x8_t v555 = vaddq_s16(v350, v471); + int16x8_t v556 = vaddq_s16(v554, v555); + int16x8_t v557 = vaddq_s16(v556, v546); + int16x8_t v558 = vaddq_s16(v557, v552); + int16x8_t v559 = vaddq_s16(v553, v558); + int16x8_t v560 = vqrdmulhq_n_s16(v559, 17734); + int16x8_t v561 = vaddq_s16(v551, v560); + int16x8_t v562 = vqrdmulhq_n_s16(v561, 16705); + int16x8_t v563 = vaddq_s16(v539, v562); + int16x8_t v564 = vaddq_s16(v446, v359); + int16x8_t v565 = vaddq_s16(v360, v411); + int16x8_t v566 = vaddq_s16(v564, v565); + int16x8_t v567_tmp = vqrdmulhq_n_s16(v566, 13573); + int16x8_t v567 = vaddq_s16(v567_tmp, v566); + int16x8_t v568 = vaddq_s16(v454, v363); + int16x8_t v569 = vaddq_s16(v364, v415); + int16x8_t v570 = vaddq_s16(v568, v569); + int16x8_t v571 = vaddq_s16(v419, v366); + int16x8_t v572 = vaddq_s16(v367, v457); + int16x8_t v573 = vaddq_s16(v571, v572); + int16x8_t v574 = vaddq_s16(v570, v573); + int16x8_t v575 = vaddq_s16(v567, v574); + int16x8_t v576 = vaddq_s16(v461, v371); + int16x8_t v577 = vaddq_s16(v372, v423); + int16x8_t v578 = vaddq_s16(v576, v577); + int16x8_t v579 = vaddq_s16(v427, v374); + int16x8_t v580 = vaddq_s16(v375, v442); + int16x8_t v581 = vaddq_s16(v579, v580); + int16x8_t v582 = vaddq_s16(v578, v581); + int16x8_t v583_tmp = vqrdmulhq_n_s16(v582, 13573); + int16x8_t v583 = vaddq_s16(v583_tmp, v582); + int16x8_t v584 = vaddq_s16(v472, v379); + int16x8_t v585 = vaddq_s16(v380, v431); + int16x8_t v586 = vaddq_s16(v584, v585); + int16x8_t v587 = vaddq_s16(v435, v382); + int16x8_t v588 = vaddq_s16(v383, v450); + int16x8_t v589 = vaddq_s16(v587, v588); + int16x8_t v590 = vaddq_s16(v586, v589); + int16x8_t v591 = vaddq_s16(v590, v582); + int16x8_t v592 = vaddq_s16(v583, v591); + int16x8_t v593 = vqrdmulhq_n_s16(v592, 17734); + int16x8_t v594 = vaddq_s16(v575, v593); + int16x8_t v595 = vaddq_s16(v581, v566); + int16x8_t v596_tmp = vqrdmulhq_n_s16(v595, 13573); + int16x8_t v596 = vaddq_s16(v596_tmp, v595); + int16x8_t v597 = vaddq_s16(v589, v570); + int16x8_t v598 = vaddq_s16(v573, v578); + int16x8_t v599 = vaddq_s16(v597, v598); + int16x8_t v600 = vaddq_s16(v596, v599); + int16x8_t v601 = vaddq_s16(v598, v595); + int16x8_t v602_tmp = vqrdmulhq_n_s16(v601, 13573); + int16x8_t v602 = vaddq_s16(v602_tmp, v601); + int16x8_t v603 = vaddq_s16(v508, v398); + int16x8_t v604 = vaddq_s16(v399, v468); + int16x8_t v605 = vaddq_s16(v603, v604); + int16x8_t v606 = vaddq_s16(v605, v586); + int16x8_t v607 = vaddq_s16(v606, v597); + int16x8_t v608 = vaddq_s16(v607, v601); + int16x8_t v609 = vaddq_s16(v602, v608); + int16x8_t v610 = vqrdmulhq_n_s16(v609, 17734); + int16x8_t v611 = vaddq_s16(v600, v610); + int16x8_t v612 = vqrdmulhq_n_s16(v611, 16705); + int16x8_t v613 = vaddq_s16(v594, v612); + int16x8_t v614 = vqrdmulhq_n_s16(v613, 16463); + int16x8_t v615 = vaddq_s16(v563, v614); + int16x8_t v616 = vaddq_s16(v565, v523); + int16x8_t v617_tmp = vqrdmulhq_n_s16(v616, 13573); + int16x8_t v617 = vaddq_s16(v617_tmp, v616); + int16x8_t v618 = vaddq_s16(v569, v525); + int16x8_t v619 = vaddq_s16(v526, v571); + int16x8_t v620 = vaddq_s16(v618, v619); + int16x8_t v621 = vaddq_s16(v617, v620); + int16x8_t v622 = vaddq_s16(v577, v529); + int16x8_t v623 = vaddq_s16(v530, v579); + int16x8_t v624 = vaddq_s16(v622, v623); + int16x8_t v625_tmp = vqrdmulhq_n_s16(v624, 13573); + int16x8_t v625 = vaddq_s16(v625_tmp, v624); + int16x8_t v626 = vaddq_s16(v585, v533); + int16x8_t v627 = vaddq_s16(v534, v587); + int16x8_t v628 = vaddq_s16(v626, v627); + int16x8_t v629 = vaddq_s16(v628, v624); + int16x8_t v630 = vaddq_s16(v625, v629); + int16x8_t v631 = vqrdmulhq_n_s16(v630, 17734); + int16x8_t v632 = vaddq_s16(v621, v631); + int16x8_t v633 = vaddq_s16(v580, v540); + int16x8_t v634 = vaddq_s16(v541, v564); + int16x8_t v635 = vaddq_s16(v633, v634); + int16x8_t v636_tmp = vqrdmulhq_n_s16(v635, 13573); + int16x8_t v636 = vaddq_s16(v636_tmp, v635); + int16x8_t v637 = vaddq_s16(v588, v544); + int16x8_t v638 = vaddq_s16(v545, v568); + int16x8_t v639 = vaddq_s16(v637, v638); + int16x8_t v640 = vaddq_s16(v572, v547); + int16x8_t v641 = vaddq_s16(v548, v576); + int16x8_t v642 = vaddq_s16(v640, v641); + int16x8_t v643 = vaddq_s16(v639, v642); + int16x8_t v644 = vaddq_s16(v636, v643); + int16x8_t v645 = vaddq_s16(v642, v635); + int16x8_t v646_tmp = vqrdmulhq_n_s16(v645, 13573); + int16x8_t v646 = vaddq_s16(v646_tmp, v645); + int16x8_t v647 = vaddq_s16(v604, v554); + int16x8_t v648 = vaddq_s16(v555, v584); + int16x8_t v649 = vaddq_s16(v647, v648); + int16x8_t v650 = vaddq_s16(v649, v639); + int16x8_t v651 = vaddq_s16(v650, v645); + int16x8_t v652 = vaddq_s16(v646, v651); + int16x8_t v653 = vqrdmulhq_n_s16(v652, 17734); + int16x8_t v654 = vaddq_s16(v644, v653); + int16x8_t v655 = vqrdmulhq_n_s16(v654, 16705); + int16x8_t v656 = vaddq_s16(v632, v655); + int16x8_t v657 = vaddq_s16(v634, v616); + int16x8_t v658_tmp = vqrdmulhq_n_s16(v657, 13573); + int16x8_t v658 = vaddq_s16(v658_tmp, v657); + int16x8_t v659 = vaddq_s16(v638, v618); + int16x8_t v660 = vaddq_s16(v619, v640); + int16x8_t v661 = vaddq_s16(v659, v660); + int16x8_t v662 = vaddq_s16(v658, v661); + int16x8_t v663 = vaddq_s16(v641, v622); + int16x8_t v664 = vaddq_s16(v623, v633); + int16x8_t v665 = vaddq_s16(v663, v664); + int16x8_t v666_tmp = vqrdmulhq_n_s16(v665, 13573); + int16x8_t v666 = vaddq_s16(v666_tmp, v665); + int16x8_t v667 = vaddq_s16(v648, v626); + int16x8_t v668 = vaddq_s16(v627, v637); + int16x8_t v669 = vaddq_s16(v667, v668); + int16x8_t v670 = vaddq_s16(v669, v665); + int16x8_t v671 = vaddq_s16(v666, v670); + int16x8_t v672 = vqrdmulhq_n_s16(v671, 17734); + int16x8_t v673 = vaddq_s16(v662, v672); + int16x8_t v674 = vaddq_s16(v664, v657); + int16x8_t v675_tmp = vqrdmulhq_n_s16(v674, 13573); + int16x8_t v675 = vaddq_s16(v675_tmp, v674); + int16x8_t v676 = vaddq_s16(v668, v659); + int16x8_t v677 = vaddq_s16(v660, v663); + int16x8_t v678 = vaddq_s16(v676, v677); + int16x8_t v679 = vaddq_s16(v675, v678); + int16x8_t v680 = vaddq_s16(v677, v674); + int16x8_t v681_tmp = vqrdmulhq_n_s16(v680, 13573); + int16x8_t v681 = vaddq_s16(v681_tmp, v680); + int16x8_t v682 = vld1q_s16(in + in_stride * 254 + i); + int16x8_t v683 = vaddq_s16(v682, v507); + int16x8_t v684 = vaddq_s16(v683, v603); + int16x8_t v685 = vaddq_s16(v684, v647); + int16x8_t v686 = vaddq_s16(v685, v667); + int16x8_t v687 = vaddq_s16(v686, v676); + int16x8_t v688 = vaddq_s16(v687, v680); + int16x8_t v689 = vaddq_s16(v681, v688); + int16x8_t v690 = vqrdmulhq_n_s16(v689, 17734); + int16x8_t v691 = vaddq_s16(v679, v690); + int16x8_t v692 = vqrdmulhq_n_s16(v691, 16705); + int16x8_t v693 = vaddq_s16(v673, v692); + int16x8_t v694 = vqrdmulhq_n_s16(v693, 16463); + int16x8_t v695 = vaddq_s16(v656, v694); + int16x8_t v696 = vqrdmulhq_n_s16(v695, 16404); + int16x8_t v697 = vaddq_s16(v615, v696); + int16x8_t v698 = vqrdmulhq_n_s16(v697, 16389); + int16x8_t v699 = vaddq_s16(v522, v698); + int16x8_t v700 = vqrdmulhq_n_s16(v699, 16385); + int16x8_t v701 = vaddq_s16(v317, v700); + int16x8_t v702 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v703_tmp = vqrdmulhq_n_s16(v702, 13573); + int16x8_t v703 = vaddq_s16(v703_tmp, v702); + int16x8_t v704 = vld1q_s16(in + in_stride * 129 + i); + int16x8_t v705 = vld1q_s16(in + in_stride * 127 + i); + int16x8_t v706 = vaddq_s16(v704, v705); + int16x8_t v707 = vaddq_s16(v703, v706); + int16x8_t v708 = vld1q_s16(in + in_stride * 65 + i); + int16x8_t v709 = vld1q_s16(in + in_stride * 63 + i); + int16x8_t v710 = vaddq_s16(v708, v709); + int16x8_t v711_tmp = vqrdmulhq_n_s16(v710, 13573); + int16x8_t v711 = vaddq_s16(v711_tmp, v710); + int16x8_t v712 = vld1q_s16(in + in_stride * 193 + i); + int16x8_t v713 = vld1q_s16(in + in_stride * 191 + i); + int16x8_t v714 = vaddq_s16(v712, v713); + int16x8_t v715 = vaddq_s16(v714, v710); + int16x8_t v716 = vaddq_s16(v711, v715); + int16x8_t v717 = vqrdmulhq_n_s16(v716, 17734); + int16x8_t v718 = vaddq_s16(v707, v717); + int16x8_t v719 = vld1q_s16(in + in_stride * 33 + i); + int16x8_t v720 = vld1q_s16(in + in_stride * 31 + i); + int16x8_t v721 = vaddq_s16(v719, v720); + int16x8_t v722_tmp = vqrdmulhq_n_s16(v721, 13573); + int16x8_t v722 = vaddq_s16(v722_tmp, v721); + int16x8_t v723 = vld1q_s16(in + in_stride * 161 + i); + int16x8_t v724 = vld1q_s16(in + in_stride * 159 + i); + int16x8_t v725 = vaddq_s16(v723, v724); + int16x8_t v726 = vld1q_s16(in + in_stride * 97 + i); + int16x8_t v727 = vld1q_s16(in + in_stride * 95 + i); + int16x8_t v728 = vaddq_s16(v726, v727); + int16x8_t v729 = vaddq_s16(v725, v728); + int16x8_t v730 = vaddq_s16(v722, v729); + int16x8_t v731 = vaddq_s16(v728, v721); + int16x8_t v732_tmp = vqrdmulhq_n_s16(v731, 13573); + int16x8_t v732 = vaddq_s16(v732_tmp, v731); + int16x8_t v733 = vld1q_s16(in + in_stride * 225 + i); + int16x8_t v734 = vld1q_s16(in + in_stride * 223 + i); + int16x8_t v735 = vaddq_s16(v733, v734); + int16x8_t v736 = vaddq_s16(v735, v725); + int16x8_t v737 = vaddq_s16(v736, v731); + int16x8_t v738 = vaddq_s16(v732, v737); + int16x8_t v739 = vqrdmulhq_n_s16(v738, 17734); + int16x8_t v740 = vaddq_s16(v730, v739); + int16x8_t v741 = vqrdmulhq_n_s16(v740, 16705); + int16x8_t v742 = vaddq_s16(v718, v741); + int16x8_t v743 = vld1q_s16(in + in_stride * 17 + i); + int16x8_t v744 = vld1q_s16(in + in_stride * 15 + i); + int16x8_t v745 = vaddq_s16(v743, v744); + int16x8_t v746_tmp = vqrdmulhq_n_s16(v745, 13573); + int16x8_t v746 = vaddq_s16(v746_tmp, v745); + int16x8_t v747 = vld1q_s16(in + in_stride * 145 + i); + int16x8_t v748 = vld1q_s16(in + in_stride * 143 + i); + int16x8_t v749 = vaddq_s16(v747, v748); + int16x8_t v750 = vld1q_s16(in + in_stride * 113 + i); + int16x8_t v751 = vld1q_s16(in + in_stride * 111 + i); + int16x8_t v752 = vaddq_s16(v750, v751); + int16x8_t v753 = vaddq_s16(v749, v752); + int16x8_t v754 = vaddq_s16(v746, v753); + int16x8_t v755 = vld1q_s16(in + in_stride * 81 + i); + int16x8_t v756 = vld1q_s16(in + in_stride * 79 + i); + int16x8_t v757 = vaddq_s16(v755, v756); + int16x8_t v758 = vld1q_s16(in + in_stride * 49 + i); + int16x8_t v759 = vld1q_s16(in + in_stride * 47 + i); + int16x8_t v760 = vaddq_s16(v758, v759); + int16x8_t v761 = vaddq_s16(v757, v760); + int16x8_t v762_tmp = vqrdmulhq_n_s16(v761, 13573); + int16x8_t v762 = vaddq_s16(v762_tmp, v761); + int16x8_t v763 = vld1q_s16(in + in_stride * 209 + i); + int16x8_t v764 = vld1q_s16(in + in_stride * 207 + i); + int16x8_t v765 = vaddq_s16(v763, v764); + int16x8_t v766 = vld1q_s16(in + in_stride * 177 + i); + int16x8_t v767 = vld1q_s16(in + in_stride * 175 + i); + int16x8_t v768 = vaddq_s16(v766, v767); + int16x8_t v769 = vaddq_s16(v765, v768); + int16x8_t v770 = vaddq_s16(v769, v761); + int16x8_t v771 = vaddq_s16(v762, v770); + int16x8_t v772 = vqrdmulhq_n_s16(v771, 17734); + int16x8_t v773 = vaddq_s16(v754, v772); + int16x8_t v774 = vaddq_s16(v760, v745); + int16x8_t v775_tmp = vqrdmulhq_n_s16(v774, 13573); + int16x8_t v775 = vaddq_s16(v775_tmp, v774); + int16x8_t v776 = vaddq_s16(v768, v749); + int16x8_t v777 = vaddq_s16(v752, v757); + int16x8_t v778 = vaddq_s16(v776, v777); + int16x8_t v779 = vaddq_s16(v775, v778); + int16x8_t v780 = vaddq_s16(v777, v774); + int16x8_t v781_tmp = vqrdmulhq_n_s16(v780, 13573); + int16x8_t v781 = vaddq_s16(v781_tmp, v780); + int16x8_t v782 = vld1q_s16(in + in_stride * 241 + i); + int16x8_t v783 = vld1q_s16(in + in_stride * 239 + i); + int16x8_t v784 = vaddq_s16(v782, v783); + int16x8_t v785 = vaddq_s16(v784, v765); + int16x8_t v786 = vaddq_s16(v785, v776); + int16x8_t v787 = vaddq_s16(v786, v780); + int16x8_t v788 = vaddq_s16(v781, v787); + int16x8_t v789 = vqrdmulhq_n_s16(v788, 17734); + int16x8_t v790 = vaddq_s16(v779, v789); + int16x8_t v791 = vqrdmulhq_n_s16(v790, 16705); + int16x8_t v792 = vaddq_s16(v773, v791); + int16x8_t v793 = vqrdmulhq_n_s16(v792, 16463); + int16x8_t v794 = vaddq_s16(v742, v793); + int16x8_t v795 = vld1q_s16(in + in_stride * 9 + i); + int16x8_t v796 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v797 = vaddq_s16(v795, v796); + int16x8_t v798_tmp = vqrdmulhq_n_s16(v797, 13573); + int16x8_t v798 = vaddq_s16(v798_tmp, v797); + int16x8_t v799 = vld1q_s16(in + in_stride * 137 + i); + int16x8_t v800 = vld1q_s16(in + in_stride * 135 + i); + int16x8_t v801 = vaddq_s16(v799, v800); + int16x8_t v802 = vld1q_s16(in + in_stride * 121 + i); + int16x8_t v803 = vld1q_s16(in + in_stride * 119 + i); + int16x8_t v804 = vaddq_s16(v802, v803); + int16x8_t v805 = vaddq_s16(v801, v804); + int16x8_t v806 = vaddq_s16(v798, v805); + int16x8_t v807 = vld1q_s16(in + in_stride * 73 + i); + int16x8_t v808 = vld1q_s16(in + in_stride * 71 + i); + int16x8_t v809 = vaddq_s16(v807, v808); + int16x8_t v810 = vld1q_s16(in + in_stride * 57 + i); + int16x8_t v811 = vld1q_s16(in + in_stride * 55 + i); + int16x8_t v812 = vaddq_s16(v810, v811); + int16x8_t v813 = vaddq_s16(v809, v812); + int16x8_t v814_tmp = vqrdmulhq_n_s16(v813, 13573); + int16x8_t v814 = vaddq_s16(v814_tmp, v813); + int16x8_t v815 = vld1q_s16(in + in_stride * 201 + i); + int16x8_t v816 = vld1q_s16(in + in_stride * 199 + i); + int16x8_t v817 = vaddq_s16(v815, v816); + int16x8_t v818 = vld1q_s16(in + in_stride * 185 + i); + int16x8_t v819 = vld1q_s16(in + in_stride * 183 + i); + int16x8_t v820 = vaddq_s16(v818, v819); + int16x8_t v821 = vaddq_s16(v817, v820); + int16x8_t v822 = vaddq_s16(v821, v813); + int16x8_t v823 = vaddq_s16(v814, v822); + int16x8_t v824 = vqrdmulhq_n_s16(v823, 17734); + int16x8_t v825 = vaddq_s16(v806, v824); + int16x8_t v826 = vld1q_s16(in + in_stride * 41 + i); + int16x8_t v827 = vld1q_s16(in + in_stride * 39 + i); + int16x8_t v828 = vaddq_s16(v826, v827); + int16x8_t v829 = vld1q_s16(in + in_stride * 25 + i); + int16x8_t v830 = vld1q_s16(in + in_stride * 23 + i); + int16x8_t v831 = vaddq_s16(v829, v830); + int16x8_t v832 = vaddq_s16(v828, v831); + int16x8_t v833_tmp = vqrdmulhq_n_s16(v832, 13573); + int16x8_t v833 = vaddq_s16(v833_tmp, v832); + int16x8_t v834 = vld1q_s16(in + in_stride * 169 + i); + int16x8_t v835 = vld1q_s16(in + in_stride * 167 + i); + int16x8_t v836 = vaddq_s16(v834, v835); + int16x8_t v837 = vld1q_s16(in + in_stride * 153 + i); + int16x8_t v838 = vld1q_s16(in + in_stride * 151 + i); + int16x8_t v839 = vaddq_s16(v837, v838); + int16x8_t v840 = vaddq_s16(v836, v839); + int16x8_t v841 = vld1q_s16(in + in_stride * 105 + i); + int16x8_t v842 = vld1q_s16(in + in_stride * 103 + i); + int16x8_t v843 = vaddq_s16(v841, v842); + int16x8_t v844 = vld1q_s16(in + in_stride * 89 + i); + int16x8_t v845 = vld1q_s16(in + in_stride * 87 + i); + int16x8_t v846 = vaddq_s16(v844, v845); + int16x8_t v847 = vaddq_s16(v843, v846); + int16x8_t v848 = vaddq_s16(v840, v847); + int16x8_t v849 = vaddq_s16(v833, v848); + int16x8_t v850 = vaddq_s16(v847, v832); + int16x8_t v851_tmp = vqrdmulhq_n_s16(v850, 13573); + int16x8_t v851 = vaddq_s16(v851_tmp, v850); + int16x8_t v852 = vld1q_s16(in + in_stride * 233 + i); + int16x8_t v853 = vld1q_s16(in + in_stride * 231 + i); + int16x8_t v854 = vaddq_s16(v852, v853); + int16x8_t v855 = vld1q_s16(in + in_stride * 217 + i); + int16x8_t v856 = vld1q_s16(in + in_stride * 215 + i); + int16x8_t v857 = vaddq_s16(v855, v856); + int16x8_t v858 = vaddq_s16(v854, v857); + int16x8_t v859 = vaddq_s16(v858, v840); + int16x8_t v860 = vaddq_s16(v859, v850); + int16x8_t v861 = vaddq_s16(v851, v860); + int16x8_t v862 = vqrdmulhq_n_s16(v861, 17734); + int16x8_t v863 = vaddq_s16(v849, v862); + int16x8_t v864 = vqrdmulhq_n_s16(v863, 16705); + int16x8_t v865 = vaddq_s16(v825, v864); + int16x8_t v866 = vaddq_s16(v831, v797); + int16x8_t v867_tmp = vqrdmulhq_n_s16(v866, 13573); + int16x8_t v867 = vaddq_s16(v867_tmp, v866); + int16x8_t v868 = vaddq_s16(v839, v801); + int16x8_t v869 = vaddq_s16(v804, v843); + int16x8_t v870 = vaddq_s16(v868, v869); + int16x8_t v871 = vaddq_s16(v867, v870); + int16x8_t v872 = vaddq_s16(v846, v809); + int16x8_t v873 = vaddq_s16(v812, v828); + int16x8_t v874 = vaddq_s16(v872, v873); + int16x8_t v875_tmp = vqrdmulhq_n_s16(v874, 13573); + int16x8_t v875 = vaddq_s16(v875_tmp, v874); + int16x8_t v876 = vaddq_s16(v857, v817); + int16x8_t v877 = vaddq_s16(v820, v836); + int16x8_t v878 = vaddq_s16(v876, v877); + int16x8_t v879 = vaddq_s16(v878, v874); + int16x8_t v880 = vaddq_s16(v875, v879); + int16x8_t v881 = vqrdmulhq_n_s16(v880, 17734); + int16x8_t v882 = vaddq_s16(v871, v881); + int16x8_t v883 = vaddq_s16(v873, v866); + int16x8_t v884_tmp = vqrdmulhq_n_s16(v883, 13573); + int16x8_t v884 = vaddq_s16(v884_tmp, v883); + int16x8_t v885 = vaddq_s16(v877, v868); + int16x8_t v886 = vaddq_s16(v869, v872); + int16x8_t v887 = vaddq_s16(v885, v886); + int16x8_t v888 = vaddq_s16(v884, v887); + int16x8_t v889 = vaddq_s16(v886, v883); + int16x8_t v890_tmp = vqrdmulhq_n_s16(v889, 13573); + int16x8_t v890 = vaddq_s16(v890_tmp, v889); + int16x8_t v891 = vld1q_s16(in + in_stride * 249 + i); + int16x8_t v892 = vld1q_s16(in + in_stride * 247 + i); + int16x8_t v893 = vaddq_s16(v891, v892); + int16x8_t v894 = vaddq_s16(v893, v854); + int16x8_t v895 = vaddq_s16(v894, v876); + int16x8_t v896 = vaddq_s16(v895, v885); + int16x8_t v897 = vaddq_s16(v896, v889); + int16x8_t v898 = vaddq_s16(v890, v897); + int16x8_t v899 = vqrdmulhq_n_s16(v898, 17734); + int16x8_t v900 = vaddq_s16(v888, v899); + int16x8_t v901 = vqrdmulhq_n_s16(v900, 16705); + int16x8_t v902 = vaddq_s16(v882, v901); + int16x8_t v903 = vqrdmulhq_n_s16(v902, 16463); + int16x8_t v904 = vaddq_s16(v865, v903); + int16x8_t v905 = vqrdmulhq_n_s16(v904, 16404); + int16x8_t v906 = vaddq_s16(v794, v905); + int16x8_t v907 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v908 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v909 = vaddq_s16(v907, v908); + int16x8_t v910_tmp = vqrdmulhq_n_s16(v909, 13573); + int16x8_t v910 = vaddq_s16(v910_tmp, v909); + int16x8_t v911 = vld1q_s16(in + in_stride * 133 + i); + int16x8_t v912 = vld1q_s16(in + in_stride * 131 + i); + int16x8_t v913 = vaddq_s16(v911, v912); + int16x8_t v914 = vld1q_s16(in + in_stride * 125 + i); + int16x8_t v915 = vld1q_s16(in + in_stride * 123 + i); + int16x8_t v916 = vaddq_s16(v914, v915); + int16x8_t v917 = vaddq_s16(v913, v916); + int16x8_t v918 = vaddq_s16(v910, v917); + int16x8_t v919 = vld1q_s16(in + in_stride * 69 + i); + int16x8_t v920 = vld1q_s16(in + in_stride * 67 + i); + int16x8_t v921 = vaddq_s16(v919, v920); + int16x8_t v922 = vld1q_s16(in + in_stride * 61 + i); + int16x8_t v923 = vld1q_s16(in + in_stride * 59 + i); + int16x8_t v924 = vaddq_s16(v922, v923); + int16x8_t v925 = vaddq_s16(v921, v924); + int16x8_t v926_tmp = vqrdmulhq_n_s16(v925, 13573); + int16x8_t v926 = vaddq_s16(v926_tmp, v925); + int16x8_t v927 = vld1q_s16(in + in_stride * 197 + i); + int16x8_t v928 = vld1q_s16(in + in_stride * 195 + i); + int16x8_t v929 = vaddq_s16(v927, v928); + int16x8_t v930 = vld1q_s16(in + in_stride * 189 + i); + int16x8_t v931 = vld1q_s16(in + in_stride * 187 + i); + int16x8_t v932 = vaddq_s16(v930, v931); + int16x8_t v933 = vaddq_s16(v929, v932); + int16x8_t v934 = vaddq_s16(v933, v925); + int16x8_t v935 = vaddq_s16(v926, v934); + int16x8_t v936 = vqrdmulhq_n_s16(v935, 17734); + int16x8_t v937 = vaddq_s16(v918, v936); + int16x8_t v938 = vld1q_s16(in + in_stride * 37 + i); + int16x8_t v939 = vld1q_s16(in + in_stride * 35 + i); + int16x8_t v940 = vaddq_s16(v938, v939); + int16x8_t v941 = vld1q_s16(in + in_stride * 29 + i); + int16x8_t v942 = vld1q_s16(in + in_stride * 27 + i); + int16x8_t v943 = vaddq_s16(v941, v942); + int16x8_t v944 = vaddq_s16(v940, v943); + int16x8_t v945_tmp = vqrdmulhq_n_s16(v944, 13573); + int16x8_t v945 = vaddq_s16(v945_tmp, v944); + int16x8_t v946 = vld1q_s16(in + in_stride * 165 + i); + int16x8_t v947 = vld1q_s16(in + in_stride * 163 + i); + int16x8_t v948 = vaddq_s16(v946, v947); + int16x8_t v949 = vld1q_s16(in + in_stride * 157 + i); + int16x8_t v950 = vld1q_s16(in + in_stride * 155 + i); + int16x8_t v951 = vaddq_s16(v949, v950); + int16x8_t v952 = vaddq_s16(v948, v951); + int16x8_t v953 = vld1q_s16(in + in_stride * 101 + i); + int16x8_t v954 = vld1q_s16(in + in_stride * 99 + i); + int16x8_t v955 = vaddq_s16(v953, v954); + int16x8_t v956 = vld1q_s16(in + in_stride * 93 + i); + int16x8_t v957 = vld1q_s16(in + in_stride * 91 + i); + int16x8_t v958 = vaddq_s16(v956, v957); + int16x8_t v959 = vaddq_s16(v955, v958); + int16x8_t v960 = vaddq_s16(v952, v959); + int16x8_t v961 = vaddq_s16(v945, v960); + int16x8_t v962 = vaddq_s16(v959, v944); + int16x8_t v963_tmp = vqrdmulhq_n_s16(v962, 13573); + int16x8_t v963 = vaddq_s16(v963_tmp, v962); + int16x8_t v964 = vld1q_s16(in + in_stride * 229 + i); + int16x8_t v965 = vld1q_s16(in + in_stride * 227 + i); + int16x8_t v966 = vaddq_s16(v964, v965); + int16x8_t v967 = vld1q_s16(in + in_stride * 221 + i); + int16x8_t v968 = vld1q_s16(in + in_stride * 219 + i); + int16x8_t v969 = vaddq_s16(v967, v968); + int16x8_t v970 = vaddq_s16(v966, v969); + int16x8_t v971 = vaddq_s16(v970, v952); + int16x8_t v972 = vaddq_s16(v971, v962); + int16x8_t v973 = vaddq_s16(v963, v972); + int16x8_t v974 = vqrdmulhq_n_s16(v973, 17734); + int16x8_t v975 = vaddq_s16(v961, v974); + int16x8_t v976 = vqrdmulhq_n_s16(v975, 16705); + int16x8_t v977 = vaddq_s16(v937, v976); + int16x8_t v978 = vld1q_s16(in + in_stride * 21 + i); + int16x8_t v979 = vld1q_s16(in + in_stride * 19 + i); + int16x8_t v980 = vaddq_s16(v978, v979); + int16x8_t v981 = vld1q_s16(in + in_stride * 13 + i); + int16x8_t v982 = vld1q_s16(in + in_stride * 11 + i); + int16x8_t v983 = vaddq_s16(v981, v982); + int16x8_t v984 = vaddq_s16(v980, v983); + int16x8_t v985_tmp = vqrdmulhq_n_s16(v984, 13573); + int16x8_t v985 = vaddq_s16(v985_tmp, v984); + int16x8_t v986 = vld1q_s16(in + in_stride * 149 + i); + int16x8_t v987 = vld1q_s16(in + in_stride * 147 + i); + int16x8_t v988 = vaddq_s16(v986, v987); + int16x8_t v989 = vld1q_s16(in + in_stride * 141 + i); + int16x8_t v990 = vld1q_s16(in + in_stride * 139 + i); + int16x8_t v991 = vaddq_s16(v989, v990); + int16x8_t v992 = vaddq_s16(v988, v991); + int16x8_t v993 = vld1q_s16(in + in_stride * 117 + i); + int16x8_t v994 = vld1q_s16(in + in_stride * 115 + i); + int16x8_t v995 = vaddq_s16(v993, v994); + int16x8_t v996 = vld1q_s16(in + in_stride * 109 + i); + int16x8_t v997 = vld1q_s16(in + in_stride * 107 + i); + int16x8_t v998 = vaddq_s16(v996, v997); + int16x8_t v999 = vaddq_s16(v995, v998); + int16x8_t v1000 = vaddq_s16(v992, v999); + int16x8_t v1001 = vaddq_s16(v985, v1000); + int16x8_t v1002 = vld1q_s16(in + in_stride * 85 + i); + int16x8_t v1003 = vld1q_s16(in + in_stride * 83 + i); + int16x8_t v1004 = vaddq_s16(v1002, v1003); + int16x8_t v1005 = vld1q_s16(in + in_stride * 77 + i); + int16x8_t v1006 = vld1q_s16(in + in_stride * 75 + i); + int16x8_t v1007 = vaddq_s16(v1005, v1006); + int16x8_t v1008 = vaddq_s16(v1004, v1007); + int16x8_t v1009 = vld1q_s16(in + in_stride * 53 + i); + int16x8_t v1010 = vld1q_s16(in + in_stride * 51 + i); + int16x8_t v1011 = vaddq_s16(v1009, v1010); + int16x8_t v1012 = vld1q_s16(in + in_stride * 45 + i); + int16x8_t v1013 = vld1q_s16(in + in_stride * 43 + i); + int16x8_t v1014 = vaddq_s16(v1012, v1013); + int16x8_t v1015 = vaddq_s16(v1011, v1014); + int16x8_t v1016 = vaddq_s16(v1008, v1015); + int16x8_t v1017_tmp = vqrdmulhq_n_s16(v1016, 13573); + int16x8_t v1017 = vaddq_s16(v1017_tmp, v1016); + int16x8_t v1018 = vld1q_s16(in + in_stride * 213 + i); + int16x8_t v1019 = vld1q_s16(in + in_stride * 211 + i); + int16x8_t v1020 = vaddq_s16(v1018, v1019); + int16x8_t v1021 = vld1q_s16(in + in_stride * 205 + i); + int16x8_t v1022 = vld1q_s16(in + in_stride * 203 + i); + int16x8_t v1023 = vaddq_s16(v1021, v1022); + int16x8_t v1024 = vaddq_s16(v1020, v1023); + int16x8_t v1025 = vld1q_s16(in + in_stride * 181 + i); + int16x8_t v1026 = vld1q_s16(in + in_stride * 179 + i); + int16x8_t v1027 = vaddq_s16(v1025, v1026); + int16x8_t v1028 = vld1q_s16(in + in_stride * 173 + i); + int16x8_t v1029 = vld1q_s16(in + in_stride * 171 + i); + int16x8_t v1030 = vaddq_s16(v1028, v1029); + int16x8_t v1031 = vaddq_s16(v1027, v1030); + int16x8_t v1032 = vaddq_s16(v1024, v1031); + int16x8_t v1033 = vaddq_s16(v1032, v1016); + int16x8_t v1034 = vaddq_s16(v1017, v1033); + int16x8_t v1035 = vqrdmulhq_n_s16(v1034, 17734); + int16x8_t v1036 = vaddq_s16(v1001, v1035); + int16x8_t v1037 = vaddq_s16(v1015, v984); + int16x8_t v1038_tmp = vqrdmulhq_n_s16(v1037, 13573); + int16x8_t v1038 = vaddq_s16(v1038_tmp, v1037); + int16x8_t v1039 = vaddq_s16(v1031, v992); + int16x8_t v1040 = vaddq_s16(v999, v1008); + int16x8_t v1041 = vaddq_s16(v1039, v1040); + int16x8_t v1042 = vaddq_s16(v1038, v1041); + int16x8_t v1043 = vaddq_s16(v1040, v1037); + int16x8_t v1044_tmp = vqrdmulhq_n_s16(v1043, 13573); + int16x8_t v1044 = vaddq_s16(v1044_tmp, v1043); + int16x8_t v1045 = vld1q_s16(in + in_stride * 245 + i); + int16x8_t v1046 = vld1q_s16(in + in_stride * 243 + i); + int16x8_t v1047 = vaddq_s16(v1045, v1046); + int16x8_t v1048 = vld1q_s16(in + in_stride * 237 + i); + int16x8_t v1049 = vld1q_s16(in + in_stride * 235 + i); + int16x8_t v1050 = vaddq_s16(v1048, v1049); + int16x8_t v1051 = vaddq_s16(v1047, v1050); + int16x8_t v1052 = vaddq_s16(v1051, v1024); + int16x8_t v1053 = vaddq_s16(v1052, v1039); + int16x8_t v1054 = vaddq_s16(v1053, v1043); + int16x8_t v1055 = vaddq_s16(v1044, v1054); + int16x8_t v1056 = vqrdmulhq_n_s16(v1055, 17734); + int16x8_t v1057 = vaddq_s16(v1042, v1056); + int16x8_t v1058 = vqrdmulhq_n_s16(v1057, 16705); + int16x8_t v1059 = vaddq_s16(v1036, v1058); + int16x8_t v1060 = vqrdmulhq_n_s16(v1059, 16463); + int16x8_t v1061 = vaddq_s16(v977, v1060); + int16x8_t v1062 = vaddq_s16(v983, v909); + int16x8_t v1063_tmp = vqrdmulhq_n_s16(v1062, 13573); + int16x8_t v1063 = vaddq_s16(v1063_tmp, v1062); + int16x8_t v1064 = vaddq_s16(v991, v913); + int16x8_t v1065 = vaddq_s16(v916, v995); + int16x8_t v1066 = vaddq_s16(v1064, v1065); + int16x8_t v1067 = vaddq_s16(v1063, v1066); + int16x8_t v1068 = vaddq_s16(v1007, v921); + int16x8_t v1069 = vaddq_s16(v924, v1011); + int16x8_t v1070 = vaddq_s16(v1068, v1069); + int16x8_t v1071_tmp = vqrdmulhq_n_s16(v1070, 13573); + int16x8_t v1071 = vaddq_s16(v1071_tmp, v1070); + int16x8_t v1072 = vaddq_s16(v1023, v929); + int16x8_t v1073 = vaddq_s16(v932, v1027); + int16x8_t v1074 = vaddq_s16(v1072, v1073); + int16x8_t v1075 = vaddq_s16(v1074, v1070); + int16x8_t v1076 = vaddq_s16(v1071, v1075); + int16x8_t v1077 = vqrdmulhq_n_s16(v1076, 17734); + int16x8_t v1078 = vaddq_s16(v1067, v1077); + int16x8_t v1079 = vaddq_s16(v1014, v940); + int16x8_t v1080 = vaddq_s16(v943, v980); + int16x8_t v1081 = vaddq_s16(v1079, v1080); + int16x8_t v1082_tmp = vqrdmulhq_n_s16(v1081, 13573); + int16x8_t v1082 = vaddq_s16(v1082_tmp, v1081); + int16x8_t v1083 = vaddq_s16(v1030, v948); + int16x8_t v1084 = vaddq_s16(v951, v988); + int16x8_t v1085 = vaddq_s16(v1083, v1084); + int16x8_t v1086 = vaddq_s16(v998, v955); + int16x8_t v1087 = vaddq_s16(v958, v1004); + int16x8_t v1088 = vaddq_s16(v1086, v1087); + int16x8_t v1089 = vaddq_s16(v1085, v1088); + int16x8_t v1090 = vaddq_s16(v1082, v1089); + int16x8_t v1091 = vaddq_s16(v1088, v1081); + int16x8_t v1092_tmp = vqrdmulhq_n_s16(v1091, 13573); + int16x8_t v1092 = vaddq_s16(v1092_tmp, v1091); + int16x8_t v1093 = vaddq_s16(v1050, v966); + int16x8_t v1094 = vaddq_s16(v969, v1020); + int16x8_t v1095 = vaddq_s16(v1093, v1094); + int16x8_t v1096 = vaddq_s16(v1095, v1085); + int16x8_t v1097 = vaddq_s16(v1096, v1091); + int16x8_t v1098 = vaddq_s16(v1092, v1097); + int16x8_t v1099 = vqrdmulhq_n_s16(v1098, 17734); + int16x8_t v1100 = vaddq_s16(v1090, v1099); + int16x8_t v1101 = vqrdmulhq_n_s16(v1100, 16705); + int16x8_t v1102 = vaddq_s16(v1078, v1101); + int16x8_t v1103 = vaddq_s16(v1080, v1062); + int16x8_t v1104_tmp = vqrdmulhq_n_s16(v1103, 13573); + int16x8_t v1104 = vaddq_s16(v1104_tmp, v1103); + int16x8_t v1105 = vaddq_s16(v1084, v1064); + int16x8_t v1106 = vaddq_s16(v1065, v1086); + int16x8_t v1107 = vaddq_s16(v1105, v1106); + int16x8_t v1108 = vaddq_s16(v1104, v1107); + int16x8_t v1109 = vaddq_s16(v1087, v1068); + int16x8_t v1110 = vaddq_s16(v1069, v1079); + int16x8_t v1111 = vaddq_s16(v1109, v1110); + int16x8_t v1112_tmp = vqrdmulhq_n_s16(v1111, 13573); + int16x8_t v1112 = vaddq_s16(v1112_tmp, v1111); + int16x8_t v1113 = vaddq_s16(v1094, v1072); + int16x8_t v1114 = vaddq_s16(v1073, v1083); + int16x8_t v1115 = vaddq_s16(v1113, v1114); + int16x8_t v1116 = vaddq_s16(v1115, v1111); + int16x8_t v1117 = vaddq_s16(v1112, v1116); + int16x8_t v1118 = vqrdmulhq_n_s16(v1117, 17734); + int16x8_t v1119 = vaddq_s16(v1108, v1118); + int16x8_t v1120 = vaddq_s16(v1110, v1103); + int16x8_t v1121_tmp = vqrdmulhq_n_s16(v1120, 13573); + int16x8_t v1121 = vaddq_s16(v1121_tmp, v1120); + int16x8_t v1122 = vaddq_s16(v1114, v1105); + int16x8_t v1123 = vaddq_s16(v1106, v1109); + int16x8_t v1124 = vaddq_s16(v1122, v1123); + int16x8_t v1125 = vaddq_s16(v1121, v1124); + int16x8_t v1126 = vaddq_s16(v1123, v1120); + int16x8_t v1127_tmp = vqrdmulhq_n_s16(v1126, 13573); + int16x8_t v1127 = vaddq_s16(v1127_tmp, v1126); + int16x8_t v1128 = vld1q_s16(in + in_stride * 253 + i); + int16x8_t v1129 = vld1q_s16(in + in_stride * 251 + i); + int16x8_t v1130 = vaddq_s16(v1128, v1129); + int16x8_t v1131 = vaddq_s16(v1130, v1047); + int16x8_t v1132 = vaddq_s16(v1131, v1093); + int16x8_t v1133 = vaddq_s16(v1132, v1113); + int16x8_t v1134 = vaddq_s16(v1133, v1122); + int16x8_t v1135 = vaddq_s16(v1134, v1126); + int16x8_t v1136 = vaddq_s16(v1127, v1135); + int16x8_t v1137 = vqrdmulhq_n_s16(v1136, 17734); + int16x8_t v1138 = vaddq_s16(v1125, v1137); + int16x8_t v1139 = vqrdmulhq_n_s16(v1138, 16705); + int16x8_t v1140 = vaddq_s16(v1119, v1139); + int16x8_t v1141 = vqrdmulhq_n_s16(v1140, 16463); + int16x8_t v1142 = vaddq_s16(v1102, v1141); + int16x8_t v1143 = vqrdmulhq_n_s16(v1142, 16404); + int16x8_t v1144 = vaddq_s16(v1061, v1143); + int16x8_t v1145 = vqrdmulhq_n_s16(v1144, 16389); + int16x8_t v1146 = vaddq_s16(v906, v1145); + int16x8_t v1147 = vaddq_s16(v908, v702); + int16x8_t v1148_tmp = vqrdmulhq_n_s16(v1147, 13573); + int16x8_t v1148 = vaddq_s16(v1148_tmp, v1147); + int16x8_t v1149 = vaddq_s16(v912, v704); + int16x8_t v1150 = vaddq_s16(v705, v914); + int16x8_t v1151 = vaddq_s16(v1149, v1150); + int16x8_t v1152 = vaddq_s16(v1148, v1151); + int16x8_t v1153 = vaddq_s16(v920, v708); + int16x8_t v1154 = vaddq_s16(v709, v922); + int16x8_t v1155 = vaddq_s16(v1153, v1154); + int16x8_t v1156_tmp = vqrdmulhq_n_s16(v1155, 13573); + int16x8_t v1156 = vaddq_s16(v1156_tmp, v1155); + int16x8_t v1157 = vaddq_s16(v928, v712); + int16x8_t v1158 = vaddq_s16(v713, v930); + int16x8_t v1159 = vaddq_s16(v1157, v1158); + int16x8_t v1160 = vaddq_s16(v1159, v1155); + int16x8_t v1161 = vaddq_s16(v1156, v1160); + int16x8_t v1162 = vqrdmulhq_n_s16(v1161, 17734); + int16x8_t v1163 = vaddq_s16(v1152, v1162); + int16x8_t v1164 = vaddq_s16(v939, v719); + int16x8_t v1165 = vaddq_s16(v720, v941); + int16x8_t v1166 = vaddq_s16(v1164, v1165); + int16x8_t v1167_tmp = vqrdmulhq_n_s16(v1166, 13573); + int16x8_t v1167 = vaddq_s16(v1167_tmp, v1166); + int16x8_t v1168 = vaddq_s16(v947, v723); + int16x8_t v1169 = vaddq_s16(v724, v949); + int16x8_t v1170 = vaddq_s16(v1168, v1169); + int16x8_t v1171 = vaddq_s16(v954, v726); + int16x8_t v1172 = vaddq_s16(v727, v956); + int16x8_t v1173 = vaddq_s16(v1171, v1172); + int16x8_t v1174 = vaddq_s16(v1170, v1173); + int16x8_t v1175 = vaddq_s16(v1167, v1174); + int16x8_t v1176 = vaddq_s16(v1173, v1166); + int16x8_t v1177_tmp = vqrdmulhq_n_s16(v1176, 13573); + int16x8_t v1177 = vaddq_s16(v1177_tmp, v1176); + int16x8_t v1178 = vaddq_s16(v965, v733); + int16x8_t v1179 = vaddq_s16(v734, v967); + int16x8_t v1180 = vaddq_s16(v1178, v1179); + int16x8_t v1181 = vaddq_s16(v1180, v1170); + int16x8_t v1182 = vaddq_s16(v1181, v1176); + int16x8_t v1183 = vaddq_s16(v1177, v1182); + int16x8_t v1184 = vqrdmulhq_n_s16(v1183, 17734); + int16x8_t v1185 = vaddq_s16(v1175, v1184); + int16x8_t v1186 = vqrdmulhq_n_s16(v1185, 16705); + int16x8_t v1187 = vaddq_s16(v1163, v1186); + int16x8_t v1188 = vaddq_s16(v979, v743); + int16x8_t v1189 = vaddq_s16(v744, v981); + int16x8_t v1190 = vaddq_s16(v1188, v1189); + int16x8_t v1191_tmp = vqrdmulhq_n_s16(v1190, 13573); + int16x8_t v1191 = vaddq_s16(v1191_tmp, v1190); + int16x8_t v1192 = vaddq_s16(v987, v747); + int16x8_t v1193 = vaddq_s16(v748, v989); + int16x8_t v1194 = vaddq_s16(v1192, v1193); + int16x8_t v1195 = vaddq_s16(v994, v750); + int16x8_t v1196 = vaddq_s16(v751, v996); + int16x8_t v1197 = vaddq_s16(v1195, v1196); + int16x8_t v1198 = vaddq_s16(v1194, v1197); + int16x8_t v1199 = vaddq_s16(v1191, v1198); + int16x8_t v1200 = vaddq_s16(v1003, v755); + int16x8_t v1201 = vaddq_s16(v756, v1005); + int16x8_t v1202 = vaddq_s16(v1200, v1201); + int16x8_t v1203 = vaddq_s16(v1010, v758); + int16x8_t v1204 = vaddq_s16(v759, v1012); + int16x8_t v1205 = vaddq_s16(v1203, v1204); + int16x8_t v1206 = vaddq_s16(v1202, v1205); + int16x8_t v1207_tmp = vqrdmulhq_n_s16(v1206, 13573); + int16x8_t v1207 = vaddq_s16(v1207_tmp, v1206); + int16x8_t v1208 = vaddq_s16(v1019, v763); + int16x8_t v1209 = vaddq_s16(v764, v1021); + int16x8_t v1210 = vaddq_s16(v1208, v1209); + int16x8_t v1211 = vaddq_s16(v1026, v766); + int16x8_t v1212 = vaddq_s16(v767, v1028); + int16x8_t v1213 = vaddq_s16(v1211, v1212); + int16x8_t v1214 = vaddq_s16(v1210, v1213); + int16x8_t v1215 = vaddq_s16(v1214, v1206); + int16x8_t v1216 = vaddq_s16(v1207, v1215); + int16x8_t v1217 = vqrdmulhq_n_s16(v1216, 17734); + int16x8_t v1218 = vaddq_s16(v1199, v1217); + int16x8_t v1219 = vaddq_s16(v1205, v1190); + int16x8_t v1220_tmp = vqrdmulhq_n_s16(v1219, 13573); + int16x8_t v1220 = vaddq_s16(v1220_tmp, v1219); + int16x8_t v1221 = vaddq_s16(v1213, v1194); + int16x8_t v1222 = vaddq_s16(v1197, v1202); + int16x8_t v1223 = vaddq_s16(v1221, v1222); + int16x8_t v1224 = vaddq_s16(v1220, v1223); + int16x8_t v1225 = vaddq_s16(v1222, v1219); + int16x8_t v1226_tmp = vqrdmulhq_n_s16(v1225, 13573); + int16x8_t v1226 = vaddq_s16(v1226_tmp, v1225); + int16x8_t v1227 = vaddq_s16(v1046, v782); + int16x8_t v1228 = vaddq_s16(v783, v1048); + int16x8_t v1229 = vaddq_s16(v1227, v1228); + int16x8_t v1230 = vaddq_s16(v1229, v1210); + int16x8_t v1231 = vaddq_s16(v1230, v1221); + int16x8_t v1232 = vaddq_s16(v1231, v1225); + int16x8_t v1233 = vaddq_s16(v1226, v1232); + int16x8_t v1234 = vqrdmulhq_n_s16(v1233, 17734); + int16x8_t v1235 = vaddq_s16(v1224, v1234); + int16x8_t v1236 = vqrdmulhq_n_s16(v1235, 16705); + int16x8_t v1237 = vaddq_s16(v1218, v1236); + int16x8_t v1238 = vqrdmulhq_n_s16(v1237, 16463); + int16x8_t v1239 = vaddq_s16(v1187, v1238); + int16x8_t v1240 = vaddq_s16(v982, v795); + int16x8_t v1241 = vaddq_s16(v796, v907); + int16x8_t v1242 = vaddq_s16(v1240, v1241); + int16x8_t v1243_tmp = vqrdmulhq_n_s16(v1242, 13573); + int16x8_t v1243 = vaddq_s16(v1243_tmp, v1242); + int16x8_t v1244 = vaddq_s16(v990, v799); + int16x8_t v1245 = vaddq_s16(v800, v911); + int16x8_t v1246 = vaddq_s16(v1244, v1245); + int16x8_t v1247 = vaddq_s16(v915, v802); + int16x8_t v1248 = vaddq_s16(v803, v993); + int16x8_t v1249 = vaddq_s16(v1247, v1248); + int16x8_t v1250 = vaddq_s16(v1246, v1249); + int16x8_t v1251 = vaddq_s16(v1243, v1250); + int16x8_t v1252 = vaddq_s16(v1006, v807); + int16x8_t v1253 = vaddq_s16(v808, v919); + int16x8_t v1254 = vaddq_s16(v1252, v1253); + int16x8_t v1255 = vaddq_s16(v923, v810); + int16x8_t v1256 = vaddq_s16(v811, v1009); + int16x8_t v1257 = vaddq_s16(v1255, v1256); + int16x8_t v1258 = vaddq_s16(v1254, v1257); + int16x8_t v1259_tmp = vqrdmulhq_n_s16(v1258, 13573); + int16x8_t v1259 = vaddq_s16(v1259_tmp, v1258); + int16x8_t v1260 = vaddq_s16(v1022, v815); + int16x8_t v1261 = vaddq_s16(v816, v927); + int16x8_t v1262 = vaddq_s16(v1260, v1261); + int16x8_t v1263 = vaddq_s16(v931, v818); + int16x8_t v1264 = vaddq_s16(v819, v1025); + int16x8_t v1265 = vaddq_s16(v1263, v1264); + int16x8_t v1266 = vaddq_s16(v1262, v1265); + int16x8_t v1267 = vaddq_s16(v1266, v1258); + int16x8_t v1268 = vaddq_s16(v1259, v1267); + int16x8_t v1269 = vqrdmulhq_n_s16(v1268, 17734); + int16x8_t v1270 = vaddq_s16(v1251, v1269); + int16x8_t v1271 = vaddq_s16(v1013, v826); + int16x8_t v1272 = vaddq_s16(v827, v938); + int16x8_t v1273 = vaddq_s16(v1271, v1272); + int16x8_t v1274 = vaddq_s16(v942, v829); + int16x8_t v1275 = vaddq_s16(v830, v978); + int16x8_t v1276 = vaddq_s16(v1274, v1275); + int16x8_t v1277 = vaddq_s16(v1273, v1276); + int16x8_t v1278_tmp = vqrdmulhq_n_s16(v1277, 13573); + int16x8_t v1278 = vaddq_s16(v1278_tmp, v1277); + int16x8_t v1279 = vaddq_s16(v1029, v834); + int16x8_t v1280 = vaddq_s16(v835, v946); + int16x8_t v1281 = vaddq_s16(v1279, v1280); + int16x8_t v1282 = vaddq_s16(v950, v837); + int16x8_t v1283 = vaddq_s16(v838, v986); + int16x8_t v1284 = vaddq_s16(v1282, v1283); + int16x8_t v1285 = vaddq_s16(v1281, v1284); + int16x8_t v1286 = vaddq_s16(v997, v841); + int16x8_t v1287 = vaddq_s16(v842, v953); + int16x8_t v1288 = vaddq_s16(v1286, v1287); + int16x8_t v1289 = vaddq_s16(v957, v844); + int16x8_t v1290 = vaddq_s16(v845, v1002); + int16x8_t v1291 = vaddq_s16(v1289, v1290); + int16x8_t v1292 = vaddq_s16(v1288, v1291); + int16x8_t v1293 = vaddq_s16(v1285, v1292); + int16x8_t v1294 = vaddq_s16(v1278, v1293); + int16x8_t v1295 = vaddq_s16(v1292, v1277); + int16x8_t v1296_tmp = vqrdmulhq_n_s16(v1295, 13573); + int16x8_t v1296 = vaddq_s16(v1296_tmp, v1295); + int16x8_t v1297 = vaddq_s16(v1049, v852); + int16x8_t v1298 = vaddq_s16(v853, v964); + int16x8_t v1299 = vaddq_s16(v1297, v1298); + int16x8_t v1300 = vaddq_s16(v968, v855); + int16x8_t v1301 = vaddq_s16(v856, v1018); + int16x8_t v1302 = vaddq_s16(v1300, v1301); + int16x8_t v1303 = vaddq_s16(v1299, v1302); + int16x8_t v1304 = vaddq_s16(v1303, v1285); + int16x8_t v1305 = vaddq_s16(v1304, v1295); + int16x8_t v1306 = vaddq_s16(v1296, v1305); + int16x8_t v1307 = vqrdmulhq_n_s16(v1306, 17734); + int16x8_t v1308 = vaddq_s16(v1294, v1307); + int16x8_t v1309 = vqrdmulhq_n_s16(v1308, 16705); + int16x8_t v1310 = vaddq_s16(v1270, v1309); + int16x8_t v1311 = vaddq_s16(v1276, v1242); + int16x8_t v1312_tmp = vqrdmulhq_n_s16(v1311, 13573); + int16x8_t v1312 = vaddq_s16(v1312_tmp, v1311); + int16x8_t v1313 = vaddq_s16(v1284, v1246); + int16x8_t v1314 = vaddq_s16(v1249, v1288); + int16x8_t v1315 = vaddq_s16(v1313, v1314); + int16x8_t v1316 = vaddq_s16(v1312, v1315); + int16x8_t v1317 = vaddq_s16(v1291, v1254); + int16x8_t v1318 = vaddq_s16(v1257, v1273); + int16x8_t v1319 = vaddq_s16(v1317, v1318); + int16x8_t v1320_tmp = vqrdmulhq_n_s16(v1319, 13573); + int16x8_t v1320 = vaddq_s16(v1320_tmp, v1319); + int16x8_t v1321 = vaddq_s16(v1302, v1262); + int16x8_t v1322 = vaddq_s16(v1265, v1281); + int16x8_t v1323 = vaddq_s16(v1321, v1322); + int16x8_t v1324 = vaddq_s16(v1323, v1319); + int16x8_t v1325 = vaddq_s16(v1320, v1324); + int16x8_t v1326 = vqrdmulhq_n_s16(v1325, 17734); + int16x8_t v1327 = vaddq_s16(v1316, v1326); + int16x8_t v1328 = vaddq_s16(v1318, v1311); + int16x8_t v1329_tmp = vqrdmulhq_n_s16(v1328, 13573); + int16x8_t v1329 = vaddq_s16(v1329_tmp, v1328); + int16x8_t v1330 = vaddq_s16(v1322, v1313); + int16x8_t v1331 = vaddq_s16(v1314, v1317); + int16x8_t v1332 = vaddq_s16(v1330, v1331); + int16x8_t v1333 = vaddq_s16(v1329, v1332); + int16x8_t v1334 = vaddq_s16(v1331, v1328); + int16x8_t v1335_tmp = vqrdmulhq_n_s16(v1334, 13573); + int16x8_t v1335 = vaddq_s16(v1335_tmp, v1334); + int16x8_t v1336 = vaddq_s16(v1129, v891); + int16x8_t v1337 = vaddq_s16(v892, v1045); + int16x8_t v1338 = vaddq_s16(v1336, v1337); + int16x8_t v1339 = vaddq_s16(v1338, v1299); + int16x8_t v1340 = vaddq_s16(v1339, v1321); + int16x8_t v1341 = vaddq_s16(v1340, v1330); + int16x8_t v1342 = vaddq_s16(v1341, v1334); + int16x8_t v1343 = vaddq_s16(v1335, v1342); + int16x8_t v1344 = vqrdmulhq_n_s16(v1343, 17734); + int16x8_t v1345 = vaddq_s16(v1333, v1344); + int16x8_t v1346 = vqrdmulhq_n_s16(v1345, 16705); + int16x8_t v1347 = vaddq_s16(v1327, v1346); + int16x8_t v1348 = vqrdmulhq_n_s16(v1347, 16463); + int16x8_t v1349 = vaddq_s16(v1310, v1348); + int16x8_t v1350 = vqrdmulhq_n_s16(v1349, 16404); + int16x8_t v1351 = vaddq_s16(v1239, v1350); + int16x8_t v1352 = vaddq_s16(v1241, v1147); + int16x8_t v1353_tmp = vqrdmulhq_n_s16(v1352, 13573); + int16x8_t v1353 = vaddq_s16(v1353_tmp, v1352); + int16x8_t v1354 = vaddq_s16(v1245, v1149); + int16x8_t v1355 = vaddq_s16(v1150, v1247); + int16x8_t v1356 = vaddq_s16(v1354, v1355); + int16x8_t v1357 = vaddq_s16(v1353, v1356); + int16x8_t v1358 = vaddq_s16(v1253, v1153); + int16x8_t v1359 = vaddq_s16(v1154, v1255); + int16x8_t v1360 = vaddq_s16(v1358, v1359); + int16x8_t v1361_tmp = vqrdmulhq_n_s16(v1360, 13573); + int16x8_t v1361 = vaddq_s16(v1361_tmp, v1360); + int16x8_t v1362 = vaddq_s16(v1261, v1157); + int16x8_t v1363 = vaddq_s16(v1158, v1263); + int16x8_t v1364 = vaddq_s16(v1362, v1363); + int16x8_t v1365 = vaddq_s16(v1364, v1360); + int16x8_t v1366 = vaddq_s16(v1361, v1365); + int16x8_t v1367 = vqrdmulhq_n_s16(v1366, 17734); + int16x8_t v1368 = vaddq_s16(v1357, v1367); + int16x8_t v1369 = vaddq_s16(v1272, v1164); + int16x8_t v1370 = vaddq_s16(v1165, v1274); + int16x8_t v1371 = vaddq_s16(v1369, v1370); + int16x8_t v1372_tmp = vqrdmulhq_n_s16(v1371, 13573); + int16x8_t v1372 = vaddq_s16(v1372_tmp, v1371); + int16x8_t v1373 = vaddq_s16(v1280, v1168); + int16x8_t v1374 = vaddq_s16(v1169, v1282); + int16x8_t v1375 = vaddq_s16(v1373, v1374); + int16x8_t v1376 = vaddq_s16(v1287, v1171); + int16x8_t v1377 = vaddq_s16(v1172, v1289); + int16x8_t v1378 = vaddq_s16(v1376, v1377); + int16x8_t v1379 = vaddq_s16(v1375, v1378); + int16x8_t v1380 = vaddq_s16(v1372, v1379); + int16x8_t v1381 = vaddq_s16(v1378, v1371); + int16x8_t v1382_tmp = vqrdmulhq_n_s16(v1381, 13573); + int16x8_t v1382 = vaddq_s16(v1382_tmp, v1381); + int16x8_t v1383 = vaddq_s16(v1298, v1178); + int16x8_t v1384 = vaddq_s16(v1179, v1300); + int16x8_t v1385 = vaddq_s16(v1383, v1384); + int16x8_t v1386 = vaddq_s16(v1385, v1375); + int16x8_t v1387 = vaddq_s16(v1386, v1381); + int16x8_t v1388 = vaddq_s16(v1382, v1387); + int16x8_t v1389 = vqrdmulhq_n_s16(v1388, 17734); + int16x8_t v1390 = vaddq_s16(v1380, v1389); + int16x8_t v1391 = vqrdmulhq_n_s16(v1390, 16705); + int16x8_t v1392 = vaddq_s16(v1368, v1391); + int16x8_t v1393 = vaddq_s16(v1275, v1188); + int16x8_t v1394 = vaddq_s16(v1189, v1240); + int16x8_t v1395 = vaddq_s16(v1393, v1394); + int16x8_t v1396_tmp = vqrdmulhq_n_s16(v1395, 13573); + int16x8_t v1396 = vaddq_s16(v1396_tmp, v1395); + int16x8_t v1397 = vaddq_s16(v1283, v1192); + int16x8_t v1398 = vaddq_s16(v1193, v1244); + int16x8_t v1399 = vaddq_s16(v1397, v1398); + int16x8_t v1400 = vaddq_s16(v1248, v1195); + int16x8_t v1401 = vaddq_s16(v1196, v1286); + int16x8_t v1402 = vaddq_s16(v1400, v1401); + int16x8_t v1403 = vaddq_s16(v1399, v1402); + int16x8_t v1404 = vaddq_s16(v1396, v1403); + int16x8_t v1405 = vaddq_s16(v1290, v1200); + int16x8_t v1406 = vaddq_s16(v1201, v1252); + int16x8_t v1407 = vaddq_s16(v1405, v1406); + int16x8_t v1408 = vaddq_s16(v1256, v1203); + int16x8_t v1409 = vaddq_s16(v1204, v1271); + int16x8_t v1410 = vaddq_s16(v1408, v1409); + int16x8_t v1411 = vaddq_s16(v1407, v1410); + int16x8_t v1412_tmp = vqrdmulhq_n_s16(v1411, 13573); + int16x8_t v1412 = vaddq_s16(v1412_tmp, v1411); + int16x8_t v1413 = vaddq_s16(v1301, v1208); + int16x8_t v1414 = vaddq_s16(v1209, v1260); + int16x8_t v1415 = vaddq_s16(v1413, v1414); + int16x8_t v1416 = vaddq_s16(v1264, v1211); + int16x8_t v1417 = vaddq_s16(v1212, v1279); + int16x8_t v1418 = vaddq_s16(v1416, v1417); + int16x8_t v1419 = vaddq_s16(v1415, v1418); + int16x8_t v1420 = vaddq_s16(v1419, v1411); + int16x8_t v1421 = vaddq_s16(v1412, v1420); + int16x8_t v1422 = vqrdmulhq_n_s16(v1421, 17734); + int16x8_t v1423 = vaddq_s16(v1404, v1422); + int16x8_t v1424 = vaddq_s16(v1410, v1395); + int16x8_t v1425_tmp = vqrdmulhq_n_s16(v1424, 13573); + int16x8_t v1425 = vaddq_s16(v1425_tmp, v1424); + int16x8_t v1426 = vaddq_s16(v1418, v1399); + int16x8_t v1427 = vaddq_s16(v1402, v1407); + int16x8_t v1428 = vaddq_s16(v1426, v1427); + int16x8_t v1429 = vaddq_s16(v1425, v1428); + int16x8_t v1430 = vaddq_s16(v1427, v1424); + int16x8_t v1431_tmp = vqrdmulhq_n_s16(v1430, 13573); + int16x8_t v1431 = vaddq_s16(v1431_tmp, v1430); + int16x8_t v1432 = vaddq_s16(v1337, v1227); + int16x8_t v1433 = vaddq_s16(v1228, v1297); + int16x8_t v1434 = vaddq_s16(v1432, v1433); + int16x8_t v1435 = vaddq_s16(v1434, v1415); + int16x8_t v1436 = vaddq_s16(v1435, v1426); + int16x8_t v1437 = vaddq_s16(v1436, v1430); + int16x8_t v1438 = vaddq_s16(v1431, v1437); + int16x8_t v1439 = vqrdmulhq_n_s16(v1438, 17734); + int16x8_t v1440 = vaddq_s16(v1429, v1439); + int16x8_t v1441 = vqrdmulhq_n_s16(v1440, 16705); + int16x8_t v1442 = vaddq_s16(v1423, v1441); + int16x8_t v1443 = vqrdmulhq_n_s16(v1442, 16463); + int16x8_t v1444 = vaddq_s16(v1392, v1443); + int16x8_t v1445 = vaddq_s16(v1394, v1352); + int16x8_t v1446_tmp = vqrdmulhq_n_s16(v1445, 13573); + int16x8_t v1446 = vaddq_s16(v1446_tmp, v1445); + int16x8_t v1447 = vaddq_s16(v1398, v1354); + int16x8_t v1448 = vaddq_s16(v1355, v1400); + int16x8_t v1449 = vaddq_s16(v1447, v1448); + int16x8_t v1450 = vaddq_s16(v1446, v1449); + int16x8_t v1451 = vaddq_s16(v1406, v1358); + int16x8_t v1452 = vaddq_s16(v1359, v1408); + int16x8_t v1453 = vaddq_s16(v1451, v1452); + int16x8_t v1454_tmp = vqrdmulhq_n_s16(v1453, 13573); + int16x8_t v1454 = vaddq_s16(v1454_tmp, v1453); + int16x8_t v1455 = vaddq_s16(v1414, v1362); + int16x8_t v1456 = vaddq_s16(v1363, v1416); + int16x8_t v1457 = vaddq_s16(v1455, v1456); + int16x8_t v1458 = vaddq_s16(v1457, v1453); + int16x8_t v1459 = vaddq_s16(v1454, v1458); + int16x8_t v1460 = vqrdmulhq_n_s16(v1459, 17734); + int16x8_t v1461 = vaddq_s16(v1450, v1460); + int16x8_t v1462 = vaddq_s16(v1409, v1369); + int16x8_t v1463 = vaddq_s16(v1370, v1393); + int16x8_t v1464 = vaddq_s16(v1462, v1463); + int16x8_t v1465_tmp = vqrdmulhq_n_s16(v1464, 13573); + int16x8_t v1465 = vaddq_s16(v1465_tmp, v1464); + int16x8_t v1466 = vaddq_s16(v1417, v1373); + int16x8_t v1467 = vaddq_s16(v1374, v1397); + int16x8_t v1468 = vaddq_s16(v1466, v1467); + int16x8_t v1469 = vaddq_s16(v1401, v1376); + int16x8_t v1470 = vaddq_s16(v1377, v1405); + int16x8_t v1471 = vaddq_s16(v1469, v1470); + int16x8_t v1472 = vaddq_s16(v1468, v1471); + int16x8_t v1473 = vaddq_s16(v1465, v1472); + int16x8_t v1474 = vaddq_s16(v1471, v1464); + int16x8_t v1475_tmp = vqrdmulhq_n_s16(v1474, 13573); + int16x8_t v1475 = vaddq_s16(v1475_tmp, v1474); + int16x8_t v1476 = vaddq_s16(v1433, v1383); + int16x8_t v1477 = vaddq_s16(v1384, v1413); + int16x8_t v1478 = vaddq_s16(v1476, v1477); + int16x8_t v1479 = vaddq_s16(v1478, v1468); + int16x8_t v1480 = vaddq_s16(v1479, v1474); + int16x8_t v1481 = vaddq_s16(v1475, v1480); + int16x8_t v1482 = vqrdmulhq_n_s16(v1481, 17734); + int16x8_t v1483 = vaddq_s16(v1473, v1482); + int16x8_t v1484 = vqrdmulhq_n_s16(v1483, 16705); + int16x8_t v1485 = vaddq_s16(v1461, v1484); + int16x8_t v1486 = vaddq_s16(v1463, v1445); + int16x8_t v1487_tmp = vqrdmulhq_n_s16(v1486, 13573); + int16x8_t v1487 = vaddq_s16(v1487_tmp, v1486); + int16x8_t v1488 = vaddq_s16(v1467, v1447); + int16x8_t v1489 = vaddq_s16(v1448, v1469); + int16x8_t v1490 = vaddq_s16(v1488, v1489); + int16x8_t v1491 = vaddq_s16(v1487, v1490); + int16x8_t v1492 = vaddq_s16(v1470, v1451); + int16x8_t v1493 = vaddq_s16(v1452, v1462); + int16x8_t v1494 = vaddq_s16(v1492, v1493); + int16x8_t v1495_tmp = vqrdmulhq_n_s16(v1494, 13573); + int16x8_t v1495 = vaddq_s16(v1495_tmp, v1494); + int16x8_t v1496 = vaddq_s16(v1477, v1455); + int16x8_t v1497 = vaddq_s16(v1456, v1466); + int16x8_t v1498 = vaddq_s16(v1496, v1497); + int16x8_t v1499 = vaddq_s16(v1498, v1494); + int16x8_t v1500 = vaddq_s16(v1495, v1499); + int16x8_t v1501 = vqrdmulhq_n_s16(v1500, 17734); + int16x8_t v1502 = vaddq_s16(v1491, v1501); + int16x8_t v1503 = vaddq_s16(v1493, v1486); + int16x8_t v1504_tmp = vqrdmulhq_n_s16(v1503, 13573); + int16x8_t v1504 = vaddq_s16(v1504_tmp, v1503); + int16x8_t v1505 = vaddq_s16(v1497, v1488); + int16x8_t v1506 = vaddq_s16(v1489, v1492); + int16x8_t v1507 = vaddq_s16(v1505, v1506); + int16x8_t v1508 = vaddq_s16(v1504, v1507); + int16x8_t v1509 = vaddq_s16(v1506, v1503); + int16x8_t v1510_tmp = vqrdmulhq_n_s16(v1509, 13573); + int16x8_t v1510 = vaddq_s16(v1510_tmp, v1509); + int16x8_t v1511 = vld1q_s16(in + in_stride * 255 + i); + int16x8_t v1512 = vaddq_s16(v1511, v1128); + int16x8_t v1513 = vaddq_s16(v1512, v1336); + int16x8_t v1514 = vaddq_s16(v1513, v1432); + int16x8_t v1515 = vaddq_s16(v1514, v1476); + int16x8_t v1516 = vaddq_s16(v1515, v1496); + int16x8_t v1517 = vaddq_s16(v1516, v1505); + int16x8_t v1518 = vaddq_s16(v1517, v1509); + int16x8_t v1519 = vaddq_s16(v1510, v1518); + int16x8_t v1520 = vqrdmulhq_n_s16(v1519, 17734); + int16x8_t v1521 = vaddq_s16(v1508, v1520); + int16x8_t v1522 = vqrdmulhq_n_s16(v1521, 16705); + int16x8_t v1523 = vaddq_s16(v1502, v1522); + int16x8_t v1524 = vqrdmulhq_n_s16(v1523, 16463); + int16x8_t v1525 = vaddq_s16(v1485, v1524); + int16x8_t v1526 = vqrdmulhq_n_s16(v1525, 16404); + int16x8_t v1527 = vaddq_s16(v1444, v1526); + int16x8_t v1528 = vqrdmulhq_n_s16(v1527, 16389); + int16x8_t v1529 = vaddq_s16(v1351, v1528); + int16x8_t v1530 = vqrdmulhq_n_s16(v1529, 16385); + int16x8_t v1531 = vaddq_s16(v1146, v1530); + int16x8_t v1532 = vqrdmulhq_n_s16(v1531, 16384); + int16x8_t v1533 = vaddq_s16(v701, v1532); + int16x8_t v1534 = vsubq_s16(v0, v1); + int16x8_t v1535 = vsubq_s16(v4, v6); + int16x8_t v1536_tmp = vqrdmulhq_n_s16(v1535, 10045); + int16x8_t v1536 = vaddq_s16(v1536_tmp, v1535); + int16x8_t v1537 = vaddq_s16(v1534, v1536); + int16x8_t v1538 = vsubq_s16(v11, v14); + int16x8_t v1539 = vsubq_s16(v17, v20); + int16x8_t v1540_tmp = vqrdmulhq_n_s16(v1539, 10045); + int16x8_t v1540 = vaddq_s16(v1540_tmp, v1539); + int16x8_t v1541 = vaddq_s16(v1538, v1540); + int16x8_t v1542 = vqrdmulhq_n_s16(v1541, 19705); + int16x8_t v1543 = vaddq_s16(v1537, v1542); + int16x8_t v1544 = vsubq_s16(v27, v30); + int16x8_t v1545 = vsubq_s16(v35, v39); + int16x8_t v1546_tmp = vqrdmulhq_n_s16(v1545, 10045); + int16x8_t v1546 = vaddq_s16(v1546_tmp, v1545); + int16x8_t v1547 = vaddq_s16(v1544, v1546); + int16x8_t v1548 = vsubq_s16(v44, v47); + int16x8_t v1549 = vsubq_s16(v50, v54); + int16x8_t v1550_tmp = vqrdmulhq_n_s16(v1549, 10045); + int16x8_t v1550 = vaddq_s16(v1550_tmp, v1549); + int16x8_t v1551 = vaddq_s16(v1548, v1550); + int16x8_t v1552 = vqrdmulhq_n_s16(v1551, 19705); + int16x8_t v1553 = vaddq_s16(v1547, v1552); + int16x8_t v1554 = vqrdmulhq_n_s16(v1553, 17121); + int16x8_t v1555 = vaddq_s16(v1543, v1554); + int16x8_t v1556 = vsubq_s16(v63, v66); + int16x8_t v1557 = vsubq_s16(v71, v75); + int16x8_t v1558_tmp = vqrdmulhq_n_s16(v1557, 10045); + int16x8_t v1558 = vaddq_s16(v1558_tmp, v1557); + int16x8_t v1559 = vaddq_s16(v1556, v1558); + int16x8_t v1560 = vsubq_s16(v82, v89); + int16x8_t v1561 = vsubq_s16(v92, v97); + int16x8_t v1562_tmp = vqrdmulhq_n_s16(v1561, 10045); + int16x8_t v1562 = vaddq_s16(v1562_tmp, v1561); + int16x8_t v1563 = vaddq_s16(v1560, v1562); + int16x8_t v1564 = vqrdmulhq_n_s16(v1563, 19705); + int16x8_t v1565 = vaddq_s16(v1559, v1564); + int16x8_t v1566 = vsubq_s16(v104, v107); + int16x8_t v1567 = vsubq_s16(v112, v116); + int16x8_t v1568_tmp = vqrdmulhq_n_s16(v1567, 10045); + int16x8_t v1568 = vaddq_s16(v1568_tmp, v1567); + int16x8_t v1569 = vaddq_s16(v1566, v1568); + int16x8_t v1570 = vsubq_s16(v121, v124); + int16x8_t v1571 = vsubq_s16(v127, v132); + int16x8_t v1572_tmp = vqrdmulhq_n_s16(v1571, 10045); + int16x8_t v1572 = vaddq_s16(v1572_tmp, v1571); + int16x8_t v1573 = vaddq_s16(v1570, v1572); + int16x8_t v1574 = vqrdmulhq_n_s16(v1573, 19705); + int16x8_t v1575 = vaddq_s16(v1569, v1574); + int16x8_t v1576 = vqrdmulhq_n_s16(v1575, 17121); + int16x8_t v1577 = vaddq_s16(v1565, v1576); + int16x8_t v1578 = vqrdmulhq_n_s16(v1577, 16563); + int16x8_t v1579 = vaddq_s16(v1555, v1578); + int16x8_t v1580 = vsubq_s16(v143, v146); + int16x8_t v1581 = vsubq_s16(v151, v155); + int16x8_t v1582_tmp = vqrdmulhq_n_s16(v1581, 10045); + int16x8_t v1582 = vaddq_s16(v1582_tmp, v1581); + int16x8_t v1583 = vaddq_s16(v1580, v1582); + int16x8_t v1584 = vsubq_s16(v162, v169); + int16x8_t v1585 = vsubq_s16(v172, v177); + int16x8_t v1586_tmp = vqrdmulhq_n_s16(v1585, 10045); + int16x8_t v1586 = vaddq_s16(v1586_tmp, v1585); + int16x8_t v1587 = vaddq_s16(v1584, v1586); + int16x8_t v1588 = vqrdmulhq_n_s16(v1587, 19705); + int16x8_t v1589 = vaddq_s16(v1583, v1588); + int16x8_t v1590 = vsubq_s16(v186, v193); + int16x8_t v1591 = vsubq_s16(v202, v210); + int16x8_t v1592_tmp = vqrdmulhq_n_s16(v1591, 10045); + int16x8_t v1592 = vaddq_s16(v1592_tmp, v1591); + int16x8_t v1593 = vaddq_s16(v1590, v1592); + int16x8_t v1594 = vsubq_s16(v215, v218); + int16x8_t v1595 = vsubq_s16(v221, v227); + int16x8_t v1596_tmp = vqrdmulhq_n_s16(v1595, 10045); + int16x8_t v1596 = vaddq_s16(v1596_tmp, v1595); + int16x8_t v1597 = vaddq_s16(v1594, v1596); + int16x8_t v1598 = vqrdmulhq_n_s16(v1597, 19705); + int16x8_t v1599 = vaddq_s16(v1593, v1598); + int16x8_t v1600 = vqrdmulhq_n_s16(v1599, 17121); + int16x8_t v1601 = vaddq_s16(v1589, v1600); + int16x8_t v1602 = vsubq_s16(v236, v239); + int16x8_t v1603 = vsubq_s16(v244, v248); + int16x8_t v1604_tmp = vqrdmulhq_n_s16(v1603, 10045); + int16x8_t v1604 = vaddq_s16(v1604_tmp, v1603); + int16x8_t v1605 = vaddq_s16(v1602, v1604); + int16x8_t v1606 = vsubq_s16(v255, v262); + int16x8_t v1607 = vsubq_s16(v265, v270); + int16x8_t v1608_tmp = vqrdmulhq_n_s16(v1607, 10045); + int16x8_t v1608 = vaddq_s16(v1608_tmp, v1607); + int16x8_t v1609 = vaddq_s16(v1606, v1608); + int16x8_t v1610 = vqrdmulhq_n_s16(v1609, 19705); + int16x8_t v1611 = vaddq_s16(v1605, v1610); + int16x8_t v1612 = vsubq_s16(v277, v280); + int16x8_t v1613 = vsubq_s16(v285, v289); + int16x8_t v1614_tmp = vqrdmulhq_n_s16(v1613, 10045); + int16x8_t v1614 = vaddq_s16(v1614_tmp, v1613); + int16x8_t v1615 = vaddq_s16(v1612, v1614); + int16x8_t v1616 = vsubq_s16(v294, v297); + int16x8_t v1617 = vsubq_s16(v300, v306); + int16x8_t v1618_tmp = vqrdmulhq_n_s16(v1617, 10045); + int16x8_t v1618 = vaddq_s16(v1618_tmp, v1617); + int16x8_t v1619 = vaddq_s16(v1616, v1618); + int16x8_t v1620 = vqrdmulhq_n_s16(v1619, 19705); + int16x8_t v1621 = vaddq_s16(v1615, v1620); + int16x8_t v1622 = vqrdmulhq_n_s16(v1621, 17121); + int16x8_t v1623 = vaddq_s16(v1611, v1622); + int16x8_t v1624 = vqrdmulhq_n_s16(v1623, 16563); + int16x8_t v1625 = vaddq_s16(v1601, v1624); + int16x8_t v1626 = vqrdmulhq_n_s16(v1625, 16429); + int16x8_t v1627 = vaddq_s16(v1579, v1626); + int16x8_t v1628 = vsubq_s16(v319, v322); + int16x8_t v1629 = vsubq_s16(v327, v331); + int16x8_t v1630_tmp = vqrdmulhq_n_s16(v1629, 10045); + int16x8_t v1630 = vaddq_s16(v1630_tmp, v1629); + int16x8_t v1631 = vaddq_s16(v1628, v1630); + int16x8_t v1632 = vsubq_s16(v338, v345); + int16x8_t v1633 = vsubq_s16(v348, v353); + int16x8_t v1634_tmp = vqrdmulhq_n_s16(v1633, 10045); + int16x8_t v1634 = vaddq_s16(v1634_tmp, v1633); + int16x8_t v1635 = vaddq_s16(v1632, v1634); + int16x8_t v1636 = vqrdmulhq_n_s16(v1635, 19705); + int16x8_t v1637 = vaddq_s16(v1631, v1636); + int16x8_t v1638 = vsubq_s16(v362, v369); + int16x8_t v1639 = vsubq_s16(v378, v386); + int16x8_t v1640_tmp = vqrdmulhq_n_s16(v1639, 10045); + int16x8_t v1640 = vaddq_s16(v1640_tmp, v1639); + int16x8_t v1641 = vaddq_s16(v1638, v1640); + int16x8_t v1642 = vsubq_s16(v391, v394); + int16x8_t v1643 = vsubq_s16(v397, v403); + int16x8_t v1644_tmp = vqrdmulhq_n_s16(v1643, 10045); + int16x8_t v1644 = vaddq_s16(v1644_tmp, v1643); + int16x8_t v1645 = vaddq_s16(v1642, v1644); + int16x8_t v1646 = vqrdmulhq_n_s16(v1645, 19705); + int16x8_t v1647 = vaddq_s16(v1641, v1646); + int16x8_t v1648 = vqrdmulhq_n_s16(v1647, 17121); + int16x8_t v1649 = vaddq_s16(v1637, v1648); + int16x8_t v1650 = vsubq_s16(v414, v421); + int16x8_t v1651 = vsubq_s16(v430, v438); + int16x8_t v1652_tmp = vqrdmulhq_n_s16(v1651, 10045); + int16x8_t v1652 = vaddq_s16(v1652_tmp, v1651); + int16x8_t v1653 = vaddq_s16(v1650, v1652); + int16x8_t v1654 = vsubq_s16(v449, v464); + int16x8_t v1655 = vsubq_s16(v467, v476); + int16x8_t v1656_tmp = vqrdmulhq_n_s16(v1655, 10045); + int16x8_t v1656 = vaddq_s16(v1656_tmp, v1655); + int16x8_t v1657 = vaddq_s16(v1654, v1656); + int16x8_t v1658 = vqrdmulhq_n_s16(v1657, 19705); + int16x8_t v1659 = vaddq_s16(v1653, v1658); + int16x8_t v1660 = vsubq_s16(v483, v486); + int16x8_t v1661 = vsubq_s16(v491, v495); + int16x8_t v1662_tmp = vqrdmulhq_n_s16(v1661, 10045); + int16x8_t v1662 = vaddq_s16(v1662_tmp, v1661); + int16x8_t v1663 = vaddq_s16(v1660, v1662); + int16x8_t v1664 = vsubq_s16(v500, v503); + int16x8_t v1665 = vsubq_s16(v506, v513); + int16x8_t v1666_tmp = vqrdmulhq_n_s16(v1665, 10045); + int16x8_t v1666 = vaddq_s16(v1666_tmp, v1665); + int16x8_t v1667 = vaddq_s16(v1664, v1666); + int16x8_t v1668 = vqrdmulhq_n_s16(v1667, 19705); + int16x8_t v1669 = vaddq_s16(v1663, v1668); + int16x8_t v1670 = vqrdmulhq_n_s16(v1669, 17121); + int16x8_t v1671 = vaddq_s16(v1659, v1670); + int16x8_t v1672 = vqrdmulhq_n_s16(v1671, 16563); + int16x8_t v1673 = vaddq_s16(v1649, v1672); + int16x8_t v1674 = vsubq_s16(v524, v527); + int16x8_t v1675 = vsubq_s16(v532, v536); + int16x8_t v1676_tmp = vqrdmulhq_n_s16(v1675, 10045); + int16x8_t v1676 = vaddq_s16(v1676_tmp, v1675); + int16x8_t v1677 = vaddq_s16(v1674, v1676); + int16x8_t v1678 = vsubq_s16(v543, v550); + int16x8_t v1679 = vsubq_s16(v553, v558); + int16x8_t v1680_tmp = vqrdmulhq_n_s16(v1679, 10045); + int16x8_t v1680 = vaddq_s16(v1680_tmp, v1679); + int16x8_t v1681 = vaddq_s16(v1678, v1680); + int16x8_t v1682 = vqrdmulhq_n_s16(v1681, 19705); + int16x8_t v1683 = vaddq_s16(v1677, v1682); + int16x8_t v1684 = vsubq_s16(v567, v574); + int16x8_t v1685 = vsubq_s16(v583, v591); + int16x8_t v1686_tmp = vqrdmulhq_n_s16(v1685, 10045); + int16x8_t v1686 = vaddq_s16(v1686_tmp, v1685); + int16x8_t v1687 = vaddq_s16(v1684, v1686); + int16x8_t v1688 = vsubq_s16(v596, v599); + int16x8_t v1689 = vsubq_s16(v602, v608); + int16x8_t v1690_tmp = vqrdmulhq_n_s16(v1689, 10045); + int16x8_t v1690 = vaddq_s16(v1690_tmp, v1689); + int16x8_t v1691 = vaddq_s16(v1688, v1690); + int16x8_t v1692 = vqrdmulhq_n_s16(v1691, 19705); + int16x8_t v1693 = vaddq_s16(v1687, v1692); + int16x8_t v1694 = vqrdmulhq_n_s16(v1693, 17121); + int16x8_t v1695 = vaddq_s16(v1683, v1694); + int16x8_t v1696 = vsubq_s16(v617, v620); + int16x8_t v1697 = vsubq_s16(v625, v629); + int16x8_t v1698_tmp = vqrdmulhq_n_s16(v1697, 10045); + int16x8_t v1698 = vaddq_s16(v1698_tmp, v1697); + int16x8_t v1699 = vaddq_s16(v1696, v1698); + int16x8_t v1700 = vsubq_s16(v636, v643); + int16x8_t v1701 = vsubq_s16(v646, v651); + int16x8_t v1702_tmp = vqrdmulhq_n_s16(v1701, 10045); + int16x8_t v1702 = vaddq_s16(v1702_tmp, v1701); + int16x8_t v1703 = vaddq_s16(v1700, v1702); + int16x8_t v1704 = vqrdmulhq_n_s16(v1703, 19705); + int16x8_t v1705 = vaddq_s16(v1699, v1704); + int16x8_t v1706 = vsubq_s16(v658, v661); + int16x8_t v1707 = vsubq_s16(v666, v670); + int16x8_t v1708_tmp = vqrdmulhq_n_s16(v1707, 10045); + int16x8_t v1708 = vaddq_s16(v1708_tmp, v1707); + int16x8_t v1709 = vaddq_s16(v1706, v1708); + int16x8_t v1710 = vsubq_s16(v675, v678); + int16x8_t v1711 = vsubq_s16(v681, v688); + int16x8_t v1712_tmp = vqrdmulhq_n_s16(v1711, 10045); + int16x8_t v1712 = vaddq_s16(v1712_tmp, v1711); + int16x8_t v1713 = vaddq_s16(v1710, v1712); + int16x8_t v1714 = vqrdmulhq_n_s16(v1713, 19705); + int16x8_t v1715 = vaddq_s16(v1709, v1714); + int16x8_t v1716 = vqrdmulhq_n_s16(v1715, 17121); + int16x8_t v1717 = vaddq_s16(v1705, v1716); + int16x8_t v1718 = vqrdmulhq_n_s16(v1717, 16563); + int16x8_t v1719 = vaddq_s16(v1695, v1718); + int16x8_t v1720 = vqrdmulhq_n_s16(v1719, 16429); + int16x8_t v1721 = vaddq_s16(v1673, v1720); + int16x8_t v1722 = vqrdmulhq_n_s16(v1721, 16395); + int16x8_t v1723 = vaddq_s16(v1627, v1722); + int16x8_t v1724 = vsubq_s16(v703, v706); + int16x8_t v1725 = vsubq_s16(v711, v715); + int16x8_t v1726_tmp = vqrdmulhq_n_s16(v1725, 10045); + int16x8_t v1726 = vaddq_s16(v1726_tmp, v1725); + int16x8_t v1727 = vaddq_s16(v1724, v1726); + int16x8_t v1728 = vsubq_s16(v722, v729); + int16x8_t v1729 = vsubq_s16(v732, v737); + int16x8_t v1730_tmp = vqrdmulhq_n_s16(v1729, 10045); + int16x8_t v1730 = vaddq_s16(v1730_tmp, v1729); + int16x8_t v1731 = vaddq_s16(v1728, v1730); + int16x8_t v1732 = vqrdmulhq_n_s16(v1731, 19705); + int16x8_t v1733 = vaddq_s16(v1727, v1732); + int16x8_t v1734 = vsubq_s16(v746, v753); + int16x8_t v1735 = vsubq_s16(v762, v770); + int16x8_t v1736_tmp = vqrdmulhq_n_s16(v1735, 10045); + int16x8_t v1736 = vaddq_s16(v1736_tmp, v1735); + int16x8_t v1737 = vaddq_s16(v1734, v1736); + int16x8_t v1738 = vsubq_s16(v775, v778); + int16x8_t v1739 = vsubq_s16(v781, v787); + int16x8_t v1740_tmp = vqrdmulhq_n_s16(v1739, 10045); + int16x8_t v1740 = vaddq_s16(v1740_tmp, v1739); + int16x8_t v1741 = vaddq_s16(v1738, v1740); + int16x8_t v1742 = vqrdmulhq_n_s16(v1741, 19705); + int16x8_t v1743 = vaddq_s16(v1737, v1742); + int16x8_t v1744 = vqrdmulhq_n_s16(v1743, 17121); + int16x8_t v1745 = vaddq_s16(v1733, v1744); + int16x8_t v1746 = vsubq_s16(v798, v805); + int16x8_t v1747 = vsubq_s16(v814, v822); + int16x8_t v1748_tmp = vqrdmulhq_n_s16(v1747, 10045); + int16x8_t v1748 = vaddq_s16(v1748_tmp, v1747); + int16x8_t v1749 = vaddq_s16(v1746, v1748); + int16x8_t v1750 = vsubq_s16(v833, v848); + int16x8_t v1751 = vsubq_s16(v851, v860); + int16x8_t v1752_tmp = vqrdmulhq_n_s16(v1751, 10045); + int16x8_t v1752 = vaddq_s16(v1752_tmp, v1751); + int16x8_t v1753 = vaddq_s16(v1750, v1752); + int16x8_t v1754 = vqrdmulhq_n_s16(v1753, 19705); + int16x8_t v1755 = vaddq_s16(v1749, v1754); + int16x8_t v1756 = vsubq_s16(v867, v870); + int16x8_t v1757 = vsubq_s16(v875, v879); + int16x8_t v1758_tmp = vqrdmulhq_n_s16(v1757, 10045); + int16x8_t v1758 = vaddq_s16(v1758_tmp, v1757); + int16x8_t v1759 = vaddq_s16(v1756, v1758); + int16x8_t v1760 = vsubq_s16(v884, v887); + int16x8_t v1761 = vsubq_s16(v890, v897); + int16x8_t v1762_tmp = vqrdmulhq_n_s16(v1761, 10045); + int16x8_t v1762 = vaddq_s16(v1762_tmp, v1761); + int16x8_t v1763 = vaddq_s16(v1760, v1762); + int16x8_t v1764 = vqrdmulhq_n_s16(v1763, 19705); + int16x8_t v1765 = vaddq_s16(v1759, v1764); + int16x8_t v1766 = vqrdmulhq_n_s16(v1765, 17121); + int16x8_t v1767 = vaddq_s16(v1755, v1766); + int16x8_t v1768 = vqrdmulhq_n_s16(v1767, 16563); + int16x8_t v1769 = vaddq_s16(v1745, v1768); + int16x8_t v1770 = vsubq_s16(v910, v917); + int16x8_t v1771 = vsubq_s16(v926, v934); + int16x8_t v1772_tmp = vqrdmulhq_n_s16(v1771, 10045); + int16x8_t v1772 = vaddq_s16(v1772_tmp, v1771); + int16x8_t v1773 = vaddq_s16(v1770, v1772); + int16x8_t v1774 = vsubq_s16(v945, v960); + int16x8_t v1775 = vsubq_s16(v963, v972); + int16x8_t v1776_tmp = vqrdmulhq_n_s16(v1775, 10045); + int16x8_t v1776 = vaddq_s16(v1776_tmp, v1775); + int16x8_t v1777 = vaddq_s16(v1774, v1776); + int16x8_t v1778 = vqrdmulhq_n_s16(v1777, 19705); + int16x8_t v1779 = vaddq_s16(v1773, v1778); + int16x8_t v1780 = vsubq_s16(v985, v1000); + int16x8_t v1781 = vsubq_s16(v1017, v1033); + int16x8_t v1782_tmp = vqrdmulhq_n_s16(v1781, 10045); + int16x8_t v1782 = vaddq_s16(v1782_tmp, v1781); + int16x8_t v1783 = vaddq_s16(v1780, v1782); + int16x8_t v1784 = vsubq_s16(v1038, v1041); + int16x8_t v1785 = vsubq_s16(v1044, v1054); + int16x8_t v1786_tmp = vqrdmulhq_n_s16(v1785, 10045); + int16x8_t v1786 = vaddq_s16(v1786_tmp, v1785); + int16x8_t v1787 = vaddq_s16(v1784, v1786); + int16x8_t v1788 = vqrdmulhq_n_s16(v1787, 19705); + int16x8_t v1789 = vaddq_s16(v1783, v1788); + int16x8_t v1790 = vqrdmulhq_n_s16(v1789, 17121); + int16x8_t v1791 = vaddq_s16(v1779, v1790); + int16x8_t v1792 = vsubq_s16(v1063, v1066); + int16x8_t v1793 = vsubq_s16(v1071, v1075); + int16x8_t v1794_tmp = vqrdmulhq_n_s16(v1793, 10045); + int16x8_t v1794 = vaddq_s16(v1794_tmp, v1793); + int16x8_t v1795 = vaddq_s16(v1792, v1794); + int16x8_t v1796 = vsubq_s16(v1082, v1089); + int16x8_t v1797 = vsubq_s16(v1092, v1097); + int16x8_t v1798_tmp = vqrdmulhq_n_s16(v1797, 10045); + int16x8_t v1798 = vaddq_s16(v1798_tmp, v1797); + int16x8_t v1799 = vaddq_s16(v1796, v1798); + int16x8_t v1800 = vqrdmulhq_n_s16(v1799, 19705); + int16x8_t v1801 = vaddq_s16(v1795, v1800); + int16x8_t v1802 = vsubq_s16(v1104, v1107); + int16x8_t v1803 = vsubq_s16(v1112, v1116); + int16x8_t v1804_tmp = vqrdmulhq_n_s16(v1803, 10045); + int16x8_t v1804 = vaddq_s16(v1804_tmp, v1803); + int16x8_t v1805 = vaddq_s16(v1802, v1804); + int16x8_t v1806 = vsubq_s16(v1121, v1124); + int16x8_t v1807 = vsubq_s16(v1127, v1135); + int16x8_t v1808_tmp = vqrdmulhq_n_s16(v1807, 10045); + int16x8_t v1808 = vaddq_s16(v1808_tmp, v1807); + int16x8_t v1809 = vaddq_s16(v1806, v1808); + int16x8_t v1810 = vqrdmulhq_n_s16(v1809, 19705); + int16x8_t v1811 = vaddq_s16(v1805, v1810); + int16x8_t v1812 = vqrdmulhq_n_s16(v1811, 17121); + int16x8_t v1813 = vaddq_s16(v1801, v1812); + int16x8_t v1814 = vqrdmulhq_n_s16(v1813, 16563); + int16x8_t v1815 = vaddq_s16(v1791, v1814); + int16x8_t v1816 = vqrdmulhq_n_s16(v1815, 16429); + int16x8_t v1817 = vaddq_s16(v1769, v1816); + int16x8_t v1818 = vsubq_s16(v1148, v1151); + int16x8_t v1819 = vsubq_s16(v1156, v1160); + int16x8_t v1820_tmp = vqrdmulhq_n_s16(v1819, 10045); + int16x8_t v1820 = vaddq_s16(v1820_tmp, v1819); + int16x8_t v1821 = vaddq_s16(v1818, v1820); + int16x8_t v1822 = vsubq_s16(v1167, v1174); + int16x8_t v1823 = vsubq_s16(v1177, v1182); + int16x8_t v1824_tmp = vqrdmulhq_n_s16(v1823, 10045); + int16x8_t v1824 = vaddq_s16(v1824_tmp, v1823); + int16x8_t v1825 = vaddq_s16(v1822, v1824); + int16x8_t v1826 = vqrdmulhq_n_s16(v1825, 19705); + int16x8_t v1827 = vaddq_s16(v1821, v1826); + int16x8_t v1828 = vsubq_s16(v1191, v1198); + int16x8_t v1829 = vsubq_s16(v1207, v1215); + int16x8_t v1830_tmp = vqrdmulhq_n_s16(v1829, 10045); + int16x8_t v1830 = vaddq_s16(v1830_tmp, v1829); + int16x8_t v1831 = vaddq_s16(v1828, v1830); + int16x8_t v1832 = vsubq_s16(v1220, v1223); + int16x8_t v1833 = vsubq_s16(v1226, v1232); + int16x8_t v1834_tmp = vqrdmulhq_n_s16(v1833, 10045); + int16x8_t v1834 = vaddq_s16(v1834_tmp, v1833); + int16x8_t v1835 = vaddq_s16(v1832, v1834); + int16x8_t v1836 = vqrdmulhq_n_s16(v1835, 19705); + int16x8_t v1837 = vaddq_s16(v1831, v1836); + int16x8_t v1838 = vqrdmulhq_n_s16(v1837, 17121); + int16x8_t v1839 = vaddq_s16(v1827, v1838); + int16x8_t v1840 = vsubq_s16(v1243, v1250); + int16x8_t v1841 = vsubq_s16(v1259, v1267); + int16x8_t v1842_tmp = vqrdmulhq_n_s16(v1841, 10045); + int16x8_t v1842 = vaddq_s16(v1842_tmp, v1841); + int16x8_t v1843 = vaddq_s16(v1840, v1842); + int16x8_t v1844 = vsubq_s16(v1278, v1293); + int16x8_t v1845 = vsubq_s16(v1296, v1305); + int16x8_t v1846_tmp = vqrdmulhq_n_s16(v1845, 10045); + int16x8_t v1846 = vaddq_s16(v1846_tmp, v1845); + int16x8_t v1847 = vaddq_s16(v1844, v1846); + int16x8_t v1848 = vqrdmulhq_n_s16(v1847, 19705); + int16x8_t v1849 = vaddq_s16(v1843, v1848); + int16x8_t v1850 = vsubq_s16(v1312, v1315); + int16x8_t v1851 = vsubq_s16(v1320, v1324); + int16x8_t v1852_tmp = vqrdmulhq_n_s16(v1851, 10045); + int16x8_t v1852 = vaddq_s16(v1852_tmp, v1851); + int16x8_t v1853 = vaddq_s16(v1850, v1852); + int16x8_t v1854 = vsubq_s16(v1329, v1332); + int16x8_t v1855 = vsubq_s16(v1335, v1342); + int16x8_t v1856_tmp = vqrdmulhq_n_s16(v1855, 10045); + int16x8_t v1856 = vaddq_s16(v1856_tmp, v1855); + int16x8_t v1857 = vaddq_s16(v1854, v1856); + int16x8_t v1858 = vqrdmulhq_n_s16(v1857, 19705); + int16x8_t v1859 = vaddq_s16(v1853, v1858); + int16x8_t v1860 = vqrdmulhq_n_s16(v1859, 17121); + int16x8_t v1861 = vaddq_s16(v1849, v1860); + int16x8_t v1862 = vqrdmulhq_n_s16(v1861, 16563); + int16x8_t v1863 = vaddq_s16(v1839, v1862); + int16x8_t v1864 = vsubq_s16(v1353, v1356); + int16x8_t v1865 = vsubq_s16(v1361, v1365); + int16x8_t v1866_tmp = vqrdmulhq_n_s16(v1865, 10045); + int16x8_t v1866 = vaddq_s16(v1866_tmp, v1865); + int16x8_t v1867 = vaddq_s16(v1864, v1866); + int16x8_t v1868 = vsubq_s16(v1372, v1379); + int16x8_t v1869 = vsubq_s16(v1382, v1387); + int16x8_t v1870_tmp = vqrdmulhq_n_s16(v1869, 10045); + int16x8_t v1870 = vaddq_s16(v1870_tmp, v1869); + int16x8_t v1871 = vaddq_s16(v1868, v1870); + int16x8_t v1872 = vqrdmulhq_n_s16(v1871, 19705); + int16x8_t v1873 = vaddq_s16(v1867, v1872); + int16x8_t v1874 = vsubq_s16(v1396, v1403); + int16x8_t v1875 = vsubq_s16(v1412, v1420); + int16x8_t v1876_tmp = vqrdmulhq_n_s16(v1875, 10045); + int16x8_t v1876 = vaddq_s16(v1876_tmp, v1875); + int16x8_t v1877 = vaddq_s16(v1874, v1876); + int16x8_t v1878 = vsubq_s16(v1425, v1428); + int16x8_t v1879 = vsubq_s16(v1431, v1437); + int16x8_t v1880_tmp = vqrdmulhq_n_s16(v1879, 10045); + int16x8_t v1880 = vaddq_s16(v1880_tmp, v1879); + int16x8_t v1881 = vaddq_s16(v1878, v1880); + int16x8_t v1882 = vqrdmulhq_n_s16(v1881, 19705); + int16x8_t v1883 = vaddq_s16(v1877, v1882); + int16x8_t v1884 = vqrdmulhq_n_s16(v1883, 17121); + int16x8_t v1885 = vaddq_s16(v1873, v1884); + int16x8_t v1886 = vsubq_s16(v1446, v1449); + int16x8_t v1887 = vsubq_s16(v1454, v1458); + int16x8_t v1888_tmp = vqrdmulhq_n_s16(v1887, 10045); + int16x8_t v1888 = vaddq_s16(v1888_tmp, v1887); + int16x8_t v1889 = vaddq_s16(v1886, v1888); + int16x8_t v1890 = vsubq_s16(v1465, v1472); + int16x8_t v1891 = vsubq_s16(v1475, v1480); + int16x8_t v1892_tmp = vqrdmulhq_n_s16(v1891, 10045); + int16x8_t v1892 = vaddq_s16(v1892_tmp, v1891); + int16x8_t v1893 = vaddq_s16(v1890, v1892); + int16x8_t v1894 = vqrdmulhq_n_s16(v1893, 19705); + int16x8_t v1895 = vaddq_s16(v1889, v1894); + int16x8_t v1896 = vsubq_s16(v1487, v1490); + int16x8_t v1897 = vsubq_s16(v1495, v1499); + int16x8_t v1898_tmp = vqrdmulhq_n_s16(v1897, 10045); + int16x8_t v1898 = vaddq_s16(v1898_tmp, v1897); + int16x8_t v1899 = vaddq_s16(v1896, v1898); + int16x8_t v1900 = vsubq_s16(v1504, v1507); + int16x8_t v1901 = vsubq_s16(v1510, v1518); + int16x8_t v1902_tmp = vqrdmulhq_n_s16(v1901, 10045); + int16x8_t v1902 = vaddq_s16(v1902_tmp, v1901); + int16x8_t v1903 = vaddq_s16(v1900, v1902); + int16x8_t v1904 = vqrdmulhq_n_s16(v1903, 19705); + int16x8_t v1905 = vaddq_s16(v1899, v1904); + int16x8_t v1906 = vqrdmulhq_n_s16(v1905, 17121); + int16x8_t v1907 = vaddq_s16(v1895, v1906); + int16x8_t v1908 = vqrdmulhq_n_s16(v1907, 16563); + int16x8_t v1909 = vaddq_s16(v1885, v1908); + int16x8_t v1910 = vqrdmulhq_n_s16(v1909, 16429); + int16x8_t v1911 = vaddq_s16(v1863, v1910); + int16x8_t v1912 = vqrdmulhq_n_s16(v1911, 16395); + int16x8_t v1913 = vaddq_s16(v1817, v1912); + int16x8_t v1914 = vqrdmulhq_n_s16(v1913, 16387); + int16x8_t v1915 = vaddq_s16(v1723, v1914); + int16x8_t v1916 = vsubq_s16(v1534, v1536); + int16x8_t v1917 = vsubq_s16(v1538, v1540); + int16x8_t v1918 = vqrdmulhq_n_s16(v1917, 29490); + int16x8_t v1919 = vaddq_s16(v1916, v1918); + int16x8_t v1920 = vsubq_s16(v1544, v1546); + int16x8_t v1921 = vsubq_s16(v1548, v1550); + int16x8_t v1922 = vqrdmulhq_n_s16(v1921, 29490); + int16x8_t v1923 = vaddq_s16(v1920, v1922); + int16x8_t v1924 = vqrdmulhq_n_s16(v1923, 18578); + int16x8_t v1925 = vaddq_s16(v1919, v1924); + int16x8_t v1926 = vsubq_s16(v1556, v1558); + int16x8_t v1927 = vsubq_s16(v1560, v1562); + int16x8_t v1928 = vqrdmulhq_n_s16(v1927, 29490); + int16x8_t v1929 = vaddq_s16(v1926, v1928); + int16x8_t v1930 = vsubq_s16(v1566, v1568); + int16x8_t v1931 = vsubq_s16(v1570, v1572); + int16x8_t v1932 = vqrdmulhq_n_s16(v1931, 29490); + int16x8_t v1933 = vaddq_s16(v1930, v1932); + int16x8_t v1934 = vqrdmulhq_n_s16(v1933, 18578); + int16x8_t v1935 = vaddq_s16(v1929, v1934); + int16x8_t v1936 = vqrdmulhq_n_s16(v1935, 16890); + int16x8_t v1937 = vaddq_s16(v1925, v1936); + int16x8_t v1938 = vsubq_s16(v1580, v1582); + int16x8_t v1939 = vsubq_s16(v1584, v1586); + int16x8_t v1940 = vqrdmulhq_n_s16(v1939, 29490); + int16x8_t v1941 = vaddq_s16(v1938, v1940); + int16x8_t v1942 = vsubq_s16(v1590, v1592); + int16x8_t v1943 = vsubq_s16(v1594, v1596); + int16x8_t v1944 = vqrdmulhq_n_s16(v1943, 29490); + int16x8_t v1945 = vaddq_s16(v1942, v1944); + int16x8_t v1946 = vqrdmulhq_n_s16(v1945, 18578); + int16x8_t v1947 = vaddq_s16(v1941, v1946); + int16x8_t v1948 = vsubq_s16(v1602, v1604); + int16x8_t v1949 = vsubq_s16(v1606, v1608); + int16x8_t v1950 = vqrdmulhq_n_s16(v1949, 29490); + int16x8_t v1951 = vaddq_s16(v1948, v1950); + int16x8_t v1952 = vsubq_s16(v1612, v1614); + int16x8_t v1953 = vsubq_s16(v1616, v1618); + int16x8_t v1954 = vqrdmulhq_n_s16(v1953, 29490); + int16x8_t v1955 = vaddq_s16(v1952, v1954); + int16x8_t v1956 = vqrdmulhq_n_s16(v1955, 18578); + int16x8_t v1957 = vaddq_s16(v1951, v1956); + int16x8_t v1958 = vqrdmulhq_n_s16(v1957, 16890); + int16x8_t v1959 = vaddq_s16(v1947, v1958); + int16x8_t v1960 = vqrdmulhq_n_s16(v1959, 16508); + int16x8_t v1961 = vaddq_s16(v1937, v1960); + int16x8_t v1962 = vsubq_s16(v1628, v1630); + int16x8_t v1963 = vsubq_s16(v1632, v1634); + int16x8_t v1964 = vqrdmulhq_n_s16(v1963, 29490); + int16x8_t v1965 = vaddq_s16(v1962, v1964); + int16x8_t v1966 = vsubq_s16(v1638, v1640); + int16x8_t v1967 = vsubq_s16(v1642, v1644); + int16x8_t v1968 = vqrdmulhq_n_s16(v1967, 29490); + int16x8_t v1969 = vaddq_s16(v1966, v1968); + int16x8_t v1970 = vqrdmulhq_n_s16(v1969, 18578); + int16x8_t v1971 = vaddq_s16(v1965, v1970); + int16x8_t v1972 = vsubq_s16(v1650, v1652); + int16x8_t v1973 = vsubq_s16(v1654, v1656); + int16x8_t v1974 = vqrdmulhq_n_s16(v1973, 29490); + int16x8_t v1975 = vaddq_s16(v1972, v1974); + int16x8_t v1976 = vsubq_s16(v1660, v1662); + int16x8_t v1977 = vsubq_s16(v1664, v1666); + int16x8_t v1978 = vqrdmulhq_n_s16(v1977, 29490); + int16x8_t v1979 = vaddq_s16(v1976, v1978); + int16x8_t v1980 = vqrdmulhq_n_s16(v1979, 18578); + int16x8_t v1981 = vaddq_s16(v1975, v1980); + int16x8_t v1982 = vqrdmulhq_n_s16(v1981, 16890); + int16x8_t v1983 = vaddq_s16(v1971, v1982); + int16x8_t v1984 = vsubq_s16(v1674, v1676); + int16x8_t v1985 = vsubq_s16(v1678, v1680); + int16x8_t v1986 = vqrdmulhq_n_s16(v1985, 29490); + int16x8_t v1987 = vaddq_s16(v1984, v1986); + int16x8_t v1988 = vsubq_s16(v1684, v1686); + int16x8_t v1989 = vsubq_s16(v1688, v1690); + int16x8_t v1990 = vqrdmulhq_n_s16(v1989, 29490); + int16x8_t v1991 = vaddq_s16(v1988, v1990); + int16x8_t v1992 = vqrdmulhq_n_s16(v1991, 18578); + int16x8_t v1993 = vaddq_s16(v1987, v1992); + int16x8_t v1994 = vsubq_s16(v1696, v1698); + int16x8_t v1995 = vsubq_s16(v1700, v1702); + int16x8_t v1996 = vqrdmulhq_n_s16(v1995, 29490); + int16x8_t v1997 = vaddq_s16(v1994, v1996); + int16x8_t v1998 = vsubq_s16(v1706, v1708); + int16x8_t v1999 = vsubq_s16(v1710, v1712); + int16x8_t v2000 = vqrdmulhq_n_s16(v1999, 29490); + int16x8_t v2001 = vaddq_s16(v1998, v2000); + int16x8_t v2002 = vqrdmulhq_n_s16(v2001, 18578); + int16x8_t v2003 = vaddq_s16(v1997, v2002); + int16x8_t v2004 = vqrdmulhq_n_s16(v2003, 16890); + int16x8_t v2005 = vaddq_s16(v1993, v2004); + int16x8_t v2006 = vqrdmulhq_n_s16(v2005, 16508); + int16x8_t v2007 = vaddq_s16(v1983, v2006); + int16x8_t v2008 = vqrdmulhq_n_s16(v2007, 16415); + int16x8_t v2009 = vaddq_s16(v1961, v2008); + int16x8_t v2010 = vsubq_s16(v1724, v1726); + int16x8_t v2011 = vsubq_s16(v1728, v1730); + int16x8_t v2012 = vqrdmulhq_n_s16(v2011, 29490); + int16x8_t v2013 = vaddq_s16(v2010, v2012); + int16x8_t v2014 = vsubq_s16(v1734, v1736); + int16x8_t v2015 = vsubq_s16(v1738, v1740); + int16x8_t v2016 = vqrdmulhq_n_s16(v2015, 29490); + int16x8_t v2017 = vaddq_s16(v2014, v2016); + int16x8_t v2018 = vqrdmulhq_n_s16(v2017, 18578); + int16x8_t v2019 = vaddq_s16(v2013, v2018); + int16x8_t v2020 = vsubq_s16(v1746, v1748); + int16x8_t v2021 = vsubq_s16(v1750, v1752); + int16x8_t v2022 = vqrdmulhq_n_s16(v2021, 29490); + int16x8_t v2023 = vaddq_s16(v2020, v2022); + int16x8_t v2024 = vsubq_s16(v1756, v1758); + int16x8_t v2025 = vsubq_s16(v1760, v1762); + int16x8_t v2026 = vqrdmulhq_n_s16(v2025, 29490); + int16x8_t v2027 = vaddq_s16(v2024, v2026); + int16x8_t v2028 = vqrdmulhq_n_s16(v2027, 18578); + int16x8_t v2029 = vaddq_s16(v2023, v2028); + int16x8_t v2030 = vqrdmulhq_n_s16(v2029, 16890); + int16x8_t v2031 = vaddq_s16(v2019, v2030); + int16x8_t v2032 = vsubq_s16(v1770, v1772); + int16x8_t v2033 = vsubq_s16(v1774, v1776); + int16x8_t v2034 = vqrdmulhq_n_s16(v2033, 29490); + int16x8_t v2035 = vaddq_s16(v2032, v2034); + int16x8_t v2036 = vsubq_s16(v1780, v1782); + int16x8_t v2037 = vsubq_s16(v1784, v1786); + int16x8_t v2038 = vqrdmulhq_n_s16(v2037, 29490); + int16x8_t v2039 = vaddq_s16(v2036, v2038); + int16x8_t v2040 = vqrdmulhq_n_s16(v2039, 18578); + int16x8_t v2041 = vaddq_s16(v2035, v2040); + int16x8_t v2042 = vsubq_s16(v1792, v1794); + int16x8_t v2043 = vsubq_s16(v1796, v1798); + int16x8_t v2044 = vqrdmulhq_n_s16(v2043, 29490); + int16x8_t v2045 = vaddq_s16(v2042, v2044); + int16x8_t v2046 = vsubq_s16(v1802, v1804); + int16x8_t v2047 = vsubq_s16(v1806, v1808); + int16x8_t v2048 = vqrdmulhq_n_s16(v2047, 29490); + int16x8_t v2049 = vaddq_s16(v2046, v2048); + int16x8_t v2050 = vqrdmulhq_n_s16(v2049, 18578); + int16x8_t v2051 = vaddq_s16(v2045, v2050); + int16x8_t v2052 = vqrdmulhq_n_s16(v2051, 16890); + int16x8_t v2053 = vaddq_s16(v2041, v2052); + int16x8_t v2054 = vqrdmulhq_n_s16(v2053, 16508); + int16x8_t v2055 = vaddq_s16(v2031, v2054); + int16x8_t v2056 = vsubq_s16(v1818, v1820); + int16x8_t v2057 = vsubq_s16(v1822, v1824); + int16x8_t v2058 = vqrdmulhq_n_s16(v2057, 29490); + int16x8_t v2059 = vaddq_s16(v2056, v2058); + int16x8_t v2060 = vsubq_s16(v1828, v1830); + int16x8_t v2061 = vsubq_s16(v1832, v1834); + int16x8_t v2062 = vqrdmulhq_n_s16(v2061, 29490); + int16x8_t v2063 = vaddq_s16(v2060, v2062); + int16x8_t v2064 = vqrdmulhq_n_s16(v2063, 18578); + int16x8_t v2065 = vaddq_s16(v2059, v2064); + int16x8_t v2066 = vsubq_s16(v1840, v1842); + int16x8_t v2067 = vsubq_s16(v1844, v1846); + int16x8_t v2068 = vqrdmulhq_n_s16(v2067, 29490); + int16x8_t v2069 = vaddq_s16(v2066, v2068); + int16x8_t v2070 = vsubq_s16(v1850, v1852); + int16x8_t v2071 = vqrdmulhq_n_s16(v2070, 18578); + int16x8_t v2072 = vsubq_s16(v1854, v1856); + int16x8_t v2073 = vqrdmulhq_n_s16(v2072, 16719); + int16x8_t v2074 = vaddq_s16(v2071, v2073); + int16x8_t v2075 = vaddq_s16(v2069, v2074); + int16x8_t v2076 = vqrdmulhq_n_s16(v2075, 16890); + int16x8_t v2077 = vaddq_s16(v2065, v2076); + int16x8_t v2078 = vsubq_s16(v1864, v1866); + int16x8_t v2079 = vsubq_s16(v1868, v1870); + int16x8_t v2080 = vqrdmulhq_n_s16(v2079, 29490); + int16x8_t v2081 = vaddq_s16(v2078, v2080); + int16x8_t v2082 = vsubq_s16(v1874, v1876); + int16x8_t v2083 = vsubq_s16(v1878, v1880); + int16x8_t v2084 = vqrdmulhq_n_s16(v2083, 29490); + int16x8_t v2085 = vaddq_s16(v2082, v2084); + int16x8_t v2086 = vqrdmulhq_n_s16(v2085, 18578); + int16x8_t v2087 = vaddq_s16(v2081, v2086); + int16x8_t v2088 = vsubq_s16(v1886, v1888); + int16x8_t v2089 = vsubq_s16(v1890, v1892); + int16x8_t v2090 = vqrdmulhq_n_s16(v2089, 29490); + int16x8_t v2091 = vaddq_s16(v2088, v2090); + int16x8_t v2092 = vsubq_s16(v1896, v1898); + int16x8_t v2093 = vsubq_s16(v1900, v1902); + int16x8_t v2094 = vqrdmulhq_n_s16(v2093, 29490); + int16x8_t v2095 = vaddq_s16(v2092, v2094); + int16x8_t v2096 = vqrdmulhq_n_s16(v2095, 18578); + int16x8_t v2097 = vaddq_s16(v2091, v2096); + int16x8_t v2098 = vqrdmulhq_n_s16(v2097, 16890); + int16x8_t v2099 = vaddq_s16(v2087, v2098); + int16x8_t v2100 = vqrdmulhq_n_s16(v2099, 16508); + int16x8_t v2101 = vaddq_s16(v2077, v2100); + int16x8_t v2102 = vqrdmulhq_n_s16(v2101, 16415); + int16x8_t v2103 = vaddq_s16(v2055, v2102); + int16x8_t v2104 = vqrdmulhq_n_s16(v2103, 16392); + int16x8_t v2105 = vaddq_s16(v2009, v2104); + int16x8_t v2106 = vsubq_s16(v2, v8); + int16x8_t v2107 = vsubq_s16(v15, v22); + int16x8_t v2108_tmp = vqrdmulhq_n_s16(v2107, 18446); + int16x8_t v2108 = vmlaq_n_s16(v2108_tmp, v2107, 2); + int16x8_t v2109 = vaddq_s16(v2106, v2108); + int16x8_t v2110 = vsubq_s16(v31, v41); + int16x8_t v2111 = vsubq_s16(v48, v56); + int16x8_t v2112_tmp = vqrdmulhq_n_s16(v2111, 18446); + int16x8_t v2112 = vmlaq_n_s16(v2112_tmp, v2111, 2); + int16x8_t v2113 = vaddq_s16(v2110, v2112); + int16x8_t v2114 = vqrdmulhq_n_s16(v2113, 21195); + int16x8_t v2115 = vaddq_s16(v2109, v2114); + int16x8_t v2116 = vsubq_s16(v67, v77); + int16x8_t v2117 = vsubq_s16(v90, v99); + int16x8_t v2118_tmp = vqrdmulhq_n_s16(v2117, 18446); + int16x8_t v2118 = vmlaq_n_s16(v2118_tmp, v2117, 2); + int16x8_t v2119 = vaddq_s16(v2116, v2118); + int16x8_t v2120 = vsubq_s16(v108, v118); + int16x8_t v2121 = vsubq_s16(v125, v134); + int16x8_t v2122_tmp = vqrdmulhq_n_s16(v2121, 18446); + int16x8_t v2122 = vmlaq_n_s16(v2122_tmp, v2121, 2); + int16x8_t v2123 = vaddq_s16(v2120, v2122); + int16x8_t v2124 = vqrdmulhq_n_s16(v2123, 21195); + int16x8_t v2125 = vaddq_s16(v2119, v2124); + int16x8_t v2126 = vqrdmulhq_n_s16(v2125, 17401); + int16x8_t v2127 = vaddq_s16(v2115, v2126); + int16x8_t v2128 = vsubq_s16(v147, v157); + int16x8_t v2129 = vsubq_s16(v170, v179); + int16x8_t v2130_tmp = vqrdmulhq_n_s16(v2129, 18446); + int16x8_t v2130 = vmlaq_n_s16(v2130_tmp, v2129, 2); + int16x8_t v2131 = vaddq_s16(v2128, v2130); + int16x8_t v2132 = vsubq_s16(v194, v212); + int16x8_t v2133 = vsubq_s16(v219, v229); + int16x8_t v2134_tmp = vqrdmulhq_n_s16(v2133, 18446); + int16x8_t v2134 = vmlaq_n_s16(v2134_tmp, v2133, 2); + int16x8_t v2135 = vaddq_s16(v2132, v2134); + int16x8_t v2136 = vqrdmulhq_n_s16(v2135, 21195); + int16x8_t v2137 = vaddq_s16(v2131, v2136); + int16x8_t v2138 = vsubq_s16(v240, v250); + int16x8_t v2139 = vsubq_s16(v263, v272); + int16x8_t v2140_tmp = vqrdmulhq_n_s16(v2139, 18446); + int16x8_t v2140 = vmlaq_n_s16(v2140_tmp, v2139, 2); + int16x8_t v2141 = vaddq_s16(v2138, v2140); + int16x8_t v2142 = vsubq_s16(v281, v291); + int16x8_t v2143 = vsubq_s16(v298, v308); + int16x8_t v2144_tmp = vqrdmulhq_n_s16(v2143, 18446); + int16x8_t v2144 = vmlaq_n_s16(v2144_tmp, v2143, 2); + int16x8_t v2145 = vaddq_s16(v2142, v2144); + int16x8_t v2146 = vqrdmulhq_n_s16(v2145, 21195); + int16x8_t v2147 = vaddq_s16(v2141, v2146); + int16x8_t v2148 = vqrdmulhq_n_s16(v2147, 17401); + int16x8_t v2149 = vaddq_s16(v2137, v2148); + int16x8_t v2150 = vqrdmulhq_n_s16(v2149, 16629); + int16x8_t v2151 = vaddq_s16(v2127, v2150); + int16x8_t v2152 = vsubq_s16(v323, v333); + int16x8_t v2153 = vsubq_s16(v346, v355); + int16x8_t v2154_tmp = vqrdmulhq_n_s16(v2153, 18446); + int16x8_t v2154 = vmlaq_n_s16(v2154_tmp, v2153, 2); + int16x8_t v2155 = vaddq_s16(v2152, v2154); + int16x8_t v2156 = vsubq_s16(v370, v388); + int16x8_t v2157 = vsubq_s16(v395, v405); + int16x8_t v2158_tmp = vqrdmulhq_n_s16(v2157, 18446); + int16x8_t v2158 = vmlaq_n_s16(v2158_tmp, v2157, 2); + int16x8_t v2159 = vaddq_s16(v2156, v2158); + int16x8_t v2160 = vqrdmulhq_n_s16(v2159, 21195); + int16x8_t v2161 = vaddq_s16(v2155, v2160); + int16x8_t v2162 = vsubq_s16(v422, v440); + int16x8_t v2163 = vsubq_s16(v465, v478); + int16x8_t v2164_tmp = vqrdmulhq_n_s16(v2163, 18446); + int16x8_t v2164 = vmlaq_n_s16(v2164_tmp, v2163, 2); + int16x8_t v2165 = vaddq_s16(v2162, v2164); + int16x8_t v2166 = vsubq_s16(v487, v497); + int16x8_t v2167 = vsubq_s16(v504, v515); + int16x8_t v2168_tmp = vqrdmulhq_n_s16(v2167, 18446); + int16x8_t v2168 = vmlaq_n_s16(v2168_tmp, v2167, 2); + int16x8_t v2169 = vaddq_s16(v2166, v2168); + int16x8_t v2170 = vqrdmulhq_n_s16(v2169, 21195); + int16x8_t v2171 = vaddq_s16(v2165, v2170); + int16x8_t v2172 = vqrdmulhq_n_s16(v2171, 17401); + int16x8_t v2173 = vaddq_s16(v2161, v2172); + int16x8_t v2174 = vsubq_s16(v528, v538); + int16x8_t v2175 = vsubq_s16(v551, v560); + int16x8_t v2176_tmp = vqrdmulhq_n_s16(v2175, 18446); + int16x8_t v2176 = vmlaq_n_s16(v2176_tmp, v2175, 2); + int16x8_t v2177 = vaddq_s16(v2174, v2176); + int16x8_t v2178 = vsubq_s16(v575, v593); + int16x8_t v2179 = vsubq_s16(v600, v610); + int16x8_t v2180_tmp = vqrdmulhq_n_s16(v2179, 18446); + int16x8_t v2180 = vmlaq_n_s16(v2180_tmp, v2179, 2); + int16x8_t v2181 = vaddq_s16(v2178, v2180); + int16x8_t v2182 = vqrdmulhq_n_s16(v2181, 21195); + int16x8_t v2183 = vaddq_s16(v2177, v2182); + int16x8_t v2184 = vsubq_s16(v621, v631); + int16x8_t v2185 = vsubq_s16(v644, v653); + int16x8_t v2186_tmp = vqrdmulhq_n_s16(v2185, 18446); + int16x8_t v2186 = vmlaq_n_s16(v2186_tmp, v2185, 2); + int16x8_t v2187 = vaddq_s16(v2184, v2186); + int16x8_t v2188 = vsubq_s16(v662, v672); + int16x8_t v2189 = vsubq_s16(v679, v690); + int16x8_t v2190_tmp = vqrdmulhq_n_s16(v2189, 18446); + int16x8_t v2190 = vmlaq_n_s16(v2190_tmp, v2189, 2); + int16x8_t v2191 = vaddq_s16(v2188, v2190); + int16x8_t v2192 = vqrdmulhq_n_s16(v2191, 21195); + int16x8_t v2193 = vaddq_s16(v2187, v2192); + int16x8_t v2194 = vqrdmulhq_n_s16(v2193, 17401); + int16x8_t v2195 = vaddq_s16(v2183, v2194); + int16x8_t v2196 = vqrdmulhq_n_s16(v2195, 16629); + int16x8_t v2197 = vaddq_s16(v2173, v2196); + int16x8_t v2198 = vqrdmulhq_n_s16(v2197, 16445); + int16x8_t v2199 = vaddq_s16(v2151, v2198); + int16x8_t v2200 = vsubq_s16(v707, v717); + int16x8_t v2201 = vsubq_s16(v730, v739); + int16x8_t v2202_tmp = vqrdmulhq_n_s16(v2201, 18446); + int16x8_t v2202 = vmlaq_n_s16(v2202_tmp, v2201, 2); + int16x8_t v2203 = vaddq_s16(v2200, v2202); + int16x8_t v2204 = vsubq_s16(v754, v772); + int16x8_t v2205 = vsubq_s16(v779, v789); + int16x8_t v2206_tmp = vqrdmulhq_n_s16(v2205, 18446); + int16x8_t v2206 = vmlaq_n_s16(v2206_tmp, v2205, 2); + int16x8_t v2207 = vaddq_s16(v2204, v2206); + int16x8_t v2208 = vqrdmulhq_n_s16(v2207, 21195); + int16x8_t v2209 = vaddq_s16(v2203, v2208); + int16x8_t v2210 = vsubq_s16(v806, v824); + int16x8_t v2211 = vsubq_s16(v849, v862); + int16x8_t v2212_tmp = vqrdmulhq_n_s16(v2211, 18446); + int16x8_t v2212 = vmlaq_n_s16(v2212_tmp, v2211, 2); + int16x8_t v2213 = vaddq_s16(v2210, v2212); + int16x8_t v2214 = vsubq_s16(v871, v881); + int16x8_t v2215 = vsubq_s16(v888, v899); + int16x8_t v2216_tmp = vqrdmulhq_n_s16(v2215, 18446); + int16x8_t v2216 = vmlaq_n_s16(v2216_tmp, v2215, 2); + int16x8_t v2217 = vaddq_s16(v2214, v2216); + int16x8_t v2218 = vqrdmulhq_n_s16(v2217, 21195); + int16x8_t v2219 = vaddq_s16(v2213, v2218); + int16x8_t v2220 = vqrdmulhq_n_s16(v2219, 17401); + int16x8_t v2221 = vaddq_s16(v2209, v2220); + int16x8_t v2222 = vsubq_s16(v918, v936); + int16x8_t v2223 = vsubq_s16(v961, v974); + int16x8_t v2224_tmp = vqrdmulhq_n_s16(v2223, 18446); + int16x8_t v2224 = vmlaq_n_s16(v2224_tmp, v2223, 2); + int16x8_t v2225 = vaddq_s16(v2222, v2224); + int16x8_t v2226 = vsubq_s16(v1001, v1035); + int16x8_t v2227 = vsubq_s16(v1042, v1056); + int16x8_t v2228_tmp = vqrdmulhq_n_s16(v2227, 18446); + int16x8_t v2228 = vmlaq_n_s16(v2228_tmp, v2227, 2); + int16x8_t v2229 = vaddq_s16(v2226, v2228); + int16x8_t v2230 = vqrdmulhq_n_s16(v2229, 21195); + int16x8_t v2231 = vaddq_s16(v2225, v2230); + int16x8_t v2232 = vsubq_s16(v1067, v1077); + int16x8_t v2233 = vsubq_s16(v1090, v1099); + int16x8_t v2234_tmp = vqrdmulhq_n_s16(v2233, 18446); + int16x8_t v2234 = vmlaq_n_s16(v2234_tmp, v2233, 2); + int16x8_t v2235 = vaddq_s16(v2232, v2234); + int16x8_t v2236 = vsubq_s16(v1108, v1118); + int16x8_t v2237 = vsubq_s16(v1125, v1137); + int16x8_t v2238_tmp = vqrdmulhq_n_s16(v2237, 18446); + int16x8_t v2238 = vmlaq_n_s16(v2238_tmp, v2237, 2); + int16x8_t v2239 = vaddq_s16(v2236, v2238); + int16x8_t v2240 = vqrdmulhq_n_s16(v2239, 21195); + int16x8_t v2241 = vaddq_s16(v2235, v2240); + int16x8_t v2242 = vqrdmulhq_n_s16(v2241, 17401); + int16x8_t v2243 = vaddq_s16(v2231, v2242); + int16x8_t v2244 = vqrdmulhq_n_s16(v2243, 16629); + int16x8_t v2245 = vaddq_s16(v2221, v2244); + int16x8_t v2246 = vsubq_s16(v1152, v1162); + int16x8_t v2247 = vsubq_s16(v1175, v1184); + int16x8_t v2248_tmp = vqrdmulhq_n_s16(v2247, 18446); + int16x8_t v2248 = vmlaq_n_s16(v2248_tmp, v2247, 2); + int16x8_t v2249 = vaddq_s16(v2246, v2248); + int16x8_t v2250 = vsubq_s16(v1199, v1217); + int16x8_t v2251 = vsubq_s16(v1224, v1234); + int16x8_t v2252_tmp = vqrdmulhq_n_s16(v2251, 18446); + int16x8_t v2252 = vmlaq_n_s16(v2252_tmp, v2251, 2); + int16x8_t v2253 = vaddq_s16(v2250, v2252); + int16x8_t v2254 = vqrdmulhq_n_s16(v2253, 21195); + int16x8_t v2255 = vaddq_s16(v2249, v2254); + int16x8_t v2256 = vsubq_s16(v1251, v1269); + int16x8_t v2257 = vsubq_s16(v1294, v1307); + int16x8_t v2258_tmp = vqrdmulhq_n_s16(v2257, 18446); + int16x8_t v2258 = vmlaq_n_s16(v2258_tmp, v2257, 2); + int16x8_t v2259 = vaddq_s16(v2256, v2258); + int16x8_t v2260 = vsubq_s16(v1316, v1326); + int16x8_t v2261 = vsubq_s16(v1333, v1344); + int16x8_t v2262_tmp = vqrdmulhq_n_s16(v2261, 18446); + int16x8_t v2262 = vmlaq_n_s16(v2262_tmp, v2261, 2); + int16x8_t v2263 = vaddq_s16(v2260, v2262); + int16x8_t v2264 = vqrdmulhq_n_s16(v2263, 21195); + int16x8_t v2265 = vaddq_s16(v2259, v2264); + int16x8_t v2266 = vqrdmulhq_n_s16(v2265, 17401); + int16x8_t v2267 = vaddq_s16(v2255, v2266); + int16x8_t v2268 = vsubq_s16(v1357, v1367); + int16x8_t v2269 = vsubq_s16(v1380, v1389); + int16x8_t v2270_tmp = vqrdmulhq_n_s16(v2269, 18446); + int16x8_t v2270 = vmlaq_n_s16(v2270_tmp, v2269, 2); + int16x8_t v2271 = vaddq_s16(v2268, v2270); + int16x8_t v2272 = vsubq_s16(v1404, v1422); + int16x8_t v2273 = vsubq_s16(v1429, v1439); + int16x8_t v2274_tmp = vqrdmulhq_n_s16(v2273, 18446); + int16x8_t v2274 = vmlaq_n_s16(v2274_tmp, v2273, 2); + int16x8_t v2275 = vaddq_s16(v2272, v2274); + int16x8_t v2276 = vqrdmulhq_n_s16(v2275, 21195); + int16x8_t v2277 = vaddq_s16(v2271, v2276); + int16x8_t v2278 = vsubq_s16(v1450, v1460); + int16x8_t v2279 = vsubq_s16(v1473, v1482); + int16x8_t v2280_tmp = vqrdmulhq_n_s16(v2279, 18446); + int16x8_t v2280 = vmlaq_n_s16(v2280_tmp, v2279, 2); + int16x8_t v2281 = vaddq_s16(v2278, v2280); + int16x8_t v2282 = vsubq_s16(v1491, v1501); + int16x8_t v2283 = vsubq_s16(v1508, v1520); + int16x8_t v2284_tmp = vqrdmulhq_n_s16(v2283, 18446); + int16x8_t v2284 = vmlaq_n_s16(v2284_tmp, v2283, 2); + int16x8_t v2285 = vaddq_s16(v2282, v2284); + int16x8_t v2286 = vqrdmulhq_n_s16(v2285, 21195); + int16x8_t v2287 = vaddq_s16(v2281, v2286); + int16x8_t v2288 = vqrdmulhq_n_s16(v2287, 17401); + int16x8_t v2289 = vaddq_s16(v2277, v2288); + int16x8_t v2290 = vqrdmulhq_n_s16(v2289, 16629); + int16x8_t v2291 = vaddq_s16(v2267, v2290); + int16x8_t v2292 = vqrdmulhq_n_s16(v2291, 16445); + int16x8_t v2293 = vaddq_s16(v2245, v2292); + int16x8_t v2294 = vqrdmulhq_n_s16(v2293, 16399); + int16x8_t v2295 = vaddq_s16(v2199, v2294); + int16x8_t v2296 = vsubq_s16(v2106, v2108); + int16x8_t v2297 = vsubq_s16(v2110, v2112); + int16x8_t v2298 = vqrdmulhq_n_s16(v2297, 25826); + int16x8_t v2299 = vaddq_s16(v2296, v2298); + int16x8_t v2300 = vsubq_s16(v2116, v2118); + int16x8_t v2301 = vsubq_s16(v2120, v2122); + int16x8_t v2302 = vqrdmulhq_n_s16(v2301, 25826); + int16x8_t v2303 = vaddq_s16(v2300, v2302); + int16x8_t v2304 = vqrdmulhq_n_s16(v2303, 18124); + int16x8_t v2305 = vaddq_s16(v2299, v2304); + int16x8_t v2306 = vsubq_s16(v2128, v2130); + int16x8_t v2307 = vsubq_s16(v2132, v2134); + int16x8_t v2308 = vqrdmulhq_n_s16(v2307, 25826); + int16x8_t v2309 = vaddq_s16(v2306, v2308); + int16x8_t v2310 = vsubq_s16(v2138, v2140); + int16x8_t v2311 = vsubq_s16(v2142, v2144); + int16x8_t v2312 = vqrdmulhq_n_s16(v2311, 25826); + int16x8_t v2313 = vaddq_s16(v2310, v2312); + int16x8_t v2314 = vqrdmulhq_n_s16(v2313, 18124); + int16x8_t v2315 = vaddq_s16(v2309, v2314); + int16x8_t v2316 = vqrdmulhq_n_s16(v2315, 16792); + int16x8_t v2317 = vaddq_s16(v2305, v2316); + int16x8_t v2318 = vsubq_s16(v2152, v2154); + int16x8_t v2319 = vsubq_s16(v2156, v2158); + int16x8_t v2320 = vqrdmulhq_n_s16(v2319, 25826); + int16x8_t v2321 = vaddq_s16(v2318, v2320); + int16x8_t v2322 = vsubq_s16(v2162, v2164); + int16x8_t v2323 = vsubq_s16(v2166, v2168); + int16x8_t v2324 = vqrdmulhq_n_s16(v2323, 25826); + int16x8_t v2325 = vaddq_s16(v2322, v2324); + int16x8_t v2326 = vqrdmulhq_n_s16(v2325, 18124); + int16x8_t v2327 = vaddq_s16(v2321, v2326); + int16x8_t v2328 = vsubq_s16(v2174, v2176); + int16x8_t v2329 = vsubq_s16(v2178, v2180); + int16x8_t v2330 = vqrdmulhq_n_s16(v2329, 25826); + int16x8_t v2331 = vaddq_s16(v2328, v2330); + int16x8_t v2332 = vsubq_s16(v2184, v2186); + int16x8_t v2333 = vsubq_s16(v2188, v2190); + int16x8_t v2334 = vqrdmulhq_n_s16(v2333, 25826); + int16x8_t v2335 = vaddq_s16(v2332, v2334); + int16x8_t v2336 = vqrdmulhq_n_s16(v2335, 18124); + int16x8_t v2337 = vaddq_s16(v2331, v2336); + int16x8_t v2338 = vqrdmulhq_n_s16(v2337, 16792); + int16x8_t v2339 = vaddq_s16(v2327, v2338); + int16x8_t v2340 = vqrdmulhq_n_s16(v2339, 16484); + int16x8_t v2341 = vaddq_s16(v2317, v2340); + int16x8_t v2342 = vsubq_s16(v2200, v2202); + int16x8_t v2343 = vsubq_s16(v2204, v2206); + int16x8_t v2344 = vqrdmulhq_n_s16(v2343, 25826); + int16x8_t v2345 = vaddq_s16(v2342, v2344); + int16x8_t v2346 = vsubq_s16(v2210, v2212); + int16x8_t v2347 = vsubq_s16(v2214, v2216); + int16x8_t v2348 = vqrdmulhq_n_s16(v2347, 25826); + int16x8_t v2349 = vaddq_s16(v2346, v2348); + int16x8_t v2350 = vqrdmulhq_n_s16(v2349, 18124); + int16x8_t v2351 = vaddq_s16(v2345, v2350); + int16x8_t v2352 = vsubq_s16(v2222, v2224); + int16x8_t v2353 = vsubq_s16(v2226, v2228); + int16x8_t v2354 = vqrdmulhq_n_s16(v2353, 25826); + int16x8_t v2355 = vaddq_s16(v2352, v2354); + int16x8_t v2356 = vsubq_s16(v2232, v2234); + int16x8_t v2357 = vsubq_s16(v2236, v2238); + int16x8_t v2358 = vqrdmulhq_n_s16(v2357, 25826); + int16x8_t v2359 = vaddq_s16(v2356, v2358); + int16x8_t v2360 = vqrdmulhq_n_s16(v2359, 18124); + int16x8_t v2361 = vaddq_s16(v2355, v2360); + int16x8_t v2362 = vqrdmulhq_n_s16(v2361, 16792); + int16x8_t v2363 = vaddq_s16(v2351, v2362); + int16x8_t v2364 = vsubq_s16(v2246, v2248); + int16x8_t v2365 = vsubq_s16(v2250, v2252); + int16x8_t v2366 = vqrdmulhq_n_s16(v2365, 25826); + int16x8_t v2367 = vaddq_s16(v2364, v2366); + int16x8_t v2368 = vsubq_s16(v2256, v2258); + int16x8_t v2369 = vsubq_s16(v2260, v2262); + int16x8_t v2370 = vqrdmulhq_n_s16(v2369, 25826); + int16x8_t v2371 = vaddq_s16(v2368, v2370); + int16x8_t v2372 = vqrdmulhq_n_s16(v2371, 18124); + int16x8_t v2373 = vaddq_s16(v2367, v2372); + int16x8_t v2374 = vsubq_s16(v2268, v2270); + int16x8_t v2375 = vsubq_s16(v2272, v2274); + int16x8_t v2376 = vqrdmulhq_n_s16(v2375, 25826); + int16x8_t v2377 = vaddq_s16(v2374, v2376); + int16x8_t v2378 = vsubq_s16(v2278, v2280); + int16x8_t v2379 = vsubq_s16(v2282, v2284); + int16x8_t v2380 = vqrdmulhq_n_s16(v2379, 25826); + int16x8_t v2381 = vaddq_s16(v2378, v2380); + int16x8_t v2382 = vqrdmulhq_n_s16(v2381, 18124); + int16x8_t v2383 = vaddq_s16(v2377, v2382); + int16x8_t v2384 = vqrdmulhq_n_s16(v2383, 16792); + int16x8_t v2385 = vaddq_s16(v2373, v2384); + int16x8_t v2386 = vqrdmulhq_n_s16(v2385, 16484); + int16x8_t v2387 = vaddq_s16(v2363, v2386); + int16x8_t v2388 = vqrdmulhq_n_s16(v2387, 16409); + int16x8_t v2389 = vaddq_s16(v2341, v2388); + int16x8_t v2390 = vsubq_s16(v1916, v1918); + int16x8_t v2391 = vsubq_s16(v1920, v1922); + int16x8_t v2392_tmp = vqrdmulhq_n_s16(v2391, 1988); + int16x8_t v2392 = vaddq_s16(v2392_tmp, v2391); + int16x8_t v2393 = vaddq_s16(v2390, v2392); + int16x8_t v2394 = vsubq_s16(v1926, v1928); + int16x8_t v2395 = vsubq_s16(v1930, v1932); + int16x8_t v2396_tmp = vqrdmulhq_n_s16(v2395, 1988); + int16x8_t v2396 = vaddq_s16(v2396_tmp, v2395); + int16x8_t v2397 = vaddq_s16(v2394, v2396); + int16x8_t v2398 = vqrdmulhq_n_s16(v2397, 19102); + int16x8_t v2399 = vaddq_s16(v2393, v2398); + int16x8_t v2400 = vsubq_s16(v1938, v1940); + int16x8_t v2401 = vsubq_s16(v1942, v1944); + int16x8_t v2402_tmp = vqrdmulhq_n_s16(v2401, 1988); + int16x8_t v2402 = vaddq_s16(v2402_tmp, v2401); + int16x8_t v2403 = vaddq_s16(v2400, v2402); + int16x8_t v2404 = vsubq_s16(v1948, v1950); + int16x8_t v2405 = vsubq_s16(v1952, v1954); + int16x8_t v2406_tmp = vqrdmulhq_n_s16(v2405, 1988); + int16x8_t v2406 = vaddq_s16(v2406_tmp, v2405); + int16x8_t v2407 = vaddq_s16(v2404, v2406); + int16x8_t v2408 = vqrdmulhq_n_s16(v2407, 19102); + int16x8_t v2409 = vaddq_s16(v2403, v2408); + int16x8_t v2410 = vqrdmulhq_n_s16(v2409, 17000); + int16x8_t v2411 = vaddq_s16(v2399, v2410); + int16x8_t v2412 = vsubq_s16(v1962, v1964); + int16x8_t v2413 = vsubq_s16(v1966, v1968); + int16x8_t v2414_tmp = vqrdmulhq_n_s16(v2413, 1988); + int16x8_t v2414 = vaddq_s16(v2414_tmp, v2413); + int16x8_t v2415 = vaddq_s16(v2412, v2414); + int16x8_t v2416 = vsubq_s16(v1972, v1974); + int16x8_t v2417 = vsubq_s16(v1976, v1978); + int16x8_t v2418_tmp = vqrdmulhq_n_s16(v2417, 1988); + int16x8_t v2418 = vaddq_s16(v2418_tmp, v2417); + int16x8_t v2419 = vaddq_s16(v2416, v2418); + int16x8_t v2420 = vqrdmulhq_n_s16(v2419, 19102); + int16x8_t v2421 = vaddq_s16(v2415, v2420); + int16x8_t v2422 = vsubq_s16(v1984, v1986); + int16x8_t v2423 = vsubq_s16(v1988, v1990); + int16x8_t v2424_tmp = vqrdmulhq_n_s16(v2423, 1988); + int16x8_t v2424 = vaddq_s16(v2424_tmp, v2423); + int16x8_t v2425 = vaddq_s16(v2422, v2424); + int16x8_t v2426 = vsubq_s16(v1994, v1996); + int16x8_t v2427 = vsubq_s16(v1998, v2000); + int16x8_t v2428_tmp = vqrdmulhq_n_s16(v2427, 1988); + int16x8_t v2428 = vaddq_s16(v2428_tmp, v2427); + int16x8_t v2429 = vaddq_s16(v2426, v2428); + int16x8_t v2430 = vqrdmulhq_n_s16(v2429, 19102); + int16x8_t v2431 = vaddq_s16(v2425, v2430); + int16x8_t v2432 = vqrdmulhq_n_s16(v2431, 17000); + int16x8_t v2433 = vaddq_s16(v2421, v2432); + int16x8_t v2434 = vqrdmulhq_n_s16(v2433, 16534); + int16x8_t v2435 = vaddq_s16(v2411, v2434); + int16x8_t v2436 = vsubq_s16(v2010, v2012); + int16x8_t v2437 = vsubq_s16(v2014, v2016); + int16x8_t v2438_tmp = vqrdmulhq_n_s16(v2437, 1988); + int16x8_t v2438 = vaddq_s16(v2438_tmp, v2437); + int16x8_t v2439 = vaddq_s16(v2436, v2438); + int16x8_t v2440 = vsubq_s16(v2020, v2022); + int16x8_t v2441 = vsubq_s16(v2024, v2026); + int16x8_t v2442_tmp = vqrdmulhq_n_s16(v2441, 1988); + int16x8_t v2442 = vaddq_s16(v2442_tmp, v2441); + int16x8_t v2443 = vaddq_s16(v2440, v2442); + int16x8_t v2444 = vqrdmulhq_n_s16(v2443, 19102); + int16x8_t v2445 = vaddq_s16(v2439, v2444); + int16x8_t v2446 = vsubq_s16(v2032, v2034); + int16x8_t v2447 = vsubq_s16(v2036, v2038); + int16x8_t v2448_tmp = vqrdmulhq_n_s16(v2447, 1988); + int16x8_t v2448 = vaddq_s16(v2448_tmp, v2447); + int16x8_t v2449 = vaddq_s16(v2446, v2448); + int16x8_t v2450 = vsubq_s16(v2042, v2044); + int16x8_t v2451 = vsubq_s16(v2046, v2048); + int16x8_t v2452_tmp = vqrdmulhq_n_s16(v2451, 1988); + int16x8_t v2452 = vaddq_s16(v2452_tmp, v2451); + int16x8_t v2453 = vaddq_s16(v2450, v2452); + int16x8_t v2454 = vqrdmulhq_n_s16(v2453, 19102); + int16x8_t v2455 = vaddq_s16(v2449, v2454); + int16x8_t v2456 = vqrdmulhq_n_s16(v2455, 17000); + int16x8_t v2457 = vaddq_s16(v2445, v2456); + int16x8_t v2458 = vsubq_s16(v2056, v2058); + int16x8_t v2459 = vsubq_s16(v2060, v2062); + int16x8_t v2460_tmp = vqrdmulhq_n_s16(v2459, 1988); + int16x8_t v2460 = vaddq_s16(v2460_tmp, v2459); + int16x8_t v2461 = vaddq_s16(v2458, v2460); + int16x8_t v2462 = vsubq_s16(v2066, v2068); + int16x8_t v2463 = vqrdmulhq_n_s16(v2072, 29490); + int16x8_t v2464 = vsubq_s16(v2070, v2463); + int16x8_t v2465_tmp = vqrdmulhq_n_s16(v2464, 1988); + int16x8_t v2465 = vaddq_s16(v2465_tmp, v2464); + int16x8_t v2466 = vaddq_s16(v2462, v2465); + int16x8_t v2467 = vqrdmulhq_n_s16(v2466, 19102); + int16x8_t v2468 = vaddq_s16(v2461, v2467); + int16x8_t v2469 = vsubq_s16(v2078, v2080); + int16x8_t v2470 = vsubq_s16(v2082, v2084); + int16x8_t v2471_tmp = vqrdmulhq_n_s16(v2470, 1988); + int16x8_t v2471 = vaddq_s16(v2471_tmp, v2470); + int16x8_t v2472 = vaddq_s16(v2469, v2471); + int16x8_t v2473 = vsubq_s16(v2088, v2090); + int16x8_t v2474 = vsubq_s16(v2092, v2094); + int16x8_t v2475_tmp = vqrdmulhq_n_s16(v2474, 1988); + int16x8_t v2475 = vaddq_s16(v2475_tmp, v2474); + int16x8_t v2476 = vaddq_s16(v2473, v2475); + int16x8_t v2477 = vqrdmulhq_n_s16(v2476, 19102); + int16x8_t v2478 = vaddq_s16(v2472, v2477); + int16x8_t v2479 = vqrdmulhq_n_s16(v2478, 17000); + int16x8_t v2480 = vaddq_s16(v2468, v2479); + int16x8_t v2481 = vqrdmulhq_n_s16(v2480, 16534); + int16x8_t v2482 = vaddq_s16(v2457, v2481); + int16x8_t v2483 = vqrdmulhq_n_s16(v2482, 16421); + int16x8_t v2484 = vaddq_s16(v2435, v2483); + int16x8_t v2485 = vsubq_s16(v1537, v1542); + int16x8_t v2486 = vsubq_s16(v1547, v1552); + int16x8_t v2487_tmp = vqrdmulhq_n_s16(v2486, 23673); + int16x8_t v2487 = vaddq_s16(v2487_tmp, v2486); + int16x8_t v2488 = vaddq_s16(v2485, v2487); + int16x8_t v2489 = vsubq_s16(v1559, v1564); + int16x8_t v2490 = vsubq_s16(v1569, v1574); + int16x8_t v2491_tmp = vqrdmulhq_n_s16(v2490, 23673); + int16x8_t v2491 = vaddq_s16(v2491_tmp, v2490); + int16x8_t v2492 = vaddq_s16(v2489, v2491); + int16x8_t v2493 = vqrdmulhq_n_s16(v2492, 20398); + int16x8_t v2494 = vaddq_s16(v2488, v2493); + int16x8_t v2495 = vsubq_s16(v1583, v1588); + int16x8_t v2496 = vsubq_s16(v1593, v1598); + int16x8_t v2497_tmp = vqrdmulhq_n_s16(v2496, 23673); + int16x8_t v2497 = vaddq_s16(v2497_tmp, v2496); + int16x8_t v2498 = vaddq_s16(v2495, v2497); + int16x8_t v2499 = vsubq_s16(v1605, v1610); + int16x8_t v2500 = vsubq_s16(v1615, v1620); + int16x8_t v2501_tmp = vqrdmulhq_n_s16(v2500, 23673); + int16x8_t v2501 = vaddq_s16(v2501_tmp, v2500); + int16x8_t v2502 = vaddq_s16(v2499, v2501); + int16x8_t v2503 = vqrdmulhq_n_s16(v2502, 20398); + int16x8_t v2504 = vaddq_s16(v2498, v2503); + int16x8_t v2505 = vqrdmulhq_n_s16(v2504, 17255); + int16x8_t v2506 = vaddq_s16(v2494, v2505); + int16x8_t v2507 = vsubq_s16(v1631, v1636); + int16x8_t v2508 = vsubq_s16(v1641, v1646); + int16x8_t v2509_tmp = vqrdmulhq_n_s16(v2508, 23673); + int16x8_t v2509 = vaddq_s16(v2509_tmp, v2508); + int16x8_t v2510 = vaddq_s16(v2507, v2509); + int16x8_t v2511 = vsubq_s16(v1653, v1658); + int16x8_t v2512 = vsubq_s16(v1663, v1668); + int16x8_t v2513_tmp = vqrdmulhq_n_s16(v2512, 23673); + int16x8_t v2513 = vaddq_s16(v2513_tmp, v2512); + int16x8_t v2514 = vaddq_s16(v2511, v2513); + int16x8_t v2515 = vqrdmulhq_n_s16(v2514, 20398); + int16x8_t v2516 = vaddq_s16(v2510, v2515); + int16x8_t v2517 = vsubq_s16(v1677, v1682); + int16x8_t v2518 = vsubq_s16(v1687, v1692); + int16x8_t v2519_tmp = vqrdmulhq_n_s16(v2518, 23673); + int16x8_t v2519 = vaddq_s16(v2519_tmp, v2518); + int16x8_t v2520 = vaddq_s16(v2517, v2519); + int16x8_t v2521 = vsubq_s16(v1699, v1704); + int16x8_t v2522 = vsubq_s16(v1709, v1714); + int16x8_t v2523_tmp = vqrdmulhq_n_s16(v2522, 23673); + int16x8_t v2523 = vaddq_s16(v2523_tmp, v2522); + int16x8_t v2524 = vaddq_s16(v2521, v2523); + int16x8_t v2525 = vqrdmulhq_n_s16(v2524, 20398); + int16x8_t v2526 = vaddq_s16(v2520, v2525); + int16x8_t v2527 = vqrdmulhq_n_s16(v2526, 17255); + int16x8_t v2528 = vaddq_s16(v2516, v2527); + int16x8_t v2529 = vqrdmulhq_n_s16(v2528, 16595); + int16x8_t v2530 = vaddq_s16(v2506, v2529); + int16x8_t v2531 = vsubq_s16(v1727, v1732); + int16x8_t v2532 = vsubq_s16(v1737, v1742); + int16x8_t v2533_tmp = vqrdmulhq_n_s16(v2532, 23673); + int16x8_t v2533 = vaddq_s16(v2533_tmp, v2532); + int16x8_t v2534 = vaddq_s16(v2531, v2533); + int16x8_t v2535 = vsubq_s16(v1749, v1754); + int16x8_t v2536 = vsubq_s16(v1759, v1764); + int16x8_t v2537_tmp = vqrdmulhq_n_s16(v2536, 23673); + int16x8_t v2537 = vaddq_s16(v2537_tmp, v2536); + int16x8_t v2538 = vaddq_s16(v2535, v2537); + int16x8_t v2539 = vqrdmulhq_n_s16(v2538, 20398); + int16x8_t v2540 = vaddq_s16(v2534, v2539); + int16x8_t v2541 = vsubq_s16(v1773, v1778); + int16x8_t v2542 = vsubq_s16(v1783, v1788); + int16x8_t v2543_tmp = vqrdmulhq_n_s16(v2542, 23673); + int16x8_t v2543 = vaddq_s16(v2543_tmp, v2542); + int16x8_t v2544 = vaddq_s16(v2541, v2543); + int16x8_t v2545 = vsubq_s16(v1795, v1800); + int16x8_t v2546 = vsubq_s16(v1805, v1810); + int16x8_t v2547_tmp = vqrdmulhq_n_s16(v2546, 23673); + int16x8_t v2547 = vaddq_s16(v2547_tmp, v2546); + int16x8_t v2548 = vaddq_s16(v2545, v2547); + int16x8_t v2549 = vqrdmulhq_n_s16(v2548, 20398); + int16x8_t v2550 = vaddq_s16(v2544, v2549); + int16x8_t v2551 = vqrdmulhq_n_s16(v2550, 17255); + int16x8_t v2552 = vaddq_s16(v2540, v2551); + int16x8_t v2553 = vsubq_s16(v1821, v1826); + int16x8_t v2554 = vsubq_s16(v1831, v1836); + int16x8_t v2555_tmp = vqrdmulhq_n_s16(v2554, 23673); + int16x8_t v2555 = vaddq_s16(v2555_tmp, v2554); + int16x8_t v2556 = vaddq_s16(v2553, v2555); + int16x8_t v2557 = vsubq_s16(v1843, v1848); + int16x8_t v2558 = vsubq_s16(v1853, v1858); + int16x8_t v2559_tmp = vqrdmulhq_n_s16(v2558, 23673); + int16x8_t v2559 = vaddq_s16(v2559_tmp, v2558); + int16x8_t v2560 = vaddq_s16(v2557, v2559); + int16x8_t v2561 = vqrdmulhq_n_s16(v2560, 20398); + int16x8_t v2562 = vaddq_s16(v2556, v2561); + int16x8_t v2563 = vsubq_s16(v1867, v1872); + int16x8_t v2564 = vsubq_s16(v1877, v1882); + int16x8_t v2565_tmp = vqrdmulhq_n_s16(v2564, 23673); + int16x8_t v2565 = vaddq_s16(v2565_tmp, v2564); + int16x8_t v2566 = vaddq_s16(v2563, v2565); + int16x8_t v2567 = vsubq_s16(v1889, v1894); + int16x8_t v2568 = vsubq_s16(v1899, v1904); + int16x8_t v2569_tmp = vqrdmulhq_n_s16(v2568, 23673); + int16x8_t v2569 = vaddq_s16(v2569_tmp, v2568); + int16x8_t v2570 = vaddq_s16(v2567, v2569); + int16x8_t v2571 = vqrdmulhq_n_s16(v2570, 20398); + int16x8_t v2572 = vaddq_s16(v2566, v2571); + int16x8_t v2573 = vqrdmulhq_n_s16(v2572, 17255); + int16x8_t v2574 = vaddq_s16(v2562, v2573); + int16x8_t v2575 = vqrdmulhq_n_s16(v2574, 16595); + int16x8_t v2576 = vaddq_s16(v2552, v2575); + int16x8_t v2577 = vqrdmulhq_n_s16(v2576, 16436); + int16x8_t v2578 = vaddq_s16(v2530, v2577); + int16x8_t v2579 = vsubq_s16(v9, v24); + int16x8_t v2580 = vsubq_s16(v42, v58); + int16x8_t v2581_tmp = vqrdmulhq_n_s16(v2580, 3314); + int16x8_t v2581 = vmlaq_n_s16(v2581_tmp, v2580, 5); + int16x8_t v2582 = vaddq_s16(v2579, v2581); + int16x8_t v2583 = vsubq_s16(v78, v101); + int16x8_t v2584 = vsubq_s16(v119, v136); + int16x8_t v2585_tmp = vqrdmulhq_n_s16(v2584, 3314); + int16x8_t v2585 = vmlaq_n_s16(v2585_tmp, v2584, 5); + int16x8_t v2586 = vaddq_s16(v2583, v2585); + int16x8_t v2587 = vqrdmulhq_n_s16(v2586, 22112); + int16x8_t v2588 = vaddq_s16(v2582, v2587); + int16x8_t v2589 = vsubq_s16(v158, v181); + int16x8_t v2590 = vsubq_s16(v213, v231); + int16x8_t v2591_tmp = vqrdmulhq_n_s16(v2590, 3314); + int16x8_t v2591 = vmlaq_n_s16(v2591_tmp, v2590, 5); + int16x8_t v2592 = vaddq_s16(v2589, v2591); + int16x8_t v2593 = vsubq_s16(v251, v274); + int16x8_t v2594 = vsubq_s16(v292, v310); + int16x8_t v2595_tmp = vqrdmulhq_n_s16(v2594, 3314); + int16x8_t v2595 = vmlaq_n_s16(v2595_tmp, v2594, 5); + int16x8_t v2596 = vaddq_s16(v2593, v2595); + int16x8_t v2597 = vqrdmulhq_n_s16(v2596, 22112); + int16x8_t v2598 = vaddq_s16(v2592, v2597); + int16x8_t v2599 = vqrdmulhq_n_s16(v2598, 17561); + int16x8_t v2600 = vaddq_s16(v2588, v2599); + int16x8_t v2601 = vsubq_s16(v334, v357); + int16x8_t v2602 = vsubq_s16(v389, v407); + int16x8_t v2603_tmp = vqrdmulhq_n_s16(v2602, 3314); + int16x8_t v2603 = vmlaq_n_s16(v2603_tmp, v2602, 5); + int16x8_t v2604 = vaddq_s16(v2601, v2603); + int16x8_t v2605 = vsubq_s16(v441, v480); + int16x8_t v2606 = vsubq_s16(v498, v517); + int16x8_t v2607_tmp = vqrdmulhq_n_s16(v2606, 3314); + int16x8_t v2607 = vmlaq_n_s16(v2607_tmp, v2606, 5); + int16x8_t v2608 = vaddq_s16(v2605, v2607); + int16x8_t v2609 = vqrdmulhq_n_s16(v2608, 22112); + int16x8_t v2610 = vaddq_s16(v2604, v2609); + int16x8_t v2611 = vsubq_s16(v539, v562); + int16x8_t v2612 = vsubq_s16(v594, v612); + int16x8_t v2613_tmp = vqrdmulhq_n_s16(v2612, 3314); + int16x8_t v2613 = vmlaq_n_s16(v2613_tmp, v2612, 5); + int16x8_t v2614 = vaddq_s16(v2611, v2613); + int16x8_t v2615 = vsubq_s16(v632, v655); + int16x8_t v2616 = vsubq_s16(v673, v692); + int16x8_t v2617_tmp = vqrdmulhq_n_s16(v2616, 3314); + int16x8_t v2617 = vmlaq_n_s16(v2617_tmp, v2616, 5); + int16x8_t v2618 = vaddq_s16(v2615, v2617); + int16x8_t v2619 = vqrdmulhq_n_s16(v2618, 22112); + int16x8_t v2620 = vaddq_s16(v2614, v2619); + int16x8_t v2621 = vqrdmulhq_n_s16(v2620, 17561); + int16x8_t v2622 = vaddq_s16(v2610, v2621); + int16x8_t v2623 = vqrdmulhq_n_s16(v2622, 16666); + int16x8_t v2624 = vaddq_s16(v2600, v2623); + int16x8_t v2625 = vsubq_s16(v718, v741); + int16x8_t v2626 = vsubq_s16(v773, v791); + int16x8_t v2627_tmp = vqrdmulhq_n_s16(v2626, 3314); + int16x8_t v2627 = vmlaq_n_s16(v2627_tmp, v2626, 5); + int16x8_t v2628 = vaddq_s16(v2625, v2627); + int16x8_t v2629 = vsubq_s16(v825, v864); + int16x8_t v2630 = vsubq_s16(v882, v901); + int16x8_t v2631_tmp = vqrdmulhq_n_s16(v2630, 3314); + int16x8_t v2631 = vmlaq_n_s16(v2631_tmp, v2630, 5); + int16x8_t v2632 = vaddq_s16(v2629, v2631); + int16x8_t v2633 = vqrdmulhq_n_s16(v2632, 22112); + int16x8_t v2634 = vaddq_s16(v2628, v2633); + int16x8_t v2635 = vsubq_s16(v937, v976); + int16x8_t v2636 = vsubq_s16(v1036, v1058); + int16x8_t v2637_tmp = vqrdmulhq_n_s16(v2636, 3314); + int16x8_t v2637 = vmlaq_n_s16(v2637_tmp, v2636, 5); + int16x8_t v2638 = vaddq_s16(v2635, v2637); + int16x8_t v2639 = vsubq_s16(v1078, v1101); + int16x8_t v2640 = vsubq_s16(v1119, v1139); + int16x8_t v2641_tmp = vqrdmulhq_n_s16(v2640, 3314); + int16x8_t v2641 = vmlaq_n_s16(v2641_tmp, v2640, 5); + int16x8_t v2642 = vaddq_s16(v2639, v2641); + int16x8_t v2643 = vqrdmulhq_n_s16(v2642, 22112); + int16x8_t v2644 = vaddq_s16(v2638, v2643); + int16x8_t v2645 = vqrdmulhq_n_s16(v2644, 17561); + int16x8_t v2646 = vaddq_s16(v2634, v2645); + int16x8_t v2647 = vsubq_s16(v1163, v1186); + int16x8_t v2648 = vsubq_s16(v1218, v1236); + int16x8_t v2649_tmp = vqrdmulhq_n_s16(v2648, 3314); + int16x8_t v2649 = vmlaq_n_s16(v2649_tmp, v2648, 5); + int16x8_t v2650 = vaddq_s16(v2647, v2649); + int16x8_t v2651 = vsubq_s16(v1270, v1309); + int16x8_t v2652 = vsubq_s16(v1327, v1346); + int16x8_t v2653_tmp = vqrdmulhq_n_s16(v2652, 3314); + int16x8_t v2653 = vmlaq_n_s16(v2653_tmp, v2652, 5); + int16x8_t v2654 = vaddq_s16(v2651, v2653); + int16x8_t v2655 = vqrdmulhq_n_s16(v2654, 22112); + int16x8_t v2656 = vaddq_s16(v2650, v2655); + int16x8_t v2657 = vsubq_s16(v1368, v1391); + int16x8_t v2658 = vsubq_s16(v1423, v1441); + int16x8_t v2659_tmp = vqrdmulhq_n_s16(v2658, 3314); + int16x8_t v2659 = vmlaq_n_s16(v2659_tmp, v2658, 5); + int16x8_t v2660 = vaddq_s16(v2657, v2659); + int16x8_t v2661 = vsubq_s16(v1461, v1484); + int16x8_t v2662 = vsubq_s16(v1502, v1522); + int16x8_t v2663_tmp = vqrdmulhq_n_s16(v2662, 3314); + int16x8_t v2663 = vmlaq_n_s16(v2663_tmp, v2662, 5); + int16x8_t v2664 = vaddq_s16(v2661, v2663); + int16x8_t v2665 = vqrdmulhq_n_s16(v2664, 22112); + int16x8_t v2666 = vaddq_s16(v2660, v2665); + int16x8_t v2667 = vqrdmulhq_n_s16(v2666, 17561); + int16x8_t v2668 = vaddq_s16(v2656, v2667); + int16x8_t v2669 = vqrdmulhq_n_s16(v2668, 16666); + int16x8_t v2670 = vaddq_s16(v2646, v2669); + int16x8_t v2671 = vqrdmulhq_n_s16(v2670, 16454); + int16x8_t v2672 = vaddq_s16(v2624, v2671); + int16x8_t v2673 = vsubq_s16(v2579, v2581); + int16x8_t v2674 = vsubq_s16(v2583, v2585); + int16x8_t v2675 = vqrdmulhq_n_s16(v2674, 24397); + int16x8_t v2676 = vaddq_s16(v2673, v2675); + int16x8_t v2677 = vsubq_s16(v2589, v2591); + int16x8_t v2678 = vsubq_s16(v2593, v2595); + int16x8_t v2679 = vqrdmulhq_n_s16(v2678, 24397); + int16x8_t v2680 = vaddq_s16(v2677, v2679); + int16x8_t v2681 = vqrdmulhq_n_s16(v2680, 17921); + int16x8_t v2682 = vaddq_s16(v2676, v2681); + int16x8_t v2683 = vsubq_s16(v2601, v2603); + int16x8_t v2684 = vsubq_s16(v2605, v2607); + int16x8_t v2685 = vqrdmulhq_n_s16(v2684, 24397); + int16x8_t v2686 = vaddq_s16(v2683, v2685); + int16x8_t v2687 = vsubq_s16(v2611, v2613); + int16x8_t v2688 = vsubq_s16(v2615, v2617); + int16x8_t v2689 = vqrdmulhq_n_s16(v2688, 24397); + int16x8_t v2690 = vaddq_s16(v2687, v2689); + int16x8_t v2691 = vqrdmulhq_n_s16(v2690, 17921); + int16x8_t v2692 = vaddq_s16(v2686, v2691); + int16x8_t v2693 = vqrdmulhq_n_s16(v2692, 16747); + int16x8_t v2694 = vaddq_s16(v2682, v2693); + int16x8_t v2695 = vsubq_s16(v2625, v2627); + int16x8_t v2696 = vsubq_s16(v2629, v2631); + int16x8_t v2697 = vqrdmulhq_n_s16(v2696, 24397); + int16x8_t v2698 = vaddq_s16(v2695, v2697); + int16x8_t v2699 = vsubq_s16(v2635, v2637); + int16x8_t v2700 = vsubq_s16(v2639, v2641); + int16x8_t v2701 = vqrdmulhq_n_s16(v2700, 24397); + int16x8_t v2702 = vaddq_s16(v2699, v2701); + int16x8_t v2703 = vqrdmulhq_n_s16(v2702, 17921); + int16x8_t v2704 = vaddq_s16(v2698, v2703); + int16x8_t v2705 = vsubq_s16(v2647, v2649); + int16x8_t v2706 = vsubq_s16(v2651, v2653); + int16x8_t v2707 = vqrdmulhq_n_s16(v2706, 24397); + int16x8_t v2708 = vaddq_s16(v2705, v2707); + int16x8_t v2709 = vsubq_s16(v2657, v2659); + int16x8_t v2710 = vsubq_s16(v2661, v2663); + int16x8_t v2711 = vqrdmulhq_n_s16(v2710, 24397); + int16x8_t v2712 = vaddq_s16(v2709, v2711); + int16x8_t v2713 = vqrdmulhq_n_s16(v2712, 17921); + int16x8_t v2714 = vaddq_s16(v2708, v2713); + int16x8_t v2715 = vqrdmulhq_n_s16(v2714, 16747); + int16x8_t v2716 = vaddq_s16(v2704, v2715); + int16x8_t v2717 = vqrdmulhq_n_s16(v2716, 16474); + int16x8_t v2718 = vaddq_s16(v2694, v2717); + int16x8_t v2719 = vsubq_s16(v2485, v2487); + int16x8_t v2720 = vsubq_s16(v2489, v2491); + int16x8_t v2721 = vqrdmulhq_n_s16(v2720, 27504); + int16x8_t v2722 = vaddq_s16(v2719, v2721); + int16x8_t v2723 = vsubq_s16(v2495, v2497); + int16x8_t v2724 = vsubq_s16(v2499, v2501); + int16x8_t v2725 = vqrdmulhq_n_s16(v2724, 27504); + int16x8_t v2726 = vaddq_s16(v2723, v2725); + int16x8_t v2727 = vqrdmulhq_n_s16(v2726, 18343); + int16x8_t v2728 = vaddq_s16(v2722, v2727); + int16x8_t v2729 = vsubq_s16(v2507, v2509); + int16x8_t v2730 = vsubq_s16(v2511, v2513); + int16x8_t v2731 = vqrdmulhq_n_s16(v2730, 27504); + int16x8_t v2732 = vaddq_s16(v2729, v2731); + int16x8_t v2733 = vsubq_s16(v2517, v2519); + int16x8_t v2734 = vsubq_s16(v2521, v2523); + int16x8_t v2735 = vqrdmulhq_n_s16(v2734, 27504); + int16x8_t v2736 = vaddq_s16(v2733, v2735); + int16x8_t v2737 = vqrdmulhq_n_s16(v2736, 18343); + int16x8_t v2738 = vaddq_s16(v2732, v2737); + int16x8_t v2739 = vqrdmulhq_n_s16(v2738, 16840); + int16x8_t v2740 = vaddq_s16(v2728, v2739); + int16x8_t v2741 = vsubq_s16(v2531, v2533); + int16x8_t v2742 = vsubq_s16(v2535, v2537); + int16x8_t v2743 = vqrdmulhq_n_s16(v2742, 27504); + int16x8_t v2744 = vaddq_s16(v2741, v2743); + int16x8_t v2745 = vsubq_s16(v2541, v2543); + int16x8_t v2746 = vsubq_s16(v2545, v2547); + int16x8_t v2747 = vqrdmulhq_n_s16(v2746, 27504); + int16x8_t v2748 = vaddq_s16(v2745, v2747); + int16x8_t v2749 = vqrdmulhq_n_s16(v2748, 18343); + int16x8_t v2750 = vaddq_s16(v2744, v2749); + int16x8_t v2751 = vsubq_s16(v2553, v2555); + int16x8_t v2752 = vsubq_s16(v2557, v2559); + int16x8_t v2753 = vqrdmulhq_n_s16(v2752, 27504); + int16x8_t v2754 = vaddq_s16(v2751, v2753); + int16x8_t v2755 = vsubq_s16(v2563, v2565); + int16x8_t v2756 = vsubq_s16(v2567, v2569); + int16x8_t v2757 = vqrdmulhq_n_s16(v2756, 27504); + int16x8_t v2758 = vaddq_s16(v2755, v2757); + int16x8_t v2759 = vqrdmulhq_n_s16(v2758, 18343); + int16x8_t v2760 = vaddq_s16(v2754, v2759); + int16x8_t v2761 = vqrdmulhq_n_s16(v2760, 16840); + int16x8_t v2762 = vaddq_s16(v2750, v2761); + int16x8_t v2763 = vqrdmulhq_n_s16(v2762, 16496); + int16x8_t v2764 = vaddq_s16(v2740, v2763); + int16x8_t v2765 = vsubq_s16(v2390, v2392); + int16x8_t v2766 = vsubq_s16(v2394, v2396); + int16x8_t v2767 = vqrdmulhq_n_s16(v2766, 31869); + int16x8_t v2768 = vaddq_s16(v2765, v2767); + int16x8_t v2769 = vsubq_s16(v2400, v2402); + int16x8_t v2770 = vsubq_s16(v2404, v2406); + int16x8_t v2771 = vqrdmulhq_n_s16(v2770, 31869); + int16x8_t v2772 = vaddq_s16(v2769, v2771); + int16x8_t v2773 = vqrdmulhq_n_s16(v2772, 18830); + int16x8_t v2774 = vaddq_s16(v2768, v2773); + int16x8_t v2775 = vsubq_s16(v2412, v2414); + int16x8_t v2776 = vsubq_s16(v2416, v2418); + int16x8_t v2777 = vqrdmulhq_n_s16(v2776, 31869); + int16x8_t v2778 = vaddq_s16(v2775, v2777); + int16x8_t v2779 = vsubq_s16(v2422, v2424); + int16x8_t v2780 = vsubq_s16(v2426, v2428); + int16x8_t v2781 = vqrdmulhq_n_s16(v2780, 31869); + int16x8_t v2782 = vaddq_s16(v2779, v2781); + int16x8_t v2783 = vqrdmulhq_n_s16(v2782, 18830); + int16x8_t v2784 = vaddq_s16(v2778, v2783); + int16x8_t v2785 = vqrdmulhq_n_s16(v2784, 16944); + int16x8_t v2786 = vaddq_s16(v2774, v2785); + int16x8_t v2787 = vsubq_s16(v2436, v2438); + int16x8_t v2788 = vsubq_s16(v2440, v2442); + int16x8_t v2789 = vqrdmulhq_n_s16(v2788, 31869); + int16x8_t v2790 = vaddq_s16(v2787, v2789); + int16x8_t v2791 = vsubq_s16(v2446, v2448); + int16x8_t v2792 = vsubq_s16(v2450, v2452); + int16x8_t v2793 = vqrdmulhq_n_s16(v2792, 31869); + int16x8_t v2794 = vaddq_s16(v2791, v2793); + int16x8_t v2795 = vqrdmulhq_n_s16(v2794, 18830); + int16x8_t v2796 = vaddq_s16(v2790, v2795); + int16x8_t v2797 = vsubq_s16(v2458, v2460); + int16x8_t v2798 = vsubq_s16(v2462, v2465); + int16x8_t v2799 = vqrdmulhq_n_s16(v2798, 31869); + int16x8_t v2800 = vaddq_s16(v2797, v2799); + int16x8_t v2801 = vsubq_s16(v2469, v2471); + int16x8_t v2802 = vsubq_s16(v2473, v2475); + int16x8_t v2803 = vqrdmulhq_n_s16(v2802, 31869); + int16x8_t v2804 = vaddq_s16(v2801, v2803); + int16x8_t v2805 = vqrdmulhq_n_s16(v2804, 18830); + int16x8_t v2806 = vaddq_s16(v2800, v2805); + int16x8_t v2807 = vqrdmulhq_n_s16(v2806, 16944); + int16x8_t v2808 = vaddq_s16(v2796, v2807); + int16x8_t v2809 = vqrdmulhq_n_s16(v2808, 16521); + int16x8_t v2810 = vaddq_s16(v2786, v2809); + int16x8_t v2811 = vsubq_s16(v2296, v2298); + int16x8_t v2812 = vsubq_s16(v2300, v2302); + int16x8_t v2813_tmp = vqrdmulhq_n_s16(v2812, 5552); + int16x8_t v2813 = vaddq_s16(v2813_tmp, v2812); + int16x8_t v2814 = vaddq_s16(v2811, v2813); + int16x8_t v2815 = vsubq_s16(v2306, v2308); + int16x8_t v2816 = vsubq_s16(v2310, v2312); + int16x8_t v2817_tmp = vqrdmulhq_n_s16(v2816, 5552); + int16x8_t v2817 = vaddq_s16(v2817_tmp, v2816); + int16x8_t v2818 = vaddq_s16(v2815, v2817); + int16x8_t v2819 = vqrdmulhq_n_s16(v2818, 19393); + int16x8_t v2820 = vaddq_s16(v2814, v2819); + int16x8_t v2821 = vsubq_s16(v2318, v2320); + int16x8_t v2822 = vsubq_s16(v2322, v2324); + int16x8_t v2823_tmp = vqrdmulhq_n_s16(v2822, 5552); + int16x8_t v2823 = vaddq_s16(v2823_tmp, v2822); + int16x8_t v2824 = vaddq_s16(v2821, v2823); + int16x8_t v2825 = vsubq_s16(v2328, v2330); + int16x8_t v2826 = vsubq_s16(v2332, v2334); + int16x8_t v2827_tmp = vqrdmulhq_n_s16(v2826, 5552); + int16x8_t v2827 = vaddq_s16(v2827_tmp, v2826); + int16x8_t v2828 = vaddq_s16(v2825, v2827); + int16x8_t v2829 = vqrdmulhq_n_s16(v2828, 19393); + int16x8_t v2830 = vaddq_s16(v2824, v2829); + int16x8_t v2831 = vqrdmulhq_n_s16(v2830, 17059); + int16x8_t v2832 = vaddq_s16(v2820, v2831); + int16x8_t v2833 = vsubq_s16(v2342, v2344); + int16x8_t v2834 = vsubq_s16(v2346, v2348); + int16x8_t v2835_tmp = vqrdmulhq_n_s16(v2834, 5552); + int16x8_t v2835 = vaddq_s16(v2835_tmp, v2834); + int16x8_t v2836 = vaddq_s16(v2833, v2835); + int16x8_t v2837 = vsubq_s16(v2352, v2354); + int16x8_t v2838 = vsubq_s16(v2356, v2358); + int16x8_t v2839_tmp = vqrdmulhq_n_s16(v2838, 5552); + int16x8_t v2839 = vaddq_s16(v2839_tmp, v2838); + int16x8_t v2840 = vaddq_s16(v2837, v2839); + int16x8_t v2841 = vqrdmulhq_n_s16(v2840, 19393); + int16x8_t v2842 = vaddq_s16(v2836, v2841); + int16x8_t v2843 = vsubq_s16(v2364, v2366); + int16x8_t v2844 = vsubq_s16(v2368, v2370); + int16x8_t v2845_tmp = vqrdmulhq_n_s16(v2844, 5552); + int16x8_t v2845 = vaddq_s16(v2845_tmp, v2844); + int16x8_t v2846 = vaddq_s16(v2843, v2845); + int16x8_t v2847 = vsubq_s16(v2374, v2376); + int16x8_t v2848 = vsubq_s16(v2378, v2380); + int16x8_t v2849_tmp = vqrdmulhq_n_s16(v2848, 5552); + int16x8_t v2849 = vaddq_s16(v2849_tmp, v2848); + int16x8_t v2850 = vaddq_s16(v2847, v2849); + int16x8_t v2851 = vqrdmulhq_n_s16(v2850, 19393); + int16x8_t v2852 = vaddq_s16(v2846, v2851); + int16x8_t v2853 = vqrdmulhq_n_s16(v2852, 17059); + int16x8_t v2854 = vaddq_s16(v2842, v2853); + int16x8_t v2855 = vqrdmulhq_n_s16(v2854, 16549); + int16x8_t v2856 = vaddq_s16(v2832, v2855); + int16x8_t v2857 = vsubq_s16(v2109, v2114); + int16x8_t v2858 = vsubq_s16(v2119, v2124); + int16x8_t v2859_tmp = vqrdmulhq_n_s16(v2858, 15865); + int16x8_t v2859 = vaddq_s16(v2859_tmp, v2858); + int16x8_t v2860 = vaddq_s16(v2857, v2859); + int16x8_t v2861 = vsubq_s16(v2131, v2136); + int16x8_t v2862 = vsubq_s16(v2141, v2146); + int16x8_t v2863_tmp = vqrdmulhq_n_s16(v2862, 15865); + int16x8_t v2863 = vaddq_s16(v2863_tmp, v2862); + int16x8_t v2864 = vaddq_s16(v2861, v2863); + int16x8_t v2865 = vqrdmulhq_n_s16(v2864, 20040); + int16x8_t v2866 = vaddq_s16(v2860, v2865); + int16x8_t v2867 = vsubq_s16(v2155, v2160); + int16x8_t v2868 = vsubq_s16(v2165, v2170); + int16x8_t v2869_tmp = vqrdmulhq_n_s16(v2868, 15865); + int16x8_t v2869 = vaddq_s16(v2869_tmp, v2868); + int16x8_t v2870 = vaddq_s16(v2867, v2869); + int16x8_t v2871 = vsubq_s16(v2177, v2182); + int16x8_t v2872 = vsubq_s16(v2187, v2192); + int16x8_t v2873_tmp = vqrdmulhq_n_s16(v2872, 15865); + int16x8_t v2873 = vaddq_s16(v2873_tmp, v2872); + int16x8_t v2874 = vaddq_s16(v2871, v2873); + int16x8_t v2875 = vqrdmulhq_n_s16(v2874, 20040); + int16x8_t v2876 = vaddq_s16(v2870, v2875); + int16x8_t v2877 = vqrdmulhq_n_s16(v2876, 17187); + int16x8_t v2878 = vaddq_s16(v2866, v2877); + int16x8_t v2879 = vsubq_s16(v2203, v2208); + int16x8_t v2880 = vsubq_s16(v2213, v2218); + int16x8_t v2881_tmp = vqrdmulhq_n_s16(v2880, 15865); + int16x8_t v2881 = vaddq_s16(v2881_tmp, v2880); + int16x8_t v2882 = vaddq_s16(v2879, v2881); + int16x8_t v2883 = vsubq_s16(v2225, v2230); + int16x8_t v2884 = vsubq_s16(v2235, v2240); + int16x8_t v2885_tmp = vqrdmulhq_n_s16(v2884, 15865); + int16x8_t v2885 = vaddq_s16(v2885_tmp, v2884); + int16x8_t v2886 = vaddq_s16(v2883, v2885); + int16x8_t v2887 = vqrdmulhq_n_s16(v2886, 20040); + int16x8_t v2888 = vaddq_s16(v2882, v2887); + int16x8_t v2889 = vsubq_s16(v2249, v2254); + int16x8_t v2890 = vsubq_s16(v2259, v2264); + int16x8_t v2891_tmp = vqrdmulhq_n_s16(v2890, 15865); + int16x8_t v2891 = vaddq_s16(v2891_tmp, v2890); + int16x8_t v2892 = vaddq_s16(v2889, v2891); + int16x8_t v2893 = vsubq_s16(v2271, v2276); + int16x8_t v2894 = vsubq_s16(v2281, v2286); + int16x8_t v2895_tmp = vqrdmulhq_n_s16(v2894, 15865); + int16x8_t v2895 = vaddq_s16(v2895_tmp, v2894); + int16x8_t v2896 = vaddq_s16(v2893, v2895); + int16x8_t v2897 = vqrdmulhq_n_s16(v2896, 20040); + int16x8_t v2898 = vaddq_s16(v2892, v2897); + int16x8_t v2899 = vqrdmulhq_n_s16(v2898, 17187); + int16x8_t v2900 = vaddq_s16(v2888, v2899); + int16x8_t v2901 = vqrdmulhq_n_s16(v2900, 16579); + int16x8_t v2902 = vaddq_s16(v2878, v2901); + int16x8_t v2903 = vsubq_s16(v1919, v1924); + int16x8_t v2904 = vsubq_s16(v1929, v1934); + int16x8_t v2905_tmp = vqrdmulhq_n_s16(v2904, 1893); + int16x8_t v2905 = vmlaq_n_s16(v2905_tmp, v2904, 2); + int16x8_t v2906 = vaddq_s16(v2903, v2905); + int16x8_t v2907 = vsubq_s16(v1941, v1946); + int16x8_t v2908 = vsubq_s16(v1951, v1956); + int16x8_t v2909_tmp = vqrdmulhq_n_s16(v2908, 1893); + int16x8_t v2909 = vmlaq_n_s16(v2909_tmp, v2908, 2); + int16x8_t v2910 = vaddq_s16(v2907, v2909); + int16x8_t v2911 = vqrdmulhq_n_s16(v2910, 20783); + int16x8_t v2912 = vaddq_s16(v2906, v2911); + int16x8_t v2913 = vsubq_s16(v1965, v1970); + int16x8_t v2914 = vsubq_s16(v1975, v1980); + int16x8_t v2915_tmp = vqrdmulhq_n_s16(v2914, 1893); + int16x8_t v2915 = vmlaq_n_s16(v2915_tmp, v2914, 2); + int16x8_t v2916 = vaddq_s16(v2913, v2915); + int16x8_t v2917 = vsubq_s16(v1987, v1992); + int16x8_t v2918 = vsubq_s16(v1997, v2002); + int16x8_t v2919_tmp = vqrdmulhq_n_s16(v2918, 1893); + int16x8_t v2919 = vmlaq_n_s16(v2919_tmp, v2918, 2); + int16x8_t v2920 = vaddq_s16(v2917, v2919); + int16x8_t v2921 = vqrdmulhq_n_s16(v2920, 20783); + int16x8_t v2922 = vaddq_s16(v2916, v2921); + int16x8_t v2923 = vqrdmulhq_n_s16(v2922, 17326); + int16x8_t v2924 = vaddq_s16(v2912, v2923); + int16x8_t v2925 = vsubq_s16(v2013, v2018); + int16x8_t v2926 = vsubq_s16(v2023, v2028); + int16x8_t v2927_tmp = vqrdmulhq_n_s16(v2926, 1893); + int16x8_t v2927 = vmlaq_n_s16(v2927_tmp, v2926, 2); + int16x8_t v2928 = vaddq_s16(v2925, v2927); + int16x8_t v2929 = vsubq_s16(v2035, v2040); + int16x8_t v2930 = vsubq_s16(v2045, v2050); + int16x8_t v2931_tmp = vqrdmulhq_n_s16(v2930, 1893); + int16x8_t v2931 = vmlaq_n_s16(v2931_tmp, v2930, 2); + int16x8_t v2932 = vaddq_s16(v2929, v2931); + int16x8_t v2933 = vqrdmulhq_n_s16(v2932, 20783); + int16x8_t v2934 = vaddq_s16(v2928, v2933); + int16x8_t v2935 = vsubq_s16(v2059, v2064); + int16x8_t v2936 = vsubq_s16(v2069, v2074); + int16x8_t v2937_tmp = vqrdmulhq_n_s16(v2936, 1893); + int16x8_t v2937 = vmlaq_n_s16(v2937_tmp, v2936, 2); + int16x8_t v2938 = vaddq_s16(v2935, v2937); + int16x8_t v2939 = vsubq_s16(v2081, v2086); + int16x8_t v2940 = vsubq_s16(v2091, v2096); + int16x8_t v2941_tmp = vqrdmulhq_n_s16(v2940, 1893); + int16x8_t v2941 = vmlaq_n_s16(v2941_tmp, v2940, 2); + int16x8_t v2942 = vaddq_s16(v2939, v2941); + int16x8_t v2943 = vqrdmulhq_n_s16(v2942, 20783); + int16x8_t v2944 = vaddq_s16(v2938, v2943); + int16x8_t v2945 = vqrdmulhq_n_s16(v2944, 17326); + int16x8_t v2946 = vaddq_s16(v2934, v2945); + int16x8_t v2947 = vqrdmulhq_n_s16(v2946, 16611); + int16x8_t v2948 = vaddq_s16(v2924, v2947); + int16x8_t v2949 = vsubq_s16(v1543, v1554); + int16x8_t v2950 = vsubq_s16(v1565, v1576); + int16x8_t v2951_tmp = vqrdmulhq_n_s16(v2950, 13357); + int16x8_t v2951 = vmlaq_n_s16(v2951_tmp, v2950, 3); + int16x8_t v2952 = vaddq_s16(v2949, v2951); + int16x8_t v2953 = vsubq_s16(v1589, v1600); + int16x8_t v2954 = vsubq_s16(v1611, v1622); + int16x8_t v2955_tmp = vqrdmulhq_n_s16(v2954, 13357); + int16x8_t v2955 = vmlaq_n_s16(v2955_tmp, v2954, 3); + int16x8_t v2956 = vaddq_s16(v2953, v2955); + int16x8_t v2957 = vqrdmulhq_n_s16(v2956, 21637); + int16x8_t v2958 = vaddq_s16(v2952, v2957); + int16x8_t v2959 = vsubq_s16(v1637, v1648); + int16x8_t v2960 = vsubq_s16(v1659, v1670); + int16x8_t v2961_tmp = vqrdmulhq_n_s16(v2960, 13357); + int16x8_t v2961 = vmlaq_n_s16(v2961_tmp, v2960, 3); + int16x8_t v2962 = vaddq_s16(v2959, v2961); + int16x8_t v2963 = vsubq_s16(v1683, v1694); + int16x8_t v2964 = vsubq_s16(v1705, v1716); + int16x8_t v2965_tmp = vqrdmulhq_n_s16(v2964, 13357); + int16x8_t v2965 = vmlaq_n_s16(v2965_tmp, v2964, 3); + int16x8_t v2966 = vaddq_s16(v2963, v2965); + int16x8_t v2967 = vqrdmulhq_n_s16(v2966, 21637); + int16x8_t v2968 = vaddq_s16(v2962, v2967); + int16x8_t v2969 = vqrdmulhq_n_s16(v2968, 17479); + int16x8_t v2970 = vaddq_s16(v2958, v2969); + int16x8_t v2971 = vsubq_s16(v1733, v1744); + int16x8_t v2972 = vsubq_s16(v1755, v1766); + int16x8_t v2973_tmp = vqrdmulhq_n_s16(v2972, 13357); + int16x8_t v2973 = vmlaq_n_s16(v2973_tmp, v2972, 3); + int16x8_t v2974 = vaddq_s16(v2971, v2973); + int16x8_t v2975 = vsubq_s16(v1779, v1790); + int16x8_t v2976 = vsubq_s16(v1801, v1812); + int16x8_t v2977_tmp = vqrdmulhq_n_s16(v2976, 13357); + int16x8_t v2977 = vmlaq_n_s16(v2977_tmp, v2976, 3); + int16x8_t v2978 = vaddq_s16(v2975, v2977); + int16x8_t v2979 = vqrdmulhq_n_s16(v2978, 21637); + int16x8_t v2980 = vaddq_s16(v2974, v2979); + int16x8_t v2981 = vsubq_s16(v1827, v1838); + int16x8_t v2982 = vsubq_s16(v1849, v1860); + int16x8_t v2983_tmp = vqrdmulhq_n_s16(v2982, 13357); + int16x8_t v2983 = vmlaq_n_s16(v2983_tmp, v2982, 3); + int16x8_t v2984 = vaddq_s16(v2981, v2983); + int16x8_t v2985 = vsubq_s16(v1873, v1884); + int16x8_t v2986 = vsubq_s16(v1895, v1906); + int16x8_t v2987_tmp = vqrdmulhq_n_s16(v2986, 13357); + int16x8_t v2987 = vmlaq_n_s16(v2987_tmp, v2986, 3); + int16x8_t v2988 = vaddq_s16(v2985, v2987); + int16x8_t v2989 = vqrdmulhq_n_s16(v2988, 21637); + int16x8_t v2990 = vaddq_s16(v2984, v2989); + int16x8_t v2991 = vqrdmulhq_n_s16(v2990, 17479); + int16x8_t v2992 = vaddq_s16(v2980, v2991); + int16x8_t v2993 = vqrdmulhq_n_s16(v2992, 16647); + int16x8_t v2994 = vaddq_s16(v2970, v2993); + int16x8_t v2995 = vsubq_s16(v25, v60); + int16x8_t v2996 = vsubq_s16(v102, v138); + int16x8_t v2997_tmp = vqrdmulhq_n_s16(v2996, 6226); + int16x8_t v2997 = vmlaq_n_s16(v2997_tmp, v2996, 10); + int16x8_t v2998 = vaddq_s16(v2995, v2997); + int16x8_t v2999 = vsubq_s16(v182, v233); + int16x8_t v3000 = vsubq_s16(v275, v312); + int16x8_t v3001_tmp = vqrdmulhq_n_s16(v3000, 6226); + int16x8_t v3001 = vmlaq_n_s16(v3001_tmp, v3000, 10); + int16x8_t v3002 = vaddq_s16(v2999, v3001); + int16x8_t v3003 = vqrdmulhq_n_s16(v3002, 22622); + int16x8_t v3004 = vaddq_s16(v2998, v3003); + int16x8_t v3005 = vsubq_s16(v358, v409); + int16x8_t v3006 = vsubq_s16(v481, v519); + int16x8_t v3007_tmp = vqrdmulhq_n_s16(v3006, 6226); + int16x8_t v3007 = vmlaq_n_s16(v3007_tmp, v3006, 10); + int16x8_t v3008 = vaddq_s16(v3005, v3007); + int16x8_t v3009 = vsubq_s16(v563, v614); + int16x8_t v3010 = vsubq_s16(v656, v694); + int16x8_t v3011_tmp = vqrdmulhq_n_s16(v3010, 6226); + int16x8_t v3011 = vmlaq_n_s16(v3011_tmp, v3010, 10); + int16x8_t v3012 = vaddq_s16(v3009, v3011); + int16x8_t v3013 = vqrdmulhq_n_s16(v3012, 22622); + int16x8_t v3014 = vaddq_s16(v3008, v3013); + int16x8_t v3015 = vqrdmulhq_n_s16(v3014, 17646); + int16x8_t v3016 = vaddq_s16(v3004, v3015); + int16x8_t v3017 = vsubq_s16(v742, v793); + int16x8_t v3018 = vsubq_s16(v865, v903); + int16x8_t v3019_tmp = vqrdmulhq_n_s16(v3018, 6226); + int16x8_t v3019 = vmlaq_n_s16(v3019_tmp, v3018, 10); + int16x8_t v3020 = vaddq_s16(v3017, v3019); + int16x8_t v3021 = vsubq_s16(v977, v1060); + int16x8_t v3022 = vsubq_s16(v1102, v1141); + int16x8_t v3023_tmp = vqrdmulhq_n_s16(v3022, 6226); + int16x8_t v3023 = vmlaq_n_s16(v3023_tmp, v3022, 10); + int16x8_t v3024 = vaddq_s16(v3021, v3023); + int16x8_t v3025 = vqrdmulhq_n_s16(v3024, 22622); + int16x8_t v3026 = vaddq_s16(v3020, v3025); + int16x8_t v3027 = vsubq_s16(v1187, v1238); + int16x8_t v3028 = vsubq_s16(v1310, v1348); + int16x8_t v3029_tmp = vqrdmulhq_n_s16(v3028, 6226); + int16x8_t v3029 = vmlaq_n_s16(v3029_tmp, v3028, 10); + int16x8_t v3030 = vaddq_s16(v3027, v3029); + int16x8_t v3031 = vsubq_s16(v1392, v1443); + int16x8_t v3032 = vsubq_s16(v1485, v1524); + int16x8_t v3033_tmp = vqrdmulhq_n_s16(v3032, 6226); + int16x8_t v3033 = vmlaq_n_s16(v3033_tmp, v3032, 10); + int16x8_t v3034 = vaddq_s16(v3031, v3033); + int16x8_t v3035 = vqrdmulhq_n_s16(v3034, 22622); + int16x8_t v3036 = vaddq_s16(v3030, v3035); + int16x8_t v3037 = vqrdmulhq_n_s16(v3036, 17646); + int16x8_t v3038 = vaddq_s16(v3026, v3037); + int16x8_t v3039 = vqrdmulhq_n_s16(v3038, 16685); + int16x8_t v3040 = vaddq_s16(v3016, v3039); + int16x8_t v3041 = vsubq_s16(v2995, v2997); + int16x8_t v3042 = vsubq_s16(v2999, v3001); + int16x8_t v3043 = vqrdmulhq_n_s16(v3042, 23761); + int16x8_t v3044 = vaddq_s16(v3041, v3043); + int16x8_t v3045 = vsubq_s16(v3005, v3007); + int16x8_t v3046 = vsubq_s16(v3009, v3011); + int16x8_t v3047 = vqrdmulhq_n_s16(v3046, 23761); + int16x8_t v3048 = vaddq_s16(v3045, v3047); + int16x8_t v3049 = vqrdmulhq_n_s16(v3048, 17826); + int16x8_t v3050 = vaddq_s16(v3044, v3049); + int16x8_t v3051 = vsubq_s16(v3017, v3019); + int16x8_t v3052 = vsubq_s16(v3021, v3023); + int16x8_t v3053 = vqrdmulhq_n_s16(v3052, 23761); + int16x8_t v3054 = vaddq_s16(v3051, v3053); + int16x8_t v3055 = vsubq_s16(v3027, v3029); + int16x8_t v3056 = vsubq_s16(v3031, v3033); + int16x8_t v3057 = vqrdmulhq_n_s16(v3056, 23761); + int16x8_t v3058 = vaddq_s16(v3055, v3057); + int16x8_t v3059 = vqrdmulhq_n_s16(v3058, 17826); + int16x8_t v3060 = vaddq_s16(v3054, v3059); + int16x8_t v3061 = vqrdmulhq_n_s16(v3060, 16726); + int16x8_t v3062 = vaddq_s16(v3050, v3061); + int16x8_t v3063 = vsubq_s16(v2949, v2951); + int16x8_t v3064 = vsubq_s16(v2953, v2955); + int16x8_t v3065 = vqrdmulhq_n_s16(v3064, 25084); + int16x8_t v3066 = vaddq_s16(v3063, v3065); + int16x8_t v3067 = vsubq_s16(v2959, v2961); + int16x8_t v3068 = vsubq_s16(v2963, v2965); + int16x8_t v3069 = vqrdmulhq_n_s16(v3068, 25084); + int16x8_t v3070 = vaddq_s16(v3067, v3069); + int16x8_t v3071 = vqrdmulhq_n_s16(v3070, 18021); + int16x8_t v3072 = vaddq_s16(v3066, v3071); + int16x8_t v3073 = vsubq_s16(v2971, v2973); + int16x8_t v3074 = vsubq_s16(v2975, v2977); + int16x8_t v3075 = vqrdmulhq_n_s16(v3074, 25084); + int16x8_t v3076 = vaddq_s16(v3073, v3075); + int16x8_t v3077 = vsubq_s16(v2981, v2983); + int16x8_t v3078 = vsubq_s16(v2985, v2987); + int16x8_t v3079 = vqrdmulhq_n_s16(v3078, 25084); + int16x8_t v3080 = vaddq_s16(v3077, v3079); + int16x8_t v3081 = vqrdmulhq_n_s16(v3080, 18021); + int16x8_t v3082 = vaddq_s16(v3076, v3081); + int16x8_t v3083 = vqrdmulhq_n_s16(v3082, 16769); + int16x8_t v3084 = vaddq_s16(v3072, v3083); + int16x8_t v3085 = vsubq_s16(v2903, v2905); + int16x8_t v3086 = vsubq_s16(v2907, v2909); + int16x8_t v3087 = vqrdmulhq_n_s16(v3086, 26631); + int16x8_t v3088 = vaddq_s16(v3085, v3087); + int16x8_t v3089 = vsubq_s16(v2913, v2915); + int16x8_t v3090 = vsubq_s16(v2917, v2919); + int16x8_t v3091 = vqrdmulhq_n_s16(v3090, 26631); + int16x8_t v3092 = vaddq_s16(v3089, v3091); + int16x8_t v3093 = vqrdmulhq_n_s16(v3092, 18231); + int16x8_t v3094 = vaddq_s16(v3088, v3093); + int16x8_t v3095 = vsubq_s16(v2925, v2927); + int16x8_t v3096 = vsubq_s16(v2929, v2931); + int16x8_t v3097 = vqrdmulhq_n_s16(v3096, 26631); + int16x8_t v3098 = vaddq_s16(v3095, v3097); + int16x8_t v3099 = vsubq_s16(v2935, v2937); + int16x8_t v3100 = vsubq_s16(v2939, v2941); + int16x8_t v3101 = vqrdmulhq_n_s16(v3100, 26631); + int16x8_t v3102 = vaddq_s16(v3099, v3101); + int16x8_t v3103 = vqrdmulhq_n_s16(v3102, 18231); + int16x8_t v3104 = vaddq_s16(v3098, v3103); + int16x8_t v3105 = vqrdmulhq_n_s16(v3104, 16815); + int16x8_t v3106 = vaddq_s16(v3094, v3105); + int16x8_t v3107 = vsubq_s16(v2857, v2859); + int16x8_t v3108 = vsubq_s16(v2861, v2863); + int16x8_t v3109 = vqrdmulhq_n_s16(v3108, 28454); + int16x8_t v3110 = vaddq_s16(v3107, v3109); + int16x8_t v3111 = vsubq_s16(v2867, v2869); + int16x8_t v3112 = vsubq_s16(v2871, v2873); + int16x8_t v3113 = vqrdmulhq_n_s16(v3112, 28454); + int16x8_t v3114 = vaddq_s16(v3111, v3113); + int16x8_t v3115 = vqrdmulhq_n_s16(v3114, 18458); + int16x8_t v3116 = vaddq_s16(v3110, v3115); + int16x8_t v3117 = vsubq_s16(v2879, v2881); + int16x8_t v3118 = vsubq_s16(v2883, v2885); + int16x8_t v3119 = vqrdmulhq_n_s16(v3118, 28454); + int16x8_t v3120 = vaddq_s16(v3117, v3119); + int16x8_t v3121 = vsubq_s16(v2889, v2891); + int16x8_t v3122 = vsubq_s16(v2893, v2895); + int16x8_t v3123 = vqrdmulhq_n_s16(v3122, 28454); + int16x8_t v3124 = vaddq_s16(v3121, v3123); + int16x8_t v3125 = vqrdmulhq_n_s16(v3124, 18458); + int16x8_t v3126 = vaddq_s16(v3120, v3125); + int16x8_t v3127 = vqrdmulhq_n_s16(v3126, 16865); + int16x8_t v3128 = vaddq_s16(v3116, v3127); + int16x8_t v3129 = vsubq_s16(v2811, v2813); + int16x8_t v3130 = vsubq_s16(v2815, v2817); + int16x8_t v3131 = vqrdmulhq_n_s16(v3130, 30624); + int16x8_t v3132 = vaddq_s16(v3129, v3131); + int16x8_t v3133 = vsubq_s16(v2821, v2823); + int16x8_t v3134 = vsubq_s16(v2825, v2827); + int16x8_t v3135 = vqrdmulhq_n_s16(v3134, 30624); + int16x8_t v3136 = vaddq_s16(v3133, v3135); + int16x8_t v3137 = vqrdmulhq_n_s16(v3136, 18702); + int16x8_t v3138 = vaddq_s16(v3132, v3137); + int16x8_t v3139 = vsubq_s16(v2833, v2835); + int16x8_t v3140 = vsubq_s16(v2837, v2839); + int16x8_t v3141 = vqrdmulhq_n_s16(v3140, 30624); + int16x8_t v3142 = vaddq_s16(v3139, v3141); + int16x8_t v3143 = vsubq_s16(v2843, v2845); + int16x8_t v3144 = vsubq_s16(v2847, v2849); + int16x8_t v3145 = vqrdmulhq_n_s16(v3144, 30624); + int16x8_t v3146 = vaddq_s16(v3143, v3145); + int16x8_t v3147 = vqrdmulhq_n_s16(v3146, 18702); + int16x8_t v3148 = vaddq_s16(v3142, v3147); + int16x8_t v3149 = vqrdmulhq_n_s16(v3148, 16916); + int16x8_t v3150 = vaddq_s16(v3138, v3149); + int16x8_t v3151 = vsubq_s16(v2765, v2767); + int16x8_t v3152 = vsubq_s16(v2769, v2771); + int16x8_t v3153_tmp = vqrdmulhq_n_s16(v3152, 472); + int16x8_t v3153 = vaddq_s16(v3153_tmp, v3152); + int16x8_t v3154 = vaddq_s16(v3151, v3153); + int16x8_t v3155 = vsubq_s16(v2775, v2777); + int16x8_t v3156 = vsubq_s16(v2779, v2781); + int16x8_t v3157_tmp = vqrdmulhq_n_s16(v3156, 472); + int16x8_t v3157 = vaddq_s16(v3157_tmp, v3156); + int16x8_t v3158 = vaddq_s16(v3155, v3157); + int16x8_t v3159 = vqrdmulhq_n_s16(v3158, 18964); + int16x8_t v3160 = vaddq_s16(v3154, v3159); + int16x8_t v3161 = vsubq_s16(v2787, v2789); + int16x8_t v3162 = vsubq_s16(v2791, v2793); + int16x8_t v3163_tmp = vqrdmulhq_n_s16(v3162, 472); + int16x8_t v3163 = vaddq_s16(v3163_tmp, v3162); + int16x8_t v3164 = vaddq_s16(v3161, v3163); + int16x8_t v3165 = vsubq_s16(v2797, v2799); + int16x8_t v3166 = vsubq_s16(v2801, v2803); + int16x8_t v3167_tmp = vqrdmulhq_n_s16(v3166, 472); + int16x8_t v3167 = vaddq_s16(v3167_tmp, v3166); + int16x8_t v3168 = vaddq_s16(v3165, v3167); + int16x8_t v3169 = vqrdmulhq_n_s16(v3168, 18964); + int16x8_t v3170 = vaddq_s16(v3164, v3169); + int16x8_t v3171 = vqrdmulhq_n_s16(v3170, 16971); + int16x8_t v3172 = vaddq_s16(v3160, v3171); + int16x8_t v3173 = vsubq_s16(v2719, v2721); + int16x8_t v3174 = vsubq_s16(v2723, v2725); + int16x8_t v3175_tmp = vqrdmulhq_n_s16(v3174, 3672); + int16x8_t v3175 = vaddq_s16(v3175_tmp, v3174); + int16x8_t v3176 = vaddq_s16(v3173, v3175); + int16x8_t v3177 = vsubq_s16(v2729, v2731); + int16x8_t v3178 = vsubq_s16(v2733, v2735); + int16x8_t v3179_tmp = vqrdmulhq_n_s16(v3178, 3672); + int16x8_t v3179 = vaddq_s16(v3179_tmp, v3178); + int16x8_t v3180 = vaddq_s16(v3177, v3179); + int16x8_t v3181 = vqrdmulhq_n_s16(v3180, 19245); + int16x8_t v3182 = vaddq_s16(v3176, v3181); + int16x8_t v3183 = vsubq_s16(v2741, v2743); + int16x8_t v3184 = vsubq_s16(v2745, v2747); + int16x8_t v3185_tmp = vqrdmulhq_n_s16(v3184, 3672); + int16x8_t v3185 = vaddq_s16(v3185_tmp, v3184); + int16x8_t v3186 = vaddq_s16(v3183, v3185); + int16x8_t v3187 = vsubq_s16(v2751, v2753); + int16x8_t v3188 = vsubq_s16(v2755, v2757); + int16x8_t v3189_tmp = vqrdmulhq_n_s16(v3188, 3672); + int16x8_t v3189 = vaddq_s16(v3189_tmp, v3188); + int16x8_t v3190 = vaddq_s16(v3187, v3189); + int16x8_t v3191 = vqrdmulhq_n_s16(v3190, 19245); + int16x8_t v3192 = vaddq_s16(v3186, v3191); + int16x8_t v3193 = vqrdmulhq_n_s16(v3192, 17029); + int16x8_t v3194 = vaddq_s16(v3182, v3193); + int16x8_t v3195 = vsubq_s16(v2673, v2675); + int16x8_t v3196 = vsubq_s16(v2677, v2679); + int16x8_t v3197_tmp = vqrdmulhq_n_s16(v3196, 7662); + int16x8_t v3197 = vaddq_s16(v3197_tmp, v3196); + int16x8_t v3198 = vaddq_s16(v3195, v3197); + int16x8_t v3199 = vsubq_s16(v2683, v2685); + int16x8_t v3200 = vsubq_s16(v2687, v2689); + int16x8_t v3201_tmp = vqrdmulhq_n_s16(v3200, 7662); + int16x8_t v3201 = vaddq_s16(v3201_tmp, v3200); + int16x8_t v3202 = vaddq_s16(v3199, v3201); + int16x8_t v3203 = vqrdmulhq_n_s16(v3202, 19546); + int16x8_t v3204 = vaddq_s16(v3198, v3203); + int16x8_t v3205 = vsubq_s16(v2695, v2697); + int16x8_t v3206 = vsubq_s16(v2699, v2701); + int16x8_t v3207_tmp = vqrdmulhq_n_s16(v3206, 7662); + int16x8_t v3207 = vaddq_s16(v3207_tmp, v3206); + int16x8_t v3208 = vaddq_s16(v3205, v3207); + int16x8_t v3209 = vsubq_s16(v2705, v2707); + int16x8_t v3210 = vsubq_s16(v2709, v2711); + int16x8_t v3211_tmp = vqrdmulhq_n_s16(v3210, 7662); + int16x8_t v3211 = vaddq_s16(v3211_tmp, v3210); + int16x8_t v3212 = vaddq_s16(v3209, v3211); + int16x8_t v3213 = vqrdmulhq_n_s16(v3212, 19546); + int16x8_t v3214 = vaddq_s16(v3208, v3213); + int16x8_t v3215 = vqrdmulhq_n_s16(v3214, 17090); + int16x8_t v3216 = vaddq_s16(v3204, v3215); + int16x8_t v3217 = vsubq_s16(v2582, v2587); + int16x8_t v3218 = vsubq_s16(v2592, v2597); + int16x8_t v3219_tmp = vqrdmulhq_n_s16(v3218, 12756); + int16x8_t v3219 = vaddq_s16(v3219_tmp, v3218); + int16x8_t v3220 = vaddq_s16(v3217, v3219); + int16x8_t v3221 = vsubq_s16(v2604, v2609); + int16x8_t v3222 = vsubq_s16(v2614, v2619); + int16x8_t v3223_tmp = vqrdmulhq_n_s16(v3222, 12756); + int16x8_t v3223 = vaddq_s16(v3223_tmp, v3222); + int16x8_t v3224 = vaddq_s16(v3221, v3223); + int16x8_t v3225 = vqrdmulhq_n_s16(v3224, 19869); + int16x8_t v3226 = vaddq_s16(v3220, v3225); + int16x8_t v3227 = vsubq_s16(v2628, v2633); + int16x8_t v3228 = vsubq_s16(v2638, v2643); + int16x8_t v3229_tmp = vqrdmulhq_n_s16(v3228, 12756); + int16x8_t v3229 = vaddq_s16(v3229_tmp, v3228); + int16x8_t v3230 = vaddq_s16(v3227, v3229); + int16x8_t v3231 = vsubq_s16(v2650, v2655); + int16x8_t v3232 = vsubq_s16(v2660, v2665); + int16x8_t v3233_tmp = vqrdmulhq_n_s16(v3232, 12756); + int16x8_t v3233 = vaddq_s16(v3233_tmp, v3232); + int16x8_t v3234 = vaddq_s16(v3231, v3233); + int16x8_t v3235 = vqrdmulhq_n_s16(v3234, 19869); + int16x8_t v3236 = vaddq_s16(v3230, v3235); + int16x8_t v3237 = vqrdmulhq_n_s16(v3236, 17153); + int16x8_t v3238 = vaddq_s16(v3226, v3237); + int16x8_t v3239 = vsubq_s16(v2488, v2493); + int16x8_t v3240 = vsubq_s16(v2498, v2503); + int16x8_t v3241_tmp = vqrdmulhq_n_s16(v3240, 19463); + int16x8_t v3241 = vaddq_s16(v3241_tmp, v3240); + int16x8_t v3242 = vaddq_s16(v3239, v3241); + int16x8_t v3243 = vsubq_s16(v2510, v2515); + int16x8_t v3244 = vsubq_s16(v2520, v2525); + int16x8_t v3245_tmp = vqrdmulhq_n_s16(v3244, 19463); + int16x8_t v3245 = vaddq_s16(v3245_tmp, v3244); + int16x8_t v3246 = vaddq_s16(v3243, v3245); + int16x8_t v3247 = vqrdmulhq_n_s16(v3246, 20216); + int16x8_t v3248 = vaddq_s16(v3242, v3247); + int16x8_t v3249 = vsubq_s16(v2534, v2539); + int16x8_t v3250 = vsubq_s16(v2544, v2549); + int16x8_t v3251_tmp = vqrdmulhq_n_s16(v3250, 19463); + int16x8_t v3251 = vaddq_s16(v3251_tmp, v3250); + int16x8_t v3252 = vaddq_s16(v3249, v3251); + int16x8_t v3253 = vsubq_s16(v2556, v2561); + int16x8_t v3254 = vsubq_s16(v2566, v2571); + int16x8_t v3255_tmp = vqrdmulhq_n_s16(v3254, 19463); + int16x8_t v3255 = vaddq_s16(v3255_tmp, v3254); + int16x8_t v3256 = vaddq_s16(v3253, v3255); + int16x8_t v3257 = vqrdmulhq_n_s16(v3256, 20216); + int16x8_t v3258 = vaddq_s16(v3252, v3257); + int16x8_t v3259 = vqrdmulhq_n_s16(v3258, 17220); + int16x8_t v3260 = vaddq_s16(v3248, v3259); + int16x8_t v3261 = vsubq_s16(v2393, v2398); + int16x8_t v3262 = vsubq_s16(v2403, v2408); + int16x8_t v3263_tmp = vqrdmulhq_n_s16(v3262, 28661); + int16x8_t v3263 = vaddq_s16(v3263_tmp, v3262); + int16x8_t v3264 = vaddq_s16(v3261, v3263); + int16x8_t v3265 = vsubq_s16(v2415, v2420); + int16x8_t v3266 = vsubq_s16(v2425, v2430); + int16x8_t v3267_tmp = vqrdmulhq_n_s16(v3266, 28661); + int16x8_t v3267 = vaddq_s16(v3267_tmp, v3266); + int16x8_t v3268 = vaddq_s16(v3265, v3267); + int16x8_t v3269 = vqrdmulhq_n_s16(v3268, 20587); + int16x8_t v3270 = vaddq_s16(v3264, v3269); + int16x8_t v3271 = vsubq_s16(v2439, v2444); + int16x8_t v3272 = vsubq_s16(v2449, v2454); + int16x8_t v3273_tmp = vqrdmulhq_n_s16(v3272, 28661); + int16x8_t v3273 = vaddq_s16(v3273_tmp, v3272); + int16x8_t v3274 = vaddq_s16(v3271, v3273); + int16x8_t v3275 = vsubq_s16(v2461, v2467); + int16x8_t v3276 = vsubq_s16(v2472, v2477); + int16x8_t v3277_tmp = vqrdmulhq_n_s16(v3276, 28661); + int16x8_t v3277 = vaddq_s16(v3277_tmp, v3276); + int16x8_t v3278 = vaddq_s16(v3275, v3277); + int16x8_t v3279 = vqrdmulhq_n_s16(v3278, 20587); + int16x8_t v3280 = vaddq_s16(v3274, v3279); + int16x8_t v3281 = vqrdmulhq_n_s16(v3280, 17290); + int16x8_t v3282 = vaddq_s16(v3270, v3281); + int16x8_t v3283 = vsubq_s16(v2299, v2304); + int16x8_t v3284 = vsubq_s16(v2309, v2314); + int16x8_t v3285_tmp = vqrdmulhq_n_s16(v3284, 9242); + int16x8_t v3285 = vmlaq_n_s16(v3285_tmp, v3284, 2); + int16x8_t v3286 = vaddq_s16(v3283, v3285); + int16x8_t v3287 = vsubq_s16(v2321, v2326); + int16x8_t v3288 = vsubq_s16(v2331, v2336); + int16x8_t v3289_tmp = vqrdmulhq_n_s16(v3288, 9242); + int16x8_t v3289 = vmlaq_n_s16(v3289_tmp, v3288, 2); + int16x8_t v3290 = vaddq_s16(v3287, v3289); + int16x8_t v3291 = vqrdmulhq_n_s16(v3290, 20985); + int16x8_t v3292 = vaddq_s16(v3286, v3291); + int16x8_t v3293 = vsubq_s16(v2345, v2350); + int16x8_t v3294 = vsubq_s16(v2355, v2360); + int16x8_t v3295_tmp = vqrdmulhq_n_s16(v3294, 9242); + int16x8_t v3295 = vmlaq_n_s16(v3295_tmp, v3294, 2); + int16x8_t v3296 = vaddq_s16(v3293, v3295); + int16x8_t v3297 = vsubq_s16(v2367, v2372); + int16x8_t v3298 = vsubq_s16(v2377, v2382); + int16x8_t v3299_tmp = vqrdmulhq_n_s16(v3298, 9242); + int16x8_t v3299 = vmlaq_n_s16(v3299_tmp, v3298, 2); + int16x8_t v3300 = vaddq_s16(v3297, v3299); + int16x8_t v3301 = vqrdmulhq_n_s16(v3300, 20985); + int16x8_t v3302 = vaddq_s16(v3296, v3301); + int16x8_t v3303 = vqrdmulhq_n_s16(v3302, 17363); + int16x8_t v3304 = vaddq_s16(v3292, v3303); + int16x8_t v3305 = vsubq_s16(v2115, v2126); + int16x8_t v3306 = vsubq_s16(v2137, v2148); + int16x8_t v3307_tmp = vqrdmulhq_n_s16(v3306, 30298); + int16x8_t v3307 = vmlaq_n_s16(v3307_tmp, v3306, 2); + int16x8_t v3308 = vaddq_s16(v3305, v3307); + int16x8_t v3309 = vsubq_s16(v2161, v2172); + int16x8_t v3310 = vsubq_s16(v2183, v2194); + int16x8_t v3311_tmp = vqrdmulhq_n_s16(v3310, 30298); + int16x8_t v3311 = vmlaq_n_s16(v3311_tmp, v3310, 2); + int16x8_t v3312 = vaddq_s16(v3309, v3311); + int16x8_t v3313 = vqrdmulhq_n_s16(v3312, 21412); + int16x8_t v3314 = vaddq_s16(v3308, v3313); + int16x8_t v3315 = vsubq_s16(v2209, v2220); + int16x8_t v3316 = vsubq_s16(v2231, v2242); + int16x8_t v3317_tmp = vqrdmulhq_n_s16(v3316, 30298); + int16x8_t v3317 = vmlaq_n_s16(v3317_tmp, v3316, 2); + int16x8_t v3318 = vaddq_s16(v3315, v3317); + int16x8_t v3319 = vsubq_s16(v2255, v2266); + int16x8_t v3320 = vsubq_s16(v2277, v2288); + int16x8_t v3321_tmp = vqrdmulhq_n_s16(v3320, 30298); + int16x8_t v3321 = vmlaq_n_s16(v3321_tmp, v3320, 2); + int16x8_t v3322 = vaddq_s16(v3319, v3321); + int16x8_t v3323 = vqrdmulhq_n_s16(v3322, 21412); + int16x8_t v3324 = vaddq_s16(v3318, v3323); + int16x8_t v3325 = vqrdmulhq_n_s16(v3324, 17440); + int16x8_t v3326 = vaddq_s16(v3314, v3325); + int16x8_t v3327 = vsubq_s16(v1925, v1936); + int16x8_t v3328 = vsubq_s16(v1947, v1958); + int16x8_t v3329_tmp = vqrdmulhq_n_s16(v3328, 2773); + int16x8_t v3329 = vmlaq_n_s16(v3329_tmp, v3328, 4); + int16x8_t v3330 = vaddq_s16(v3327, v3329); + int16x8_t v3331 = vsubq_s16(v1971, v1982); + int16x8_t v3332 = vsubq_s16(v1993, v2004); + int16x8_t v3333_tmp = vqrdmulhq_n_s16(v3332, 2773); + int16x8_t v3333 = vmlaq_n_s16(v3333_tmp, v3332, 4); + int16x8_t v3334 = vaddq_s16(v3331, v3333); + int16x8_t v3335 = vqrdmulhq_n_s16(v3334, 21871); + int16x8_t v3336 = vaddq_s16(v3330, v3335); + int16x8_t v3337 = vsubq_s16(v2019, v2030); + int16x8_t v3338 = vsubq_s16(v2041, v2052); + int16x8_t v3339_tmp = vqrdmulhq_n_s16(v3338, 2773); + int16x8_t v3339 = vmlaq_n_s16(v3339_tmp, v3338, 4); + int16x8_t v3340 = vaddq_s16(v3337, v3339); + int16x8_t v3341 = vsubq_s16(v2065, v2076); + int16x8_t v3342 = vsubq_s16(v2087, v2098); + int16x8_t v3343_tmp = vqrdmulhq_n_s16(v3342, 2773); + int16x8_t v3343 = vmlaq_n_s16(v3343_tmp, v3342, 4); + int16x8_t v3344 = vaddq_s16(v3341, v3343); + int16x8_t v3345 = vqrdmulhq_n_s16(v3344, 21871); + int16x8_t v3346 = vaddq_s16(v3340, v3345); + int16x8_t v3347 = vqrdmulhq_n_s16(v3346, 17520); + int16x8_t v3348 = vaddq_s16(v3336, v3347); + int16x8_t v3349 = vsubq_s16(v1555, v1578); + int16x8_t v3350 = vsubq_s16(v1601, v1624); + int16x8_t v3351_tmp = vqrdmulhq_n_s16(v3350, 26108); + int16x8_t v3351 = vmlaq_n_s16(v3351_tmp, v3350, 6); + int16x8_t v3352 = vaddq_s16(v3349, v3351); + int16x8_t v3353 = vsubq_s16(v1649, v1672); + int16x8_t v3354 = vsubq_s16(v1695, v1718); + int16x8_t v3355_tmp = vqrdmulhq_n_s16(v3354, 26108); + int16x8_t v3355 = vmlaq_n_s16(v3355_tmp, v3354, 6); + int16x8_t v3356 = vaddq_s16(v3353, v3355); + int16x8_t v3357 = vqrdmulhq_n_s16(v3356, 22363); + int16x8_t v3358 = vaddq_s16(v3352, v3357); + int16x8_t v3359 = vsubq_s16(v1745, v1768); + int16x8_t v3360 = vsubq_s16(v1791, v1814); + int16x8_t v3361_tmp = vqrdmulhq_n_s16(v3360, 26108); + int16x8_t v3361 = vmlaq_n_s16(v3361_tmp, v3360, 6); + int16x8_t v3362 = vaddq_s16(v3359, v3361); + int16x8_t v3363 = vsubq_s16(v1839, v1862); + int16x8_t v3364 = vsubq_s16(v1885, v1908); + int16x8_t v3365_tmp = vqrdmulhq_n_s16(v3364, 26108); + int16x8_t v3365 = vmlaq_n_s16(v3365_tmp, v3364, 6); + int16x8_t v3366 = vaddq_s16(v3363, v3365); + int16x8_t v3367 = vqrdmulhq_n_s16(v3366, 22363); + int16x8_t v3368 = vaddq_s16(v3362, v3367); + int16x8_t v3369 = vqrdmulhq_n_s16(v3368, 17603); + int16x8_t v3370 = vaddq_s16(v3358, v3369); + int16x8_t v3371 = vsubq_s16(v61, v140); + int16x8_t v3372 = vsubq_s16(v234, v314); + int16x8_t v3373_tmp = vqrdmulhq_n_s16(v3372, 12251); + int16x8_t v3373 = vmlaq_n_s16(v3373_tmp, v3372, 20); + int16x8_t v3374 = vaddq_s16(v3371, v3373); + int16x8_t v3375 = vsubq_s16(v410, v521); + int16x8_t v3376 = vsubq_s16(v615, v696); + int16x8_t v3377_tmp = vqrdmulhq_n_s16(v3376, 12251); + int16x8_t v3377 = vmlaq_n_s16(v3377_tmp, v3376, 20); + int16x8_t v3378 = vaddq_s16(v3375, v3377); + int16x8_t v3379 = vqrdmulhq_n_s16(v3378, 22891); + int16x8_t v3380 = vaddq_s16(v3374, v3379); + int16x8_t v3381 = vsubq_s16(v794, v905); + int16x8_t v3382 = vsubq_s16(v1061, v1143); + int16x8_t v3383_tmp = vqrdmulhq_n_s16(v3382, 12251); + int16x8_t v3383 = vmlaq_n_s16(v3383_tmp, v3382, 20); + int16x8_t v3384 = vaddq_s16(v3381, v3383); + int16x8_t v3385 = vsubq_s16(v1239, v1350); + int16x8_t v3386 = vsubq_s16(v1444, v1526); + int16x8_t v3387_tmp = vqrdmulhq_n_s16(v3386, 12251); + int16x8_t v3387 = vmlaq_n_s16(v3387_tmp, v3386, 20); + int16x8_t v3388 = vaddq_s16(v3385, v3387); + int16x8_t v3389 = vqrdmulhq_n_s16(v3388, 22891); + int16x8_t v3390 = vaddq_s16(v3384, v3389); + int16x8_t v3391 = vqrdmulhq_n_s16(v3390, 17689); + int16x8_t v3392 = vaddq_s16(v3380, v3391); + int16x8_t v3393 = vsubq_s16(v3371, v3373); + int16x8_t v3394 = vsubq_s16(v3375, v3377); + int16x8_t v3395 = vqrdmulhq_n_s16(v3394, 23460); + int16x8_t v3396 = vaddq_s16(v3393, v3395); + int16x8_t v3397 = vsubq_s16(v3381, v3383); + int16x8_t v3398 = vsubq_s16(v3385, v3387); + int16x8_t v3399 = vqrdmulhq_n_s16(v3398, 23460); + int16x8_t v3400 = vaddq_s16(v3397, v3399); + int16x8_t v3401 = vqrdmulhq_n_s16(v3400, 17779); + int16x8_t v3402 = vaddq_s16(v3396, v3401); + int16x8_t v3403 = vsubq_s16(v3349, v3351); + int16x8_t v3404 = vsubq_s16(v3353, v3355); + int16x8_t v3405 = vqrdmulhq_n_s16(v3404, 24073); + int16x8_t v3406 = vaddq_s16(v3403, v3405); + int16x8_t v3407 = vsubq_s16(v3359, v3361); + int16x8_t v3408 = vsubq_s16(v3363, v3365); + int16x8_t v3409 = vqrdmulhq_n_s16(v3408, 24073); + int16x8_t v3410 = vaddq_s16(v3407, v3409); + int16x8_t v3411 = vqrdmulhq_n_s16(v3410, 17873); + int16x8_t v3412 = vaddq_s16(v3406, v3411); + int16x8_t v3413 = vsubq_s16(v3327, v3329); + int16x8_t v3414 = vsubq_s16(v3331, v3333); + int16x8_t v3415 = vqrdmulhq_n_s16(v3414, 24734); + int16x8_t v3416 = vaddq_s16(v3413, v3415); + int16x8_t v3417 = vsubq_s16(v3337, v3339); + int16x8_t v3418 = vsubq_s16(v3341, v3343); + int16x8_t v3419 = vqrdmulhq_n_s16(v3418, 24734); + int16x8_t v3420 = vaddq_s16(v3417, v3419); + int16x8_t v3421 = vqrdmulhq_n_s16(v3420, 17971); + int16x8_t v3422 = vaddq_s16(v3416, v3421); + int16x8_t v3423 = vsubq_s16(v3305, v3307); + int16x8_t v3424 = vsubq_s16(v3309, v3311); + int16x8_t v3425 = vqrdmulhq_n_s16(v3424, 25448); + int16x8_t v3426 = vaddq_s16(v3423, v3425); + int16x8_t v3427 = vsubq_s16(v3315, v3317); + int16x8_t v3428 = vsubq_s16(v3319, v3321); + int16x8_t v3429 = vqrdmulhq_n_s16(v3428, 25448); + int16x8_t v3430 = vaddq_s16(v3427, v3429); + int16x8_t v3431 = vqrdmulhq_n_s16(v3430, 18072); + int16x8_t v3432 = vaddq_s16(v3426, v3431); + int16x8_t v3433 = vsubq_s16(v3283, v3285); + int16x8_t v3434 = vsubq_s16(v3287, v3289); + int16x8_t v3435 = vqrdmulhq_n_s16(v3434, 26220); + int16x8_t v3436 = vaddq_s16(v3433, v3435); + int16x8_t v3437 = vsubq_s16(v3293, v3295); + int16x8_t v3438 = vsubq_s16(v3297, v3299); + int16x8_t v3439 = vqrdmulhq_n_s16(v3438, 26220); + int16x8_t v3440 = vaddq_s16(v3437, v3439); + int16x8_t v3441 = vqrdmulhq_n_s16(v3440, 18177); + int16x8_t v3442 = vaddq_s16(v3436, v3441); + int16x8_t v3443 = vsubq_s16(v3261, v3263); + int16x8_t v3444 = vsubq_s16(v3265, v3267); + int16x8_t v3445 = vqrdmulhq_n_s16(v3444, 27058); + int16x8_t v3446 = vaddq_s16(v3443, v3445); + int16x8_t v3447 = vsubq_s16(v3271, v3273); + int16x8_t v3448 = vsubq_s16(v3275, v3277); + int16x8_t v3449 = vqrdmulhq_n_s16(v3448, 27058); + int16x8_t v3450 = vaddq_s16(v3447, v3449); + int16x8_t v3451 = vqrdmulhq_n_s16(v3450, 18286); + int16x8_t v3452 = vaddq_s16(v3446, v3451); + int16x8_t v3453 = vsubq_s16(v3239, v3241); + int16x8_t v3454 = vsubq_s16(v3243, v3245); + int16x8_t v3455 = vqrdmulhq_n_s16(v3454, 27969); + int16x8_t v3456 = vaddq_s16(v3453, v3455); + int16x8_t v3457 = vsubq_s16(v3249, v3251); + int16x8_t v3458 = vsubq_s16(v3253, v3255); + int16x8_t v3459 = vqrdmulhq_n_s16(v3458, 27969); + int16x8_t v3460 = vaddq_s16(v3457, v3459); + int16x8_t v3461 = vqrdmulhq_n_s16(v3460, 18400); + int16x8_t v3462 = vaddq_s16(v3456, v3461); + int16x8_t v3463 = vsubq_s16(v3217, v3219); + int16x8_t v3464 = vsubq_s16(v3221, v3223); + int16x8_t v3465 = vqrdmulhq_n_s16(v3464, 28961); + int16x8_t v3466 = vaddq_s16(v3463, v3465); + int16x8_t v3467 = vsubq_s16(v3227, v3229); + int16x8_t v3468 = vsubq_s16(v3231, v3233); + int16x8_t v3469 = vqrdmulhq_n_s16(v3468, 28961); + int16x8_t v3470 = vaddq_s16(v3467, v3469); + int16x8_t v3471 = vqrdmulhq_n_s16(v3470, 18517); + int16x8_t v3472 = vaddq_s16(v3466, v3471); + int16x8_t v3473 = vsubq_s16(v3195, v3197); + int16x8_t v3474 = vsubq_s16(v3199, v3201); + int16x8_t v3475 = vqrdmulhq_n_s16(v3474, 30044); + int16x8_t v3476 = vaddq_s16(v3473, v3475); + int16x8_t v3477 = vsubq_s16(v3205, v3207); + int16x8_t v3478 = vsubq_s16(v3209, v3211); + int16x8_t v3479 = vqrdmulhq_n_s16(v3478, 30044); + int16x8_t v3480 = vaddq_s16(v3477, v3479); + int16x8_t v3481 = vqrdmulhq_n_s16(v3480, 18639); + int16x8_t v3482 = vaddq_s16(v3476, v3481); + int16x8_t v3483 = vsubq_s16(v3173, v3175); + int16x8_t v3484 = vsubq_s16(v3177, v3179); + int16x8_t v3485 = vqrdmulhq_n_s16(v3484, 31232); + int16x8_t v3486 = vaddq_s16(v3483, v3485); + int16x8_t v3487 = vsubq_s16(v3183, v3185); + int16x8_t v3488 = vsubq_s16(v3187, v3189); + int16x8_t v3489 = vqrdmulhq_n_s16(v3488, 31232); + int16x8_t v3490 = vaddq_s16(v3487, v3489); + int16x8_t v3491 = vqrdmulhq_n_s16(v3490, 18765); + int16x8_t v3492 = vaddq_s16(v3486, v3491); + int16x8_t v3493 = vsubq_s16(v3151, v3153); + int16x8_t v3494 = vsubq_s16(v3155, v3157); + int16x8_t v3495 = vqrdmulhq_n_s16(v3494, 32538); + int16x8_t v3496 = vaddq_s16(v3493, v3495); + int16x8_t v3497 = vsubq_s16(v3161, v3163); + int16x8_t v3498 = vsubq_s16(v3165, v3167); + int16x8_t v3499 = vqrdmulhq_n_s16(v3498, 32538); + int16x8_t v3500 = vaddq_s16(v3497, v3499); + int16x8_t v3501 = vqrdmulhq_n_s16(v3500, 18896); + int16x8_t v3502 = vaddq_s16(v3496, v3501); + int16x8_t v3503 = vsubq_s16(v3129, v3131); + int16x8_t v3504 = vsubq_s16(v3133, v3135); + int16x8_t v3505_tmp = vqrdmulhq_n_s16(v3504, 1211); + int16x8_t v3505 = vaddq_s16(v3505_tmp, v3504); + int16x8_t v3506 = vaddq_s16(v3503, v3505); + int16x8_t v3507 = vsubq_s16(v3139, v3141); + int16x8_t v3508 = vsubq_s16(v3143, v3145); + int16x8_t v3509_tmp = vqrdmulhq_n_s16(v3508, 1211); + int16x8_t v3509 = vaddq_s16(v3509_tmp, v3508); + int16x8_t v3510 = vaddq_s16(v3507, v3509); + int16x8_t v3511 = vqrdmulhq_n_s16(v3510, 19032); + int16x8_t v3512 = vaddq_s16(v3506, v3511); + int16x8_t v3513 = vsubq_s16(v3107, v3109); + int16x8_t v3514 = vsubq_s16(v3111, v3113); + int16x8_t v3515_tmp = vqrdmulhq_n_s16(v3514, 2808); + int16x8_t v3515 = vaddq_s16(v3515_tmp, v3514); + int16x8_t v3516 = vaddq_s16(v3513, v3515); + int16x8_t v3517 = vsubq_s16(v3117, v3119); + int16x8_t v3518 = vsubq_s16(v3121, v3123); + int16x8_t v3519_tmp = vqrdmulhq_n_s16(v3518, 2808); + int16x8_t v3519 = vaddq_s16(v3519_tmp, v3518); + int16x8_t v3520 = vaddq_s16(v3517, v3519); + int16x8_t v3521 = vqrdmulhq_n_s16(v3520, 19172); + int16x8_t v3522 = vaddq_s16(v3516, v3521); + int16x8_t v3523 = vsubq_s16(v3085, v3087); + int16x8_t v3524 = vsubq_s16(v3089, v3091); + int16x8_t v3525_tmp = vqrdmulhq_n_s16(v3524, 4586); + int16x8_t v3525 = vaddq_s16(v3525_tmp, v3524); + int16x8_t v3526 = vaddq_s16(v3523, v3525); + int16x8_t v3527 = vsubq_s16(v3095, v3097); + int16x8_t v3528 = vsubq_s16(v3099, v3101); + int16x8_t v3529_tmp = vqrdmulhq_n_s16(v3528, 4586); + int16x8_t v3529 = vaddq_s16(v3529_tmp, v3528); + int16x8_t v3530 = vaddq_s16(v3527, v3529); + int16x8_t v3531 = vqrdmulhq_n_s16(v3530, 19318); + int16x8_t v3532 = vaddq_s16(v3526, v3531); + int16x8_t v3533 = vsubq_s16(v3063, v3065); + int16x8_t v3534 = vsubq_s16(v3067, v3069); + int16x8_t v3535_tmp = vqrdmulhq_n_s16(v3534, 6576); + int16x8_t v3535 = vaddq_s16(v3535_tmp, v3534); + int16x8_t v3536 = vaddq_s16(v3533, v3535); + int16x8_t v3537 = vsubq_s16(v3073, v3075); + int16x8_t v3538 = vsubq_s16(v3077, v3079); + int16x8_t v3539_tmp = vqrdmulhq_n_s16(v3538, 6576); + int16x8_t v3539 = vaddq_s16(v3539_tmp, v3538); + int16x8_t v3540 = vaddq_s16(v3537, v3539); + int16x8_t v3541 = vqrdmulhq_n_s16(v3540, 19469); + int16x8_t v3542 = vaddq_s16(v3536, v3541); + int16x8_t v3543 = vsubq_s16(v3041, v3043); + int16x8_t v3544 = vsubq_s16(v3045, v3047); + int16x8_t v3545_tmp = vqrdmulhq_n_s16(v3544, 8817); + int16x8_t v3545 = vaddq_s16(v3545_tmp, v3544); + int16x8_t v3546 = vaddq_s16(v3543, v3545); + int16x8_t v3547 = vsubq_s16(v3051, v3053); + int16x8_t v3548 = vsubq_s16(v3055, v3057); + int16x8_t v3549_tmp = vqrdmulhq_n_s16(v3548, 8817); + int16x8_t v3549 = vaddq_s16(v3549_tmp, v3548); + int16x8_t v3550 = vaddq_s16(v3547, v3549); + int16x8_t v3551 = vqrdmulhq_n_s16(v3550, 19625); + int16x8_t v3552 = vaddq_s16(v3546, v3551); + int16x8_t v3553 = vsubq_s16(v2998, v3003); + int16x8_t v3554 = vsubq_s16(v3008, v3013); + int16x8_t v3555_tmp = vqrdmulhq_n_s16(v3554, 11356); + int16x8_t v3555 = vaddq_s16(v3555_tmp, v3554); + int16x8_t v3556 = vaddq_s16(v3553, v3555); + int16x8_t v3557 = vsubq_s16(v3020, v3025); + int16x8_t v3558 = vsubq_s16(v3030, v3035); + int16x8_t v3559_tmp = vqrdmulhq_n_s16(v3558, 11356); + int16x8_t v3559 = vaddq_s16(v3559_tmp, v3558); + int16x8_t v3560 = vaddq_s16(v3557, v3559); + int16x8_t v3561 = vqrdmulhq_n_s16(v3560, 19786); + int16x8_t v3562 = vaddq_s16(v3556, v3561); + int16x8_t v3563 = vsubq_s16(v2952, v2957); + int16x8_t v3564 = vsubq_s16(v2962, v2967); + int16x8_t v3565_tmp = vqrdmulhq_n_s16(v3564, 14256); + int16x8_t v3565 = vaddq_s16(v3565_tmp, v3564); + int16x8_t v3566 = vaddq_s16(v3563, v3565); + int16x8_t v3567 = vsubq_s16(v2974, v2979); + int16x8_t v3568 = vsubq_s16(v2984, v2989); + int16x8_t v3569_tmp = vqrdmulhq_n_s16(v3568, 14256); + int16x8_t v3569 = vaddq_s16(v3569_tmp, v3568); + int16x8_t v3570 = vaddq_s16(v3567, v3569); + int16x8_t v3571 = vqrdmulhq_n_s16(v3570, 19954); + int16x8_t v3572 = vaddq_s16(v3566, v3571); + int16x8_t v3573 = vsubq_s16(v2906, v2911); + int16x8_t v3574 = vsubq_s16(v2916, v2921); + int16x8_t v3575_tmp = vqrdmulhq_n_s16(v3574, 17596); + int16x8_t v3575 = vaddq_s16(v3575_tmp, v3574); + int16x8_t v3576 = vaddq_s16(v3573, v3575); + int16x8_t v3577 = vsubq_s16(v2928, v2933); + int16x8_t v3578 = vsubq_s16(v2938, v2943); + int16x8_t v3579_tmp = vqrdmulhq_n_s16(v3578, 17596); + int16x8_t v3579 = vaddq_s16(v3579_tmp, v3578); + int16x8_t v3580 = vaddq_s16(v3577, v3579); + int16x8_t v3581 = vqrdmulhq_n_s16(v3580, 20127); + int16x8_t v3582 = vaddq_s16(v3576, v3581); + int16x8_t v3583 = vsubq_s16(v2860, v2865); + int16x8_t v3584 = vsubq_s16(v2870, v2875); + int16x8_t v3585_tmp = vqrdmulhq_n_s16(v3584, 21483); + int16x8_t v3585 = vaddq_s16(v3585_tmp, v3584); + int16x8_t v3586 = vaddq_s16(v3583, v3585); + int16x8_t v3587 = vsubq_s16(v2882, v2887); + int16x8_t v3588 = vsubq_s16(v2892, v2897); + int16x8_t v3589_tmp = vqrdmulhq_n_s16(v3588, 21483); + int16x8_t v3589 = vaddq_s16(v3589_tmp, v3588); + int16x8_t v3590 = vaddq_s16(v3587, v3589); + int16x8_t v3591 = vqrdmulhq_n_s16(v3590, 20306); + int16x8_t v3592 = vaddq_s16(v3586, v3591); + int16x8_t v3593 = vsubq_s16(v2814, v2819); + int16x8_t v3594 = vsubq_s16(v2824, v2829); + int16x8_t v3595_tmp = vqrdmulhq_n_s16(v3594, 26057); + int16x8_t v3595 = vaddq_s16(v3595_tmp, v3594); + int16x8_t v3596 = vaddq_s16(v3593, v3595); + int16x8_t v3597 = vsubq_s16(v2836, v2841); + int16x8_t v3598 = vsubq_s16(v2846, v2851); + int16x8_t v3599_tmp = vqrdmulhq_n_s16(v3598, 26057); + int16x8_t v3599 = vaddq_s16(v3599_tmp, v3598); + int16x8_t v3600 = vaddq_s16(v3597, v3599); + int16x8_t v3601 = vqrdmulhq_n_s16(v3600, 20492); + int16x8_t v3602 = vaddq_s16(v3596, v3601); + int16x8_t v3603 = vsubq_s16(v2768, v2773); + int16x8_t v3604 = vsubq_s16(v2778, v2783); + int16x8_t v3605_tmp = vqrdmulhq_n_s16(v3604, 31517); + int16x8_t v3605 = vaddq_s16(v3605_tmp, v3604); + int16x8_t v3606 = vaddq_s16(v3603, v3605); + int16x8_t v3607 = vsubq_s16(v2790, v2795); + int16x8_t v3608 = vsubq_s16(v2800, v2805); + int16x8_t v3609_tmp = vqrdmulhq_n_s16(v3608, 31517); + int16x8_t v3609 = vaddq_s16(v3609_tmp, v3608); + int16x8_t v3610 = vaddq_s16(v3607, v3609); + int16x8_t v3611 = vqrdmulhq_n_s16(v3610, 20684); + int16x8_t v3612 = vaddq_s16(v3606, v3611); + int16x8_t v3613 = vsubq_s16(v2722, v2727); + int16x8_t v3614 = vsubq_s16(v2732, v2737); + int16x8_t v3615_tmp = vqrdmulhq_n_s16(v3614, 5373); + int16x8_t v3615 = vmlaq_n_s16(v3615_tmp, v3614, 2); + int16x8_t v3616 = vaddq_s16(v3613, v3615); + int16x8_t v3617 = vsubq_s16(v2744, v2749); + int16x8_t v3618 = vsubq_s16(v2754, v2759); + int16x8_t v3619_tmp = vqrdmulhq_n_s16(v3618, 5373); + int16x8_t v3619 = vmlaq_n_s16(v3619_tmp, v3618, 2); + int16x8_t v3620 = vaddq_s16(v3617, v3619); + int16x8_t v3621 = vqrdmulhq_n_s16(v3620, 20883); + int16x8_t v3622 = vaddq_s16(v3616, v3621); + int16x8_t v3623 = vsubq_s16(v2676, v2681); + int16x8_t v3624 = vsubq_s16(v2686, v2691); + int16x8_t v3625_tmp = vqrdmulhq_n_s16(v3624, 13571); + int16x8_t v3625 = vmlaq_n_s16(v3625_tmp, v3624, 2); + int16x8_t v3626 = vaddq_s16(v3623, v3625); + int16x8_t v3627 = vsubq_s16(v2698, v2703); + int16x8_t v3628 = vsubq_s16(v2708, v2713); + int16x8_t v3629_tmp = vqrdmulhq_n_s16(v3628, 13571); + int16x8_t v3629 = vmlaq_n_s16(v3629_tmp, v3628, 2); + int16x8_t v3630 = vaddq_s16(v3627, v3629); + int16x8_t v3631 = vqrdmulhq_n_s16(v3630, 21089); + int16x8_t v3632 = vaddq_s16(v3626, v3631); + int16x8_t v3633 = vsubq_s16(v2588, v2599); + int16x8_t v3634 = vsubq_s16(v2610, v2621); + int16x8_t v3635_tmp = vqrdmulhq_n_s16(v3634, 23975); + int16x8_t v3635 = vmlaq_n_s16(v3635_tmp, v3634, 2); + int16x8_t v3636 = vaddq_s16(v3633, v3635); + int16x8_t v3637 = vsubq_s16(v2634, v2645); + int16x8_t v3638 = vsubq_s16(v2656, v2667); + int16x8_t v3639_tmp = vqrdmulhq_n_s16(v3638, 23975); + int16x8_t v3639 = vmlaq_n_s16(v3639_tmp, v3638, 2); + int16x8_t v3640 = vaddq_s16(v3637, v3639); + int16x8_t v3641 = vqrdmulhq_n_s16(v3640, 21303); + int16x8_t v3642 = vaddq_s16(v3636, v3641); + int16x8_t v3643 = vsubq_s16(v2494, v2505); + int16x8_t v3644 = vsubq_s16(v2516, v2527); + int16x8_t v3645_tmp = vqrdmulhq_n_s16(v3644, 4832); + int16x8_t v3645 = vmlaq_n_s16(v3645_tmp, v3644, 3); + int16x8_t v3646 = vaddq_s16(v3643, v3645); + int16x8_t v3647 = vsubq_s16(v2540, v2551); + int16x8_t v3648 = vsubq_s16(v2562, v2573); + int16x8_t v3649_tmp = vqrdmulhq_n_s16(v3648, 4832); + int16x8_t v3649 = vmlaq_n_s16(v3649_tmp, v3648, 3); + int16x8_t v3650 = vaddq_s16(v3647, v3649); + int16x8_t v3651 = vqrdmulhq_n_s16(v3650, 21524); + int16x8_t v3652 = vaddq_s16(v3646, v3651); + int16x8_t v3653 = vsubq_s16(v2399, v2410); + int16x8_t v3654 = vsubq_s16(v2421, v2432); + int16x8_t v3655_tmp = vqrdmulhq_n_s16(v3654, 23437); + int16x8_t v3655 = vmlaq_n_s16(v3655_tmp, v3654, 3); + int16x8_t v3656 = vaddq_s16(v3653, v3655); + int16x8_t v3657 = vsubq_s16(v2445, v2456); + int16x8_t v3658 = vsubq_s16(v2468, v2479); + int16x8_t v3659_tmp = vqrdmulhq_n_s16(v3658, 23437); + int16x8_t v3659 = vmlaq_n_s16(v3659_tmp, v3658, 3); + int16x8_t v3660 = vaddq_s16(v3657, v3659); + int16x8_t v3661 = vqrdmulhq_n_s16(v3660, 21753); + int16x8_t v3662 = vaddq_s16(v3656, v3661); + int16x8_t v3663 = vsubq_s16(v2305, v2316); + int16x8_t v3664 = vsubq_s16(v2327, v2338); + int16x8_t v3665_tmp = vqrdmulhq_n_s16(v3664, 17573); + int16x8_t v3665 = vmlaq_n_s16(v3665_tmp, v3664, 4); + int16x8_t v3666 = vaddq_s16(v3663, v3665); + int16x8_t v3667 = vsubq_s16(v2351, v2362); + int16x8_t v3668 = vsubq_s16(v2373, v2384); + int16x8_t v3669_tmp = vqrdmulhq_n_s16(v3668, 17573); + int16x8_t v3669 = vmlaq_n_s16(v3669_tmp, v3668, 4); + int16x8_t v3670 = vaddq_s16(v3667, v3669); + int16x8_t v3671 = vqrdmulhq_n_s16(v3670, 21990); + int16x8_t v3672 = vaddq_s16(v3666, v3671); + int16x8_t v3673 = vsubq_s16(v2127, v2150); + int16x8_t v3674 = vsubq_s16(v2173, v2196); + int16x8_t v3675_tmp = vqrdmulhq_n_s16(v3674, 27122); + int16x8_t v3675 = vmlaq_n_s16(v3675_tmp, v3674, 5); + int16x8_t v3676 = vaddq_s16(v3673, v3675); + int16x8_t v3677 = vsubq_s16(v2221, v2244); + int16x8_t v3678 = vsubq_s16(v2267, v2290); + int16x8_t v3679_tmp = vqrdmulhq_n_s16(v3678, 27122); + int16x8_t v3679 = vmlaq_n_s16(v3679_tmp, v3678, 5); + int16x8_t v3680 = vaddq_s16(v3677, v3679); + int16x8_t v3681 = vqrdmulhq_n_s16(v3680, 22236); + int16x8_t v3682 = vaddq_s16(v3676, v3681); + int16x8_t v3683 = vsubq_s16(v1937, v1960); + int16x8_t v3684 = vsubq_s16(v1983, v2006); + int16x8_t v3685_tmp = vqrdmulhq_n_s16(v3684, 5041); + int16x8_t v3685 = vmlaq_n_s16(v3685_tmp, v3684, 8); + int16x8_t v3686 = vaddq_s16(v3683, v3685); + int16x8_t v3687 = vsubq_s16(v2031, v2054); + int16x8_t v3688 = vsubq_s16(v2077, v2100); + int16x8_t v3689_tmp = vqrdmulhq_n_s16(v3688, 5041); + int16x8_t v3689 = vmlaq_n_s16(v3689_tmp, v3688, 8); + int16x8_t v3690 = vaddq_s16(v3687, v3689); + int16x8_t v3691 = vqrdmulhq_n_s16(v3690, 22491); + int16x8_t v3692 = vaddq_s16(v3686, v3691); + int16x8_t v3693 = vsubq_s16(v1579, v1626); + int16x8_t v3694 = vsubq_s16(v1673, v1720); + int16x8_t v3695_tmp = vqrdmulhq_n_s16(v3694, 19146); + int16x8_t v3695 = vmlaq_n_s16(v3695_tmp, v3694, 13); + int16x8_t v3696 = vaddq_s16(v3693, v3695); + int16x8_t v3697 = vsubq_s16(v1769, v1816); + int16x8_t v3698 = vsubq_s16(v1863, v1910); + int16x8_t v3699_tmp = vqrdmulhq_n_s16(v3698, 19146); + int16x8_t v3699 = vmlaq_n_s16(v3699_tmp, v3698, 13); + int16x8_t v3700 = vaddq_s16(v3697, v3699); + int16x8_t v3701 = vqrdmulhq_n_s16(v3700, 22755); + int16x8_t v3702 = vaddq_s16(v3696, v3701); + int16x8_t v3703 = vsubq_s16(v141, v316); + int16x8_t v3704 = vsubq_s16(v522, v698); + int16x8_t v3705_tmp = vqrdmulhq_n_s16(v3704, 24402); + int16x8_t v3705 = vmlaq_n_s16(v3705_tmp, v3704, 40); + int16x8_t v3706 = vaddq_s16(v3703, v3705); + int16x8_t v3707 = vsubq_s16(v906, v1145); + int16x8_t v3708 = vsubq_s16(v1351, v1528); + int16x8_t v3709_tmp = vqrdmulhq_n_s16(v3708, 24402); + int16x8_t v3709 = vmlaq_n_s16(v3709_tmp, v3708, 40); + int16x8_t v3710 = vaddq_s16(v3707, v3709); + int16x8_t v3711 = vqrdmulhq_n_s16(v3710, 23030); + int16x8_t v3712 = vaddq_s16(v3706, v3711); + int16x8_t v3713 = vsubq_s16(v3703, v3705); + int16x8_t v3714 = vsubq_s16(v3707, v3709); + int16x8_t v3715 = vqrdmulhq_n_s16(v3714, 23314); + int16x8_t v3716 = vaddq_s16(v3713, v3715); + int16x8_t v3717 = vsubq_s16(v3693, v3695); + int16x8_t v3718 = vsubq_s16(v3697, v3699); + int16x8_t v3719 = vqrdmulhq_n_s16(v3718, 23609); + int16x8_t v3720 = vaddq_s16(v3717, v3719); + int16x8_t v3721 = vsubq_s16(v3683, v3685); + int16x8_t v3722 = vsubq_s16(v3687, v3689); + int16x8_t v3723 = vqrdmulhq_n_s16(v3722, 23915); + int16x8_t v3724 = vaddq_s16(v3721, v3723); + int16x8_t v3725 = vsubq_s16(v3673, v3675); + int16x8_t v3726 = vsubq_s16(v3677, v3679); + int16x8_t v3727 = vqrdmulhq_n_s16(v3726, 24233); + int16x8_t v3728 = vaddq_s16(v3725, v3727); + int16x8_t v3729 = vsubq_s16(v3663, v3665); + int16x8_t v3730 = vsubq_s16(v3667, v3669); + int16x8_t v3731 = vqrdmulhq_n_s16(v3730, 24564); + int16x8_t v3732 = vaddq_s16(v3729, v3731); + int16x8_t v3733 = vsubq_s16(v3653, v3655); + int16x8_t v3734 = vsubq_s16(v3657, v3659); + int16x8_t v3735 = vqrdmulhq_n_s16(v3734, 24907); + int16x8_t v3736 = vaddq_s16(v3733, v3735); + int16x8_t v3737 = vsubq_s16(v3643, v3645); + int16x8_t v3738 = vsubq_s16(v3647, v3649); + int16x8_t v3739 = vqrdmulhq_n_s16(v3738, 25264); + int16x8_t v3740 = vaddq_s16(v3737, v3739); + int16x8_t v3741 = vsubq_s16(v3633, v3635); + int16x8_t v3742 = vsubq_s16(v3637, v3639); + int16x8_t v3743 = vqrdmulhq_n_s16(v3742, 25635); + int16x8_t v3744 = vaddq_s16(v3741, v3743); + int16x8_t v3745 = vsubq_s16(v3623, v3625); + int16x8_t v3746 = vsubq_s16(v3627, v3629); + int16x8_t v3747 = vqrdmulhq_n_s16(v3746, 26021); + int16x8_t v3748 = vaddq_s16(v3745, v3747); + int16x8_t v3749 = vsubq_s16(v3613, v3615); + int16x8_t v3750 = vsubq_s16(v3617, v3619); + int16x8_t v3751 = vqrdmulhq_n_s16(v3750, 26423); + int16x8_t v3752 = vaddq_s16(v3749, v3751); + int16x8_t v3753 = vsubq_s16(v3603, v3605); + int16x8_t v3754 = vsubq_s16(v3607, v3609); + int16x8_t v3755 = vqrdmulhq_n_s16(v3754, 26842); + int16x8_t v3756 = vaddq_s16(v3753, v3755); + int16x8_t v3757 = vsubq_s16(v3593, v3595); + int16x8_t v3758 = vsubq_s16(v3597, v3599); + int16x8_t v3759 = vqrdmulhq_n_s16(v3758, 27279); + int16x8_t v3760 = vaddq_s16(v3757, v3759); + int16x8_t v3761 = vsubq_s16(v3583, v3585); + int16x8_t v3762 = vsubq_s16(v3587, v3589); + int16x8_t v3763 = vqrdmulhq_n_s16(v3762, 27734); + int16x8_t v3764 = vaddq_s16(v3761, v3763); + int16x8_t v3765 = vsubq_s16(v3573, v3575); + int16x8_t v3766 = vsubq_s16(v3577, v3579); + int16x8_t v3767 = vqrdmulhq_n_s16(v3766, 28209); + int16x8_t v3768 = vaddq_s16(v3765, v3767); + int16x8_t v3769 = vsubq_s16(v3563, v3565); + int16x8_t v3770 = vsubq_s16(v3567, v3569); + int16x8_t v3771 = vqrdmulhq_n_s16(v3770, 28705); + int16x8_t v3772 = vaddq_s16(v3769, v3771); + int16x8_t v3773 = vsubq_s16(v3553, v3555); + int16x8_t v3774 = vsubq_s16(v3557, v3559); + int16x8_t v3775 = vqrdmulhq_n_s16(v3774, 29223); + int16x8_t v3776 = vaddq_s16(v3773, v3775); + int16x8_t v3777 = vsubq_s16(v3543, v3545); + int16x8_t v3778 = vsubq_s16(v3547, v3549); + int16x8_t v3779 = vqrdmulhq_n_s16(v3778, 29764); + int16x8_t v3780 = vaddq_s16(v3777, v3779); + int16x8_t v3781 = vsubq_s16(v3533, v3535); + int16x8_t v3782 = vsubq_s16(v3537, v3539); + int16x8_t v3783 = vqrdmulhq_n_s16(v3782, 30331); + int16x8_t v3784 = vaddq_s16(v3781, v3783); + int16x8_t v3785 = vsubq_s16(v3523, v3525); + int16x8_t v3786 = vsubq_s16(v3527, v3529); + int16x8_t v3787 = vqrdmulhq_n_s16(v3786, 30925); + int16x8_t v3788 = vaddq_s16(v3785, v3787); + int16x8_t v3789 = vsubq_s16(v3513, v3515); + int16x8_t v3790 = vsubq_s16(v3517, v3519); + int16x8_t v3791 = vqrdmulhq_n_s16(v3790, 31547); + int16x8_t v3792 = vaddq_s16(v3789, v3791); + int16x8_t v3793 = vsubq_s16(v3503, v3505); + int16x8_t v3794 = vsubq_s16(v3507, v3509); + int16x8_t v3795 = vqrdmulhq_n_s16(v3794, 32199); + int16x8_t v3796 = vaddq_s16(v3793, v3795); + int16x8_t v3797 = vsubq_s16(v3493, v3495); + int16x8_t v3798 = vsubq_s16(v3497, v3499); + int16x8_t v3799_tmp = vqrdmulhq_n_s16(v3798, 117); + int16x8_t v3799 = vaddq_s16(v3799_tmp, v3798); + int16x8_t v3800 = vaddq_s16(v3797, v3799); + int16x8_t v3801 = vsubq_s16(v3483, v3485); + int16x8_t v3802 = vsubq_s16(v3487, v3489); + int16x8_t v3803_tmp = vqrdmulhq_n_s16(v3802, 837); + int16x8_t v3803 = vaddq_s16(v3803_tmp, v3802); + int16x8_t v3804 = vaddq_s16(v3801, v3803); + int16x8_t v3805 = vsubq_s16(v3473, v3475); + int16x8_t v3806 = vsubq_s16(v3477, v3479); + int16x8_t v3807_tmp = vqrdmulhq_n_s16(v3806, 1594); + int16x8_t v3807 = vaddq_s16(v3807_tmp, v3806); + int16x8_t v3808 = vaddq_s16(v3805, v3807); + int16x8_t v3809 = vsubq_s16(v3463, v3465); + int16x8_t v3810 = vsubq_s16(v3467, v3469); + int16x8_t v3811_tmp = vqrdmulhq_n_s16(v3810, 2393); + int16x8_t v3811 = vaddq_s16(v3811_tmp, v3810); + int16x8_t v3812 = vaddq_s16(v3809, v3811); + int16x8_t v3813 = vsubq_s16(v3453, v3455); + int16x8_t v3814 = vsubq_s16(v3457, v3459); + int16x8_t v3815_tmp = vqrdmulhq_n_s16(v3814, 3234); + int16x8_t v3815 = vaddq_s16(v3815_tmp, v3814); + int16x8_t v3816 = vaddq_s16(v3813, v3815); + int16x8_t v3817 = vsubq_s16(v3443, v3445); + int16x8_t v3818 = vsubq_s16(v3447, v3449); + int16x8_t v3819_tmp = vqrdmulhq_n_s16(v3818, 4123); + int16x8_t v3819 = vaddq_s16(v3819_tmp, v3818); + int16x8_t v3820 = vaddq_s16(v3817, v3819); + int16x8_t v3821 = vsubq_s16(v3433, v3435); + int16x8_t v3822 = vsubq_s16(v3437, v3439); + int16x8_t v3823_tmp = vqrdmulhq_n_s16(v3822, 5062); + int16x8_t v3823 = vaddq_s16(v3823_tmp, v3822); + int16x8_t v3824 = vaddq_s16(v3821, v3823); + int16x8_t v3825 = vsubq_s16(v3423, v3425); + int16x8_t v3826 = vsubq_s16(v3427, v3429); + int16x8_t v3827_tmp = vqrdmulhq_n_s16(v3826, 6057); + int16x8_t v3827 = vaddq_s16(v3827_tmp, v3826); + int16x8_t v3828 = vaddq_s16(v3825, v3827); + int16x8_t v3829 = vsubq_s16(v3413, v3415); + int16x8_t v3830 = vsubq_s16(v3417, v3419); + int16x8_t v3831_tmp = vqrdmulhq_n_s16(v3830, 7111); + int16x8_t v3831 = vaddq_s16(v3831_tmp, v3830); + int16x8_t v3832 = vaddq_s16(v3829, v3831); + int16x8_t v3833 = vsubq_s16(v3403, v3405); + int16x8_t v3834 = vsubq_s16(v3407, v3409); + int16x8_t v3835_tmp = vqrdmulhq_n_s16(v3834, 8231); + int16x8_t v3835 = vaddq_s16(v3835_tmp, v3834); + int16x8_t v3836 = vaddq_s16(v3833, v3835); + int16x8_t v3837 = vsubq_s16(v3393, v3395); + int16x8_t v3838 = vsubq_s16(v3397, v3399); + int16x8_t v3839_tmp = vqrdmulhq_n_s16(v3838, 9421); + int16x8_t v3839 = vaddq_s16(v3839_tmp, v3838); + int16x8_t v3840 = vaddq_s16(v3837, v3839); + int16x8_t v3841 = vsubq_s16(v3374, v3379); + int16x8_t v3842 = vsubq_s16(v3384, v3389); + int16x8_t v3843_tmp = vqrdmulhq_n_s16(v3842, 10690); + int16x8_t v3843 = vaddq_s16(v3843_tmp, v3842); + int16x8_t v3844 = vaddq_s16(v3841, v3843); + int16x8_t v3845 = vsubq_s16(v3352, v3357); + int16x8_t v3846 = vsubq_s16(v3362, v3367); + int16x8_t v3847_tmp = vqrdmulhq_n_s16(v3846, 12044); + int16x8_t v3847 = vaddq_s16(v3847_tmp, v3846); + int16x8_t v3848 = vaddq_s16(v3845, v3847); + int16x8_t v3849 = vsubq_s16(v3330, v3335); + int16x8_t v3850 = vsubq_s16(v3340, v3345); + int16x8_t v3851_tmp = vqrdmulhq_n_s16(v3850, 13493); + int16x8_t v3851 = vaddq_s16(v3851_tmp, v3850); + int16x8_t v3852 = vaddq_s16(v3849, v3851); + int16x8_t v3853 = vsubq_s16(v3308, v3313); + int16x8_t v3854 = vsubq_s16(v3318, v3323); + int16x8_t v3855_tmp = vqrdmulhq_n_s16(v3854, 15046); + int16x8_t v3855 = vaddq_s16(v3855_tmp, v3854); + int16x8_t v3856 = vaddq_s16(v3853, v3855); + int16x8_t v3857 = vsubq_s16(v3286, v3291); + int16x8_t v3858 = vsubq_s16(v3296, v3301); + int16x8_t v3859_tmp = vqrdmulhq_n_s16(v3858, 16715); + int16x8_t v3859 = vaddq_s16(v3859_tmp, v3858); + int16x8_t v3860 = vaddq_s16(v3857, v3859); + int16x8_t v3861 = vsubq_s16(v3264, v3269); + int16x8_t v3862 = vsubq_s16(v3274, v3279); + int16x8_t v3863_tmp = vqrdmulhq_n_s16(v3862, 18512); + int16x8_t v3863 = vaddq_s16(v3863_tmp, v3862); + int16x8_t v3864 = vaddq_s16(v3861, v3863); + int16x8_t v3865 = vsubq_s16(v3242, v3247); + int16x8_t v3866 = vsubq_s16(v3252, v3257); + int16x8_t v3867_tmp = vqrdmulhq_n_s16(v3866, 20453); + int16x8_t v3867 = vaddq_s16(v3867_tmp, v3866); + int16x8_t v3868 = vaddq_s16(v3865, v3867); + int16x8_t v3869 = vsubq_s16(v3220, v3225); + int16x8_t v3870 = vsubq_s16(v3230, v3235); + int16x8_t v3871_tmp = vqrdmulhq_n_s16(v3870, 22555); + int16x8_t v3871 = vaddq_s16(v3871_tmp, v3870); + int16x8_t v3872 = vaddq_s16(v3869, v3871); + int16x8_t v3873 = vsubq_s16(v3198, v3203); + int16x8_t v3874 = vsubq_s16(v3208, v3213); + int16x8_t v3875_tmp = vqrdmulhq_n_s16(v3874, 24839); + int16x8_t v3875 = vaddq_s16(v3875_tmp, v3874); + int16x8_t v3876 = vaddq_s16(v3873, v3875); + int16x8_t v3877 = vsubq_s16(v3176, v3181); + int16x8_t v3878 = vsubq_s16(v3186, v3191); + int16x8_t v3879_tmp = vqrdmulhq_n_s16(v3878, 27330); + int16x8_t v3879 = vaddq_s16(v3879_tmp, v3878); + int16x8_t v3880 = vaddq_s16(v3877, v3879); + int16x8_t v3881 = vsubq_s16(v3154, v3159); + int16x8_t v3882 = vsubq_s16(v3164, v3169); + int16x8_t v3883_tmp = vqrdmulhq_n_s16(v3882, 30056); + int16x8_t v3883 = vaddq_s16(v3883_tmp, v3882); + int16x8_t v3884 = vaddq_s16(v3881, v3883); + int16x8_t v3885 = vsubq_s16(v3132, v3137); + int16x8_t v3886 = vsubq_s16(v3142, v3147); + int16x8_t v3887_tmp = vqrdmulhq_n_s16(v3886, 282); + int16x8_t v3887 = vmlaq_n_s16(v3887_tmp, v3886, 2); + int16x8_t v3888 = vaddq_s16(v3885, v3887); + int16x8_t v3889 = vsubq_s16(v3110, v3115); + int16x8_t v3890 = vsubq_s16(v3120, v3125); + int16x8_t v3891_tmp = vqrdmulhq_n_s16(v3890, 3588); + int16x8_t v3891 = vmlaq_n_s16(v3891_tmp, v3890, 2); + int16x8_t v3892 = vaddq_s16(v3889, v3891); + int16x8_t v3893 = vsubq_s16(v3088, v3093); + int16x8_t v3894 = vsubq_s16(v3098, v3103); + int16x8_t v3895_tmp = vqrdmulhq_n_s16(v3894, 7255); + int16x8_t v3895 = vmlaq_n_s16(v3895_tmp, v3894, 2); + int16x8_t v3896 = vaddq_s16(v3893, v3895); + int16x8_t v3897 = vsubq_s16(v3066, v3071); + int16x8_t v3898 = vsubq_s16(v3076, v3081); + int16x8_t v3899_tmp = vqrdmulhq_n_s16(v3898, 11344); + int16x8_t v3899 = vmlaq_n_s16(v3899_tmp, v3898, 2); + int16x8_t v3900 = vaddq_s16(v3897, v3899); + int16x8_t v3901 = vsubq_s16(v3044, v3049); + int16x8_t v3902 = vsubq_s16(v3054, v3059); + int16x8_t v3903_tmp = vqrdmulhq_n_s16(v3902, 15934); + int16x8_t v3903 = vmlaq_n_s16(v3903_tmp, v3902, 2); + int16x8_t v3904 = vaddq_s16(v3901, v3903); + int16x8_t v3905 = vsubq_s16(v3004, v3015); + int16x8_t v3906 = vsubq_s16(v3026, v3037); + int16x8_t v3907_tmp = vqrdmulhq_n_s16(v3906, 21120); + int16x8_t v3907 = vmlaq_n_s16(v3907_tmp, v3906, 2); + int16x8_t v3908 = vaddq_s16(v3905, v3907); + int16x8_t v3909 = vsubq_s16(v2958, v2969); + int16x8_t v3910 = vsubq_s16(v2980, v2991); + int16x8_t v3911_tmp = vqrdmulhq_n_s16(v3910, 27027); + int16x8_t v3911 = vmlaq_n_s16(v3911_tmp, v3910, 2); + int16x8_t v3912 = vaddq_s16(v3909, v3911); + int16x8_t v3913 = vsubq_s16(v2912, v2923); + int16x8_t v3914 = vsubq_s16(v2934, v2945); + int16x8_t v3915_tmp = vqrdmulhq_n_s16(v3914, 1045); + int16x8_t v3915 = vmlaq_n_s16(v3915_tmp, v3914, 3); + int16x8_t v3916 = vaddq_s16(v3913, v3915); + int16x8_t v3917 = vsubq_s16(v2866, v2877); + int16x8_t v3918 = vsubq_s16(v2888, v2899); + int16x8_t v3919_tmp = vqrdmulhq_n_s16(v3918, 8923); + int16x8_t v3919 = vmlaq_n_s16(v3919_tmp, v3918, 3); + int16x8_t v3920 = vaddq_s16(v3917, v3919); + int16x8_t v3921 = vsubq_s16(v2820, v2831); + int16x8_t v3922 = vsubq_s16(v2842, v2853); + int16x8_t v3923_tmp = vqrdmulhq_n_s16(v3922, 18177); + int16x8_t v3923 = vmlaq_n_s16(v3923_tmp, v3922, 3); + int16x8_t v3924 = vaddq_s16(v3921, v3923); + int16x8_t v3925 = vsubq_s16(v2774, v2785); + int16x8_t v3926 = vsubq_s16(v2796, v2807); + int16x8_t v3927_tmp = vqrdmulhq_n_s16(v3926, 29200); + int16x8_t v3927 = vmlaq_n_s16(v3927_tmp, v3926, 3); + int16x8_t v3928 = vaddq_s16(v3925, v3927); + int16x8_t v3929 = vsubq_s16(v2728, v2739); + int16x8_t v3930 = vsubq_s16(v2750, v2761); + int16x8_t v3931_tmp = vqrdmulhq_n_s16(v3930, 9782); + int16x8_t v3931 = vmlaq_n_s16(v3931_tmp, v3930, 4); + int16x8_t v3932 = vaddq_s16(v3929, v3931); + int16x8_t v3933 = vsubq_s16(v2682, v2693); + int16x8_t v3934 = vsubq_s16(v2704, v2715); + int16x8_t v3935_tmp = vqrdmulhq_n_s16(v3934, 26282); + int16x8_t v3935 = vmlaq_n_s16(v3935_tmp, v3934, 4); + int16x8_t v3936 = vaddq_s16(v3933, v3935); + int16x8_t v3937 = vsubq_s16(v2600, v2623); + int16x8_t v3938 = vsubq_s16(v2646, v2669); + int16x8_t v3939_tmp = vqrdmulhq_n_s16(v3938, 14423); + int16x8_t v3939 = vmlaq_n_s16(v3939_tmp, v3938, 5); + int16x8_t v3940 = vaddq_s16(v3937, v3939); + int16x8_t v3941 = vsubq_s16(v2506, v2529); + int16x8_t v3942 = vsubq_s16(v2552, v2575); + int16x8_t v3943_tmp = vqrdmulhq_n_s16(v3942, 9008); + int16x8_t v3943 = vmlaq_n_s16(v3943_tmp, v3942, 6); + int16x8_t v3944 = vaddq_s16(v3941, v3943); + int16x8_t v3945 = vsubq_s16(v2411, v2434); + int16x8_t v3946 = vsubq_s16(v2457, v2481); + int16x8_t v3947_tmp = vqrdmulhq_n_s16(v3946, 13552); + int16x8_t v3947 = vmlaq_n_s16(v3947_tmp, v3946, 7); + int16x8_t v3948 = vaddq_s16(v3945, v3947); + int16x8_t v3949 = vsubq_s16(v2317, v2340); + int16x8_t v3950 = vsubq_s16(v2363, v2386); + int16x8_t v3951_tmp = vqrdmulhq_n_s16(v3950, 1925); + int16x8_t v3951 = vmlaq_n_s16(v3951_tmp, v3950, 9); + int16x8_t v3952 = vaddq_s16(v3949, v3951); + int16x8_t v3953 = vsubq_s16(v2151, v2198); + int16x8_t v3954 = vsubq_s16(v2245, v2292); + int16x8_t v3955_tmp = vqrdmulhq_n_s16(v3954, 21123); + int16x8_t v3955 = vmlaq_n_s16(v3955_tmp, v3954, 11); + int16x8_t v3956 = vaddq_s16(v3953, v3955); + int16x8_t v3957 = vsubq_s16(v1961, v2008); + int16x8_t v3958 = vsubq_s16(v2055, v2102); + int16x8_t v3959_tmp = vqrdmulhq_n_s16(v3958, 9831); + int16x8_t v3959 = vmlaq_n_s16(v3959_tmp, v3958, 16); + int16x8_t v3960 = vaddq_s16(v3957, v3959); + int16x8_t v3961 = vsubq_s16(v1627, v1722); + int16x8_t v3962 = vsubq_s16(v1817, v1912); + int16x8_t v3963_tmp = vqrdmulhq_n_s16(v3962, 5373); + int16x8_t v3963 = vmlaq_n_s16(v3963_tmp, v3962, 27); + int16x8_t v3964 = vaddq_s16(v3961, v3963); + int16x8_t v3965 = vsubq_s16(v317, v700); + int16x8_t v3966 = vsubq_s16(v1146, v1530); + int16x8_t v3967_tmp = vqrdmulhq_n_s16(v3966, 15986); + int16x8_t v3967 = vmlaq_n_s16(v3967_tmp, v3966, 81); + int16x8_t v3968 = vaddq_s16(v3965, v3967); + int16x8_t v3969 = vsubq_s16(v3965, v3967); + int16x8_t v3970 = vsubq_s16(v3961, v3963); + int16x8_t v3971 = vsubq_s16(v3957, v3959); + int16x8_t v3972 = vsubq_s16(v3953, v3955); + int16x8_t v3973 = vsubq_s16(v3949, v3951); + int16x8_t v3974 = vsubq_s16(v3945, v3947); + int16x8_t v3975 = vsubq_s16(v3941, v3943); + int16x8_t v3976 = vsubq_s16(v3937, v3939); + int16x8_t v3977 = vsubq_s16(v3933, v3935); + int16x8_t v3978 = vsubq_s16(v3929, v3931); + int16x8_t v3979 = vsubq_s16(v3925, v3927); + int16x8_t v3980 = vsubq_s16(v3921, v3923); + int16x8_t v3981 = vsubq_s16(v3917, v3919); + int16x8_t v3982 = vsubq_s16(v3913, v3915); + int16x8_t v3983 = vsubq_s16(v3909, v3911); + int16x8_t v3984 = vsubq_s16(v3905, v3907); + int16x8_t v3985 = vsubq_s16(v3901, v3903); + int16x8_t v3986 = vsubq_s16(v3897, v3899); + int16x8_t v3987 = vsubq_s16(v3893, v3895); + int16x8_t v3988 = vsubq_s16(v3889, v3891); + int16x8_t v3989 = vsubq_s16(v3885, v3887); + int16x8_t v3990 = vsubq_s16(v3881, v3883); + int16x8_t v3991 = vsubq_s16(v3877, v3879); + int16x8_t v3992 = vsubq_s16(v3873, v3875); + int16x8_t v3993 = vsubq_s16(v3869, v3871); + int16x8_t v3994 = vsubq_s16(v3865, v3867); + int16x8_t v3995 = vsubq_s16(v3861, v3863); + int16x8_t v3996 = vsubq_s16(v3857, v3859); + int16x8_t v3997 = vsubq_s16(v3853, v3855); + int16x8_t v3998 = vsubq_s16(v3849, v3851); + int16x8_t v3999 = vsubq_s16(v3845, v3847); + int16x8_t v4000 = vsubq_s16(v3841, v3843); + int16x8_t v4001 = vsubq_s16(v3837, v3839); + int16x8_t v4002 = vsubq_s16(v3833, v3835); + int16x8_t v4003 = vsubq_s16(v3829, v3831); + int16x8_t v4004 = vsubq_s16(v3825, v3827); + int16x8_t v4005 = vsubq_s16(v3821, v3823); + int16x8_t v4006 = vsubq_s16(v3817, v3819); + int16x8_t v4007 = vsubq_s16(v3813, v3815); + int16x8_t v4008 = vsubq_s16(v3809, v3811); + int16x8_t v4009 = vsubq_s16(v3805, v3807); + int16x8_t v4010 = vsubq_s16(v3801, v3803); + int16x8_t v4011 = vsubq_s16(v3797, v3799); + int16x8_t v4012 = vsubq_s16(v3793, v3795); + int16x8_t v4013 = vsubq_s16(v3789, v3791); + int16x8_t v4014 = vsubq_s16(v3785, v3787); + int16x8_t v4015 = vsubq_s16(v3781, v3783); + int16x8_t v4016 = vsubq_s16(v3777, v3779); + int16x8_t v4017 = vsubq_s16(v3773, v3775); + int16x8_t v4018 = vsubq_s16(v3769, v3771); + int16x8_t v4019 = vsubq_s16(v3765, v3767); + int16x8_t v4020 = vsubq_s16(v3761, v3763); + int16x8_t v4021 = vsubq_s16(v3757, v3759); + int16x8_t v4022 = vsubq_s16(v3753, v3755); + int16x8_t v4023 = vsubq_s16(v3749, v3751); + int16x8_t v4024 = vsubq_s16(v3745, v3747); + int16x8_t v4025 = vsubq_s16(v3741, v3743); + int16x8_t v4026 = vsubq_s16(v3737, v3739); + int16x8_t v4027 = vsubq_s16(v3733, v3735); + int16x8_t v4028 = vsubq_s16(v3729, v3731); + int16x8_t v4029 = vsubq_s16(v3725, v3727); + int16x8_t v4030 = vsubq_s16(v3721, v3723); + int16x8_t v4031 = vsubq_s16(v3717, v3719); + int16x8_t v4032 = vsubq_s16(v3713, v3715); + int16x8_t v4033 = vsubq_s16(v3706, v3711); + int16x8_t v4034 = vsubq_s16(v3696, v3701); + int16x8_t v4035 = vsubq_s16(v3686, v3691); + int16x8_t v4036 = vsubq_s16(v3676, v3681); + int16x8_t v4037 = vsubq_s16(v3666, v3671); + int16x8_t v4038 = vsubq_s16(v3656, v3661); + int16x8_t v4039 = vsubq_s16(v3646, v3651); + int16x8_t v4040 = vsubq_s16(v3636, v3641); + int16x8_t v4041 = vsubq_s16(v3626, v3631); + int16x8_t v4042 = vsubq_s16(v3616, v3621); + int16x8_t v4043 = vsubq_s16(v3606, v3611); + int16x8_t v4044 = vsubq_s16(v3596, v3601); + int16x8_t v4045 = vsubq_s16(v3586, v3591); + int16x8_t v4046 = vsubq_s16(v3576, v3581); + int16x8_t v4047 = vsubq_s16(v3566, v3571); + int16x8_t v4048 = vsubq_s16(v3556, v3561); + int16x8_t v4049 = vsubq_s16(v3546, v3551); + int16x8_t v4050 = vsubq_s16(v3536, v3541); + int16x8_t v4051 = vsubq_s16(v3526, v3531); + int16x8_t v4052 = vsubq_s16(v3516, v3521); + int16x8_t v4053 = vsubq_s16(v3506, v3511); + int16x8_t v4054 = vsubq_s16(v3496, v3501); + int16x8_t v4055 = vsubq_s16(v3486, v3491); + int16x8_t v4056 = vsubq_s16(v3476, v3481); + int16x8_t v4057 = vsubq_s16(v3466, v3471); + int16x8_t v4058 = vsubq_s16(v3456, v3461); + int16x8_t v4059 = vsubq_s16(v3446, v3451); + int16x8_t v4060 = vsubq_s16(v3436, v3441); + int16x8_t v4061 = vsubq_s16(v3426, v3431); + int16x8_t v4062 = vsubq_s16(v3416, v3421); + int16x8_t v4063 = vsubq_s16(v3406, v3411); + int16x8_t v4064 = vsubq_s16(v3396, v3401); + int16x8_t v4065 = vsubq_s16(v3380, v3391); + int16x8_t v4066 = vsubq_s16(v3358, v3369); + int16x8_t v4067 = vsubq_s16(v3336, v3347); + int16x8_t v4068 = vsubq_s16(v3314, v3325); + int16x8_t v4069 = vsubq_s16(v3292, v3303); + int16x8_t v4070 = vsubq_s16(v3270, v3281); + int16x8_t v4071 = vsubq_s16(v3248, v3259); + int16x8_t v4072 = vsubq_s16(v3226, v3237); + int16x8_t v4073 = vsubq_s16(v3204, v3215); + int16x8_t v4074 = vsubq_s16(v3182, v3193); + int16x8_t v4075 = vsubq_s16(v3160, v3171); + int16x8_t v4076 = vsubq_s16(v3138, v3149); + int16x8_t v4077 = vsubq_s16(v3116, v3127); + int16x8_t v4078 = vsubq_s16(v3094, v3105); + int16x8_t v4079 = vsubq_s16(v3072, v3083); + int16x8_t v4080 = vsubq_s16(v3050, v3061); + int16x8_t v4081 = vsubq_s16(v3016, v3039); + int16x8_t v4082 = vsubq_s16(v2970, v2993); + int16x8_t v4083 = vsubq_s16(v2924, v2947); + int16x8_t v4084 = vsubq_s16(v2878, v2901); + int16x8_t v4085 = vsubq_s16(v2832, v2855); + int16x8_t v4086 = vsubq_s16(v2786, v2809); + int16x8_t v4087 = vsubq_s16(v2740, v2763); + int16x8_t v4088 = vsubq_s16(v2694, v2717); + int16x8_t v4089 = vsubq_s16(v2624, v2671); + int16x8_t v4090 = vsubq_s16(v2530, v2577); + int16x8_t v4091 = vsubq_s16(v2435, v2483); + int16x8_t v4092 = vsubq_s16(v2341, v2388); + int16x8_t v4093 = vsubq_s16(v2199, v2294); + int16x8_t v4094 = vsubq_s16(v2009, v2104); + int16x8_t v4095 = vsubq_s16(v1723, v1914); + int16x8_t v4096 = vsubq_s16(v701, v1532); + vst1q_s16(out + out_stride * 0 + i, v1533); + vst1q_s16(out + out_stride * 1 + i, v1915); + vst1q_s16(out + out_stride * 2 + i, v2105); + vst1q_s16(out + out_stride * 3 + i, v2295); + vst1q_s16(out + out_stride * 4 + i, v2389); + vst1q_s16(out + out_stride * 5 + i, v2484); + vst1q_s16(out + out_stride * 6 + i, v2578); + vst1q_s16(out + out_stride * 7 + i, v2672); + vst1q_s16(out + out_stride * 8 + i, v2718); + vst1q_s16(out + out_stride * 9 + i, v2764); + vst1q_s16(out + out_stride * 10 + i, v2810); + vst1q_s16(out + out_stride * 11 + i, v2856); + vst1q_s16(out + out_stride * 12 + i, v2902); + vst1q_s16(out + out_stride * 13 + i, v2948); + vst1q_s16(out + out_stride * 14 + i, v2994); + vst1q_s16(out + out_stride * 15 + i, v3040); + vst1q_s16(out + out_stride * 16 + i, v3062); + vst1q_s16(out + out_stride * 17 + i, v3084); + vst1q_s16(out + out_stride * 18 + i, v3106); + vst1q_s16(out + out_stride * 19 + i, v3128); + vst1q_s16(out + out_stride * 20 + i, v3150); + vst1q_s16(out + out_stride * 21 + i, v3172); + vst1q_s16(out + out_stride * 22 + i, v3194); + vst1q_s16(out + out_stride * 23 + i, v3216); + vst1q_s16(out + out_stride * 24 + i, v3238); + vst1q_s16(out + out_stride * 25 + i, v3260); + vst1q_s16(out + out_stride * 26 + i, v3282); + vst1q_s16(out + out_stride * 27 + i, v3304); + vst1q_s16(out + out_stride * 28 + i, v3326); + vst1q_s16(out + out_stride * 29 + i, v3348); + vst1q_s16(out + out_stride * 30 + i, v3370); + vst1q_s16(out + out_stride * 31 + i, v3392); + vst1q_s16(out + out_stride * 32 + i, v3402); + vst1q_s16(out + out_stride * 33 + i, v3412); + vst1q_s16(out + out_stride * 34 + i, v3422); + vst1q_s16(out + out_stride * 35 + i, v3432); + vst1q_s16(out + out_stride * 36 + i, v3442); + vst1q_s16(out + out_stride * 37 + i, v3452); + vst1q_s16(out + out_stride * 38 + i, v3462); + vst1q_s16(out + out_stride * 39 + i, v3472); + vst1q_s16(out + out_stride * 40 + i, v3482); + vst1q_s16(out + out_stride * 41 + i, v3492); + vst1q_s16(out + out_stride * 42 + i, v3502); + vst1q_s16(out + out_stride * 43 + i, v3512); + vst1q_s16(out + out_stride * 44 + i, v3522); + vst1q_s16(out + out_stride * 45 + i, v3532); + vst1q_s16(out + out_stride * 46 + i, v3542); + vst1q_s16(out + out_stride * 47 + i, v3552); + vst1q_s16(out + out_stride * 48 + i, v3562); + vst1q_s16(out + out_stride * 49 + i, v3572); + vst1q_s16(out + out_stride * 50 + i, v3582); + vst1q_s16(out + out_stride * 51 + i, v3592); + vst1q_s16(out + out_stride * 52 + i, v3602); + vst1q_s16(out + out_stride * 53 + i, v3612); + vst1q_s16(out + out_stride * 54 + i, v3622); + vst1q_s16(out + out_stride * 55 + i, v3632); + vst1q_s16(out + out_stride * 56 + i, v3642); + vst1q_s16(out + out_stride * 57 + i, v3652); + vst1q_s16(out + out_stride * 58 + i, v3662); + vst1q_s16(out + out_stride * 59 + i, v3672); + vst1q_s16(out + out_stride * 60 + i, v3682); + vst1q_s16(out + out_stride * 61 + i, v3692); + vst1q_s16(out + out_stride * 62 + i, v3702); + vst1q_s16(out + out_stride * 63 + i, v3712); + vst1q_s16(out + out_stride * 64 + i, v3716); + vst1q_s16(out + out_stride * 65 + i, v3720); + vst1q_s16(out + out_stride * 66 + i, v3724); + vst1q_s16(out + out_stride * 67 + i, v3728); + vst1q_s16(out + out_stride * 68 + i, v3732); + vst1q_s16(out + out_stride * 69 + i, v3736); + vst1q_s16(out + out_stride * 70 + i, v3740); + vst1q_s16(out + out_stride * 71 + i, v3744); + vst1q_s16(out + out_stride * 72 + i, v3748); + vst1q_s16(out + out_stride * 73 + i, v3752); + vst1q_s16(out + out_stride * 74 + i, v3756); + vst1q_s16(out + out_stride * 75 + i, v3760); + vst1q_s16(out + out_stride * 76 + i, v3764); + vst1q_s16(out + out_stride * 77 + i, v3768); + vst1q_s16(out + out_stride * 78 + i, v3772); + vst1q_s16(out + out_stride * 79 + i, v3776); + vst1q_s16(out + out_stride * 80 + i, v3780); + vst1q_s16(out + out_stride * 81 + i, v3784); + vst1q_s16(out + out_stride * 82 + i, v3788); + vst1q_s16(out + out_stride * 83 + i, v3792); + vst1q_s16(out + out_stride * 84 + i, v3796); + vst1q_s16(out + out_stride * 85 + i, v3800); + vst1q_s16(out + out_stride * 86 + i, v3804); + vst1q_s16(out + out_stride * 87 + i, v3808); + vst1q_s16(out + out_stride * 88 + i, v3812); + vst1q_s16(out + out_stride * 89 + i, v3816); + vst1q_s16(out + out_stride * 90 + i, v3820); + vst1q_s16(out + out_stride * 91 + i, v3824); + vst1q_s16(out + out_stride * 92 + i, v3828); + vst1q_s16(out + out_stride * 93 + i, v3832); + vst1q_s16(out + out_stride * 94 + i, v3836); + vst1q_s16(out + out_stride * 95 + i, v3840); + vst1q_s16(out + out_stride * 96 + i, v3844); + vst1q_s16(out + out_stride * 97 + i, v3848); + vst1q_s16(out + out_stride * 98 + i, v3852); + vst1q_s16(out + out_stride * 99 + i, v3856); + vst1q_s16(out + out_stride * 100 + i, v3860); + vst1q_s16(out + out_stride * 101 + i, v3864); + vst1q_s16(out + out_stride * 102 + i, v3868); + vst1q_s16(out + out_stride * 103 + i, v3872); + vst1q_s16(out + out_stride * 104 + i, v3876); + vst1q_s16(out + out_stride * 105 + i, v3880); + vst1q_s16(out + out_stride * 106 + i, v3884); + vst1q_s16(out + out_stride * 107 + i, v3888); + vst1q_s16(out + out_stride * 108 + i, v3892); + vst1q_s16(out + out_stride * 109 + i, v3896); + vst1q_s16(out + out_stride * 110 + i, v3900); + vst1q_s16(out + out_stride * 111 + i, v3904); + vst1q_s16(out + out_stride * 112 + i, v3908); + vst1q_s16(out + out_stride * 113 + i, v3912); + vst1q_s16(out + out_stride * 114 + i, v3916); + vst1q_s16(out + out_stride * 115 + i, v3920); + vst1q_s16(out + out_stride * 116 + i, v3924); + vst1q_s16(out + out_stride * 117 + i, v3928); + vst1q_s16(out + out_stride * 118 + i, v3932); + vst1q_s16(out + out_stride * 119 + i, v3936); + vst1q_s16(out + out_stride * 120 + i, v3940); + vst1q_s16(out + out_stride * 121 + i, v3944); + vst1q_s16(out + out_stride * 122 + i, v3948); + vst1q_s16(out + out_stride * 123 + i, v3952); + vst1q_s16(out + out_stride * 124 + i, v3956); + vst1q_s16(out + out_stride * 125 + i, v3960); + vst1q_s16(out + out_stride * 126 + i, v3964); + vst1q_s16(out + out_stride * 127 + i, v3968); + vst1q_s16(out + out_stride * 128 + i, v3969); + vst1q_s16(out + out_stride * 129 + i, v3970); + vst1q_s16(out + out_stride * 130 + i, v3971); + vst1q_s16(out + out_stride * 131 + i, v3972); + vst1q_s16(out + out_stride * 132 + i, v3973); + vst1q_s16(out + out_stride * 133 + i, v3974); + vst1q_s16(out + out_stride * 134 + i, v3975); + vst1q_s16(out + out_stride * 135 + i, v3976); + vst1q_s16(out + out_stride * 136 + i, v3977); + vst1q_s16(out + out_stride * 137 + i, v3978); + vst1q_s16(out + out_stride * 138 + i, v3979); + vst1q_s16(out + out_stride * 139 + i, v3980); + vst1q_s16(out + out_stride * 140 + i, v3981); + vst1q_s16(out + out_stride * 141 + i, v3982); + vst1q_s16(out + out_stride * 142 + i, v3983); + vst1q_s16(out + out_stride * 143 + i, v3984); + vst1q_s16(out + out_stride * 144 + i, v3985); + vst1q_s16(out + out_stride * 145 + i, v3986); + vst1q_s16(out + out_stride * 146 + i, v3987); + vst1q_s16(out + out_stride * 147 + i, v3988); + vst1q_s16(out + out_stride * 148 + i, v3989); + vst1q_s16(out + out_stride * 149 + i, v3990); + vst1q_s16(out + out_stride * 150 + i, v3991); + vst1q_s16(out + out_stride * 151 + i, v3992); + vst1q_s16(out + out_stride * 152 + i, v3993); + vst1q_s16(out + out_stride * 153 + i, v3994); + vst1q_s16(out + out_stride * 154 + i, v3995); + vst1q_s16(out + out_stride * 155 + i, v3996); + vst1q_s16(out + out_stride * 156 + i, v3997); + vst1q_s16(out + out_stride * 157 + i, v3998); + vst1q_s16(out + out_stride * 158 + i, v3999); + vst1q_s16(out + out_stride * 159 + i, v4000); + vst1q_s16(out + out_stride * 160 + i, v4001); + vst1q_s16(out + out_stride * 161 + i, v4002); + vst1q_s16(out + out_stride * 162 + i, v4003); + vst1q_s16(out + out_stride * 163 + i, v4004); + vst1q_s16(out + out_stride * 164 + i, v4005); + vst1q_s16(out + out_stride * 165 + i, v4006); + vst1q_s16(out + out_stride * 166 + i, v4007); + vst1q_s16(out + out_stride * 167 + i, v4008); + vst1q_s16(out + out_stride * 168 + i, v4009); + vst1q_s16(out + out_stride * 169 + i, v4010); + vst1q_s16(out + out_stride * 170 + i, v4011); + vst1q_s16(out + out_stride * 171 + i, v4012); + vst1q_s16(out + out_stride * 172 + i, v4013); + vst1q_s16(out + out_stride * 173 + i, v4014); + vst1q_s16(out + out_stride * 174 + i, v4015); + vst1q_s16(out + out_stride * 175 + i, v4016); + vst1q_s16(out + out_stride * 176 + i, v4017); + vst1q_s16(out + out_stride * 177 + i, v4018); + vst1q_s16(out + out_stride * 178 + i, v4019); + vst1q_s16(out + out_stride * 179 + i, v4020); + vst1q_s16(out + out_stride * 180 + i, v4021); + vst1q_s16(out + out_stride * 181 + i, v4022); + vst1q_s16(out + out_stride * 182 + i, v4023); + vst1q_s16(out + out_stride * 183 + i, v4024); + vst1q_s16(out + out_stride * 184 + i, v4025); + vst1q_s16(out + out_stride * 185 + i, v4026); + vst1q_s16(out + out_stride * 186 + i, v4027); + vst1q_s16(out + out_stride * 187 + i, v4028); + vst1q_s16(out + out_stride * 188 + i, v4029); + vst1q_s16(out + out_stride * 189 + i, v4030); + vst1q_s16(out + out_stride * 190 + i, v4031); + vst1q_s16(out + out_stride * 191 + i, v4032); + vst1q_s16(out + out_stride * 192 + i, v4033); + vst1q_s16(out + out_stride * 193 + i, v4034); + vst1q_s16(out + out_stride * 194 + i, v4035); + vst1q_s16(out + out_stride * 195 + i, v4036); + vst1q_s16(out + out_stride * 196 + i, v4037); + vst1q_s16(out + out_stride * 197 + i, v4038); + vst1q_s16(out + out_stride * 198 + i, v4039); + vst1q_s16(out + out_stride * 199 + i, v4040); + vst1q_s16(out + out_stride * 200 + i, v4041); + vst1q_s16(out + out_stride * 201 + i, v4042); + vst1q_s16(out + out_stride * 202 + i, v4043); + vst1q_s16(out + out_stride * 203 + i, v4044); + vst1q_s16(out + out_stride * 204 + i, v4045); + vst1q_s16(out + out_stride * 205 + i, v4046); + vst1q_s16(out + out_stride * 206 + i, v4047); + vst1q_s16(out + out_stride * 207 + i, v4048); + vst1q_s16(out + out_stride * 208 + i, v4049); + vst1q_s16(out + out_stride * 209 + i, v4050); + vst1q_s16(out + out_stride * 210 + i, v4051); + vst1q_s16(out + out_stride * 211 + i, v4052); + vst1q_s16(out + out_stride * 212 + i, v4053); + vst1q_s16(out + out_stride * 213 + i, v4054); + vst1q_s16(out + out_stride * 214 + i, v4055); + vst1q_s16(out + out_stride * 215 + i, v4056); + vst1q_s16(out + out_stride * 216 + i, v4057); + vst1q_s16(out + out_stride * 217 + i, v4058); + vst1q_s16(out + out_stride * 218 + i, v4059); + vst1q_s16(out + out_stride * 219 + i, v4060); + vst1q_s16(out + out_stride * 220 + i, v4061); + vst1q_s16(out + out_stride * 221 + i, v4062); + vst1q_s16(out + out_stride * 222 + i, v4063); + vst1q_s16(out + out_stride * 223 + i, v4064); + vst1q_s16(out + out_stride * 224 + i, v4065); + vst1q_s16(out + out_stride * 225 + i, v4066); + vst1q_s16(out + out_stride * 226 + i, v4067); + vst1q_s16(out + out_stride * 227 + i, v4068); + vst1q_s16(out + out_stride * 228 + i, v4069); + vst1q_s16(out + out_stride * 229 + i, v4070); + vst1q_s16(out + out_stride * 230 + i, v4071); + vst1q_s16(out + out_stride * 231 + i, v4072); + vst1q_s16(out + out_stride * 232 + i, v4073); + vst1q_s16(out + out_stride * 233 + i, v4074); + vst1q_s16(out + out_stride * 234 + i, v4075); + vst1q_s16(out + out_stride * 235 + i, v4076); + vst1q_s16(out + out_stride * 236 + i, v4077); + vst1q_s16(out + out_stride * 237 + i, v4078); + vst1q_s16(out + out_stride * 238 + i, v4079); + vst1q_s16(out + out_stride * 239 + i, v4080); + vst1q_s16(out + out_stride * 240 + i, v4081); + vst1q_s16(out + out_stride * 241 + i, v4082); + vst1q_s16(out + out_stride * 242 + i, v4083); + vst1q_s16(out + out_stride * 243 + i, v4084); + vst1q_s16(out + out_stride * 244 + i, v4085); + vst1q_s16(out + out_stride * 245 + i, v4086); + vst1q_s16(out + out_stride * 246 + i, v4087); + vst1q_s16(out + out_stride * 247 + i, v4088); + vst1q_s16(out + out_stride * 248 + i, v4089); + vst1q_s16(out + out_stride * 249 + i, v4090); + vst1q_s16(out + out_stride * 250 + i, v4091); + vst1q_s16(out + out_stride * 251 + i, v4092); + vst1q_s16(out + out_stride * 252 + i, v4093); + vst1q_s16(out + out_stride * 253 + i, v4094); + vst1q_s16(out + out_stride * 254 + i, v4095); + vst1q_s16(out + out_stride * 255 + i, v4096); + } +} diff --git a/third_party/jpeg-xl/lib/jxl/fast_dct32-inl.h b/third_party/jpeg-xl/lib/jxl/fast_dct32-inl.h new file mode 100644 index 0000000000..0f3b31cfea --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_dct32-inl.h @@ -0,0 +1,419 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<32>) { return 1; } + +void FastIDCT(FastDCTTag<32>, const int16_t* in, size_t in_stride, int16_t* out, + size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 16 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 8 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 24 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 20 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 12 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vld1q_s16(in + in_stride * 28 + i); + int16x8_t v17 = vaddq_s16(v16, v12); + int16x8_t v18 = vaddq_s16(v13, v10); + int16x8_t v19 = vaddq_s16(v17, v18); + int16x8_t v20 = vqrdmulhq_n_s16(v19, 17734); + int16x8_t v21 = vqrdmulhq_n_s16(v18, 25080); + int16x8_t v22 = vaddq_s16(v20, v21); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v27_tmp = vqrdmulhq_n_s16(v26, 13573); + int16x8_t v27 = vaddq_s16(v27_tmp, v26); + int16x8_t v28 = vld1q_s16(in + in_stride * 18 + i); + int16x8_t v29 = vld1q_s16(in + in_stride * 14 + i); + int16x8_t v30 = vaddq_s16(v28, v29); + int16x8_t v31 = vaddq_s16(v27, v30); + int16x8_t v32 = vld1q_s16(in + in_stride * 10 + i); + int16x8_t v33 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v34 = vaddq_s16(v32, v33); + int16x8_t v35 = vqrdmulhq_n_s16(v34, 25080); + int16x8_t v36 = vld1q_s16(in + in_stride * 26 + i); + int16x8_t v37 = vld1q_s16(in + in_stride * 22 + i); + int16x8_t v38 = vaddq_s16(v36, v37); + int16x8_t v39 = vaddq_s16(v38, v34); + int16x8_t v40 = vqrdmulhq_n_s16(v39, 17734); + int16x8_t v41 = vaddq_s16(v35, v40); + int16x8_t v42 = vaddq_s16(v31, v41); + int16x8_t v43 = vaddq_s16(v33, v26); + int16x8_t v44_tmp = vqrdmulhq_n_s16(v43, 13573); + int16x8_t v44 = vaddq_s16(v44_tmp, v43); + int16x8_t v45 = vaddq_s16(v29, v32); + int16x8_t v46 = vaddq_s16(v37, v28); + int16x8_t v47 = vaddq_s16(v45, v46); + int16x8_t v48 = vaddq_s16(v44, v47); + int16x8_t v49 = vqrdmulhq_n_s16(v48, 16705); + int16x8_t v50 = vld1q_s16(in + in_stride * 30 + i); + int16x8_t v51 = vaddq_s16(v50, v36); + int16x8_t v52 = vaddq_s16(v51, v46); + int16x8_t v53 = vqrdmulhq_n_s16(v52, 17734); + int16x8_t v54 = vaddq_s16(v45, v43); + int16x8_t v55_tmp = vqrdmulhq_n_s16(v54, 10045); + int16x8_t v55 = vaddq_s16(v55_tmp, v54); + int16x8_t v56 = vaddq_s16(v53, v55); + int16x8_t v57 = vqrdmulhq_n_s16(v56, 16705); + int16x8_t v58 = vaddq_s16(v49, v57); + int16x8_t v59 = vaddq_s16(v42, v58); + int16x8_t v60 = vqrdmulhq_n_s16(v59, 16463); + int16x8_t v61 = vaddq_s16(v25, v60); + int16x8_t v62 = vld1q_s16(in + in_stride * 13 + i); + int16x8_t v63 = vld1q_s16(in + in_stride * 11 + i); + int16x8_t v64 = vaddq_s16(v62, v63); + int16x8_t v65 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v66 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v67 = vaddq_s16(v65, v66); + int16x8_t v68 = vaddq_s16(v64, v67); + int16x8_t v69_tmp = vqrdmulhq_n_s16(v68, 10045); + int16x8_t v69 = vaddq_s16(v69_tmp, v68); + int16x8_t v70 = vld1q_s16(in + in_stride * 21 + i); + int16x8_t v71 = vld1q_s16(in + in_stride * 19 + i); + int16x8_t v72 = vaddq_s16(v70, v71); + int16x8_t v73 = vld1q_s16(in + in_stride * 29 + i); + int16x8_t v74 = vld1q_s16(in + in_stride * 27 + i); + int16x8_t v75 = vaddq_s16(v73, v74); + int16x8_t v76 = vaddq_s16(v72, v75); + int16x8_t v77 = vqrdmulhq_n_s16(v76, 17734); + int16x8_t v78 = vaddq_s16(v69, v77); + int16x8_t v79 = vqrdmulhq_n_s16(v78, 16705); + int16x8_t v80_tmp = vqrdmulhq_n_s16(v67, 13573); + int16x8_t v80 = vaddq_s16(v80_tmp, v67); + int16x8_t v81 = vaddq_s16(v64, v72); + int16x8_t v82 = vaddq_s16(v80, v81); + int16x8_t v83 = vqrdmulhq_n_s16(v82, 16705); + int16x8_t v84 = vaddq_s16(v79, v83); + int16x8_t v85 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v86_tmp = vqrdmulhq_n_s16(v85, 13573); + int16x8_t v86 = vaddq_s16(v86_tmp, v85); + int16x8_t v87 = vld1q_s16(in + in_stride * 17 + i); + int16x8_t v88 = vld1q_s16(in + in_stride * 15 + i); + int16x8_t v89 = vaddq_s16(v87, v88); + int16x8_t v90 = vaddq_s16(v86, v89); + int16x8_t v91 = vld1q_s16(in + in_stride * 9 + i); + int16x8_t v92 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v93 = vaddq_s16(v91, v92); + int16x8_t v94 = vqrdmulhq_n_s16(v93, 25080); + int16x8_t v95 = vld1q_s16(in + in_stride * 25 + i); + int16x8_t v96 = vld1q_s16(in + in_stride * 23 + i); + int16x8_t v97 = vaddq_s16(v95, v96); + int16x8_t v98 = vaddq_s16(v97, v93); + int16x8_t v99 = vqrdmulhq_n_s16(v98, 17734); + int16x8_t v100 = vaddq_s16(v94, v99); + int16x8_t v101 = vaddq_s16(v90, v100); + int16x8_t v102 = vaddq_s16(v84, v101); + int16x8_t v103 = vaddq_s16(v92, v65); + int16x8_t v104 = vaddq_s16(v66, v85); + int16x8_t v105 = vaddq_s16(v103, v104); + int16x8_t v106_tmp = vqrdmulhq_n_s16(v105, 13573); + int16x8_t v106 = vaddq_s16(v106_tmp, v105); + int16x8_t v107 = vaddq_s16(v96, v70); + int16x8_t v108 = vaddq_s16(v71, v87); + int16x8_t v109 = vaddq_s16(v107, v108); + int16x8_t v110 = vaddq_s16(v63, v91); + int16x8_t v111 = vaddq_s16(v88, v62); + int16x8_t v112 = vaddq_s16(v110, v111); + int16x8_t v113 = vaddq_s16(v109, v112); + int16x8_t v114 = vaddq_s16(v106, v113); + int16x8_t v115 = vqrdmulhq_n_s16(v114, 16705); + int16x8_t v116 = vaddq_s16(v112, v105); + int16x8_t v117 = vqrdmulhq_n_s16(v116, 25080); + int16x8_t v118 = vqrdmulhq_n_s16(v116, 17734); + int16x8_t v119 = vaddq_s16(v74, v95); + int16x8_t v120 = vld1q_s16(in + in_stride * 31 + i); + int16x8_t v121 = vaddq_s16(v120, v73); + int16x8_t v122 = vaddq_s16(v119, v121); + int16x8_t v123 = vaddq_s16(v122, v109); + int16x8_t v124 = vqrdmulhq_n_s16(v123, 17734); + int16x8_t v125 = vaddq_s16(v118, v124); + int16x8_t v126 = vaddq_s16(v117, v125); + int16x8_t v127 = vqrdmulhq_n_s16(v126, 16705); + int16x8_t v128 = vaddq_s16(v115, v127); + int16x8_t v129 = vqrdmulhq_n_s16(v128, 16463); + int16x8_t v130_tmp = vqrdmulhq_n_s16(v104, 13573); + int16x8_t v130 = vaddq_s16(v130_tmp, v104); + int16x8_t v131 = vaddq_s16(v108, v111); + int16x8_t v132 = vaddq_s16(v130, v131); + int16x8_t v133 = vaddq_s16(v119, v107); + int16x8_t v134 = vqrdmulhq_n_s16(v133, 17734); + int16x8_t v135 = vaddq_s16(v110, v103); + int16x8_t v136_tmp = vqrdmulhq_n_s16(v135, 10045); + int16x8_t v136 = vaddq_s16(v136_tmp, v135); + int16x8_t v137 = vaddq_s16(v134, v136); + int16x8_t v138 = vaddq_s16(v132, v137); + int16x8_t v139 = vqrdmulhq_n_s16(v138, 16463); + int16x8_t v140 = vaddq_s16(v129, v139); + int16x8_t v141 = vaddq_s16(v102, v140); + int16x8_t v142 = vqrdmulhq_n_s16(v141, 16404); + int16x8_t v143 = vaddq_s16(v61, v142); + int16x8_t v144 = vsubq_s16(v0, v1); + int16x8_t v145 = vsubq_s16(v4, v6); + int16x8_t v146_tmp = vqrdmulhq_n_s16(v145, 10045); + int16x8_t v146 = vaddq_s16(v146_tmp, v145); + int16x8_t v147 = vaddq_s16(v144, v146); + int16x8_t v148 = vsubq_s16(v11, v14); + int16x8_t v149 = vqrdmulhq_n_s16(v18, 17734); + int16x8_t v150_tmp = vqrdmulhq_n_s16(v17, 10045); + int16x8_t v150 = vaddq_s16(v150_tmp, v17); + int16x8_t v151 = vsubq_s16(v149, v150); + int16x8_t v152 = vaddq_s16(v148, v151); + int16x8_t v153 = vqrdmulhq_n_s16(v152, 19705); + int16x8_t v154 = vaddq_s16(v147, v153); + int16x8_t v155 = vsubq_s16(v27, v30); + int16x8_t v156 = vqrdmulhq_n_s16(v34, 17734); + int16x8_t v157_tmp = vqrdmulhq_n_s16(v38, 10045); + int16x8_t v157 = vaddq_s16(v157_tmp, v38); + int16x8_t v158 = vsubq_s16(v156, v157); + int16x8_t v159 = vaddq_s16(v155, v158); + int16x8_t v160 = vqrdmulhq_n_s16(v54, 13573); + int16x8_t v161 = vsubq_s16(v160, v52); + int16x8_t v162 = vqrdmulhq_n_s16(v161, 25746); + int16x8_t v163 = vsubq_s16(v44, v47); + int16x8_t v164 = vqrdmulhq_n_s16(v163, 19705); + int16x8_t v165 = vaddq_s16(v162, v164); + int16x8_t v166 = vaddq_s16(v159, v165); + int16x8_t v167 = vqrdmulhq_n_s16(v166, 17121); + int16x8_t v168 = vaddq_s16(v154, v167); + int16x8_t v169 = vsubq_s16(v86, v89); + int16x8_t v170 = vqrdmulhq_n_s16(v93, 17734); + int16x8_t v171_tmp = vqrdmulhq_n_s16(v97, 10045); + int16x8_t v171 = vaddq_s16(v171_tmp, v97); + int16x8_t v172 = vsubq_s16(v170, v171); + int16x8_t v173 = vaddq_s16(v169, v172); + int16x8_t v174 = vsubq_s16(v80, v81); + int16x8_t v175 = vqrdmulhq_n_s16(v174, 19705); + int16x8_t v176 = vqrdmulhq_n_s16(v68, 13573); + int16x8_t v177 = vsubq_s16(v176, v76); + int16x8_t v178 = vqrdmulhq_n_s16(v177, 25746); + int16x8_t v179 = vaddq_s16(v175, v178); + int16x8_t v180 = vaddq_s16(v173, v179); + int16x8_t v181 = vsubq_s16(v130, v131); + int16x8_t v182 = vqrdmulhq_n_s16(v135, 13573); + int16x8_t v183 = vsubq_s16(v182, v133); + int16x8_t v184_tmp = vqrdmulhq_n_s16(v183, 10045); + int16x8_t v184 = vaddq_s16(v184_tmp, v183); + int16x8_t v185 = vaddq_s16(v181, v184); + int16x8_t v186 = vqrdmulhq_n_s16(v185, 17121); + int16x8_t v187 = vqrdmulhq_n_s16(v105, 27867); + int16x8_t v188 = vqrdmulhq_n_s16(v113, 19705); + int16x8_t v189 = vsubq_s16(v187, v188); + int16x8_t v190 = vqrdmulhq_n_s16(v116, 13573); + int16x8_t v191 = vsubq_s16(v190, v123); + int16x8_t v192 = vqrdmulhq_n_s16(v191, 25746); + int16x8_t v193 = vaddq_s16(v189, v192); + int16x8_t v194 = vqrdmulhq_n_s16(v193, 17121); + int16x8_t v195 = vaddq_s16(v186, v194); + int16x8_t v196 = vaddq_s16(v180, v195); + int16x8_t v197 = vqrdmulhq_n_s16(v196, 16563); + int16x8_t v198 = vaddq_s16(v168, v197); + int16x8_t v199 = vsubq_s16(v144, v146); + int16x8_t v200 = vsubq_s16(v148, v151); + int16x8_t v201 = vqrdmulhq_n_s16(v200, 29490); + int16x8_t v202 = vaddq_s16(v199, v201); + int16x8_t v203 = vsubq_s16(v155, v158); + int16x8_t v204 = vqrdmulhq_n_s16(v163, 29490); + int16x8_t v205_tmp = vqrdmulhq_n_s16(v161, 5763); + int16x8_t v205 = vaddq_s16(v205_tmp, v161); + int16x8_t v206 = vsubq_s16(v204, v205); + int16x8_t v207 = vaddq_s16(v203, v206); + int16x8_t v208 = vqrdmulhq_n_s16(v207, 18578); + int16x8_t v209 = vaddq_s16(v202, v208); + int16x8_t v210 = vsubq_s16(v169, v172); + int16x8_t v211 = vqrdmulhq_n_s16(v174, 29490); + int16x8_t v212_tmp = vqrdmulhq_n_s16(v177, 5763); + int16x8_t v212 = vaddq_s16(v212_tmp, v177); + int16x8_t v213 = vsubq_s16(v211, v212); + int16x8_t v214 = vaddq_s16(v210, v213); + int16x8_t v215 = vsubq_s16(v181, v184); + int16x8_t v216 = vqrdmulhq_n_s16(v215, 18578); + int16x8_t v217 = vqrdmulhq_n_s16(v189, 27803); + int16x8_t v218 = vqrdmulhq_n_s16(v191, 21845); + int16x8_t v219 = vsubq_s16(v217, v218); + int16x8_t v220 = vaddq_s16(v216, v219); + int16x8_t v221 = vaddq_s16(v214, v220); + int16x8_t v222 = vqrdmulhq_n_s16(v221, 16890); + int16x8_t v223 = vaddq_s16(v209, v222); + int16x8_t v224 = vsubq_s16(v2, v8); + int16x8_t v225 = vsubq_s16(v15, v22); + int16x8_t v226_tmp = vqrdmulhq_n_s16(v225, 18446); + int16x8_t v226 = vmlaq_n_s16(v226_tmp, v225, 2); + int16x8_t v227 = vaddq_s16(v224, v226); + int16x8_t v228 = vsubq_s16(v31, v41); + int16x8_t v229 = vsubq_s16(v48, v56); + int16x8_t v230_tmp = vqrdmulhq_n_s16(v229, 18446); + int16x8_t v230 = vmlaq_n_s16(v230_tmp, v229, 2); + int16x8_t v231 = vaddq_s16(v228, v230); + int16x8_t v232 = vqrdmulhq_n_s16(v231, 21195); + int16x8_t v233 = vaddq_s16(v227, v232); + int16x8_t v234 = vsubq_s16(v82, v78); + int16x8_t v235_tmp = vqrdmulhq_n_s16(v234, 18446); + int16x8_t v235 = vmlaq_n_s16(v235_tmp, v234, 2); + int16x8_t v236 = vsubq_s16(v90, v100); + int16x8_t v237 = vaddq_s16(v235, v236); + int16x8_t v238 = vsubq_s16(v132, v137); + int16x8_t v239 = vsubq_s16(v114, v126); + int16x8_t v240_tmp = vqrdmulhq_n_s16(v239, 18446); + int16x8_t v240 = vmlaq_n_s16(v240_tmp, v239, 2); + int16x8_t v241 = vaddq_s16(v238, v240); + int16x8_t v242 = vqrdmulhq_n_s16(v241, 21195); + int16x8_t v243 = vaddq_s16(v237, v242); + int16x8_t v244 = vqrdmulhq_n_s16(v243, 17401); + int16x8_t v245 = vaddq_s16(v233, v244); + int16x8_t v246 = vsubq_s16(v228, v230); + int16x8_t v247 = vqrdmulhq_n_s16(v246, 25826); + int16x8_t v248 = vsubq_s16(v224, v226); + int16x8_t v249 = vaddq_s16(v247, v248); + int16x8_t v250 = vsubq_s16(v238, v240); + int16x8_t v251 = vqrdmulhq_n_s16(v250, 25826); + int16x8_t v252 = vsubq_s16(v236, v235); + int16x8_t v253 = vaddq_s16(v251, v252); + int16x8_t v254 = vqrdmulhq_n_s16(v253, 18124); + int16x8_t v255 = vaddq_s16(v249, v254); + int16x8_t v256 = vsubq_s16(v199, v201); + int16x8_t v257 = vsubq_s16(v203, v206); + int16x8_t v258_tmp = vqrdmulhq_n_s16(v257, 1988); + int16x8_t v258 = vaddq_s16(v258_tmp, v257); + int16x8_t v259 = vaddq_s16(v256, v258); + int16x8_t v260 = vsubq_s16(v210, v213); + int16x8_t v261_tmp = vqrdmulhq_n_s16(v219, 25030); + int16x8_t v261 = vaddq_s16(v261_tmp, v219); + int16x8_t v262 = vsubq_s16(v215, v261); + int16x8_t v263_tmp = vqrdmulhq_n_s16(v262, 1988); + int16x8_t v263 = vaddq_s16(v263_tmp, v262); + int16x8_t v264 = vaddq_s16(v260, v263); + int16x8_t v265 = vqrdmulhq_n_s16(v264, 19102); + int16x8_t v266 = vaddq_s16(v259, v265); + int16x8_t v267 = vsubq_s16(v147, v153); + int16x8_t v268 = vsubq_s16(v159, v165); + int16x8_t v269_tmp = vqrdmulhq_n_s16(v268, 23673); + int16x8_t v269 = vaddq_s16(v269_tmp, v268); + int16x8_t v270 = vaddq_s16(v267, v269); + int16x8_t v271 = vsubq_s16(v173, v179); + int16x8_t v272 = vsubq_s16(v185, v193); + int16x8_t v273_tmp = vqrdmulhq_n_s16(v272, 23673); + int16x8_t v273 = vaddq_s16(v273_tmp, v272); + int16x8_t v274 = vaddq_s16(v271, v273); + int16x8_t v275 = vqrdmulhq_n_s16(v274, 20398); + int16x8_t v276 = vaddq_s16(v270, v275); + int16x8_t v277 = vsubq_s16(v9, v24); + int16x8_t v278 = vsubq_s16(v42, v58); + int16x8_t v279_tmp = vqrdmulhq_n_s16(v278, 3314); + int16x8_t v279 = vmlaq_n_s16(v279_tmp, v278, 5); + int16x8_t v280 = vaddq_s16(v277, v279); + int16x8_t v281 = vsubq_s16(v138, v128); + int16x8_t v282_tmp = vqrdmulhq_n_s16(v281, 3314); + int16x8_t v282 = vmlaq_n_s16(v282_tmp, v281, 5); + int16x8_t v283 = vsubq_s16(v101, v84); + int16x8_t v284 = vaddq_s16(v282, v283); + int16x8_t v285 = vqrdmulhq_n_s16(v284, 22112); + int16x8_t v286 = vaddq_s16(v280, v285); + int16x8_t v287 = vsubq_s16(v277, v279); + int16x8_t v288 = vsubq_s16(v283, v282); + int16x8_t v289 = vqrdmulhq_n_s16(v288, 24397); + int16x8_t v290 = vaddq_s16(v287, v289); + int16x8_t v291 = vsubq_s16(v267, v269); + int16x8_t v292 = vsubq_s16(v271, v273); + int16x8_t v293 = vqrdmulhq_n_s16(v292, 27504); + int16x8_t v294 = vaddq_s16(v291, v293); + int16x8_t v295 = vsubq_s16(v260, v263); + int16x8_t v296 = vqrdmulhq_n_s16(v295, 31869); + int16x8_t v297 = vsubq_s16(v256, v258); + int16x8_t v298 = vaddq_s16(v296, v297); + int16x8_t v299 = vsubq_s16(v248, v247); + int16x8_t v300 = vsubq_s16(v252, v251); + int16x8_t v301_tmp = vqrdmulhq_n_s16(v300, 5552); + int16x8_t v301 = vaddq_s16(v301_tmp, v300); + int16x8_t v302 = vaddq_s16(v299, v301); + int16x8_t v303 = vsubq_s16(v227, v232); + int16x8_t v304 = vsubq_s16(v237, v242); + int16x8_t v305_tmp = vqrdmulhq_n_s16(v304, 15865); + int16x8_t v305 = vaddq_s16(v305_tmp, v304); + int16x8_t v306 = vaddq_s16(v303, v305); + int16x8_t v307 = vsubq_s16(v202, v208); + int16x8_t v308 = vsubq_s16(v214, v220); + int16x8_t v309_tmp = vqrdmulhq_n_s16(v308, 1893); + int16x8_t v309 = vmlaq_n_s16(v309_tmp, v308, 2); + int16x8_t v310 = vaddq_s16(v307, v309); + int16x8_t v311 = vsubq_s16(v154, v167); + int16x8_t v312 = vsubq_s16(v180, v195); + int16x8_t v313_tmp = vqrdmulhq_n_s16(v312, 13357); + int16x8_t v313 = vmlaq_n_s16(v313_tmp, v312, 3); + int16x8_t v314 = vaddq_s16(v311, v313); + int16x8_t v315 = vsubq_s16(v102, v140); + int16x8_t v316_tmp = vqrdmulhq_n_s16(v315, 6226); + int16x8_t v316 = vmlaq_n_s16(v316_tmp, v315, 10); + int16x8_t v317 = vsubq_s16(v25, v60); + int16x8_t v318 = vaddq_s16(v316, v317); + int16x8_t v319 = vsubq_s16(v317, v316); + int16x8_t v320 = vsubq_s16(v311, v313); + int16x8_t v321 = vsubq_s16(v307, v309); + int16x8_t v322 = vsubq_s16(v303, v305); + int16x8_t v323 = vsubq_s16(v299, v301); + int16x8_t v324 = vsubq_s16(v297, v296); + int16x8_t v325 = vsubq_s16(v291, v293); + int16x8_t v326 = vsubq_s16(v287, v289); + int16x8_t v327 = vsubq_s16(v280, v285); + int16x8_t v328 = vsubq_s16(v270, v275); + int16x8_t v329 = vsubq_s16(v259, v265); + int16x8_t v330 = vsubq_s16(v249, v254); + int16x8_t v331 = vsubq_s16(v233, v244); + int16x8_t v332 = vsubq_s16(v209, v222); + int16x8_t v333 = vsubq_s16(v168, v197); + int16x8_t v334 = vsubq_s16(v61, v142); + vst1q_s16(out + out_stride * 0 + i, v143); + vst1q_s16(out + out_stride * 1 + i, v198); + vst1q_s16(out + out_stride * 2 + i, v223); + vst1q_s16(out + out_stride * 3 + i, v245); + vst1q_s16(out + out_stride * 4 + i, v255); + vst1q_s16(out + out_stride * 5 + i, v266); + vst1q_s16(out + out_stride * 6 + i, v276); + vst1q_s16(out + out_stride * 7 + i, v286); + vst1q_s16(out + out_stride * 8 + i, v290); + vst1q_s16(out + out_stride * 9 + i, v294); + vst1q_s16(out + out_stride * 10 + i, v298); + vst1q_s16(out + out_stride * 11 + i, v302); + vst1q_s16(out + out_stride * 12 + i, v306); + vst1q_s16(out + out_stride * 13 + i, v310); + vst1q_s16(out + out_stride * 14 + i, v314); + vst1q_s16(out + out_stride * 15 + i, v318); + vst1q_s16(out + out_stride * 16 + i, v319); + vst1q_s16(out + out_stride * 17 + i, v320); + vst1q_s16(out + out_stride * 18 + i, v321); + vst1q_s16(out + out_stride * 19 + i, v322); + vst1q_s16(out + out_stride * 20 + i, v323); + vst1q_s16(out + out_stride * 21 + i, v324); + vst1q_s16(out + out_stride * 22 + i, v325); + vst1q_s16(out + out_stride * 23 + i, v326); + vst1q_s16(out + out_stride * 24 + i, v327); + vst1q_s16(out + out_stride * 25 + i, v328); + vst1q_s16(out + out_stride * 26 + i, v329); + vst1q_s16(out + out_stride * 27 + i, v330); + vst1q_s16(out + out_stride * 28 + i, v331); + vst1q_s16(out + out_stride * 29 + i, v332); + vst1q_s16(out + out_stride * 30 + i, v333); + vst1q_s16(out + out_stride * 31 + i, v334); + } +} diff --git a/third_party/jpeg-xl/lib/jxl/fast_dct64-inl.h b/third_party/jpeg-xl/lib/jxl/fast_dct64-inl.h new file mode 100644 index 0000000000..400da1a9de --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_dct64-inl.h @@ -0,0 +1,985 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<64>) { return 1; } + +void FastIDCT(FastDCTTag<64>, const int16_t* in, size_t in_stride, int16_t* out, + size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 32 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 16 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 48 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 8 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 40 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 24 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vld1q_s16(in + in_stride * 56 + i); + int16x8_t v17 = vaddq_s16(v16, v12); + int16x8_t v18 = vaddq_s16(v13, v10); + int16x8_t v19 = vaddq_s16(v17, v18); + int16x8_t v20 = vqrdmulhq_n_s16(v19, 17734); + int16x8_t v21 = vqrdmulhq_n_s16(v18, 25080); + int16x8_t v22 = vaddq_s16(v20, v21); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v27_tmp = vqrdmulhq_n_s16(v26, 13573); + int16x8_t v27 = vaddq_s16(v27_tmp, v26); + int16x8_t v28 = vld1q_s16(in + in_stride * 36 + i); + int16x8_t v29 = vld1q_s16(in + in_stride * 28 + i); + int16x8_t v30 = vaddq_s16(v28, v29); + int16x8_t v31 = vaddq_s16(v27, v30); + int16x8_t v32 = vld1q_s16(in + in_stride * 20 + i); + int16x8_t v33 = vld1q_s16(in + in_stride * 12 + i); + int16x8_t v34 = vaddq_s16(v32, v33); + int16x8_t v35 = vqrdmulhq_n_s16(v34, 25080); + int16x8_t v36 = vld1q_s16(in + in_stride * 52 + i); + int16x8_t v37 = vld1q_s16(in + in_stride * 44 + i); + int16x8_t v38 = vaddq_s16(v36, v37); + int16x8_t v39 = vaddq_s16(v38, v34); + int16x8_t v40 = vqrdmulhq_n_s16(v39, 17734); + int16x8_t v41 = vaddq_s16(v35, v40); + int16x8_t v42 = vaddq_s16(v31, v41); + int16x8_t v43 = vaddq_s16(v33, v26); + int16x8_t v44_tmp = vqrdmulhq_n_s16(v43, 13573); + int16x8_t v44 = vaddq_s16(v44_tmp, v43); + int16x8_t v45 = vaddq_s16(v37, v28); + int16x8_t v46 = vaddq_s16(v29, v32); + int16x8_t v47 = vaddq_s16(v45, v46); + int16x8_t v48 = vaddq_s16(v44, v47); + int16x8_t v49 = vqrdmulhq_n_s16(v48, 16705); + int16x8_t v50 = vaddq_s16(v46, v43); + int16x8_t v51_tmp = vqrdmulhq_n_s16(v50, 10045); + int16x8_t v51 = vaddq_s16(v51_tmp, v50); + int16x8_t v52 = vld1q_s16(in + in_stride * 60 + i); + int16x8_t v53 = vaddq_s16(v52, v36); + int16x8_t v54 = vaddq_s16(v53, v45); + int16x8_t v55 = vqrdmulhq_n_s16(v54, 17734); + int16x8_t v56 = vaddq_s16(v51, v55); + int16x8_t v57 = vqrdmulhq_n_s16(v56, 16705); + int16x8_t v58 = vaddq_s16(v49, v57); + int16x8_t v59 = vaddq_s16(v42, v58); + int16x8_t v60 = vqrdmulhq_n_s16(v59, 16463); + int16x8_t v61 = vaddq_s16(v25, v60); + int16x8_t v62 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v63_tmp = vqrdmulhq_n_s16(v62, 13573); + int16x8_t v63 = vaddq_s16(v63_tmp, v62); + int16x8_t v64 = vld1q_s16(in + in_stride * 34 + i); + int16x8_t v65 = vld1q_s16(in + in_stride * 30 + i); + int16x8_t v66 = vaddq_s16(v64, v65); + int16x8_t v67 = vaddq_s16(v63, v66); + int16x8_t v68 = vld1q_s16(in + in_stride * 18 + i); + int16x8_t v69 = vld1q_s16(in + in_stride * 14 + i); + int16x8_t v70 = vaddq_s16(v68, v69); + int16x8_t v71 = vqrdmulhq_n_s16(v70, 25080); + int16x8_t v72 = vld1q_s16(in + in_stride * 50 + i); + int16x8_t v73 = vld1q_s16(in + in_stride * 46 + i); + int16x8_t v74 = vaddq_s16(v72, v73); + int16x8_t v75 = vaddq_s16(v74, v70); + int16x8_t v76 = vqrdmulhq_n_s16(v75, 17734); + int16x8_t v77 = vaddq_s16(v71, v76); + int16x8_t v78 = vaddq_s16(v67, v77); + int16x8_t v79 = vld1q_s16(in + in_stride * 10 + i); + int16x8_t v80 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v81 = vaddq_s16(v79, v80); + int16x8_t v82_tmp = vqrdmulhq_n_s16(v81, 13573); + int16x8_t v82 = vaddq_s16(v82_tmp, v81); + int16x8_t v83 = vld1q_s16(in + in_stride * 42 + i); + int16x8_t v84 = vld1q_s16(in + in_stride * 38 + i); + int16x8_t v85 = vaddq_s16(v83, v84); + int16x8_t v86 = vld1q_s16(in + in_stride * 26 + i); + int16x8_t v87 = vld1q_s16(in + in_stride * 22 + i); + int16x8_t v88 = vaddq_s16(v86, v87); + int16x8_t v89 = vaddq_s16(v85, v88); + int16x8_t v90 = vaddq_s16(v82, v89); + int16x8_t v91 = vqrdmulhq_n_s16(v90, 16705); + int16x8_t v92 = vaddq_s16(v88, v81); + int16x8_t v93_tmp = vqrdmulhq_n_s16(v92, 10045); + int16x8_t v93 = vaddq_s16(v93_tmp, v92); + int16x8_t v94 = vld1q_s16(in + in_stride * 58 + i); + int16x8_t v95 = vld1q_s16(in + in_stride * 54 + i); + int16x8_t v96 = vaddq_s16(v94, v95); + int16x8_t v97 = vaddq_s16(v96, v85); + int16x8_t v98 = vqrdmulhq_n_s16(v97, 17734); + int16x8_t v99 = vaddq_s16(v93, v98); + int16x8_t v100 = vqrdmulhq_n_s16(v99, 16705); + int16x8_t v101 = vaddq_s16(v91, v100); + int16x8_t v102 = vaddq_s16(v78, v101); + int16x8_t v103 = vaddq_s16(v69, v79); + int16x8_t v104 = vaddq_s16(v80, v62); + int16x8_t v105 = vaddq_s16(v103, v104); + int16x8_t v106_tmp = vqrdmulhq_n_s16(v105, 13573); + int16x8_t v106 = vaddq_s16(v106_tmp, v105); + int16x8_t v107 = vaddq_s16(v73, v83); + int16x8_t v108 = vaddq_s16(v84, v64); + int16x8_t v109 = vaddq_s16(v107, v108); + int16x8_t v110 = vaddq_s16(v65, v86); + int16x8_t v111 = vaddq_s16(v87, v68); + int16x8_t v112 = vaddq_s16(v110, v111); + int16x8_t v113 = vaddq_s16(v109, v112); + int16x8_t v114 = vaddq_s16(v106, v113); + int16x8_t v115 = vqrdmulhq_n_s16(v114, 16705); + int16x8_t v116 = vaddq_s16(v112, v105); + int16x8_t v117 = vqrdmulhq_n_s16(v116, 25080); + int16x8_t v118 = vqrdmulhq_n_s16(v116, 17734); + int16x8_t v119 = vld1q_s16(in + in_stride * 62 + i); + int16x8_t v120 = vaddq_s16(v119, v94); + int16x8_t v121 = vaddq_s16(v95, v72); + int16x8_t v122 = vaddq_s16(v120, v121); + int16x8_t v123 = vaddq_s16(v122, v109); + int16x8_t v124 = vqrdmulhq_n_s16(v123, 17734); + int16x8_t v125 = vaddq_s16(v118, v124); + int16x8_t v126 = vaddq_s16(v117, v125); + int16x8_t v127 = vqrdmulhq_n_s16(v126, 16705); + int16x8_t v128 = vaddq_s16(v115, v127); + int16x8_t v129 = vqrdmulhq_n_s16(v128, 16463); + int16x8_t v130_tmp = vqrdmulhq_n_s16(v104, 13573); + int16x8_t v130 = vaddq_s16(v130_tmp, v104); + int16x8_t v131 = vaddq_s16(v108, v110); + int16x8_t v132 = vaddq_s16(v130, v131); + int16x8_t v133 = vaddq_s16(v111, v103); + int16x8_t v134_tmp = vqrdmulhq_n_s16(v133, 10045); + int16x8_t v134 = vaddq_s16(v134_tmp, v133); + int16x8_t v135 = vaddq_s16(v121, v107); + int16x8_t v136 = vqrdmulhq_n_s16(v135, 17734); + int16x8_t v137 = vaddq_s16(v134, v136); + int16x8_t v138 = vaddq_s16(v132, v137); + int16x8_t v139 = vqrdmulhq_n_s16(v138, 16463); + int16x8_t v140 = vaddq_s16(v129, v139); + int16x8_t v141 = vaddq_s16(v102, v140); + int16x8_t v142 = vqrdmulhq_n_s16(v141, 16404); + int16x8_t v143 = vaddq_s16(v61, v142); + int16x8_t v144 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v145_tmp = vqrdmulhq_n_s16(v144, 13573); + int16x8_t v145 = vaddq_s16(v145_tmp, v144); + int16x8_t v146 = vld1q_s16(in + in_stride * 33 + i); + int16x8_t v147 = vld1q_s16(in + in_stride * 31 + i); + int16x8_t v148 = vaddq_s16(v146, v147); + int16x8_t v149 = vaddq_s16(v145, v148); + int16x8_t v150 = vld1q_s16(in + in_stride * 17 + i); + int16x8_t v151 = vld1q_s16(in + in_stride * 15 + i); + int16x8_t v152 = vaddq_s16(v150, v151); + int16x8_t v153 = vqrdmulhq_n_s16(v152, 25080); + int16x8_t v154 = vld1q_s16(in + in_stride * 49 + i); + int16x8_t v155 = vld1q_s16(in + in_stride * 47 + i); + int16x8_t v156 = vaddq_s16(v154, v155); + int16x8_t v157 = vaddq_s16(v156, v152); + int16x8_t v158 = vqrdmulhq_n_s16(v157, 17734); + int16x8_t v159 = vaddq_s16(v153, v158); + int16x8_t v160 = vaddq_s16(v149, v159); + int16x8_t v161 = vld1q_s16(in + in_stride * 9 + i); + int16x8_t v162 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v163 = vaddq_s16(v161, v162); + int16x8_t v164_tmp = vqrdmulhq_n_s16(v163, 13573); + int16x8_t v164 = vaddq_s16(v164_tmp, v163); + int16x8_t v165 = vld1q_s16(in + in_stride * 41 + i); + int16x8_t v166 = vld1q_s16(in + in_stride * 39 + i); + int16x8_t v167 = vaddq_s16(v165, v166); + int16x8_t v168 = vld1q_s16(in + in_stride * 25 + i); + int16x8_t v169 = vld1q_s16(in + in_stride * 23 + i); + int16x8_t v170 = vaddq_s16(v168, v169); + int16x8_t v171 = vaddq_s16(v167, v170); + int16x8_t v172 = vaddq_s16(v164, v171); + int16x8_t v173 = vqrdmulhq_n_s16(v172, 16705); + int16x8_t v174 = vaddq_s16(v170, v163); + int16x8_t v175_tmp = vqrdmulhq_n_s16(v174, 10045); + int16x8_t v175 = vaddq_s16(v175_tmp, v174); + int16x8_t v176 = vld1q_s16(in + in_stride * 57 + i); + int16x8_t v177 = vld1q_s16(in + in_stride * 55 + i); + int16x8_t v178 = vaddq_s16(v176, v177); + int16x8_t v179 = vaddq_s16(v178, v167); + int16x8_t v180 = vqrdmulhq_n_s16(v179, 17734); + int16x8_t v181 = vaddq_s16(v175, v180); + int16x8_t v182 = vqrdmulhq_n_s16(v181, 16705); + int16x8_t v183 = vaddq_s16(v173, v182); + int16x8_t v184 = vaddq_s16(v160, v183); + int16x8_t v185 = vld1q_s16(in + in_stride * 37 + i); + int16x8_t v186 = vld1q_s16(in + in_stride * 35 + i); + int16x8_t v187 = vaddq_s16(v185, v186); + int16x8_t v188 = vld1q_s16(in + in_stride * 45 + i); + int16x8_t v189 = vld1q_s16(in + in_stride * 43 + i); + int16x8_t v190 = vaddq_s16(v188, v189); + int16x8_t v191 = vaddq_s16(v187, v190); + int16x8_t v192 = vld1q_s16(in + in_stride * 29 + i); + int16x8_t v193 = vld1q_s16(in + in_stride * 27 + i); + int16x8_t v194 = vaddq_s16(v192, v193); + int16x8_t v195 = vld1q_s16(in + in_stride * 21 + i); + int16x8_t v196 = vld1q_s16(in + in_stride * 19 + i); + int16x8_t v197 = vaddq_s16(v195, v196); + int16x8_t v198 = vaddq_s16(v194, v197); + int16x8_t v199 = vaddq_s16(v191, v198); + int16x8_t v200 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v201 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v202 = vaddq_s16(v200, v201); + int16x8_t v203 = vld1q_s16(in + in_stride * 13 + i); + int16x8_t v204 = vld1q_s16(in + in_stride * 11 + i); + int16x8_t v205 = vaddq_s16(v203, v204); + int16x8_t v206 = vaddq_s16(v202, v205); + int16x8_t v207_tmp = vqrdmulhq_n_s16(v206, 13573); + int16x8_t v207 = vaddq_s16(v207_tmp, v206); + int16x8_t v208 = vaddq_s16(v199, v207); + int16x8_t v209 = vqrdmulhq_n_s16(v208, 16705); + int16x8_t v210 = vaddq_s16(v198, v206); + int16x8_t v211 = vqrdmulhq_n_s16(v210, 25080); + int16x8_t v212 = vqrdmulhq_n_s16(v210, 17734); + int16x8_t v213 = vld1q_s16(in + in_stride * 53 + i); + int16x8_t v214 = vld1q_s16(in + in_stride * 51 + i); + int16x8_t v215 = vaddq_s16(v213, v214); + int16x8_t v216 = vld1q_s16(in + in_stride * 61 + i); + int16x8_t v217 = vld1q_s16(in + in_stride * 59 + i); + int16x8_t v218 = vaddq_s16(v216, v217); + int16x8_t v219 = vaddq_s16(v215, v218); + int16x8_t v220 = vaddq_s16(v219, v191); + int16x8_t v221 = vqrdmulhq_n_s16(v220, 17734); + int16x8_t v222 = vaddq_s16(v212, v221); + int16x8_t v223 = vaddq_s16(v211, v222); + int16x8_t v224 = vqrdmulhq_n_s16(v223, 16705); + int16x8_t v225 = vaddq_s16(v209, v224); + int16x8_t v226 = vqrdmulhq_n_s16(v225, 16463); + int16x8_t v227_tmp = vqrdmulhq_n_s16(v202, 13573); + int16x8_t v227 = vaddq_s16(v227_tmp, v202); + int16x8_t v228 = vaddq_s16(v187, v194); + int16x8_t v229 = vaddq_s16(v227, v228); + int16x8_t v230 = vaddq_s16(v215, v190); + int16x8_t v231 = vqrdmulhq_n_s16(v230, 17734); + int16x8_t v232 = vaddq_s16(v197, v205); + int16x8_t v233_tmp = vqrdmulhq_n_s16(v232, 10045); + int16x8_t v233 = vaddq_s16(v233_tmp, v232); + int16x8_t v234 = vaddq_s16(v231, v233); + int16x8_t v235 = vaddq_s16(v229, v234); + int16x8_t v236 = vqrdmulhq_n_s16(v235, 16463); + int16x8_t v237 = vaddq_s16(v226, v236); + int16x8_t v238 = vaddq_s16(v184, v237); + int16x8_t v239 = vaddq_s16(v201, v144); + int16x8_t v240_tmp = vqrdmulhq_n_s16(v239, 13573); + int16x8_t v240 = vaddq_s16(v240_tmp, v239); + int16x8_t v241 = vaddq_s16(v186, v146); + int16x8_t v242 = vaddq_s16(v147, v192); + int16x8_t v243 = vaddq_s16(v241, v242); + int16x8_t v244 = vaddq_s16(v240, v243); + int16x8_t v245 = vaddq_s16(v196, v150); + int16x8_t v246 = vaddq_s16(v151, v203); + int16x8_t v247 = vaddq_s16(v245, v246); + int16x8_t v248_tmp = vqrdmulhq_n_s16(v247, 10045); + int16x8_t v248 = vaddq_s16(v248_tmp, v247); + int16x8_t v249 = vaddq_s16(v155, v188); + int16x8_t v250 = vaddq_s16(v214, v154); + int16x8_t v251 = vaddq_s16(v249, v250); + int16x8_t v252 = vqrdmulhq_n_s16(v251, 17734); + int16x8_t v253 = vaddq_s16(v248, v252); + int16x8_t v254 = vaddq_s16(v244, v253); + int16x8_t v255 = vaddq_s16(v204, v161); + int16x8_t v256 = vaddq_s16(v162, v200); + int16x8_t v257 = vaddq_s16(v255, v256); + int16x8_t v258_tmp = vqrdmulhq_n_s16(v257, 13573); + int16x8_t v258 = vaddq_s16(v258_tmp, v257); + int16x8_t v259 = vaddq_s16(v189, v165); + int16x8_t v260 = vaddq_s16(v166, v185); + int16x8_t v261 = vaddq_s16(v259, v260); + int16x8_t v262 = vaddq_s16(v169, v195); + int16x8_t v263 = vaddq_s16(v193, v168); + int16x8_t v264 = vaddq_s16(v262, v263); + int16x8_t v265 = vaddq_s16(v261, v264); + int16x8_t v266 = vaddq_s16(v258, v265); + int16x8_t v267 = vqrdmulhq_n_s16(v266, 16705); + int16x8_t v268 = vaddq_s16(v264, v257); + int16x8_t v269 = vqrdmulhq_n_s16(v268, 25080); + int16x8_t v270 = vaddq_s16(v217, v176); + int16x8_t v271 = vaddq_s16(v177, v213); + int16x8_t v272 = vaddq_s16(v270, v271); + int16x8_t v273 = vaddq_s16(v272, v261); + int16x8_t v274 = vqrdmulhq_n_s16(v273, 17734); + int16x8_t v275 = vqrdmulhq_n_s16(v268, 17734); + int16x8_t v276 = vaddq_s16(v274, v275); + int16x8_t v277 = vaddq_s16(v269, v276); + int16x8_t v278 = vqrdmulhq_n_s16(v277, 16705); + int16x8_t v279 = vaddq_s16(v267, v278); + int16x8_t v280 = vaddq_s16(v254, v279); + int16x8_t v281 = vqrdmulhq_n_s16(v280, 16404); + int16x8_t v282 = vaddq_s16(v256, v239); + int16x8_t v283_tmp = vqrdmulhq_n_s16(v282, 13573); + int16x8_t v283 = vaddq_s16(v283_tmp, v282); + int16x8_t v284 = vaddq_s16(v260, v241); + int16x8_t v285 = vaddq_s16(v242, v263); + int16x8_t v286 = vaddq_s16(v284, v285); + int16x8_t v287 = vaddq_s16(v283, v286); + int16x8_t v288 = vaddq_s16(v262, v245); + int16x8_t v289 = vaddq_s16(v246, v255); + int16x8_t v290 = vaddq_s16(v288, v289); + int16x8_t v291 = vqrdmulhq_n_s16(v290, 25080); + int16x8_t v292 = vqrdmulhq_n_s16(v290, 17734); + int16x8_t v293 = vaddq_s16(v271, v250); + int16x8_t v294 = vaddq_s16(v249, v259); + int16x8_t v295 = vaddq_s16(v293, v294); + int16x8_t v296 = vqrdmulhq_n_s16(v295, 17734); + int16x8_t v297 = vaddq_s16(v292, v296); + int16x8_t v298 = vaddq_s16(v291, v297); + int16x8_t v299 = vaddq_s16(v287, v298); + int16x8_t v300 = vqrdmulhq_n_s16(v299, 16463); + int16x8_t v301 = vaddq_s16(v289, v282); + int16x8_t v302 = vqrdmulhq_n_s16(v301, 23624); + int16x8_t v303 = vaddq_s16(v294, v284); + int16x8_t v304 = vqrdmulhq_n_s16(v303, 19705); + int16x8_t v305 = vaddq_s16(v285, v288); + int16x8_t v306 = vqrdmulhq_n_s16(v305, 19705); + int16x8_t v307 = vaddq_s16(v304, v306); + int16x8_t v308 = vqrdmulhq_n_s16(v307, 27779); + int16x8_t v309 = vaddq_s16(v302, v308); + int16x8_t v310 = vaddq_s16(v305, v301); + int16x8_t v311 = vqrdmulhq_n_s16(v310, 25080); + int16x8_t v312 = vqrdmulhq_n_s16(v310, 17734); + int16x8_t v313 = vld1q_s16(in + in_stride * 63 + i); + int16x8_t v314 = vaddq_s16(v313, v216); + int16x8_t v315 = vaddq_s16(v314, v270); + int16x8_t v316 = vaddq_s16(v315, v293); + int16x8_t v317 = vqrdmulhq_n_s16(v316, 25746); + int16x8_t v318 = vqrdmulhq_n_s16(v303, 25746); + int16x8_t v319 = vaddq_s16(v317, v318); + int16x8_t v320 = vqrdmulhq_n_s16(v319, 22571); + int16x8_t v321 = vaddq_s16(v312, v320); + int16x8_t v322 = vaddq_s16(v311, v321); + int16x8_t v323 = vqrdmulhq_n_s16(v322, 16705); + int16x8_t v324 = vaddq_s16(v309, v323); + int16x8_t v325 = vqrdmulhq_n_s16(v324, 16463); + int16x8_t v326 = vaddq_s16(v300, v325); + int16x8_t v327 = vqrdmulhq_n_s16(v326, 16404); + int16x8_t v328 = vaddq_s16(v281, v327); + int16x8_t v329 = vaddq_s16(v238, v328); + int16x8_t v330 = vqrdmulhq_n_s16(v329, 16389); + int16x8_t v331 = vaddq_s16(v143, v330); + int16x8_t v332 = vsubq_s16(v82, v89); + int16x8_t v333 = vqrdmulhq_n_s16(v332, 19705); + int16x8_t v334 = vqrdmulhq_n_s16(v92, 13573); + int16x8_t v335 = vsubq_s16(v334, v97); + int16x8_t v336 = vqrdmulhq_n_s16(v335, 25746); + int16x8_t v337 = vaddq_s16(v333, v336); + int16x8_t v338 = vsubq_s16(v63, v66); + int16x8_t v339 = vqrdmulhq_n_s16(v70, 17734); + int16x8_t v340_tmp = vqrdmulhq_n_s16(v74, 10045); + int16x8_t v340 = vaddq_s16(v340_tmp, v74); + int16x8_t v341 = vsubq_s16(v339, v340); + int16x8_t v342 = vaddq_s16(v338, v341); + int16x8_t v343 = vaddq_s16(v337, v342); + int16x8_t v344 = vsubq_s16(v130, v131); + int16x8_t v345 = vqrdmulhq_n_s16(v133, 13573); + int16x8_t v346 = vsubq_s16(v345, v135); + int16x8_t v347_tmp = vqrdmulhq_n_s16(v346, 10045); + int16x8_t v347 = vaddq_s16(v347_tmp, v346); + int16x8_t v348 = vaddq_s16(v344, v347); + int16x8_t v349 = vqrdmulhq_n_s16(v348, 17121); + int16x8_t v350 = vqrdmulhq_n_s16(v105, 27867); + int16x8_t v351 = vqrdmulhq_n_s16(v113, 19705); + int16x8_t v352 = vsubq_s16(v350, v351); + int16x8_t v353 = vqrdmulhq_n_s16(v116, 13573); + int16x8_t v354 = vsubq_s16(v353, v123); + int16x8_t v355 = vqrdmulhq_n_s16(v354, 25746); + int16x8_t v356 = vaddq_s16(v352, v355); + int16x8_t v357 = vqrdmulhq_n_s16(v356, 17121); + int16x8_t v358 = vaddq_s16(v349, v357); + int16x8_t v359 = vaddq_s16(v343, v358); + int16x8_t v360 = vqrdmulhq_n_s16(v359, 16563); + int16x8_t v361 = vsubq_s16(v27, v30); + int16x8_t v362 = vqrdmulhq_n_s16(v34, 17734); + int16x8_t v363_tmp = vqrdmulhq_n_s16(v38, 10045); + int16x8_t v363 = vaddq_s16(v363_tmp, v38); + int16x8_t v364 = vsubq_s16(v362, v363); + int16x8_t v365 = vaddq_s16(v361, v364); + int16x8_t v366 = vsubq_s16(v44, v47); + int16x8_t v367 = vqrdmulhq_n_s16(v366, 19705); + int16x8_t v368 = vqrdmulhq_n_s16(v50, 13573); + int16x8_t v369 = vsubq_s16(v368, v54); + int16x8_t v370 = vqrdmulhq_n_s16(v369, 25746); + int16x8_t v371 = vaddq_s16(v367, v370); + int16x8_t v372 = vaddq_s16(v365, v371); + int16x8_t v373 = vqrdmulhq_n_s16(v372, 17121); + int16x8_t v374 = vsubq_s16(v0, v1); + int16x8_t v375 = vsubq_s16(v4, v6); + int16x8_t v376_tmp = vqrdmulhq_n_s16(v375, 10045); + int16x8_t v376 = vaddq_s16(v376_tmp, v375); + int16x8_t v377 = vaddq_s16(v374, v376); + int16x8_t v378 = vsubq_s16(v11, v14); + int16x8_t v379 = vqrdmulhq_n_s16(v18, 17734); + int16x8_t v380_tmp = vqrdmulhq_n_s16(v17, 10045); + int16x8_t v380 = vaddq_s16(v380_tmp, v17); + int16x8_t v381 = vsubq_s16(v379, v380); + int16x8_t v382 = vaddq_s16(v378, v381); + int16x8_t v383 = vqrdmulhq_n_s16(v382, 19705); + int16x8_t v384 = vaddq_s16(v377, v383); + int16x8_t v385 = vaddq_s16(v373, v384); + int16x8_t v386 = vaddq_s16(v360, v385); + int16x8_t v387 = vsubq_s16(v145, v148); + int16x8_t v388 = vqrdmulhq_n_s16(v152, 17734); + int16x8_t v389_tmp = vqrdmulhq_n_s16(v156, 10045); + int16x8_t v389 = vaddq_s16(v389_tmp, v156); + int16x8_t v390 = vsubq_s16(v388, v389); + int16x8_t v391 = vaddq_s16(v387, v390); + int16x8_t v392 = vsubq_s16(v164, v171); + int16x8_t v393 = vqrdmulhq_n_s16(v392, 19705); + int16x8_t v394 = vqrdmulhq_n_s16(v174, 13573); + int16x8_t v395 = vsubq_s16(v394, v179); + int16x8_t v396 = vqrdmulhq_n_s16(v395, 25746); + int16x8_t v397 = vaddq_s16(v393, v396); + int16x8_t v398 = vaddq_s16(v391, v397); + int16x8_t v399 = vsubq_s16(v227, v228); + int16x8_t v400 = vqrdmulhq_n_s16(v232, 13573); + int16x8_t v401 = vsubq_s16(v400, v230); + int16x8_t v402_tmp = vqrdmulhq_n_s16(v401, 10045); + int16x8_t v402 = vaddq_s16(v402_tmp, v401); + int16x8_t v403 = vaddq_s16(v399, v402); + int16x8_t v404 = vqrdmulhq_n_s16(v403, 17121); + int16x8_t v405 = vqrdmulhq_n_s16(v206, 27867); + int16x8_t v406 = vqrdmulhq_n_s16(v199, 19705); + int16x8_t v407 = vsubq_s16(v405, v406); + int16x8_t v408 = vqrdmulhq_n_s16(v210, 13573); + int16x8_t v409 = vsubq_s16(v408, v220); + int16x8_t v410 = vqrdmulhq_n_s16(v409, 25746); + int16x8_t v411 = vaddq_s16(v407, v410); + int16x8_t v412 = vqrdmulhq_n_s16(v411, 17121); + int16x8_t v413 = vaddq_s16(v404, v412); + int16x8_t v414 = vaddq_s16(v398, v413); + int16x8_t v415 = vsubq_s16(v240, v243); + int16x8_t v416 = vqrdmulhq_n_s16(v247, 13573); + int16x8_t v417 = vsubq_s16(v416, v251); + int16x8_t v418_tmp = vqrdmulhq_n_s16(v417, 10045); + int16x8_t v418 = vaddq_s16(v418_tmp, v417); + int16x8_t v419 = vaddq_s16(v415, v418); + int16x8_t v420 = vqrdmulhq_n_s16(v257, 27867); + int16x8_t v421 = vqrdmulhq_n_s16(v265, 19705); + int16x8_t v422 = vsubq_s16(v420, v421); + int16x8_t v423 = vqrdmulhq_n_s16(v268, 13573); + int16x8_t v424 = vsubq_s16(v423, v273); + int16x8_t v425 = vqrdmulhq_n_s16(v424, 25746); + int16x8_t v426 = vaddq_s16(v422, v425); + int16x8_t v427 = vaddq_s16(v419, v426); + int16x8_t v428 = vqrdmulhq_n_s16(v427, 16563); + int16x8_t v429 = vqrdmulhq_n_s16(v301, 27867); + int16x8_t v430 = vsubq_s16(v429, v307); + int16x8_t v431 = vqrdmulhq_n_s16(v310, 10664); + int16x8_t v432 = vsubq_s16(v431, v319); + int16x8_t v433 = vaddq_s16(v430, v432); + int16x8_t v434 = vqrdmulhq_n_s16(v433, 17121); + int16x8_t v435 = vsubq_s16(v283, v286); + int16x8_t v436 = vqrdmulhq_n_s16(v290, 13573); + int16x8_t v437 = vsubq_s16(v436, v295); + int16x8_t v438_tmp = vqrdmulhq_n_s16(v437, 10045); + int16x8_t v438 = vaddq_s16(v438_tmp, v437); + int16x8_t v439 = vaddq_s16(v435, v438); + int16x8_t v440 = vqrdmulhq_n_s16(v439, 17121); + int16x8_t v441 = vaddq_s16(v434, v440); + int16x8_t v442 = vqrdmulhq_n_s16(v441, 16563); + int16x8_t v443 = vaddq_s16(v428, v442); + int16x8_t v444 = vaddq_s16(v414, v443); + int16x8_t v445 = vqrdmulhq_n_s16(v444, 16429); + int16x8_t v446 = vaddq_s16(v386, v445); + int16x8_t v447 = vsubq_s16(v374, v376); + int16x8_t v448 = vsubq_s16(v378, v381); + int16x8_t v449 = vqrdmulhq_n_s16(v448, 29490); + int16x8_t v450 = vaddq_s16(v447, v449); + int16x8_t v451 = vsubq_s16(v361, v364); + int16x8_t v452 = vqrdmulhq_n_s16(v366, 29490); + int16x8_t v453_tmp = vqrdmulhq_n_s16(v369, 5763); + int16x8_t v453 = vaddq_s16(v453_tmp, v369); + int16x8_t v454 = vsubq_s16(v452, v453); + int16x8_t v455 = vaddq_s16(v451, v454); + int16x8_t v456 = vqrdmulhq_n_s16(v455, 18578); + int16x8_t v457 = vaddq_s16(v450, v456); + int16x8_t v458 = vsubq_s16(v338, v341); + int16x8_t v459 = vqrdmulhq_n_s16(v332, 29490); + int16x8_t v460_tmp = vqrdmulhq_n_s16(v335, 5763); + int16x8_t v460 = vaddq_s16(v460_tmp, v335); + int16x8_t v461 = vsubq_s16(v459, v460); + int16x8_t v462 = vaddq_s16(v458, v461); + int16x8_t v463 = vqrdmulhq_n_s16(v352, 27803); + int16x8_t v464 = vqrdmulhq_n_s16(v354, 21845); + int16x8_t v465 = vsubq_s16(v463, v464); + int16x8_t v466 = vsubq_s16(v344, v347); + int16x8_t v467 = vqrdmulhq_n_s16(v466, 18578); + int16x8_t v468 = vaddq_s16(v465, v467); + int16x8_t v469 = vaddq_s16(v462, v468); + int16x8_t v470 = vqrdmulhq_n_s16(v469, 16890); + int16x8_t v471 = vaddq_s16(v457, v470); + int16x8_t v472 = vsubq_s16(v415, v418); + int16x8_t v473_tmp = vqrdmulhq_n_s16(v422, 16273); + int16x8_t v473 = vaddq_s16(v473_tmp, v422); + int16x8_t v474_tmp = vqrdmulhq_n_s16(v424, 5763); + int16x8_t v474 = vaddq_s16(v474_tmp, v424); + int16x8_t v475 = vsubq_s16(v473, v474); + int16x8_t v476 = vaddq_s16(v472, v475); + int16x8_t v477 = vqrdmulhq_n_s16(v476, 16890); + int16x8_t v478 = vqrdmulhq_n_s16(v435, 20261); + int16x8_t v479 = vqrdmulhq_n_s16(v437, 26472); + int16x8_t v480 = vsubq_s16(v478, v479); + int16x8_t v481 = vqrdmulhq_n_s16(v480, 30046); + int16x8_t v482 = vqrdmulhq_n_s16(v430, 30322); + int16x8_t v483 = vqrdmulhq_n_s16(v432, 30322); + int16x8_t v484 = vsubq_s16(v482, v483); + int16x8_t v485 = vqrdmulhq_n_s16(v484, 30046); + int16x8_t v486 = vaddq_s16(v481, v485); + int16x8_t v487 = vqrdmulhq_n_s16(v486, 16890); + int16x8_t v488 = vaddq_s16(v477, v487); + int16x8_t v489 = vsubq_s16(v387, v390); + int16x8_t v490 = vqrdmulhq_n_s16(v392, 29490); + int16x8_t v491_tmp = vqrdmulhq_n_s16(v395, 5763); + int16x8_t v491 = vaddq_s16(v491_tmp, v395); + int16x8_t v492 = vsubq_s16(v490, v491); + int16x8_t v493 = vaddq_s16(v489, v492); + int16x8_t v494 = vsubq_s16(v399, v402); + int16x8_t v495 = vqrdmulhq_n_s16(v494, 18578); + int16x8_t v496 = vqrdmulhq_n_s16(v407, 27803); + int16x8_t v497 = vqrdmulhq_n_s16(v409, 21845); + int16x8_t v498 = vsubq_s16(v496, v497); + int16x8_t v499 = vaddq_s16(v495, v498); + int16x8_t v500 = vaddq_s16(v493, v499); + int16x8_t v501 = vaddq_s16(v488, v500); + int16x8_t v502 = vqrdmulhq_n_s16(v501, 16508); + int16x8_t v503 = vaddq_s16(v471, v502); + int16x8_t v504 = vsubq_s16(v2, v8); + int16x8_t v505 = vsubq_s16(v15, v22); + int16x8_t v506_tmp = vqrdmulhq_n_s16(v505, 18446); + int16x8_t v506 = vmlaq_n_s16(v506_tmp, v505, 2); + int16x8_t v507 = vaddq_s16(v504, v506); + int16x8_t v508 = vsubq_s16(v31, v41); + int16x8_t v509 = vsubq_s16(v48, v56); + int16x8_t v510_tmp = vqrdmulhq_n_s16(v509, 18446); + int16x8_t v510 = vmlaq_n_s16(v510_tmp, v509, 2); + int16x8_t v511 = vaddq_s16(v508, v510); + int16x8_t v512 = vqrdmulhq_n_s16(v511, 21195); + int16x8_t v513 = vaddq_s16(v507, v512); + int16x8_t v514 = vsubq_s16(v67, v77); + int16x8_t v515 = vsubq_s16(v90, v99); + int16x8_t v516_tmp = vqrdmulhq_n_s16(v515, 18446); + int16x8_t v516 = vmlaq_n_s16(v516_tmp, v515, 2); + int16x8_t v517 = vaddq_s16(v514, v516); + int16x8_t v518 = vsubq_s16(v114, v126); + int16x8_t v519_tmp = vqrdmulhq_n_s16(v518, 18446); + int16x8_t v519 = vmlaq_n_s16(v519_tmp, v518, 2); + int16x8_t v520 = vsubq_s16(v132, v137); + int16x8_t v521 = vaddq_s16(v519, v520); + int16x8_t v522 = vqrdmulhq_n_s16(v521, 21195); + int16x8_t v523 = vaddq_s16(v517, v522); + int16x8_t v524 = vqrdmulhq_n_s16(v523, 17401); + int16x8_t v525 = vaddq_s16(v513, v524); + int16x8_t v526 = vsubq_s16(v172, v181); + int16x8_t v527_tmp = vqrdmulhq_n_s16(v526, 18446); + int16x8_t v527 = vmlaq_n_s16(v527_tmp, v526, 2); + int16x8_t v528 = vsubq_s16(v149, v159); + int16x8_t v529 = vaddq_s16(v527, v528); + int16x8_t v530 = vsubq_s16(v229, v234); + int16x8_t v531 = vsubq_s16(v208, v223); + int16x8_t v532_tmp = vqrdmulhq_n_s16(v531, 18446); + int16x8_t v532 = vmlaq_n_s16(v532_tmp, v531, 2); + int16x8_t v533 = vaddq_s16(v530, v532); + int16x8_t v534 = vqrdmulhq_n_s16(v533, 21195); + int16x8_t v535 = vaddq_s16(v529, v534); + int16x8_t v536 = vsubq_s16(v244, v253); + int16x8_t v537 = vsubq_s16(v266, v277); + int16x8_t v538_tmp = vqrdmulhq_n_s16(v537, 18446); + int16x8_t v538 = vmlaq_n_s16(v538_tmp, v537, 2); + int16x8_t v539 = vaddq_s16(v536, v538); + int16x8_t v540 = vqrdmulhq_n_s16(v539, 17401); + int16x8_t v541 = vqrdmulhq_n_s16(v287, 25826); + int16x8_t v542 = vqrdmulhq_n_s16(v298, 25826); + int16x8_t v543 = vsubq_s16(v541, v542); + int16x8_t v544 = vqrdmulhq_n_s16(v543, 14281); + int16x8_t v545_tmp = vqrdmulhq_n_s16(v309, 31509); + int16x8_t v545 = vaddq_s16(v545_tmp, v309); + int16x8_t v546 = vsubq_s16(v545, v322); + int16x8_t v547 = vqrdmulhq_n_s16(v546, 28847); + int16x8_t v548 = vaddq_s16(v544, v547); + int16x8_t v549 = vaddq_s16(v540, v548); + int16x8_t v550 = vaddq_s16(v535, v549); + int16x8_t v551 = vqrdmulhq_n_s16(v550, 16629); + int16x8_t v552 = vaddq_s16(v525, v551); + int16x8_t v553 = vsubq_s16(v504, v506); + int16x8_t v554 = vsubq_s16(v508, v510); + int16x8_t v555 = vqrdmulhq_n_s16(v554, 25826); + int16x8_t v556 = vaddq_s16(v553, v555); + int16x8_t v557 = vsubq_s16(v514, v516); + int16x8_t v558 = vsubq_s16(v520, v519); + int16x8_t v559 = vqrdmulhq_n_s16(v558, 25826); + int16x8_t v560 = vaddq_s16(v557, v559); + int16x8_t v561 = vqrdmulhq_n_s16(v560, 18124); + int16x8_t v562 = vaddq_s16(v556, v561); + int16x8_t v563 = vsubq_s16(v528, v527); + int16x8_t v564 = vsubq_s16(v530, v532); + int16x8_t v565 = vqrdmulhq_n_s16(v564, 25826); + int16x8_t v566 = vaddq_s16(v563, v565); + int16x8_t v567 = vsubq_s16(v536, v538); + int16x8_t v568 = vqrdmulhq_n_s16(v567, 18124); + int16x8_t v569_tmp = vqrdmulhq_n_s16(v546, 654); + int16x8_t v569 = vmlaq_n_s16(v569_tmp, v546, 2); + int16x8_t v570 = vsubq_s16(v543, v569); + int16x8_t v571 = vqrdmulhq_n_s16(v570, 18124); + int16x8_t v572 = vaddq_s16(v568, v571); + int16x8_t v573 = vaddq_s16(v566, v572); + int16x8_t v574 = vqrdmulhq_n_s16(v573, 16792); + int16x8_t v575 = vaddq_s16(v562, v574); + int16x8_t v576 = vsubq_s16(v458, v461); + int16x8_t v577_tmp = vqrdmulhq_n_s16(v465, 25030); + int16x8_t v577 = vaddq_s16(v577_tmp, v465); + int16x8_t v578 = vsubq_s16(v466, v577); + int16x8_t v579_tmp = vqrdmulhq_n_s16(v578, 1988); + int16x8_t v579 = vaddq_s16(v579_tmp, v578); + int16x8_t v580 = vaddq_s16(v576, v579); + int16x8_t v581 = vqrdmulhq_n_s16(v580, 19102); + int16x8_t v582 = vsubq_s16(v447, v449); + int16x8_t v583 = vsubq_s16(v451, v454); + int16x8_t v584_tmp = vqrdmulhq_n_s16(v583, 1988); + int16x8_t v584 = vaddq_s16(v584_tmp, v583); + int16x8_t v585 = vaddq_s16(v582, v584); + int16x8_t v586 = vaddq_s16(v581, v585); + int16x8_t v587 = vsubq_s16(v489, v492); + int16x8_t v588_tmp = vqrdmulhq_n_s16(v498, 25030); + int16x8_t v588 = vaddq_s16(v588_tmp, v498); + int16x8_t v589 = vsubq_s16(v494, v588); + int16x8_t v590_tmp = vqrdmulhq_n_s16(v589, 1988); + int16x8_t v590 = vaddq_s16(v590_tmp, v589); + int16x8_t v591 = vaddq_s16(v587, v590); + int16x8_t v592 = vsubq_s16(v472, v475); + int16x8_t v593 = vqrdmulhq_n_s16(v592, 19102); + int16x8_t v594 = vsubq_s16(v480, v484); + int16x8_t v595 = vaddq_s16(v593, v594); + int16x8_t v596 = vaddq_s16(v591, v595); + int16x8_t v597 = vqrdmulhq_n_s16(v596, 17000); + int16x8_t v598 = vaddq_s16(v586, v597); + int16x8_t v599 = vsubq_s16(v365, v371); + int16x8_t v600_tmp = vqrdmulhq_n_s16(v599, 23673); + int16x8_t v600 = vaddq_s16(v600_tmp, v599); + int16x8_t v601 = vsubq_s16(v377, v383); + int16x8_t v602 = vaddq_s16(v600, v601); + int16x8_t v603 = vsubq_s16(v348, v356); + int16x8_t v604_tmp = vqrdmulhq_n_s16(v603, 23673); + int16x8_t v604 = vaddq_s16(v604_tmp, v603); + int16x8_t v605 = vsubq_s16(v342, v337); + int16x8_t v606 = vaddq_s16(v604, v605); + int16x8_t v607 = vqrdmulhq_n_s16(v606, 20398); + int16x8_t v608 = vaddq_s16(v602, v607); + int16x8_t v609 = vsubq_s16(v391, v397); + int16x8_t v610 = vsubq_s16(v403, v411); + int16x8_t v611_tmp = vqrdmulhq_n_s16(v610, 23673); + int16x8_t v611 = vaddq_s16(v611_tmp, v610); + int16x8_t v612 = vaddq_s16(v609, v611); + int16x8_t v613 = vsubq_s16(v419, v426); + int16x8_t v614 = vqrdmulhq_n_s16(v613, 20398); + int16x8_t v615 = vsubq_s16(v439, v433); + int16x8_t v616_tmp = vqrdmulhq_n_s16(v615, 2367); + int16x8_t v616 = vaddq_s16(v616_tmp, v615); + int16x8_t v617 = vaddq_s16(v614, v616); + int16x8_t v618 = vaddq_s16(v612, v617); + int16x8_t v619 = vqrdmulhq_n_s16(v618, 17255); + int16x8_t v620 = vaddq_s16(v608, v619); + int16x8_t v621 = vsubq_s16(v160, v183); + int16x8_t v622 = vsubq_s16(v235, v225); + int16x8_t v623_tmp = vqrdmulhq_n_s16(v622, 3314); + int16x8_t v623 = vmlaq_n_s16(v623_tmp, v622, 5); + int16x8_t v624 = vaddq_s16(v621, v623); + int16x8_t v625 = vsubq_s16(v254, v279); + int16x8_t v626 = vsubq_s16(v299, v324); + int16x8_t v627_tmp = vqrdmulhq_n_s16(v626, 3314); + int16x8_t v627 = vmlaq_n_s16(v627_tmp, v626, 5); + int16x8_t v628 = vaddq_s16(v625, v627); + int16x8_t v629 = vqrdmulhq_n_s16(v628, 22112); + int16x8_t v630 = vaddq_s16(v624, v629); + int16x8_t v631 = vqrdmulhq_n_s16(v630, 17561); + int16x8_t v632 = vsubq_s16(v9, v24); + int16x8_t v633 = vsubq_s16(v42, v58); + int16x8_t v634_tmp = vqrdmulhq_n_s16(v633, 3314); + int16x8_t v634 = vmlaq_n_s16(v634_tmp, v633, 5); + int16x8_t v635 = vaddq_s16(v632, v634); + int16x8_t v636 = vsubq_s16(v78, v101); + int16x8_t v637 = vsubq_s16(v138, v128); + int16x8_t v638_tmp = vqrdmulhq_n_s16(v637, 3314); + int16x8_t v638 = vmlaq_n_s16(v638_tmp, v637, 5); + int16x8_t v639 = vaddq_s16(v636, v638); + int16x8_t v640 = vqrdmulhq_n_s16(v639, 22112); + int16x8_t v641 = vaddq_s16(v635, v640); + int16x8_t v642 = vaddq_s16(v631, v641); + int16x8_t v643 = vsubq_s16(v632, v634); + int16x8_t v644 = vsubq_s16(v636, v638); + int16x8_t v645 = vqrdmulhq_n_s16(v644, 24397); + int16x8_t v646 = vaddq_s16(v643, v645); + int16x8_t v647 = vsubq_s16(v621, v623); + int16x8_t v648 = vsubq_s16(v625, v627); + int16x8_t v649 = vqrdmulhq_n_s16(v648, 24397); + int16x8_t v650 = vaddq_s16(v647, v649); + int16x8_t v651 = vqrdmulhq_n_s16(v650, 17921); + int16x8_t v652 = vaddq_s16(v646, v651); + int16x8_t v653 = vsubq_s16(v601, v600); + int16x8_t v654 = vsubq_s16(v605, v604); + int16x8_t v655 = vqrdmulhq_n_s16(v654, 27504); + int16x8_t v656 = vaddq_s16(v653, v655); + int16x8_t v657 = vsubq_s16(v609, v611); + int16x8_t v658 = vqrdmulhq_n_s16(v613, 27504); + int16x8_t v659_tmp = vqrdmulhq_n_s16(v615, 14606); + int16x8_t v659 = vaddq_s16(v659_tmp, v615); + int16x8_t v660 = vsubq_s16(v658, v659); + int16x8_t v661 = vaddq_s16(v657, v660); + int16x8_t v662 = vqrdmulhq_n_s16(v661, 18343); + int16x8_t v663 = vaddq_s16(v656, v662); + int16x8_t v664 = vsubq_s16(v582, v584); + int16x8_t v665 = vsubq_s16(v576, v579); + int16x8_t v666 = vqrdmulhq_n_s16(v665, 31869); + int16x8_t v667 = vaddq_s16(v664, v666); + int16x8_t v668 = vsubq_s16(v587, v590); + int16x8_t v669_tmp = vqrdmulhq_n_s16(v594, 23444); + int16x8_t v669 = vaddq_s16(v669_tmp, v594); + int16x8_t v670 = vsubq_s16(v592, v669); + int16x8_t v671 = vqrdmulhq_n_s16(v670, 31869); + int16x8_t v672 = vaddq_s16(v668, v671); + int16x8_t v673 = vqrdmulhq_n_s16(v672, 18830); + int16x8_t v674 = vaddq_s16(v667, v673); + int16x8_t v675 = vsubq_s16(v553, v555); + int16x8_t v676 = vsubq_s16(v557, v559); + int16x8_t v677_tmp = vqrdmulhq_n_s16(v676, 5552); + int16x8_t v677 = vaddq_s16(v677_tmp, v676); + int16x8_t v678 = vaddq_s16(v675, v677); + int16x8_t v679 = vsubq_s16(v563, v565); + int16x8_t v680 = vsubq_s16(v567, v570); + int16x8_t v681_tmp = vqrdmulhq_n_s16(v680, 5552); + int16x8_t v681 = vaddq_s16(v681_tmp, v680); + int16x8_t v682 = vaddq_s16(v679, v681); + int16x8_t v683 = vqrdmulhq_n_s16(v682, 19393); + int16x8_t v684 = vaddq_s16(v678, v683); + int16x8_t v685 = vsubq_s16(v507, v512); + int16x8_t v686 = vsubq_s16(v517, v522); + int16x8_t v687_tmp = vqrdmulhq_n_s16(v686, 15865); + int16x8_t v687 = vaddq_s16(v687_tmp, v686); + int16x8_t v688 = vaddq_s16(v685, v687); + int16x8_t v689 = vsubq_s16(v529, v534); + int16x8_t v690_tmp = vqrdmulhq_n_s16(v548, 28937); + int16x8_t v690 = vaddq_s16(v690_tmp, v548); + int16x8_t v691 = vsubq_s16(v539, v690); + int16x8_t v692_tmp = vqrdmulhq_n_s16(v691, 15865); + int16x8_t v692 = vaddq_s16(v692_tmp, v691); + int16x8_t v693 = vaddq_s16(v689, v692); + int16x8_t v694 = vqrdmulhq_n_s16(v693, 20040); + int16x8_t v695 = vaddq_s16(v688, v694); + int16x8_t v696 = vsubq_s16(v476, v486); + int16x8_t v697_tmp = vqrdmulhq_n_s16(v696, 1893); + int16x8_t v697 = vmlaq_n_s16(v697_tmp, v696, 2); + int16x8_t v698 = vsubq_s16(v493, v499); + int16x8_t v699 = vaddq_s16(v697, v698); + int16x8_t v700 = vqrdmulhq_n_s16(v699, 20783); + int16x8_t v701 = vsubq_s16(v450, v456); + int16x8_t v702 = vsubq_s16(v462, v468); + int16x8_t v703_tmp = vqrdmulhq_n_s16(v702, 1893); + int16x8_t v703 = vmlaq_n_s16(v703_tmp, v702, 2); + int16x8_t v704 = vaddq_s16(v701, v703); + int16x8_t v705 = vaddq_s16(v700, v704); + int16x8_t v706 = vsubq_s16(v384, v373); + int16x8_t v707 = vsubq_s16(v343, v358); + int16x8_t v708_tmp = vqrdmulhq_n_s16(v707, 13357); + int16x8_t v708 = vmlaq_n_s16(v708_tmp, v707, 3); + int16x8_t v709 = vaddq_s16(v706, v708); + int16x8_t v710 = vsubq_s16(v398, v413); + int16x8_t v711 = vsubq_s16(v427, v441); + int16x8_t v712_tmp = vqrdmulhq_n_s16(v711, 13357); + int16x8_t v712 = vmlaq_n_s16(v712_tmp, v711, 3); + int16x8_t v713 = vaddq_s16(v710, v712); + int16x8_t v714 = vqrdmulhq_n_s16(v713, 21637); + int16x8_t v715 = vaddq_s16(v709, v714); + int16x8_t v716 = vsubq_s16(v25, v60); + int16x8_t v717 = vsubq_s16(v102, v140); + int16x8_t v718_tmp = vqrdmulhq_n_s16(v717, 6226); + int16x8_t v718 = vmlaq_n_s16(v718_tmp, v717, 10); + int16x8_t v719 = vaddq_s16(v716, v718); + int16x8_t v720 = vsubq_s16(v280, v326); + int16x8_t v721_tmp = vqrdmulhq_n_s16(v720, 6226); + int16x8_t v721 = vmlaq_n_s16(v721_tmp, v720, 10); + int16x8_t v722 = vsubq_s16(v184, v237); + int16x8_t v723 = vaddq_s16(v721, v722); + int16x8_t v724 = vqrdmulhq_n_s16(v723, 22622); + int16x8_t v725 = vaddq_s16(v719, v724); + int16x8_t v726 = vsubq_s16(v716, v718); + int16x8_t v727 = vsubq_s16(v722, v721); + int16x8_t v728 = vqrdmulhq_n_s16(v727, 23761); + int16x8_t v729 = vaddq_s16(v726, v728); + int16x8_t v730 = vsubq_s16(v706, v708); + int16x8_t v731 = vsubq_s16(v710, v712); + int16x8_t v732 = vqrdmulhq_n_s16(v731, 25084); + int16x8_t v733 = vaddq_s16(v730, v732); + int16x8_t v734 = vsubq_s16(v701, v703); + int16x8_t v735 = vsubq_s16(v698, v697); + int16x8_t v736 = vqrdmulhq_n_s16(v735, 26631); + int16x8_t v737 = vaddq_s16(v734, v736); + int16x8_t v738 = vsubq_s16(v685, v687); + int16x8_t v739 = vsubq_s16(v689, v692); + int16x8_t v740 = vqrdmulhq_n_s16(v739, 28454); + int16x8_t v741 = vaddq_s16(v738, v740); + int16x8_t v742 = vsubq_s16(v675, v677); + int16x8_t v743 = vsubq_s16(v679, v681); + int16x8_t v744 = vqrdmulhq_n_s16(v743, 30624); + int16x8_t v745 = vaddq_s16(v742, v744); + int16x8_t v746 = vsubq_s16(v664, v666); + int16x8_t v747 = vsubq_s16(v668, v671); + int16x8_t v748_tmp = vqrdmulhq_n_s16(v747, 472); + int16x8_t v748 = vaddq_s16(v748_tmp, v747); + int16x8_t v749 = vaddq_s16(v746, v748); + int16x8_t v750 = vsubq_s16(v653, v655); + int16x8_t v751 = vsubq_s16(v657, v660); + int16x8_t v752_tmp = vqrdmulhq_n_s16(v751, 3672); + int16x8_t v752 = vaddq_s16(v752_tmp, v751); + int16x8_t v753 = vaddq_s16(v750, v752); + int16x8_t v754 = vsubq_s16(v643, v645); + int16x8_t v755 = vsubq_s16(v647, v649); + int16x8_t v756_tmp = vqrdmulhq_n_s16(v755, 7662); + int16x8_t v756 = vaddq_s16(v756_tmp, v755); + int16x8_t v757 = vaddq_s16(v754, v756); + int16x8_t v758 = vsubq_s16(v635, v640); + int16x8_t v759 = vsubq_s16(v624, v629); + int16x8_t v760_tmp = vqrdmulhq_n_s16(v759, 12756); + int16x8_t v760 = vaddq_s16(v760_tmp, v759); + int16x8_t v761 = vaddq_s16(v758, v760); + int16x8_t v762 = vsubq_s16(v602, v607); + int16x8_t v763 = vsubq_s16(v612, v617); + int16x8_t v764_tmp = vqrdmulhq_n_s16(v763, 19463); + int16x8_t v764 = vaddq_s16(v764_tmp, v763); + int16x8_t v765 = vaddq_s16(v762, v764); + int16x8_t v766 = vsubq_s16(v585, v581); + int16x8_t v767 = vsubq_s16(v591, v595); + int16x8_t v768_tmp = vqrdmulhq_n_s16(v767, 28661); + int16x8_t v768 = vaddq_s16(v768_tmp, v767); + int16x8_t v769 = vaddq_s16(v766, v768); + int16x8_t v770 = vsubq_s16(v556, v561); + int16x8_t v771 = vsubq_s16(v566, v572); + int16x8_t v772_tmp = vqrdmulhq_n_s16(v771, 9242); + int16x8_t v772 = vmlaq_n_s16(v772_tmp, v771, 2); + int16x8_t v773 = vaddq_s16(v770, v772); + int16x8_t v774 = vsubq_s16(v513, v524); + int16x8_t v775 = vsubq_s16(v535, v549); + int16x8_t v776_tmp = vqrdmulhq_n_s16(v775, 30298); + int16x8_t v776 = vmlaq_n_s16(v776_tmp, v775, 2); + int16x8_t v777 = vaddq_s16(v774, v776); + int16x8_t v778 = vsubq_s16(v457, v470); + int16x8_t v779 = vsubq_s16(v500, v488); + int16x8_t v780_tmp = vqrdmulhq_n_s16(v779, 2773); + int16x8_t v780 = vmlaq_n_s16(v780_tmp, v779, 4); + int16x8_t v781 = vaddq_s16(v778, v780); + int16x8_t v782 = vsubq_s16(v385, v360); + int16x8_t v783 = vsubq_s16(v414, v443); + int16x8_t v784_tmp = vqrdmulhq_n_s16(v783, 26108); + int16x8_t v784 = vmlaq_n_s16(v784_tmp, v783, 6); + int16x8_t v785 = vaddq_s16(v782, v784); + int16x8_t v786 = vsubq_s16(v61, v142); + int16x8_t v787 = vsubq_s16(v238, v328); + int16x8_t v788_tmp = vqrdmulhq_n_s16(v787, 12251); + int16x8_t v788 = vmlaq_n_s16(v788_tmp, v787, 20); + int16x8_t v789 = vaddq_s16(v786, v788); + int16x8_t v790 = vsubq_s16(v786, v788); + int16x8_t v791 = vsubq_s16(v782, v784); + int16x8_t v792 = vsubq_s16(v778, v780); + int16x8_t v793 = vsubq_s16(v774, v776); + int16x8_t v794 = vsubq_s16(v770, v772); + int16x8_t v795 = vsubq_s16(v766, v768); + int16x8_t v796 = vsubq_s16(v762, v764); + int16x8_t v797 = vsubq_s16(v758, v760); + int16x8_t v798 = vsubq_s16(v754, v756); + int16x8_t v799 = vsubq_s16(v750, v752); + int16x8_t v800 = vsubq_s16(v746, v748); + int16x8_t v801 = vsubq_s16(v742, v744); + int16x8_t v802 = vsubq_s16(v738, v740); + int16x8_t v803 = vsubq_s16(v734, v736); + int16x8_t v804 = vsubq_s16(v730, v732); + int16x8_t v805 = vsubq_s16(v726, v728); + int16x8_t v806 = vsubq_s16(v719, v724); + int16x8_t v807 = vsubq_s16(v709, v714); + int16x8_t v808 = vsubq_s16(v704, v700); + int16x8_t v809 = vsubq_s16(v688, v694); + int16x8_t v810 = vsubq_s16(v678, v683); + int16x8_t v811 = vsubq_s16(v667, v673); + int16x8_t v812 = vsubq_s16(v656, v662); + int16x8_t v813 = vsubq_s16(v646, v651); + int16x8_t v814 = vsubq_s16(v641, v631); + int16x8_t v815 = vsubq_s16(v608, v619); + int16x8_t v816 = vsubq_s16(v586, v597); + int16x8_t v817 = vsubq_s16(v562, v574); + int16x8_t v818 = vsubq_s16(v525, v551); + int16x8_t v819 = vsubq_s16(v471, v502); + int16x8_t v820 = vsubq_s16(v386, v445); + int16x8_t v821 = vsubq_s16(v143, v330); + vst1q_s16(out + out_stride * 0 + i, v331); + vst1q_s16(out + out_stride * 1 + i, v446); + vst1q_s16(out + out_stride * 2 + i, v503); + vst1q_s16(out + out_stride * 3 + i, v552); + vst1q_s16(out + out_stride * 4 + i, v575); + vst1q_s16(out + out_stride * 5 + i, v598); + vst1q_s16(out + out_stride * 6 + i, v620); + vst1q_s16(out + out_stride * 7 + i, v642); + vst1q_s16(out + out_stride * 8 + i, v652); + vst1q_s16(out + out_stride * 9 + i, v663); + vst1q_s16(out + out_stride * 10 + i, v674); + vst1q_s16(out + out_stride * 11 + i, v684); + vst1q_s16(out + out_stride * 12 + i, v695); + vst1q_s16(out + out_stride * 13 + i, v705); + vst1q_s16(out + out_stride * 14 + i, v715); + vst1q_s16(out + out_stride * 15 + i, v725); + vst1q_s16(out + out_stride * 16 + i, v729); + vst1q_s16(out + out_stride * 17 + i, v733); + vst1q_s16(out + out_stride * 18 + i, v737); + vst1q_s16(out + out_stride * 19 + i, v741); + vst1q_s16(out + out_stride * 20 + i, v745); + vst1q_s16(out + out_stride * 21 + i, v749); + vst1q_s16(out + out_stride * 22 + i, v753); + vst1q_s16(out + out_stride * 23 + i, v757); + vst1q_s16(out + out_stride * 24 + i, v761); + vst1q_s16(out + out_stride * 25 + i, v765); + vst1q_s16(out + out_stride * 26 + i, v769); + vst1q_s16(out + out_stride * 27 + i, v773); + vst1q_s16(out + out_stride * 28 + i, v777); + vst1q_s16(out + out_stride * 29 + i, v781); + vst1q_s16(out + out_stride * 30 + i, v785); + vst1q_s16(out + out_stride * 31 + i, v789); + vst1q_s16(out + out_stride * 32 + i, v790); + vst1q_s16(out + out_stride * 33 + i, v791); + vst1q_s16(out + out_stride * 34 + i, v792); + vst1q_s16(out + out_stride * 35 + i, v793); + vst1q_s16(out + out_stride * 36 + i, v794); + vst1q_s16(out + out_stride * 37 + i, v795); + vst1q_s16(out + out_stride * 38 + i, v796); + vst1q_s16(out + out_stride * 39 + i, v797); + vst1q_s16(out + out_stride * 40 + i, v798); + vst1q_s16(out + out_stride * 41 + i, v799); + vst1q_s16(out + out_stride * 42 + i, v800); + vst1q_s16(out + out_stride * 43 + i, v801); + vst1q_s16(out + out_stride * 44 + i, v802); + vst1q_s16(out + out_stride * 45 + i, v803); + vst1q_s16(out + out_stride * 46 + i, v804); + vst1q_s16(out + out_stride * 47 + i, v805); + vst1q_s16(out + out_stride * 48 + i, v806); + vst1q_s16(out + out_stride * 49 + i, v807); + vst1q_s16(out + out_stride * 50 + i, v808); + vst1q_s16(out + out_stride * 51 + i, v809); + vst1q_s16(out + out_stride * 52 + i, v810); + vst1q_s16(out + out_stride * 53 + i, v811); + vst1q_s16(out + out_stride * 54 + i, v812); + vst1q_s16(out + out_stride * 55 + i, v813); + vst1q_s16(out + out_stride * 56 + i, v814); + vst1q_s16(out + out_stride * 57 + i, v815); + vst1q_s16(out + out_stride * 58 + i, v816); + vst1q_s16(out + out_stride * 59 + i, v817); + vst1q_s16(out + out_stride * 60 + i, v818); + vst1q_s16(out + out_stride * 61 + i, v819); + vst1q_s16(out + out_stride * 62 + i, v820); + vst1q_s16(out + out_stride * 63 + i, v821); + } +} diff --git a/third_party/jpeg-xl/lib/jxl/fast_dct8-inl.h b/third_party/jpeg-xl/lib/jxl/fast_dct8-inl.h new file mode 100644 index 0000000000..946ace4a0c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_dct8-inl.h @@ -0,0 +1,80 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* This file is automatically generated. Do not modify it directly. */ +#if HWY_TARGET != HWY_NEON +#error "only include this file from fast_dct-inl.h" +#endif + +constexpr size_t FastIDCTIntegerBits(FastDCTTag<8>) { return 1; } + +void FastIDCT(FastDCTTag<8>, const int16_t* in, size_t in_stride, int16_t* out, + size_t out_stride, size_t count) { + JXL_ASSERT(count % 8 == 0); + for (size_t i = 0; i < count; i += 8) { + int16x8_t v0 = vld1q_s16(in + in_stride * 0 + i); + int16x8_t v1 = vld1q_s16(in + in_stride * 4 + i); + int16x8_t v2 = vaddq_s16(v0, v1); + int16x8_t v3 = vld1q_s16(in + in_stride * 2 + i); + int16x8_t v4_tmp = vqrdmulhq_n_s16(v3, 13573); + int16x8_t v4 = vaddq_s16(v4_tmp, v3); + int16x8_t v5 = vld1q_s16(in + in_stride * 6 + i); + int16x8_t v6 = vaddq_s16(v5, v3); + int16x8_t v7 = vaddq_s16(v4, v6); + int16x8_t v8 = vqrdmulhq_n_s16(v7, 17734); + int16x8_t v9 = vaddq_s16(v2, v8); + int16x8_t v10 = vld1q_s16(in + in_stride * 1 + i); + int16x8_t v11_tmp = vqrdmulhq_n_s16(v10, 13573); + int16x8_t v11 = vaddq_s16(v11_tmp, v10); + int16x8_t v12 = vld1q_s16(in + in_stride * 5 + i); + int16x8_t v13 = vld1q_s16(in + in_stride * 3 + i); + int16x8_t v14 = vaddq_s16(v12, v13); + int16x8_t v15 = vaddq_s16(v11, v14); + int16x8_t v16 = vaddq_s16(v13, v10); + int16x8_t v17 = vqrdmulhq_n_s16(v16, 25080); + int16x8_t v18 = vld1q_s16(in + in_stride * 7 + i); + int16x8_t v19 = vaddq_s16(v18, v12); + int16x8_t v20 = vaddq_s16(v16, v19); + int16x8_t v21 = vqrdmulhq_n_s16(v20, 17734); + int16x8_t v22 = vaddq_s16(v17, v21); + int16x8_t v23 = vaddq_s16(v15, v22); + int16x8_t v24 = vqrdmulhq_n_s16(v23, 16705); + int16x8_t v25 = vaddq_s16(v9, v24); + int16x8_t v26 = vsubq_s16(v0, v1); + int16x8_t v27 = vsubq_s16(v4, v6); + int16x8_t v28_tmp = vqrdmulhq_n_s16(v27, 10045); + int16x8_t v28 = vaddq_s16(v28_tmp, v27); + int16x8_t v29 = vaddq_s16(v26, v28); + int16x8_t v30 = vsubq_s16(v11, v14); + int16x8_t v31 = vqrdmulhq_n_s16(v16, 17734); + int16x8_t v32_tmp = vqrdmulhq_n_s16(v19, 10045); + int16x8_t v32 = vaddq_s16(v32_tmp, v19); + int16x8_t v33 = vsubq_s16(v31, v32); + int16x8_t v34 = vaddq_s16(v30, v33); + int16x8_t v35 = vqrdmulhq_n_s16(v34, 19705); + int16x8_t v36 = vaddq_s16(v29, v35); + int16x8_t v37 = vsubq_s16(v26, v28); + int16x8_t v38 = vsubq_s16(v30, v33); + int16x8_t v39 = vqrdmulhq_n_s16(v38, 29490); + int16x8_t v40 = vaddq_s16(v37, v39); + int16x8_t v41 = vsubq_s16(v2, v8); + int16x8_t v42 = vsubq_s16(v15, v22); + int16x8_t v43_tmp = vqrdmulhq_n_s16(v42, 18446); + int16x8_t v43 = vmlaq_n_s16(v43_tmp, v42, 2); + int16x8_t v44 = vaddq_s16(v41, v43); + int16x8_t v45 = vsubq_s16(v41, v43); + int16x8_t v46 = vsubq_s16(v37, v39); + int16x8_t v47 = vsubq_s16(v29, v35); + int16x8_t v48 = vsubq_s16(v9, v24); + vst1q_s16(out + out_stride * 0 + i, v25); + vst1q_s16(out + out_stride * 1 + i, v36); + vst1q_s16(out + out_stride * 2 + i, v40); + vst1q_s16(out + out_stride * 3 + i, v44); + vst1q_s16(out + out_stride * 4 + i, v45); + vst1q_s16(out + out_stride * 5 + i, v46); + vst1q_s16(out + out_stride * 6 + i, v47); + vst1q_s16(out + out_stride * 7 + i, v48); + } +} diff --git a/third_party/jpeg-xl/lib/jxl/fast_dct_test.cc b/third_party/jpeg-xl/lib/jxl/fast_dct_test.cc new file mode 100644 index 0000000000..a55b67afb2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_dct_test.cc @@ -0,0 +1,377 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <numeric> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/fast_dct_test.cc" +#include <hwy/foreach_target.h> + +#include "lib/jxl/base/random.h" +#include "lib/jxl/dct-inl.h" +#include "lib/jxl/fast_dct-inl.h" +#include "lib/jxl/fast_dct.h" +#include "lib/jxl/testing.h" +#include "lib/jxl/transpose-inl.h" + +// Test utils +#include <hwy/highway.h> +#include <hwy/tests/hwy_gtest.h> +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +template <size_t N, size_t M> +HWY_NOINLINE void TestFastTranspose() { +#if HWY_TARGET == HWY_NEON + auto array_mem = hwy::AllocateAligned<int16_t>(N * M); + int16_t* array = array_mem.get(); + auto transposed_mem = hwy::AllocateAligned<int16_t>(N * M); + int16_t* transposed = transposed_mem.get(); + std::iota(array, array + N * M, 0); + for (size_t j = 0; j < 100000000 / (N * M); j++) { + FastTransposeBlock(array, M, N, M, transposed, N); + } + for (size_t i = 0; i < M; i++) { + for (size_t j = 0; j < N; j++) { + EXPECT_EQ(array[j * M + i], transposed[i * N + j]); + } + } +#endif +} + +template <size_t N, size_t M> +HWY_NOINLINE void TestFloatTranspose() { + auto array_mem = hwy::AllocateAligned<float>(N * M); + float* array = array_mem.get(); + auto transposed_mem = hwy::AllocateAligned<float>(N * M); + float* transposed = transposed_mem.get(); + std::iota(array, array + N * M, 0); + for (size_t j = 0; j < 100000000 / (N * M); j++) { + Transpose<N, M>::Run(DCTFrom(array, M), DCTTo(transposed, N)); + } + for (size_t i = 0; i < M; i++) { + for (size_t j = 0; j < N; j++) { + EXPECT_EQ(array[j * M + i], transposed[i * N + j]); + } + } +} + +// TODO(sboukortt): re-enable the FloatIDCT tests once we find out why they fail +// in ASAN mode in the CI runners and seemingly not locally. + +HWY_NOINLINE void TestFastTranspose8x8() { TestFastTranspose<8, 8>(); } +HWY_NOINLINE void TestFloatTranspose8x8() { TestFloatTranspose<8, 8>(); } +HWY_NOINLINE void TestFastIDCT8x8() { TestFastIDCT<8, 8>(); } +HWY_NOINLINE void TestFloatIDCT8x8() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<8, 8>(); +#endif +} +HWY_NOINLINE void TestFastTranspose8x16() { TestFastTranspose<8, 16>(); } +HWY_NOINLINE void TestFloatTranspose8x16() { TestFloatTranspose<8, 16>(); } +HWY_NOINLINE void TestFastIDCT8x16() { TestFastIDCT<8, 16>(); } +HWY_NOINLINE void TestFloatIDCT8x16() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<8, 16>(); +#endif +} +HWY_NOINLINE void TestFastTranspose8x32() { TestFastTranspose<8, 32>(); } +HWY_NOINLINE void TestFloatTranspose8x32() { TestFloatTranspose<8, 32>(); } +HWY_NOINLINE void TestFastIDCT8x32() { TestFastIDCT<8, 32>(); } +HWY_NOINLINE void TestFloatIDCT8x32() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<8, 32>(); +#endif +} +HWY_NOINLINE void TestFastTranspose16x8() { TestFastTranspose<16, 8>(); } +HWY_NOINLINE void TestFloatTranspose16x8() { TestFloatTranspose<16, 8>(); } +HWY_NOINLINE void TestFastIDCT16x8() { TestFastIDCT<16, 8>(); } +HWY_NOINLINE void TestFloatIDCT16x8() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<16, 8>(); +#endif +} +HWY_NOINLINE void TestFastTranspose16x16() { TestFastTranspose<16, 16>(); } +HWY_NOINLINE void TestFloatTranspose16x16() { TestFloatTranspose<16, 16>(); } +HWY_NOINLINE void TestFastIDCT16x16() { TestFastIDCT<16, 16>(); } +HWY_NOINLINE void TestFloatIDCT16x16() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<16, 16>(); +#endif +} +HWY_NOINLINE void TestFastTranspose16x32() { TestFastTranspose<16, 32>(); } +HWY_NOINLINE void TestFloatTranspose16x32() { TestFloatTranspose<16, 32>(); } +HWY_NOINLINE void TestFastIDCT16x32() { TestFastIDCT<16, 32>(); } +HWY_NOINLINE void TestFloatIDCT16x32() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<16, 32>(); +#endif +} +HWY_NOINLINE void TestFastTranspose32x8() { TestFastTranspose<32, 8>(); } +HWY_NOINLINE void TestFloatTranspose32x8() { TestFloatTranspose<32, 8>(); } +HWY_NOINLINE void TestFastIDCT32x8() { TestFastIDCT<32, 8>(); } +HWY_NOINLINE void TestFloatIDCT32x8() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<32, 8>(); +#endif +} +HWY_NOINLINE void TestFastTranspose32x16() { TestFastTranspose<32, 16>(); } +HWY_NOINLINE void TestFloatTranspose32x16() { TestFloatTranspose<32, 16>(); } +HWY_NOINLINE void TestFastIDCT32x16() { TestFastIDCT<32, 16>(); } +HWY_NOINLINE void TestFloatIDCT32x16() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<32, 16>(); +#endif +} +HWY_NOINLINE void TestFastTranspose32x32() { TestFastTranspose<32, 32>(); } +HWY_NOINLINE void TestFloatTranspose32x32() { TestFloatTranspose<32, 32>(); } +HWY_NOINLINE void TestFastIDCT32x32() { TestFastIDCT<32, 32>(); } +HWY_NOINLINE void TestFloatIDCT32x32() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<32, 32>(); +#endif +} +HWY_NOINLINE void TestFastTranspose32x64() { TestFastTranspose<32, 64>(); } +HWY_NOINLINE void TestFloatTranspose32x64() { TestFloatTranspose<32, 64>(); } +HWY_NOINLINE void TestFastIDCT32x64() { TestFastIDCT<32, 64>(); } +HWY_NOINLINE void TestFloatIDCT32x64() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<32, 64>(); +#endif +} +HWY_NOINLINE void TestFastTranspose64x32() { TestFastTranspose<64, 32>(); } +HWY_NOINLINE void TestFloatTranspose64x32() { TestFloatTranspose<64, 32>(); } +HWY_NOINLINE void TestFastIDCT64x32() { TestFastIDCT<64, 32>(); } +HWY_NOINLINE void TestFloatIDCT64x32() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<64, 32>(); +#endif +} +HWY_NOINLINE void TestFastTranspose64x64() { TestFastTranspose<64, 64>(); } +HWY_NOINLINE void TestFloatTranspose64x64() { TestFloatTranspose<64, 64>(); } +HWY_NOINLINE void TestFastIDCT64x64() { TestFastIDCT<64, 64>(); } +HWY_NOINLINE void TestFloatIDCT64x64() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<64, 64>(); +#endif +} +HWY_NOINLINE void TestFastTranspose64x128() { TestFastTranspose<64, 128>(); } +HWY_NOINLINE void TestFloatTranspose64x128() { TestFloatTranspose<64, 128>(); } +/* +HWY_NOINLINE void TestFastIDCT64x128() { TestFastIDCT<64, 128>(); } +HWY_NOINLINE void TestFloatIDCT64x128() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<64, 128>(); +#endif +} +*/ +HWY_NOINLINE void TestFastTranspose128x64() { TestFastTranspose<128, 64>(); } +HWY_NOINLINE void TestFloatTranspose128x64() { TestFloatTranspose<128, 64>(); } +/* +HWY_NOINLINE void TestFastIDCT128x64() { TestFastIDCT<128, 64>(); } +HWY_NOINLINE void TestFloatIDCT128x64() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<128, 64>(); +#endif +} +*/ +HWY_NOINLINE void TestFastTranspose128x128() { TestFastTranspose<128, 128>(); } +HWY_NOINLINE void TestFloatTranspose128x128() { + TestFloatTranspose<128, 128>(); +} +/* +HWY_NOINLINE void TestFastIDCT128x128() { TestFastIDCT<128, 128>(); } +HWY_NOINLINE void TestFloatIDCT128x128() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<128, 128>(); +#endif +} +*/ +HWY_NOINLINE void TestFastTranspose128x256() { TestFastTranspose<128, 256>(); } +HWY_NOINLINE void TestFloatTranspose128x256() { + TestFloatTranspose<128, 256>(); +} +/* +HWY_NOINLINE void TestFastIDCT128x256() { TestFastIDCT<128, 256>(); } +HWY_NOINLINE void TestFloatIDCT128x256() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<128, 256>(); +#endif +} +*/ +HWY_NOINLINE void TestFastTranspose256x128() { TestFastTranspose<256, 128>(); } +HWY_NOINLINE void TestFloatTranspose256x128() { + TestFloatTranspose<256, 128>(); +} +/* +HWY_NOINLINE void TestFastIDCT256x128() { TestFastIDCT<256, 128>(); } +HWY_NOINLINE void TestFloatIDCT256x128() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<256, 128>(); +#endif +} +*/ +HWY_NOINLINE void TestFastTranspose256x256() { TestFastTranspose<256, 256>(); } +HWY_NOINLINE void TestFloatTranspose256x256() { + TestFloatTranspose<256, 256>(); +} +/* +HWY_NOINLINE void TestFastIDCT256x256() { TestFastIDCT<256, 256>(); } +HWY_NOINLINE void TestFloatIDCT256x256() { +#if HWY_TARGET == HWY_SCALAR && \ + (defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER)) + GTEST_SKIP(); +#else + TestFloatIDCT<256, 256>(); +#endif +} +*/ + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class FastDCTTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(FastDCTTargetTest); + +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose8x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose8x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose8x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose8x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose8x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose8x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose16x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose16x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose16x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose16x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose16x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose16x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose32x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose32x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose32x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose32x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose32x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose32x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose32x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose32x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose64x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose64x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose64x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose64x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose64x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose64x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose128x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose128x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose128x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose128x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose128x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose128x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose256x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose256x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatTranspose256x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastTranspose256x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT8x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT8x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT8x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT8x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT8x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT8x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT16x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT16x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT16x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT16x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT16x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT16x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT32x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT32x8); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT32x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT32x16); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT32x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT32x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT32x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT32x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT64x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT64x32); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT64x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT64x64); +/* + * DCT-128 and above have very large errors just by rounding inputs. +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT64x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT64x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT128x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT128x64); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT128x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT128x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT128x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT128x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT256x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT256x128); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFloatIDCT256x256); +HWY_EXPORT_AND_TEST_P(FastDCTTargetTest, TestFastIDCT256x256); +*/ + +TEST(FastDCTTest, TestWrapperFloat) { BenchmarkFloatIDCT32x32(); } +TEST(FastDCTTest, TestWrapperFast) { BenchmarkFastIDCT32x32(); } + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/fast_math_test.cc b/third_party/jpeg-xl/lib/jxl/fast_math_test.cc new file mode 100644 index 0000000000..868e1b72f4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fast_math_test.cc @@ -0,0 +1,237 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/fast_math_test.cc" +#include <jxl/cms.h> + +#include <hwy/foreach_target.h> + +#include "lib/jxl/base/random.h" +#include "lib/jxl/cms/transfer_functions-inl.h" +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/testing.h" + +// Test utils +#include <hwy/highway.h> +#include <hwy/tests/hwy_gtest.h> +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +HWY_NOINLINE void TestFastLog2() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(1e-7f, 1e3f); + const auto actual_v = FastLog2f(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float abs_err = std::abs(std::log2(f) - actual); + EXPECT_LT(abs_err, 3.1E-6) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast<double>(max_abs_err)); +} + +HWY_NOINLINE void TestFastPow2() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_rel_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(-100, 100); + const auto actual_v = FastPow2f(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float expected = std::pow(2, f); + const float rel_err = std::abs(expected - actual) / expected; + EXPECT_LT(rel_err, 3.1E-6) << "f = " << f; + max_rel_err = std::max(max_rel_err, rel_err); + } + printf("max rel err %e\n", static_cast<double>(max_rel_err)); +} + +HWY_NOINLINE void TestFastPow() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_rel_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float b = rng.UniformF(1e-3f, 1e3f); + const float e = rng.UniformF(-10, 10); + const auto actual_v = FastPowf(d, Set(d, b), Set(d, e)); + const float actual = GetLane(actual_v); + const float expected = std::pow(b, e); + const float rel_err = std::abs(expected - actual) / expected; + EXPECT_LT(rel_err, 3E-5) << "b = " << b << " e = " << e; + max_rel_err = std::max(max_rel_err, rel_err); + } + printf("max rel err %e\n", static_cast<double>(max_rel_err)); +} + +HWY_NOINLINE void TestFastCos() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(-1e3f, 1e3f); + const auto actual_v = FastCosf(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float abs_err = std::abs(std::cos(f) - actual); + EXPECT_LT(abs_err, 7E-5) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast<double>(max_abs_err)); +} + +HWY_NOINLINE void TestFastErf() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(-5.f, 5.f); + const auto actual_v = FastErff(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float abs_err = std::abs(std::erf(f) - actual); + EXPECT_LT(abs_err, 7E-4) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast<double>(max_abs_err)); +} + +HWY_NOINLINE void TestCubeRoot() { + const HWY_FULL(float) d; + for (uint64_t x5 = 0; x5 < 2000000; x5++) { + const float x = x5 * 1E-5f; + const float expected = cbrtf(x); + HWY_ALIGN float approx[MaxLanes(d)]; + Store(CubeRootAndAdd(Set(d, x), Zero(d)), d, approx); + + // All lanes are same + for (size_t i = 1; i < Lanes(d); ++i) { + EXPECT_NEAR(approx[0], approx[i], 5E-7f); + } + EXPECT_NEAR(approx[0], expected, 8E-7f); + } +} + +HWY_NOINLINE void TestFastSRGB() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(0.0f, 1.0f); + const auto actual_v = FastLinearToSRGB(d, Set(d, f)); + const float actual = GetLane(actual_v); + const float expected = GetLane(TF_SRGB().EncodedFromDisplay(d, Set(d, f))); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 1.2E-4) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast<double>(max_abs_err)); +} + +HWY_NOINLINE void TestFast709EFD() { + constexpr size_t kNumTrials = 1 << 23; + Rng rng(1); + float max_abs_err = 0; + HWY_FULL(float) d; + for (size_t i = 0; i < kNumTrials; i++) { + const float f = rng.UniformF(0.0f, 1.0f); + const float actual = GetLane(TF_709().EncodedFromDisplay(d, Set(d, f))); + const float expected = TF_709().EncodedFromDisplay(f); + const float abs_err = std::abs(expected - actual); + EXPECT_LT(abs_err, 2e-6) << "f = " << f; + max_abs_err = std::max(max_abs_err, abs_err); + } + printf("max abs err %e\n", static_cast<double>(max_abs_err)); +} + +HWY_NOINLINE void TestFastXYB() { + if (!HasFastXYBTosRGB8()) return; + ImageMetadata metadata; + ImageBundle ib(&metadata); + int scaling = 1; + int n = 256 * scaling; + float inv_scaling = 1.0f / scaling; + int kChunk = 32; + // The image is divided in chunks to reduce total memory usage. + for (int cr = 0; cr < n; cr += kChunk) { + for (int cg = 0; cg < n; cg += kChunk) { + for (int cb = 0; cb < n; cb += kChunk) { + Image3F chunk(kChunk * kChunk, kChunk); + for (int ir = 0; ir < kChunk; ir++) { + for (int ig = 0; ig < kChunk; ig++) { + for (int ib = 0; ib < kChunk; ib++) { + float r = (cr + ir) * inv_scaling; + float g = (cg + ig) * inv_scaling; + float b = (cb + ib) * inv_scaling; + chunk.PlaneRow(0, ir)[ig * kChunk + ib] = r * (1.0f / 255); + chunk.PlaneRow(1, ir)[ig * kChunk + ib] = g * (1.0f / 255); + chunk.PlaneRow(2, ir)[ig * kChunk + ib] = b * (1.0f / 255); + } + } + } + ib.SetFromImage(std::move(chunk), ColorEncoding::SRGB()); + Image3F xyb(kChunk * kChunk, kChunk); + std::vector<uint8_t> roundtrip(kChunk * kChunk * kChunk * 3); + ToXYB(ib, nullptr, &xyb, *JxlGetDefaultCms()); + for (int y = 0; y < kChunk; y++) { + const float* xyba[4] = {xyb.PlaneRow(0, y), xyb.PlaneRow(1, y), + xyb.PlaneRow(2, y), nullptr}; + jxl::HWY_NAMESPACE::FastXYBTosRGB8( + xyba, roundtrip.data() + 3 * xyb.xsize() * y, false, xyb.xsize()); + } + for (int ir = 0; ir < kChunk; ir++) { + for (int ig = 0; ig < kChunk; ig++) { + for (int ib = 0; ib < kChunk; ib++) { + float r = (cr + ir) * inv_scaling; + float g = (cg + ig) * inv_scaling; + float b = (cb + ib) * inv_scaling; + size_t idx = ir * kChunk * kChunk + ig * kChunk + ib; + int rr = roundtrip[3 * idx]; + int rg = roundtrip[3 * idx + 1]; + int rb = roundtrip[3 * idx + 2]; + EXPECT_LT(abs(r - rr), 2) << "expected " << r << " got " << rr; + EXPECT_LT(abs(g - rg), 2) << "expected " << g << " got " << rg; + EXPECT_LT(abs(b - rb), 2) << "expected " << b << " got " << rb; + } + } + } + } + } + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class FastMathTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(FastMathTargetTest); + +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastLog2); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPow2); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPow); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastCos); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastErf); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestCubeRoot); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastSRGB); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFast709EFD); +HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastXYB); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/field_encodings.h b/third_party/jpeg-xl/lib/jxl/field_encodings.h new file mode 100644 index 0000000000..613e8fad33 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/field_encodings.h @@ -0,0 +1,134 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_FIELD_ENCODINGS_H_ +#define LIB_JXL_FIELD_ENCODINGS_H_ + +// Constants needed to encode/decode fields; avoids including the full fields.h. + +#include <stddef.h> +#include <stdint.h> + +#include <hwy/base.h> +#include <vector> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Macro to define the Fields' derived class Name when compiling with debug +// names. +#if JXL_IS_DEBUG_BUILD +#define JXL_FIELDS_NAME(X) \ + const char* Name() const override { return #X; } +#else +#define JXL_FIELDS_NAME(X) +#endif // JXL_IS_DEBUG_BUILD + +class Visitor; +class Fields { + public: + virtual ~Fields() = default; +#if JXL_IS_DEBUG_BUILD + virtual const char* Name() const = 0; +#endif // JXL_IS_DEBUG_BUILD + virtual Status VisitFields(Visitor* JXL_RESTRICT visitor) = 0; +}; + +// Distribution of U32 values for one particular selector. Represents either a +// power of two-sized range, or a single value. A separate type ensures this is +// only passed to the U32Enc ctor. +struct U32Distr { + // No need to validate - all `d` are legitimate. + constexpr explicit U32Distr(uint32_t d) : d(d) {} + + static constexpr uint32_t kDirect = 0x80000000u; + + constexpr bool IsDirect() const { return (d & kDirect) != 0; } + + // Only call if IsDirect(). + constexpr uint32_t Direct() const { return d & (kDirect - 1); } + + // Only call if !IsDirect(). + constexpr size_t ExtraBits() const { return (d & 0x1F) + 1; } + uint32_t Offset() const { return (d >> 5) & 0x3FFFFFF; } + + uint32_t d; +}; + +// A direct-coded 31-bit value occupying 2 bits in the bitstream. +constexpr U32Distr Val(uint32_t value) { + return U32Distr(value | U32Distr::kDirect); +} + +// Value - `offset` will be signaled in `bits` extra bits. +constexpr U32Distr BitsOffset(uint32_t bits, uint32_t offset) { + return U32Distr(((bits - 1) & 0x1F) + ((offset & 0x3FFFFFF) << 5)); +} + +// Value will be signaled in `bits` extra bits. +constexpr U32Distr Bits(uint32_t bits) { return BitsOffset(bits, 0); } + +// See U32Coder documentation in fields.h. +class U32Enc { + public: + constexpr U32Enc(const U32Distr d0, const U32Distr d1, const U32Distr d2, + const U32Distr d3) + : d_{d0, d1, d2, d3} {} + + // Returns the U32Distr at `selector` = 0..3, least-significant first. + U32Distr GetDistr(const uint32_t selector) const { + JXL_ASSERT(selector < 4); + return d_[selector]; + } + + private: + U32Distr d_[4]; +}; + +// Returns bit with the given `index` (0 = least significant). +template <typename T> +static inline constexpr uint64_t MakeBit(T index) { + return 1ULL << static_cast<uint32_t>(index); +} + +// Returns vector of all possible values of an Enum type. Relies on each Enum +// providing an overload of EnumBits() that returns a bit array of its values, +// which implies values must be in [0, 64). +template <typename Enum> +std::vector<Enum> Values() { + uint64_t bits = EnumBits(Enum()); + + std::vector<Enum> values; + values.reserve(hwy::PopCount(bits)); + + // For each 1-bit in bits: add its index as value + while (bits != 0) { + const int index = Num0BitsBelowLS1Bit_Nonzero(bits); + values.push_back(static_cast<Enum>(index)); + bits &= bits - 1; // clear least-significant bit + } + return values; +} + +// Returns true if value is one of Values<Enum>(). +template <class Enum> +Status EnumValid(const Enum value) { + if (static_cast<uint32_t>(value) >= 64) { + return JXL_FAILURE("Value %u too large for %s\n", + static_cast<uint32_t>(value), EnumName(Enum())); + } + const uint64_t bit = MakeBit(value); + if ((EnumBits(Enum()) & bit) == 0) { + return JXL_FAILURE("Invalid value %u for %s\n", + static_cast<uint32_t>(value), EnumName(Enum())); + } + return true; +} + +} // namespace jxl + +#endif // LIB_JXL_FIELD_ENCODINGS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/fields.cc b/third_party/jpeg-xl/lib/jxl/fields.cc new file mode 100644 index 0000000000..746d7e4d30 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fields.cc @@ -0,0 +1,656 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/fields.h" + +#include <algorithm> +#include <cinttypes> +#include <cmath> +#include <cstddef> +#include <hwy/base.h> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/printf_macros.h" + +namespace jxl { + +namespace { + +using ::jxl::fields_internal::VisitorBase; + +struct InitVisitor : public VisitorBase { + Status Bits(const size_t /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U32(const U32Enc /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U64(const uint64_t default_value, + uint64_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status Bool(bool default_value, bool* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status F16(const float default_value, float* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + // Always visit conditional fields to ensure they are initialized. + Status Conditional(bool /*condition*/) override { return true; } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) override { + // Just initialize this field and don't skip initializing others. + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return false; + } + + Status VisitNested(Fields* /*fields*/) override { + // Avoid re-initializing nested bundles (their ctors already called + // Bundle::Init for their fields). + return true; + } +}; + +// Similar to InitVisitor, but also initializes nested fields. +struct SetDefaultVisitor : public VisitorBase { + Status Bits(const size_t /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U32(const U32Enc /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status U64(const uint64_t default_value, + uint64_t* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status Bool(bool default_value, bool* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + Status F16(const float default_value, float* JXL_RESTRICT value) override { + *value = default_value; + return true; + } + + // Always visit conditional fields to ensure they are initialized. + Status Conditional(bool /*condition*/) override { return true; } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) override { + // Just initialize this field and don't skip initializing others. + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return false; + } +}; + +class AllDefaultVisitor : public VisitorBase { + public: + explicit AllDefaultVisitor() : VisitorBase() {} + + Status Bits(const size_t bits, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + all_default_ &= *value == default_value; + return true; + } + + Status U32(const U32Enc /*unused*/, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) override { + all_default_ &= *value == default_value; + return true; + } + + Status U64(const uint64_t default_value, + uint64_t* JXL_RESTRICT value) override { + all_default_ &= *value == default_value; + return true; + } + + Status F16(const float default_value, float* JXL_RESTRICT value) override { + all_default_ &= std::abs(*value - default_value) < 1E-6f; + return true; + } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT /*all_default*/) override { + // Visit all fields so we can compute the actual all_default_ value. + return false; + } + + bool AllDefault() const { return all_default_; } + + private: + bool all_default_ = true; +}; + +class ReadVisitor : public VisitorBase { + public: + explicit ReadVisitor(BitReader* reader) : VisitorBase(), reader_(reader) {} + + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + *value = BitsCoder::Read(bits, reader_); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + return true; + } + + Status U32(const U32Enc dist, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + *value = U32Coder::Read(dist, reader_); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT value) override { + *value = U64Coder::Read(reader_); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT value) override { + ok_ &= F16Coder::Read(reader_, value); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + return true; + } + + void SetDefault(Fields* fields) override { Bundle::SetDefault(fields); } + + bool IsReading() const override { return true; } + + // This never fails because visitors are expected to keep reading until + // EndExtensions, see comment there. + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions)); + if (*extensions == 0) return true; + + // For each nonzero bit, i.e. extension that is present: + for (uint64_t remaining_extensions = *extensions; remaining_extensions != 0; + remaining_extensions &= remaining_extensions - 1) { + const size_t idx_extension = + Num0BitsBelowLS1Bit_Nonzero(remaining_extensions); + // Read additional U64 (one per extension) indicating the number of bits + // (allows skipping individual extensions). + JXL_RETURN_IF_ERROR(U64(0, &extension_bits_[idx_extension])); + if (!SafeAdd(total_extension_bits_, extension_bits_[idx_extension], + total_extension_bits_)) { + return JXL_FAILURE("Extension bits overflowed, invalid codestream"); + } + } + // Used by EndExtensions to skip past any _remaining_ extensions. + pos_after_ext_size_ = reader_->TotalBitsConsumed(); + JXL_ASSERT(pos_after_ext_size_ != 0); + return true; + } + + Status EndExtensions() override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::EndExtensions()); + // Happens if extensions == 0: don't read size, done. + if (pos_after_ext_size_ == 0) return true; + + // Not enough bytes as set by BeginExtensions or earlier. Do not return + // this as a JXL_FAILURE or false (which can also propagate to error + // through e.g. JXL_RETURN_IF_ERROR), since this may be used while + // silently checking whether there are enough bytes. If this case must be + // treated as an error, reader_>Close() will do this, just like is already + // done for non-extension fields. + if (!enough_bytes_) return true; + + // Skip new fields this (old?) decoder didn't know about, if any. + const size_t bits_read = reader_->TotalBitsConsumed(); + uint64_t end; + if (!SafeAdd(pos_after_ext_size_, total_extension_bits_, end)) { + return JXL_FAILURE("Invalid extension size, caused overflow"); + } + if (bits_read > end) { + return JXL_FAILURE("Read more extension bits than budgeted"); + } + const size_t remaining_bits = end - bits_read; + if (remaining_bits != 0) { + JXL_WARNING("Skipping %" PRIuS "-bit extension(s)", remaining_bits); + reader_->SkipBits(remaining_bits); + if (!reader_->AllReadsWithinBounds()) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + } + return true; + } + + Status OK() const { return ok_; } + + private: + // Whether any error other than not enough bytes occurred. + bool ok_ = true; + + // Whether there are enough input bytes to read from. + bool enough_bytes_ = true; + BitReader* const reader_; + // May be 0 even if the corresponding extension is present. + uint64_t extension_bits_[Bundle::kMaxExtensions] = {0}; + uint64_t total_extension_bits_ = 0; + size_t pos_after_ext_size_ = 0; // 0 iff extensions == 0. + + friend Status jxl::CheckHasEnoughBits(Visitor*, size_t); +}; + +class MaxBitsVisitor : public VisitorBase { + public: + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT /*value*/) override { + max_bits_ += BitsCoder::MaxEncodedBits(bits); + return true; + } + + Status U32(const U32Enc enc, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT /*value*/) override { + max_bits_ += U32Coder::MaxEncodedBits(enc); + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT /*value*/) override { + max_bits_ += U64Coder::MaxEncodedBits(); + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT /*value*/) override { + max_bits_ += F16Coder::MaxEncodedBits(); + return true; + } + + Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) override { + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return false; // For max bits, assume nothing is default + } + + // Always visit conditional fields to get a (loose) upper bound. + Status Conditional(bool /*condition*/) override { return true; } + + Status BeginExtensions(uint64_t* JXL_RESTRICT /*extensions*/) override { + // Skip - extensions are not included in "MaxBits" because their length + // is potentially unbounded. + return true; + } + + Status EndExtensions() override { return true; } + + size_t MaxBits() const { return max_bits_; } + + private: + size_t max_bits_ = 0; +}; + +class CanEncodeVisitor : public VisitorBase { + public: + explicit CanEncodeVisitor() : VisitorBase() {} + + Status Bits(const size_t bits, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= BitsCoder::CanEncode(bits, *value, &encoded_bits); + encoded_bits_ += encoded_bits; + return true; + } + + Status U32(const U32Enc enc, const uint32_t /*default_value*/, + uint32_t* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= U32Coder::CanEncode(enc, *value, &encoded_bits); + encoded_bits_ += encoded_bits; + return true; + } + + Status U64(const uint64_t /*default_value*/, + uint64_t* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= U64Coder::CanEncode(*value, &encoded_bits); + encoded_bits_ += encoded_bits; + return true; + } + + Status F16(const float /*default_value*/, + float* JXL_RESTRICT value) override { + size_t encoded_bits = 0; + ok_ &= F16Coder::CanEncode(*value, &encoded_bits); + encoded_bits_ += encoded_bits; + return true; + } + + Status AllDefault(const Fields& fields, + bool* JXL_RESTRICT all_default) override { + *all_default = Bundle::AllDefault(fields); + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return *all_default; + } + + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions)); + extensions_ = *extensions; + if (*extensions != 0) { + JXL_ASSERT(pos_after_ext_ == 0); + pos_after_ext_ = encoded_bits_; + JXL_ASSERT(pos_after_ext_ != 0); // visited "extensions" + } + return true; + } + // EndExtensions = default. + + Status GetSizes(size_t* JXL_RESTRICT extension_bits, + size_t* JXL_RESTRICT total_bits) { + JXL_RETURN_IF_ERROR(ok_); + *extension_bits = 0; + *total_bits = encoded_bits_; + // Only if extension field was nonzero will we encode their sizes. + if (pos_after_ext_ != 0) { + JXL_ASSERT(encoded_bits_ >= pos_after_ext_); + *extension_bits = encoded_bits_ - pos_after_ext_; + // Also need to encode *extension_bits and bill it to *total_bits. + size_t encoded_bits = 0; + ok_ &= U64Coder::CanEncode(*extension_bits, &encoded_bits); + *total_bits += encoded_bits; + + // TODO(janwas): support encoding individual extension sizes. We + // currently ascribe all bits to the first and send zeros for the + // others. + for (size_t i = 1; i < hwy::PopCount(extensions_); ++i) { + encoded_bits = 0; + ok_ &= U64Coder::CanEncode(0, &encoded_bits); + *total_bits += encoded_bits; + } + } + return true; + } + + private: + bool ok_ = true; + size_t encoded_bits_ = 0; + uint64_t extensions_ = 0; + // Snapshot of encoded_bits_ after visiting the extension field, but NOT + // including the hidden extension sizes. + uint64_t pos_after_ext_ = 0; +}; +} // namespace + +void Bundle::Init(Fields* fields) { + InitVisitor visitor; + if (!visitor.Visit(fields)) { + JXL_UNREACHABLE("Init should never fail"); + } +} +void Bundle::SetDefault(Fields* fields) { + SetDefaultVisitor visitor; + if (!visitor.Visit(fields)) { + JXL_UNREACHABLE("SetDefault should never fail"); + } +} +bool Bundle::AllDefault(const Fields& fields) { + AllDefaultVisitor visitor; + if (!visitor.VisitConst(fields)) { + JXL_UNREACHABLE("AllDefault should never fail"); + } + return visitor.AllDefault(); +} +size_t Bundle::MaxBits(const Fields& fields) { + MaxBitsVisitor visitor; +#if JXL_ENABLE_ASSERT + Status ret = +#else + (void) +#endif // JXL_ENABLE_ASSERT + visitor.VisitConst(fields); + JXL_ASSERT(ret); + return visitor.MaxBits(); +} +Status Bundle::CanEncode(const Fields& fields, size_t* extension_bits, + size_t* total_bits) { + CanEncodeVisitor visitor; + JXL_QUIET_RETURN_IF_ERROR(visitor.VisitConst(fields)); + JXL_QUIET_RETURN_IF_ERROR(visitor.GetSizes(extension_bits, total_bits)); + return true; +} +Status Bundle::Read(BitReader* reader, Fields* fields) { + ReadVisitor visitor(reader); + JXL_RETURN_IF_ERROR(visitor.Visit(fields)); + return visitor.OK(); +} +bool Bundle::CanRead(BitReader* reader, Fields* fields) { + ReadVisitor visitor(reader); + Status status = visitor.Visit(fields); + // We are only checking here whether there are enough bytes. We still return + // true for other errors because it means there are enough bytes to determine + // there's an error. Use Read() to determine which error it is. + return status.code() != StatusCode::kNotEnoughBytes; +} + +size_t BitsCoder::MaxEncodedBits(const size_t bits) { return bits; } + +Status BitsCoder::CanEncode(const size_t bits, const uint32_t value, + size_t* JXL_RESTRICT encoded_bits) { + *encoded_bits = bits; + if (value >= (1ULL << bits)) { + return JXL_FAILURE("Value %u too large for %" PRIu64 " bits", value, + static_cast<uint64_t>(bits)); + } + return true; +} + +uint32_t BitsCoder::Read(const size_t bits, BitReader* JXL_RESTRICT reader) { + return reader->ReadBits(bits); +} + +size_t U32Coder::MaxEncodedBits(const U32Enc enc) { + size_t extra_bits = 0; + for (uint32_t selector = 0; selector < 4; ++selector) { + const U32Distr d = enc.GetDistr(selector); + if (d.IsDirect()) { + continue; + } else { + extra_bits = std::max<size_t>(extra_bits, d.ExtraBits()); + } + } + return 2 + extra_bits; +} + +Status U32Coder::CanEncode(const U32Enc enc, const uint32_t value, + size_t* JXL_RESTRICT encoded_bits) { + uint32_t selector; + size_t total_bits; + const Status ok = ChooseSelector(enc, value, &selector, &total_bits); + *encoded_bits = ok ? total_bits : 0; + return ok; +} + +uint32_t U32Coder::Read(const U32Enc enc, BitReader* JXL_RESTRICT reader) { + const uint32_t selector = reader->ReadFixedBits<2>(); + const U32Distr d = enc.GetDistr(selector); + if (d.IsDirect()) { + return d.Direct(); + } else { + return reader->ReadBits(d.ExtraBits()) + d.Offset(); + } +} + +Status U32Coder::ChooseSelector(const U32Enc enc, const uint32_t value, + uint32_t* JXL_RESTRICT selector, + size_t* JXL_RESTRICT total_bits) { +#if JXL_ENABLE_ASSERT + const size_t bits_required = 32 - Num0BitsAboveMS1Bit(value); +#endif // JXL_ENABLE_ASSERT + JXL_ASSERT(bits_required <= 32); + + *selector = 0; + *total_bits = 0; + + // It is difficult to verify whether Dist32Byte are sorted, so check all + // selectors and keep the one with the fewest total_bits. + *total_bits = 64; // more than any valid encoding + for (uint32_t s = 0; s < 4; ++s) { + const U32Distr d = enc.GetDistr(s); + if (d.IsDirect()) { + if (d.Direct() == value) { + *selector = s; + *total_bits = 2; + return true; // Done, direct is always the best possible. + } + continue; + } + const size_t extra_bits = d.ExtraBits(); + const uint32_t offset = d.Offset(); + if (value < offset || value >= offset + (1ULL << extra_bits)) continue; + + // Better than prior encoding, remember it: + if (2 + extra_bits < *total_bits) { + *selector = s; + *total_bits = 2 + extra_bits; + } + } + + if (*total_bits == 64) { + return JXL_FAILURE("No feasible selector for %u", value); + } + + return true; +} + +uint64_t U64Coder::Read(BitReader* JXL_RESTRICT reader) { + uint64_t selector = reader->ReadFixedBits<2>(); + if (selector == 0) { + return 0; + } + if (selector == 1) { + return 1 + reader->ReadFixedBits<4>(); + } + if (selector == 2) { + return 17 + reader->ReadFixedBits<8>(); + } + + // selector 3, varint, groups have first 12, then 8, and last 4 bits. + uint64_t result = reader->ReadFixedBits<12>(); + + uint64_t shift = 12; + while (reader->ReadFixedBits<1>()) { + if (shift == 60) { + result |= static_cast<uint64_t>(reader->ReadFixedBits<4>()) << shift; + break; + } + result |= static_cast<uint64_t>(reader->ReadFixedBits<8>()) << shift; + shift += 8; + } + + return result; +} + +// Can always encode, but useful because it also returns bit size. +Status U64Coder::CanEncode(uint64_t value, size_t* JXL_RESTRICT encoded_bits) { + if (value == 0) { + *encoded_bits = 2; // 2 selector bits + } else if (value <= 16) { + *encoded_bits = 2 + 4; // 2 selector bits + 4 payload bits + } else if (value <= 272) { + *encoded_bits = 2 + 8; // 2 selector bits + 8 payload bits + } else { + *encoded_bits = 2 + 12; // 2 selector bits + 12 payload bits + value >>= 12; + int shift = 12; + while (value > 0 && shift < 60) { + *encoded_bits += 1 + 8; // 1 continuation bit + 8 payload bits + value >>= 8; + shift += 8; + } + if (value > 0) { + // This only could happen if shift == N - 4. + *encoded_bits += 1 + 4; // 1 continuation bit + 4 payload bits + } else { + *encoded_bits += 1; // 1 stop bit + } + } + + return true; +} + +Status F16Coder::Read(BitReader* JXL_RESTRICT reader, + float* JXL_RESTRICT value) { + const uint32_t bits16 = reader->ReadFixedBits<16>(); + const uint32_t sign = bits16 >> 15; + const uint32_t biased_exp = (bits16 >> 10) & 0x1F; + const uint32_t mantissa = bits16 & 0x3FF; + + if (JXL_UNLIKELY(biased_exp == 31)) { + return JXL_FAILURE("F16 infinity or NaN are not supported"); + } + + // Subnormal or zero + if (JXL_UNLIKELY(biased_exp == 0)) { + *value = (1.0f / 16384) * (mantissa * (1.0f / 1024)); + if (sign) *value = -*value; + return true; + } + + // Normalized: convert the representation directly (faster than ldexp/tables). + const uint32_t biased_exp32 = biased_exp + (127 - 15); + const uint32_t mantissa32 = mantissa << (23 - 10); + const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; + memcpy(value, &bits32, sizeof(bits32)); + return true; +} + +Status F16Coder::CanEncode(float value, size_t* JXL_RESTRICT encoded_bits) { + *encoded_bits = MaxEncodedBits(); + if (std::isnan(value) || std::isinf(value)) { + return JXL_FAILURE("Should not attempt to store NaN and infinity"); + } + return std::abs(value) <= 65504.0f; +} + +Status CheckHasEnoughBits(Visitor* visitor, size_t bits) { + if (!visitor->IsReading()) return false; + ReadVisitor* rv = static_cast<ReadVisitor*>(visitor); + size_t have_bits = rv->reader_->TotalBytes() * kBitsPerByte; + size_t want_bits = bits + rv->reader_->TotalBitsConsumed(); + if (have_bits < want_bits) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for header"); + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/fields.h b/third_party/jpeg-xl/lib/jxl/fields.h new file mode 100644 index 0000000000..d05fe4517e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fields.h @@ -0,0 +1,374 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_FIELDS_H_ +#define LIB_JXL_FIELDS_H_ + +// Forward/backward-compatible 'bundles' with auto-serialized 'fields'. + +#include <cmath> // abs +#include <cstdarg> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <cstring> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +struct AuxOut; +struct BitWriter; + +// Integer coders: BitsCoder (raw), U32Coder (table), U64Coder (varint). + +// Reads/writes a given (fixed) number of bits <= 32. +namespace BitsCoder { +size_t MaxEncodedBits(size_t bits); + +Status CanEncode(size_t bits, uint32_t value, + size_t* JXL_RESTRICT encoded_bits); + +uint32_t Read(size_t bits, BitReader* JXL_RESTRICT reader); + +// Returns false if the value is too large to encode. +Status Write(size_t bits, uint32_t value, BitWriter* JXL_RESTRICT writer); +} // namespace BitsCoder + +// Encodes u32 using a lookup table and/or extra bits, governed by a per-field +// encoding `enc` which consists of four distributions `d` chosen via a 2-bit +// selector (least significant = 0). Each d may have two modes: +// - direct: if d.IsDirect(), the value is d.Direct(); +// - offset: the value is derived from d.ExtraBits() extra bits plus d.Offset(); +// This encoding is denser than Exp-Golomb or Gamma codes when both small and +// large values occur. +// +// Examples: +// Direct: U32Enc(Val(8), Val(16), Val(32), Bits(6)), value 32 => 10b. +// Offset: U32Enc(Val(0), BitsOffset(1, 1), BitsOffset(2, 3), BitsOffset(8, 8)) +// defines the following prefix code: +// 00 -> 0 +// 01x -> 1..2 +// 10xx -> 3..7 +// 11xxxxxxxx -> 8..263 +namespace U32Coder { +size_t MaxEncodedBits(U32Enc enc); +Status CanEncode(U32Enc enc, uint32_t value, size_t* JXL_RESTRICT encoded_bits); +uint32_t Read(U32Enc enc, BitReader* JXL_RESTRICT reader); + +// Returns false if the value is too large to encode. +Status Write(U32Enc enc, uint32_t value, BitWriter* JXL_RESTRICT writer); + +// "private" +Status ChooseSelector(U32Enc enc, uint32_t value, + uint32_t* JXL_RESTRICT selector, + size_t* JXL_RESTRICT total_bits); +} // namespace U32Coder + +// Encodes 64-bit unsigned integers with a fixed distribution, taking 2 bits +// to encode 0, 6 bits to encode 1 to 16, 10 bits to encode 17 to 272, 15 bits +// to encode up to 4095, and on the order of log2(value) * 1.125 bits for +// larger values. +namespace U64Coder { +constexpr size_t MaxEncodedBits() { return 2 + 12 + 6 * (8 + 1) + (4 + 1); } + +uint64_t Read(BitReader* JXL_RESTRICT reader); + +// Returns false if the value is too large to encode. +Status Write(uint64_t value, BitWriter* JXL_RESTRICT writer); + +// Can always encode, but useful because it also returns bit size. +Status CanEncode(uint64_t value, size_t* JXL_RESTRICT encoded_bits); +} // namespace U64Coder + +// IEEE 754 half-precision (binary16). Refuses to read/write NaN/Inf. +namespace F16Coder { +constexpr size_t MaxEncodedBits() { return 16; } + +// Returns false if the bit representation is NaN or infinity +Status Read(BitReader* JXL_RESTRICT reader, float* JXL_RESTRICT value); + +// Returns false if the value is too large to encode. +Status Write(float value, BitWriter* JXL_RESTRICT writer); +Status CanEncode(float value, size_t* JXL_RESTRICT encoded_bits); +} // namespace F16Coder + +// A "bundle" is a forward- and backward compatible collection of fields. +// They are used for SizeHeader/FrameHeader/GroupHeader. Bundles can be +// extended by appending(!) fields. Optional fields may be omitted from the +// bitstream by conditionally visiting them. When reading new bitstreams with +// old code, we skip unknown fields at the end of the bundle. This requires +// storing the amount of extra appended bits, and that fields are visited in +// chronological order of being added to the format, because old decoders +// cannot skip some future fields and resume reading old fields. Similarly, +// new readers query bits in an "extensions" field to skip (groups of) fields +// not present in old bitstreams. Note that each bundle must include an +// "extensions" field prior to freezing the format, otherwise it cannot be +// extended. +// +// To ensure interoperability, there will be no opaque fields. +// +// HOWTO: +// - basic usage: define a struct with member variables ("fields") and a +// VisitFields(v) member function that calls v->U32/Bool etc. for each +// field, specifying their default values. The ctor must call +// Bundle::Init(this). +// +// - print a trace of visitors: ensure each bundle has a static Name() member +// function, and change Bundle::Print* to return true. +// +// - optional fields: in VisitFields, add if (v->Conditional(your_condition)) +// { v->Bool(default, &field); }. This prevents reading/writing field +// if !your_condition, which is typically computed from a prior field. +// WARNING: to ensure all fields are initialized, do not add an else branch; +// instead add another if (v->Conditional(!your_condition)). +// +// - repeated fields: for dynamic sizes, use e.g. std::vector and in +// VisitFields, if (v->IsReading()) field.resize(size) before accessing field. +// For static or bounded sizes, use an array or std::array. In all cases, +// simply visit each array element as if it were a normal field. +// +// - nested bundles: add a bundle as a normal field and in VisitFields call +// JXL_RETURN_IF_ERROR(v->VisitNested(&nested)); +// +// - allow future extensions: define a "uint64_t extensions" field and call +// v->BeginExtensions(&extensions) after visiting all non-extension fields, +// and `return v->EndExtensions();` after the last extension field. +// +// - encode an entire bundle in one bit if ALL its fields equal their default +// values: add a "mutable bool all_default" field and as the first visitor: +// if (v->AllDefault(*this, &all_default)) { +// // Overwrite all serialized fields, but not any nonserialized_*. +// v->SetDefault(this); +// return true; +// } +// Note: if extensions are present, AllDefault() == false. + +namespace Bundle { +constexpr size_t kMaxExtensions = 64; // bits in u64 + +// Initializes fields to the default values. It is not recursive to nested +// fields, this function is intended to be called in the constructors so +// each nested field will already Init itself. +void Init(Fields* JXL_RESTRICT fields); + +// Similar to Init, but recursive to nested fields. +void SetDefault(Fields* JXL_RESTRICT fields); + +// Returns whether ALL fields (including `extensions`, if present) are equal +// to their default value. +bool AllDefault(const Fields& fields); + +// Returns max number of bits required to encode a T. +size_t MaxBits(const Fields& fields); + +// Returns whether a header's fields can all be encoded, i.e. they have a +// valid representation. If so, "*total_bits" is the exact number of bits +// required. Called by Write. +Status CanEncode(const Fields& fields, size_t* JXL_RESTRICT extension_bits, + size_t* JXL_RESTRICT total_bits); + +Status Read(BitReader* reader, Fields* JXL_RESTRICT fields); + +// Returns whether enough bits are available to fully read this bundle using +// Read. Also returns true in case of a codestream error (other than not being +// large enough): that means enough bits are available to determine there's an +// error, use Read to get such error status. +// NOTE: this advances the BitReader, a different one pointing back at the +// original bit position in the codestream must be created to use Read after +// this. +bool CanRead(BitReader* reader, Fields* JXL_RESTRICT fields); + +Status Write(const Fields& fields, BitWriter* JXL_RESTRICT writer, size_t layer, + AuxOut* aux_out); +} // namespace Bundle + +// Different subclasses of Visitor are passed to implementations of Fields +// throughout their lifetime. Templates used to be used for this but dynamic +// polymorphism produces more compact executables than template reification did. +class Visitor { + public: + virtual ~Visitor() = default; + virtual Status Visit(Fields* fields) = 0; + + virtual Status Bool(bool default_value, bool* JXL_RESTRICT value) = 0; + virtual Status U32(U32Enc, uint32_t, uint32_t*) = 0; + + // Helper to construct U32Enc from U32Distr. + Status U32(const U32Distr d0, const U32Distr d1, const U32Distr d2, + const U32Distr d3, const uint32_t default_value, + uint32_t* JXL_RESTRICT value) { + return U32(U32Enc(d0, d1, d2, d3), default_value, value); + } + + template <typename EnumT> + Status Enum(const EnumT default_value, EnumT* JXL_RESTRICT value) { + uint32_t u32 = static_cast<uint32_t>(*value); + // 00 -> 0 + // 01 -> 1 + // 10xxxx -> 2..17 + // 11yyyyyy -> 18..81 + JXL_RETURN_IF_ERROR(U32(Val(0), Val(1), BitsOffset(4, 2), BitsOffset(6, 18), + static_cast<uint32_t>(default_value), &u32)); + *value = static_cast<EnumT>(u32); + return EnumValid(*value); + } + + virtual Status Bits(size_t bits, uint32_t default_value, + uint32_t* JXL_RESTRICT value) = 0; + virtual Status U64(uint64_t default_value, uint64_t* JXL_RESTRICT value) = 0; + virtual Status F16(float default_value, float* JXL_RESTRICT value) = 0; + + // Returns whether VisitFields should visit some subsequent fields. + // "condition" is typically from prior fields, e.g. flags. + // Overridden by InitVisitor and MaxBitsVisitor. + virtual Status Conditional(bool condition) { return condition; } + + // Overridden by InitVisitor, AllDefaultVisitor and CanEncodeVisitor. + virtual Status AllDefault(const Fields& /*fields*/, + bool* JXL_RESTRICT all_default) { + JXL_RETURN_IF_ERROR(Bool(true, all_default)); + return *all_default; + } + + virtual void SetDefault(Fields* /*fields*/) { + // Do nothing by default, this is overridden by ReadVisitor. + } + + // Returns the result of visiting a nested Bundle. + // Overridden by InitVisitor. + virtual Status VisitNested(Fields* fields) { return Visit(fields); } + + // Overridden by ReadVisitor. Enables dynamically-sized fields. + virtual bool IsReading() const { return false; } + + virtual Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) = 0; + virtual Status EndExtensions() = 0; +}; + +namespace fields_internal { +// A bundle can be in one of three states concerning extensions: not-begun, +// active, ended. Bundles may be nested, so we need a stack of states. +class ExtensionStates { + public: + void Push() { + // Initial state = not-begun. + begun_ <<= 1; + ended_ <<= 1; + } + + // Clears current state; caller must check IsEnded beforehand. + void Pop() { + begun_ >>= 1; + ended_ >>= 1; + } + + // Returns true if state == active || state == ended. + Status IsBegun() const { return (begun_ & 1) != 0; } + // Returns true if state != not-begun && state != active. + Status IsEnded() const { return (ended_ & 1) != 0; } + + void Begin() { + JXL_ASSERT(!IsBegun()); + JXL_ASSERT(!IsEnded()); + begun_ += 1; + } + + void End() { + JXL_ASSERT(IsBegun()); + JXL_ASSERT(!IsEnded()); + ended_ += 1; + } + + private: + // Current state := least-significant bit of begun_ and ended_. + uint64_t begun_ = 0; + uint64_t ended_ = 0; +}; + +// Visitors generate Init/AllDefault/Read/Write logic for all fields. Each +// bundle's VisitFields member function calls visitor->U32 etc. We do not +// overload operator() because a function name is easier to search for. + +class VisitorBase : public Visitor { + public: + explicit VisitorBase() {} + ~VisitorBase() override { JXL_ASSERT(depth_ == 0); } + + // This is the only call site of Fields::VisitFields. + // Ensures EndExtensions was called. + Status Visit(Fields* fields) override { + depth_ += 1; + JXL_ASSERT(depth_ <= Bundle::kMaxExtensions); + extension_states_.Push(); + + const Status ok = fields->VisitFields(this); + + if (ok) { + // If VisitFields called BeginExtensions, must also call + // EndExtensions. + JXL_ASSERT(!extension_states_.IsBegun() || extension_states_.IsEnded()); + } else { + // Failed, undefined state: don't care whether EndExtensions was + // called. + } + + extension_states_.Pop(); + JXL_ASSERT(depth_ != 0); + depth_ -= 1; + + return ok; + } + + // For visitors accepting a const Visitor, need to const-cast so we can call + // the non-const Visitor::VisitFields. NOTE: C is not modified except the + // `all_default` field by CanEncodeVisitor. + Status VisitConst(const Fields& t) { return Visit(const_cast<Fields*>(&t)); } + + // Derived types (overridden by InitVisitor because it is unsafe to read + // from *value there) + + Status Bool(bool default_value, bool* JXL_RESTRICT value) override { + uint32_t bits = *value ? 1 : 0; + JXL_RETURN_IF_ERROR(Bits(1, static_cast<uint32_t>(default_value), &bits)); + JXL_DASSERT(bits <= 1); + *value = bits == 1; + return true; + } + + // Overridden by ReadVisitor and WriteVisitor. + // Called before any conditional visit based on "extensions". + // Overridden by ReadVisitor, CanEncodeVisitor and WriteVisitor. + Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override { + JXL_RETURN_IF_ERROR(U64(0, extensions)); + + extension_states_.Begin(); + return true; + } + + // Called after all extension fields (if any). Although non-extension + // fields could be visited afterward, we prefer the convention that + // extension fields are always the last to be visited. Overridden by + // ReadVisitor. + Status EndExtensions() override { + extension_states_.End(); + return true; + } + + private: + size_t depth_ = 0; // to check nesting + ExtensionStates extension_states_; +}; +} // namespace fields_internal + +Status CheckHasEnoughBits(Visitor* visitor, size_t bits); + +} // namespace jxl + +#endif // LIB_JXL_FIELDS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/fields_test.cc b/third_party/jpeg-xl/lib/jxl/fields_test.cc new file mode 100644 index 0000000000..b178a6bd6a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/fields_test.cc @@ -0,0 +1,429 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/fields.h" + +#include <stddef.h> +#include <stdint.h> + +#include <array> +#include <utility> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +// Ensures `value` round-trips and in exactly `expected_bits_written`. +void TestU32Coder(const uint32_t value, const size_t expected_bits_written) { + const U32Enc enc(Val(0), Bits(4), Val(0x7FFFFFFF), Bits(32)); + + BitWriter writer; + BitWriter::Allotment allotment( + &writer, RoundUpBitsToByteMultiple(U32Coder::MaxEncodedBits(enc))); + + size_t precheck_pos; + EXPECT_TRUE(U32Coder::CanEncode(enc, value, &precheck_pos)); + EXPECT_EQ(expected_bits_written, precheck_pos); + + EXPECT_TRUE(U32Coder::Write(enc, value, &writer)); + EXPECT_EQ(expected_bits_written, writer.BitsWritten()); + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, 0, nullptr); + + BitReader reader(writer.GetSpan()); + const uint32_t decoded_value = U32Coder::Read(enc, &reader); + EXPECT_EQ(value, decoded_value); + EXPECT_TRUE(reader.Close()); +} + +TEST(FieldsTest, U32CoderTest) { + TestU32Coder(0, 2); + TestU32Coder(1, 6); + TestU32Coder(15, 6); + TestU32Coder(0x7FFFFFFF, 2); + TestU32Coder(128, 34); + TestU32Coder(0x7FFFFFFEu, 34); + TestU32Coder(0x80000000u, 34); + TestU32Coder(0xFFFFFFFFu, 34); +} + +void TestU64Coder(const uint64_t value, const size_t expected_bits_written) { + BitWriter writer; + BitWriter::Allotment allotment( + &writer, RoundUpBitsToByteMultiple(U64Coder::MaxEncodedBits())); + + size_t precheck_pos; + EXPECT_TRUE(U64Coder::CanEncode(value, &precheck_pos)); + EXPECT_EQ(expected_bits_written, precheck_pos); + + EXPECT_TRUE(U64Coder::Write(value, &writer)); + EXPECT_EQ(expected_bits_written, writer.BitsWritten()); + + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, 0, nullptr); + + BitReader reader(writer.GetSpan()); + const uint64_t decoded_value = U64Coder::Read(&reader); + EXPECT_EQ(value, decoded_value); + EXPECT_TRUE(reader.Close()); +} + +TEST(FieldsTest, U64CoderTest) { + // Values that should take 2 bits (selector 00): 0 + TestU64Coder(0, 2); + + // Values that should take 6 bits (2 for selector, 4 for value): 1..16 + TestU64Coder(1, 6); + TestU64Coder(2, 6); + TestU64Coder(8, 6); + TestU64Coder(15, 6); + TestU64Coder(16, 6); + + // Values that should take 10 bits (2 for selector, 8 for value): 17..272 + TestU64Coder(17, 10); + TestU64Coder(18, 10); + TestU64Coder(100, 10); + TestU64Coder(271, 10); + TestU64Coder(272, 10); + + // Values that should take 15 bits (2 for selector, 12 for value, 1 for varint + // end): (0)..273..4095 + TestU64Coder(273, 15); + TestU64Coder(274, 15); + TestU64Coder(1000, 15); + TestU64Coder(4094, 15); + TestU64Coder(4095, 15); + + // Take 24 bits (of which 20 actual value): (0)..4096..1048575 + TestU64Coder(4096, 24); + TestU64Coder(4097, 24); + TestU64Coder(10000, 24); + TestU64Coder(1048574, 24); + TestU64Coder(1048575, 24); + + // Take 33 bits (of which 28 actual value): (0)..1048576..268435455 + TestU64Coder(1048576, 33); + TestU64Coder(1048577, 33); + TestU64Coder(10000000, 33); + TestU64Coder(268435454, 33); + TestU64Coder(268435455, 33); + + // Take 42 bits (of which 36 actual value): (0)..268435456..68719476735 + TestU64Coder(268435456ull, 42); + TestU64Coder(268435457ull, 42); + TestU64Coder(1000000000ull, 42); + TestU64Coder(68719476734ull, 42); + TestU64Coder(68719476735ull, 42); + + // Take 51 bits (of which 44 actual value): (0)..68719476736..17592186044415 + TestU64Coder(68719476736ull, 51); + TestU64Coder(68719476737ull, 51); + TestU64Coder(1000000000000ull, 51); + TestU64Coder(17592186044414ull, 51); + TestU64Coder(17592186044415ull, 51); + + // Take 60 bits (of which 52 actual value): + // (0)..17592186044416..4503599627370495 + TestU64Coder(17592186044416ull, 60); + TestU64Coder(17592186044417ull, 60); + TestU64Coder(100000000000000ull, 60); + TestU64Coder(4503599627370494ull, 60); + TestU64Coder(4503599627370495ull, 60); + + // Take 69 bits (of which 60 actual value): + // (0)..4503599627370496..1152921504606846975 + TestU64Coder(4503599627370496ull, 69); + TestU64Coder(4503599627370497ull, 69); + TestU64Coder(10000000000000000ull, 69); + TestU64Coder(1152921504606846974ull, 69); + TestU64Coder(1152921504606846975ull, 69); + + // Take 73 bits (of which 64 actual value): + // (0)..1152921504606846976..18446744073709551615 + TestU64Coder(1152921504606846976ull, 73); + TestU64Coder(1152921504606846977ull, 73); + TestU64Coder(10000000000000000000ull, 73); + TestU64Coder(18446744073709551614ull, 73); + TestU64Coder(18446744073709551615ull, 73); +} + +Status TestF16Coder(const float value) { + size_t max_encoded_bits; + // It is not a fatal error if it can't be encoded. + if (!F16Coder::CanEncode(value, &max_encoded_bits)) return false; + EXPECT_EQ(F16Coder::MaxEncodedBits(), max_encoded_bits); + + BitWriter writer; + BitWriter::Allotment allotment(&writer, + RoundUpBitsToByteMultiple(max_encoded_bits)); + + EXPECT_TRUE(F16Coder::Write(value, &writer)); + EXPECT_EQ(F16Coder::MaxEncodedBits(), writer.BitsWritten()); + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, 0, nullptr); + + BitReader reader(writer.GetSpan()); + float decoded_value; + EXPECT_TRUE(F16Coder::Read(&reader, &decoded_value)); + // All values we test can be represented exactly. + EXPECT_EQ(value, decoded_value); + EXPECT_TRUE(reader.Close()); + return true; +} + +TEST(FieldsTest, F16CoderTest) { + for (float sign : {-1.0f, 1.0f}) { + // (anything less than 1E-3 are subnormals) + for (float mag : {0.0f, 0.5f, 1.0f, 2.0f, 2.5f, 16.015625f, 1.0f / 4096, + 1.0f / 16384, 65504.0f}) { + EXPECT_TRUE(TestF16Coder(sign * mag)); + } + } + + // Out of range + EXPECT_FALSE(TestF16Coder(65504.01f)); + EXPECT_FALSE(TestF16Coder(-65505.0f)); +} + +// Ensures Read(Write()) returns the same fields. +TEST(FieldsTest, TestRoundtripSize) { + for (int i = 0; i < 8; i++) { + SizeHeader size; + ASSERT_TRUE(size.Set(123 + 77 * i, 7 + i)); + + size_t extension_bits = 999, total_bits = 999; // Initialize as garbage. + ASSERT_TRUE(Bundle::CanEncode(size, &extension_bits, &total_bits)); + EXPECT_EQ(0u, extension_bits); + + BitWriter writer; + ASSERT_TRUE(WriteSizeHeader(size, &writer, 0, nullptr)); + EXPECT_EQ(total_bits, writer.BitsWritten()); + writer.ZeroPadToByte(); + + SizeHeader size2; + BitReader reader(writer.GetSpan()); + ASSERT_TRUE(ReadSizeHeader(&reader, &size2)); + EXPECT_EQ(total_bits, reader.TotalBitsConsumed()); + EXPECT_TRUE(reader.Close()); + + EXPECT_EQ(size.xsize(), size2.xsize()); + EXPECT_EQ(size.ysize(), size2.ysize()); + } +} + +// Ensure all values can be reached by the encoding. +TEST(FieldsTest, TestCropRect) { + CodecMetadata metadata; + for (int32_t i = -999; i < 19000; ++i) { + FrameHeader f(&metadata); + f.custom_size_or_origin = true; + f.frame_origin.x0 = i; + f.frame_origin.y0 = i; + f.frame_size.xsize = 1000 + i; + f.frame_size.ysize = 1000 + i; + size_t extension_bits = 0, total_bits = 0; + ASSERT_TRUE(Bundle::CanEncode(f, &extension_bits, &total_bits)); + EXPECT_EQ(0u, extension_bits); + EXPECT_GE(total_bits, 9u); + } +} +TEST(FieldsTest, TestPreview) { + // (div8 cannot represent 4360, but !div8 can go a little higher) + for (uint32_t i = 1; i < 4360; ++i) { + PreviewHeader p; + ASSERT_TRUE(p.Set(i, i)); + size_t extension_bits = 0, total_bits = 0; + ASSERT_TRUE(Bundle::CanEncode(p, &extension_bits, &total_bits)); + EXPECT_EQ(0u, extension_bits); + EXPECT_GE(total_bits, 6u); + } +} + +// Ensures Read(Write()) returns the same fields. +TEST(FieldsTest, TestRoundtripFrame) { + CodecMetadata metadata; + FrameHeader h(&metadata); + h.extensions = 0x800; + + size_t extension_bits = 999, total_bits = 999; // Initialize as garbage. + ASSERT_TRUE(Bundle::CanEncode(h, &extension_bits, &total_bits)); + EXPECT_EQ(0u, extension_bits); + BitWriter writer; + ASSERT_TRUE(WriteFrameHeader(h, &writer, nullptr)); + EXPECT_EQ(total_bits, writer.BitsWritten()); + writer.ZeroPadToByte(); + + FrameHeader h2(&metadata); + BitReader reader(writer.GetSpan()); + ASSERT_TRUE(ReadFrameHeader(&reader, &h2)); + EXPECT_EQ(total_bits, reader.TotalBitsConsumed()); + EXPECT_TRUE(reader.Close()); + + EXPECT_EQ(h.extensions, h2.extensions); + EXPECT_EQ(h.flags, h2.flags); +} + +#ifndef JXL_CRASH_ON_ERROR +// Ensure out-of-bounds values cause an error. +TEST(FieldsTest, TestOutOfRange) { + SizeHeader h; + ASSERT_TRUE(h.Set(0xFFFFFFFFull, 0xFFFFFFFFull)); + size_t extension_bits = 999, total_bits = 999; // Initialize as garbage. + ASSERT_FALSE(Bundle::CanEncode(h, &extension_bits, &total_bits)); +} +#endif + +struct OldBundle : public Fields { + OldBundle() { Bundle::Init(this); } + JXL_FIELDS_NAME(OldBundle) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Bits(2), Bits(3), Bits(4), 1, &old_small)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(1.125f, &old_f)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(7), Bits(12), Bits(16), Bits(32), 0, &old_large)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + return visitor->EndExtensions(); + } + + uint32_t old_small; + float old_f; + uint32_t old_large; + uint64_t extensions; +}; + +struct NewBundle : public Fields { + NewBundle() { Bundle::Init(this); } + JXL_FIELDS_NAME(NewBundle) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Bits(2), Bits(3), Bits(4), 1, &old_small)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(1.125f, &old_f)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(7), Bits(12), Bits(16), Bits(32), 0, &old_large)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + if (visitor->Conditional(extensions & 1)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(2), Bits(2), Bits(3), Bits(4), 2, &new_small)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(-2.0f, &new_f)); + } + if (visitor->Conditional(extensions & 2)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(9), Bits(12), Bits(16), Bits(32), 0, &new_large)); + } + return visitor->EndExtensions(); + } + + uint32_t old_small; + float old_f; + uint32_t old_large; + uint64_t extensions; + + // If extensions & 1 + uint32_t new_small = 2; + float new_f = -2.0f; + // If extensions & 2 + uint32_t new_large = 0; +}; + +TEST(FieldsTest, TestNewDecoderOldData) { + OldBundle old_bundle; + old_bundle.old_large = 123; + old_bundle.old_f = 3.75f; + old_bundle.extensions = 0; + + // Write to bit stream + const size_t kMaxOutBytes = 999; + BitWriter writer; + // Make sure values are initialized by code under test. + size_t extension_bits = 12345, total_bits = 12345; + ASSERT_TRUE(Bundle::CanEncode(old_bundle, &extension_bits, &total_bits)); + ASSERT_LE(total_bits, kMaxOutBytes * kBitsPerByte); + EXPECT_EQ(0u, extension_bits); + AuxOut aux_out; + ASSERT_TRUE(Bundle::Write(old_bundle, &writer, kLayerHeader, &aux_out)); + + BitWriter::Allotment allotment(&writer, + kMaxOutBytes * kBitsPerByte - total_bits); + writer.Write(20, 0xA55A); // sentinel + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, kLayerHeader, nullptr); + + ASSERT_LE(writer.GetSpan().size(), kMaxOutBytes); + BitReader reader(writer.GetSpan()); + NewBundle new_bundle; + ASSERT_TRUE(Bundle::Read(&reader, &new_bundle)); + EXPECT_EQ(reader.TotalBitsConsumed(), + aux_out.layers[kLayerHeader].total_bits); + EXPECT_EQ(reader.ReadBits(20), 0xA55Au); + EXPECT_TRUE(reader.Close()); + + // Old fields are the same in both + EXPECT_EQ(old_bundle.extensions, new_bundle.extensions); + EXPECT_EQ(old_bundle.old_small, new_bundle.old_small); + EXPECT_EQ(old_bundle.old_f, new_bundle.old_f); + EXPECT_EQ(old_bundle.old_large, new_bundle.old_large); + // New fields match their defaults + EXPECT_EQ(2u, new_bundle.new_small); + EXPECT_EQ(-2.0f, new_bundle.new_f); + EXPECT_EQ(0u, new_bundle.new_large); +} + +TEST(FieldsTest, TestOldDecoderNewData) { + NewBundle new_bundle; + new_bundle.old_large = 123; + new_bundle.extensions = 3; + new_bundle.new_f = 999.0f; + new_bundle.new_large = 456; + + // Write to bit stream + constexpr size_t kMaxOutBytes = 999; + BitWriter writer; + // Make sure values are initialized by code under test. + size_t extension_bits = 12345, total_bits = 12345; + ASSERT_TRUE(Bundle::CanEncode(new_bundle, &extension_bits, &total_bits)); + EXPECT_NE(0u, extension_bits); + AuxOut aux_out; + ASSERT_TRUE(Bundle::Write(new_bundle, &writer, kLayerHeader, &aux_out)); + ASSERT_LE(aux_out.layers[kLayerHeader].total_bits, + kMaxOutBytes * kBitsPerByte); + + BitWriter::Allotment allotment( + &writer, + kMaxOutBytes * kBitsPerByte - aux_out.layers[kLayerHeader].total_bits); + // Ensure Read skips the additional fields + writer.Write(20, 0xA55A); // sentinel + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, kLayerHeader, nullptr); + + BitReader reader(writer.GetSpan()); + OldBundle old_bundle; + ASSERT_TRUE(Bundle::Read(&reader, &old_bundle)); + EXPECT_EQ(reader.TotalBitsConsumed(), + aux_out.layers[kLayerHeader].total_bits); + EXPECT_EQ(reader.ReadBits(20), 0xA55Au); + EXPECT_TRUE(reader.Close()); + + // Old fields are the same in both + EXPECT_EQ(new_bundle.extensions, old_bundle.extensions); + EXPECT_EQ(new_bundle.old_small, old_bundle.old_small); + EXPECT_EQ(new_bundle.old_f, old_bundle.old_f); + EXPECT_EQ(new_bundle.old_large, old_bundle.old_large); + // (Can't check new fields because old decoder doesn't know about them) +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/frame_dimensions.h b/third_party/jpeg-xl/lib/jxl/frame_dimensions.h new file mode 100644 index 0000000000..8440a95463 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/frame_dimensions.h @@ -0,0 +1,117 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_FRAME_DIMENSIONS_H_ +#define LIB_JXL_FRAME_DIMENSIONS_H_ + +// FrameDimensions struct, block and group dimensions constants. + +#include <cstddef> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/image.h" + +namespace jxl { +// Some enums and typedefs used by more than one header file. + +// Block is the square grid of pixels to which an "energy compaction" +// transformation (e.g. DCT) is applied. Each block has its own AC quantizer. +constexpr size_t kBlockDim = 8; + +constexpr size_t kDCTBlockSize = kBlockDim * kBlockDim; + +constexpr size_t kGroupDim = 256; +static_assert(kGroupDim % kBlockDim == 0, + "Group dim should be divisible by block dim"); +constexpr size_t kGroupDimInBlocks = kGroupDim / kBlockDim; + +// Dimensions of a frame, in pixels, and other derived dimensions. +// Computed from FrameHeader. +// TODO(veluca): add extra channels. +struct FrameDimensions { + void Set(size_t xsize, size_t ysize, size_t group_size_shift, + size_t max_hshift, size_t max_vshift, bool modular_mode, + size_t upsampling) { + group_dim = (kGroupDim >> 1) << group_size_shift; + dc_group_dim = group_dim * kBlockDim; + xsize_upsampled = xsize; + ysize_upsampled = ysize; + this->xsize = DivCeil(xsize, upsampling); + this->ysize = DivCeil(ysize, upsampling); + xsize_blocks = DivCeil(this->xsize, kBlockDim << max_hshift) << max_hshift; + ysize_blocks = DivCeil(this->ysize, kBlockDim << max_vshift) << max_vshift; + xsize_padded = xsize_blocks * kBlockDim; + ysize_padded = ysize_blocks * kBlockDim; + if (modular_mode) { + // Modular mode doesn't have any padding. + xsize_padded = this->xsize; + ysize_padded = this->ysize; + } + xsize_upsampled_padded = xsize_padded * upsampling; + ysize_upsampled_padded = ysize_padded * upsampling; + xsize_groups = DivCeil(this->xsize, group_dim); + ysize_groups = DivCeil(this->ysize, group_dim); + xsize_dc_groups = DivCeil(xsize_blocks, group_dim); + ysize_dc_groups = DivCeil(ysize_blocks, group_dim); + num_groups = xsize_groups * ysize_groups; + num_dc_groups = xsize_dc_groups * ysize_dc_groups; + } + + Rect GroupRect(size_t group_index) const { + const size_t gx = group_index % xsize_groups; + const size_t gy = group_index / xsize_groups; + const Rect rect(gx * group_dim, gy * group_dim, group_dim, group_dim, xsize, + ysize); + return rect; + } + + Rect BlockGroupRect(size_t group_index) const { + const size_t gx = group_index % xsize_groups; + const size_t gy = group_index / xsize_groups; + const Rect rect(gx * (group_dim >> 3), gy * (group_dim >> 3), + group_dim >> 3, group_dim >> 3, xsize_blocks, ysize_blocks); + return rect; + } + + Rect DCGroupRect(size_t group_index) const { + const size_t gx = group_index % xsize_dc_groups; + const size_t gy = group_index / xsize_dc_groups; + const Rect rect(gx * group_dim, gy * group_dim, group_dim, group_dim, + xsize_blocks, ysize_blocks); + return rect; + } + + // Image size without any upsampling, i.e. original_size / upsampling. + size_t xsize; + size_t ysize; + // Original image size. + size_t xsize_upsampled; + size_t ysize_upsampled; + // Image size after upsampling the padded image. + size_t xsize_upsampled_padded; + size_t ysize_upsampled_padded; + // Image size after padding to a multiple of kBlockDim (if VarDCT mode). + size_t xsize_padded; + size_t ysize_padded; + // Image size in kBlockDim blocks. + size_t xsize_blocks; + size_t ysize_blocks; + // Image size in number of groups. + size_t xsize_groups; + size_t ysize_groups; + // Image size in number of DC groups. + size_t xsize_dc_groups; + size_t ysize_dc_groups; + // Number of AC or DC groups. + size_t num_groups; + size_t num_dc_groups; + // Size of a group. + size_t group_dim; + size_t dc_group_dim; +}; + +} // namespace jxl + +#endif // LIB_JXL_FRAME_DIMENSIONS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/frame_header.cc b/third_party/jpeg-xl/lib/jxl/frame_header.cc new file mode 100644 index 0000000000..a9e79ff1b8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/frame_header.cc @@ -0,0 +1,501 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/frame_header.h" + +#include <sstream> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" // kMaxNumPasses +#include "lib/jxl/fields.h" +#include "lib/jxl/pack_signed.h" + +namespace jxl { + +constexpr uint8_t YCbCrChromaSubsampling::kHShift[] = {0, 1, 1, 0}; +constexpr uint8_t YCbCrChromaSubsampling::kVShift[] = {0, 1, 0, 1}; + +static Status VisitBlendMode(Visitor* JXL_RESTRICT visitor, + BlendMode default_value, BlendMode* blend_mode) { + uint32_t encoded = static_cast<uint32_t>(*blend_mode); + + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(static_cast<uint32_t>(BlendMode::kReplace)), + Val(static_cast<uint32_t>(BlendMode::kAdd)), + Val(static_cast<uint32_t>(BlendMode::kBlend)), BitsOffset(2, 3), + static_cast<uint32_t>(default_value), &encoded)); + if (encoded > 4) { + return JXL_FAILURE("Invalid blend_mode"); + } + *blend_mode = static_cast<BlendMode>(encoded); + return true; +} + +static Status VisitFrameType(Visitor* JXL_RESTRICT visitor, + FrameType default_value, FrameType* frame_type) { + uint32_t encoded = static_cast<uint32_t>(*frame_type); + + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(static_cast<uint32_t>(FrameType::kRegularFrame)), + Val(static_cast<uint32_t>(FrameType::kDCFrame)), + Val(static_cast<uint32_t>(FrameType::kReferenceOnly)), + Val(static_cast<uint32_t>(FrameType::kSkipProgressive)), + static_cast<uint32_t>(default_value), &encoded)); + *frame_type = static_cast<FrameType>(encoded); + return true; +} + +BlendingInfo::BlendingInfo() { Bundle::Init(this); } + +Status BlendingInfo::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR( + VisitBlendMode(visitor, BlendMode::kReplace, &mode)); + if (visitor->Conditional(nonserialized_num_extra_channels > 0 && + (mode == BlendMode::kBlend || + mode == BlendMode::kAlphaWeightedAdd))) { + // Up to 11 alpha channels for blending. + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(0), Val(1), Val(2), BitsOffset(3, 3), 0, &alpha_channel)); + if (visitor->IsReading() && + alpha_channel >= nonserialized_num_extra_channels) { + return JXL_FAILURE("Invalid alpha channel for blending"); + } + } + if (visitor->Conditional((nonserialized_num_extra_channels > 0 && + (mode == BlendMode::kBlend || + mode == BlendMode::kAlphaWeightedAdd)) || + mode == BlendMode::kMul)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &clamp)); + } + // 'old' frame for blending. Only necessary if this is not a full frame, or + // blending is not kReplace. + if (visitor->Conditional(mode != BlendMode::kReplace || + nonserialized_is_partial_frame)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Val(2), Val(3), 0, &source)); + } + return true; +} + +#if JXL_DEBUG_V_LEVEL >= 1 +std::string BlendingInfo::DebugString() const { + std::ostringstream os; + os << (mode == BlendMode::kReplace ? "Replace" + : mode == BlendMode::kAdd ? "Add" + : mode == BlendMode::kBlend ? "Blend" + : mode == BlendMode::kAlphaWeightedAdd ? "AlphaWeightedAdd" + : "Mul"); + if (nonserialized_num_extra_channels > 0 && + (mode == BlendMode::kBlend || mode == BlendMode::kAlphaWeightedAdd)) { + os << ",alpha=" << alpha_channel << ",clamp=" << clamp; + } else if (mode == BlendMode::kMul) { + os << ",clamp=" << clamp; + } + if (mode != BlendMode::kReplace || nonserialized_is_partial_frame) { + os << ",source=" << source; + } + return os.str(); +} +#endif + +AnimationFrame::AnimationFrame(const CodecMetadata* metadata) + : nonserialized_metadata(metadata) { + Bundle::Init(this); +} +Status AnimationFrame::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->Conditional(nonserialized_metadata != nullptr && + nonserialized_metadata->m.have_animation)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Bits(8), Bits(32), 0, &duration)); + } + + if (visitor->Conditional( + nonserialized_metadata != nullptr && + nonserialized_metadata->m.animation.have_timecodes)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(32, 0, &timecode)); + } + return true; +} + +YCbCrChromaSubsampling::YCbCrChromaSubsampling() { Bundle::Init(this); } +Passes::Passes() { Bundle::Init(this); } +Status Passes::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), BitsOffset(3, 4), 1, &num_passes)); + JXL_ASSERT(num_passes <= kMaxNumPasses); // Cannot happen when reading + + if (visitor->Conditional(num_passes != 1)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(0), Val(1), Val(2), BitsOffset(1, 3), 0, &num_downsample)); + JXL_ASSERT(num_downsample <= 4); // 1,2,4,8 + if (num_downsample > num_passes) { + return JXL_FAILURE("num_downsample %u > num_passes %u", num_downsample, + num_passes); + } + + for (uint32_t i = 0; i < num_passes - 1; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 0, &shift[i])); + } + shift[num_passes - 1] = 0; + + for (uint32_t i = 0; i < num_downsample; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(4), Val(8), 1, &downsample[i])); + if (i > 0 && downsample[i] >= downsample[i - 1]) { + return JXL_FAILURE("downsample sequence should be decreasing"); + } + } + for (uint32_t i = 0; i < num_downsample; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Val(2), Bits(3), 0, &last_pass[i])); + if (i > 0 && last_pass[i] <= last_pass[i - 1]) { + return JXL_FAILURE("last_pass sequence should be increasing"); + } + if (last_pass[i] >= num_passes) { + return JXL_FAILURE("last_pass %u >= num_passes %u", last_pass[i], + num_passes); + } + } + } + + return true; +} + +#if JXL_DEBUG_V_LEVEL >= 1 +std::string Passes::DebugString() const { + std::ostringstream os; + os << "p=" << num_passes; + if (num_downsample) { + os << ",ds="; + for (uint32_t i = 0; i < num_downsample; ++i) { + os << last_pass[i] << ":" << downsample[i]; + if (i + 1 < num_downsample) os << ";"; + } + } + bool have_shifts = false; + for (uint32_t i = 0; i < num_passes; ++i) { + if (shift[i]) have_shifts = true; + } + if (have_shifts) { + os << ",shifts="; + for (uint32_t i = 0; i < num_passes; ++i) { + os << shift[i]; + if (i + 1 < num_passes) os << ";"; + } + } + return os.str(); +} +#endif + +FrameHeader::FrameHeader(const CodecMetadata* metadata) + : animation_frame(metadata), nonserialized_metadata(metadata) { + Bundle::Init(this); +} + +Status ReadFrameHeader(BitReader* JXL_RESTRICT reader, + FrameHeader* JXL_RESTRICT frame) { + return Bundle::Read(reader, frame); +} + +Status FrameHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR( + VisitFrameType(visitor, FrameType::kRegularFrame, &frame_type)); + if (visitor->IsReading() && nonserialized_is_preview && + frame_type != kRegularFrame) { + return JXL_FAILURE("Only regular frame could be a preview"); + } + + // FrameEncoding. + bool is_modular = (encoding == FrameEncoding::kModular); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &is_modular)); + encoding = (is_modular ? FrameEncoding::kModular : FrameEncoding::kVarDCT); + + // Flags + JXL_QUIET_RETURN_IF_ERROR(visitor->U64(0, &flags)); + + // Color transform + bool xyb_encoded = nonserialized_metadata == nullptr || + nonserialized_metadata->m.xyb_encoded; + + if (xyb_encoded) { + color_transform = ColorTransform::kXYB; + } else { + // Alternate if kYCbCr. + bool alternate = color_transform == ColorTransform::kYCbCr; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &alternate)); + color_transform = + (alternate ? ColorTransform::kYCbCr : ColorTransform::kNone); + } + + // Chroma subsampling for YCbCr, if no DC frame is used. + if (visitor->Conditional(color_transform == ColorTransform::kYCbCr && + ((flags & kUseDcFrame) == 0))) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&chroma_subsampling)); + } + + size_t num_extra_channels = + nonserialized_metadata != nullptr + ? nonserialized_metadata->m.extra_channel_info.size() + : 0; + + // Upsampling + if (visitor->Conditional((flags & kUseDcFrame) == 0)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(4), Val(8), 1, &upsampling)); + if (nonserialized_metadata != nullptr && + visitor->Conditional(num_extra_channels != 0)) { + const std::vector<ExtraChannelInfo>& extra_channels = + nonserialized_metadata->m.extra_channel_info; + extra_channel_upsampling.resize(extra_channels.size(), 1); + for (size_t i = 0; i < extra_channels.size(); ++i) { + uint32_t dim_shift = + nonserialized_metadata->m.extra_channel_info[i].dim_shift; + uint32_t& ec_upsampling = extra_channel_upsampling[i]; + ec_upsampling >>= dim_shift; + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(4), Val(8), 1, &ec_upsampling)); + ec_upsampling <<= dim_shift; + if (ec_upsampling < upsampling) { + return JXL_FAILURE( + "EC upsampling (%u) < color upsampling (%u), which is invalid.", + ec_upsampling, upsampling); + } + if (ec_upsampling > 8) { + return JXL_FAILURE("EC upsampling too large (%u)", ec_upsampling); + } + } + } else { + extra_channel_upsampling.clear(); + } + } + + // Modular- or VarDCT-specific data. + if (visitor->Conditional(encoding == FrameEncoding::kModular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 1, &group_size_shift)); + } + if (visitor->Conditional(encoding == FrameEncoding::kVarDCT && + color_transform == ColorTransform::kXYB)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 3, &x_qm_scale)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 2, &b_qm_scale)); + } else { + x_qm_scale = b_qm_scale = 2; // noop + } + + // Not useful for kPatchSource + if (visitor->Conditional(frame_type != FrameType::kReferenceOnly)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&passes)); + } + + if (visitor->Conditional(frame_type == FrameType::kDCFrame)) { + // Up to 4 pyramid levels - for up to 16384x downsampling. + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 1, &dc_level)); + } + if (frame_type != FrameType::kDCFrame) { + dc_level = 0; + } + + bool is_partial_frame = false; + if (visitor->Conditional(frame_type != FrameType::kDCFrame)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &custom_size_or_origin)); + if (visitor->Conditional(custom_size_or_origin)) { + const U32Enc enc(Bits(8), BitsOffset(11, 256), BitsOffset(14, 2304), + BitsOffset(30, 18688)); + // Frame offset, only if kRegularFrame or kSkipProgressive. + if (visitor->Conditional(frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive)) { + uint32_t ux0 = PackSigned(frame_origin.x0); + uint32_t uy0 = PackSigned(frame_origin.y0); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &ux0)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &uy0)); + frame_origin.x0 = UnpackSigned(ux0); + frame_origin.y0 = UnpackSigned(uy0); + } + // Frame size + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &frame_size.xsize)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(enc, 0, &frame_size.ysize)); + if (custom_size_or_origin && + (frame_size.xsize == 0 || frame_size.ysize == 0)) { + return JXL_FAILURE( + "Invalid crop dimensions for frame: zero width or height"); + } + int32_t image_xsize = default_xsize(); + int32_t image_ysize = default_ysize(); + if (frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive) { + is_partial_frame |= frame_origin.x0 > 0; + is_partial_frame |= frame_origin.y0 > 0; + is_partial_frame |= (static_cast<int32_t>(frame_size.xsize) + + frame_origin.x0) < image_xsize; + is_partial_frame |= (static_cast<int32_t>(frame_size.ysize) + + frame_origin.y0) < image_ysize; + } + } + } + + // Blending info, animation info and whether this is the last frame or not. + if (visitor->Conditional(frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive)) { + blending_info.nonserialized_num_extra_channels = num_extra_channels; + blending_info.nonserialized_is_partial_frame = is_partial_frame; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&blending_info)); + bool replace_all = (blending_info.mode == BlendMode::kReplace); + extra_channel_blending_info.resize(num_extra_channels); + for (size_t i = 0; i < num_extra_channels; i++) { + auto& ec_blending_info = extra_channel_blending_info[i]; + ec_blending_info.nonserialized_is_partial_frame = is_partial_frame; + ec_blending_info.nonserialized_num_extra_channels = num_extra_channels; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&ec_blending_info)); + replace_all &= (ec_blending_info.mode == BlendMode::kReplace); + } + if (visitor->IsReading() && nonserialized_is_preview) { + if (!replace_all || custom_size_or_origin) { + return JXL_FAILURE("Preview is not compatible with blending"); + } + } + if (visitor->Conditional(nonserialized_metadata != nullptr && + nonserialized_metadata->m.have_animation)) { + animation_frame.nonserialized_metadata = nonserialized_metadata; + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&animation_frame)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(true, &is_last)); + } else { + is_last = false; + } + + // ID of that can be used to refer to this frame. 0 for a non-zero-duration + // frame means that it will not be referenced. Not necessary for the last + // frame. + if (visitor->Conditional(frame_type != kDCFrame && !is_last)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), Val(2), Val(3), 0, &save_as_reference)); + } + + // If this frame is not blended on another frame post-color-transform, it may + // be stored for being referenced either before or after the color transform. + // If it is blended post-color-transform, it must be blended after. It must + // also be blended after if this is a kRegular frame that does not cover the + // full frame, as samples outside the partial region are from a + // post-color-transform frame. + if (frame_type != FrameType::kDCFrame) { + if (visitor->Conditional(CanBeReferenced() && + blending_info.mode == BlendMode::kReplace && + !is_partial_frame && + (frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive))) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bool(false, &save_before_color_transform)); + } else if (visitor->Conditional(frame_type == FrameType::kReferenceOnly)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bool(true, &save_before_color_transform)); + if (!save_before_color_transform && + (frame_size.xsize < nonserialized_metadata->xsize() || + frame_size.ysize < nonserialized_metadata->ysize() || + frame_origin.x0 != 0 || frame_origin.y0 != 0)) { + return JXL_FAILURE( + "non-patch reference frame with invalid crop: %" PRIuS "x%" PRIuS + "%+d%+d", + static_cast<size_t>(frame_size.xsize), + static_cast<size_t>(frame_size.ysize), + static_cast<int>(frame_origin.x0), + static_cast<int>(frame_origin.y0)); + } + } + } else { + save_before_color_transform = true; + } + + JXL_QUIET_RETURN_IF_ERROR(VisitNameString(visitor, &name)); + + loop_filter.nonserialized_is_modular = is_modular; + JXL_RETURN_IF_ERROR(visitor->VisitNested(&loop_filter)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + // Extensions: in chronological order of being added to the format. + return visitor->EndExtensions(); +} + +#if JXL_DEBUG_V_LEVEL >= 1 +std::string FrameHeader::DebugString() const { + std::ostringstream os; + os << (encoding == FrameEncoding::kVarDCT ? "VarDCT" : "Modular"); + os << ","; + os << (frame_type == FrameType::kRegularFrame ? "Regular" + : frame_type == FrameType::kDCFrame ? "DC" + : frame_type == FrameType::kReferenceOnly ? "Reference" + : "SkipProgressive"); + if (frame_type == FrameType::kDCFrame) { + os << "(lv" << dc_level << ")"; + } + + if (flags) { + os << ","; + uint32_t remaining = flags; + +#define TEST_FLAG(name) \ + if (flags & Flags::k##name) { \ + remaining &= ~Flags::k##name; \ + os << #name; \ + if (remaining) os << "|"; \ + } + TEST_FLAG(Noise); + TEST_FLAG(Patches); + TEST_FLAG(Splines); + TEST_FLAG(UseDcFrame); + TEST_FLAG(SkipAdaptiveDCSmoothing); +#undef TEST_FLAG + } + + os << ","; + os << (color_transform == ColorTransform::kXYB ? "XYB" + : color_transform == ColorTransform::kYCbCr ? "YCbCr" + : "None"); + + if (encoding == FrameEncoding::kModular) { + os << ",shift=" << group_size_shift; + } else if (color_transform == ColorTransform::kXYB) { + os << ",qm=" << x_qm_scale << ";" << b_qm_scale; + } + if (frame_type != FrameType::kReferenceOnly) { + os << "," << passes.DebugString(); + } + if (custom_size_or_origin) { + os << ",xs=" << frame_size.xsize; + os << ",ys=" << frame_size.ysize; + if (frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive) { + os << ",x0=" << frame_origin.x0; + os << ",y0=" << frame_origin.y0; + } + } + if (upsampling > 1) os << ",up=" << upsampling; + if (loop_filter.gab) os << ",Gaborish"; + if (loop_filter.epf_iters > 0) os << ",epf=" << loop_filter.epf_iters; + if (animation_frame.duration > 0) os << ",dur=" << animation_frame.duration; + if (frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive) { + os << ","; + os << blending_info.DebugString(); + for (size_t i = 0; i < extra_channel_blending_info.size(); ++i) { + os << (i == 0 ? "[" : ";"); + os << extra_channel_blending_info[i].DebugString(); + if (i + 1 == extra_channel_blending_info.size()) os << "]"; + } + } + if (save_as_reference > 0) os << ",ref=" << save_as_reference; + os << "," << (save_before_color_transform ? "before" : "after") << "_ct"; + if (is_last) os << ",last"; + return os.str(); +} +#endif + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/frame_header.h b/third_party/jpeg-xl/lib/jxl/frame_header.h new file mode 100644 index 0000000000..b246bf813e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/frame_header.h @@ -0,0 +1,504 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_FRAME_HEADER_H_ +#define LIB_JXL_FRAME_HEADER_H_ + +// Frame header with backward and forward-compatible extension capability and +// compressed integer fields. + +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <string> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/common.h" // kMaxNumPasses +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/loop_filter.h" + +namespace jxl { + +// TODO(eustas): move to proper place? +// Also used by extra channel names. +static inline Status VisitNameString(Visitor* JXL_RESTRICT visitor, + std::string* name) { + uint32_t name_length = static_cast<uint32_t>(name->length()); + // Allows layer name lengths up to 1071 bytes + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Bits(4), BitsOffset(5, 16), + BitsOffset(10, 48), 0, &name_length)); + if (visitor->IsReading()) { + name->resize(name_length); + } + for (size_t i = 0; i < name_length; i++) { + uint32_t c = static_cast<uint8_t>((*name)[i]); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(8, 0, &c)); + (*name)[i] = static_cast<char>(c); + } + return true; +} + +enum class FrameEncoding : uint32_t { + kVarDCT, + kModular, +}; + +enum class ColorTransform : uint32_t { + kXYB, // Values are encoded with XYB. May only be used if + // ImageBundle::xyb_encoded. + kNone, // Values are encoded according to the attached color profile. May + // only be used if !ImageBundle::xyb_encoded. + kYCbCr, // Values are encoded according to the attached color profile, but + // transformed to YCbCr. May only be used if + // !ImageBundle::xyb_encoded. +}; + +inline std::array<int, 3> JpegOrder(ColorTransform ct, bool is_gray) { + if (is_gray) { + return {{0, 0, 0}}; + } + JXL_ASSERT(ct != ColorTransform::kXYB); + if (ct == ColorTransform::kYCbCr) { + return {{1, 0, 2}}; + } else { + return {{0, 1, 2}}; + } +} + +struct YCbCrChromaSubsampling : public Fields { + YCbCrChromaSubsampling(); + JXL_FIELDS_NAME(YCbCrChromaSubsampling) + size_t HShift(size_t c) const { return maxhs_ - kHShift[channel_mode_[c]]; } + size_t VShift(size_t c) const { return maxvs_ - kVShift[channel_mode_[c]]; } + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override { + // TODO(veluca): consider allowing 4x downsamples + for (size_t i = 0; i < 3; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 0, &channel_mode_[i])); + } + Recompute(); + return true; + } + + uint8_t MaxHShift() const { return maxhs_; } + uint8_t MaxVShift() const { return maxvs_; } + + uint8_t RawHShift(size_t c) const { return kHShift[channel_mode_[c]]; } + uint8_t RawVShift(size_t c) const { return kVShift[channel_mode_[c]]; } + + // Uses JPEG channel order (Y, Cb, Cr). + Status Set(const uint8_t* hsample, const uint8_t* vsample) { + for (size_t c = 0; c < 3; c++) { + size_t cjpeg = c < 2 ? c ^ 1 : c; + size_t i = 0; + for (; i < 4; i++) { + if (1 << kHShift[i] == hsample[cjpeg] && + 1 << kVShift[i] == vsample[cjpeg]) { + channel_mode_[c] = i; + break; + } + } + if (i == 4) { + return JXL_FAILURE("Invalid subsample mode"); + } + } + Recompute(); + return true; + } + + bool Is444() const { + return HShift(0) == 0 && VShift(0) == 0 && // Cb + HShift(2) == 0 && VShift(2) == 0 && // Cr + HShift(1) == 0 && VShift(1) == 0; // Y + } + + bool Is420() const { + return HShift(0) == 1 && VShift(0) == 1 && // Cb + HShift(2) == 1 && VShift(2) == 1 && // Cr + HShift(1) == 0 && VShift(1) == 0; // Y + } + + bool Is422() const { + return HShift(0) == 1 && VShift(0) == 0 && // Cb + HShift(2) == 1 && VShift(2) == 0 && // Cr + HShift(1) == 0 && VShift(1) == 0; // Y + } + + bool Is440() const { + return HShift(0) == 0 && VShift(0) == 1 && // Cb + HShift(2) == 0 && VShift(2) == 1 && // Cr + HShift(1) == 0 && VShift(1) == 0; // Y + } + + std::string DebugString() const { + if (Is444()) return "444"; + if (Is420()) return "420"; + if (Is422()) return "422"; + if (Is440()) return "440"; + return "cs" + std::to_string(channel_mode_[0]) + + std::to_string(channel_mode_[1]) + std::to_string(channel_mode_[2]); + } + + private: + void Recompute() { + maxhs_ = 0; + maxvs_ = 0; + for (size_t i = 0; i < 3; i++) { + maxhs_ = std::max(maxhs_, kHShift[channel_mode_[i]]); + maxvs_ = std::max(maxvs_, kVShift[channel_mode_[i]]); + } + } + static const uint8_t kHShift[4]; + static const uint8_t kVShift[4]; + uint32_t channel_mode_[3]; + uint8_t maxhs_; + uint8_t maxvs_; +}; + +// Indicates how to combine the current frame with a previously-saved one. Can +// be independently controlled for color and extra channels. Formulas are +// indicative and treat alpha as if it is in range 0.0-1.0. In descriptions +// below, alpha channel is the extra channel of type alpha used for blending +// according to the blend_channel, or fully opaque if there is no alpha channel. +// The blending specified here is used for performing blending *after* color +// transforms - in linear sRGB if blending a XYB-encoded frame on another +// XYB-encoded frame, in sRGB if blending a frame with kColorSpace == kSRGB, or +// in the original colorspace otherwise. Blending in XYB or YCbCr is done by +// using patches. +enum class BlendMode { + // The new values (in the crop) replace the old ones: sample = new + kReplace = 0, + // The new values (in the crop) get added to the old ones: sample = old + new + kAdd = 1, + // The new values (in the crop) replace the old ones if alpha>0: + // For the alpha channel that is used as source: + // alpha = old + new * (1 - old) + // For other channels if !alpha_associated: + // sample = ((1 - new_alpha) * old * old_alpha + new_alpha * new) / alpha + // For other channels if alpha_associated: + // sample = (1 - new_alpha) * old + new + // The alpha formula applies to the alpha used for the division in the other + // channels formula, and applies to the alpha channel itself if its + // blend_channel value matches itself. + kBlend = 2, + // The new values (in the crop) are added to the old ones if alpha>0: + // For the alpha channel that is used as source: + // sample = sample = old + new * (1 - old) + // For other channels: sample = old + alpha * new + kAlphaWeightedAdd = 3, + // The new values (in the crop) get multiplied by the old ones: + // sample = old * new + // The range of the new value matters for multiplication purposes, and its + // nominal range of 0..1 is computed the same way as this is done for the + // alpha values in kBlend and kAlphaWeightedAdd. + // If using kMul as a blend mode for color channels, no color transform is + // performed on the current frame. + kMul = 4, +}; + +struct BlendingInfo : public Fields { + BlendingInfo(); + JXL_FIELDS_NAME(BlendingInfo) + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + BlendMode mode; + // Which extra channel to use as alpha channel for blending, only encoded + // for blend modes that involve alpha and if there are more than 1 extra + // channels. + uint32_t alpha_channel; + // Clamp alpha or channel values to 0-1 range. + bool clamp; + // Frame ID to copy from (0-3). Only encoded if blend_mode is not kReplace. + uint32_t source; + + std::string DebugString() const; + + size_t nonserialized_num_extra_channels = 0; + bool nonserialized_is_partial_frame = false; +}; + +// Origin of the current frame. Not present for frames of type +// kOnlyPatches. +struct FrameOrigin { + int32_t x0, y0; // can be negative. +}; + +// Size of the current frame. +struct FrameSize { + uint32_t xsize, ysize; +}; + +// AnimationFrame defines duration of animation frames. +struct AnimationFrame : public Fields { + explicit AnimationFrame(const CodecMetadata* metadata); + JXL_FIELDS_NAME(AnimationFrame) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // How long to wait [in ticks, see Animation{}] after rendering. + // May be 0 if the current frame serves as a foundation for another frame. + uint32_t duration; + + uint32_t timecode; // 0xHHMMSSFF + + // Must be set to the one ImageMetadata acting as the full codestream header, + // with correct xyb_encoded, list of extra channels, etc... + const CodecMetadata* nonserialized_metadata = nullptr; +}; + +// For decoding to lower resolutions. Only used for kRegular frames. +struct Passes : public Fields { + Passes(); + JXL_FIELDS_NAME(Passes) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + void GetDownsamplingBracket(size_t pass, int& minShift, int& maxShift) const { + maxShift = 2; + minShift = 3; + for (size_t i = 0;; i++) { + for (uint32_t j = 0; j < num_downsample; ++j) { + if (i == last_pass[j]) { + if (downsample[j] == 8) minShift = 3; + if (downsample[j] == 4) minShift = 2; + if (downsample[j] == 2) minShift = 1; + if (downsample[j] == 1) minShift = 0; + } + } + if (i == num_passes - 1) minShift = 0; + if (i == pass) return; + maxShift = minShift - 1; + } + } + + uint32_t GetDownsamplingTargetForCompletedPasses(uint32_t num_p) const { + if (num_p >= num_passes) return 1; + uint32_t retval = 8; + for (uint32_t i = 0; i < num_downsample; ++i) { + if (num_p > last_pass[i]) { + retval = std::min(retval, downsample[i]); + } + } + return retval; + } + + std::string DebugString() const; + + uint32_t num_passes; // <= kMaxNumPasses + uint32_t num_downsample; // <= num_passes + + // Array of num_downsample pairs. downsample=1/last_pass=num_passes-1 and + // downsample=8/last_pass=0 need not be specified; they are implicit. + uint32_t downsample[kMaxNumPasses]; + uint32_t last_pass[kMaxNumPasses]; + // Array of shift values for each pass. It is implicitly assumed to be 0 for + // the last pass. + uint32_t shift[kMaxNumPasses]; +}; + +enum FrameType { + // A "regular" frame: might be a crop, and will be blended on a previous + // frame, if any, and displayed or blended in future frames. + kRegularFrame = 0, + // A DC frame: this frame is downsampled and will be *only* used as the DC of + // a future frame and, possibly, for previews. Cannot be cropped, blended, or + // referenced by patches or blending modes. Frames that *use* a DC frame + // cannot have non-default sizes either. + kDCFrame = 1, + // A PatchesSource frame: this frame will be only used as a source frame for + // taking patches. Can be cropped, but cannot have non-(0, 0) x0 and y0. + kReferenceOnly = 2, + // Same as kRegularFrame, but not used for progressive rendering. This also + // implies no early display of DC. + kSkipProgressive = 3, +}; + +// Image/frame := one of more of these, where the last has is_last = true. +// Starts at a byte-aligned address "a"; the next pass starts at "a + size". +struct FrameHeader : public Fields { + // Optional postprocessing steps. These flags are the source of truth; + // Override must set/clear them rather than change their meaning. Values + // chosen such that typical flags == 0 (encoded in only two bits). + enum Flags { + // Often but not always off => low bit value: + + // Inject noise into decoded output. + kNoise = 1, + + // Overlay patches. + kPatches = 2, + + // 4, 8 = reserved for future sometimes-off + + // Overlay splines. + kSplines = 16, + + kUseDcFrame = 32, // Implies kSkipAdaptiveDCSmoothing. + + // 64 = reserved for future often-off + + // Almost always on => negated: + + kSkipAdaptiveDCSmoothing = 128, + }; + + explicit FrameHeader(const CodecMetadata* metadata); + JXL_FIELDS_NAME(FrameHeader) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Sets/clears `flag` based upon `condition`. + void UpdateFlag(const bool condition, const uint64_t flag) { + if (condition) { + flags |= flag; + } else { + flags &= ~flag; + } + } + + // Returns true if this frame is supposed to be saved for future usage by + // other frames. + bool CanBeReferenced() const { + // DC frames cannot be referenced. The last frame cannot be referenced. A + // duration 0 frame makes little sense if it is not referenced. A + // non-duration 0 frame may or may not be referenced. + return !is_last && frame_type != FrameType::kDCFrame && + (animation_frame.duration == 0 || save_as_reference != 0); + } + + mutable bool all_default; + + // Always present + FrameEncoding encoding; + // Some versions of UBSAN complain in VisitFrameType if not initialized. + FrameType frame_type = FrameType::kRegularFrame; + + uint64_t flags; + + ColorTransform color_transform; + YCbCrChromaSubsampling chroma_subsampling; + + uint32_t group_size_shift; // only if encoding == kModular; + + uint32_t x_qm_scale; // only if VarDCT and color_transform == kXYB + uint32_t b_qm_scale; // only if VarDCT and color_transform == kXYB + + std::string name; + + // Skipped for kReferenceOnly. + Passes passes; + + // Skipped for kDCFrame + bool custom_size_or_origin; + FrameSize frame_size; + + // upsampling factors for color and extra channels. + // Upsampling is always performed before applying any inverse color transform. + // Skipped (1) if kUseDCFrame + uint32_t upsampling; + std::vector<uint32_t> extra_channel_upsampling; + + // Only for kRegular frames. + FrameOrigin frame_origin; + + BlendingInfo blending_info; + std::vector<BlendingInfo> extra_channel_blending_info; + + // Animation info for this frame. + AnimationFrame animation_frame; + + // This is the last frame. + bool is_last; + + // ID to refer to this frame with. 0-3, not present if kDCFrame. + // 0 has a special meaning for kRegular frames of nonzero duration: it defines + // a frame that will not be referenced in the future. + uint32_t save_as_reference; + + // Whether to save this frame before or after the color transform. A frame + // that is saved before the color tansform can only be used for blending + // through patches. On the contrary, a frame that is saved after the color + // transform can only be used for blending through blending modes. + // Irrelevant for extra channel blending. Can only be true if + // blending_info.mode == kReplace and this is not a partial kRegularFrame; if + // this is a DC frame, it is always true. + bool save_before_color_transform; + + uint32_t dc_level; // 1-4 if kDCFrame (0 otherwise). + + // Must be set to the one ImageMetadata acting as the full codestream header, + // with correct xyb_encoded, list of extra channels, etc... + const CodecMetadata* nonserialized_metadata = nullptr; + + // NOTE: This is ignored by AllDefault. + LoopFilter loop_filter; + + bool nonserialized_is_preview = false; + + size_t default_xsize() const { + if (!nonserialized_metadata) return 0; + if (nonserialized_is_preview) { + return nonserialized_metadata->m.preview_size.xsize(); + } + return nonserialized_metadata->xsize(); + } + + size_t default_ysize() const { + if (!nonserialized_metadata) return 0; + if (nonserialized_is_preview) { + return nonserialized_metadata->m.preview_size.ysize(); + } + return nonserialized_metadata->ysize(); + } + + FrameDimensions ToFrameDimensions() const { + size_t xsize = default_xsize(); + size_t ysize = default_ysize(); + + xsize = frame_size.xsize ? frame_size.xsize : xsize; + ysize = frame_size.ysize ? frame_size.ysize : ysize; + + if (dc_level != 0) { + xsize = DivCeil(xsize, 1 << (3 * dc_level)); + ysize = DivCeil(ysize, 1 << (3 * dc_level)); + } + + FrameDimensions frame_dim; + frame_dim.Set(xsize, ysize, group_size_shift, + chroma_subsampling.MaxHShift(), + chroma_subsampling.MaxVShift(), + encoding == FrameEncoding::kModular, upsampling); + return frame_dim; + } + + // True if a color transform should be applied to this frame. + bool needs_color_transform() const { + return !save_before_color_transform || + frame_type == FrameType::kRegularFrame || + frame_type == FrameType::kSkipProgressive; + } + + std::string DebugString() const; + + uint64_t extensions; +}; + +Status ReadFrameHeader(BitReader* JXL_RESTRICT reader, + FrameHeader* JXL_RESTRICT frame); + +// Shared by enc/dec. 5F and 13 are by far the most common for d1/2/4/8, 0 +// ensures low overhead for small images. +static constexpr U32Enc kOrderEnc = + U32Enc(Val(0x5F), Val(0x13), Val(0), Bits(kNumOrders)); + +} // namespace jxl + +#endif // LIB_JXL_FRAME_HEADER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/gamma_correct_test.cc b/third_party/jpeg-xl/lib/jxl/gamma_correct_test.cc new file mode 100644 index 0000000000..131ec4fa83 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/gamma_correct_test.cc @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <stdlib.h> + +#include <algorithm> + +#include "lib/jxl/enc_gamma_correct.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +TEST(GammaCorrectTest, TestLinearToSrgbEdgeCases) { + EXPECT_EQ(0, LinearToSrgb8Direct(0.0)); + EXPECT_NEAR(0, LinearToSrgb8Direct(1E-6f), 2E-5); + EXPECT_EQ(0, LinearToSrgb8Direct(-1E-6f)); + EXPECT_EQ(0, LinearToSrgb8Direct(-1E6)); + EXPECT_NEAR(1, LinearToSrgb8Direct(1 - 1E-6f), 1E-5); + EXPECT_EQ(1, LinearToSrgb8Direct(1 + 1E-6f)); + EXPECT_EQ(1, LinearToSrgb8Direct(1E6)); +} + +TEST(GammaCorrectTest, TestRoundTrip) { + // NOLINTNEXTLINE(clang-analyzer-security.FloatLoopCounter) + for (double linear = 0.0; linear <= 1.0; linear += 1E-7) { + const double srgb = LinearToSrgb8Direct(linear); + const double linear2 = Srgb8ToLinearDirect(srgb); + ASSERT_LT(std::abs(linear - linear2), 2E-13) + << "linear = " << linear << ", linear2 = " << linear2; + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/gradient_test.cc b/third_party/jpeg-xl/lib/jxl/gradient_test.cc new file mode 100644 index 0000000000..055a419f5b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/gradient_test.cc @@ -0,0 +1,201 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/cms.h> +#include <math.h> +#include <stddef.h> +#include <stdint.h> + +#include <algorithm> +#include <array> +#include <cmath> +#include <utility> +#include <vector> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { + +struct AuxOut; + +namespace { + +// Returns distance of point p to line p0..p1, the result is signed and is not +// normalized. +double PointLineDist(double x0, double y0, double x1, double y1, double x, + double y) { + return (y1 - y0) * x - (x1 - x0) * y + x1 * y0 - y1 * x0; +} + +// Generates a test image with a gradient from one color to another. +// Angle in degrees, colors can be given in hex as 0xRRGGBB. The angle is the +// angle in which the change direction happens. +Image3F GenerateTestGradient(uint32_t color0, uint32_t color1, double angle, + size_t xsize, size_t ysize) { + Image3F image(xsize, ysize); + + double x0 = xsize / 2; + double y0 = ysize / 2; + double x1 = x0 + std::sin(angle / 360.0 * 2.0 * kPi); + double y1 = y0 + std::cos(angle / 360.0 * 2.0 * kPi); + + double maxdist = + std::max<double>(fabs(PointLineDist(x0, y0, x1, y1, 0, 0)), + fabs(PointLineDist(x0, y0, x1, y1, xsize, 0))); + + for (size_t c = 0; c < 3; ++c) { + float c0 = ((color0 >> (8 * (2 - c))) & 255); + float c1 = ((color1 >> (8 * (2 - c))) & 255); + for (size_t y = 0; y < ysize; ++y) { + float* row = image.PlaneRow(c, y); + for (size_t x = 0; x < xsize; ++x) { + double dist = PointLineDist(x0, y0, x1, y1, x, y); + double v = ((dist / maxdist) + 1.0) / 2.0; + float color = c0 * (1.0 - v) + c1 * v; + row[x] = color; + } + } + } + + return image; +} + +// Computes the max of the horizontal and vertical second derivative for each +// pixel, where second derivative means absolute value of difference of left +// delta and right delta (top/bottom for vertical direction). +// The radius over which the derivative is computed is only 1 pixel and it only +// checks two angles (hor and ver), but this approximation works well enough. +static ImageF Gradient2(const ImageF& image) { + size_t xsize = image.xsize(); + size_t ysize = image.ysize(); + ImageF image2(image.xsize(), image.ysize()); + for (size_t y = 1; y + 1 < ysize; y++) { + const auto* JXL_RESTRICT row0 = image.Row(y - 1); + const auto* JXL_RESTRICT row1 = image.Row(y); + const auto* JXL_RESTRICT row2 = image.Row(y + 1); + auto* row_out = image2.Row(y); + for (size_t x = 1; x + 1 < xsize; x++) { + float ddx = (row1[x] - row1[x - 1]) - (row1[x + 1] - row1[x]); + float ddy = (row1[x] - row0[x]) - (row2[x] - row1[x]); + row_out[x] = std::max(fabsf(ddx), fabsf(ddy)); + } + } + // Copy to the borders + if (ysize > 2) { + auto* JXL_RESTRICT row0 = image2.Row(0); + const auto* JXL_RESTRICT row1 = image2.Row(1); + const auto* JXL_RESTRICT row2 = image2.Row(ysize - 2); + auto* JXL_RESTRICT row3 = image2.Row(ysize - 1); + for (size_t x = 1; x + 1 < xsize; x++) { + row0[x] = row1[x]; + row3[x] = row2[x]; + } + } else { + const auto* row0_in = image.Row(0); + const auto* row1_in = image.Row(ysize - 1); + auto* row0_out = image2.Row(0); + auto* row1_out = image2.Row(ysize - 1); + for (size_t x = 1; x + 1 < xsize; x++) { + // Image too narrow, take first derivative instead + row0_out[x] = row1_out[x] = fabsf(row0_in[x] - row1_in[x]); + } + } + if (xsize > 2) { + for (size_t y = 0; y < ysize; y++) { + auto* row = image2.Row(y); + row[0] = row[1]; + row[xsize - 1] = row[xsize - 2]; + } + } else { + for (size_t y = 0; y < ysize; y++) { + const auto* JXL_RESTRICT row_in = image.Row(y); + auto* row_out = image2.Row(y); + // Image too narrow, take first derivative instead + row_out[0] = row_out[xsize - 1] = fabsf(row_in[0] - row_in[xsize - 1]); + } + } + return image2; +} + +static Image3F Gradient2(const Image3F& image) { + return Image3F(Gradient2(image.Plane(0)), Gradient2(image.Plane(1)), + Gradient2(image.Plane(2))); +} + +/* +Tests if roundtrip with jxl on a gradient image doesn't cause banding. +Only tests if use_gradient is true. Set to false for debugging to see the +distance values. +Angle in degrees, colors can be given in hex as 0xRRGGBB. +*/ +void TestGradient(ThreadPool* pool, uint32_t color0, uint32_t color1, + size_t xsize, size_t ysize, float angle, bool fast_mode, + float butteraugli_distance, bool use_gradient = true) { + CompressParams cparams; + cparams.butteraugli_distance = butteraugli_distance; + if (fast_mode) { + cparams.speed_tier = SpeedTier::kSquirrel; + } + Image3F gradient = GenerateTestGradient(color0, color1, angle, xsize, ysize); + + CodecInOut io; + io.metadata.m.SetUintSamples(8); + io.metadata.m.color_encoding = ColorEncoding::SRGB(); + io.SetFromImage(std::move(gradient), io.metadata.m.color_encoding); + + CodecInOut io2; + + std::vector<uint8_t> compressed; + EXPECT_TRUE(test::EncodeFile(cparams, &io, &compressed, pool)); + EXPECT_TRUE(test::DecodeFile({}, Bytes(compressed), &io2, pool)); + EXPECT_TRUE(io2.Main().TransformTo(io2.metadata.m.color_encoding, + *JxlGetDefaultCms(), pool)); + + if (use_gradient) { + // Test that the gradient map worked. For that, we take a second derivative + // of the image with Gradient2 to measure how linear the change is in x and + // y direction. For a well handled gradient, we expect max values around + // 0.1, while if there is noticeable banding, which means the gradient map + // failed, the values are around 0.5-1.0 (regardless of + // butteraugli_distance). + Image3F gradient2 = Gradient2(*io2.Main().color()); + + std::array<float, 3> image_max; + Image3Max(gradient2, &image_max); + + // TODO(jyrki): These values used to work with 0.2, 0.2, 0.2. + EXPECT_LE(image_max[0], 3.15); + EXPECT_LE(image_max[1], 1.72); + EXPECT_LE(image_max[2], 5.05); + } +} + +static constexpr bool fast_mode = true; + +TEST(GradientTest, SteepGradient) { + test::ThreadPoolForTests pool(8); + // Relatively steep gradients, colors from the sky of stp.png + TestGradient(&pool, 0xd99d58, 0x889ab1, 512, 512, 90, fast_mode, 3.0); +} + +TEST(GradientTest, SubtleGradient) { + test::ThreadPoolForTests pool(8); + // Very subtle gradient + TestGradient(&pool, 0xb89b7b, 0xa89b8d, 512, 512, 90, fast_mode, 4.0); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/headers.cc b/third_party/jpeg-xl/lib/jxl/headers.cc new file mode 100644 index 0000000000..db88147687 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/headers.cc @@ -0,0 +1,194 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/headers.h" + +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_dimensions.h" + +namespace jxl { +namespace { + +struct Rational { + constexpr explicit Rational(uint32_t num, uint32_t den) + : num(num), den(den) {} + + // Returns floor(multiplicand * rational). + constexpr uint32_t MulTruncate(uint32_t multiplicand) const { + return uint64_t(multiplicand) * num / den; + } + + uint32_t num; + uint32_t den; +}; + +Rational FixedAspectRatios(uint32_t ratio) { + JXL_ASSERT(0 != ratio && ratio < 8); + // Other candidates: 5/4, 7/5, 14/9, 16/10, 5/3, 21/9, 12/5 + constexpr Rational kRatios[7] = {Rational(1, 1), // square + Rational(12, 10), // + Rational(4, 3), // camera + Rational(3, 2), // mobile camera + Rational(16, 9), // camera/display + Rational(5, 4), // + Rational(2, 1)}; // + return kRatios[ratio - 1]; +} + +uint32_t FindAspectRatio(uint32_t xsize, uint32_t ysize) { + for (uint32_t r = 1; r < 8; ++r) { + if (xsize == FixedAspectRatios(r).MulTruncate(ysize)) { + return r; + } + } + return 0; // Must send xsize instead +} + +} // namespace + +size_t SizeHeader::xsize() const { + if (ratio_ != 0) { + return FixedAspectRatios(ratio_).MulTruncate( + static_cast<uint32_t>(ysize())); + } + return small_ ? ((xsize_div8_minus_1_ + 1) * 8) : xsize_; +} + +Status SizeHeader::Set(size_t xsize64, size_t ysize64) { + if (xsize64 > 0xFFFFFFFFull || ysize64 > 0xFFFFFFFFull) { + return JXL_FAILURE("Image too large"); + } + const uint32_t xsize32 = static_cast<uint32_t>(xsize64); + const uint32_t ysize32 = static_cast<uint32_t>(ysize64); + if (xsize64 == 0 || ysize64 == 0) return JXL_FAILURE("Empty image"); + ratio_ = FindAspectRatio(xsize32, ysize32); + small_ = ysize64 <= 256 && (ysize64 % kBlockDim) == 0 && + (ratio_ != 0 || (xsize64 <= 256 && (xsize64 % kBlockDim) == 0)); + if (small_) { + ysize_div8_minus_1_ = ysize32 / 8 - 1; + } else { + ysize_ = ysize32; + } + + if (ratio_ == 0) { + if (small_) { + xsize_div8_minus_1_ = xsize32 / 8 - 1; + } else { + xsize_ = xsize32; + } + } + JXL_ASSERT(xsize() == xsize64); + JXL_ASSERT(ysize() == ysize64); + return true; +} + +Status PreviewHeader::Set(size_t xsize64, size_t ysize64) { + const uint32_t xsize32 = static_cast<uint32_t>(xsize64); + const uint32_t ysize32 = static_cast<uint32_t>(ysize64); + if (xsize64 == 0 || ysize64 == 0) return JXL_FAILURE("Empty preview"); + div8_ = (xsize64 % kBlockDim) == 0 && (ysize64 % kBlockDim) == 0; + if (div8_) { + ysize_div8_ = ysize32 / 8; + } else { + ysize_ = ysize32; + } + + ratio_ = FindAspectRatio(xsize32, ysize32); + if (ratio_ == 0) { + if (div8_) { + xsize_div8_ = xsize32 / 8; + } else { + xsize_ = xsize32; + } + } + JXL_ASSERT(xsize() == xsize64); + JXL_ASSERT(ysize() == ysize64); + return true; +} + +size_t PreviewHeader::xsize() const { + if (ratio_ != 0) { + return FixedAspectRatios(ratio_).MulTruncate( + static_cast<uint32_t>(ysize())); + } + return div8_ ? (xsize_div8_ * 8) : xsize_; +} + +SizeHeader::SizeHeader() { Bundle::Init(this); } +Status SizeHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &small_)); + + if (visitor->Conditional(small_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, 0, &ysize_div8_minus_1_)); + } + if (visitor->Conditional(!small_)) { + // (Could still be small, but non-multiple of 8.) + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(9, 1), BitsOffset(13, 1), + BitsOffset(18, 1), BitsOffset(30, 1), + 1, &ysize_)); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &ratio_)); + if (visitor->Conditional(ratio_ == 0 && small_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, 0, &xsize_div8_minus_1_)); + } + if (visitor->Conditional(ratio_ == 0 && !small_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(9, 1), BitsOffset(13, 1), + BitsOffset(18, 1), BitsOffset(30, 1), + 1, &xsize_)); + } + + return true; +} + +PreviewHeader::PreviewHeader() { Bundle::Init(this); } +Status PreviewHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &div8_)); + + if (visitor->Conditional(div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(16), Val(32), BitsOffset(5, 1), + BitsOffset(9, 33), 1, &ysize_div8_)); + } + if (visitor->Conditional(!div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(6, 1), BitsOffset(8, 65), + BitsOffset(10, 321), + BitsOffset(12, 1345), 1, &ysize_)); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &ratio_)); + if (visitor->Conditional(ratio_ == 0 && div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(16), Val(32), BitsOffset(5, 1), + BitsOffset(9, 33), 1, &xsize_div8_)); + } + if (visitor->Conditional(ratio_ == 0 && !div8_)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(BitsOffset(6, 1), BitsOffset(8, 65), + BitsOffset(10, 321), + BitsOffset(12, 1345), 1, &xsize_)); + } + + return true; +} + +AnimationHeader::AnimationHeader() { Bundle::Init(this); } +Status AnimationHeader::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(100), Val(1000), BitsOffset(10, 1), + BitsOffset(30, 1), 1, &tps_numerator)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(1), Val(1001), BitsOffset(8, 1), + BitsOffset(10, 1), 1, + &tps_denominator)); + + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Bits(3), Bits(16), Bits(32), 0, &num_loops)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_timecodes)); + return true; +} + +Status ReadSizeHeader(BitReader* JXL_RESTRICT reader, + SizeHeader* JXL_RESTRICT size) { + return Bundle::Read(reader, size); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/headers.h b/third_party/jpeg-xl/lib/jxl/headers.h new file mode 100644 index 0000000000..3cce84dabc --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/headers.h @@ -0,0 +1,97 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_HEADERS_H_ +#define LIB_JXL_HEADERS_H_ + +// Codestream headers. + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +// Reserved by ISO/IEC 10918-1. LF causes files opened in text mode to be +// rejected because the marker changes to 0x0D instead. The 0xFF prefix also +// ensures there were no 7-bit transmission limitations. +static constexpr uint8_t kCodestreamMarker = 0x0A; + +// Compact representation of image dimensions (best case: 9 bits) so decoders +// can preallocate early. +class SizeHeader : public Fields { + public: + SizeHeader(); + JXL_FIELDS_NAME(SizeHeader) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + Status Set(size_t xsize, size_t ysize); + + size_t xsize() const; + size_t ysize() const { + return small_ ? ((ysize_div8_minus_1_ + 1) * 8) : ysize_; + } + + private: + bool small_; // xsize and ysize <= 256 and divisible by 8. + + uint32_t ysize_div8_minus_1_; + uint32_t ysize_; + + uint32_t ratio_; + uint32_t xsize_div8_minus_1_; + uint32_t xsize_; +}; + +// (Similar to SizeHeader but different encoding because previews are smaller) +class PreviewHeader : public Fields { + public: + PreviewHeader(); + JXL_FIELDS_NAME(PreviewHeader) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + Status Set(size_t xsize, size_t ysize); + + size_t xsize() const; + size_t ysize() const { return div8_ ? (ysize_div8_ * 8) : ysize_; } + + private: + bool div8_; // xsize and ysize divisible by 8. + + uint32_t ysize_div8_; + uint32_t ysize_; + + uint32_t ratio_; + uint32_t xsize_div8_; + uint32_t xsize_; +}; + +struct AnimationHeader : public Fields { + AnimationHeader(); + JXL_FIELDS_NAME(AnimationHeader) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Ticks per second (expressed as rational number to support NTSC) + uint32_t tps_numerator; + uint32_t tps_denominator; + + uint32_t num_loops; // 0 means to repeat infinitely. + + bool have_timecodes; +}; + +Status ReadSizeHeader(BitReader* JXL_RESTRICT reader, + SizeHeader* JXL_RESTRICT size); + +} // namespace jxl + +#endif // LIB_JXL_HEADERS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/huffman_table.cc b/third_party/jpeg-xl/lib/jxl/huffman_table.cc new file mode 100644 index 0000000000..9ae7865af6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/huffman_table.cc @@ -0,0 +1,161 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/huffman_table.h" + +#include <cstring> /* for memcpy */ +#include <vector> + +#include "lib/jxl/ans_params.h" +#include "lib/jxl/dec_huffman.h" + +namespace jxl { + +/* Returns reverse(reverse(key, len) + 1, len), where reverse(key, len) is the + bit-wise reversal of the len least significant bits of key. */ +static inline int GetNextKey(int key, int len) { + int step = 1u << (len - 1); + while (key & step) { + step >>= 1; + } + return (key & (step - 1)) + step; +} + +/* Stores code in table[0], table[step], table[2*step], ..., table[end] */ +/* Assumes that end is an integer multiple of step */ +static inline void ReplicateValue(HuffmanCode* table, int step, int end, + HuffmanCode code) { + do { + end -= step; + table[end] = code; + } while (end > 0); +} + +/* Returns the table width of the next 2nd level table. count is the histogram + of bit lengths for the remaining symbols, len is the code length of the next + processed symbol */ +static inline size_t NextTableBitSize(const uint16_t* const count, size_t len, + int root_bits) { + size_t left = 1u << (len - root_bits); + while (len < PREFIX_MAX_BITS) { + if (left <= count[len]) break; + left -= count[len]; + ++len; + left <<= 1; + } + return len - root_bits; +} + +uint32_t BuildHuffmanTable(HuffmanCode* root_table, int root_bits, + const uint8_t* const code_lengths, + size_t code_lengths_size, uint16_t* count) { + HuffmanCode code; /* current table entry */ + HuffmanCode* table; /* next available space in table */ + size_t len; /* current code length */ + size_t symbol; /* symbol index in original or sorted table */ + int key; /* reversed prefix code */ + int step; /* step size to replicate values in current table */ + int low; /* low bits for current root entry */ + int mask; /* mask for low bits */ + size_t table_bits; /* key length of current table */ + int table_size; /* size of current table */ + int total_size; /* sum of root table size and 2nd level table sizes */ + /* offsets in sorted table for each length */ + uint16_t offset[PREFIX_MAX_BITS + 1]; + size_t max_length = 1; + + if (code_lengths_size > 1u << PREFIX_MAX_BITS) return 0; + + /* symbols sorted by code length */ + std::vector<uint16_t> sorted_storage(code_lengths_size); + uint16_t* sorted = sorted_storage.data(); + + /* generate offsets into sorted symbol table by code length */ + { + uint16_t sum = 0; + for (len = 1; len <= PREFIX_MAX_BITS; len++) { + offset[len] = sum; + if (count[len]) { + sum = static_cast<uint16_t>(sum + count[len]); + max_length = len; + } + } + } + + /* sort symbols by length, by symbol order within each length */ + for (symbol = 0; symbol < code_lengths_size; symbol++) { + if (code_lengths[symbol] != 0) { + sorted[offset[code_lengths[symbol]]++] = symbol; + } + } + + table = root_table; + table_bits = root_bits; + table_size = 1u << table_bits; + total_size = table_size; + + /* special case code with only one value */ + if (offset[PREFIX_MAX_BITS] == 1) { + code.bits = 0; + code.value = static_cast<uint16_t>(sorted[0]); + for (key = 0; key < total_size; ++key) { + table[key] = code; + } + return total_size; + } + + /* fill in root table */ + /* let's reduce the table size to a smaller size if possible, and */ + /* create the repetitions by memcpy if possible in the coming loop */ + if (table_bits > max_length) { + table_bits = max_length; + table_size = 1u << table_bits; + } + key = 0; + symbol = 0; + code.bits = 1; + step = 2; + do { + for (; count[code.bits] != 0; --count[code.bits]) { + code.value = static_cast<uint16_t>(sorted[symbol++]); + ReplicateValue(&table[key], step, table_size, code); + key = GetNextKey(key, code.bits); + } + step <<= 1; + } while (++code.bits <= table_bits); + + /* if root_bits != table_bits we only created one fraction of the */ + /* table, and we need to replicate it now. */ + while (total_size != table_size) { + memcpy(&table[table_size], &table[0], table_size * sizeof(table[0])); + table_size <<= 1; + } + + /* fill in 2nd level tables and add pointers to root table */ + mask = total_size - 1; + low = -1; + for (len = root_bits + 1, step = 2; len <= max_length; ++len, step <<= 1) { + for (; count[len] != 0; --count[len]) { + if ((key & mask) != low) { + table += table_size; + table_bits = NextTableBitSize(count, len, root_bits); + table_size = 1u << table_bits; + total_size += table_size; + low = key & mask; + root_table[low].bits = static_cast<uint8_t>(table_bits + root_bits); + root_table[low].value = + static_cast<uint16_t>((table - root_table) - low); + } + code.bits = static_cast<uint8_t>(len - root_bits); + code.value = static_cast<uint16_t>(sorted[symbol++]); + ReplicateValue(&table[key >> root_bits], step, table_size, code); + key = GetNextKey(key, len); + } + } + + return total_size; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/huffman_table.h b/third_party/jpeg-xl/lib/jxl/huffman_table.h new file mode 100644 index 0000000000..11cdb2fc45 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/huffman_table.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_HUFFMAN_TABLE_H_ +#define LIB_JXL_HUFFMAN_TABLE_H_ + +#include <stdint.h> +#include <stdlib.h> + +namespace jxl { + +struct HuffmanCode { + uint8_t bits; /* number of bits used for this symbol */ + uint16_t value; /* symbol value or table offset */ +}; + +/* Builds Huffman lookup table assuming code lengths are in symbol order. */ +/* Returns 0 in case of error (invalid tree or memory error), otherwise + populated size of table. */ +uint32_t BuildHuffmanTable(HuffmanCode* root_table, int root_bits, + const uint8_t* code_lengths, + size_t code_lengths_size, uint16_t* count); + +} // namespace jxl + +#endif // LIB_JXL_HUFFMAN_TABLE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/iaca_test.cc b/third_party/jpeg-xl/lib/jxl/iaca_test.cc new file mode 100644 index 0000000000..e25d9316d5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/iaca_test.cc @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/base/iaca.h" + +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +TEST(IacaTest, MarkersDefaultToDisabledAndDoNotCrash) { + BeginIACA(); + EndIACA(); +} + +TEST(IacaTest, ScopeDefaultToDisabledAndDoNotCrash) { ScopeIACA iaca; } + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/icc_codec.cc b/third_party/jpeg-xl/lib/jxl/icc_codec.cc new file mode 100644 index 0000000000..a1f118ebfb --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/icc_codec.cc @@ -0,0 +1,395 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/icc_codec.h" + +#include <stdint.h> + +#include <map> +#include <string> +#include <vector> + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/icc_codec_common.h" +#include "lib/jxl/padded_bytes.h" + +namespace jxl { +namespace { + +// Shuffles or interleaves bytes, for example with width 2, turns "ABCDabcd" +// into "AaBbCcDc". Transposes a matrix of ceil(size / width) columns and +// width rows. There are size elements, size may be < width * height, if so the +// last elements of the rightmost column are missing, the missing spots are +// transposed along with the filled spots, and the result has the missing +// elements at the end of the bottom row. The input is the input matrix in +// scanline order but with missing elements skipped (which may occur in multiple +// locations), the output is the result matrix in scanline order (with +// no need to skip missing elements as they are past the end of the data). +void Shuffle(uint8_t* data, size_t size, size_t width) { + size_t height = (size + width - 1) / width; // amount of rows of output + PaddedBytes result(size); + // i = output index, j input index + size_t s = 0, j = 0; + for (size_t i = 0; i < size; i++) { + result[i] = data[j]; + j += height; + if (j >= size) j = ++s; + } + + for (size_t i = 0; i < size; i++) { + data[i] = result[i]; + } +} + +// TODO(eustas): should be 20, or even 18, once DecodeVarInt is improved; +// currently DecodeVarInt does not signal the errors, and marks +// 11 bytes as used even if only 10 are used (and 9 is enough for +// 63-bit values). +constexpr const size_t kPreambleSize = 22; // enough for reading 2 VarInts + +uint64_t DecodeVarInt(const uint8_t* input, size_t inputSize, size_t* pos) { + size_t i; + uint64_t ret = 0; + for (i = 0; *pos + i < inputSize && i < 10; ++i) { + ret |= uint64_t(input[*pos + i] & 127) << uint64_t(7 * i); + // If the next-byte flag is not set, stop + if ((input[*pos + i] & 128) == 0) break; + } + // TODO(user): Return a decoding error if i == 10. + *pos += i + 1; + return ret; +} + +} // namespace + +// Mimics the beginning of UnpredictICC for quick validity check. +// At least kPreambleSize bytes of data should be valid at invocation time. +Status CheckPreamble(const PaddedBytes& data, size_t enc_size, + size_t output_limit) { + const uint8_t* enc = data.data(); + size_t size = data.size(); + size_t pos = 0; + uint64_t osize = DecodeVarInt(enc, size, &pos); + JXL_RETURN_IF_ERROR(CheckIs32Bit(osize)); + if (pos >= size) return JXL_FAILURE("Out of bounds"); + uint64_t csize = DecodeVarInt(enc, size, &pos); + JXL_RETURN_IF_ERROR(CheckIs32Bit(csize)); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size)); + // We expect that UnpredictICC inflates input, not the other way round. + if (osize + 65536 < enc_size) return JXL_FAILURE("Malformed ICC"); + if (output_limit && osize > output_limit) { + return JXL_FAILURE("Decoded ICC is too large"); + } + return true; +} + +// Decodes the result of PredictICC back to a valid ICC profile. +Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result) { + if (!result->empty()) return JXL_FAILURE("result must be empty initially"); + size_t pos = 0; + // TODO(lode): technically speaking we need to check that the entire varint + // decoding never goes out of bounds, not just the first byte. This requires + // a DecodeVarInt function that returns an error code. It is safe to use + // DecodeVarInt with out of bounds values, it silently returns, but the + // specification requires an error. Idem for all DecodeVarInt below. + if (pos >= size) return JXL_FAILURE("Out of bounds"); + uint64_t osize = DecodeVarInt(enc, size, &pos); // Output size + JXL_RETURN_IF_ERROR(CheckIs32Bit(osize)); + if (pos >= size) return JXL_FAILURE("Out of bounds"); + uint64_t csize = DecodeVarInt(enc, size, &pos); // Commands size + // Every command is translated to at least on byte. + JXL_RETURN_IF_ERROR(CheckIs32Bit(csize)); + size_t cpos = pos; // pos in commands stream + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size)); + size_t commands_end = cpos + csize; + pos = commands_end; // pos in data stream + + // Header + PaddedBytes header; + header.append(ICCInitialHeaderPrediction()); + EncodeUint32(0, osize, &header); + for (size_t i = 0; i <= kICCHeaderSize; i++) { + if (result->size() == osize) { + if (cpos != commands_end) return JXL_FAILURE("Not all commands used"); + if (pos != size) return JXL_FAILURE("Not all data used"); + return true; // Valid end + } + if (i == kICCHeaderSize) break; // Done + ICCPredictHeader(result->data(), result->size(), header.data(), i); + if (pos >= size) return JXL_FAILURE("Out of bounds"); + result->push_back(enc[pos++] + header[i]); + } + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + + // Tag list + uint64_t numtags = DecodeVarInt(enc, size, &cpos); + + if (numtags != 0) { + numtags--; + JXL_RETURN_IF_ERROR(CheckIs32Bit(numtags)); + AppendUint32(numtags, result); + uint64_t prevtagstart = kICCHeaderSize + numtags * 12; + uint64_t prevtagsize = 0; + for (;;) { + if (result->size() > osize) return JXL_FAILURE("Invalid result size"); + if (cpos > commands_end) return JXL_FAILURE("Out of bounds"); + if (cpos == commands_end) break; // Valid end + uint8_t command = enc[cpos++]; + uint8_t tagcode = command & 63; + Tag tag; + if (tagcode == 0) { + break; + } else if (tagcode == kCommandTagUnknown) { + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 4, size)); + tag = DecodeKeyword(enc, size, pos); + pos += 4; + } else if (tagcode == kCommandTagTRC) { + tag = kRtrcTag; + } else if (tagcode == kCommandTagXYZ) { + tag = kRxyzTag; + } else { + if (tagcode - kCommandTagStringFirst >= kNumTagStrings) { + return JXL_FAILURE("Unknown tagcode"); + } + tag = *kTagStrings[tagcode - kCommandTagStringFirst]; + } + AppendKeyword(tag, result); + + uint64_t tagstart; + uint64_t tagsize = prevtagsize; + if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag || + tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag || + tag == kLumiTag) { + tagsize = 20; + } + + if (command & kFlagBitOffset) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + tagstart = DecodeVarInt(enc, size, &cpos); + } else { + JXL_RETURN_IF_ERROR(CheckIs32Bit(prevtagstart)); + tagstart = prevtagstart + prevtagsize; + } + JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart)); + AppendUint32(tagstart, result); + if (command & kFlagBitSize) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + tagsize = DecodeVarInt(enc, size, &cpos); + } + JXL_RETURN_IF_ERROR(CheckIs32Bit(tagsize)); + AppendUint32(tagsize, result); + prevtagstart = tagstart; + prevtagsize = tagsize; + + if (tagcode == kCommandTagTRC) { + AppendKeyword(kGtrcTag, result); + AppendUint32(tagstart, result); + AppendUint32(tagsize, result); + AppendKeyword(kBtrcTag, result); + AppendUint32(tagstart, result); + AppendUint32(tagsize, result); + } + + if (tagcode == kCommandTagXYZ) { + JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart + tagsize * 2)); + AppendKeyword(kGxyzTag, result); + AppendUint32(tagstart + tagsize, result); + AppendUint32(tagsize, result); + AppendKeyword(kBxyzTag, result); + AppendUint32(tagstart + tagsize * 2, result); + AppendUint32(tagsize, result); + } + } + } + + // Main Content + for (;;) { + if (result->size() > osize) return JXL_FAILURE("Invalid result size"); + if (cpos > commands_end) return JXL_FAILURE("Out of bounds"); + if (cpos == commands_end) break; // Valid end + uint8_t command = enc[cpos++]; + if (command == kCommandInsert) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + uint64_t num = DecodeVarInt(enc, size, &cpos); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size)); + for (size_t i = 0; i < num; i++) { + result->push_back(enc[pos++]); + } + } else if (command == kCommandShuffle2 || command == kCommandShuffle4) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + uint64_t num = DecodeVarInt(enc, size, &cpos); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size)); + PaddedBytes shuffled(num); + for (size_t i = 0; i < num; i++) { + shuffled[i] = enc[pos + i]; + } + if (command == kCommandShuffle2) { + Shuffle(shuffled.data(), num, 2); + } else if (command == kCommandShuffle4) { + Shuffle(shuffled.data(), num, 4); + } + for (size_t i = 0; i < num; i++) { + result->push_back(shuffled[i]); + pos++; + } + } else if (command == kCommandPredict) { + JXL_RETURN_IF_ERROR(CheckOutOfBounds(cpos, 2, commands_end)); + uint8_t flags = enc[cpos++]; + + size_t width = (flags & 3) + 1; + if (width == 3) return JXL_FAILURE("Invalid width"); + + int order = (flags & 12) >> 2; + if (order == 3) return JXL_FAILURE("Invalid order"); + + uint64_t stride = width; + if (flags & 16) { + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + stride = DecodeVarInt(enc, size, &cpos); + if (stride < width) { + return JXL_FAILURE("Invalid stride"); + } + } + // If stride * 4 >= result->size(), return failure. The check + // "size == 0 || ((size - 1) >> 2) < stride" corresponds to + // "stride * 4 >= size", but does not suffer from integer overflow. + // This check is more strict than necessary but follows the specification + // and the encoder should ensure this is followed. + if (result->empty() || ((result->size() - 1u) >> 2u) < stride) { + return JXL_FAILURE("Invalid stride"); + } + + if (cpos >= commands_end) return JXL_FAILURE("Out of bounds"); + uint64_t num = DecodeVarInt(enc, size, &cpos); // in bytes + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size)); + + PaddedBytes shuffled(num); + for (size_t i = 0; i < num; i++) { + shuffled[i] = enc[pos + i]; + } + if (width > 1) Shuffle(shuffled.data(), num, width); + + size_t start = result->size(); + for (size_t i = 0; i < num; i++) { + uint8_t predicted = LinearPredictICCValue(result->data(), start, i, + stride, width, order); + result->push_back(predicted + shuffled[i]); + } + pos += num; + } else if (command == kCommandXYZ) { + AppendKeyword(kXyz_Tag, result); + for (int i = 0; i < 4; i++) result->push_back(0); + JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 12, size)); + for (size_t i = 0; i < 12; i++) { + result->push_back(enc[pos++]); + } + } else if (command >= kCommandTypeStartFirst && + command < kCommandTypeStartFirst + kNumTypeStrings) { + AppendKeyword(*kTypeStrings[command - kCommandTypeStartFirst], result); + for (size_t i = 0; i < 4; i++) { + result->push_back(0); + } + } else { + return JXL_FAILURE("Unknown command"); + } + } + + if (pos != size) return JXL_FAILURE("Not all data used"); + if (result->size() != osize) return JXL_FAILURE("Invalid result size"); + + return true; +} + +Status ICCReader::Init(BitReader* reader, size_t output_limit) { + JXL_RETURN_IF_ERROR(CheckEOI(reader)); + used_bits_base_ = reader->TotalBitsConsumed(); + if (bits_to_skip_ == 0) { + enc_size_ = U64Coder::Read(reader); + if (enc_size_ > 268435456) { + // Avoid too large memory allocation for invalid file. + return JXL_FAILURE("Too large encoded profile"); + } + JXL_RETURN_IF_ERROR( + DecodeHistograms(reader, kNumICCContexts, &code_, &context_map_)); + ans_reader_ = ANSSymbolReader(&code_, reader); + i_ = 0; + decompressed_.resize(std::min<size_t>(i_ + 0x400, enc_size_)); + for (; i_ < std::min<size_t>(2, enc_size_); i_++) { + decompressed_[i_] = ans_reader_.ReadHybridUint( + ICCANSContext(i_, i_ > 0 ? decompressed_[i_ - 1] : 0, + i_ > 1 ? decompressed_[i_ - 2] : 0), + reader, context_map_); + } + if (enc_size_ > kPreambleSize) { + for (; i_ < kPreambleSize; i_++) { + decompressed_[i_] = ans_reader_.ReadHybridUint( + ICCANSContext(i_, decompressed_[i_ - 1], decompressed_[i_ - 2]), + reader, context_map_); + } + JXL_RETURN_IF_ERROR(CheckEOI(reader)); + JXL_RETURN_IF_ERROR( + CheckPreamble(decompressed_, enc_size_, output_limit)); + } + bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_; + } else { + reader->SkipBits(bits_to_skip_); + } + return true; +} + +Status ICCReader::Process(BitReader* reader, PaddedBytes* icc) { + ANSSymbolReader::Checkpoint checkpoint; + size_t saved_i = 0; + auto save = [&]() { + ans_reader_.Save(&checkpoint); + bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_; + saved_i = i_; + }; + save(); + auto check_and_restore = [&]() { + Status status = CheckEOI(reader); + if (!status) { + // not enough bytes. + ans_reader_.Restore(checkpoint); + i_ = saved_i; + return status; + } + return Status(true); + }; + for (; i_ < enc_size_; i_++) { + if (i_ % ANSSymbolReader::kMaxCheckpointInterval == 0 && i_ > 0) { + JXL_RETURN_IF_ERROR(check_and_restore()); + save(); + if ((i_ > 0) && (((i_ & 0xFFFF) == 0))) { + float used_bytes = + (reader->TotalBitsConsumed() - used_bits_base_) / 8.0f; + if (i_ > used_bytes * 256) return JXL_FAILURE("Corrupted stream"); + } + decompressed_.resize(std::min<size_t>(i_ + 0x400, enc_size_)); + } + JXL_DASSERT(i_ >= 2); + decompressed_[i_] = ans_reader_.ReadHybridUint( + ICCANSContext(i_, decompressed_[i_ - 1], decompressed_[i_ - 2]), reader, + context_map_); + } + JXL_RETURN_IF_ERROR(check_and_restore()); + bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_; + if (!ans_reader_.CheckANSFinalState()) { + return JXL_FAILURE("Corrupted ICC profile"); + } + + icc->clear(); + return UnpredictICC(decompressed_.data(), decompressed_.size(), icc); +} + +Status ICCReader::CheckEOI(BitReader* reader) { + if (reader->AllReadsWithinBounds()) return true; + return JXL_STATUS(StatusCode::kNotEnoughBytes, + "Not enough bytes for reading ICC profile"); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/icc_codec.h b/third_party/jpeg-xl/lib/jxl/icc_codec.h new file mode 100644 index 0000000000..87e523a575 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/icc_codec.h @@ -0,0 +1,50 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ICC_CODEC_H_ +#define LIB_JXL_ICC_CODEC_H_ + +// Compressed representation of ICC profiles. + +#include <cstddef> +#include <cstdint> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/padded_bytes.h" + +namespace jxl { + +struct ICCReader { + Status Init(BitReader* reader, size_t output_limit); + Status Process(BitReader* reader, PaddedBytes* icc); + void Reset() { + bits_to_skip_ = 0; + decompressed_.clear(); + } + + private: + Status CheckEOI(BitReader* reader); + size_t i_ = 0; + size_t bits_to_skip_ = 0; + size_t used_bits_base_ = 0; + uint64_t enc_size_ = 0; + std::vector<uint8_t> context_map_; + ANSCode code_; + ANSSymbolReader ans_reader_; + PaddedBytes decompressed_; +}; + +// Exposed only for testing +Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result); + +// Exposed only for testing +Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result); + +} // namespace jxl + +#endif // LIB_JXL_ICC_CODEC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/icc_codec_common.cc b/third_party/jpeg-xl/lib/jxl/icc_codec_common.cc new file mode 100644 index 0000000000..d420567b6f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/icc_codec_common.cc @@ -0,0 +1,180 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/icc_codec_common.h" + +#include <stdint.h> + +#include <map> +#include <string> +#include <vector> + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/padded_bytes.h" + +namespace jxl { +namespace { +static uint8_t ByteKind1(uint8_t b) { + if ('a' <= b && b <= 'z') return 0; + if ('A' <= b && b <= 'Z') return 0; + if ('0' <= b && b <= '9') return 1; + if (b == '.' || b == ',') return 1; + if (b == 0) return 2; + if (b == 1) return 3; + if (b < 16) return 4; + if (b == 255) return 6; + if (b > 240) return 5; + return 7; +} + +static uint8_t ByteKind2(uint8_t b) { + if ('a' <= b && b <= 'z') return 0; + if ('A' <= b && b <= 'Z') return 0; + if ('0' <= b && b <= '9') return 1; + if (b == '.' || b == ',') return 1; + if (b < 16) return 2; + if (b > 240) return 3; + return 4; +} + +template <typename T> +T PredictValue(T p1, T p2, T p3, int order) { + if (order == 0) return p1; + if (order == 1) return 2 * p1 - p2; + if (order == 2) return 3 * p1 - 3 * p2 + p3; + return 0; +} +} // namespace + +uint32_t DecodeUint32(const uint8_t* data, size_t size, size_t pos) { + return pos + 4 > size ? 0 : LoadBE32(data + pos); +} + +void EncodeUint32(size_t pos, uint32_t value, PaddedBytes* data) { + if (pos + 4 > data->size()) return; + StoreBE32(value, data->data() + pos); +} + +void AppendUint32(uint32_t value, PaddedBytes* data) { + data->resize(data->size() + 4); + EncodeUint32(data->size() - 4, value, data); +} + +typedef std::array<uint8_t, 4> Tag; + +Tag DecodeKeyword(const uint8_t* data, size_t size, size_t pos) { + if (pos + 4 > size) return {{' ', ' ', ' ', ' '}}; + return {{data[pos], data[pos + 1], data[pos + 2], data[pos + 3]}}; +} + +void EncodeKeyword(const Tag& keyword, uint8_t* data, size_t size, size_t pos) { + if (keyword.size() != 4 || pos + 3 >= size) return; + for (size_t i = 0; i < 4; ++i) data[pos + i] = keyword[i]; +} + +void AppendKeyword(const Tag& keyword, PaddedBytes* data) { + JXL_ASSERT(keyword.size() == 4); + data->append(keyword); +} + +// Checks if a + b > size, taking possible integer overflow into account. +Status CheckOutOfBounds(uint64_t a, uint64_t b, uint64_t size) { + uint64_t pos = a + b; + if (pos > size) return JXL_FAILURE("Out of bounds"); + if (pos < a) return JXL_FAILURE("Out of bounds"); // overflow happened + return true; +} + +Status CheckIs32Bit(uint64_t v) { + static constexpr const uint64_t kUpper32 = ~static_cast<uint64_t>(0xFFFFFFFF); + if ((v & kUpper32) != 0) return JXL_FAILURE("32-bit value expected"); + return true; +} + +const uint8_t kIccInitialHeaderPrediction[kICCHeaderSize] = { + 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 'm', 'n', 't', 'r', + 'R', 'G', 'B', ' ', 'X', 'Y', 'Z', ' ', 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 'a', 'c', 's', 'p', 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 246, 214, 0, 1, 0, 0, 0, 0, 211, 45, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + +const Span<const uint8_t> ICCInitialHeaderPrediction() { + return Bytes(kIccInitialHeaderPrediction); +} + +void ICCPredictHeader(const uint8_t* icc, size_t size, uint8_t* header, + size_t pos) { + if (pos == 8 && size >= 8) { + header[80] = icc[4]; + header[81] = icc[5]; + header[82] = icc[6]; + header[83] = icc[7]; + } + if (pos == 41 && size >= 41) { + if (icc[40] == 'A') { + header[41] = 'P'; + header[42] = 'P'; + header[43] = 'L'; + } + if (icc[40] == 'M') { + header[41] = 'S'; + header[42] = 'F'; + header[43] = 'T'; + } + } + if (pos == 42 && size >= 42) { + if (icc[40] == 'S' && icc[41] == 'G') { + header[42] = 'I'; + header[43] = ' '; + } + if (icc[40] == 'S' && icc[41] == 'U') { + header[42] = 'N'; + header[43] = 'W'; + } + } +} + +// Predicts a value with linear prediction of given order (0-2), for integers +// with width bytes and given stride in bytes between values. +// The start position is at start + i, and the relevant modulus of i describes +// which byte of the multi-byte integer is being handled. +// The value start + i must be at least stride * 4. +uint8_t LinearPredictICCValue(const uint8_t* data, size_t start, size_t i, + size_t stride, size_t width, int order) { + size_t pos = start + i; + if (width == 1) { + uint8_t p1 = data[pos - stride]; + uint8_t p2 = data[pos - stride * 2]; + uint8_t p3 = data[pos - stride * 3]; + return PredictValue(p1, p2, p3, order); + } else if (width == 2) { + size_t p = start + (i & ~1); + uint16_t p1 = (data[p - stride * 1] << 8) + data[p - stride * 1 + 1]; + uint16_t p2 = (data[p - stride * 2] << 8) + data[p - stride * 2 + 1]; + uint16_t p3 = (data[p - stride * 3] << 8) + data[p - stride * 3 + 1]; + uint16_t pred = PredictValue(p1, p2, p3, order); + return (i & 1) ? (pred & 255) : ((pred >> 8) & 255); + } else { + size_t p = start + (i & ~3); + uint32_t p1 = DecodeUint32(data, pos, p - stride); + uint32_t p2 = DecodeUint32(data, pos, p - stride * 2); + uint32_t p3 = DecodeUint32(data, pos, p - stride * 3); + uint32_t pred = PredictValue(p1, p2, p3, order); + unsigned shiftbytes = 3 - (i & 3); + return (pred >> (shiftbytes * 8)) & 255; + } +} + +size_t ICCANSContext(size_t i, size_t b1, size_t b2) { + if (i <= 128) return 0; + return 1 + ByteKind1(b1) + ByteKind2(b2) * 8; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/icc_codec_common.h b/third_party/jpeg-xl/lib/jxl/icc_codec_common.h new file mode 100644 index 0000000000..702a0e7b1f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/icc_codec_common.h @@ -0,0 +1,107 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_ICC_CODEC_COMMON_H_ +#define LIB_JXL_ICC_CODEC_COMMON_H_ + +// Compressed representation of ICC profiles. + +#include <array> +#include <cstddef> +#include <cstdint> + +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +class PaddedBytes; + +static constexpr size_t kICCHeaderSize = 128; + +typedef std::array<uint8_t, 4> Tag; + +static const Tag kAcspTag = {{'a', 'c', 's', 'p'}}; +static const Tag kBkptTag = {{'b', 'k', 'p', 't'}}; +static const Tag kBtrcTag = {{'b', 'T', 'R', 'C'}}; +static const Tag kBxyzTag = {{'b', 'X', 'Y', 'Z'}}; +static const Tag kChadTag = {{'c', 'h', 'a', 'd'}}; +static const Tag kChrmTag = {{'c', 'h', 'r', 'm'}}; +static const Tag kCprtTag = {{'c', 'p', 'r', 't'}}; +static const Tag kCurvTag = {{'c', 'u', 'r', 'v'}}; +static const Tag kDescTag = {{'d', 'e', 's', 'c'}}; +static const Tag kDmddTag = {{'d', 'm', 'd', 'd'}}; +static const Tag kDmndTag = {{'d', 'm', 'n', 'd'}}; +static const Tag kGbd_Tag = {{'g', 'b', 'd', ' '}}; +static const Tag kGtrcTag = {{'g', 'T', 'R', 'C'}}; +static const Tag kGxyzTag = {{'g', 'X', 'Y', 'Z'}}; +static const Tag kKtrcTag = {{'k', 'T', 'R', 'C'}}; +static const Tag kKxyzTag = {{'k', 'X', 'Y', 'Z'}}; +static const Tag kLumiTag = {{'l', 'u', 'm', 'i'}}; +static const Tag kMab_Tag = {{'m', 'A', 'B', ' '}}; +static const Tag kMba_Tag = {{'m', 'B', 'A', ' '}}; +static const Tag kMlucTag = {{'m', 'l', 'u', 'c'}}; +static const Tag kMntrTag = {{'m', 'n', 't', 'r'}}; +static const Tag kParaTag = {{'p', 'a', 'r', 'a'}}; +static const Tag kRgb_Tag = {{'R', 'G', 'B', ' '}}; +static const Tag kRtrcTag = {{'r', 'T', 'R', 'C'}}; +static const Tag kRxyzTag = {{'r', 'X', 'Y', 'Z'}}; +static const Tag kSf32Tag = {{'s', 'f', '3', '2'}}; +static const Tag kTextTag = {{'t', 'e', 'x', 't'}}; +static const Tag kVcgtTag = {{'v', 'c', 'g', 't'}}; +static const Tag kWtptTag = {{'w', 't', 'p', 't'}}; +static const Tag kXyz_Tag = {{'X', 'Y', 'Z', ' '}}; + +// Tag names focused on RGB and GRAY monitor profiles +static constexpr size_t kNumTagStrings = 17; +static constexpr const Tag* kTagStrings[kNumTagStrings] = { + &kCprtTag, &kWtptTag, &kBkptTag, &kRxyzTag, &kGxyzTag, &kBxyzTag, + &kKxyzTag, &kRtrcTag, &kGtrcTag, &kBtrcTag, &kKtrcTag, &kChadTag, + &kDescTag, &kChrmTag, &kDmndTag, &kDmddTag, &kLumiTag}; + +static constexpr size_t kCommandTagUnknown = 1; +static constexpr size_t kCommandTagTRC = 2; +static constexpr size_t kCommandTagXYZ = 3; +static constexpr size_t kCommandTagStringFirst = 4; + +// Tag types focused on RGB and GRAY monitor profiles +static constexpr size_t kNumTypeStrings = 8; +static constexpr const Tag* kTypeStrings[kNumTypeStrings] = { + &kXyz_Tag, &kDescTag, &kTextTag, &kMlucTag, + &kParaTag, &kCurvTag, &kSf32Tag, &kGbd_Tag}; + +static constexpr size_t kCommandInsert = 1; +static constexpr size_t kCommandShuffle2 = 2; +static constexpr size_t kCommandShuffle4 = 3; +static constexpr size_t kCommandPredict = 4; +static constexpr size_t kCommandXYZ = 10; +static constexpr size_t kCommandTypeStartFirst = 16; + +static constexpr size_t kFlagBitOffset = 64; +static constexpr size_t kFlagBitSize = 128; + +static constexpr size_t kNumICCContexts = 41; + +uint32_t DecodeUint32(const uint8_t* data, size_t size, size_t pos); +void EncodeUint32(size_t pos, uint32_t value, PaddedBytes* data); +void AppendUint32(uint32_t value, PaddedBytes* data); +Tag DecodeKeyword(const uint8_t* data, size_t size, size_t pos); +void EncodeKeyword(const Tag& keyword, uint8_t* data, size_t size, size_t pos); +void AppendKeyword(const Tag& keyword, PaddedBytes* data); + +// Checks if a + b > size, taking possible integer overflow into account. +Status CheckOutOfBounds(uint64_t a, uint64_t b, uint64_t size); +Status CheckIs32Bit(uint64_t v); + +const Span<const uint8_t> ICCInitialHeaderPrediction(); +void ICCPredictHeader(const uint8_t* icc, size_t size, uint8_t* header, + size_t pos); +uint8_t LinearPredictICCValue(const uint8_t* data, size_t start, size_t i, + size_t stride, size_t width, int order); +size_t ICCANSContext(size_t i, size_t b1, size_t b2); + +} // namespace jxl + +#endif // LIB_JXL_ICC_CODEC_COMMON_H_ diff --git a/third_party/jpeg-xl/lib/jxl/icc_codec_test.cc b/third_party/jpeg-xl/lib/jxl/icc_codec_test.cc new file mode 100644 index 0000000000..743aa9a30e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/icc_codec_test.cc @@ -0,0 +1,209 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/icc_codec.h" + +#include <cstdint> +#include <string> +#include <vector> + +#include "lib/jxl/base/span.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/enc_icc_codec.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +void TestProfile(const IccBytes& icc) { + BitWriter writer; + ASSERT_TRUE(WriteICC(icc, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + std::vector<uint8_t> dec; + BitReader reader(writer.GetSpan()); + ASSERT_TRUE(test::ReadICC(&reader, &dec)); + ASSERT_TRUE(reader.Close()); + EXPECT_EQ(icc.size(), dec.size()); + if (icc.size() == dec.size()) { + for (size_t i = 0; i < icc.size(); i++) { + EXPECT_EQ(icc[i], dec[i]); + if (icc[i] != dec[i]) break; // One output is enough + } + } +} + +void TestProfile(const std::string& icc) { + IccBytes data; + Bytes(icc).AppendTo(&data); + TestProfile(data); +} + +// Valid profile from one of the images output by the decoder. +static const unsigned char kTestProfile[] = { + 0x00, 0x00, 0x03, 0x80, 0x6c, 0x63, 0x6d, 0x73, 0x04, 0x30, 0x00, 0x00, + 0x6d, 0x6e, 0x74, 0x72, 0x52, 0x47, 0x42, 0x20, 0x58, 0x59, 0x5a, 0x20, + 0x07, 0xe3, 0x00, 0x04, 0x00, 0x1d, 0x00, 0x0f, 0x00, 0x32, 0x00, 0x2e, + 0x61, 0x63, 0x73, 0x70, 0x41, 0x50, 0x50, 0x4c, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0xf6, 0xd6, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xd3, 0x2d, 0x6c, 0x63, 0x6d, 0x73, + 0x5f, 0x07, 0x0d, 0x3e, 0x4d, 0x32, 0xf2, 0x6e, 0x5d, 0x77, 0x26, 0xcc, + 0x23, 0xb0, 0x6a, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d, + 0x64, 0x65, 0x73, 0x63, 0x00, 0x00, 0x01, 0x20, 0x00, 0x00, 0x00, 0x42, + 0x63, 0x70, 0x72, 0x74, 0x00, 0x00, 0x01, 0x64, 0x00, 0x00, 0x01, 0x00, + 0x77, 0x74, 0x70, 0x74, 0x00, 0x00, 0x02, 0x64, 0x00, 0x00, 0x00, 0x14, + 0x63, 0x68, 0x61, 0x64, 0x00, 0x00, 0x02, 0x78, 0x00, 0x00, 0x00, 0x2c, + 0x72, 0x58, 0x59, 0x5a, 0x00, 0x00, 0x02, 0xa4, 0x00, 0x00, 0x00, 0x14, + 0x62, 0x58, 0x59, 0x5a, 0x00, 0x00, 0x02, 0xb8, 0x00, 0x00, 0x00, 0x14, + 0x67, 0x58, 0x59, 0x5a, 0x00, 0x00, 0x02, 0xcc, 0x00, 0x00, 0x00, 0x14, + 0x72, 0x54, 0x52, 0x43, 0x00, 0x00, 0x02, 0xe0, 0x00, 0x00, 0x00, 0x20, + 0x67, 0x54, 0x52, 0x43, 0x00, 0x00, 0x02, 0xe0, 0x00, 0x00, 0x00, 0x20, + 0x62, 0x54, 0x52, 0x43, 0x00, 0x00, 0x02, 0xe0, 0x00, 0x00, 0x00, 0x20, + 0x63, 0x68, 0x72, 0x6d, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x24, + 0x64, 0x6d, 0x6e, 0x64, 0x00, 0x00, 0x03, 0x24, 0x00, 0x00, 0x00, 0x28, + 0x64, 0x6d, 0x64, 0x64, 0x00, 0x00, 0x03, 0x4c, 0x00, 0x00, 0x00, 0x32, + 0x6d, 0x6c, 0x75, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x0c, 0x65, 0x6e, 0x55, 0x53, 0x00, 0x00, 0x00, 0x26, + 0x00, 0x00, 0x00, 0x1c, 0x00, 0x52, 0x00, 0x47, 0x00, 0x42, 0x00, 0x5f, + 0x00, 0x44, 0x00, 0x36, 0x00, 0x35, 0x00, 0x5f, 0x00, 0x53, 0x00, 0x52, + 0x00, 0x47, 0x00, 0x5f, 0x00, 0x52, 0x00, 0x65, 0x00, 0x6c, 0x00, 0x5f, + 0x00, 0x37, 0x00, 0x30, 0x00, 0x39, 0x00, 0x00, 0x6d, 0x6c, 0x75, 0x63, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, + 0x65, 0x6e, 0x55, 0x53, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x00, 0x00, 0x1c, + 0x00, 0x43, 0x00, 0x6f, 0x00, 0x70, 0x00, 0x79, 0x00, 0x72, 0x00, 0x69, + 0x00, 0x67, 0x00, 0x68, 0x00, 0x74, 0x00, 0x20, 0x00, 0x32, 0x00, 0x30, + 0x00, 0x31, 0x00, 0x38, 0x00, 0x20, 0x00, 0x47, 0x00, 0x6f, 0x00, 0x6f, + 0x00, 0x67, 0x00, 0x6c, 0x00, 0x65, 0x00, 0x20, 0x00, 0x4c, 0x00, 0x4c, + 0x00, 0x43, 0x00, 0x2c, 0x00, 0x20, 0x00, 0x43, 0x00, 0x43, 0x00, 0x2d, + 0x00, 0x42, 0x00, 0x59, 0x00, 0x2d, 0x00, 0x53, 0x00, 0x41, 0x00, 0x20, + 0x00, 0x33, 0x00, 0x2e, 0x00, 0x30, 0x00, 0x20, 0x00, 0x55, 0x00, 0x6e, + 0x00, 0x70, 0x00, 0x6f, 0x00, 0x72, 0x00, 0x74, 0x00, 0x65, 0x00, 0x64, + 0x00, 0x20, 0x00, 0x6c, 0x00, 0x69, 0x00, 0x63, 0x00, 0x65, 0x00, 0x6e, + 0x00, 0x73, 0x00, 0x65, 0x00, 0x28, 0x00, 0x68, 0x00, 0x74, 0x00, 0x74, + 0x00, 0x70, 0x00, 0x73, 0x00, 0x3a, 0x00, 0x2f, 0x00, 0x2f, 0x00, 0x63, + 0x00, 0x72, 0x00, 0x65, 0x00, 0x61, 0x00, 0x74, 0x00, 0x69, 0x00, 0x76, + 0x00, 0x65, 0x00, 0x63, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x6d, 0x00, 0x6f, + 0x00, 0x6e, 0x00, 0x73, 0x00, 0x2e, 0x00, 0x6f, 0x00, 0x72, 0x00, 0x67, + 0x00, 0x2f, 0x00, 0x6c, 0x00, 0x69, 0x00, 0x63, 0x00, 0x65, 0x00, 0x6e, + 0x00, 0x73, 0x00, 0x65, 0x00, 0x73, 0x00, 0x2f, 0x00, 0x62, 0x00, 0x79, + 0x00, 0x2d, 0x00, 0x73, 0x00, 0x61, 0x00, 0x2f, 0x00, 0x33, 0x00, 0x2e, + 0x00, 0x30, 0x00, 0x2f, 0x00, 0x6c, 0x00, 0x65, 0x00, 0x67, 0x00, 0x61, + 0x00, 0x6c, 0x00, 0x63, 0x00, 0x6f, 0x00, 0x64, 0x00, 0x65, 0x00, 0x29, + 0x58, 0x59, 0x5a, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf6, 0xd6, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0xd3, 0x2d, 0x73, 0x66, 0x33, 0x32, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x0c, 0x42, 0x00, 0x00, 0x05, 0xde, + 0xff, 0xff, 0xf3, 0x25, 0x00, 0x00, 0x07, 0x93, 0x00, 0x00, 0xfd, 0x90, + 0xff, 0xff, 0xfb, 0xa1, 0xff, 0xff, 0xfd, 0xa2, 0x00, 0x00, 0x03, 0xdc, + 0x00, 0x00, 0xc0, 0x6e, 0x58, 0x59, 0x5a, 0x20, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x6f, 0xa0, 0x00, 0x00, 0x38, 0xf5, 0x00, 0x00, 0x03, 0x90, + 0x58, 0x59, 0x5a, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x24, 0x9f, + 0x00, 0x00, 0x0f, 0x84, 0x00, 0x00, 0xb6, 0xc4, 0x58, 0x59, 0x5a, 0x20, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x62, 0x97, 0x00, 0x00, 0xb7, 0x87, + 0x00, 0x00, 0x18, 0xd9, 0x70, 0x61, 0x72, 0x61, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x38, 0xe4, 0x00, 0x00, 0xe8, 0xf0, + 0x00, 0x00, 0x17, 0x10, 0x00, 0x00, 0x38, 0xe4, 0x00, 0x00, 0x14, 0xbc, + 0x63, 0x68, 0x72, 0x6d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, + 0x00, 0x00, 0xa3, 0xd7, 0x00, 0x00, 0x54, 0x7c, 0x00, 0x00, 0x4c, 0xcd, + 0x00, 0x00, 0x99, 0x9a, 0x00, 0x00, 0x26, 0x67, 0x00, 0x00, 0x0f, 0x5c, + 0x6d, 0x6c, 0x75, 0x63, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x0c, 0x65, 0x6e, 0x55, 0x53, 0x00, 0x00, 0x00, 0x0c, + 0x00, 0x00, 0x00, 0x1c, 0x00, 0x47, 0x00, 0x6f, 0x00, 0x6f, 0x00, 0x67, + 0x00, 0x6c, 0x00, 0x65, 0x6d, 0x6c, 0x75, 0x63, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x65, 0x6e, 0x55, 0x53, + 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x49, 0x00, 0x6d, + 0x00, 0x61, 0x00, 0x67, 0x00, 0x65, 0x00, 0x20, 0x00, 0x63, 0x00, 0x6f, + 0x00, 0x64, 0x00, 0x65, 0x00, 0x63, 0x00, 0x00, +}; + +} // namespace + +TEST(IccCodecTest, Icc) { + // Empty string cannot be tested, encoder checks against writing it. + TestProfile("a"); + TestProfile("ab"); + TestProfile("aaaa"); + + { + // Exactly the ICC header size + IccBytes profile(128); + for (size_t i = 0; i < 128; i++) { + profile[i] = 0; + } + TestProfile(profile); + } + + { + IccBytes profile; + Bytes(kTestProfile, sizeof(kTestProfile)).AppendTo(&profile); + TestProfile(profile); + } + + // Test substrings of full profile + { + IccBytes profile; + for (size_t i = 0; i <= 256; i++) { + profile.push_back(kTestProfile[i]); + TestProfile(profile); + } + } +} + +// kTestProfile after encoding with the ICC codec +static const unsigned char kEncodedTestProfile[] = { + 0x1f, 0x8b, 0x1, 0x13, 0x10, 0x0, 0x0, 0x0, 0x20, 0x4c, 0xcc, 0x3, + 0xe7, 0xa0, 0xa5, 0xa2, 0x90, 0xa4, 0x27, 0xe8, 0x79, 0x1d, 0xe3, 0x26, + 0x57, 0x54, 0xef, 0x0, 0xe8, 0x97, 0x2, 0xce, 0xa1, 0xd7, 0x85, 0x16, + 0xb4, 0x29, 0x94, 0x58, 0xf2, 0x56, 0xc0, 0x76, 0xea, 0x23, 0xec, 0x7c, + 0x73, 0x51, 0x41, 0x40, 0x23, 0x21, 0x95, 0x4, 0x75, 0x12, 0xc9, 0xcc, + 0x16, 0xbd, 0xb6, 0x99, 0xad, 0xf8, 0x75, 0x35, 0xb6, 0x42, 0xae, 0xae, + 0xae, 0x86, 0x56, 0xf8, 0xcc, 0x16, 0x30, 0xb3, 0x45, 0xad, 0xd, 0x40, + 0xd6, 0xd1, 0xd6, 0x99, 0x40, 0xbe, 0xe2, 0xdc, 0x31, 0x7, 0xa6, 0xb9, + 0x27, 0x92, 0x38, 0x0, 0x3, 0x5e, 0x2c, 0xbe, 0xe6, 0xfb, 0x19, 0xbf, + 0xf3, 0x6d, 0xbc, 0x4d, 0x64, 0xe5, 0xba, 0x76, 0xde, 0x31, 0x65, 0x66, + 0x14, 0xa6, 0x3a, 0xc5, 0x8f, 0xb1, 0xb4, 0xba, 0x1f, 0xb1, 0xb8, 0xd4, + 0x75, 0xba, 0x18, 0x86, 0x95, 0x3c, 0x26, 0xf6, 0x25, 0x62, 0x53, 0xfd, + 0x9c, 0x94, 0x76, 0xf6, 0x95, 0x2c, 0xb1, 0xfd, 0xdc, 0xc0, 0xe4, 0x3f, + 0xb3, 0xff, 0x67, 0xde, 0xd5, 0x94, 0xcc, 0xb0, 0x83, 0x2f, 0x28, 0x93, + 0x92, 0x3, 0xa1, 0x41, 0x64, 0x60, 0x62, 0x70, 0x80, 0x87, 0xaf, 0xe7, + 0x60, 0x4a, 0x20, 0x23, 0xb3, 0x11, 0x7, 0x38, 0x38, 0xd4, 0xa, 0x66, + 0xb5, 0x93, 0x41, 0x90, 0x19, 0x17, 0x18, 0x60, 0xa5, 0xb, 0x7a, 0x24, + 0xaa, 0x20, 0x81, 0xac, 0xa9, 0xa1, 0x70, 0xa6, 0x12, 0x8a, 0x4a, 0xa3, + 0xa0, 0xf9, 0x9a, 0x97, 0xe7, 0xa8, 0xac, 0x8, 0xa8, 0xc4, 0x2a, 0x86, + 0xa7, 0x69, 0x1e, 0x67, 0xe6, 0xbe, 0xa4, 0xd3, 0xff, 0x91, 0x61, 0xf6, + 0x8a, 0xe6, 0xb5, 0xb3, 0x61, 0x9f, 0x19, 0x17, 0x98, 0x27, 0x6b, 0xe9, + 0x8, 0x98, 0xe1, 0x21, 0x4a, 0x9, 0xb5, 0xd7, 0xca, 0xfa, 0x94, 0xd0, + 0x69, 0x1a, 0xeb, 0x52, 0x1, 0x4e, 0xf5, 0xf6, 0xdf, 0x7f, 0xe7, 0x29, + 0x70, 0xee, 0x4, 0xda, 0x2f, 0xa4, 0xff, 0xfe, 0xbb, 0x6f, 0xa8, 0xff, + 0xfe, 0xdb, 0xaf, 0x8, 0xf6, 0x72, 0xa1, 0x40, 0x5d, 0xf0, 0x2d, 0x8, + 0x82, 0x5b, 0x87, 0xbd, 0x10, 0x8, 0xe9, 0x7, 0xee, 0x4b, 0x80, 0xda, + 0x4a, 0x4, 0xc5, 0x5e, 0xa0, 0xb7, 0x1e, 0x60, 0xb0, 0x59, 0x76, 0x60, + 0xb, 0x2e, 0x19, 0x8a, 0x2e, 0x1c, 0xe6, 0x6, 0x20, 0xb8, 0x64, 0x18, + 0x2a, 0xcf, 0x51, 0x94, 0xd4, 0xee, 0xc3, 0xfe, 0x39, 0x74, 0xd4, 0x2b, + 0x48, 0xc9, 0x83, 0x4c, 0x9b, 0xd0, 0x4c, 0x35, 0x10, 0xe3, 0x9, 0xf7, + 0x72, 0xf0, 0x7a, 0xe, 0xbf, 0x7d, 0x36, 0x2e, 0x19, 0x7e, 0x3f, 0xc, + 0xf7, 0x93, 0xe7, 0xf4, 0x1d, 0x32, 0xc6, 0xb0, 0x89, 0xad, 0xe0, 0x28, + 0xc1, 0xa7, 0x59, 0xe3, 0x0, +}; + +// Tests that the decoded kEncodedTestProfile matches kTestProfile. +TEST(IccCodecTest, EncodedIccProfile) { + jxl::BitReader reader( + jxl::Bytes(kEncodedTestProfile, sizeof(kEncodedTestProfile))); + std::vector<uint8_t> dec; + ASSERT_TRUE(test::ReadICC(&reader, &dec)); + ASSERT_TRUE(reader.Close()); + EXPECT_EQ(sizeof(kTestProfile), dec.size()); + if (sizeof(kTestProfile) == dec.size()) { + for (size_t i = 0; i < dec.size(); i++) { + EXPECT_EQ(kTestProfile[i], dec[i]); + if (kTestProfile[i] != dec[i]) break; // One output is enough + } + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/image.cc b/third_party/jpeg-xl/lib/jxl/image.cc new file mode 100644 index 0000000000..382c957799 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image.cc @@ -0,0 +1,205 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/image.h" + +#include <algorithm> // swap + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/image.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/sanitizers.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { + +namespace HWY_NAMESPACE { +size_t GetVectorSize() { return HWY_LANES(uint8_t); } +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE + +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { + +HWY_EXPORT(GetVectorSize); // Local function. + +// Returns distance [bytes] between the start of two consecutive rows, a +// multiple of vector/cache line size but NOT CacheAligned::kAlias - see below. +size_t BytesPerRow(const size_t xsize, const size_t sizeof_t) { + const size_t vec_size = VectorSize(); + size_t valid_bytes = xsize * sizeof_t; + + // Allow unaligned accesses starting at the last valid value - this may raise + // msan errors unless the user calls InitializePaddingForUnalignedAccesses. + // Skip for the scalar case because no extra lanes will be loaded. + if (vec_size != 0) { + valid_bytes += vec_size - sizeof_t; + } + + // Round up to vector and cache line size. + const size_t align = std::max(vec_size, CacheAligned::kAlignment); + size_t bytes_per_row = RoundUpTo(valid_bytes, align); + + // During the lengthy window before writes are committed to memory, CPUs + // guard against read after write hazards by checking the address, but + // only the lower 11 bits. We avoid a false dependency between writes to + // consecutive rows by ensuring their sizes are not multiples of 2 KiB. + // Avoid2K prevents the same problem for the planes of an Image3. + if (bytes_per_row % CacheAligned::kAlias == 0) { + bytes_per_row += align; + } + + JXL_ASSERT(bytes_per_row % align == 0); + return bytes_per_row; +} + +} // namespace + +size_t VectorSize() { + static size_t bytes = HWY_DYNAMIC_DISPATCH(GetVectorSize)(); + return bytes; +} + +PlaneBase::PlaneBase(const size_t xsize, const size_t ysize, + const size_t sizeof_t) + : xsize_(static_cast<uint32_t>(xsize)), + ysize_(static_cast<uint32_t>(ysize)), + orig_xsize_(static_cast<uint32_t>(xsize)), + orig_ysize_(static_cast<uint32_t>(ysize)) { + JXL_CHECK(xsize == xsize_); + JXL_CHECK(ysize == ysize_); + + JXL_ASSERT(sizeof_t == 1 || sizeof_t == 2 || sizeof_t == 4 || sizeof_t == 8); + + bytes_per_row_ = 0; + // Dimensions can be zero, e.g. for lazily-allocated images. Only allocate + // if nonzero, because "zero" bytes still have padding/bookkeeping overhead. + if (xsize != 0 && ysize != 0) { + bytes_per_row_ = BytesPerRow(xsize, sizeof_t); + bytes_ = AllocateArray(bytes_per_row_ * ysize); + JXL_CHECK(bytes_.get()); + InitializePadding(sizeof_t, Padding::kRoundUp); + } +} + +void PlaneBase::InitializePadding(const size_t sizeof_t, Padding padding) { +#if defined(MEMORY_SANITIZER) || HWY_IDE + if (xsize_ == 0 || ysize_ == 0) return; + + const size_t vec_size = VectorSize(); + if (vec_size == 0) return; // Scalar mode: no padding needed + + const size_t valid_size = xsize_ * sizeof_t; + const size_t initialize_size = padding == Padding::kRoundUp + ? RoundUpTo(valid_size, vec_size) + : valid_size + vec_size - sizeof_t; + if (valid_size == initialize_size) return; + + for (size_t y = 0; y < ysize_; ++y) { + uint8_t* JXL_RESTRICT row = static_cast<uint8_t*>(VoidRow(y)); +#if defined(__clang__) && \ + ((!defined(__apple_build_version__) && __clang_major__ <= 6) || \ + (defined(__apple_build_version__) && \ + __apple_build_version__ <= 10001145)) + // There's a bug in msan in clang-6 when handling AVX2 operations. This + // workaround allows tests to pass on msan, although it is slower and + // prevents msan warnings from uninitialized images. + std::fill(row, msan::kSanitizerSentinelByte, initialize_size); +#else + memset(row + valid_size, msan::kSanitizerSentinelByte, + initialize_size - valid_size); +#endif // clang6 + } +#endif // MEMORY_SANITIZER +} + +void PlaneBase::Swap(PlaneBase& other) { + std::swap(xsize_, other.xsize_); + std::swap(ysize_, other.ysize_); + std::swap(orig_xsize_, other.orig_xsize_); + std::swap(orig_ysize_, other.orig_ysize_); + std::swap(bytes_per_row_, other.bytes_per_row_); + std::swap(bytes_, other.bytes_); +} + +void PadImageToBlockMultipleInPlace(Image3F* JXL_RESTRICT in, + size_t block_dim) { + const size_t xsize_orig = in->xsize(); + const size_t ysize_orig = in->ysize(); + const size_t xsize = RoundUpTo(xsize_orig, block_dim); + const size_t ysize = RoundUpTo(ysize_orig, block_dim); + // Expands image size to the originally-allocated size. + in->ShrinkTo(xsize, ysize); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ysize_orig; y++) { + float* JXL_RESTRICT row = in->PlaneRow(c, y); + for (size_t x = xsize_orig; x < xsize; x++) { + row[x] = row[xsize_orig - 1]; + } + } + const float* JXL_RESTRICT row_src = in->ConstPlaneRow(c, ysize_orig - 1); + for (size_t y = ysize_orig; y < ysize; y++) { + memcpy(in->PlaneRow(c, y), row_src, xsize * sizeof(float)); + } + } +} + +static void DownsampleImage(const ImageF& input, size_t factor, + ImageF* output) { + JXL_ASSERT(factor != 1); + output->ShrinkTo(DivCeil(input.xsize(), factor), + DivCeil(input.ysize(), factor)); + size_t in_stride = input.PixelsPerRow(); + for (size_t y = 0; y < output->ysize(); y++) { + float* row_out = output->Row(y); + const float* row_in = input.Row(factor * y); + for (size_t x = 0; x < output->xsize(); x++) { + size_t cnt = 0; + float sum = 0; + for (size_t iy = 0; iy < factor && iy + factor * y < input.ysize(); + iy++) { + for (size_t ix = 0; ix < factor && ix + factor * x < input.xsize(); + ix++) { + sum += row_in[iy * in_stride + x * factor + ix]; + cnt++; + } + } + row_out[x] = sum / cnt; + } + } +} + +void DownsampleImage(ImageF* image, size_t factor) { + // Allocate extra space to avoid a reallocation when padding. + ImageF downsampled(DivCeil(image->xsize(), factor) + kBlockDim, + DivCeil(image->ysize(), factor) + kBlockDim); + DownsampleImage(*image, factor, &downsampled); + *image = std::move(downsampled); +} + +void DownsampleImage(Image3F* opsin, size_t factor) { + JXL_ASSERT(factor != 1); + // Allocate extra space to avoid a reallocation when padding. + Image3F downsampled(DivCeil(opsin->xsize(), factor) + kBlockDim, + DivCeil(opsin->ysize(), factor) + kBlockDim); + downsampled.ShrinkTo(downsampled.xsize() - kBlockDim, + downsampled.ysize() - kBlockDim); + for (size_t c = 0; c < 3; c++) { + DownsampleImage(opsin->Plane(c), factor, &downsampled.Plane(c)); + } + *opsin = std::move(downsampled); +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/image.h b/third_party/jpeg-xl/lib/jxl/image.h new file mode 100644 index 0000000000..98c387bb77 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image.h @@ -0,0 +1,509 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_IMAGE_H_ +#define LIB_JXL_IMAGE_H_ + +// SIMD/multicore-friendly planar image representation with row accessors. + +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) +#include <inttypes.h> +#endif + +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#include <algorithm> +#include <sstream> +#include <string> +#include <utility> // std::move + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/cache_aligned.h" + +namespace jxl { + +// Helper function to create rows that are multiples of SIMD vector size. +size_t VectorSize(); + +// Type-independent parts of Plane<> - reduces code duplication and facilitates +// moving member function implementations to cc file. +struct PlaneBase { + PlaneBase() + : xsize_(0), + ysize_(0), + orig_xsize_(0), + orig_ysize_(0), + bytes_per_row_(0), + bytes_(nullptr) {} + PlaneBase(size_t xsize, size_t ysize, size_t sizeof_t); + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo() instead. + PlaneBase(const PlaneBase& other) = delete; + PlaneBase& operator=(const PlaneBase& other) = delete; + + // Move constructor (required for returning Image from function) + PlaneBase(PlaneBase&& other) noexcept = default; + + // Move assignment (required for std::vector) + PlaneBase& operator=(PlaneBase&& other) noexcept = default; + + void Swap(PlaneBase& other); + + // Useful for pre-allocating image with some padding for alignment purposes + // and later reporting the actual valid dimensions. May also be used to + // un-shrink the image. Caller is responsible for ensuring xsize/ysize are <= + // the original dimensions. + void ShrinkTo(const size_t xsize, const size_t ysize) { + JXL_CHECK(xsize <= orig_xsize_); + JXL_CHECK(ysize <= orig_ysize_); + xsize_ = static_cast<uint32_t>(xsize); + ysize_ = static_cast<uint32_t>(ysize); + // NOTE: we can't recompute bytes_per_row for more compact storage and + // better locality because that would invalidate the image contents. + } + + // How many pixels. + JXL_INLINE size_t xsize() const { return xsize_; } + JXL_INLINE size_t ysize() const { return ysize_; } + + // NOTE: do not use this for copying rows - the valid xsize may be much less. + JXL_INLINE size_t bytes_per_row() const { return bytes_per_row_; } + + // Raw access to byte contents, for interfacing with other libraries. + // Unsigned char instead of char to avoid surprises (sign extension). + JXL_INLINE uint8_t* bytes() { + void* p = bytes_.get(); + return static_cast<uint8_t * JXL_RESTRICT>(JXL_ASSUME_ALIGNED(p, 64)); + } + JXL_INLINE const uint8_t* bytes() const { + const void* p = bytes_.get(); + return static_cast<const uint8_t * JXL_RESTRICT>(JXL_ASSUME_ALIGNED(p, 64)); + } + + protected: + // Returns pointer to the start of a row. + JXL_INLINE void* VoidRow(const size_t y) const { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + if (y >= ysize_) { + JXL_ABORT("Row(%" PRIu64 ") in (%u x %u) image\n", (uint64_t)y, xsize_, + ysize_); + } +#endif + + void* row = bytes_.get() + y * bytes_per_row_; + return JXL_ASSUME_ALIGNED(row, 64); + } + + enum class Padding { + // Allow Load(d, row + x) for x = 0; x < xsize(); x += Lanes(d). Default. + kRoundUp, + // Allow LoadU(d, row + x) for x = xsize() - 1. This requires an extra + // vector to be initialized. If done by default, this would suppress + // legitimate msan warnings. We therefore require users to explicitly call + // InitializePadding before using unaligned loads (e.g. convolution). + kUnaligned + }; + + // Initializes the minimum bytes required to suppress msan warnings from + // legitimate (according to Padding mode) vector loads/stores on the right + // border, where some lanes are uninitialized and assumed to be unused. + void InitializePadding(size_t sizeof_t, Padding padding); + + // (Members are non-const to enable assignment during move-assignment.) + uint32_t xsize_; // In valid pixels, not including any padding. + uint32_t ysize_; + uint32_t orig_xsize_; + uint32_t orig_ysize_; + size_t bytes_per_row_; // Includes padding. + CacheAlignedUniquePtr bytes_; +}; + +// Single channel, aligned rows separated by padding. T must be POD. +// +// 'Single channel' (one 2D array per channel) simplifies vectorization +// (repeating the same operation on multiple adjacent components) without the +// complexity of a hybrid layout (8 R, 8 G, 8 B, ...). In particular, clients +// can easily iterate over all components in a row and Image requires no +// knowledge of the pixel format beyond the component type "T". +// +// 'Aligned' means each row is aligned to the L1 cache line size. This prevents +// false sharing between two threads operating on adjacent rows. +// +// 'Padding' is still relevant because vectors could potentially be larger than +// a cache line. By rounding up row sizes to the vector size, we allow +// reading/writing ALIGNED vectors whose first lane is a valid sample. This +// avoids needing a separate loop to handle remaining unaligned lanes. +// +// This image layout could also be achieved with a vector and a row accessor +// function, but a class wrapper with support for "deleter" allows wrapping +// existing memory allocated by clients without copying the pixels. It also +// provides convenient accessors for xsize/ysize, which shortens function +// argument lists. Supports move-construction so it can be stored in containers. +template <typename ComponentType> +class Plane : public PlaneBase { + public: + using T = ComponentType; + static constexpr size_t kNumPlanes = 1; + + Plane() = default; + Plane(const size_t xsize, const size_t ysize) + : PlaneBase(xsize, ysize, sizeof(T)) {} + + void InitializePaddingForUnalignedAccesses() { + InitializePadding(sizeof(T), Padding::kUnaligned); + } + + JXL_INLINE T* Row(const size_t y) { return static_cast<T*>(VoidRow(y)); } + + // Returns pointer to const (see above). + JXL_INLINE const T* Row(const size_t y) const { + return static_cast<const T*>(VoidRow(y)); + } + + // Documents that the access is const. + JXL_INLINE const T* ConstRow(const size_t y) const { + return static_cast<const T*>(VoidRow(y)); + } + + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must + // NOT be used to determine xsize. + JXL_INLINE intptr_t PixelsPerRow() const { + return static_cast<intptr_t>(bytes_per_row_ / sizeof(T)); + } +}; + +using ImageSB = Plane<int8_t>; +using ImageB = Plane<uint8_t>; +using ImageS = Plane<int16_t>; // signed integer or half-float +using ImageU = Plane<uint16_t>; +using ImageI = Plane<int32_t>; +using ImageF = Plane<float>; +using ImageD = Plane<double>; + +// Also works for Image3 and mixed argument types. +template <class Image1, class Image2> +bool SameSize(const Image1& image1, const Image2& image2) { + return image1.xsize() == image2.xsize() && image1.ysize() == image2.ysize(); +} + +template <typename T> +class Image3; + +// Rectangular region in image(s). Factoring this out of Image instead of +// shifting the pointer by x0/y0 allows this to apply to multiple images with +// different resolutions (e.g. color transform and quantization field). +// Can compare using SameSize(rect1, rect2). +template <typename T> +class RectT { + public: + // Most windows are xsize_max * ysize_max, except those on the borders where + // begin + size_max > end. + constexpr RectT(T xbegin, T ybegin, size_t xsize_max, size_t ysize_max, + T xend, T yend) + : x0_(xbegin), + y0_(ybegin), + xsize_(ClampedSize(xbegin, xsize_max, xend)), + ysize_(ClampedSize(ybegin, ysize_max, yend)) {} + + // Construct with origin and known size (typically from another Rect). + constexpr RectT(T xbegin, T ybegin, size_t xsize, size_t ysize) + : x0_(xbegin), y0_(ybegin), xsize_(xsize), ysize_(ysize) {} + + // Construct a rect that covers a whole image/plane/ImageBundle etc. + template <typename ImageT> + explicit RectT(const ImageT& image) + : RectT(0, 0, image.xsize(), image.ysize()) {} + + RectT() : RectT(0, 0, 0, 0) {} + + RectT(const RectT&) = default; + RectT& operator=(const RectT&) = default; + + // Construct a subrect that resides in an image/plane/ImageBundle etc. + template <typename ImageT> + RectT Crop(const ImageT& image) const { + return Intersection(RectT(image)); + } + + // Construct a subrect that resides in the [0, ysize) x [0, xsize) region of + // the current rect. + RectT Crop(size_t area_xsize, size_t area_ysize) const { + return Intersection(RectT(0, 0, area_xsize, area_ysize)); + } + + // Returns a rect that only contains `num` lines with offset `y` from `y0()`. + RectT Lines(size_t y, size_t num) const { + JXL_DASSERT(y + num <= ysize_); + return RectT(x0_, y0_ + y, xsize_, num); + } + + RectT Line(size_t y) const { return Lines(y, 1); } + + JXL_MUST_USE_RESULT RectT Intersection(const RectT& other) const { + return RectT(std::max(x0_, other.x0_), std::max(y0_, other.y0_), xsize_, + ysize_, std::min(x1(), other.x1()), + std::min(y1(), other.y1())); + } + + JXL_MUST_USE_RESULT RectT Translate(int64_t x_offset, + int64_t y_offset) const { + return RectT(x0_ + x_offset, y0_ + y_offset, xsize_, ysize_); + } + + template <typename V> + V* Row(Plane<V>* image, size_t y) const { + JXL_DASSERT(y + y0_ >= 0); + return image->Row(y + y0_) + x0_; + } + + template <typename V> + const V* Row(const Plane<V>* image, size_t y) const { + JXL_DASSERT(y + y0_ >= 0); + return image->Row(y + y0_) + x0_; + } + + template <typename V> + V* PlaneRow(Image3<V>* image, const size_t c, size_t y) const { + JXL_DASSERT(y + y0_ >= 0); + return image->PlaneRow(c, y + y0_) + x0_; + } + + template <typename V> + const V* ConstRow(const Plane<V>& image, size_t y) const { + JXL_DASSERT(y + y0_ >= 0); + return image.ConstRow(y + y0_) + x0_; + } + + template <typename V> + const V* ConstPlaneRow(const Image3<V>& image, size_t c, size_t y) const { + JXL_DASSERT(y + y0_ >= 0); + return image.ConstPlaneRow(c, y + y0_) + x0_; + } + + bool IsInside(const RectT& other) const { + return x0_ >= other.x0() && x1() <= other.x1() && y0_ >= other.y0() && + y1() <= other.y1(); + } + + // Returns true if this Rect fully resides in the given image. ImageT could be + // Plane<T> or Image3<T>; however if ImageT is Rect, results are nonsensical. + template <class ImageT> + bool IsInside(const ImageT& image) const { + return IsInside(RectT(image)); + } + + T x0() const { return x0_; } + T y0() const { return y0_; } + size_t xsize() const { return xsize_; } + size_t ysize() const { return ysize_; } + T x1() const { return x0_ + xsize_; } + T y1() const { return y0_ + ysize_; } + + RectT<T> ShiftLeft(size_t shiftx, size_t shifty) const { + return RectT<T>(x0_ * (1 << shiftx), y0_ * (1 << shifty), xsize_ << shiftx, + ysize_ << shifty); + } + RectT<T> ShiftLeft(size_t shift) const { return ShiftLeft(shift, shift); } + + // Requires x0(), y0() to be multiples of 1<<shiftx, 1<<shifty. + RectT<T> CeilShiftRight(size_t shiftx, size_t shifty) const { + JXL_ASSERT(x0_ % (1 << shiftx) == 0); + JXL_ASSERT(y0_ % (1 << shifty) == 0); + return RectT<T>(x0_ / (1 << shiftx), y0_ / (1 << shifty), + DivCeil(xsize_, T{1} << shiftx), + DivCeil(ysize_, T{1} << shifty)); + } + RectT<T> CeilShiftRight(std::pair<size_t, size_t> shift) const { + return CeilShiftRight(shift.first, shift.second); + } + RectT<T> CeilShiftRight(size_t shift) const { + return CeilShiftRight(shift, shift); + } + + RectT<T> Extend(T border, RectT<T> parent) const { + T new_x0 = x0() > parent.x0() + border ? x0() - border : parent.x0(); + T new_y0 = y0() > parent.y0() + border ? y0() - border : parent.y0(); + T new_x1 = x1() + border > parent.x1() ? parent.x1() : x1() + border; + T new_y1 = y1() + border > parent.y1() ? parent.y1() : y1() + border; + return RectT<T>(new_x0, new_y0, new_x1 - new_x0, new_y1 - new_y0); + } + + template <typename U> + RectT<U> As() const { + return RectT<U>(U(x0_), U(y0_), U(xsize_), U(ysize_)); + } + + private: + // Returns size_max, or whatever is left in [begin, end). + static constexpr size_t ClampedSize(T begin, size_t size_max, T end) { + return (static_cast<T>(begin + size_max) <= end) + ? size_max + : (end > begin ? end - begin : 0); + } + + T x0_; + T y0_; + + size_t xsize_; + size_t ysize_; +}; + +template <typename T> +std::string Description(RectT<T> r) { + std::ostringstream os; + os << "[" << r.x0() << ".." << r.x1() << ")x" + << "[" << r.y0() << ".." << r.y1() << ")"; + return os.str(); +} + +using Rect = RectT<size_t>; + +// Currently, we abuse Image to either refer to an image that owns its storage +// or one that doesn't. In similar vein, we abuse Image* function parameters to +// either mean "assign to me" or "fill the provided image with data". +// Hopefully, the "assign to me" meaning will go away and most images in the +// codebase will not be backed by own storage. When this happens we can redesign +// Image to be a non-storage-holding view class and introduce BackedImage in +// those places that actually need it. + +// NOTE: we can't use Image as a view because invariants are violated +// (alignment and the presence of padding before/after each "row"). + +// A bundle of 3 same-sized images. Typically constructed by moving from three +// rvalue references to Image. To overwrite an existing Image3 using +// single-channel producers, we also need access to Image*. Constructing +// temporary non-owning Image pointing to one plane of an existing Image3 risks +// dangling references, especially if the wrapper is moved. Therefore, we +// store an array of Image (which are compact enough that size is not a concern) +// and provide Plane+Row accessors. +template <typename ComponentType> +class Image3 { + public: + using T = ComponentType; + using PlaneT = jxl::Plane<T>; + static constexpr size_t kNumPlanes = 3; + + Image3() : planes_{PlaneT(), PlaneT(), PlaneT()} {} + + Image3(const size_t xsize, const size_t ysize) + : planes_{PlaneT(xsize, ysize), PlaneT(xsize, ysize), + PlaneT(xsize, ysize)} {} + + Image3(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + } + + Image3(PlaneT&& plane0, PlaneT&& plane1, PlaneT&& plane2) { + JXL_CHECK(SameSize(plane0, plane1)); + JXL_CHECK(SameSize(plane0, plane2)); + planes_[0] = std::move(plane0); + planes_[1] = std::move(plane1); + planes_[2] = std::move(plane2); + } + + // Copy construction/assignment is forbidden to avoid inadvertent copies, + // which can be very expensive. Use CopyImageTo instead. + Image3(const Image3& other) = delete; + Image3& operator=(const Image3& other) = delete; + + Image3& operator=(Image3&& other) noexcept { + for (size_t i = 0; i < kNumPlanes; i++) { + planes_[i] = std::move(other.planes_[i]); + } + return *this; + } + + // Returns row pointer; usage: PlaneRow(idx_plane, y)[x] = val. + JXL_INLINE T* PlaneRow(const size_t c, const size_t y) { + // Custom implementation instead of calling planes_[c].Row ensures only a + // single multiplication is needed for PlaneRow(0..2, y). + PlaneRowBoundsCheck(c, y); + const size_t row_offset = y * planes_[0].bytes_per_row(); + void* row = planes_[c].bytes() + row_offset; + return static_cast<T * JXL_RESTRICT>(JXL_ASSUME_ALIGNED(row, 64)); + } + + // Returns const row pointer; usage: val = PlaneRow(idx_plane, y)[x]. + JXL_INLINE const T* PlaneRow(const size_t c, const size_t y) const { + PlaneRowBoundsCheck(c, y); + const size_t row_offset = y * planes_[0].bytes_per_row(); + const void* row = planes_[c].bytes() + row_offset; + return static_cast<const T * JXL_RESTRICT>(JXL_ASSUME_ALIGNED(row, 64)); + } + + // Returns const row pointer, even if called from a non-const Image3. + JXL_INLINE const T* ConstPlaneRow(const size_t c, const size_t y) const { + PlaneRowBoundsCheck(c, y); + return PlaneRow(c, y); + } + + JXL_INLINE const PlaneT& Plane(size_t idx) const { return planes_[idx]; } + + JXL_INLINE PlaneT& Plane(size_t idx) { return planes_[idx]; } + + void Swap(Image3& other) { + for (size_t c = 0; c < 3; ++c) { + other.planes_[c].Swap(planes_[c]); + } + } + + // Useful for pre-allocating image with some padding for alignment purposes + // and later reporting the actual valid dimensions. May also be used to + // un-shrink the image. Caller is responsible for ensuring xsize/ysize are <= + // the original dimensions. + void ShrinkTo(const size_t xsize, const size_t ysize) { + for (PlaneT& plane : planes_) { + plane.ShrinkTo(xsize, ysize); + } + } + + // Sizes of all three images are guaranteed to be equal. + JXL_INLINE size_t xsize() const { return planes_[0].xsize(); } + JXL_INLINE size_t ysize() const { return planes_[0].ysize(); } + // Returns offset [bytes] from one row to the next row of the same plane. + // WARNING: this must NOT be used to determine xsize, nor for copying rows - + // the valid xsize may be much less. + JXL_INLINE size_t bytes_per_row() const { return planes_[0].bytes_per_row(); } + // Returns number of pixels (some of which are padding) per row. Useful for + // computing other rows via pointer arithmetic. WARNING: this must NOT be used + // to determine xsize. + JXL_INLINE intptr_t PixelsPerRow() const { return planes_[0].PixelsPerRow(); } + + private: + void PlaneRowBoundsCheck(const size_t c, const size_t y) const { +#if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ + defined(THREAD_SANITIZER) + if (c >= kNumPlanes || y >= ysize()) { + JXL_ABORT("PlaneRow(%" PRIu64 ", %" PRIu64 ") in (%" PRIu64 " x %" PRIu64 + ") image\n", + static_cast<uint64_t>(c), static_cast<uint64_t>(y), + static_cast<uint64_t>(xsize()), static_cast<uint64_t>(ysize())); + } +#endif + } + + private: + PlaneT planes_[kNumPlanes]; +}; + +using Image3B = Image3<uint8_t>; +using Image3S = Image3<int16_t>; +using Image3U = Image3<uint16_t>; +using Image3I = Image3<int32_t>; +using Image3F = Image3<float>; +using Image3D = Image3<double>; + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/image_bundle.cc b/third_party/jpeg-xl/lib/jxl/image_bundle.cc new file mode 100644 index 0000000000..fc6d153e6a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_bundle.cc @@ -0,0 +1,123 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/image_bundle.h" + +#include <limits> +#include <utility> + +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/fields.h" + +namespace jxl { + +void ImageBundle::ShrinkTo(size_t xsize, size_t ysize) { + if (HasColor()) color_.ShrinkTo(xsize, ysize); + for (ImageF& ec : extra_channels_) { + ec.ShrinkTo(xsize, ysize); + } +} + +// Called by all other SetFrom*. +void ImageBundle::SetFromImage(Image3F&& color, + const ColorEncoding& c_current) { + JXL_CHECK(color.xsize() != 0 && color.ysize() != 0); + JXL_CHECK(metadata_->color_encoding.IsGray() == c_current.IsGray()); + color_ = std::move(color); + c_current_ = c_current; + VerifySizes(); +} + +void ImageBundle::VerifyMetadata() const { + JXL_CHECK(!c_current_.ICC().empty()); + JXL_CHECK(metadata_->color_encoding.IsGray() == IsGray()); + + if (metadata_->HasAlpha() && alpha().xsize() == 0) { + JXL_UNREACHABLE("MD alpha_bits %u IB alpha %" PRIuS " x %" PRIuS "\n", + metadata_->GetAlphaBits(), alpha().xsize(), + alpha().ysize()); + } + const uint32_t alpha_bits = metadata_->GetAlphaBits(); + JXL_CHECK(alpha_bits <= 32); + + // metadata_->num_extra_channels may temporarily differ from + // extra_channels_.size(), e.g. after SetAlpha. They are synced by the next + // call to VisitFields. +} + +void ImageBundle::VerifySizes() const { + const size_t xs = xsize(); + const size_t ys = ysize(); + + if (HasExtraChannels()) { + JXL_CHECK(xs != 0 && ys != 0); + for (const ImageF& ec : extra_channels_) { + JXL_CHECK(ec.xsize() == xs); + JXL_CHECK(ec.ysize() == ys); + } + } +} + +size_t ImageBundle::DetectRealBitdepth() const { + return metadata_->bit_depth.bits_per_sample; + + // TODO(lode): let this function return lower bit depth if possible, e.g. + // return 8 bits in case the original image came from a 16-bit PNG that + // was in fact representable as 8-bit PNG. Ensure that the implementation + // returns 16 if e.g. two consecutive 16-bit values appeared in the original + // image (such as 32768 and 32769), take into account that e.g. the values + // 3-bit can represent is not a superset of the values 2-bit can represent, + // and there may be slight imprecisions in the floating point image. +} + +const ImageF& ImageBundle::black() const { + JXL_ASSERT(HasBlack()); + const size_t ec = metadata_->Find(ExtraChannel::kBlack) - + metadata_->extra_channel_info.data(); + JXL_ASSERT(ec < extra_channels_.size()); + return extra_channels_[ec]; +} +const ImageF& ImageBundle::alpha() const { + JXL_ASSERT(HasAlpha()); + const size_t ec = metadata_->Find(ExtraChannel::kAlpha) - + metadata_->extra_channel_info.data(); + JXL_ASSERT(ec < extra_channels_.size()); + return extra_channels_[ec]; +} +ImageF* ImageBundle::alpha() { + JXL_ASSERT(HasAlpha()); + const size_t ec = metadata_->Find(ExtraChannel::kAlpha) - + metadata_->extra_channel_info.data(); + JXL_ASSERT(ec < extra_channels_.size()); + return &extra_channels_[ec]; +} + +void ImageBundle::SetAlpha(ImageF&& alpha) { + const ExtraChannelInfo* eci = metadata_->Find(ExtraChannel::kAlpha); + // Must call SetAlphaBits first, otherwise we don't know which channel index + JXL_CHECK(eci != nullptr); + JXL_CHECK(alpha.xsize() != 0 && alpha.ysize() != 0); + if (extra_channels_.size() < metadata_->extra_channel_info.size()) { + // TODO(jon): get rid of this case + extra_channels_.insert( + extra_channels_.begin() + (eci - metadata_->extra_channel_info.data()), + std::move(alpha)); + } else { + extra_channels_[eci - metadata_->extra_channel_info.data()] = + std::move(alpha); + } + // num_extra_channels is automatically set in visitor + VerifySizes(); +} + +void ImageBundle::SetExtraChannels(std::vector<ImageF>&& extra_channels) { + for (const ImageF& plane : extra_channels) { + JXL_CHECK(plane.xsize() != 0 && plane.ysize() != 0); + } + extra_channels_ = std::move(extra_channels); + VerifySizes(); +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/image_bundle.h b/third_party/jpeg-xl/lib/jxl/image_bundle.h new file mode 100644 index 0000000000..2eea496d5e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_bundle.h @@ -0,0 +1,252 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_IMAGE_BUNDLE_H_ +#define LIB_JXL_IMAGE_BUNDLE_H_ + +// The main image or frame consists of a bundle of associated images. + +#include <jxl/cms_interface.h> +#include <stddef.h> +#include <stdint.h> + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" // JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { + +// A bundle of color/alpha/depth/plane images. +class ImageBundle { + public: + // Uninitialized state for use as output parameter. + ImageBundle() : metadata_(nullptr) {} + // Caller is responsible for setting metadata before calling Set*. + explicit ImageBundle(const ImageMetadata* metadata) : metadata_(metadata) {} + + // Move-only (allows storing in std::vector). + ImageBundle(ImageBundle&&) = default; + ImageBundle& operator=(ImageBundle&&) = default; + + ImageBundle Copy() const { + ImageBundle copy(metadata_); + copy.color_ = Image3F(color_.xsize(), color_.ysize()); + CopyImageTo(color_, ©.color_); + copy.c_current_ = c_current_; + copy.extra_channels_.reserve(extra_channels_.size()); + for (const ImageF& plane : extra_channels_) { + ImageF ec(plane.xsize(), plane.ysize()); + CopyImageTo(plane, &ec); + copy.extra_channels_.emplace_back(std::move(ec)); + } + + copy.jpeg_data = + jpeg_data ? make_unique<jpeg::JPEGData>(*jpeg_data) : nullptr; + copy.color_transform = color_transform; + copy.chroma_subsampling = chroma_subsampling; + + return copy; + } + + // -- SIZE + + size_t xsize() const { + if (IsJPEG()) return jpeg_data->width; + if (color_.xsize() != 0) return color_.xsize(); + return extra_channels_.empty() ? 0 : extra_channels_[0].xsize(); + } + size_t ysize() const { + if (IsJPEG()) return jpeg_data->height; + if (color_.ysize() != 0) return color_.ysize(); + return extra_channels_.empty() ? 0 : extra_channels_[0].ysize(); + } + void ShrinkTo(size_t xsize, size_t ysize); + + // sizes taking orientation into account + size_t oriented_xsize() const { + if (static_cast<uint32_t>(metadata_->GetOrientation()) > 4) { + return ysize(); + } else { + return xsize(); + } + } + size_t oriented_ysize() const { + if (static_cast<uint32_t>(metadata_->GetOrientation()) > 4) { + return xsize(); + } else { + return ysize(); + } + } + + // -- COLOR + + // Whether color() is valid/usable. Returns true in most cases. Even images + // with spot colors (one example of when !planes().empty()) typically have a + // part that can be converted to RGB. + bool HasColor() const { return color_.xsize() != 0; } + + // For resetting the size when switching from a reference to main frame. + void RemoveColor() { color_ = Image3F(); } + + // Do not use if !HasColor(). + const Image3F& color() const { + // If this fails, Set* was not called - perhaps because decoding failed? + JXL_DASSERT(HasColor()); + return color_; + } + + // Do not use if !HasColor(). + Image3F* color() { + JXL_DASSERT(HasColor()); + return &color_; + } + + // If c_current.IsGray(), all planes must be identical. NOTE: c_current is + // independent of metadata()->color_encoding, which is the original, whereas + // a decoder might return pixels in a different c_current. + // This only sets the color channels, you must also make extra channels + // match the amount that is in the metadata. + void SetFromImage(Image3F&& color, const ColorEncoding& c_current); + + // -- COLOR ENCODING + + const ColorEncoding& c_current() const { return c_current_; } + + // Returns whether the color image has identical planes. Once established by + // Set*, remains unchanged until a subsequent Set* or TransformTo. + bool IsGray() const { return c_current_.IsGray(); } + + bool IsSRGB() const { return c_current_.IsSRGB(); } + bool IsLinearSRGB() const { return c_current_.IsLinearSRGB(); } + + // Set the c_current profile without doing any transformation, e.g. if the + // transformation was already applied. + void OverrideProfile(const ColorEncoding& new_c_current) { + c_current_ = new_c_current; + } + + // TODO(lode): TransformTo and CopyTo are implemented in enc_image_bundle.cc, + // move these functions out of this header file and class, to + // enc_image_bundle.h. + + // Transforms color to c_desired and sets c_current to c_desired. Alpha and + // metadata remains unchanged. + Status TransformTo(const ColorEncoding& c_desired, const JxlCmsInterface& cms, + ThreadPool* pool = nullptr); + // Copies this:rect, converts to c_desired, and allocates+fills out. + Status CopyTo(const Rect& rect, const ColorEncoding& c_desired, + const JxlCmsInterface& cms, Image3F* out, + ThreadPool* pool = nullptr) const; + + // Detect 'real' bit depth, which can be lower than nominal bit depth + // (this is common in PNG), returns 'real' bit depth + size_t DetectRealBitdepth() const; + + // -- ALPHA + + void SetAlpha(ImageF&& alpha); + bool HasAlpha() const { + return metadata_->Find(ExtraChannel::kAlpha) != nullptr; + } + bool AlphaIsPremultiplied() const { + const ExtraChannelInfo* eci = metadata_->Find(ExtraChannel::kAlpha); + return (eci == nullptr) ? false : eci->alpha_associated; + } + const ImageF& alpha() const; + ImageF* alpha(); + + // -- EXTRA CHANNELS + bool HasBlack() const { + return metadata_->Find(ExtraChannel::kBlack) != nullptr; + } + const ImageF& black() const; + + // Extra channels of unknown interpretation (e.g. spot colors). + void SetExtraChannels(std::vector<ImageF>&& extra_channels); + void ClearExtraChannels() { extra_channels_.clear(); } + bool HasExtraChannels() const { return !extra_channels_.empty(); } + const std::vector<ImageF>& extra_channels() const { return extra_channels_; } + std::vector<ImageF>& extra_channels() { return extra_channels_; } + + const ImageMetadata* metadata() const { return metadata_; } + + void VerifyMetadata() const; + + void SetDecodedBytes(size_t decoded_bytes) { decoded_bytes_ = decoded_bytes; } + size_t decoded_bytes() const { return decoded_bytes_; } + + // -- JPEG transcoding: + + // Returns true if image does or will represent quantized DCT-8 coefficients, + // stored in 8x8 pixel regions. + bool IsJPEG() const { +#if JPEGXL_ENABLE_TRANSCODE_JPEG + return jpeg_data != nullptr; +#else // JPEGXL_ENABLE_TRANSCODE_JPEG + return false; +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + } + + std::unique_ptr<jpeg::JPEGData> jpeg_data; + // these fields are used to signal the input JPEG color space + // NOTE: JPEG doesn't actually provide a way to determine whether YCbCr was + // applied or not. + ColorTransform color_transform = ColorTransform::kNone; + YCbCrChromaSubsampling chroma_subsampling; + + FrameOrigin origin{0, 0}; + + // Animation-related information, corresponding to the timecode and duration + // fields of the jxl::AnimationFrame of the jxl::FrameHeader. + // TODO(lode): ImageBundle is used here to carry the information from + // jxl::FrameHeader, consider instead passing a jxl::FrameHeader directly to + // EncodeFrame or having a field of that type here. + uint32_t duration = 0; + uint32_t timecode = 0; + + // TODO(lode): these fields do not match the JXL frame header, it should be + // possible to specify up to 4 (3 if nonzero duration) slots to save this + // frame as reference (see save_as_reference). + bool use_for_next_frame = false; + bool blend = false; + BlendMode blendmode = BlendMode::kBlend; + + std::string name; + + private: + // Called after any Set* to ensure their sizes are compatible. + void VerifySizes() const; + + // Required for TransformTo so that an ImageBundle is self-sufficient. Always + // points to the same thing, but cannot be const-pointer because that prevents + // the compiler from generating a move ctor. + const ImageMetadata* metadata_; + + // Initialized by Set*: + Image3F color_; // If empty, planes_ is not; all planes equal if IsGray(). + ColorEncoding c_current_; // of color_ + + // Initialized by SetPlanes; size = ImageMetadata.num_extra_channels + std::vector<ImageF> extra_channels_; + + // How many bytes of the input were actually read. + size_t decoded_bytes_ = 0; +}; + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_BUNDLE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/image_bundle_test.cc b/third_party/jpeg-xl/lib/jxl/image_bundle_test.cc new file mode 100644 index 0000000000..1a10598fe2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_bundle_test.cc @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/image_bundle.h" + +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +TEST(ImageBundleTest, ExtraChannelName) { + AuxOut aux_out; + BitWriter writer; + BitWriter::Allotment allotment(&writer, 99); + + ImageMetadata metadata; + ExtraChannelInfo eci; + eci.type = ExtraChannel::kBlack; + eci.name = "testK"; + metadata.extra_channel_info.push_back(std::move(eci)); + ASSERT_TRUE(WriteImageMetadata(metadata, &writer, /*layer=*/0, &aux_out)); + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, /*layer=*/0, &aux_out); + + BitReader reader(writer.GetSpan()); + ImageMetadata metadata_out; + ASSERT_TRUE(ReadImageMetadata(&reader, &metadata_out)); + EXPECT_TRUE(reader.Close()); + EXPECT_EQ("testK", metadata_out.Find(ExtraChannel::kBlack)->name); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/image_metadata.cc b/third_party/jpeg-xl/lib/jxl/image_metadata.cc new file mode 100644 index 0000000000..4cca910753 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_metadata.cc @@ -0,0 +1,477 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/image_metadata.h" + +#include <limits> +#include <utility> + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/cms/opsin_params.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/quantizer.h" + +namespace jxl { +BitDepth::BitDepth() { Bundle::Init(this); } +Status BitDepth::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &floating_point_sample)); + // The same fields (bits_per_sample and exponent_bits_per_sample) are read + // in a different way depending on floating_point_sample's value. It's still + // default-initialized correctly so using visitor->Conditional is not + // required. + if (!floating_point_sample) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(8), Val(10), Val(12), BitsOffset(6, 1), 8, &bits_per_sample)); + exponent_bits_per_sample = 0; + } else { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val(32), Val(16), Val(24), BitsOffset(6, 1), 32, &bits_per_sample)); + // The encoded value is exponent_bits_per_sample - 1, encoded in 3 bits + // so the value can be in range [1, 8]. + const uint32_t offset = 1; + exponent_bits_per_sample -= offset; + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bits(4, 8 - offset, &exponent_bits_per_sample)); + exponent_bits_per_sample += offset; + } + + // Error-checking for floating point ranges. + if (floating_point_sample) { + if (exponent_bits_per_sample < 2 || exponent_bits_per_sample > 8) { + return JXL_FAILURE("Invalid exponent_bits_per_sample: %u", + exponent_bits_per_sample); + } + int mantissa_bits = + static_cast<int>(bits_per_sample) - exponent_bits_per_sample - 1; + if (mantissa_bits < 2 || mantissa_bits > 23) { + return JXL_FAILURE("Invalid bits_per_sample: %u", bits_per_sample); + } + } else { + if (bits_per_sample > 31) { + return JXL_FAILURE("Invalid bits_per_sample: %u", bits_per_sample); + } + } + return true; +} + +#if JXL_DEBUG_V_LEVEL >= 1 +std::string BitDepth::DebugString() const { + std::ostringstream os; + os << (floating_point_sample ? "F" : "U"); + os << bits_per_sample; + if (floating_point_sample) os << "." << exponent_bits_per_sample; + return os.str(); +} +#endif + +CustomTransformData::CustomTransformData() { Bundle::Init(this); } +Status CustomTransformData::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + if (visitor->Conditional(nonserialized_xyb_encoded)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&opsin_inverse_matrix)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &custom_weights_mask)); + if (visitor->Conditional((custom_weights_mask & 0x1) != 0)) { + // 4 5x5 kernels, but all of them can be obtained by symmetry from one, + // which is symmetric along its main diagonal. The top-left kernel is + // defined by + // + // 0 1 2 3 4 + // 1 5 6 7 8 + // 2 6 9 10 11 + // 3 7 10 12 13 + // 4 8 11 13 14 + float constexpr kWeights2[15] = { + -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f, + 0.14111091f, 0.28896755f, 0.00278718f, -0.01610267f, 0.56661550f, + 0.03777607f, -0.01986694f, -0.03144731f, -0.01185068f, -0.00213539f}; + for (size_t i = 0; i < 15; i++) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kWeights2[i], &upsampling2_weights[i])); + } + } + if (visitor->Conditional((custom_weights_mask & 0x2) != 0)) { + // 16 5x5 kernels, but all of them can be obtained by symmetry from + // three, two of which are symmetric along their main diagonals. The top + // left 4 kernels are defined by + // + // 0 1 2 3 4 5 6 7 8 9 + // 1 10 11 12 13 14 15 16 17 18 + // 2 11 19 20 21 22 23 24 25 26 + // 3 12 20 27 28 29 30 31 32 33 + // 4 13 21 28 34 35 36 37 38 39 + // + // 5 14 22 29 35 40 41 42 43 44 + // 6 15 23 30 36 41 45 46 47 48 + // 7 16 24 31 37 42 46 49 50 51 + // 8 17 25 32 38 43 47 50 52 53 + // 9 18 26 33 39 44 48 51 53 54 + constexpr float kWeights4[55] = { + -0.02419067f, -0.03491987f, -0.03693351f, -0.03094285f, -0.00529785f, + -0.01663432f, -0.03556863f, -0.03888905f, -0.03516850f, -0.00989469f, + 0.23651958f, 0.33392945f, -0.01073543f, -0.01313181f, -0.03556694f, + 0.13048175f, 0.40103025f, 0.03951150f, -0.02077584f, 0.46914198f, + -0.00209270f, -0.01484589f, -0.04064806f, 0.18942530f, 0.56279892f, + 0.06674400f, -0.02335494f, -0.03551682f, -0.00754830f, -0.02267919f, + -0.02363578f, 0.00315804f, -0.03399098f, -0.01359519f, -0.00091653f, + -0.00335467f, -0.01163294f, -0.01610294f, -0.00974088f, -0.00191622f, + -0.01095446f, -0.03198464f, -0.04455121f, -0.02799790f, -0.00645912f, + 0.06390599f, 0.22963888f, 0.00630981f, -0.01897349f, 0.67537268f, + 0.08483369f, -0.02534994f, -0.02205197f, -0.01667999f, -0.00384443f}; + for (size_t i = 0; i < 55; i++) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kWeights4[i], &upsampling4_weights[i])); + } + } + if (visitor->Conditional((custom_weights_mask & 0x4) != 0)) { + // 64 5x5 kernels, all of them can be obtained by symmetry from + // 10, 4 of which are symmetric along their main diagonals. The top + // left 16 kernels are defined by + // 0 1 2 3 4 5 6 7 8 9 a b c d e f 10 11 12 13 + // 1 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 21 22 23 24 25 26 + // 2 15 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 + // 3 16 28 39 3a 3b 3c 3d 3e 3f 40 41 42 43 44 45 46 47 48 49 + // 4 17 29 3a 4a 4b 4c 4d 4e 4f 50 51 52 53 54 55 56 57 58 59 + + // 5 18 2a 3b 4b 5a 5b 5c 5d 5e 5f 60 61 62 63 64 65 66 67 68 + // 6 19 2b 3c 4c 5b 69 6a 6b 6c 6d 6e 6f 70 71 72 73 74 75 76 + // 7 1a 2c 3d 4d 5c 6a 77 78 79 7a 7b 7c 7d 7e 7f 80 81 82 83 + // 8 1b 2d 3e 4e 5d 6b 78 84 85 86 87 88 89 8a 8b 8c 8d 8e 8f + // 9 1c 2e 3f 4f 5e 6c 79 85 90 91 92 93 94 95 96 97 98 99 9a + + // a 1d 2f 40 50 5f 6d 7a 86 91 9b 9c 9d 9e 9f a0 a1 a2 a3 a4 + // b 1e 30 41 51 60 6e 7b 87 92 9c a5 a6 a7 a8 a9 aa ab ac ad + // c 1f 31 42 52 61 6f 7c 88 93 9d a6 ae af b0 b1 b2 b3 b4 b5 + // d 20 32 43 53 62 70 7d 89 94 9e a7 af b6 b7 b8 b9 ba bb bc + // e 21 33 44 54 63 71 7e 8a 95 9f a8 b0 b7 bd be bf c0 c1 c2 + + // f 22 34 45 55 64 72 7f 8b 96 a0 a9 b1 b8 be c3 c4 c5 c6 c7 + // 10 23 35 46 56 65 73 80 8c 97 a1 aa b2 b9 bf c4 c8 c9 ca cb + // 11 24 36 47 57 66 74 81 8d 98 a2 ab b3 ba c0 c5 c9 cc cd ce + // 12 25 37 48 58 67 75 82 8e 99 a3 ac b4 bb c1 c6 ca cd cf d0 + // 13 26 38 49 59 68 76 83 8f 9a a4 ad b5 bc c2 c7 cb ce d0 d1 + constexpr float kWeights8[210] = { + -0.02928613f, -0.03706353f, -0.03783812f, -0.03324558f, -0.00447632f, + -0.02519406f, -0.03752601f, -0.03901508f, -0.03663285f, -0.00646649f, + -0.02066407f, -0.03838633f, -0.04002101f, -0.03900035f, -0.00901973f, + -0.01626393f, -0.03954148f, -0.04046620f, -0.03979621f, -0.01224485f, + 0.29895328f, 0.35757708f, -0.02447552f, -0.01081748f, -0.04314594f, + 0.23903219f, 0.41119301f, -0.00573046f, -0.01450239f, -0.04246845f, + 0.17567618f, 0.45220643f, 0.02287757f, -0.01936783f, -0.03583255f, + 0.11572472f, 0.47416733f, 0.06284440f, -0.02685066f, 0.42720050f, + -0.02248939f, -0.01155273f, -0.04562755f, 0.28689496f, 0.49093869f, + -0.00007891f, -0.01545926f, -0.04562659f, 0.21238920f, 0.53980934f, + 0.03369474f, -0.02070211f, -0.03866988f, 0.14229550f, 0.56593398f, + 0.08045181f, -0.02888298f, -0.03680918f, -0.00542229f, -0.02920477f, + -0.02788574f, -0.02118180f, -0.03942402f, -0.00775547f, -0.02433614f, + -0.03193943f, -0.02030828f, -0.04044014f, -0.01074016f, -0.01930822f, + -0.03620399f, -0.01974125f, -0.03919545f, -0.01456093f, -0.00045072f, + -0.00360110f, -0.01020207f, -0.01231907f, -0.00638988f, -0.00071592f, + -0.00279122f, -0.00957115f, -0.01288327f, -0.00730937f, -0.00107783f, + -0.00210156f, -0.00890705f, -0.01317668f, -0.00813895f, -0.00153491f, + -0.02128481f, -0.04173044f, -0.04831487f, -0.03293190f, -0.00525260f, + -0.01720322f, -0.04052736f, -0.05045706f, -0.03607317f, -0.00738030f, + -0.01341764f, -0.03965629f, -0.05151616f, -0.03814886f, -0.01005819f, + 0.18968273f, 0.33063684f, -0.01300105f, -0.01372950f, -0.04017465f, + 0.13727832f, 0.36402234f, 0.01027890f, -0.01832107f, -0.03365072f, + 0.08734506f, 0.38194295f, 0.04338228f, -0.02525993f, 0.56408126f, + 0.00458352f, -0.01648227f, -0.04887868f, 0.24585519f, 0.62026135f, + 0.04314807f, -0.02213737f, -0.04158014f, 0.16637289f, 0.65027023f, + 0.09621636f, -0.03101388f, -0.04082742f, -0.00904519f, -0.02790922f, + -0.02117818f, 0.00798662f, -0.03995711f, -0.01243427f, -0.02231705f, + -0.02946266f, 0.00992055f, -0.03600283f, -0.01684920f, -0.00111684f, + -0.00411204f, -0.01297130f, -0.01723725f, -0.01022545f, -0.00165306f, + -0.00313110f, -0.01218016f, -0.01763266f, -0.01125620f, -0.00231663f, + -0.01374149f, -0.03797620f, -0.05142937f, -0.03117307f, -0.00581914f, + -0.01064003f, -0.03608089f, -0.05272168f, -0.03375670f, -0.00795586f, + 0.09628104f, 0.27129991f, -0.00353779f, -0.01734151f, -0.03153981f, + 0.05686230f, 0.28500998f, 0.02230594f, -0.02374955f, 0.68214326f, + 0.05018048f, -0.02320852f, -0.04383616f, 0.18459474f, 0.71517975f, + 0.10805613f, -0.03263677f, -0.03637639f, -0.01394373f, -0.02511203f, + -0.01728636f, 0.05407331f, -0.02867568f, -0.01893131f, -0.00240854f, + -0.00446511f, -0.01636187f, -0.02377053f, -0.01522848f, -0.00333334f, + -0.00819975f, -0.02964169f, -0.04499287f, -0.02745350f, -0.00612408f, + 0.02727416f, 0.19446600f, 0.00159832f, -0.02232473f, 0.74982506f, + 0.11452620f, -0.03348048f, -0.01605681f, -0.02070339f, -0.00458223f}; + for (size_t i = 0; i < 210; i++) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kWeights8[i], &upsampling8_weights[i])); + } + } + return true; +} + +ExtraChannelInfo::ExtraChannelInfo() { Bundle::Init(this); } +Status ExtraChannelInfo::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + // General + JXL_QUIET_RETURN_IF_ERROR(visitor->Enum(ExtraChannel::kAlpha, &type)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&bit_depth)); + + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(3), Val(4), BitsOffset(3, 1), 0, &dim_shift)); + if ((1U << dim_shift) > 8) { + return JXL_FAILURE("dim_shift %u too large", dim_shift); + } + + JXL_QUIET_RETURN_IF_ERROR(VisitNameString(visitor, &name)); + + // Conditional + if (visitor->Conditional(type == ExtraChannel::kAlpha)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &alpha_associated)); + } + if (visitor->Conditional(type == ExtraChannel::kSpotColor)) { + for (float& c : spot_color) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0, &c)); + } + } + if (visitor->Conditional(type == ExtraChannel::kCFA)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(1), Bits(2), BitsOffset(4, 3), + BitsOffset(8, 19), 1, &cfa_channel)); + } + + if (type == ExtraChannel::kUnknown || + (int(ExtraChannel::kReserved0) <= int(type) && + int(type) <= int(ExtraChannel::kReserved7))) { + return JXL_FAILURE("Unknown extra channel (bits %u, shift %u, name '%s')\n", + bit_depth.bits_per_sample, dim_shift, name.c_str()); + } + return true; +} + +#if JXL_DEBUG_V_LEVEL >= 1 +std::string ExtraChannelInfo::DebugString() const { + std::ostringstream os; + os << (type == ExtraChannel::kAlpha ? "Alpha" + : type == ExtraChannel::kDepth ? "Depth" + : type == ExtraChannel::kSpotColor ? "Spot" + : type == ExtraChannel::kSelectionMask ? "Mask" + : type == ExtraChannel::kBlack ? "Black" + : type == ExtraChannel::kCFA ? "CFA" + : type == ExtraChannel::kThermal ? "Thermal" + : "Unknown"); + if (type == ExtraChannel::kAlpha && alpha_associated) os << "(premul)"; + os << " " << bit_depth.DebugString(); + os << " shift: " << dim_shift; + return os.str(); +} +#endif + +ImageMetadata::ImageMetadata() { Bundle::Init(this); } +Status ImageMetadata::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + // Bundle::AllDefault does not allow usage when reading (it may abort the + // program when a codestream has invalid values), but when reading we + // overwrite the extra_fields value, so do not need to call AllDefault. + bool tone_mapping_default = + visitor->IsReading() ? false : Bundle::AllDefault(tone_mapping); + + bool extra_fields = (orientation != 1 || have_preview || have_animation || + have_intrinsic_size || !tone_mapping_default); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &extra_fields)); + if (visitor->Conditional(extra_fields)) { + orientation--; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(3, 0, &orientation)); + orientation++; + // (No need for bounds checking because we read exactly 3 bits) + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_intrinsic_size)); + if (visitor->Conditional(have_intrinsic_size)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&intrinsic_size)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_preview)); + if (visitor->Conditional(have_preview)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&preview_size)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &have_animation)); + if (visitor->Conditional(have_animation)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&animation)); + } + } else { + orientation = 1; // identity + have_intrinsic_size = false; + have_preview = false; + have_animation = false; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&bit_depth)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bool(true, &modular_16_bit_buffer_sufficient)); + + num_extra_channels = extra_channel_info.size(); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(4, 2), + BitsOffset(12, 1), 0, + &num_extra_channels)); + + if (visitor->Conditional(num_extra_channels != 0)) { + if (visitor->IsReading()) { + extra_channel_info.resize(num_extra_channels); + } + for (ExtraChannelInfo& eci : extra_channel_info) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&eci)); + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(true, &xyb_encoded)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&color_encoding)); + if (visitor->Conditional(extra_fields)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&tone_mapping)); + } + + // Treat as if only the fields up to extra channels exist. + if (visitor->IsReading() && nonserialized_only_parse_basic_info) { + return true; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + // Extensions: in chronological order of being added to the format. + return visitor->EndExtensions(); +} + +OpsinInverseMatrix::OpsinInverseMatrix() { Bundle::Init(this); } +Status OpsinInverseMatrix::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + for (int i = 0; i < 9; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(jxl::cms::DefaultInverseOpsinAbsorbanceMatrix()[i], + &inverse_matrix[i])); + } + for (int i = 0; i < 3; ++i) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16( + jxl::cms::kNegOpsinAbsorbanceBiasRGB[i], &opsin_biases[i])); + } + for (int i = 0; i < 4; ++i) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kDefaultQuantBias[i], &quant_biases[i])); + } + return true; +} + +ToneMapping::ToneMapping() { Bundle::Init(this); } +Status ToneMapping::VisitFields(Visitor* JXL_RESTRICT visitor) { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(kDefaultIntensityTarget, &intensity_target)); + if (intensity_target <= 0.f) { + return JXL_FAILURE("invalid intensity target"); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.0f, &min_nits)); + if (min_nits < 0.f || min_nits > intensity_target) { + return JXL_FAILURE("invalid min %f vs max %f", min_nits, intensity_target); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &relative_to_max_display)); + + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.0f, &linear_below)); + if (linear_below < 0 || (relative_to_max_display && linear_below > 1.0f)) { + return JXL_FAILURE("invalid linear_below %f (%s)", linear_below, + relative_to_max_display ? "relative" : "absolute"); + } + + return true; +} + +Status ReadImageMetadata(BitReader* JXL_RESTRICT reader, + ImageMetadata* JXL_RESTRICT metadata) { + return Bundle::Read(reader, metadata); +} + +void ImageMetadata::SetAlphaBits(uint32_t bits, bool alpha_is_premultiplied) { + std::vector<ExtraChannelInfo>& eciv = extra_channel_info; + ExtraChannelInfo* alpha = Find(ExtraChannel::kAlpha); + if (bits == 0) { + if (alpha != nullptr) { + // Remove the alpha channel from the extra channel info. It's + // theoretically possible that there are multiple, remove all in that + // case. This ensure a next HasAlpha() will return false. + const auto is_alpha = [](const ExtraChannelInfo& eci) { + return eci.type == ExtraChannel::kAlpha; + }; + eciv.erase(std::remove_if(eciv.begin(), eciv.end(), is_alpha), + eciv.end()); + } + } else { + if (alpha == nullptr) { + ExtraChannelInfo info; + info.type = ExtraChannel::kAlpha; + info.bit_depth.bits_per_sample = bits; + info.dim_shift = 0; + info.alpha_associated = alpha_is_premultiplied; + // Prepend rather than append: in case there already are other extra + // channels, prefer alpha channel to be listed first. + eciv.insert(eciv.begin(), info); + } else { + // Ignores potential extra alpha channels, only sets to first one. + alpha->bit_depth.bits_per_sample = bits; + alpha->bit_depth.floating_point_sample = false; + alpha->bit_depth.exponent_bits_per_sample = 0; + alpha->alpha_associated = alpha_is_premultiplied; + } + } + num_extra_channels = extra_channel_info.size(); + if (bits > 12) modular_16_bit_buffer_sufficient = false; +} + +#if JXL_DEBUG_V_LEVEL >= 1 +std::string ImageMetadata::DebugString() const { + std::ostringstream os; + os << bit_depth.DebugString(); + if (modular_16_bit_buffer_sufficient) { + os << " (modular 16)"; + } + os << (xyb_encoded ? " xyb encoded" : " orig profile"); + os << " " << Description(color_encoding); + if (num_extra_channels > 0) { + os << " extra channels:"; + for (size_t i = 0; i < num_extra_channels; ++i) { + os << " (" << extra_channel_info[i].DebugString() << ")"; + if (i + 1 < num_extra_channels) os << ","; + } + } + if (have_preview) { + os << " preview: " << preview_size.xsize() << "x" << preview_size.ysize(); + } + if (orientation != 1) { + os << " orientation: " << orientation; + } + return os.str(); +} + +std::string CodecMetadata::DebugString() const { + std::ostringstream os; + os << size.xsize() << "x" << size.ysize(); + os << " " << m.DebugString(); + return os.str(); +} +#endif + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/image_metadata.h b/third_party/jpeg-xl/lib/jxl/image_metadata.h new file mode 100644 index 0000000000..be603a49f3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_metadata.h @@ -0,0 +1,427 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Main codestream header bundles, the metadata that applies to all frames. +// Enums must align with the C API definitions in codestream_header.h. + +#ifndef LIB_JXL_IMAGE_METADATA_H_ +#define LIB_JXL_IMAGE_METADATA_H_ + +#include <jxl/codestream_header.h> +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/headers.h" + +namespace jxl { + +struct AuxOut; + +// EXIF orientation of the image. This field overrides any field present in +// actual EXIF metadata. The value tells which transformation the decoder must +// apply after decoding to display the image with the correct orientation. +enum class Orientation : uint32_t { + // Values 1..8 match the EXIF definitions. + kIdentity = JXL_ORIENT_IDENTITY, + kFlipHorizontal = JXL_ORIENT_FLIP_HORIZONTAL, + kRotate180 = JXL_ORIENT_ROTATE_180, + kFlipVertical = JXL_ORIENT_FLIP_VERTICAL, + kTranspose = JXL_ORIENT_TRANSPOSE, + kRotate90 = JXL_ORIENT_ROTATE_90_CW, + kAntiTranspose = JXL_ORIENT_ANTI_TRANSPOSE, + kRotate270 = JXL_ORIENT_ROTATE_90_CCW, +}; +// Don't need an EnumBits because Orientation is not read via Enum(). + +enum class ExtraChannel : uint32_t { + // First two enumerators (most common) are cheaper to encode + kAlpha = JXL_CHANNEL_ALPHA, + kDepth = JXL_CHANNEL_DEPTH, + + kSpotColor = JXL_CHANNEL_SPOT_COLOR, + kSelectionMask = JXL_CHANNEL_SELECTION_MASK, + kBlack = JXL_CHANNEL_BLACK, // for CMYK + kCFA = JXL_CHANNEL_CFA, // Bayer channel + kThermal = JXL_CHANNEL_THERMAL, + kReserved0 = JXL_CHANNEL_RESERVED0, + kReserved1 = JXL_CHANNEL_RESERVED1, + kReserved2 = JXL_CHANNEL_RESERVED2, + kReserved3 = JXL_CHANNEL_RESERVED3, + kReserved4 = JXL_CHANNEL_RESERVED4, + kReserved5 = JXL_CHANNEL_RESERVED5, + kReserved6 = JXL_CHANNEL_RESERVED6, + kReserved7 = JXL_CHANNEL_RESERVED7, + // disambiguated via name string, raise warning if unsupported + kUnknown = JXL_CHANNEL_UNKNOWN, + // like kUnknown but can silently be ignored + kOptional = JXL_CHANNEL_OPTIONAL +}; +static inline const char* EnumName(ExtraChannel /*unused*/) { + return "ExtraChannel"; +} +static inline constexpr uint64_t EnumBits(ExtraChannel /*unused*/) { + using EC = ExtraChannel; + return MakeBit(EC::kAlpha) | MakeBit(EC::kDepth) | MakeBit(EC::kSpotColor) | + MakeBit(EC::kSelectionMask) | MakeBit(EC::kBlack) | MakeBit(EC::kCFA) | + MakeBit(EC::kThermal) | MakeBit(EC::kUnknown) | MakeBit(EC::kOptional); +} + +// Used in ImageMetadata and ExtraChannelInfo. +struct BitDepth : public Fields { + BitDepth(); + JXL_FIELDS_NAME(BitDepth) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + std::string DebugString() const; + + // Whether the original (uncompressed) samples are floating point or + // unsigned integer. + bool floating_point_sample; + + // Bit depth of the original (uncompressed) image samples. Must be in the + // range [1, 32]. + uint32_t bits_per_sample; + + // Floating point exponent bits of the original (uncompressed) image samples, + // only used if floating_point_sample is true. + // If used, the samples are floating point with: + // - 1 sign bit + // - exponent_bits_per_sample exponent bits + // - (bits_per_sample - exponent_bits_per_sample - 1) mantissa bits + // If used, exponent_bits_per_sample must be in the range + // [2, 8] and amount of mantissa bits must be in the range [2, 23]. + // NOTE: exponent_bits_per_sample is 8 for single precision binary32 + // point, 5 for half precision binary16, 7 for fp24. + uint32_t exponent_bits_per_sample; +}; + +// Describes one extra channel. +struct ExtraChannelInfo : public Fields { + ExtraChannelInfo(); + JXL_FIELDS_NAME(ExtraChannelInfo) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + std::string DebugString() const; + + mutable bool all_default; + + ExtraChannel type; + BitDepth bit_depth; + uint32_t dim_shift; // downsampled by 2^dim_shift on each axis + + std::string name; // UTF-8 + + // Conditional: + bool alpha_associated; // i.e. premultiplied + float spot_color[4]; // spot color in linear RGBA + uint32_t cfa_channel; +}; + +struct OpsinInverseMatrix : public Fields { + OpsinInverseMatrix(); + JXL_FIELDS_NAME(OpsinInverseMatrix) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + mutable bool all_default; + + float inverse_matrix[9]; + float opsin_biases[3]; + float quant_biases[4]; +}; + +// Information useful for mapping HDR images to lower dynamic range displays. +struct ToneMapping : public Fields { + ToneMapping(); + JXL_FIELDS_NAME(ToneMapping) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + mutable bool all_default; + + // Upper bound on the intensity level present in the image. For unsigned + // integer pixel encodings, this is the brightness of the largest + // representable value. The image does not necessarily contain a pixel + // actually this bright. An encoder is allowed to set 255 for SDR images + // without computing a histogram. + float intensity_target; // [nits] + + // Lower bound on the intensity level present in the image. This may be + // loose, i.e. lower than the actual darkest pixel. When tone mapping, a + // decoder will map [min_nits, intensity_target] to the display range. + float min_nits; + + bool relative_to_max_display; // see below + // The tone mapping will leave unchanged (linear mapping) any pixels whose + // brightness is strictly below this. The interpretation depends on + // relative_to_max_display. If true, this is a ratio [0, 1] of the maximum + // display brightness [nits], otherwise an absolute brightness [nits]. + float linear_below; +}; + +// Contains weights to customize some transforms - in particular, XYB and +// upsampling. +struct CustomTransformData : public Fields { + CustomTransformData(); + JXL_FIELDS_NAME(CustomTransformData) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Must be set before calling VisitFields. Must equal xyb_encoded of + // ImageMetadata, should be set by ImageMetadata during VisitFields. + bool nonserialized_xyb_encoded = false; + + mutable bool all_default; + + OpsinInverseMatrix opsin_inverse_matrix; + + uint32_t custom_weights_mask; + float upsampling2_weights[15]; + float upsampling4_weights[55]; + float upsampling8_weights[210]; +}; + +// Properties of the original image bundle. This enables Encode(Decode()) to +// re-create an equivalent image without user input. +struct ImageMetadata : public Fields { + ImageMetadata(); + JXL_FIELDS_NAME(ImageMetadata) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + // Returns bit depth of the JPEG XL compressed alpha channel, or 0 if no alpha + // channel present. In the theoretical case that there are multiple alpha + // channels, returns the bit depth of the first. + uint32_t GetAlphaBits() const { + const ExtraChannelInfo* alpha = Find(ExtraChannel::kAlpha); + if (alpha == nullptr) return 0; + JXL_ASSERT(alpha->bit_depth.bits_per_sample != 0); + return alpha->bit_depth.bits_per_sample; + } + + // Sets bit depth of alpha channel, adding extra channel if needed, or + // removing all alpha channels if bits is 0. + // Assumes integer alpha channel and not designed to support multiple + // alpha channels (it's possible to use those features by manipulating + // extra_channel_info directly). + // + // Callers must insert the actual channel image at the same index before any + // further modifications to extra_channel_info. + void SetAlphaBits(uint32_t bits, bool alpha_is_premultiplied = false); + + bool HasAlpha() const { return GetAlphaBits() != 0; } + + // Sets the original bit depth fields to indicate unsigned integer of the + // given bit depth. + // TODO(lode): move function to BitDepth + void SetUintSamples(uint32_t bits) { + bit_depth.bits_per_sample = bits; + bit_depth.exponent_bits_per_sample = 0; + bit_depth.floating_point_sample = false; + // RCT / Squeeze may add one bit each, and this is about int16_t, + // so uint13 should still be OK but limiting it to 12 seems safer. + // TODO(jon): figure out a better way to set this header field. + // (in particular, if modular mode is not used it doesn't matter, + // and if transforms are restricted, up to 15-bit could be done) + if (bits > 12) modular_16_bit_buffer_sufficient = false; + } + // Sets the original bit depth fields to indicate single precision floating + // point. + // TODO(lode): move function to BitDepth + void SetFloat32Samples() { + bit_depth.bits_per_sample = 32; + bit_depth.exponent_bits_per_sample = 8; + bit_depth.floating_point_sample = true; + modular_16_bit_buffer_sufficient = false; + } + + void SetFloat16Samples() { + bit_depth.bits_per_sample = 16; + bit_depth.exponent_bits_per_sample = 5; + bit_depth.floating_point_sample = true; + modular_16_bit_buffer_sufficient = false; + } + + void SetIntensityTarget(float intensity_target) { + tone_mapping.intensity_target = intensity_target; + } + float IntensityTarget() const { + JXL_ASSERT(tone_mapping.intensity_target != 0); + return tone_mapping.intensity_target; + } + + // Returns first ExtraChannelInfo of the given type, or nullptr if none. + const ExtraChannelInfo* Find(ExtraChannel type) const { + for (const ExtraChannelInfo& eci : extra_channel_info) { + if (eci.type == type) return &eci; + } + return nullptr; + } + + // Returns first ExtraChannelInfo of the given type, or nullptr if none. + ExtraChannelInfo* Find(ExtraChannel type) { + for (ExtraChannelInfo& eci : extra_channel_info) { + if (eci.type == type) return &eci; + } + return nullptr; + } + + Orientation GetOrientation() const { + return static_cast<Orientation>(orientation); + } + + bool ExtraFieldsDefault() const; + + std::string DebugString() const; + + mutable bool all_default; + + BitDepth bit_depth; + bool modular_16_bit_buffer_sufficient; // otherwise 32 is. + + // Whether the colors values of the pixels of frames are encoded in the + // codestream using the absolute XYB color space, or the using values that + // follow the color space defined by the ColorEncoding or ICC profile. This + // determines when or whether a CMS (Color Management System) is needed to get + // the pixels in a desired color space. In one case, the pixels have one known + // color space and a CMS is needed to convert them to the original image's + // color space, in the other case the pixels have the color space of the + // original image and a CMS is required if a different display space, or a + // single known consistent color space for multiple decoded images, is + // desired. In all cases, the color space of all frames from a single image is + // the same, both VarDCT and modular frames. + // + // If true: then frames can be decoded to XYB (which can also be converted to + // linear and non-linear sRGB with the built in conversion without CMS). The + // attached ColorEncoding or ICC profile has no effect on the meaning of the + // pixel's color values, but instead indicates what the color profile of the + // original image was, and what color profile one should convert to when + // decoding to integers to prevent clipping and precision loss. To do that + // conversion requires a CMS. + // + // If false: then the color values of decoded frames are in the space defined + // by the attached ColorEncoding or ICC profile. To instead get the pixels in + // a chosen known color space, such as sRGB, requires a CMS, since the + // attached ColorEncoding or ICC profile could be any arbitrary color space. + // This mode is typically used for lossless images encoded as integers. + // Frames can also use YCbCr encoding, some frames may and some may not, but + // this is not a different color space but a certain encoding of the RGB + // values. + // + // Note: if !xyb_encoded, but the attached color profile indicates XYB (which + // can happen either if it's a ColorEncoding with color_space_ == + // ColorSpace::kXYB, or if it's an ICC Profile that has been crafted to + // represent XYB), then the frames still may not use ColorEncoding kXYB, they + // must still use kNone (or kYCbCr, which would mean applying the YCbCr + // transform to the 3-channel XYB data), since with !xyb_encoded, the 3 + // channels are stored as-is, no matter what meaning the color profile assigns + // to them. To use ColorSpace::kXYB, xyb_encoded must be true. + // + // This value is defined in image metadata because this is the global + // codestream header. This value does not affect the image itself, so is not + // image metadata per se, it only affects the encoding, and what color space + // the decoder can receive the pixels in without needing a CMS. + bool xyb_encoded; + + ColorEncoding color_encoding; + + // These values are initialized to defaults such that the 'extra_fields' + // condition in VisitFields uses correctly initialized values. + uint32_t orientation = 1; + bool have_preview = false; + bool have_animation = false; + bool have_intrinsic_size = false; + + // If present, the stored image has the dimensions of the first SizeHeader, + // but decoders are advised to resample or display per `intrinsic_size`. + SizeHeader intrinsic_size; // only if have_intrinsic_size + + ToneMapping tone_mapping; + + // When reading: deserialized. When writing: automatically set from vector. + uint32_t num_extra_channels; + std::vector<ExtraChannelInfo> extra_channel_info; + + // Only present if m.have_preview. + PreviewHeader preview_size; + // Only present if m.have_animation. + AnimationHeader animation; + + uint64_t extensions; + + // Option to stop parsing after basic info, and treat as if the later + // fields do not participate. Use to parse only basic image information + // excluding the final larger or variable sized data. + bool nonserialized_only_parse_basic_info = false; +}; + +Status ReadImageMetadata(BitReader* JXL_RESTRICT reader, + ImageMetadata* JXL_RESTRICT metadata); + +Status WriteImageMetadata(const ImageMetadata& metadata, + BitWriter* JXL_RESTRICT writer, size_t layer, + AuxOut* aux_out); + +// All metadata applicable to the entire codestream (dimensions, extra channels, +// ...) +struct CodecMetadata { + // TODO(lode): use the preview and animation fields too, in place of the + // nonserialized_ ones in ImageMetadata. + ImageMetadata m; + // The size of the codestream: this is the nominal size applicable to all + // frames, although some frames can have a different effective size through + // crop, dc_level or representing a the preview. + SizeHeader size; + // Often default. + CustomTransformData transform_data; + + size_t xsize() const { return size.xsize(); } + size_t ysize() const { return size.ysize(); } + size_t oriented_xsize(bool keep_orientation) const { + if (static_cast<uint32_t>(m.GetOrientation()) > 4 && !keep_orientation) { + return ysize(); + } else { + return xsize(); + } + } + size_t oriented_preview_xsize(bool keep_orientation) const { + if (static_cast<uint32_t>(m.GetOrientation()) > 4 && !keep_orientation) { + return m.preview_size.ysize(); + } else { + return m.preview_size.xsize(); + } + } + size_t oriented_ysize(bool keep_orientation) const { + if (static_cast<uint32_t>(m.GetOrientation()) > 4 && !keep_orientation) { + return xsize(); + } else { + return ysize(); + } + } + size_t oriented_preview_ysize(bool keep_orientation) const { + if (static_cast<uint32_t>(m.GetOrientation()) > 4 && !keep_orientation) { + return m.preview_size.xsize(); + } else { + return m.preview_size.ysize(); + } + } + + std::string DebugString() const; +}; + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_METADATA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/image_ops.h b/third_party/jpeg-xl/lib/jxl/image_ops.h new file mode 100644 index 0000000000..b2ce23f13d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_ops.h @@ -0,0 +1,454 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_IMAGE_OPS_H_ +#define LIB_JXL_IMAGE_OPS_H_ + +// Operations on images. + +#include <algorithm> +#include <array> +#include <limits> +#include <vector> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/image.h" + +namespace jxl { + +template <typename T> +void CopyImageTo(const Plane<T>& from, Plane<T>* JXL_RESTRICT to) { + JXL_ASSERT(SameSize(from, *to)); + if (from.ysize() == 0 || from.xsize() == 0) return; + for (size_t y = 0; y < from.ysize(); ++y) { + const T* JXL_RESTRICT row_from = from.ConstRow(y); + T* JXL_RESTRICT row_to = to->Row(y); + memcpy(row_to, row_from, from.xsize() * sizeof(T)); + } +} + +// Copies `from:rect_from` to `to:rect_to`. +template <typename T> +void CopyImageTo(const Rect& rect_from, const Plane<T>& from, + const Rect& rect_to, Plane<T>* JXL_RESTRICT to) { + JXL_DASSERT(SameSize(rect_from, rect_to)); + JXL_DASSERT(rect_from.IsInside(from)); + JXL_DASSERT(rect_to.IsInside(*to)); + if (rect_from.xsize() == 0) return; + for (size_t y = 0; y < rect_from.ysize(); ++y) { + const T* JXL_RESTRICT row_from = rect_from.ConstRow(from, y); + T* JXL_RESTRICT row_to = rect_to.Row(to, y); + memcpy(row_to, row_from, rect_from.xsize() * sizeof(T)); + } +} + +// Copies `from:rect_from` to `to:rect_to`. +template <typename T> +void CopyImageTo(const Rect& rect_from, const Image3<T>& from, + const Rect& rect_to, Image3<T>* JXL_RESTRICT to) { + JXL_ASSERT(SameSize(rect_from, rect_to)); + for (size_t c = 0; c < 3; c++) { + CopyImageTo(rect_from, from.Plane(c), rect_to, &to->Plane(c)); + } +} + +template <typename T, typename U> +void ConvertPlaneAndClamp(const Rect& rect_from, const Plane<T>& from, + const Rect& rect_to, Plane<U>* JXL_RESTRICT to) { + JXL_ASSERT(SameSize(rect_from, rect_to)); + using M = decltype(T() + U()); + for (size_t y = 0; y < rect_to.ysize(); ++y) { + const T* JXL_RESTRICT row_from = rect_from.ConstRow(from, y); + U* JXL_RESTRICT row_to = rect_to.Row(to, y); + for (size_t x = 0; x < rect_to.xsize(); ++x) { + row_to[x] = + std::min<M>(std::max<M>(row_from[x], std::numeric_limits<U>::min()), + std::numeric_limits<U>::max()); + } + } +} + +// Copies `from` to `to`. +template <typename T> +void CopyImageTo(const T& from, T* JXL_RESTRICT to) { + return CopyImageTo(Rect(from), from, Rect(*to), to); +} + +// Copies `from:rect_from` to `to:rect_to`; also copies `padding` pixels of +// border around `from:rect_from`, in all directions, whenever they are inside +// the first image. +template <typename T> +void CopyImageToWithPadding(const Rect& from_rect, const T& from, + size_t padding, const Rect& to_rect, T* to) { + size_t xextra0 = std::min(padding, from_rect.x0()); + size_t xextra1 = + std::min(padding, from.xsize() - from_rect.x0() - from_rect.xsize()); + size_t yextra0 = std::min(padding, from_rect.y0()); + size_t yextra1 = + std::min(padding, from.ysize() - from_rect.y0() - from_rect.ysize()); + JXL_DASSERT(to_rect.x0() >= xextra0); + JXL_DASSERT(to_rect.y0() >= yextra0); + + return CopyImageTo(Rect(from_rect.x0() - xextra0, from_rect.y0() - yextra0, + from_rect.xsize() + xextra0 + xextra1, + from_rect.ysize() + yextra0 + yextra1), + from, + Rect(to_rect.x0() - xextra0, to_rect.y0() - yextra0, + to_rect.xsize() + xextra0 + xextra1, + to_rect.ysize() + yextra0 + yextra1), + to); +} + +template <class ImageIn, class ImageOut> +void Subtract(const ImageIn& image1, const ImageIn& image2, ImageOut* out) { + using T = typename ImageIn::T; + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + + for (size_t y = 0; y < ysize; ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + T* const JXL_RESTRICT row_out = out->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = row1[x] - row2[x]; + } + } +} + +// In-place. +template <typename Tin, typename Tout> +void SubtractFrom(const Plane<Tin>& what, Plane<Tout>* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstRow(y); + Tout* JXL_RESTRICT row_to = to->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] -= row_what[x]; + } + } +} + +// In-place. +template <typename Tin, typename Tout> +void AddTo(const Plane<Tin>& what, Plane<Tout>* to) { + const size_t xsize = what.xsize(); + const size_t ysize = what.ysize(); + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = what.ConstRow(y); + Tout* JXL_RESTRICT row_to = to->Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] += row_what[x]; + } + } +} + +template <typename Tin, typename Tout> +void AddTo(Rect rectFrom, const Plane<Tin>& what, Rect rectTo, + Plane<Tout>* to) { + JXL_ASSERT(SameSize(rectFrom, rectTo)); + const size_t xsize = rectTo.xsize(); + const size_t ysize = rectTo.ysize(); + for (size_t y = 0; y < ysize; ++y) { + const Tin* JXL_RESTRICT row_what = rectFrom.ConstRow(what, y); + Tout* JXL_RESTRICT row_to = rectTo.Row(to, y); + for (size_t x = 0; x < xsize; ++x) { + row_to[x] += row_what[x]; + } + } +} + +// Returns linear combination of two grayscale images. +template <typename T> +Plane<T> LinComb(const T lambda1, const Plane<T>& image1, const T lambda2, + const Plane<T>& image2) { + const size_t xsize = image1.xsize(); + const size_t ysize = image1.ysize(); + JXL_CHECK(xsize == image2.xsize()); + JXL_CHECK(ysize == image2.ysize()); + Plane<T> out(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + T* const JXL_RESTRICT row_out = out.Row(y); + for (size_t x = 0; x < xsize; ++x) { + row_out[x] = lambda1 * row1[x] + lambda2 * row2[x]; + } + } + return out; +} + +// Multiplies image by lambda in-place +template <typename T> +void ScaleImage(const T lambda, Plane<T>* image) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = lambda * row[x]; + } + } +} + +// Multiplies image by lambda in-place +template <typename T> +void ScaleImage(const T lambda, Image3<T>* image) { + for (size_t c = 0; c < 3; ++c) { + ScaleImage(lambda, &image->Plane(c)); + } +} + +template <typename T> +void FillImage(const T value, Plane<T>* image) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = value; + } + } +} + +template <typename T> +void ZeroFillImage(Plane<T>* image) { + if (image->xsize() == 0) return; + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + memset(row, 0, image->xsize() * sizeof(T)); + } +} + +// Mirrors out of bounds coordinates and returns valid coordinates unchanged. +// We assume the radius (distance outside the image) is small compared to the +// image size, otherwise this might not terminate. +// The mirror is outside the last column (border pixel is also replicated). +static inline int64_t Mirror(int64_t x, const int64_t xsize) { + JXL_DASSERT(xsize != 0); + + // TODO(janwas): replace with branchless version + while (x < 0 || x >= xsize) { + if (x < 0) { + x = -x - 1; + } else { + x = 2 * xsize - 1 - x; + } + } + return x; +} + +// Wrap modes for ensuring X/Y coordinates are in the valid range [0, size): + +// Mirrors (repeating the edge pixel once). Useful for convolutions. +struct WrapMirror { + JXL_INLINE int64_t operator()(const int64_t coord, const int64_t size) const { + return Mirror(coord, size); + } +}; + +// Returns the same coordinate: required for TFNode with Border(), or useful +// when we know "coord" is already valid (e.g. interior of an image). +struct WrapUnchanged { + JXL_INLINE int64_t operator()(const int64_t coord, int64_t /*size*/) const { + return coord; + } +}; + +// Similar to Wrap* but for row pointers (reduces Row() multiplications). + +class WrapRowMirror { + public: + template <class ImageOrView> + WrapRowMirror(const ImageOrView& image, size_t ysize) + : first_row_(image.ConstRow(0)), last_row_(image.ConstRow(ysize - 1)) {} + + const float* operator()(const float* const JXL_RESTRICT row, + const int64_t stride) const { + if (row < first_row_) { + const int64_t num_before = first_row_ - row; + // Mirrored; one row before => row 0, two before = row 1, ... + return first_row_ + num_before - stride; + } + if (row > last_row_) { + const int64_t num_after = row - last_row_; + // Mirrored; one row after => last row, two after = last - 1, ... + return last_row_ - num_after + stride; + } + return row; + } + + private: + const float* const JXL_RESTRICT first_row_; + const float* const JXL_RESTRICT last_row_; +}; + +struct WrapRowUnchanged { + JXL_INLINE const float* operator()(const float* const JXL_RESTRICT row, + int64_t /*stride*/) const { + return row; + } +}; + +// Sets "thickness" pixels on each border to "value". This is faster than +// initializing the entire image and overwriting valid/interior pixels. +template <typename T> +void SetBorder(const size_t thickness, const T value, Plane<T>* image) { + const size_t xsize = image->xsize(); + const size_t ysize = image->ysize(); + // Top: fill entire row + for (size_t y = 0; y < std::min(thickness, ysize); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + std::fill(row, row + xsize, value); + } + + // Bottom: fill entire row + for (size_t y = ysize - thickness; y < ysize; ++y) { + T* const JXL_RESTRICT row = image->Row(y); + std::fill(row, row + xsize, value); + } + + // Left/right: fill the 'columns' on either side, but only if the image is + // big enough that they don't already belong to the top/bottom rows. + if (ysize >= 2 * thickness) { + for (size_t y = thickness; y < ysize - thickness; ++y) { + T* const JXL_RESTRICT row = image->Row(y); + std::fill(row, row + thickness, value); + std::fill(row + xsize - thickness, row + xsize, value); + } + } +} + +// Computes the minimum and maximum pixel value. +template <typename T> +void ImageMinMax(const Plane<T>& image, T* const JXL_RESTRICT min, + T* const JXL_RESTRICT max) { + *min = std::numeric_limits<T>::max(); + *max = std::numeric_limits<T>::lowest(); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* const JXL_RESTRICT row = image.Row(y); + for (size_t x = 0; x < image.xsize(); ++x) { + *min = std::min(*min, row[x]); + *max = std::max(*max, row[x]); + } + } +} + +template <typename T> +Plane<T> ImageFromPacked(const std::vector<T>& packed, const size_t xsize, + const size_t ysize) { + Plane<T> out(xsize, ysize); + for (size_t y = 0; y < ysize; ++y) { + T* const JXL_RESTRICT row = out.Row(y); + const T* const JXL_RESTRICT packed_row = &packed[y * xsize]; + memcpy(row, packed_row, xsize * sizeof(T)); + } + return out; +} + +template <typename T> +void Image3Max(const Image3<T>& image, std::array<T, 3>* out_max) { + for (size_t c = 0; c < 3; ++c) { + T max = std::numeric_limits<T>::min(); + for (size_t y = 0; y < image.ysize(); ++y) { + const T* JXL_RESTRICT row = image.ConstPlaneRow(c, y); + for (size_t x = 0; x < image.xsize(); ++x) { + max = std::max(max, row[x]); + } + } + (*out_max)[c] = max; + } +} + +template <typename T> +std::vector<T> PackedFromImage(const Plane<T>& image, const Rect& rect) { + const size_t xsize = rect.xsize(); + const size_t ysize = rect.ysize(); + std::vector<T> packed(xsize * ysize); + for (size_t y = 0; y < rect.ysize(); ++y) { + memcpy(&packed[y * xsize], rect.ConstRow(image, y), xsize * sizeof(T)); + } + return packed; +} + +template <typename T> +std::vector<T> PackedFromImage(const Plane<T>& image) { + return PackedFromImage(image, Rect(image)); +} + +// Initializes all planes to the same "value". +template <typename T> +void FillImage(const T value, Image3<T>* image) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = value; + } + } + } +} + +template <typename T> +void FillPlane(const T value, Plane<T>* image) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + row[x] = value; + } + } +} + +template <typename T> +void FillImage(const T value, Image3<T>* image, Rect rect) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < rect.ysize(); ++y) { + T* JXL_RESTRICT row = rect.PlaneRow(image, c, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + row[x] = value; + } + } + } +} + +template <typename T> +void FillPlane(const T value, Plane<T>* image, Rect rect) { + for (size_t y = 0; y < rect.ysize(); ++y) { + T* JXL_RESTRICT row = rect.Row(image, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + row[x] = value; + } + } +} + +template <typename T> +void ZeroFillImage(Image3<T>* image) { + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* JXL_RESTRICT row = image->PlaneRow(c, y); + if (image->xsize() != 0) memset(row, 0, image->xsize() * sizeof(T)); + } + } +} + +template <typename T> +void ZeroFillPlane(Plane<T>* image, Rect rect) { + for (size_t y = 0; y < rect.ysize(); ++y) { + T* JXL_RESTRICT row = rect.Row(image, y); + memset(row, 0, rect.xsize() * sizeof(T)); + } +} + +// Same as above, but operates in-place. Assumes that the `in` image was +// allocated large enough. +void PadImageToBlockMultipleInPlace(Image3F* JXL_RESTRICT in, + size_t block_dim = kBlockDim); + +// Downsamples an image by a given factor. +void DownsampleImage(Image3F* opsin, size_t factor); +void DownsampleImage(ImageF* image, size_t factor); + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_OPS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/image_ops_test.cc b/third_party/jpeg-xl/lib/jxl/image_ops_test.cc new file mode 100644 index 0000000000..dfcb2292c5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_ops_test.cc @@ -0,0 +1,163 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/image_ops.h" + +#include <stdint.h> +#include <stdlib.h> + +#include <utility> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +template <typename T> +void TestPacked(const size_t xsize, const size_t ysize) { + Plane<T> image1(xsize, ysize); + RandomFillImage(&image1); + const std::vector<T>& packed = PackedFromImage(image1); + const Plane<T>& image2 = ImageFromPacked(packed, xsize, ysize); + JXL_EXPECT_OK(SamePixels(image1, image2, _)); +} + +TEST(ImageTest, TestPacked) { + TestPacked<uint8_t>(1, 1); + TestPacked<uint8_t>(7, 1); + TestPacked<uint8_t>(1, 7); + + TestPacked<int16_t>(1, 1); + TestPacked<int16_t>(7, 1); + TestPacked<int16_t>(1, 7); + + TestPacked<uint16_t>(1, 1); + TestPacked<uint16_t>(7, 1); + TestPacked<uint16_t>(1, 7); + + TestPacked<float>(1, 1); + TestPacked<float>(7, 1); + TestPacked<float>(1, 7); +} + +// Ensure entire payload is readable/writable for various size/offset combos. +TEST(ImageTest, TestAllocator) { + Rng rng(0); + const size_t k32 = 32; + const size_t kAlign = CacheAligned::kAlignment; + for (size_t size : {k32 * 1, k32 * 2, k32 * 3, k32 * 4, k32 * 5, + CacheAligned::kAlias, 2 * CacheAligned::kAlias + 4}) { + for (size_t offset = 0; offset <= CacheAligned::kAlias; offset += kAlign) { + uint8_t* bytes = + static_cast<uint8_t*>(CacheAligned::Allocate(size, offset)); + JXL_CHECK(reinterpret_cast<uintptr_t>(bytes) % kAlign == 0); + // Ensure we can write/read the last byte. Use RNG to fool the compiler + // into thinking the write is necessary. + memset(bytes, 0, size); + bytes[size - 1] = 1; // greatest element + uint32_t pos = rng.UniformU(0, size - 1); // random but != greatest + JXL_CHECK(bytes[pos] < bytes[size - 1]); + + CacheAligned::Free(bytes); + } + } +} + +template <typename T> +void TestFillImpl(Image3<T>* img, const char* layout) { + FillImage(T(1), img); + for (size_t y = 0; y < img->ysize(); ++y) { + for (size_t c = 0; c < 3; ++c) { + T* JXL_RESTRICT row = img->PlaneRow(c, y); + for (size_t x = 0; x < img->xsize(); ++x) { + if (row[x] != T(1)) { + printf("Not 1 at c=%" PRIuS " %" PRIuS ", %" PRIuS " (%" PRIuS + " x %" PRIuS ") (%s)\n", + c, x, y, img->xsize(), img->ysize(), layout); + abort(); + } + row[x] = T(2); + } + } + } + + // Same for ZeroFillImage and swapped c/y loop ordering. + ZeroFillImage(img); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < img->ysize(); ++y) { + T* JXL_RESTRICT row = img->PlaneRow(c, y); + for (size_t x = 0; x < img->xsize(); ++x) { + if (row[x] != T(0)) { + printf("Not 0 at c=%" PRIuS " %" PRIuS ", %" PRIuS " (%" PRIuS + " x %" PRIuS ") (%s)\n", + c, x, y, img->xsize(), img->ysize(), layout); + abort(); + } + row[x] = T(3); + } + } + } +} + +template <typename T> +void TestFillT() { + for (uint32_t xsize : {0, 1, 15, 16, 31, 32}) { + for (uint32_t ysize : {0, 1, 15, 16, 31, 32}) { + Image3<T> image(xsize, ysize); + TestFillImpl(&image, "size ctor"); + + Image3<T> planar(Plane<T>(xsize, ysize), Plane<T>(xsize, ysize), + Plane<T>(xsize, ysize)); + TestFillImpl(&planar, "planar"); + } + } +} + +// Ensure y/c/x and c/y/x loops visit pixels no more than once. +TEST(ImageTest, TestFill) { + TestFillT<uint8_t>(); + TestFillT<int16_t>(); + TestFillT<float>(); + TestFillT<double>(); +} + +TEST(ImageTest, CopyImageToWithPaddingTest) { + Plane<uint32_t> src(100, 61); + for (size_t y = 0; y < src.ysize(); y++) { + for (size_t x = 0; x < src.xsize(); x++) { + src.Row(y)[x] = x * 1000 + y; + } + } + Rect src_rect(10, 20, 30, 40); + EXPECT_TRUE(src_rect.IsInside(src)); + + Plane<uint32_t> dst(60, 50); + FillImage(0u, &dst); + Rect dst_rect(20, 5, 30, 40); + EXPECT_TRUE(dst_rect.IsInside(dst)); + + CopyImageToWithPadding(src_rect, src, /*padding=*/2, dst_rect, &dst); + + // ysize is + 3 instead of + 4 because we are at the y image boundary on the + // source image. + Rect padded_dst_rect(20 - 2, 5 - 2, 30 + 4, 40 + 3); + for (size_t y = 0; y < dst.ysize(); y++) { + for (size_t x = 0; x < dst.xsize(); x++) { + if (Rect(x, y, 1, 1).IsInside(padded_dst_rect)) { + EXPECT_EQ((x - dst_rect.x0() + src_rect.x0()) * 1000 + + (y - dst_rect.y0() + src_rect.y0()), + dst.Row(y)[x]); + } else { + EXPECT_EQ(0u, dst.Row(y)[x]); + } + } + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/image_test_utils.h b/third_party/jpeg-xl/lib/jxl/image_test_utils.h new file mode 100644 index 0000000000..e7d72285e6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/image_test_utils.h @@ -0,0 +1,257 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_IMAGE_TEST_UTILS_H_ +#define LIB_JXL_IMAGE_TEST_UTILS_H_ + +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS +#endif + +#include <inttypes.h> +#include <stddef.h> +#include <stdint.h> + +#include <cmath> +#include <limits> +#include <sstream> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/image.h" + +namespace jxl { + +template <typename T> +bool SamePixels(const Plane<T>& image1, const Plane<T>& image2, + std::stringstream& failures) { + const Rect rect(image1); + JXL_CHECK(SameSize(image1, image2)); + size_t mismatches = 0; + for (size_t y = rect.y0(); y < rect.ysize(); ++y) { + const T* const JXL_RESTRICT row1 = image1.Row(y); + const T* const JXL_RESTRICT row2 = image2.Row(y); + for (size_t x = rect.x0(); x < rect.xsize(); ++x) { + if (row1[x] != row2[x]) { + failures << "pixel mismatch" << x << ", " << y << ": " + << double(row1[x]) << " != " << double(row2[x]) << "\n"; + if (++mismatches > 4) { + return false; + } + } + } + } + return mismatches == 0; +} + +template <typename T> +bool SamePixels(const Image3<T>& image1, const Image3<T>& image2, + std::stringstream& failures) { + JXL_CHECK(SameSize(image1, image2)); + for (size_t c = 0; c < 3; ++c) { + if (!SamePixels(image1.Plane(c), image2.Plane(c), failures)) { + return false; + } + } + return true; +} + +// Use for floating-point images with fairly large numbers; tolerates small +// absolute errors and/or small relative errors. +template <typename T> +bool VerifyRelativeError(const Plane<T>& expected, const Plane<T>& actual, + const double threshold_l1, + const double threshold_relative, + std::stringstream& failures, const intptr_t border = 0, + const size_t c = 0) { + JXL_CHECK(SameSize(expected, actual)); + const intptr_t xsize = expected.xsize(); + const intptr_t ysize = expected.ysize(); + + // Max over current scanline to give a better idea whether there are + // systematic errors or just one outlier. Invalid if negative. + double max_l1 = -1; + double max_relative = -1; + bool any_bad = false; + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + const T* const JXL_RESTRICT row_actual = actual.Row(y); + for (intptr_t x = border; x < xsize - border; ++x) { + const double l1 = std::abs(row_expected[x] - row_actual[x]); + + // Cannot compute relative, only check/update L1. + if (std::abs(row_expected[x]) < 1E-10) { + if (l1 > threshold_l1) { + any_bad = true; + max_l1 = std::max(max_l1, l1); + } + } else { + const double relative = l1 / std::abs(double(row_expected[x])); + if (l1 > threshold_l1 && relative > threshold_relative) { + // Fails both tolerances => will exit below, update max_*. + any_bad = true; + max_l1 = std::max(max_l1, l1); + max_relative = std::max(max_relative, relative); + } + } + } + } + if (!any_bad) { + return true; + } + // Never had a valid relative value, don't print it. + if (max_relative < 0) { + fprintf(stderr, "c=%" PRIu64 ": max +/- %E exceeds +/- %.2E\n", + static_cast<uint64_t>(c), max_l1, threshold_l1); + } else { + fprintf(stderr, + "c=%" PRIu64 ": max +/- %E, x %E exceeds +/- %.2E, x %.2E\n", + static_cast<uint64_t>(c), max_l1, max_relative, threshold_l1, + threshold_relative); + } + // Dump the expected image and actual image if the region is small enough. + const intptr_t kMaxTestDumpSize = 16; + if (xsize <= kMaxTestDumpSize + 2 * border && + ysize <= kMaxTestDumpSize + 2 * border) { + fprintf(stderr, "Expected image:\n"); + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + for (intptr_t x = border; x < xsize - border; ++x) { + fprintf(stderr, "%10lf ", static_cast<double>(row_expected[x])); + } + fprintf(stderr, "\n"); + } + + fprintf(stderr, "Actual image:\n"); + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + const T* const JXL_RESTRICT row_actual = actual.Row(y); + for (intptr_t x = border; x < xsize - border; ++x) { + const double l1 = std::abs(row_expected[x] - row_actual[x]); + + bool bad = l1 > threshold_l1; + if (row_expected[x] > 1E-10) { + const double relative = l1 / std::abs(double(row_expected[x])); + bad &= relative > threshold_relative; + } + if (bad) { + fprintf(stderr, "%10lf ", static_cast<double>(row_actual[x])); + } else { + fprintf(stderr, "%10s ", "=="); + } + } + fprintf(stderr, "\n"); + } + } + + // Find first failing x for further debugging. + for (intptr_t y = border; y < ysize - border; ++y) { + const T* const JXL_RESTRICT row_expected = expected.Row(y); + const T* const JXL_RESTRICT row_actual = actual.Row(y); + + for (intptr_t x = border; x < xsize - border; ++x) { + const double l1 = std::abs(row_expected[x] - row_actual[x]); + + bool bad = l1 > threshold_l1; + if (row_expected[x] > 1E-10) { + const double relative = l1 / std::abs(double(row_expected[x])); + bad &= relative > threshold_relative; + } + if (bad) { + failures << x << ", " << y << " (" << expected.xsize() << " x " + << expected.ysize() << ") expected " + << static_cast<double>(row_expected[x]) << " actual " + << static_cast<double>(row_actual[x]); + return false; + } + } + } + return false; +} + +template <typename T> +bool VerifyRelativeError(const Image3<T>& expected, const Image3<T>& actual, + const float threshold_l1, + const float threshold_relative, + std::stringstream& failures, + const intptr_t border = 0) { + for (size_t c = 0; c < 3; ++c) { + bool ok = + VerifyRelativeError(expected.Plane(c), actual.Plane(c), threshold_l1, + threshold_relative, failures, border, c); + if (!ok) { + return false; + } + } + return true; +} + +template <typename T, typename U = T> +void GenerateImage(Rng& rng, Plane<T>* image, U begin, U end) { + for (size_t y = 0; y < image->ysize(); ++y) { + T* const JXL_RESTRICT row = image->Row(y); + for (size_t x = 0; x < image->xsize(); ++x) { + if (std::is_same<T, float>::value || std::is_same<T, double>::value) { + row[x] = rng.UniformF(begin, end); + } else if (std::is_signed<T>::value) { + row[x] = rng.UniformI(begin, end); + } else { + row[x] = rng.UniformU(begin, end); + } + } + } +} + +template <typename T> +void RandomFillImage(Plane<T>* image, const T begin, const T end, + const int seed = 129) { + Rng rng(seed); + GenerateImage(rng, image, begin, end); +} + +template <typename T> +typename std::enable_if<std::is_integral<T>::value>::type RandomFillImage( + Plane<T>* image) { + Rng rng(129); + GenerateImage(rng, image, int64_t(0), + int64_t(std::numeric_limits<T>::max()) + 1); +} + +JXL_INLINE void RandomFillImage(Plane<float>* image) { + Rng rng(129); + GenerateImage(rng, image, 0.0f, std::numeric_limits<float>::max()); +} + +template <typename T, typename U> +void GenerateImage(Rng& rng, Image3<T>* image, U begin, U end) { + for (size_t c = 0; c < 3; ++c) { + GenerateImage(rng, &image->Plane(c), begin, end); + } +} + +template <typename T> +typename std::enable_if<std::is_integral<T>::value>::type RandomFillImage( + Image3<T>* image) { + Rng rng(129); + GenerateImage(rng, image, int64_t(0), + int64_t(std::numeric_limits<T>::max()) + 1); +} + +JXL_INLINE void RandomFillImage(Image3F* image) { + Rng rng(129); + GenerateImage(rng, image, 0.0f, std::numeric_limits<float>::max()); +} + +template <typename T, typename U> +void RandomFillImage(Image3<T>* image, const U begin, const U end, + const int seed = 129) { + Rng rng(seed); + GenerateImage(rng, image, begin, end); +} + +} // namespace jxl + +#endif // LIB_JXL_IMAGE_TEST_UTILS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/inverse_mtf-inl.h b/third_party/jpeg-xl/lib/jxl/inverse_mtf-inl.h new file mode 100644 index 0000000000..fcb01d7396 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/inverse_mtf-inl.h @@ -0,0 +1,90 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// SIMDified inverse-move-to-front transform. + +#if defined(LIB_JXL_INVERSE_MTF_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_INVERSE_MTF_INL_H_ +#undef LIB_JXL_INVERSE_MTF_INL_H_ +#else +#define LIB_JXL_INVERSE_MTF_INL_H_ +#endif + +#include <hwy/highway.h> + +#include "lib/jxl/sanitizers.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::FirstN; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::Load; +using hwy::HWY_NAMESPACE::LoadU; +using hwy::HWY_NAMESPACE::StoreU; + +inline void MoveToFront(uint8_t* v, uint8_t index) { + uint8_t value = v[index]; + uint8_t i = index; + if (i < 4) { + for (; i; --i) v[i] = v[i - 1]; + } else { + const HWY_CAPPED(uint8_t, 64) d; + int tail = i & (Lanes(d) - 1); + if (tail) { + i -= tail; + const auto vec = Load(d, v + i); + const auto prev = LoadU(d, v + i + 1); + StoreU(IfThenElse(FirstN(d, tail), vec, prev), d, v + i + 1); + } + while (i) { + i -= Lanes(d); + const auto vec = Load(d, v + i); + StoreU(vec, d, v + i + 1); + } + } + v[0] = value; +} + +inline void InverseMoveToFrontTransform(uint8_t* v, int v_len) { + HWY_ALIGN uint8_t mtf[256 + 64]; + int i; + for (i = 0; i < 256; ++i) { + mtf[i] = static_cast<uint8_t>(i); + } +#if JXL_MEMORY_SANITIZER + const HWY_CAPPED(uint8_t, 64) d; + for (size_t j = 0; j < Lanes(d); ++j) { + mtf[256 + j] = 0; + } +#endif // JXL_MEMORY_SANITIZER + for (i = 0; i < v_len; ++i) { + uint8_t index = v[i]; + v[i] = mtf[index]; + if (index) MoveToFront(mtf, index); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_INVERSE_MTF_INL_H_ + +#if HWY_ONCE +#ifndef INVERSE_MTF_ONCE +#define INVERSE_MTF_ONCE + +namespace jxl { +inline void InverseMoveToFrontTransform(uint8_t* v, int v_len) { + return HWY_STATIC_DISPATCH(InverseMoveToFrontTransform)(v, v_len); +} +} // namespace jxl + +#endif // INVERSE_MTF_ONCE +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.cc b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.cc new file mode 100644 index 0000000000..9763786453 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.cc @@ -0,0 +1,145 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/dec_jpeg_data.h" + +#include <brotli/decode.h> + +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { +namespace jpeg { +Status DecodeJPEGData(Span<const uint8_t> encoded, JPEGData* jpeg_data) { + Status ret = true; + const uint8_t* in = encoded.data(); + size_t available_in = encoded.size(); + { + BitReader br(encoded); + BitReaderScopedCloser br_closer(&br, &ret); + JXL_RETURN_IF_ERROR(Bundle::Read(&br, jpeg_data)); + JXL_RETURN_IF_ERROR(br.JumpToByteBoundary()); + in += br.TotalBitsConsumed() / 8; + available_in -= br.TotalBitsConsumed() / 8; + } + JXL_RETURN_IF_ERROR(ret); + + BrotliDecoderState* brotli_dec = + BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + + struct BrotliDecDeleter { + BrotliDecoderState* brotli_dec; + ~BrotliDecDeleter() { BrotliDecoderDestroyInstance(brotli_dec); } + } brotli_dec_deleter{brotli_dec}; + + BrotliDecoderResult result = + BrotliDecoderResult::BROTLI_DECODER_RESULT_SUCCESS; + + auto br_read = [&](std::vector<uint8_t>& data) -> Status { + size_t available_out = data.size(); + uint8_t* out = data.data(); + while (available_out != 0) { + if (BrotliDecoderIsFinished(brotli_dec)) { + return JXL_FAILURE("Not enough decompressed output"); + } + uint8_t* next_out_before = out; + size_t avail_out_before = available_out; + msan::MemoryIsInitialized(in, available_in); + result = BrotliDecoderDecompressStream(brotli_dec, &available_in, &in, + &available_out, &out, nullptr); + if (result != + BrotliDecoderResult::BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT && + result != BrotliDecoderResult::BROTLI_DECODER_RESULT_SUCCESS) { + return JXL_FAILURE( + "Brotli decoding error: %s\n", + BrotliDecoderErrorString(BrotliDecoderGetErrorCode(brotli_dec))); + } + msan::UnpoisonMemory(next_out_before, avail_out_before - available_out); + } + return true; + }; + size_t num_icc = 0; + for (size_t i = 0; i < jpeg_data->app_data.size(); i++) { + auto& marker = jpeg_data->app_data[i]; + if (jpeg_data->app_marker_type[i] != AppMarkerType::kUnknown) { + // Set the size of the marker. + size_t size_minus_1 = marker.size() - 1; + marker[1] = size_minus_1 >> 8; + marker[2] = size_minus_1 & 0xFF; + if (jpeg_data->app_marker_type[i] == AppMarkerType::kICC) { + if (marker.size() < 17) { + return JXL_FAILURE("ICC markers must be at least 17 bytes"); + } + marker[0] = 0xE2; + memcpy(&marker[3], kIccProfileTag, sizeof kIccProfileTag); + marker[15] = ++num_icc; + } + } else { + JXL_RETURN_IF_ERROR(br_read(marker)); + if (marker[1] * 256u + marker[2] + 1u != marker.size()) { + return JXL_FAILURE("Incorrect marker size"); + } + } + } + for (size_t i = 0; i < jpeg_data->app_data.size(); i++) { + auto& marker = jpeg_data->app_data[i]; + if (jpeg_data->app_marker_type[i] == AppMarkerType::kICC) { + marker[16] = num_icc; + } + if (jpeg_data->app_marker_type[i] == AppMarkerType::kExif) { + marker[0] = 0xE1; + if (marker.size() < 3 + sizeof kExifTag) { + return JXL_FAILURE("Incorrect Exif marker size"); + } + memcpy(&marker[3], kExifTag, sizeof kExifTag); + } + if (jpeg_data->app_marker_type[i] == AppMarkerType::kXMP) { + marker[0] = 0xE1; + if (marker.size() < 3 + sizeof kXMPTag) { + return JXL_FAILURE("Incorrect XMP marker size"); + } + memcpy(&marker[3], kXMPTag, sizeof kXMPTag); + } + } + // TODO(eustas): actually inject ICC profile and check it fits perfectly. + for (size_t i = 0; i < jpeg_data->com_data.size(); i++) { + auto& marker = jpeg_data->com_data[i]; + JXL_RETURN_IF_ERROR(br_read(marker)); + if (marker[1] * 256u + marker[2] + 1u != marker.size()) { + return JXL_FAILURE("Incorrect marker size"); + } + } + for (size_t i = 0; i < jpeg_data->inter_marker_data.size(); i++) { + JXL_RETURN_IF_ERROR(br_read(jpeg_data->inter_marker_data[i])); + } + JXL_RETURN_IF_ERROR(br_read(jpeg_data->tail_data)); + + // Check if there is more decompressed output. + size_t available_out = 1; + uint64_t sink; + uint8_t* next_out = reinterpret_cast<uint8_t*>(&sink); + result = BrotliDecoderDecompressStream(brotli_dec, &available_in, &in, + &available_out, &next_out, nullptr); + if (available_out == 0 || + result == BrotliDecoderResult::BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + return JXL_FAILURE("Excess data in compressed stream"); + } + if (result == BrotliDecoderResult::BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT) { + return JXL_FAILURE("Incomplete brotli-stream"); + } + if (!BrotliDecoderIsFinished(brotli_dec) || + result != BrotliDecoderResult::BROTLI_DECODER_RESULT_SUCCESS) { + return JXL_FAILURE("Corrupted brotli-stream"); + } + if (available_in != 0) { + return JXL_FAILURE("Unused data after brotli stream"); + } + + return true; +} +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.h b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.h new file mode 100644 index 0000000000..b9d50bf9f8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data.h @@ -0,0 +1,19 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_DATA_H_ +#define LIB_JXL_JPEG_DEC_JPEG_DATA_H_ + +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { +Status DecodeJPEGData(Span<const uint8_t> encoded, JPEGData* jpeg_data); +} +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_DATA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.cc b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.cc new file mode 100644 index 0000000000..64560d9ab0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.cc @@ -0,0 +1,1063 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/dec_jpeg_data_writer.h" + +#include <stdlib.h> +#include <string.h> /* for memset, memcpy */ + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <deque> +#include <string> +#include <vector> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/jpeg/dec_jpeg_serialization_state.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +namespace { + +enum struct SerializationStatus { + NEEDS_MORE_INPUT, + NEEDS_MORE_OUTPUT, + ERROR, + DONE +}; + +const int kJpegPrecision = 8; + +// JpegBitWriter: buffer size +const size_t kJpegBitWriterChunkSize = 16384; + +// Returns non-zero if and only if x has a zero byte, i.e. one of +// x & 0xff, x & 0xff00, ..., x & 0xff00000000000000 is zero. +static JXL_INLINE uint64_t HasZeroByte(uint64_t x) { + return (x - 0x0101010101010101ULL) & ~x & 0x8080808080808080ULL; +} + +void JpegBitWriterInit(JpegBitWriter* bw, + std::deque<OutputChunk>* output_queue) { + bw->output = output_queue; + bw->chunk = OutputChunk(kJpegBitWriterChunkSize); + bw->pos = 0; + bw->put_buffer = 0; + bw->put_bits = 64; + bw->healthy = true; + bw->data = bw->chunk.buffer->data(); +} + +static JXL_NOINLINE void SwapBuffer(JpegBitWriter* bw) { + bw->chunk.len = bw->pos; + bw->output->emplace_back(std::move(bw->chunk)); + bw->chunk = OutputChunk(kJpegBitWriterChunkSize); + bw->data = bw->chunk.buffer->data(); + bw->pos = 0; +} + +static JXL_INLINE void Reserve(JpegBitWriter* bw, size_t n_bytes) { + if (JXL_UNLIKELY((bw->pos + n_bytes) > kJpegBitWriterChunkSize)) { + SwapBuffer(bw); + } +} + +/** + * Writes the given byte to the output, writes an extra zero if byte is 0xFF. + * + * This method is "careless" - caller must make sure that there is enough + * space in the output buffer. Emits up to 2 bytes to buffer. + */ +static JXL_INLINE void EmitByte(JpegBitWriter* bw, int byte) { + bw->data[bw->pos] = byte; + bw->data[bw->pos + 1] = 0; + bw->pos += (byte != 0xFF ? 1 : 2); +} + +static JXL_INLINE void DischargeBitBuffer(JpegBitWriter* bw, int nbits, + uint64_t bits) { + // At this point we are ready to emit the put_buffer to the output. + // The JPEG format requires that after every 0xff byte in the entropy + // coded section, there is a zero byte, therefore we first check if any of + // the 8 bytes of put_buffer is 0xFF. + bw->put_buffer |= (bits >> -bw->put_bits); + if (JXL_UNLIKELY(HasZeroByte(~bw->put_buffer))) { + // We have a 0xFF byte somewhere, examine each byte and append a zero + // byte if necessary. + EmitByte(bw, (bw->put_buffer >> 56) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 48) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 40) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 32) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 24) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 16) & 0xFF); + EmitByte(bw, (bw->put_buffer >> 8) & 0xFF); + EmitByte(bw, (bw->put_buffer) & 0xFF); + } else { + // We don't have any 0xFF bytes, output all 8 bytes without checking. + StoreBE64(bw->put_buffer, bw->data + bw->pos); + bw->pos += 8; + } + + bw->put_bits += 64; + bw->put_buffer = bits << bw->put_bits; +} + +static JXL_INLINE void WriteBits(JpegBitWriter* bw, int nbits, uint64_t bits) { + JXL_DASSERT(nbits > 0); + bw->put_bits -= nbits; + if (JXL_UNLIKELY(bw->put_bits < 0)) { + if (JXL_UNLIKELY(nbits > 64)) { + bw->put_bits += nbits; + bw->healthy = false; + } else { + DischargeBitBuffer(bw, nbits, bits); + } + } else { + bw->put_buffer |= (bits << bw->put_bits); + } +} + +void EmitMarker(JpegBitWriter* bw, int marker) { + Reserve(bw, 2); + JXL_DASSERT(marker != 0xFF); + bw->data[bw->pos++] = 0xFF; + bw->data[bw->pos++] = marker; +} + +bool JumpToByteBoundary(JpegBitWriter* bw, const uint8_t** pad_bits, + const uint8_t* pad_bits_end) { + size_t n_bits = bw->put_bits & 7u; + uint8_t pad_pattern; + if (*pad_bits == nullptr) { + pad_pattern = (1u << n_bits) - 1; + } else { + pad_pattern = 0; + const uint8_t* src = *pad_bits; + // TODO(eustas): bitwise reading looks insanely ineffective... + while (n_bits--) { + pad_pattern <<= 1; + if (src >= pad_bits_end) return false; + // TODO(eustas): DCHECK *src == {0, 1} + pad_pattern |= !!*(src++); + } + *pad_bits = src; + } + + Reserve(bw, 16); + + while (bw->put_bits <= 56) { + int c = (bw->put_buffer >> 56) & 0xFF; + EmitByte(bw, c); + bw->put_buffer <<= 8; + bw->put_bits += 8; + } + if (bw->put_bits < 64) { + int pad_mask = 0xFFu >> (64 - bw->put_bits); + int c = ((bw->put_buffer >> 56) & ~pad_mask) | pad_pattern; + EmitByte(bw, c); + } + bw->put_buffer = 0; + bw->put_bits = 64; + + return true; +} + +void JpegBitWriterFinish(JpegBitWriter* bw) { + if (bw->pos == 0) return; + bw->chunk.len = bw->pos; + bw->output->emplace_back(std::move(bw->chunk)); + bw->chunk = OutputChunk(nullptr, 0); + bw->data = nullptr; + bw->pos = 0; +} + +void DCTCodingStateInit(DCTCodingState* s) { + s->eob_run_ = 0; + s->cur_ac_huff_ = nullptr; + s->refinement_bits_.clear(); + s->refinement_bits_.reserve(64); +} + +static JXL_INLINE void WriteSymbol(int symbol, HuffmanCodeTable* table, + JpegBitWriter* bw) { + WriteBits(bw, table->depth[symbol], table->code[symbol]); +} + +static JXL_INLINE void WriteSymbolBits(int symbol, HuffmanCodeTable* table, + JpegBitWriter* bw, int nbits, + uint64_t bits) { + WriteBits(bw, nbits + table->depth[symbol], + bits | (table->code[symbol] << nbits)); +} + +// Emit all buffered data to the bit stream using the given Huffman code and +// bit writer. +static JXL_INLINE void Flush(DCTCodingState* s, JpegBitWriter* bw) { + if (s->eob_run_ > 0) { + Reserve(bw, 16); + int nbits = FloorLog2Nonzero<uint32_t>(s->eob_run_); + int symbol = nbits << 4u; + WriteSymbol(symbol, s->cur_ac_huff_, bw); + if (nbits > 0) { + WriteBits(bw, nbits, s->eob_run_ & ((1 << nbits) - 1)); + } + s->eob_run_ = 0; + } + const size_t kStride = 124; // (515 - 16) / 2 / 2 + size_t num_words = s->refinement_bits_count_ >> 4; + size_t i = 0; + while (i < num_words) { + size_t limit = std::min(i + kStride, num_words); + Reserve(bw, 512); + for (; i < limit; ++i) { + WriteBits(bw, 16, s->refinement_bits_[i]); + } + } + Reserve(bw, 16); + size_t tail = s->refinement_bits_count_ & 0xF; + if (tail) { + WriteBits(bw, tail, s->refinement_bits_.back()); + } + s->refinement_bits_.clear(); + s->refinement_bits_count_ = 0; +} + +// Buffer some more data at the end-of-band (the last non-zero or newly +// non-zero coefficient within the [Ss, Se] spectral band). +static JXL_INLINE void BufferEndOfBand(DCTCodingState* s, + HuffmanCodeTable* ac_huff, + const int* new_bits_array, + size_t new_bits_count, + JpegBitWriter* bw) { + if (s->eob_run_ == 0) { + s->cur_ac_huff_ = ac_huff; + } + ++s->eob_run_; + if (new_bits_count) { + uint64_t new_bits = 0; + for (size_t i = 0; i < new_bits_count; ++i) { + new_bits = (new_bits << 1) | new_bits_array[i]; + } + size_t tail = s->refinement_bits_count_ & 0xF; + if (tail) { // First stuff the tail item + size_t stuff_bits_count = std::min(16 - tail, new_bits_count); + uint16_t stuff_bits = new_bits >> (new_bits_count - stuff_bits_count); + stuff_bits &= ((1u << stuff_bits_count) - 1); + s->refinement_bits_.back() = + (s->refinement_bits_.back() << stuff_bits_count) | stuff_bits; + new_bits_count -= stuff_bits_count; + s->refinement_bits_count_ += stuff_bits_count; + } + while (new_bits_count >= 16) { + s->refinement_bits_.push_back(new_bits >> (new_bits_count - 16)); + new_bits_count -= 16; + s->refinement_bits_count_ += 16; + } + if (new_bits_count) { + s->refinement_bits_.push_back(new_bits & ((1u << new_bits_count) - 1)); + s->refinement_bits_count_ += new_bits_count; + } + } + + if (s->eob_run_ == 0x7FFF) { + Flush(s, bw); + } +} + +bool BuildHuffmanCodeTable(const JPEGHuffmanCode& huff, + HuffmanCodeTable* table) { + int huff_code[kJpegHuffmanAlphabetSize]; + // +1 for a sentinel element. + uint32_t huff_size[kJpegHuffmanAlphabetSize + 1]; + int p = 0; + for (size_t l = 1; l <= kJpegHuffmanMaxBitLength; ++l) { + int i = huff.counts[l]; + if (p + i > kJpegHuffmanAlphabetSize + 1) { + return false; + } + while (i--) huff_size[p++] = l; + } + + if (p == 0) { + return true; + } + + // Reuse sentinel element. + int last_p = p - 1; + huff_size[last_p] = 0; + + int code = 0; + uint32_t si = huff_size[0]; + p = 0; + while (huff_size[p]) { + while ((huff_size[p]) == si) { + huff_code[p++] = code; + code++; + } + code <<= 1; + si++; + } + for (p = 0; p < last_p; p++) { + int i = huff.values[p]; + table->depth[i] = huff_size[p]; + table->code[i] = huff_code[p]; + } + return true; +} + +bool EncodeSOI(SerializationState* state) { + state->output_queue.push_back(OutputChunk({0xFF, 0xD8})); + return true; +} + +bool EncodeEOI(const JPEGData& jpg, SerializationState* state) { + state->output_queue.push_back(OutputChunk({0xFF, 0xD9})); + state->output_queue.emplace_back(jpg.tail_data); + return true; +} + +bool EncodeSOF(const JPEGData& jpg, uint8_t marker, SerializationState* state) { + if (marker <= 0xC2) state->is_progressive = (marker == 0xC2); + + const size_t n_comps = jpg.components.size(); + const size_t marker_len = 8 + 3 * n_comps; + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = marker; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + data[pos++] = kJpegPrecision; + data[pos++] = jpg.height >> 8u; + data[pos++] = jpg.height & 0xFFu; + data[pos++] = jpg.width >> 8u; + data[pos++] = jpg.width & 0xFFu; + data[pos++] = n_comps; + for (size_t i = 0; i < n_comps; ++i) { + data[pos++] = jpg.components[i].id; + data[pos++] = ((jpg.components[i].h_samp_factor << 4u) | + (jpg.components[i].v_samp_factor)); + const size_t quant_idx = jpg.components[i].quant_idx; + if (quant_idx >= jpg.quant.size()) return false; + data[pos++] = jpg.quant[quant_idx].index; + } + return true; +} + +bool EncodeSOS(const JPEGData& jpg, const JPEGScanInfo& scan_info, + SerializationState* state) { + const size_t n_scans = scan_info.num_components; + const size_t marker_len = 6 + 2 * n_scans; + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = 0xDA; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + data[pos++] = n_scans; + for (size_t i = 0; i < n_scans; ++i) { + const JPEGComponentScanInfo& si = scan_info.components[i]; + if (si.comp_idx >= jpg.components.size()) return false; + data[pos++] = jpg.components[si.comp_idx].id; + data[pos++] = (si.dc_tbl_idx << 4u) + si.ac_tbl_idx; + } + data[pos++] = scan_info.Ss; + data[pos++] = scan_info.Se; + data[pos++] = ((scan_info.Ah << 4u) | (scan_info.Al)); + return true; +} + +bool EncodeDHT(const JPEGData& jpg, SerializationState* state) { + const std::vector<JPEGHuffmanCode>& huffman_code = jpg.huffman_code; + + size_t marker_len = 2; + for (size_t i = state->dht_index; i < huffman_code.size(); ++i) { + const JPEGHuffmanCode& huff = huffman_code[i]; + marker_len += kJpegHuffmanMaxBitLength; + for (size_t j = 0; j < huff.counts.size(); ++j) { + marker_len += huff.counts[j]; + } + if (huff.is_last) break; + } + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = 0xC4; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + while (true) { + const size_t huffman_code_index = state->dht_index++; + if (huffman_code_index >= huffman_code.size()) { + return false; + } + const JPEGHuffmanCode& huff = huffman_code[huffman_code_index]; + size_t index = huff.slot_id; + HuffmanCodeTable* huff_table; + if (index & 0x10) { + index -= 0x10; + huff_table = &state->ac_huff_table[index]; + } else { + huff_table = &state->dc_huff_table[index]; + } + // TODO(eustas): cache + huff_table->InitDepths(127); + if (!BuildHuffmanCodeTable(huff, huff_table)) { + return false; + } + huff_table->initialized = true; + size_t total_count = 0; + size_t max_length = 0; + for (size_t i = 0; i < huff.counts.size(); ++i) { + if (huff.counts[i] != 0) { + max_length = i; + } + total_count += huff.counts[i]; + } + --total_count; + data[pos++] = huff.slot_id; + for (size_t i = 1; i <= kJpegHuffmanMaxBitLength; ++i) { + data[pos++] = (i == max_length ? huff.counts[i] - 1 : huff.counts[i]); + } + for (size_t i = 0; i < total_count; ++i) { + data[pos++] = huff.values[i]; + } + if (huff.is_last) break; + } + return true; +} + +bool EncodeDQT(const JPEGData& jpg, SerializationState* state) { + int marker_len = 2; + for (size_t i = state->dqt_index; i < jpg.quant.size(); ++i) { + const JPEGQuantTable& table = jpg.quant[i]; + marker_len += 1 + (table.precision ? 2 : 1) * kDCTBlockSize; + if (table.is_last) break; + } + state->output_queue.emplace_back(marker_len + 2); + uint8_t* data = state->output_queue.back().buffer->data(); + size_t pos = 0; + data[pos++] = 0xFF; + data[pos++] = 0xDB; + data[pos++] = marker_len >> 8u; + data[pos++] = marker_len & 0xFFu; + while (true) { + const size_t idx = state->dqt_index++; + if (idx >= jpg.quant.size()) { + return false; // corrupt input + } + const JPEGQuantTable& table = jpg.quant[idx]; + data[pos++] = (table.precision << 4u) + table.index; + for (size_t i = 0; i < kDCTBlockSize; ++i) { + int val_idx = kJPEGNaturalOrder[i]; + int val = table.values[val_idx]; + if (table.precision) { + data[pos++] = val >> 8u; + } + data[pos++] = val & 0xFFu; + } + if (table.is_last) break; + } + return true; +} + +bool EncodeDRI(const JPEGData& jpg, SerializationState* state) { + state->seen_dri_marker = true; + OutputChunk dri_marker = {0xFF, + 0xDD, + 0, + 4, + static_cast<uint8_t>(jpg.restart_interval >> 8), + static_cast<uint8_t>(jpg.restart_interval & 0xFF)}; + state->output_queue.push_back(std::move(dri_marker)); + return true; +} + +bool EncodeRestart(uint8_t marker, SerializationState* state) { + state->output_queue.push_back(OutputChunk({0xFF, marker})); + return true; +} + +bool EncodeAPP(const JPEGData& jpg, uint8_t marker, SerializationState* state) { + // TODO(eustas): check that marker corresponds to payload? + (void)marker; + + size_t app_index = state->app_index++; + if (app_index >= jpg.app_data.size()) return false; + state->output_queue.push_back(OutputChunk({0xFF})); + state->output_queue.emplace_back(jpg.app_data[app_index]); + return true; +} + +bool EncodeCOM(const JPEGData& jpg, SerializationState* state) { + size_t com_index = state->com_index++; + if (com_index >= jpg.com_data.size()) return false; + state->output_queue.push_back(OutputChunk({0xFF})); + state->output_queue.emplace_back(jpg.com_data[com_index]); + return true; +} + +bool EncodeInterMarkerData(const JPEGData& jpg, SerializationState* state) { + size_t index = state->data_index++; + if (index >= jpg.inter_marker_data.size()) return false; + state->output_queue.emplace_back(jpg.inter_marker_data[index]); + return true; +} + +bool EncodeDCTBlockSequential(const coeff_t* coeffs, HuffmanCodeTable* dc_huff, + HuffmanCodeTable* ac_huff, int num_zero_runs, + coeff_t* last_dc_coeff, JpegBitWriter* bw) { + coeff_t temp2; + coeff_t temp; + coeff_t litmus = 0; + temp2 = coeffs[0]; + temp = temp2 - *last_dc_coeff; + *last_dc_coeff = temp2; + temp2 = temp >> (8 * sizeof(coeff_t) - 1); + temp += temp2; + temp2 ^= temp; + + int dc_nbits = (temp2 == 0) ? 0 : (FloorLog2Nonzero<uint32_t>(temp2) + 1); + WriteSymbol(dc_nbits, dc_huff, bw); +#if false + // If the input is corrupt, this could be triggered. Checking is + // costly though, so it makes more sense to avoid this branch. + // (producing a corrupt JPEG when the input is corrupt, instead + // of catching it and returning error) + if (dc_nbits >= 12) return false; +#endif + if (dc_nbits) { + WriteBits(bw, dc_nbits, temp & ((1u << dc_nbits) - 1)); + } + int16_t r = 0; + + for (size_t i = 1; i < 64; i++) { + if ((temp = coeffs[kJPEGNaturalOrder[i]]) == 0) { + r++; + } else { + temp2 = temp >> (8 * sizeof(coeff_t) - 1); + temp += temp2; + temp2 ^= temp; + if (JXL_UNLIKELY(r > 15)) { + WriteSymbol(0xf0, ac_huff, bw); + r -= 16; + if (r > 15) { + WriteSymbol(0xf0, ac_huff, bw); + r -= 16; + } + if (r > 15) { + WriteSymbol(0xf0, ac_huff, bw); + r -= 16; + } + } + litmus |= temp2; + int ac_nbits = + FloorLog2Nonzero<uint32_t>(static_cast<uint16_t>(temp2)) + 1; + int symbol = (r << 4u) + ac_nbits; + WriteSymbolBits(symbol, ac_huff, bw, ac_nbits, + temp & ((1 << ac_nbits) - 1)); + r = 0; + } + } + + for (int i = 0; i < num_zero_runs; ++i) { + WriteSymbol(0xf0, ac_huff, bw); + r -= 16; + } + if (r > 0) { + WriteSymbol(0, ac_huff, bw); + } + return (litmus >= 0); +} + +bool EncodeDCTBlockProgressive(const coeff_t* coeffs, HuffmanCodeTable* dc_huff, + HuffmanCodeTable* ac_huff, int Ss, int Se, + int Al, int num_zero_runs, + DCTCodingState* coding_state, + coeff_t* last_dc_coeff, JpegBitWriter* bw) { + bool eob_run_allowed = Ss > 0; + coeff_t temp2; + coeff_t temp; + if (Ss == 0) { + temp2 = coeffs[0] >> Al; + temp = temp2 - *last_dc_coeff; + *last_dc_coeff = temp2; + temp2 = temp; + if (temp < 0) { + temp = -temp; + if (temp < 0) return false; + temp2--; + } + int nbits = (temp == 0) ? 0 : (FloorLog2Nonzero<uint32_t>(temp) + 1); + WriteSymbol(nbits, dc_huff, bw); + if (nbits) { + WriteBits(bw, nbits, temp2 & ((1 << nbits) - 1)); + } + ++Ss; + } + if (Ss > Se) { + return true; + } + int r = 0; + for (int k = Ss; k <= Se; ++k) { + if ((temp = coeffs[kJPEGNaturalOrder[k]]) == 0) { + r++; + continue; + } + if (temp < 0) { + temp = -temp; + if (temp < 0) return false; + temp >>= Al; + temp2 = ~temp; + } else { + temp >>= Al; + temp2 = temp; + } + if (temp == 0) { + r++; + continue; + } + Flush(coding_state, bw); + while (r > 15) { + WriteSymbol(0xf0, ac_huff, bw); + r -= 16; + } + int nbits = FloorLog2Nonzero<uint32_t>(temp) + 1; + int symbol = (r << 4u) + nbits; + WriteSymbol(symbol, ac_huff, bw); + WriteBits(bw, nbits, temp2 & ((1 << nbits) - 1)); + r = 0; + } + if (num_zero_runs > 0) { + Flush(coding_state, bw); + for (int i = 0; i < num_zero_runs; ++i) { + WriteSymbol(0xf0, ac_huff, bw); + r -= 16; + } + } + if (r > 0) { + BufferEndOfBand(coding_state, ac_huff, nullptr, 0, bw); + if (!eob_run_allowed) { + Flush(coding_state, bw); + } + } + return true; +} + +bool EncodeRefinementBits(const coeff_t* coeffs, HuffmanCodeTable* ac_huff, + int Ss, int Se, int Al, DCTCodingState* coding_state, + JpegBitWriter* bw) { + bool eob_run_allowed = Ss > 0; + if (Ss == 0) { + // Emit next bit of DC component. + WriteBits(bw, 1, (coeffs[0] >> Al) & 1); + ++Ss; + } + if (Ss > Se) { + return true; + } + int abs_values[kDCTBlockSize]; + int eob = 0; + for (int k = Ss; k <= Se; k++) { + const coeff_t abs_val = std::abs(coeffs[kJPEGNaturalOrder[k]]); + abs_values[k] = abs_val >> Al; + if (abs_values[k] == 1) { + eob = k; + } + } + int r = 0; + int refinement_bits[kDCTBlockSize]; + size_t refinement_bits_count = 0; + for (int k = Ss; k <= Se; k++) { + if (abs_values[k] == 0) { + r++; + continue; + } + while (r > 15 && k <= eob) { + Flush(coding_state, bw); + WriteSymbol(0xf0, ac_huff, bw); + r -= 16; + for (size_t i = 0; i < refinement_bits_count; ++i) { + WriteBits(bw, 1, refinement_bits[i]); + } + refinement_bits_count = 0; + } + if (abs_values[k] > 1) { + refinement_bits[refinement_bits_count++] = abs_values[k] & 1u; + continue; + } + Flush(coding_state, bw); + int symbol = (r << 4u) + 1; + int new_non_zero_bit = (coeffs[kJPEGNaturalOrder[k]] < 0) ? 0 : 1; + WriteSymbol(symbol, ac_huff, bw); + WriteBits(bw, 1, new_non_zero_bit); + for (size_t i = 0; i < refinement_bits_count; ++i) { + WriteBits(bw, 1, refinement_bits[i]); + } + refinement_bits_count = 0; + r = 0; + } + if (r > 0 || refinement_bits_count) { + BufferEndOfBand(coding_state, ac_huff, refinement_bits, + refinement_bits_count, bw); + if (!eob_run_allowed) { + Flush(coding_state, bw); + } + } + return true; +} + +template <int kMode> +SerializationStatus JXL_NOINLINE DoEncodeScan(const JPEGData& jpg, + SerializationState* state) { + const JPEGScanInfo& scan_info = jpg.scan_info[state->scan_index]; + EncodeScanState& ss = state->scan_state; + + const int restart_interval = + state->seen_dri_marker ? jpg.restart_interval : 0; + + const auto get_next_extra_zero_run_index = [&ss, &scan_info]() -> int { + if (ss.extra_zero_runs_pos < scan_info.extra_zero_runs.size()) { + return scan_info.extra_zero_runs[ss.extra_zero_runs_pos].block_idx; + } else { + return -1; + } + }; + + const auto get_next_reset_point = [&ss, &scan_info]() -> int { + if (ss.next_reset_point_pos < scan_info.reset_points.size()) { + return scan_info.reset_points[ss.next_reset_point_pos++]; + } else { + return -1; + } + }; + + if (ss.stage == EncodeScanState::HEAD) { + if (!EncodeSOS(jpg, scan_info, state)) return SerializationStatus::ERROR; + JpegBitWriterInit(&ss.bw, &state->output_queue); + DCTCodingStateInit(&ss.coding_state); + ss.restarts_to_go = restart_interval; + ss.next_restart_marker = 0; + ss.block_scan_index = 0; + ss.extra_zero_runs_pos = 0; + ss.next_extra_zero_run_index = get_next_extra_zero_run_index(); + ss.next_reset_point_pos = 0; + ss.next_reset_point = get_next_reset_point(); + ss.mcu_y = 0; + memset(ss.last_dc_coeff, 0, sizeof(ss.last_dc_coeff)); + ss.stage = EncodeScanState::BODY; + } + JpegBitWriter* bw = &ss.bw; + DCTCodingState* coding_state = &ss.coding_state; + + JXL_DASSERT(ss.stage == EncodeScanState::BODY); + + // "Non-interleaved" means color data comes in separate scans, in other words + // each scan can contain only one color component. + const bool is_interleaved = (scan_info.num_components > 1); + int MCUs_per_row = 0; + int MCU_rows = 0; + jpg.CalculateMcuSize(scan_info, &MCUs_per_row, &MCU_rows); + const bool is_progressive = state->is_progressive; + const int Al = is_progressive ? scan_info.Al : 0; + const int Ss = is_progressive ? scan_info.Ss : 0; + const int Se = is_progressive ? scan_info.Se : 63; + + // DC-only is defined by [0..0] spectral range. + const bool want_ac = ((Ss != 0) || (Se != 0)); + const bool want_dc = (Ss == 0); + // TODO(user): support streaming decoding again. + const bool complete_ac = true; + const bool has_ac = true; + if (want_ac && !has_ac) return SerializationStatus::NEEDS_MORE_INPUT; + + // |has_ac| implies |complete_dc| but not vice versa; for the sake of + // simplicity we pretend they are equal, because they are separated by just a + // few bytes of input. + const bool complete_dc = has_ac; + const bool complete = want_ac ? complete_ac : complete_dc; + // When "incomplete" |ac_dc| tracks information about current ("incomplete") + // band parsing progress. + + // FIXME: Is this always complete? + // const int last_mcu_y = + // complete ? MCU_rows : parsing_state.internal->ac_dc.next_mcu_y * + // v_group; + (void)complete; + const int last_mcu_y = complete ? MCU_rows : 0; + + for (; ss.mcu_y < last_mcu_y; ++ss.mcu_y) { + for (int mcu_x = 0; mcu_x < MCUs_per_row; ++mcu_x) { + // Possibly emit a restart marker. + if (restart_interval > 0 && ss.restarts_to_go == 0) { + Flush(coding_state, bw); + if (!JumpToByteBoundary(bw, &state->pad_bits, state->pad_bits_end)) { + return SerializationStatus::ERROR; + } + EmitMarker(bw, 0xD0 + ss.next_restart_marker); + ss.next_restart_marker += 1; + ss.next_restart_marker &= 0x7; + ss.restarts_to_go = restart_interval; + memset(ss.last_dc_coeff, 0, sizeof(ss.last_dc_coeff)); + } + + // Encode one MCU + for (size_t i = 0; i < scan_info.num_components; ++i) { + const JPEGComponentScanInfo& si = scan_info.components[i]; + const JPEGComponent& c = jpg.components[si.comp_idx]; + size_t dc_tbl_idx = si.dc_tbl_idx; + size_t ac_tbl_idx = si.ac_tbl_idx; + HuffmanCodeTable* dc_huff = &state->dc_huff_table[dc_tbl_idx]; + HuffmanCodeTable* ac_huff = &state->ac_huff_table[ac_tbl_idx]; + if (want_dc && !dc_huff->initialized) { + return SerializationStatus::ERROR; + } + if (want_ac && !ac_huff->initialized) { + return SerializationStatus::ERROR; + } + int n_blocks_y = is_interleaved ? c.v_samp_factor : 1; + int n_blocks_x = is_interleaved ? c.h_samp_factor : 1; + for (int iy = 0; iy < n_blocks_y; ++iy) { + for (int ix = 0; ix < n_blocks_x; ++ix) { + int block_y = ss.mcu_y * n_blocks_y + iy; + int block_x = mcu_x * n_blocks_x + ix; + int block_idx = block_y * c.width_in_blocks + block_x; + if (ss.block_scan_index == ss.next_reset_point) { + Flush(coding_state, bw); + ss.next_reset_point = get_next_reset_point(); + } + int num_zero_runs = 0; + if (ss.block_scan_index == ss.next_extra_zero_run_index) { + num_zero_runs = scan_info.extra_zero_runs[ss.extra_zero_runs_pos] + .num_extra_zero_runs; + ++ss.extra_zero_runs_pos; + ss.next_extra_zero_run_index = get_next_extra_zero_run_index(); + } + const coeff_t* coeffs = &c.coeffs[block_idx << 6]; + bool ok; + // compressed size per block cannot be more than 512 bytes + Reserve(bw, 512); + if (kMode == 0) { + ok = EncodeDCTBlockSequential(coeffs, dc_huff, ac_huff, + num_zero_runs, + ss.last_dc_coeff + si.comp_idx, bw); + } else if (kMode == 1) { + ok = EncodeDCTBlockProgressive( + coeffs, dc_huff, ac_huff, Ss, Se, Al, num_zero_runs, + coding_state, ss.last_dc_coeff + si.comp_idx, bw); + } else { + ok = EncodeRefinementBits(coeffs, ac_huff, Ss, Se, Al, + coding_state, bw); + } + if (!ok) return SerializationStatus::ERROR; + ++ss.block_scan_index; + } + } + } + --ss.restarts_to_go; + } + } + if (ss.mcu_y < MCU_rows) { + if (!bw->healthy) return SerializationStatus::ERROR; + return SerializationStatus::NEEDS_MORE_INPUT; + } + Flush(coding_state, bw); + if (!JumpToByteBoundary(bw, &state->pad_bits, state->pad_bits_end)) { + return SerializationStatus::ERROR; + } + JpegBitWriterFinish(bw); + ss.stage = EncodeScanState::HEAD; + state->scan_index++; + if (!bw->healthy) return SerializationStatus::ERROR; + + return SerializationStatus::DONE; +} + +static SerializationStatus JXL_INLINE EncodeScan(const JPEGData& jpg, + SerializationState* state) { + const JPEGScanInfo& scan_info = jpg.scan_info[state->scan_index]; + const bool is_progressive = state->is_progressive; + const int Al = is_progressive ? scan_info.Al : 0; + const int Ah = is_progressive ? scan_info.Ah : 0; + const int Ss = is_progressive ? scan_info.Ss : 0; + const int Se = is_progressive ? scan_info.Se : 63; + const bool need_sequential = + !is_progressive || (Ah == 0 && Al == 0 && Ss == 0 && Se == 63); + if (need_sequential) { + return DoEncodeScan<0>(jpg, state); + } else if (Ah == 0) { + return DoEncodeScan<1>(jpg, state); + } else { + return DoEncodeScan<2>(jpg, state); + } +} + +SerializationStatus SerializeSection(uint8_t marker, SerializationState* state, + const JPEGData& jpg) { + const auto to_status = [](bool result) { + return result ? SerializationStatus::DONE : SerializationStatus::ERROR; + }; + // TODO(eustas): add and use marker enum + switch (marker) { + case 0xC0: + case 0xC1: + case 0xC2: + case 0xC9: + case 0xCA: + return to_status(EncodeSOF(jpg, marker, state)); + + case 0xC4: + return to_status(EncodeDHT(jpg, state)); + + case 0xD0: + case 0xD1: + case 0xD2: + case 0xD3: + case 0xD4: + case 0xD5: + case 0xD6: + case 0xD7: + return to_status(EncodeRestart(marker, state)); + + case 0xD9: + return to_status(EncodeEOI(jpg, state)); + + case 0xDA: + return EncodeScan(jpg, state); + + case 0xDB: + return to_status(EncodeDQT(jpg, state)); + + case 0xDD: + return to_status(EncodeDRI(jpg, state)); + + case 0xE0: + case 0xE1: + case 0xE2: + case 0xE3: + case 0xE4: + case 0xE5: + case 0xE6: + case 0xE7: + case 0xE8: + case 0xE9: + case 0xEA: + case 0xEB: + case 0xEC: + case 0xED: + case 0xEE: + case 0xEF: + return to_status(EncodeAPP(jpg, marker, state)); + + case 0xFE: + return to_status(EncodeCOM(jpg, state)); + + case 0xFF: + return to_status(EncodeInterMarkerData(jpg, state)); + + default: + return SerializationStatus::ERROR; + } +} + +// TODO(veluca): add streaming support again. +Status WriteJpegInternal(const JPEGData& jpg, const JPEGOutput& out, + SerializationState* ss) { + const auto maybe_push_output = [&]() -> Status { + if (ss->stage != SerializationState::STAGE_ERROR) { + while (!ss->output_queue.empty()) { + auto& chunk = ss->output_queue.front(); + size_t num_written = out(chunk.next, chunk.len); + if (num_written == 0 && chunk.len > 0) { + return StatusMessage(Status(StatusCode::kNotEnoughBytes), + "Failed to write output"); + } + chunk.len -= num_written; + if (chunk.len == 0) { + ss->output_queue.pop_front(); + } + } + } + return true; + }; + + while (true) { + switch (ss->stage) { + case SerializationState::STAGE_INIT: { + // Valid Brunsli requires, at least, 0xD9 marker. + // This might happen on corrupted stream, or on unconditioned JPEGData. + // TODO(eustas): check D9 in the only one and is the last one. + if (jpg.marker_order.empty()) { + ss->stage = SerializationState::STAGE_ERROR; + break; + } + ss->dc_huff_table.resize(kMaxHuffmanTables); + ss->ac_huff_table.resize(kMaxHuffmanTables); + if (jpg.has_zero_padding_bit) { + ss->pad_bits = jpg.padding_bits.data(); + ss->pad_bits_end = ss->pad_bits + jpg.padding_bits.size(); + } + + EncodeSOI(ss); + JXL_QUIET_RETURN_IF_ERROR(maybe_push_output()); + ss->stage = SerializationState::STAGE_SERIALIZE_SECTION; + break; + } + + case SerializationState::STAGE_SERIALIZE_SECTION: { + if (ss->section_index >= jpg.marker_order.size()) { + ss->stage = SerializationState::STAGE_DONE; + break; + } + uint8_t marker = jpg.marker_order[ss->section_index]; + SerializationStatus status = SerializeSection(marker, ss, jpg); + if (status == SerializationStatus::ERROR) { + JXL_WARNING("Failed to encode marker 0x%.2x", marker); + ss->stage = SerializationState::STAGE_ERROR; + break; + } + JXL_QUIET_RETURN_IF_ERROR(maybe_push_output()); + if (status == SerializationStatus::NEEDS_MORE_INPUT) { + return JXL_FAILURE("Incomplete serialization data"); + } else if (status != SerializationStatus::DONE) { + JXL_DASSERT(false); + ss->stage = SerializationState::STAGE_ERROR; + break; + } + ++ss->section_index; + break; + } + + case SerializationState::STAGE_DONE: + JXL_ASSERT(ss->output_queue.empty()); + if (ss->pad_bits != nullptr && ss->pad_bits != ss->pad_bits_end) { + return JXL_FAILURE("Invalid number of padding bits."); + } + return true; + + case SerializationState::STAGE_ERROR: + return JXL_FAILURE("JPEG serialization error"); + } + } +} + +} // namespace + +Status WriteJpeg(const JPEGData& jpg, const JPEGOutput& out) { + auto ss = jxl::make_unique<SerializationState>(); + return WriteJpegInternal(jpg, out, ss.get()); +} + +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.h b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.h new file mode 100644 index 0000000000..c6f70ff8b1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_data_writer.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Functions for writing a JPEGData object into a jpeg byte stream. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_DATA_WRITER_H_ +#define LIB_JXL_JPEG_DEC_JPEG_DATA_WRITER_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <functional> + +#include "lib/jxl/jpeg/dec_jpeg_serialization_state.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +// Function type used to write len bytes into buf. Returns the number of bytes +// written. +using JPEGOutput = std::function<size_t(const uint8_t* buf, size_t len)>; + +Status WriteJpeg(const JPEGData& jpg, const JPEGOutput& out); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_DATA_WRITER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_output_chunk.h b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_output_chunk.h new file mode 100644 index 0000000000..e003c04952 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_output_chunk.h @@ -0,0 +1,72 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_OUTPUT_CHUNK_H_ +#define LIB_JXL_JPEG_DEC_JPEG_OUTPUT_CHUNK_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <initializer_list> +#include <memory> +#include <vector> + +namespace jxl { +namespace jpeg { + +/** + * A chunk of output data. + * + * Data producer creates OutputChunks and adds them to the end output queue. + * Once control flow leaves the producer code, it is considered that chunk of + * data is final and can not be changed; to underline this fact |next| is a + * const-pointer. + * + * Data consumer removes OutputChunks from the beginning of the output queue. + * It is possible to consume OutputChunks partially, by updating |next| and + * |len|. + * + * There are 2 types of output chunks: + * - owning: actual data is stored in |buffer| field; producer fills data after + * the instance it created; it is legal to reduce |len| to show that not all + * the capacity of |buffer| is used + * - non-owning: represents the data stored (owned) somewhere else + */ +struct OutputChunk { + // Non-owning + template <typename Bytes> + explicit OutputChunk(Bytes& bytes) : len(bytes.size()) { + // Deal both with const qualifier and data type. + const void* src = bytes.data(); + next = reinterpret_cast<const uint8_t*>(src); + } + + // Non-owning + OutputChunk(const uint8_t* data, size_t size) : next(data), len(size) {} + + // Owning + explicit OutputChunk(size_t size = 0) { + buffer.reset(new std::vector<uint8_t>(size)); + next = buffer->data(); + len = size; + } + + // Owning + OutputChunk(std::initializer_list<uint8_t> bytes) { + buffer.reset(new std::vector<uint8_t>(bytes)); + next = buffer->data(); + len = bytes.size(); + } + + const uint8_t* next; + size_t len; + // TODO(veluca): consider removing the unique_ptr. + std::unique_ptr<std::vector<uint8_t>> buffer; +}; + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_OUTPUT_CHUNK_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_serialization_state.h b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_serialization_state.h new file mode 100644 index 0000000000..9950dc1744 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/dec_jpeg_serialization_state.h @@ -0,0 +1,101 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_JPEG_DEC_JPEG_SERIALIZATION_STATE_H_ +#define LIB_JXL_JPEG_DEC_JPEG_SERIALIZATION_STATE_H_ + +#include <algorithm> +#include <deque> +#include <vector> + +#include "lib/jxl/jpeg/dec_jpeg_output_chunk.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +struct HuffmanCodeTable { + int8_t depth[256]; + uint16_t code[256]; + bool initialized = false; + void InitDepths(int8_t value = 0) { + std::fill(std::begin(depth), std::end(depth), value); + } +}; + +// Handles the packing of bits into output bytes. +struct JpegBitWriter { + bool healthy; + std::deque<OutputChunk>* output; + OutputChunk chunk; + uint8_t* data; + size_t pos; + uint64_t put_buffer; + int put_bits; +}; + +// Holds data that is buffered between 8x8 blocks in progressive mode. +struct DCTCodingState { + // The run length of end-of-band symbols in a progressive scan. + int eob_run_; + // The huffman table to be used when flushing the state. + HuffmanCodeTable* cur_ac_huff_; + // The sequence of currently buffered refinement bits for a successive + // approximation scan (one where Ah > 0). + std::vector<uint16_t> refinement_bits_; + size_t refinement_bits_count_ = 0; +}; + +struct EncodeScanState { + enum Stage { HEAD, BODY }; + + Stage stage = HEAD; + + int mcu_y; + JpegBitWriter bw; + coeff_t last_dc_coeff[kMaxComponents] = {0}; + int restarts_to_go; + int next_restart_marker; + int block_scan_index; + DCTCodingState coding_state; + size_t extra_zero_runs_pos; + int next_extra_zero_run_index; + size_t next_reset_point_pos; + int next_reset_point; +}; + +struct SerializationState { + enum Stage { + STAGE_INIT, + STAGE_SERIALIZE_SECTION, + STAGE_DONE, + STAGE_ERROR, + }; + + Stage stage = STAGE_INIT; + + std::deque<OutputChunk> output_queue; + + size_t section_index = 0; + int dht_index = 0; + int dqt_index = 0; + int app_index = 0; + int com_index = 0; + int data_index = 0; + int scan_index = 0; + std::vector<HuffmanCodeTable> dc_huff_table; + std::vector<HuffmanCodeTable> ac_huff_table; + const uint8_t* pad_bits = nullptr; + const uint8_t* pad_bits_end = nullptr; + bool seen_dri_marker = false; + bool is_progressive = false; + + EncodeScanState scan_state; +}; + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_DEC_JPEG_SERIALIZATION_STATE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.cc b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.cc new file mode 100644 index 0000000000..97342553e5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.cc @@ -0,0 +1,404 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/enc_jpeg_data.h" + +#include <brotli/encode.h> + +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/jpeg/enc_jpeg_data_reader.h" +#include "lib/jxl/luminance.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { +namespace jpeg { + +namespace { + +constexpr int BITS_IN_JSAMPLE = 8; +using ByteSpan = Span<const uint8_t>; + +// TODO(eustas): move to jpeg_data, to use from codec_jpg as well. +// See if there is a canonically chunked ICC profile and mark corresponding +// app-tags with AppMarkerType::kICC. +Status DetectIccProfile(JPEGData& jpeg_data) { + JXL_DASSERT(jpeg_data.app_data.size() == jpeg_data.app_marker_type.size()); + size_t num_icc = 0; + size_t num_icc_jpeg = 0; + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + const auto& app = jpeg_data.app_data[i]; + size_t pos = 0; + if (app[pos++] != 0xE2) continue; + // At least APPn + size; otherwise it should be intermarker-data. + JXL_DASSERT(app.size() >= 3); + size_t tag_length = (app[pos] << 8) + app[pos + 1]; + pos += 2; + JXL_DASSERT(app.size() == tag_length + 1); + // Empty payload is 2 bytes for tag length itself + signature + if (tag_length < 2 + sizeof kIccProfileTag) continue; + + if (memcmp(&app[pos], kIccProfileTag, sizeof kIccProfileTag) != 0) continue; + pos += sizeof kIccProfileTag; + uint8_t chunk_id = app[pos++]; + uint8_t num_chunks = app[pos++]; + if (chunk_id != num_icc + 1) continue; + if (num_icc_jpeg == 0) num_icc_jpeg = num_chunks; + if (num_icc_jpeg != num_chunks) continue; + num_icc++; + jpeg_data.app_marker_type[i] = AppMarkerType::kICC; + } + if (num_icc != num_icc_jpeg) { + return JXL_FAILURE("Invalid ICC chunks"); + } + return true; +} + +bool GetMarkerPayload(const uint8_t* data, size_t size, ByteSpan* payload) { + if (size < 3) { + return false; + } + size_t hi = data[1]; + size_t lo = data[2]; + size_t internal_size = (hi << 8u) | lo; + // Second byte of marker is not counted towards size. + if (internal_size != size - 1) { + return false; + } + // cut second marker byte and "length" from payload. + *payload = ByteSpan(data, size); + payload->remove_prefix(3); + return true; +} + +Status DetectBlobs(jpeg::JPEGData& jpeg_data) { + JXL_DASSERT(jpeg_data.app_data.size() == jpeg_data.app_marker_type.size()); + bool have_exif = false, have_xmp = false; + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + auto& marker = jpeg_data.app_data[i]; + if (marker.empty() || marker[0] != kApp1) { + continue; + } + ByteSpan payload; + if (!GetMarkerPayload(marker.data(), marker.size(), &payload)) { + // Something is wrong with this marker; does not care. + continue; + } + if (!have_exif && payload.size() >= sizeof kExifTag && + !memcmp(payload.data(), kExifTag, sizeof kExifTag)) { + jpeg_data.app_marker_type[i] = AppMarkerType::kExif; + have_exif = true; + } + if (!have_xmp && payload.size() >= sizeof kXMPTag && + !memcmp(payload.data(), kXMPTag, sizeof kXMPTag)) { + jpeg_data.app_marker_type[i] = AppMarkerType::kXMP; + have_xmp = true; + } + } + return true; +} + +Status ParseChunkedMarker(const jpeg::JPEGData& src, uint8_t marker_type, + const ByteSpan& tag, IccBytes* output, + bool allow_permutations = false) { + output->clear(); + + std::vector<ByteSpan> chunks; + std::vector<bool> presence; + size_t expected_number_of_parts = 0; + bool is_first_chunk = true; + size_t ordinal = 0; + for (const auto& marker : src.app_data) { + if (marker.empty() || marker[0] != marker_type) { + continue; + } + ByteSpan payload; + if (!GetMarkerPayload(marker.data(), marker.size(), &payload)) { + // Something is wrong with this marker; does not care. + continue; + } + if ((payload.size() < tag.size()) || + memcmp(payload.data(), tag.data(), tag.size()) != 0) { + continue; + } + payload.remove_prefix(tag.size()); + if (payload.size() < 2) { + return JXL_FAILURE("Chunk is too small."); + } + uint8_t index = payload[0]; + uint8_t total = payload[1]; + ordinal++; + if (!allow_permutations) { + if (index != ordinal) return JXL_FAILURE("Invalid chunk order."); + } + + payload.remove_prefix(2); + + JXL_RETURN_IF_ERROR(total != 0); + if (is_first_chunk) { + is_first_chunk = false; + expected_number_of_parts = total; + // 1-based indices; 0-th element is added for convenience. + chunks.resize(total + 1); + presence.resize(total + 1); + } else { + JXL_RETURN_IF_ERROR(expected_number_of_parts == total); + } + + if (index == 0 || index > total) { + return JXL_FAILURE("Invalid chunk index."); + } + + if (presence[index]) { + return JXL_FAILURE("Duplicate chunk."); + } + presence[index] = true; + chunks[index] = payload; + } + + for (size_t i = 0; i < expected_number_of_parts; ++i) { + // 0-th element is not used. + size_t index = i + 1; + if (!presence[index]) { + return JXL_FAILURE("Missing chunk."); + } + chunks[index].AppendTo(output); + } + + return true; +} + +Status SetBlobsFromJpegData(const jpeg::JPEGData& jpeg_data, Blobs* blobs) { + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + auto& marker = jpeg_data.app_data[i]; + if (marker.empty() || marker[0] != kApp1) { + continue; + } + ByteSpan payload; + if (!GetMarkerPayload(marker.data(), marker.size(), &payload)) { + // Something is wrong with this marker; does not care. + continue; + } + if (payload.size() >= sizeof kExifTag && + !memcmp(payload.data(), kExifTag, sizeof kExifTag)) { + if (blobs->exif.empty()) { + blobs->exif.resize(payload.size() - sizeof kExifTag); + memcpy(blobs->exif.data(), payload.data() + sizeof kExifTag, + payload.size() - sizeof kExifTag); + } else { + JXL_WARNING( + "ReJPEG: multiple Exif blobs, storing only first one in the JPEG " + "XL container\n"); + } + } + if (payload.size() >= sizeof kXMPTag && + !memcmp(payload.data(), kXMPTag, sizeof kXMPTag)) { + if (blobs->xmp.empty()) { + blobs->xmp.resize(payload.size() - sizeof kXMPTag); + memcpy(blobs->xmp.data(), payload.data() + sizeof kXMPTag, + payload.size() - sizeof kXMPTag); + } else { + JXL_WARNING( + "ReJPEG: multiple XMP blobs, storing only first one in the JPEG " + "XL container\n"); + } + } + } + return true; +} + +static inline bool IsJPG(const Span<const uint8_t> bytes) { + return bytes.size() >= 2 && bytes[0] == 0xFF && bytes[1] == 0xD8; +} + +} // namespace + +void SetColorEncodingFromJpegData(const jpeg::JPEGData& jpg, + ColorEncoding* color_encoding) { + IccBytes icc_profile; + if (!ParseChunkedMarker(jpg, kApp2, ByteSpan(kIccProfileTag), &icc_profile)) { + JXL_WARNING("ReJPEG: corrupted ICC profile\n"); + icc_profile.clear(); + } + + if (icc_profile.empty()) { + bool is_gray = (jpg.components.size() == 1); + *color_encoding = ColorEncoding::SRGB(is_gray); + } else { + color_encoding->SetICCRaw(std::move(icc_profile)); + } +} + +Status SetChromaSubsamplingFromJpegData(const JPEGData& jpg, + YCbCrChromaSubsampling* cs) { + size_t nbcomp = jpg.components.size(); + if (nbcomp != 1 && nbcomp != 3) { + return JXL_FAILURE("Cannot recompress JPEGs with neither 1 nor 3 channels"); + } + if (nbcomp == 3) { + uint8_t hsample[3], vsample[3]; + for (size_t i = 0; i < nbcomp; i++) { + hsample[i] = jpg.components[i].h_samp_factor; + vsample[i] = jpg.components[i].v_samp_factor; + } + JXL_RETURN_IF_ERROR(cs->Set(hsample, vsample)); + } else if (nbcomp == 1) { + uint8_t hsample[3], vsample[3]; + for (size_t i = 0; i < 3; i++) { + hsample[i] = jpg.components[0].h_samp_factor; + vsample[i] = jpg.components[0].v_samp_factor; + } + JXL_RETURN_IF_ERROR(cs->Set(hsample, vsample)); + } + return true; +} + +Status SetColorTransformFromJpegData(const JPEGData& jpg, + ColorTransform* color_transform) { + size_t nbcomp = jpg.components.size(); + if (nbcomp != 1 && nbcomp != 3) { + return JXL_FAILURE("Cannot recompress JPEGs with neither 1 nor 3 channels"); + } + bool is_rgb = false; + { + const auto& markers = jpg.marker_order; + // If there is a JFIF marker, this is YCbCr. Otherwise... + if (std::find(markers.begin(), markers.end(), 0xE0) == markers.end()) { + // Try to find an 'Adobe' marker. + size_t app_markers = 0; + size_t i = 0; + for (; i < markers.size(); i++) { + // This is an APP marker. + if ((markers[i] & 0xF0) == 0xE0) { + JXL_CHECK(app_markers < jpg.app_data.size()); + // APP14 marker + if (markers[i] == 0xEE) { + const auto& data = jpg.app_data[app_markers]; + if (data.size() == 15 && data[3] == 'A' && data[4] == 'd' && + data[5] == 'o' && data[6] == 'b' && data[7] == 'e') { + // 'Adobe' marker. + is_rgb = data[14] == 0; + break; + } + } + app_markers++; + } + } + + if (i == markers.size()) { + // No 'Adobe' marker, guess from component IDs. + is_rgb = nbcomp == 3 && jpg.components[0].id == 'R' && + jpg.components[1].id == 'G' && jpg.components[2].id == 'B'; + } + } + } + *color_transform = + (!is_rgb || nbcomp == 1) ? ColorTransform::kYCbCr : ColorTransform::kNone; + return true; +} + +Status EncodeJPEGData(JPEGData& jpeg_data, std::vector<uint8_t>* bytes, + const CompressParams& cparams) { + bytes->clear(); + jpeg_data.app_marker_type.resize(jpeg_data.app_data.size(), + AppMarkerType::kUnknown); + JXL_RETURN_IF_ERROR(DetectIccProfile(jpeg_data)); + JXL_RETURN_IF_ERROR(DetectBlobs(jpeg_data)); + + size_t total_data = 0; + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + if (jpeg_data.app_marker_type[i] != AppMarkerType::kUnknown) { + continue; + } + total_data += jpeg_data.app_data[i].size(); + } + for (size_t i = 0; i < jpeg_data.com_data.size(); i++) { + total_data += jpeg_data.com_data[i].size(); + } + for (size_t i = 0; i < jpeg_data.inter_marker_data.size(); i++) { + total_data += jpeg_data.inter_marker_data[i].size(); + } + total_data += jpeg_data.tail_data.size(); + size_t brotli_capacity = BrotliEncoderMaxCompressedSize(total_data); + + BitWriter writer; + JXL_RETURN_IF_ERROR(Bundle::Write(jpeg_data, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + { + PaddedBytes serialized_jpeg_data = std::move(writer).TakeBytes(); + bytes->reserve(serialized_jpeg_data.size() + brotli_capacity); + Bytes(serialized_jpeg_data).AppendTo(bytes); + } + + BrotliEncoderState* brotli_enc = + BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); + int effort = cparams.brotli_effort; + if (effort < 0) effort = 11 - static_cast<int>(cparams.speed_tier); + BrotliEncoderSetParameter(brotli_enc, BROTLI_PARAM_QUALITY, effort); + size_t initial_size = bytes->size(); + BrotliEncoderSetParameter(brotli_enc, BROTLI_PARAM_SIZE_HINT, total_data); + bytes->resize(initial_size + brotli_capacity); + size_t enc_size = 0; + auto br_append = [&](const std::vector<uint8_t>& data, bool last) { + size_t available_in = data.size(); + const uint8_t* in = data.data(); + uint8_t* out = &(*bytes)[initial_size + enc_size]; + do { + uint8_t* out_before = out; + msan::MemoryIsInitialized(in, available_in); + JXL_CHECK(BrotliEncoderCompressStream( + brotli_enc, last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS, + &available_in, &in, &brotli_capacity, &out, &enc_size)); + msan::UnpoisonMemory(out_before, out - out_before); + } while (BrotliEncoderHasMoreOutput(brotli_enc) || available_in > 0); + }; + + for (size_t i = 0; i < jpeg_data.app_data.size(); i++) { + if (jpeg_data.app_marker_type[i] != AppMarkerType::kUnknown) { + continue; + } + br_append(jpeg_data.app_data[i], /*last=*/false); + } + for (size_t i = 0; i < jpeg_data.com_data.size(); i++) { + br_append(jpeg_data.com_data[i], /*last=*/false); + } + for (size_t i = 0; i < jpeg_data.inter_marker_data.size(); i++) { + br_append(jpeg_data.inter_marker_data[i], /*last=*/false); + } + br_append(jpeg_data.tail_data, /*last=*/true); + BrotliEncoderDestroyInstance(brotli_enc); + bytes->resize(initial_size + enc_size); + return true; +} + +Status DecodeImageJPG(const Span<const uint8_t> bytes, CodecInOut* io) { + if (!IsJPG(bytes)) return false; + io->frames.clear(); + io->frames.reserve(1); + io->frames.emplace_back(&io->metadata.m); + io->Main().jpeg_data = make_unique<jpeg::JPEGData>(); + jpeg::JPEGData* jpeg_data = io->Main().jpeg_data.get(); + if (!jpeg::ReadJpeg(bytes.data(), bytes.size(), jpeg::JpegReadMode::kReadAll, + jpeg_data)) { + return JXL_FAILURE("Error reading JPEG"); + } + SetColorEncodingFromJpegData(*jpeg_data, &io->metadata.m.color_encoding); + JXL_RETURN_IF_ERROR(SetBlobsFromJpegData(*jpeg_data, &io->blobs)); + JXL_RETURN_IF_ERROR(SetChromaSubsamplingFromJpegData( + *jpeg_data, &io->Main().chroma_subsampling)); + JXL_RETURN_IF_ERROR( + SetColorTransformFromJpegData(*jpeg_data, &io->Main().color_transform)); + + io->metadata.m.SetIntensityTarget(kDefaultIntensityTarget); + io->metadata.m.SetUintSamples(BITS_IN_JSAMPLE); + io->SetFromImage(Image3F(jpeg_data->width, jpeg_data->height), + io->metadata.m.color_encoding); + SetIntensityTarget(&io->metadata.m); + return true; +} + +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.h b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.h new file mode 100644 index 0000000000..f9a3a95e23 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data.h @@ -0,0 +1,42 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_JPEG_ENC_JPEG_DATA_H_ +#define LIB_JXL_JPEG_ENC_JPEG_DATA_H_ + +#include <cstdint> +#include <vector> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { + +class CodecInOut; + +namespace jpeg { +Status EncodeJPEGData(JPEGData& jpeg_data, std::vector<uint8_t>* bytes, + const CompressParams& cparams); + +void SetColorEncodingFromJpegData(const jpeg::JPEGData& jpg, + ColorEncoding* color_encoding); +Status SetChromaSubsamplingFromJpegData(const JPEGData& jpg, + YCbCrChromaSubsampling* cs); +Status SetColorTransformFromJpegData(const JPEGData& jpg, + ColorTransform* color_transform); + +/** + * Decodes bytes containing JPEG codestream into a CodecInOut as coefficients + * only, for lossless JPEG transcoding. + */ +Status DecodeImageJPG(Span<const uint8_t> bytes, CodecInOut* io); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_ENC_JPEG_DATA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.cc b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.cc new file mode 100644 index 0000000000..ce64dae47b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.cc @@ -0,0 +1,1054 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/enc_jpeg_data_reader.h" + +#include <inttypes.h> +#include <string.h> + +#include <algorithm> +#include <string> +#include <vector> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/jpeg/enc_jpeg_huffman_decode.h" +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +namespace { +static const int kBrunsliMaxSampling = 15; + +// Macros for commonly used error conditions. + +#define JXL_JPEG_VERIFY_LEN(n) \ + if (*pos + (n) > len) { \ + return JXL_FAILURE("Unexpected end of input: pos=%" PRIuS \ + " need=%d len=%" PRIuS, \ + *pos, static_cast<int>(n), len); \ + } + +#define JXL_JPEG_VERIFY_INPUT(var, low, high, code) \ + if ((var) < (low) || (var) > (high)) { \ + return JXL_FAILURE("Invalid " #var ": %d", static_cast<int>(var)); \ + } + +#define JXL_JPEG_VERIFY_MARKER_END() \ + if (start_pos + marker_len != *pos) { \ + return JXL_FAILURE("Invalid marker length: declared=%" PRIuS \ + " actual=%" PRIuS, \ + marker_len, (*pos - start_pos)); \ + } + +#define JXL_JPEG_EXPECT_MARKER() \ + if (pos + 2 > len || data[pos] != 0xff) { \ + return JXL_FAILURE( \ + "Marker byte (0xff) expected, found: 0x%.2x pos=%" PRIuS \ + " len=%" PRIuS, \ + (pos < len ? data[pos] : 0), pos, len); \ + } + +inline int ReadUint8(const uint8_t* data, size_t* pos) { + return data[(*pos)++]; +} + +inline int ReadUint16(const uint8_t* data, size_t* pos) { + int v = (data[*pos] << 8) + data[*pos + 1]; + *pos += 2; + return v; +} + +// Reads the Start of Frame (SOF) marker segment and fills in *jpg with the +// parsed data. +bool ProcessSOF(const uint8_t* data, const size_t len, JpegReadMode mode, + size_t* pos, JPEGData* jpg) { + if (jpg->width != 0) { + return JXL_FAILURE("Duplicate SOF marker."); + } + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(8); + size_t marker_len = ReadUint16(data, pos); + int precision = ReadUint8(data, pos); + int height = ReadUint16(data, pos); + int width = ReadUint16(data, pos); + int num_components = ReadUint8(data, pos); + // 'jbrd' is hardcoded for 8bits: + JXL_JPEG_VERIFY_INPUT(precision, 8, 8, PRECISION); + JXL_JPEG_VERIFY_INPUT(height, 1, kMaxDimPixels, HEIGHT); + JXL_JPEG_VERIFY_INPUT(width, 1, kMaxDimPixels, WIDTH); + JXL_JPEG_VERIFY_INPUT(num_components, 1, kMaxComponents, NUMCOMP); + JXL_JPEG_VERIFY_LEN(3 * num_components); + jpg->height = height; + jpg->width = width; + jpg->components.resize(num_components); + + // Read sampling factors and quant table index for each component. + std::vector<bool> ids_seen(256, false); + int max_h_samp_factor = 1; + int max_v_samp_factor = 1; + for (size_t i = 0; i < jpg->components.size(); ++i) { + const int id = ReadUint8(data, pos); + if (ids_seen[id]) { // (cf. section B.2.2, syntax of Ci) + return JXL_FAILURE("Duplicate ID %d in SOF.", id); + } + ids_seen[id] = true; + jpg->components[i].id = id; + int factor = ReadUint8(data, pos); + int h_samp_factor = factor >> 4; + int v_samp_factor = factor & 0xf; + JXL_JPEG_VERIFY_INPUT(h_samp_factor, 1, kBrunsliMaxSampling, SAMP_FACTOR); + JXL_JPEG_VERIFY_INPUT(v_samp_factor, 1, kBrunsliMaxSampling, SAMP_FACTOR); + jpg->components[i].h_samp_factor = h_samp_factor; + jpg->components[i].v_samp_factor = v_samp_factor; + jpg->components[i].quant_idx = ReadUint8(data, pos); + max_h_samp_factor = std::max(max_h_samp_factor, h_samp_factor); + max_v_samp_factor = std::max(max_v_samp_factor, v_samp_factor); + } + + // We have checked above that none of the sampling factors are 0, so the max + // sampling factors can not be 0. + int MCU_rows = DivCeil(jpg->height, max_v_samp_factor * 8); + int MCU_cols = DivCeil(jpg->width, max_h_samp_factor * 8); + // Compute the block dimensions for each component. + for (size_t i = 0; i < jpg->components.size(); ++i) { + JPEGComponent* c = &jpg->components[i]; + if (max_h_samp_factor % c->h_samp_factor != 0 || + max_v_samp_factor % c->v_samp_factor != 0) { + return JXL_FAILURE("Non-integral subsampling ratios."); + } + c->width_in_blocks = MCU_cols * c->h_samp_factor; + c->height_in_blocks = MCU_rows * c->v_samp_factor; + const uint64_t num_blocks = + static_cast<uint64_t>(c->width_in_blocks) * c->height_in_blocks; + if (mode == JpegReadMode::kReadAll) { + c->coeffs.resize(num_blocks * kDCTBlockSize); + } + } + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the Start of Scan (SOS) marker segment and fills in *scan_info with the +// parsed data. +bool ProcessSOS(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(3); + size_t marker_len = ReadUint16(data, pos); + size_t comps_in_scan = ReadUint8(data, pos); + JXL_JPEG_VERIFY_INPUT(comps_in_scan, 1, jpg->components.size(), + COMPS_IN_SCAN); + + JPEGScanInfo scan_info; + scan_info.num_components = comps_in_scan; + JXL_JPEG_VERIFY_LEN(2 * comps_in_scan); + std::vector<bool> ids_seen(256, false); + for (size_t i = 0; i < comps_in_scan; ++i) { + uint32_t id = ReadUint8(data, pos); + if (ids_seen[id]) { // (cf. section B.2.3, regarding CSj) + return JXL_FAILURE("Duplicate ID %d in SOS.", id); + } + ids_seen[id] = true; + bool found_index = false; + for (size_t j = 0; j < jpg->components.size(); ++j) { + if (jpg->components[j].id == id) { + scan_info.components[i].comp_idx = j; + found_index = true; + } + } + if (!found_index) { + return JXL_FAILURE("SOS marker: Could not find component with id %d", id); + } + int c = ReadUint8(data, pos); + int dc_tbl_idx = c >> 4; + int ac_tbl_idx = c & 0xf; + JXL_JPEG_VERIFY_INPUT(dc_tbl_idx, 0, 3, HUFFMAN_INDEX); + JXL_JPEG_VERIFY_INPUT(ac_tbl_idx, 0, 3, HUFFMAN_INDEX); + scan_info.components[i].dc_tbl_idx = dc_tbl_idx; + scan_info.components[i].ac_tbl_idx = ac_tbl_idx; + } + JXL_JPEG_VERIFY_LEN(3); + scan_info.Ss = ReadUint8(data, pos); + scan_info.Se = ReadUint8(data, pos); + JXL_JPEG_VERIFY_INPUT(static_cast<int>(scan_info.Ss), 0, 63, START_OF_SCAN); + JXL_JPEG_VERIFY_INPUT(scan_info.Se, scan_info.Ss, 63, END_OF_SCAN); + int c = ReadUint8(data, pos); + scan_info.Ah = c >> 4; + scan_info.Al = c & 0xf; + if (scan_info.Ah != 0 && scan_info.Al != scan_info.Ah - 1) { + // section G.1.1.1.2 : Successive approximation control only improves + // by one bit at a time. But it's not always respected, so we just issue + // a warning. + JXL_WARNING("Invalid progressive parameters: Al=%d Ah=%d", scan_info.Al, + scan_info.Ah); + } + // Check that all the Huffman tables needed for this scan are defined. + for (size_t i = 0; i < comps_in_scan; ++i) { + bool found_dc_table = false; + bool found_ac_table = false; + for (size_t j = 0; j < jpg->huffman_code.size(); ++j) { + uint32_t slot_id = jpg->huffman_code[j].slot_id; + if (slot_id == scan_info.components[i].dc_tbl_idx) { + found_dc_table = true; + } else if (slot_id == scan_info.components[i].ac_tbl_idx + 16) { + found_ac_table = true; + } + } + if (scan_info.Ss == 0 && !found_dc_table) { + return JXL_FAILURE( + "SOS marker: Could not find DC Huffman table with index %d", + scan_info.components[i].dc_tbl_idx); + } + if (scan_info.Se > 0 && !found_ac_table) { + return JXL_FAILURE( + "SOS marker: Could not find AC Huffman table with index %d", + scan_info.components[i].ac_tbl_idx); + } + } + jpg->scan_info.push_back(scan_info); + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the Define Huffman Table (DHT) marker segment and fills in *jpg with +// the parsed data. Builds the Huffman decoding table in either dc_huff_lut or +// ac_huff_lut, depending on the type and solt_id of Huffman code being read. +bool ProcessDHT(const uint8_t* data, const size_t len, JpegReadMode mode, + std::vector<HuffmanTableEntry>* dc_huff_lut, + std::vector<HuffmanTableEntry>* ac_huff_lut, size_t* pos, + JPEGData* jpg) { + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + if (marker_len == 2) { + return JXL_FAILURE("DHT marker: no Huffman table found"); + } + while (*pos < start_pos + marker_len) { + JXL_JPEG_VERIFY_LEN(1 + kJpegHuffmanMaxBitLength); + JPEGHuffmanCode huff; + huff.slot_id = ReadUint8(data, pos); + int huffman_index = huff.slot_id; + int is_ac_table = (huff.slot_id & 0x10) != 0; + HuffmanTableEntry* huff_lut; + if (is_ac_table) { + huffman_index -= 0x10; + JXL_JPEG_VERIFY_INPUT(huffman_index, 0, 3, HUFFMAN_INDEX); + huff_lut = &(*ac_huff_lut)[huffman_index * kJpegHuffmanLutSize]; + } else { + JXL_JPEG_VERIFY_INPUT(huffman_index, 0, 3, HUFFMAN_INDEX); + huff_lut = &(*dc_huff_lut)[huffman_index * kJpegHuffmanLutSize]; + } + huff.counts[0] = 0; + int total_count = 0; + int space = 1 << kJpegHuffmanMaxBitLength; + int max_depth = 1; + for (size_t i = 1; i <= kJpegHuffmanMaxBitLength; ++i) { + int count = ReadUint8(data, pos); + if (count != 0) { + max_depth = i; + } + huff.counts[i] = count; + total_count += count; + space -= count * (1 << (kJpegHuffmanMaxBitLength - i)); + } + if (is_ac_table) { + JXL_JPEG_VERIFY_INPUT(total_count, 0, kJpegHuffmanAlphabetSize, + HUFFMAN_CODE); + } else { + JXL_JPEG_VERIFY_INPUT(total_count, 0, kJpegDCAlphabetSize, HUFFMAN_CODE); + } + JXL_JPEG_VERIFY_LEN(total_count); + std::vector<bool> values_seen(256, false); + for (int i = 0; i < total_count; ++i) { + int value = ReadUint8(data, pos); + if (!is_ac_table) { + JXL_JPEG_VERIFY_INPUT(value, 0, kJpegDCAlphabetSize - 1, HUFFMAN_CODE); + } + if (values_seen[value]) { + return JXL_FAILURE("Duplicate Huffman code value %d", value); + } + values_seen[value] = true; + huff.values[i] = value; + } + // Add an invalid symbol that will have the all 1 code. + ++huff.counts[max_depth]; + huff.values[total_count] = kJpegHuffmanAlphabetSize; + space -= (1 << (kJpegHuffmanMaxBitLength - max_depth)); + if (space < 0) { + return JXL_FAILURE("Invalid Huffman code lengths."); + } else if (space > 0 && huff_lut[0].value != 0xffff) { + // Re-initialize the values to an invalid symbol so that we can recognize + // it when reading the bit stream using a Huffman code with space > 0. + for (int i = 0; i < kJpegHuffmanLutSize; ++i) { + huff_lut[i].bits = 0; + huff_lut[i].value = 0xffff; + } + } + huff.is_last = (*pos == start_pos + marker_len); + if (mode == JpegReadMode::kReadAll) { + BuildJpegHuffmanTable(&huff.counts[0], &huff.values[0], huff_lut); + } + jpg->huffman_code.push_back(huff); + } + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the Define Quantization Table (DQT) marker segment and fills in *jpg +// with the parsed data. +bool ProcessDQT(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + if (marker_len == 2) { + return JXL_FAILURE("DQT marker: no quantization table found"); + } + while (*pos < start_pos + marker_len && jpg->quant.size() < kMaxQuantTables) { + JXL_JPEG_VERIFY_LEN(1); + int quant_table_index = ReadUint8(data, pos); + int quant_table_precision = quant_table_index >> 4; + JXL_JPEG_VERIFY_INPUT(quant_table_precision, 0, 1, QUANT_TBL_PRECISION); + quant_table_index &= 0xf; + JXL_JPEG_VERIFY_INPUT(quant_table_index, 0, 3, QUANT_TBL_INDEX); + JXL_JPEG_VERIFY_LEN((quant_table_precision + 1) * kDCTBlockSize); + JPEGQuantTable table; + table.index = quant_table_index; + table.precision = quant_table_precision; + for (size_t i = 0; i < kDCTBlockSize; ++i) { + int quant_val = + quant_table_precision ? ReadUint16(data, pos) : ReadUint8(data, pos); + JXL_JPEG_VERIFY_INPUT(quant_val, 1, 65535, QUANT_VAL); + table.values[kJPEGNaturalOrder[i]] = quant_val; + } + table.is_last = (*pos == start_pos + marker_len); + jpg->quant.push_back(table); + } + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Reads the DRI marker and saves the restart interval into *jpg. +bool ProcessDRI(const uint8_t* data, const size_t len, size_t* pos, + bool* found_dri, JPEGData* jpg) { + if (*found_dri) { + return JXL_FAILURE("Duplicate DRI marker."); + } + *found_dri = true; + const size_t start_pos = *pos; + JXL_JPEG_VERIFY_LEN(4); + size_t marker_len = ReadUint16(data, pos); + int restart_interval = ReadUint16(data, pos); + jpg->restart_interval = restart_interval; + JXL_JPEG_VERIFY_MARKER_END(); + return true; +} + +// Saves the APP marker segment as a string to *jpg. +bool ProcessAPP(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + JXL_JPEG_VERIFY_INPUT(marker_len, 2, 65535, MARKER_LEN); + JXL_JPEG_VERIFY_LEN(marker_len - 2); + JXL_DASSERT(*pos >= 3); + // Save the marker type together with the app data. + const uint8_t* app_str_start = data + *pos - 3; + std::vector<uint8_t> app_str(app_str_start, app_str_start + marker_len + 1); + *pos += marker_len - 2; + jpg->app_data.push_back(app_str); + return true; +} + +// Saves the COM marker segment as a string to *jpg. +bool ProcessCOM(const uint8_t* data, const size_t len, size_t* pos, + JPEGData* jpg) { + JXL_JPEG_VERIFY_LEN(2); + size_t marker_len = ReadUint16(data, pos); + JXL_JPEG_VERIFY_INPUT(marker_len, 2, 65535, MARKER_LEN); + JXL_JPEG_VERIFY_LEN(marker_len - 2); + const uint8_t* com_str_start = data + *pos - 3; + std::vector<uint8_t> com_str(com_str_start, com_str_start + marker_len + 1); + *pos += marker_len - 2; + jpg->com_data.push_back(com_str); + return true; +} + +// Helper structure to read bits from the entropy coded data segment. +struct BitReaderState { + BitReaderState(const uint8_t* data, const size_t len, size_t pos) + : data_(data), len_(len) { + Reset(pos); + } + + void Reset(size_t pos) { + pos_ = pos; + val_ = 0; + bits_left_ = 0; + next_marker_pos_ = len_ - 2; + FillBitWindow(); + } + + // Returns the next byte and skips the 0xff/0x00 escape sequences. + uint8_t GetNextByte() { + if (pos_ >= next_marker_pos_) { + ++pos_; + return 0; + } + uint8_t c = data_[pos_++]; + if (c == 0xff) { + uint8_t escape = data_[pos_]; + if (escape == 0) { + ++pos_; + } else { + // 0xff was followed by a non-zero byte, which means that we found the + // start of the next marker segment. + next_marker_pos_ = pos_ - 1; + } + } + return c; + } + + void FillBitWindow() { + if (bits_left_ <= 16) { + while (bits_left_ <= 56) { + val_ <<= 8; + val_ |= (uint64_t)GetNextByte(); + bits_left_ += 8; + } + } + } + + int ReadBits(int nbits) { + FillBitWindow(); + uint64_t val = (val_ >> (bits_left_ - nbits)) & ((1ULL << nbits) - 1); + bits_left_ -= nbits; + return val; + } + + // Sets *pos to the next stream position where parsing should continue. + // Enqueue the padding bits seen (0 or 1). + // Returns false if there is inconsistent or invalid padding or the stream + // ended too early. + bool FinishStream(JPEGData* jpg, size_t* pos) { + int npadbits = bits_left_ & 7; + if (npadbits > 0) { + uint64_t padmask = (1ULL << npadbits) - 1; + uint64_t padbits = (val_ >> (bits_left_ - npadbits)) & padmask; + if (padbits != padmask) { + jpg->has_zero_padding_bit = true; + } + for (int i = npadbits - 1; i >= 0; --i) { + jpg->padding_bits.push_back((padbits >> i) & 1); + } + } + // Give back some bytes that we did not use. + int unused_bytes_left = bits_left_ >> 3; + while (unused_bytes_left-- > 0) { + --pos_; + // If we give back a 0 byte, we need to check if it was a 0xff/0x00 escape + // sequence, and if yes, we need to give back one more byte. + if (pos_ < next_marker_pos_ && data_[pos_] == 0 && + data_[pos_ - 1] == 0xff) { + --pos_; + } + } + if (pos_ > next_marker_pos_) { + // Data ran out before the scan was complete. + return JXL_FAILURE("Unexpected end of scan."); + } + *pos = pos_; + return true; + } + + const uint8_t* data_; + const size_t len_; + size_t pos_; + uint64_t val_; + int bits_left_; + size_t next_marker_pos_; +}; + +// Returns the next Huffman-coded symbol. +int ReadSymbol(const HuffmanTableEntry* table, BitReaderState* br) { + int nbits; + br->FillBitWindow(); + int val = (br->val_ >> (br->bits_left_ - 8)) & 0xff; + table += val; + nbits = table->bits - 8; + if (nbits > 0) { + br->bits_left_ -= 8; + table += table->value; + val = (br->val_ >> (br->bits_left_ - nbits)) & ((1 << nbits) - 1); + table += val; + } + br->bits_left_ -= table->bits; + return table->value; +} + +/** + * Returns the DC diff or AC value for extra bits value x and prefix code s. + * + * CCITT Rec. T.81 (1992 E) + * Table F.1 – Difference magnitude categories for DC coding + * SSSS | DIFF values + * ------+-------------------------- + * 0 | 0 + * 1 | –1, 1 + * 2 | –3, –2, 2, 3 + * 3 | –7..–4, 4..7 + * ......|.......................... + * 11 | –2047..–1024, 1024..2047 + * + * CCITT Rec. T.81 (1992 E) + * Table F.2 – Categories assigned to coefficient values + * [ Same as Table F.1, but does not include SSSS equal to 0 and 11] + * + * + * CCITT Rec. T.81 (1992 E) + * F.1.2.1.1 Structure of DC code table + * For each category,... additional bits... appended... to uniquely identify + * which difference... occurred... When DIFF is positive... SSSS... bits of DIFF + * are appended. When DIFF is negative... SSSS... bits of (DIFF – 1) are + * appended... Most significant bit... is 0 for negative differences and 1 for + * positive differences. + * + * In other words the upper half of extra bits range represents DIFF as is. + * The lower half represents the negative DIFFs with an offset. + */ +int HuffExtend(int x, int s) { + JXL_DASSERT(s >= 1); + int half = 1 << (s - 1); + if (x >= half) { + JXL_DASSERT(x < (1 << s)); + return x; + } else { + return x - (1 << s) + 1; + } +} + +// Decodes one 8x8 block of DCT coefficients from the bit stream. +bool DecodeDCTBlock(const HuffmanTableEntry* dc_huff, + const HuffmanTableEntry* ac_huff, int Ss, int Se, int Al, + int* eobrun, bool* reset_state, int* num_zero_runs, + BitReaderState* br, JPEGData* jpg, coeff_t* last_dc_coeff, + coeff_t* coeffs) { + // Nowadays multiplication is even faster than variable shift. + int Am = 1 << Al; + bool eobrun_allowed = Ss > 0; + if (Ss == 0) { + int s = ReadSymbol(dc_huff, br); + if (s >= kJpegDCAlphabetSize) { + return JXL_FAILURE("Invalid Huffman symbol %d for DC coefficient.", s); + } + int diff = 0; + if (s > 0) { + int bits = br->ReadBits(s); + diff = HuffExtend(bits, s); + } + int coeff = diff + *last_dc_coeff; + const int dc_coeff = coeff * Am; + coeffs[0] = dc_coeff; + // TODO(eustas): is there a more elegant / explicit way to check this? + if (dc_coeff != coeffs[0]) { + return JXL_FAILURE("Invalid DC coefficient %d", dc_coeff); + } + *last_dc_coeff = coeff; + ++Ss; + } + if (Ss > Se) { + return true; + } + if (*eobrun > 0) { + --(*eobrun); + return true; + } + *num_zero_runs = 0; + for (int k = Ss; k <= Se; k++) { + int sr = ReadSymbol(ac_huff, br); + if (sr >= kJpegHuffmanAlphabetSize) { + return JXL_FAILURE("Invalid Huffman symbol %d for AC coefficient %d", sr, + k); + } + int r = sr >> 4; + int s = sr & 15; + if (s > 0) { + k += r; + if (k > Se) { + return JXL_FAILURE("Out-of-band coefficient %d band was %d-%d", k, Ss, + Se); + } + if (s + Al >= kJpegDCAlphabetSize) { + return JXL_FAILURE( + "Out of range AC coefficient value: s = %d Al = %d k = %d", s, Al, + k); + } + int bits = br->ReadBits(s); + int coeff = HuffExtend(bits, s); + coeffs[kJPEGNaturalOrder[k]] = coeff * Am; + *num_zero_runs = 0; + } else if (r == 15) { + k += 15; + ++(*num_zero_runs); + } else { + if (eobrun_allowed && k == Ss && *eobrun == 0) { + // We have two end-of-block runs right after each other, so we signal + // the jpeg encoder to force a state reset at this point. + *reset_state = true; + } + *eobrun = 1 << r; + if (r > 0) { + if (!eobrun_allowed) { + return JXL_FAILURE("End-of-block run crossing DC coeff."); + } + *eobrun += br->ReadBits(r); + } + break; + } + } + --(*eobrun); + return true; +} + +bool RefineDCTBlock(const HuffmanTableEntry* ac_huff, int Ss, int Se, int Al, + int* eobrun, bool* reset_state, BitReaderState* br, + JPEGData* jpg, coeff_t* coeffs) { + // Nowadays multiplication is even faster than variable shift. + int Am = 1 << Al; + bool eobrun_allowed = Ss > 0; + if (Ss == 0) { + int s = br->ReadBits(1); + coeff_t dc_coeff = coeffs[0]; + dc_coeff |= s * Am; + coeffs[0] = dc_coeff; + ++Ss; + } + if (Ss > Se) { + return true; + } + int p1 = Am; + int m1 = -Am; + int k = Ss; + int r; + int s; + bool in_zero_run = false; + if (*eobrun <= 0) { + for (; k <= Se; k++) { + s = ReadSymbol(ac_huff, br); + if (s >= kJpegHuffmanAlphabetSize) { + return JXL_FAILURE("Invalid Huffman symbol %d for AC coefficient %d", s, + k); + } + r = s >> 4; + s &= 15; + if (s) { + if (s != 1) { + return JXL_FAILURE("Invalid Huffman symbol %d for AC coefficient %d", + s, k); + } + s = br->ReadBits(1) ? p1 : m1; + in_zero_run = false; + } else { + if (r != 15) { + if (eobrun_allowed && k == Ss && *eobrun == 0) { + // We have two end-of-block runs right after each other, so we + // signal the jpeg encoder to force a state reset at this point. + *reset_state = true; + } + *eobrun = 1 << r; + if (r > 0) { + if (!eobrun_allowed) { + return JXL_FAILURE("End-of-block run crossing DC coeff."); + } + *eobrun += br->ReadBits(r); + } + break; + } + in_zero_run = true; + } + do { + coeff_t thiscoef = coeffs[kJPEGNaturalOrder[k]]; + if (thiscoef != 0) { + if (br->ReadBits(1)) { + if ((thiscoef & p1) == 0) { + if (thiscoef >= 0) { + thiscoef += p1; + } else { + thiscoef += m1; + } + } + } + coeffs[kJPEGNaturalOrder[k]] = thiscoef; + } else { + if (--r < 0) { + break; + } + } + k++; + } while (k <= Se); + if (s) { + if (k > Se) { + return JXL_FAILURE("Out-of-band coefficient %d band was %d-%d", k, Ss, + Se); + } + coeffs[kJPEGNaturalOrder[k]] = s; + } + } + } + if (in_zero_run) { + return JXL_FAILURE("Extra zero run before end-of-block."); + } + if (*eobrun > 0) { + for (; k <= Se; k++) { + coeff_t thiscoef = coeffs[kJPEGNaturalOrder[k]]; + if (thiscoef != 0) { + if (br->ReadBits(1)) { + if ((thiscoef & p1) == 0) { + if (thiscoef >= 0) { + thiscoef += p1; + } else { + thiscoef += m1; + } + } + } + coeffs[kJPEGNaturalOrder[k]] = thiscoef; + } + } + } + --(*eobrun); + return true; +} + +bool ProcessRestart(const uint8_t* data, const size_t len, + int* next_restart_marker, BitReaderState* br, + JPEGData* jpg) { + size_t pos = 0; + if (!br->FinishStream(jpg, &pos)) { + return JXL_FAILURE("Invalid scan"); + } + int expected_marker = 0xd0 + *next_restart_marker; + JXL_JPEG_EXPECT_MARKER(); + int marker = data[pos + 1]; + if (marker != expected_marker) { + return JXL_FAILURE("Did not find expected restart marker %d actual %d", + expected_marker, marker); + } + br->Reset(pos + 2); + *next_restart_marker += 1; + *next_restart_marker &= 0x7; + return true; +} + +bool ProcessScan(const uint8_t* data, const size_t len, + const std::vector<HuffmanTableEntry>& dc_huff_lut, + const std::vector<HuffmanTableEntry>& ac_huff_lut, + uint16_t scan_progression[kMaxComponents][kDCTBlockSize], + bool is_progressive, size_t* pos, JPEGData* jpg) { + if (!ProcessSOS(data, len, pos, jpg)) { + return false; + } + JPEGScanInfo* scan_info = &jpg->scan_info.back(); + bool is_interleaved = (scan_info->num_components > 1); + int max_h_samp_factor = 1; + int max_v_samp_factor = 1; + for (size_t i = 0; i < jpg->components.size(); ++i) { + max_h_samp_factor = + std::max(max_h_samp_factor, jpg->components[i].h_samp_factor); + max_v_samp_factor = + std::max(max_v_samp_factor, jpg->components[i].v_samp_factor); + } + + int MCU_rows = DivCeil(jpg->height, max_v_samp_factor * 8); + int MCUs_per_row = DivCeil(jpg->width, max_h_samp_factor * 8); + if (!is_interleaved) { + const JPEGComponent& c = jpg->components[scan_info->components[0].comp_idx]; + MCUs_per_row = DivCeil(jpg->width * c.h_samp_factor, 8 * max_h_samp_factor); + MCU_rows = DivCeil(jpg->height * c.v_samp_factor, 8 * max_v_samp_factor); + } + coeff_t last_dc_coeff[kMaxComponents] = {0}; + BitReaderState br(data, len, *pos); + int restarts_to_go = jpg->restart_interval; + int next_restart_marker = 0; + int eobrun = -1; + int block_scan_index = 0; + const int Al = is_progressive ? scan_info->Al : 0; + const int Ah = is_progressive ? scan_info->Ah : 0; + const int Ss = is_progressive ? scan_info->Ss : 0; + const int Se = is_progressive ? scan_info->Se : 63; + const uint16_t scan_bitmask = Ah == 0 ? (0xffff << Al) : (1u << Al); + const uint16_t refinement_bitmask = (1 << Al) - 1; + for (size_t i = 0; i < scan_info->num_components; ++i) { + int comp_idx = scan_info->components[i].comp_idx; + for (int k = Ss; k <= Se; ++k) { + if (scan_progression[comp_idx][k] & scan_bitmask) { + return JXL_FAILURE( + "Overlapping scans: component=%d k=%d prev_mask: %u cur_mask %u", + comp_idx, k, scan_progression[i][k], scan_bitmask); + } + if (scan_progression[comp_idx][k] & refinement_bitmask) { + return JXL_FAILURE( + "Invalid scan order, a more refined scan was already done: " + "component=%d k=%d prev_mask=%u cur_mask=%u", + comp_idx, k, scan_progression[i][k], scan_bitmask); + } + scan_progression[comp_idx][k] |= scan_bitmask; + } + } + if (Al > 10) { + return JXL_FAILURE("Scan parameter Al=%d is not supported.", Al); + } + for (int mcu_y = 0; mcu_y < MCU_rows; ++mcu_y) { + for (int mcu_x = 0; mcu_x < MCUs_per_row; ++mcu_x) { + // Handle the restart intervals. + if (jpg->restart_interval > 0) { + if (restarts_to_go == 0) { + if (ProcessRestart(data, len, &next_restart_marker, &br, jpg)) { + restarts_to_go = jpg->restart_interval; + memset(static_cast<void*>(last_dc_coeff), 0, sizeof(last_dc_coeff)); + if (eobrun > 0) { + return JXL_FAILURE("End-of-block run too long."); + } + eobrun = -1; // fresh start + } else { + return JXL_FAILURE("Could not process restart."); + } + } + --restarts_to_go; + } + // Decode one MCU. + for (size_t i = 0; i < scan_info->num_components; ++i) { + JPEGComponentScanInfo* si = &scan_info->components[i]; + JPEGComponent* c = &jpg->components[si->comp_idx]; + const HuffmanTableEntry* dc_lut = + &dc_huff_lut[si->dc_tbl_idx * kJpegHuffmanLutSize]; + const HuffmanTableEntry* ac_lut = + &ac_huff_lut[si->ac_tbl_idx * kJpegHuffmanLutSize]; + int nblocks_y = is_interleaved ? c->v_samp_factor : 1; + int nblocks_x = is_interleaved ? c->h_samp_factor : 1; + for (int iy = 0; iy < nblocks_y; ++iy) { + for (int ix = 0; ix < nblocks_x; ++ix) { + int block_y = mcu_y * nblocks_y + iy; + int block_x = mcu_x * nblocks_x + ix; + int block_idx = block_y * c->width_in_blocks + block_x; + bool reset_state = false; + int num_zero_runs = 0; + coeff_t* coeffs = &c->coeffs[block_idx * kDCTBlockSize]; + if (Ah == 0) { + if (!DecodeDCTBlock(dc_lut, ac_lut, Ss, Se, Al, &eobrun, + &reset_state, &num_zero_runs, &br, jpg, + &last_dc_coeff[si->comp_idx], coeffs)) { + return false; + } + } else { + if (!RefineDCTBlock(ac_lut, Ss, Se, Al, &eobrun, &reset_state, + &br, jpg, coeffs)) { + return false; + } + } + if (reset_state) { + scan_info->reset_points.emplace_back(block_scan_index); + } + if (num_zero_runs > 0) { + JPEGScanInfo::ExtraZeroRunInfo info; + info.block_idx = block_scan_index; + info.num_extra_zero_runs = num_zero_runs; + scan_info->extra_zero_runs.push_back(info); + } + ++block_scan_index; + } + } + } + } + } + if (eobrun > 0) { + return JXL_FAILURE("End-of-block run too long."); + } + if (!br.FinishStream(jpg, pos)) { + return JXL_FAILURE("Invalid scan."); + } + if (*pos > len) { + return JXL_FAILURE("Unexpected end of file during scan. pos=%" PRIuS + " len=%" PRIuS, + *pos, len); + } + return true; +} + +// Changes the quant_idx field of the components to refer to the index of the +// quant table in the jpg->quant array. +bool FixupIndexes(JPEGData* jpg) { + for (size_t i = 0; i < jpg->components.size(); ++i) { + JPEGComponent* c = &jpg->components[i]; + bool found_index = false; + for (size_t j = 0; j < jpg->quant.size(); ++j) { + if (jpg->quant[j].index == c->quant_idx) { + c->quant_idx = j; + found_index = true; + break; + } + } + if (!found_index) { + return JXL_FAILURE("Quantization table with index %u not found", + c->quant_idx); + } + } + return true; +} + +size_t FindNextMarker(const uint8_t* data, const size_t len, size_t pos) { + // kIsValidMarker[i] == 1 means (0xc0 + i) is a valid marker. + static const uint8_t kIsValidMarker[] = { + 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + }; + size_t num_skipped = 0; + while (pos + 1 < len && (data[pos] != 0xff || data[pos + 1] < 0xc0 || + !kIsValidMarker[data[pos + 1] - 0xc0])) { + ++pos; + ++num_skipped; + } + return num_skipped; +} + +} // namespace + +bool ReadJpeg(const uint8_t* data, const size_t len, JpegReadMode mode, + JPEGData* jpg) { + size_t pos = 0; + // Check SOI marker. + JXL_JPEG_EXPECT_MARKER(); + int marker = data[pos + 1]; + pos += 2; + if (marker != 0xd8) { + return JXL_FAILURE("Did not find expected SOI marker, actual=%d", marker); + } + int lut_size = kMaxHuffmanTables * kJpegHuffmanLutSize; + std::vector<HuffmanTableEntry> dc_huff_lut(lut_size); + std::vector<HuffmanTableEntry> ac_huff_lut(lut_size); + bool found_sof = false; + bool found_dri = false; + uint16_t scan_progression[kMaxComponents][kDCTBlockSize] = {{0}}; + + jpg->padding_bits.resize(0); + bool is_progressive = false; // default + do { + // Read next marker. + size_t num_skipped = FindNextMarker(data, len, pos); + if (num_skipped > 0) { + // Add a fake marker to indicate arbitrary in-between-markers data. + jpg->marker_order.push_back(0xff); + jpg->inter_marker_data.emplace_back(data + pos, data + pos + num_skipped); + pos += num_skipped; + } + JXL_JPEG_EXPECT_MARKER(); + marker = data[pos + 1]; + pos += 2; + bool ok = true; + switch (marker) { + case 0xc0: + case 0xc1: + case 0xc2: + is_progressive = (marker == 0xc2); + ok = ProcessSOF(data, len, mode, &pos, jpg); + found_sof = true; + break; + case 0xc4: + ok = ProcessDHT(data, len, mode, &dc_huff_lut, &ac_huff_lut, &pos, jpg); + break; + case 0xd0: + case 0xd1: + case 0xd2: + case 0xd3: + case 0xd4: + case 0xd5: + case 0xd6: + case 0xd7: + // RST markers do not have any data. + break; + case 0xd9: + // Found end marker. + break; + case 0xda: + if (mode == JpegReadMode::kReadAll) { + ok = ProcessScan(data, len, dc_huff_lut, ac_huff_lut, + scan_progression, is_progressive, &pos, jpg); + } + break; + case 0xdb: + ok = ProcessDQT(data, len, &pos, jpg); + break; + case 0xdd: + ok = ProcessDRI(data, len, &pos, &found_dri, jpg); + break; + case 0xe0: + case 0xe1: + case 0xe2: + case 0xe3: + case 0xe4: + case 0xe5: + case 0xe6: + case 0xe7: + case 0xe8: + case 0xe9: + case 0xea: + case 0xeb: + case 0xec: + case 0xed: + case 0xee: + case 0xef: + if (mode != JpegReadMode::kReadTables) { + ok = ProcessAPP(data, len, &pos, jpg); + } + break; + case 0xfe: + if (mode != JpegReadMode::kReadTables) { + ok = ProcessCOM(data, len, &pos, jpg); + } + break; + default: + return JXL_FAILURE("Unsupported marker: %d pos=%" PRIuS " len=%" PRIuS, + marker, pos, len); + } + if (!ok) { + return false; + } + jpg->marker_order.push_back(marker); + if (mode == JpegReadMode::kReadHeader && found_sof) { + break; + } + } while (marker != 0xd9); + + if (!found_sof) { + return JXL_FAILURE("Missing SOF marker."); + } + + // Supplemental checks. + if (mode == JpegReadMode::kReadAll) { + if (pos < len) { + jpg->tail_data = std::vector<uint8_t>(data + pos, data + len); + } + if (!FixupIndexes(jpg)) { + return false; + } + if (jpg->huffman_code.empty()) { + // Section B.2.4.2: "If a table has never been defined for a particular + // destination, then when this destination is specified in a scan header, + // the results are unpredictable." + return JXL_FAILURE("Need at least one Huffman code table."); + } + if (jpg->huffman_code.size() >= kMaxDHTMarkers) { + return JXL_FAILURE("Too many Huffman tables."); + } + } + return true; +} + +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.h b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.h new file mode 100644 index 0000000000..3fad820e9d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_data_reader.h @@ -0,0 +1,36 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Functions for reading a jpeg byte stream into a JPEGData object. + +#ifndef LIB_JXL_JPEG_ENC_JPEG_DATA_READER_H_ +#define LIB_JXL_JPEG_ENC_JPEG_DATA_READER_H_ + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +enum class JpegReadMode { + kReadHeader, // only basic headers + kReadTables, // headers and tables (quant, Huffman, ...) + kReadAll, // everything +}; + +// Parses the JPEG stream contained in data[*pos ... len) and fills in *jpg with +// the parsed information. +// If mode is kReadHeader, it fills in only the image dimensions in *jpg. +// Returns false if the data is not valid JPEG, or if it contains an unsupported +// JPEG feature. +bool ReadJpeg(const uint8_t* data, const size_t len, JpegReadMode mode, + JPEGData* jpg); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_ENC_JPEG_DATA_READER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.cc b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.cc new file mode 100644 index 0000000000..38282e640a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.cc @@ -0,0 +1,103 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/enc_jpeg_huffman_decode.h" + +#include "lib/jxl/jpeg/jpeg_data.h" + +namespace jxl { +namespace jpeg { + +// Returns the table width of the next 2nd level table, count is the histogram +// of bit lengths for the remaining symbols, len is the code length of the next +// processed symbol. +static inline int NextTableBitSize(const int* count, int len) { + int left = 1 << (len - kJpegHuffmanRootTableBits); + while (len < static_cast<int>(kJpegHuffmanMaxBitLength)) { + left -= count[len]; + if (left <= 0) break; + ++len; + left <<= 1; + } + return len - kJpegHuffmanRootTableBits; +} + +void BuildJpegHuffmanTable(const uint32_t* count, const uint32_t* symbols, + HuffmanTableEntry* lut) { + HuffmanTableEntry code; // current table entry + HuffmanTableEntry* table; // next available space in table + int len; // current code length + int idx; // symbol index + int key; // prefix code + int reps; // number of replicate key values in current table + int low; // low bits for current root entry + int table_bits; // key length of current table + int table_size; // size of current table + + // Make a local copy of the input bit length histogram. + int tmp_count[kJpegHuffmanMaxBitLength + 1] = {0}; + int total_count = 0; + for (len = 1; len <= static_cast<int>(kJpegHuffmanMaxBitLength); ++len) { + tmp_count[len] = count[len]; + total_count += tmp_count[len]; + } + + table = lut; + table_bits = kJpegHuffmanRootTableBits; + table_size = 1 << table_bits; + + // Special case code with only one value. + if (total_count == 1) { + code.bits = 0; + code.value = symbols[0]; + for (key = 0; key < table_size; ++key) { + table[key] = code; + } + return; + } + + // Fill in root table. + key = 0; + idx = 0; + for (len = 1; len <= kJpegHuffmanRootTableBits; ++len) { + for (; tmp_count[len] > 0; --tmp_count[len]) { + code.bits = len; + code.value = symbols[idx++]; + reps = 1 << (kJpegHuffmanRootTableBits - len); + while (reps--) { + table[key++] = code; + } + } + } + + // Fill in 2nd level tables and add pointers to root table. + table += table_size; + table_size = 0; + low = 0; + for (len = kJpegHuffmanRootTableBits + 1; + len <= static_cast<int>(kJpegHuffmanMaxBitLength); ++len) { + for (; tmp_count[len] > 0; --tmp_count[len]) { + // Start a new sub-table if the previous one is full. + if (low >= table_size) { + table += table_size; + table_bits = NextTableBitSize(tmp_count, len); + table_size = 1 << table_bits; + low = 0; + lut[key].bits = table_bits + kJpegHuffmanRootTableBits; + lut[key].value = (table - lut) - key; + ++key; + } + code.bits = len - kJpegHuffmanRootTableBits; + code.value = symbols[idx++]; + reps = 1 << (table_bits - code.bits); + while (reps--) { + table[low++] = code; + } + } + } +} + +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.h b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.h new file mode 100644 index 0000000000..b8a60e4107 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/enc_jpeg_huffman_decode.h @@ -0,0 +1,41 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Utility function for building a Huffman lookup table for the jpeg decoder. + +#ifndef LIB_JXL_JPEG_ENC_JPEG_HUFFMAN_DECODE_H_ +#define LIB_JXL_JPEG_ENC_JPEG_HUFFMAN_DECODE_H_ + +#include <stdint.h> + +namespace jxl { +namespace jpeg { + +constexpr int kJpegHuffmanRootTableBits = 8; +// Maximum huffman lookup table size. +// According to zlib/examples/enough.c, 758 entries are always enough for +// an alphabet of 257 symbols (256 + 1 special symbol for the all 1s code) and +// max bit length 16 if the root table has 8 bits. +constexpr int kJpegHuffmanLutSize = 758; + +struct HuffmanTableEntry { + // Initialize the value to an invalid symbol so that we can recognize it + // when reading the bit stream using a Huffman code with space > 0. + HuffmanTableEntry() : bits(0), value(0xffff) {} + + uint8_t bits; // number of bits used for this symbol + uint16_t value; // symbol value or table offset +}; + +// Builds jpeg-style Huffman lookup table from the given symbols. +// The symbols are in order of increasing bit lengths. The number of symbols +// with bit length n is given in counts[n] for each n >= 1. +void BuildJpegHuffmanTable(const uint32_t* counts, const uint32_t* symbols, + HuffmanTableEntry* lut); + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_ENC_JPEG_HUFFMAN_DECODE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.cc b/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.cc new file mode 100644 index 0000000000..6744e6935a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.cc @@ -0,0 +1,480 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/jpeg/jpeg_data.h" + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/common.h" // kMaxNumPasses, JPEGXL_ENABLE_TRANSCODE_JPEG + +namespace jxl { +namespace jpeg { + +#if JPEGXL_ENABLE_TRANSCODE_JPEG + +namespace { +enum JPEGComponentType : uint32_t { + kGray = 0, + kYCbCr = 1, + kRGB = 2, + kCustom = 3, +}; + +struct JPEGInfo { + size_t num_app_markers = 0; + size_t num_com_markers = 0; + size_t num_scans = 0; + size_t num_intermarker = 0; + bool has_dri = false; +}; + +Status VisitMarker(uint8_t* marker, Visitor* visitor, JPEGInfo* info) { + uint32_t marker32 = *marker - 0xc0; + JXL_RETURN_IF_ERROR(visitor->Bits(6, 0x00, &marker32)); + *marker = marker32 + 0xc0; + if ((*marker & 0xf0) == 0xe0) { + info->num_app_markers++; + } + if (*marker == 0xfe) { + info->num_com_markers++; + } + if (*marker == 0xda) { + info->num_scans++; + } + // We use a fake 0xff marker to signal intermarker data. + if (*marker == 0xff) { + info->num_intermarker++; + } + if (*marker == 0xdd) { + info->has_dri = true; + } + return true; +} + +} // namespace + +Status JPEGData::VisitFields(Visitor* visitor) { + bool is_gray = components.size() == 1; + JXL_RETURN_IF_ERROR(visitor->Bool(false, &is_gray)); + if (visitor->IsReading()) { + components.resize(is_gray ? 1 : 3); + } + JPEGInfo info; + if (visitor->IsReading()) { + uint8_t marker = 0xc0; + do { + JXL_RETURN_IF_ERROR(VisitMarker(&marker, visitor, &info)); + marker_order.push_back(marker); + if (marker_order.size() > 16384) { + return JXL_FAILURE("Too many markers: %" PRIuS "\n", + marker_order.size()); + } + } while (marker != 0xd9); + } else { + if (marker_order.size() > 16384) { + return JXL_FAILURE("Too many markers: %" PRIuS "\n", marker_order.size()); + } + for (size_t i = 0; i < marker_order.size(); i++) { + JXL_RETURN_IF_ERROR(VisitMarker(&marker_order[i], visitor, &info)); + } + if (!marker_order.empty()) { + // Last marker should always be EOI marker. + JXL_CHECK(marker_order.back() == 0xd9); + } + } + + // Size of the APP and COM markers. + if (visitor->IsReading()) { + app_data.resize(info.num_app_markers); + app_marker_type.resize(info.num_app_markers); + com_data.resize(info.num_com_markers); + scan_info.resize(info.num_scans); + } + JXL_ASSERT(app_data.size() == info.num_app_markers); + JXL_ASSERT(app_marker_type.size() == info.num_app_markers); + JXL_ASSERT(com_data.size() == info.num_com_markers); + JXL_ASSERT(scan_info.size() == info.num_scans); + for (size_t i = 0; i < app_data.size(); i++) { + auto& app = app_data[i]; + // Encodes up to 8 different values. + JXL_RETURN_IF_ERROR( + visitor->U32(Val(0), Val(1), BitsOffset(1, 2), BitsOffset(2, 4), 0, + reinterpret_cast<uint32_t*>(&app_marker_type[i]))); + if (app_marker_type[i] != AppMarkerType::kUnknown && + app_marker_type[i] != AppMarkerType::kICC && + app_marker_type[i] != AppMarkerType::kExif && + app_marker_type[i] != AppMarkerType::kXMP) { + return JXL_FAILURE("Unknown app marker type %u", + static_cast<uint32_t>(app_marker_type[i])); + } + uint32_t len = app.size() - 1; + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &len)); + if (visitor->IsReading()) app.resize(len + 1); + if (app.size() < 3) { + return JXL_FAILURE("Invalid marker size: %" PRIuS "\n", app.size()); + } + } + for (auto& com : com_data) { + uint32_t len = com.size() - 1; + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &len)); + if (visitor->IsReading()) com.resize(len + 1); + if (com.size() < 3) { + return JXL_FAILURE("Invalid marker size: %" PRIuS "\n", com.size()); + } + } + + uint32_t num_quant_tables = quant.size(); + JXL_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 2, &num_quant_tables)); + if (num_quant_tables == 4) { + return JXL_FAILURE("Invalid number of quant tables"); + } + if (visitor->IsReading()) { + quant.resize(num_quant_tables); + } + for (size_t i = 0; i < num_quant_tables; i++) { + if (quant[i].precision > 1) { + return JXL_FAILURE( + "Quant tables with more than 16 bits are not supported"); + } + JXL_RETURN_IF_ERROR(visitor->Bits(1, 0, &quant[i].precision)); + JXL_RETURN_IF_ERROR(visitor->Bits(2, i, &quant[i].index)); + JXL_RETURN_IF_ERROR(visitor->Bool(true, &quant[i].is_last)); + } + + JPEGComponentType component_type = + components.size() == 1 && components[0].id == 1 ? JPEGComponentType::kGray + : components.size() == 3 && components[0].id == 1 && + components[1].id == 2 && components[2].id == 3 + ? JPEGComponentType::kYCbCr + : components.size() == 3 && components[0].id == 'R' && + components[1].id == 'G' && components[2].id == 'B' + ? JPEGComponentType::kRGB + : JPEGComponentType::kCustom; + JXL_RETURN_IF_ERROR( + visitor->Bits(2, JPEGComponentType::kYCbCr, + reinterpret_cast<uint32_t*>(&component_type))); + uint32_t num_components; + if (component_type == JPEGComponentType::kGray) { + num_components = 1; + } else if (component_type != JPEGComponentType::kCustom) { + num_components = 3; + } else { + num_components = components.size(); + JXL_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 3, &num_components)); + if (num_components != 1 && num_components != 3) { + return JXL_FAILURE("Invalid number of components: %u", num_components); + } + } + if (visitor->IsReading()) { + components.resize(num_components); + } + if (component_type == JPEGComponentType::kCustom) { + for (size_t i = 0; i < components.size(); i++) { + JXL_RETURN_IF_ERROR(visitor->Bits(8, 0, &components[i].id)); + } + } else if (component_type == JPEGComponentType::kGray) { + components[0].id = 1; + } else if (component_type == JPEGComponentType::kRGB) { + components[0].id = 'R'; + components[1].id = 'G'; + components[2].id = 'B'; + } else { + components[0].id = 1; + components[1].id = 2; + components[2].id = 3; + } + size_t used_tables = 0; + for (size_t i = 0; i < components.size(); i++) { + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &components[i].quant_idx)); + if (components[i].quant_idx >= quant.size()) { + return JXL_FAILURE("Invalid quant table for component %" PRIuS ": %u\n", + i, components[i].quant_idx); + } + used_tables |= 1U << components[i].quant_idx; + } + for (size_t i = 0; i < quant.size(); i++) { + if (used_tables & (1 << i)) continue; + if (i == 0) return JXL_FAILURE("First quant table unused."); + // Unused quant table has to be set to copy of previous quant table + for (size_t j = 0; j < 64; j++) { + if (quant[i].values[j] != quant[i - 1].values[j]) { + return JXL_FAILURE("Non-trivial unused quant table"); + } + } + } + + uint32_t num_huff = huffman_code.size(); + JXL_RETURN_IF_ERROR(visitor->U32(Val(4), BitsOffset(3, 2), BitsOffset(4, 10), + BitsOffset(6, 26), 4, &num_huff)); + if (visitor->IsReading()) { + huffman_code.resize(num_huff); + } + for (JPEGHuffmanCode& hc : huffman_code) { + bool is_ac = hc.slot_id >> 4; + uint32_t id = hc.slot_id & 0xF; + JXL_RETURN_IF_ERROR(visitor->Bool(false, &is_ac)); + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &id)); + hc.slot_id = (static_cast<uint32_t>(is_ac) << 4) | id; + JXL_RETURN_IF_ERROR(visitor->Bool(true, &hc.is_last)); + size_t num_symbols = 0; + for (size_t i = 0; i <= 16; i++) { + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(3, 2), + Bits(8), 0, &hc.counts[i])); + num_symbols += hc.counts[i]; + } + if (num_symbols < 1) { + // Actually, at least 2 symbols are required, since one of them is EOI. + return JXL_FAILURE("Empty Huffman table"); + } + if (num_symbols > hc.values.size()) { + return JXL_FAILURE("Huffman code too large (%" PRIuS ")", num_symbols); + } + // Presence flags for 4 * 64 + 1 values. + uint64_t value_slots[5] = {}; + for (size_t i = 0; i < num_symbols; i++) { + // Goes up to 256, included. Might have the same symbol appear twice... + JXL_RETURN_IF_ERROR(visitor->U32(Bits(2), BitsOffset(2, 4), + BitsOffset(4, 8), BitsOffset(8, 1), 0, + &hc.values[i])); + value_slots[hc.values[i] >> 6] |= (uint64_t)1 << (hc.values[i] & 0x3F); + } + if (hc.values[num_symbols - 1] != kJpegHuffmanAlphabetSize) { + return JXL_FAILURE("Missing EOI symbol"); + } + // Last element, denoting EOI, have to be 1 after the loop. + JXL_ASSERT(value_slots[4] == 1); + size_t num_values = 1; + for (size_t i = 0; i < 4; ++i) num_values += hwy::PopCount(value_slots[i]); + if (num_values != num_symbols) { + return JXL_FAILURE("Duplicate Huffman symbols"); + } + if (!is_ac) { + bool only_dc = ((value_slots[0] >> kJpegDCAlphabetSize) | value_slots[1] | + value_slots[2] | value_slots[3]) == 0; + if (!only_dc) return JXL_FAILURE("Huffman symbols out of DC range"); + } + } + + for (auto& scan : scan_info) { + JXL_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), Val(4), 1, &scan.num_components)); + if (scan.num_components >= 4) { + return JXL_FAILURE("Invalid number of components in SOS marker"); + } + JXL_RETURN_IF_ERROR(visitor->Bits(6, 0, &scan.Ss)); + JXL_RETURN_IF_ERROR(visitor->Bits(6, 63, &scan.Se)); + JXL_RETURN_IF_ERROR(visitor->Bits(4, 0, &scan.Al)); + JXL_RETURN_IF_ERROR(visitor->Bits(4, 0, &scan.Ah)); + for (size_t i = 0; i < scan.num_components; i++) { + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &scan.components[i].comp_idx)); + if (scan.components[i].comp_idx >= components.size()) { + return JXL_FAILURE("Invalid component idx in SOS marker"); + } + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &scan.components[i].ac_tbl_idx)); + JXL_RETURN_IF_ERROR(visitor->Bits(2, 0, &scan.components[i].dc_tbl_idx)); + } + // TODO(veluca): actually set and use this value. + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), Val(2), BitsOffset(3, 3), + kMaxNumPasses - 1, + &scan.last_needed_pass)); + } + + // From here on, this is data that is not strictly necessary to get a valid + // JPEG, but necessary for bit-exact JPEG reconstruction. + if (info.has_dri) { + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &restart_interval)); + } + + for (auto& scan : scan_info) { + uint32_t num_reset_points = scan.reset_points.size(); + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(2, 1), BitsOffset(4, 4), + BitsOffset(16, 20), 0, &num_reset_points)); + if (visitor->IsReading()) { + scan.reset_points.resize(num_reset_points); + } + int last_block_idx = -1; + for (auto& block_idx : scan.reset_points) { + block_idx -= last_block_idx + 1; + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(3, 1), + BitsOffset(5, 9), BitsOffset(28, 41), 0, + &block_idx)); + block_idx += last_block_idx + 1; + if (block_idx >= (3u << 26)) { + // At most 8K x 8K x num_channels blocks are possible in a JPEG. + // So valid block indices are below 3 * 2^26. + return JXL_FAILURE("Invalid block ID: %u", block_idx); + } + last_block_idx = block_idx; + } + + uint32_t num_extra_zero_runs = scan.extra_zero_runs.size(); + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(2, 1), BitsOffset(4, 4), + BitsOffset(16, 20), 0, + &num_extra_zero_runs)); + if (visitor->IsReading()) { + scan.extra_zero_runs.resize(num_extra_zero_runs); + } + last_block_idx = -1; + for (size_t i = 0; i < scan.extra_zero_runs.size(); ++i) { + uint32_t& block_idx = scan.extra_zero_runs[i].block_idx; + JXL_RETURN_IF_ERROR(visitor->U32( + Val(1), BitsOffset(2, 2), BitsOffset(4, 5), BitsOffset(8, 20), 1, + &scan.extra_zero_runs[i].num_extra_zero_runs)); + block_idx -= last_block_idx + 1; + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(3, 1), + BitsOffset(5, 9), BitsOffset(28, 41), 0, + &block_idx)); + block_idx += last_block_idx + 1; + if (block_idx > (3u << 26)) { + return JXL_FAILURE("Invalid block ID: %u", block_idx); + } + last_block_idx = block_idx; + } + } + std::vector<uint32_t> inter_marker_data_sizes; + inter_marker_data_sizes.reserve(info.num_intermarker); + for (size_t i = 0; i < info.num_intermarker; ++i) { + uint32_t len = visitor->IsReading() ? 0 : inter_marker_data[i].size(); + JXL_RETURN_IF_ERROR(visitor->Bits(16, 0, &len)); + if (visitor->IsReading()) inter_marker_data_sizes.emplace_back(len); + } + uint32_t tail_data_len = tail_data.size(); + if (!visitor->IsReading() && tail_data_len > 4260096) { + return JXL_FAILURE("Tail data too large (max size = 4260096, size = %u)", + tail_data_len); + } + JXL_RETURN_IF_ERROR(visitor->U32(Val(0), BitsOffset(8, 1), + BitsOffset(16, 257), BitsOffset(22, 65793), + 0, &tail_data_len)); + + JXL_RETURN_IF_ERROR(visitor->Bool(false, &has_zero_padding_bit)); + if (has_zero_padding_bit) { + uint32_t nbit = padding_bits.size(); + JXL_RETURN_IF_ERROR(visitor->Bits(24, 0, &nbit)); + if (visitor->IsReading()) { + JXL_RETURN_IF_ERROR(CheckHasEnoughBits(visitor, nbit)); + padding_bits.reserve(std::min<uint32_t>(1024u, nbit)); + for (uint32_t i = 0; i < nbit; i++) { + bool bbit = false; + JXL_RETURN_IF_ERROR(visitor->Bool(false, &bbit)); + padding_bits.push_back(bbit); + } + } else { + for (uint8_t& bit : padding_bits) { + bool bbit = bit; + JXL_RETURN_IF_ERROR(visitor->Bool(false, &bbit)); + bit = bbit; + } + } + } + + { + size_t dht_index = 0; + size_t scan_index = 0; + bool is_progressive = false; + bool ac_ok[kMaxHuffmanTables] = {false}; + bool dc_ok[kMaxHuffmanTables] = {false}; + for (uint8_t marker : marker_order) { + if (marker == 0xC2) { + is_progressive = true; + } else if (marker == 0xC4) { + for (; dht_index < huffman_code.size();) { + const JPEGHuffmanCode& huff = huffman_code[dht_index++]; + size_t index = huff.slot_id; + if (index & 0x10) { + index -= 0x10; + ac_ok[index] = true; + } else { + dc_ok[index] = true; + } + if (huff.is_last) break; + } + } else if (marker == 0xDA) { + const JPEGScanInfo& si = scan_info[scan_index++]; + for (size_t i = 0; i < si.num_components; ++i) { + const JPEGComponentScanInfo& csi = si.components[i]; + size_t dc_tbl_idx = csi.dc_tbl_idx; + size_t ac_tbl_idx = csi.ac_tbl_idx; + bool want_dc = !is_progressive || (si.Ss == 0); + if (want_dc && !dc_ok[dc_tbl_idx]) { + return JXL_FAILURE("DC Huffman table used before defined"); + } + bool want_ac = !is_progressive || (si.Ss != 0) || (si.Se != 0); + if (want_ac && !ac_ok[ac_tbl_idx]) { + return JXL_FAILURE("AC Huffman table used before defined"); + } + } + } + } + } + + // Apply postponed actions. + if (visitor->IsReading()) { + tail_data.resize(tail_data_len); + JXL_ASSERT(inter_marker_data_sizes.size() == info.num_intermarker); + inter_marker_data.reserve(info.num_intermarker); + for (size_t i = 0; i < info.num_intermarker; ++i) { + inter_marker_data.emplace_back(inter_marker_data_sizes[i]); + } + } + + return true; +} + +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +void JPEGData::CalculateMcuSize(const JPEGScanInfo& scan, int* MCUs_per_row, + int* MCU_rows) const { + const bool is_interleaved = (scan.num_components > 1); + const JPEGComponent& base_component = components[scan.components[0].comp_idx]; + // h_group / v_group act as numerators for converting number of blocks to + // number of MCU. In interleaved mode it is 1, so MCU is represented with + // max_*_samp_factor blocks. In non-interleaved mode we choose numerator to + // be the samping factor, consequently MCU is always represented with single + // block. + const int h_group = is_interleaved ? 1 : base_component.h_samp_factor; + const int v_group = is_interleaved ? 1 : base_component.v_samp_factor; + int max_h_samp_factor = 1; + int max_v_samp_factor = 1; + for (const auto& c : components) { + max_h_samp_factor = std::max(c.h_samp_factor, max_h_samp_factor); + max_v_samp_factor = std::max(c.v_samp_factor, max_v_samp_factor); + } + *MCUs_per_row = DivCeil(width * h_group, 8 * max_h_samp_factor); + *MCU_rows = DivCeil(height * v_group, 8 * max_v_samp_factor); +} + +#if JPEGXL_ENABLE_TRANSCODE_JPEG + +Status SetJPEGDataFromICC(const std::vector<uint8_t>& icc, + jpeg::JPEGData* jpeg_data) { + size_t icc_pos = 0; + for (size_t i = 0; i < jpeg_data->app_data.size(); i++) { + if (jpeg_data->app_marker_type[i] != jpeg::AppMarkerType::kICC) { + continue; + } + size_t len = jpeg_data->app_data[i].size() - 17; + if (icc_pos + len > icc.size()) { + return JXL_FAILURE( + "ICC length is less than APP markers: requested %" PRIuS + " more bytes, " + "%" PRIuS " available", + len, icc.size() - icc_pos); + } + memcpy(&jpeg_data->app_data[i][17], icc.data() + icc_pos, len); + icc_pos += len; + } + if (icc_pos != icc.size() && icc_pos != 0) { + return JXL_FAILURE("ICC length is more than APP markers"); + } + return true; +} + +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +} // namespace jpeg +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.h b/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.h new file mode 100644 index 0000000000..4387d20066 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jpeg/jpeg_data.h @@ -0,0 +1,218 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Data structures that represent the non-pixel contents of a jpeg file. + +#ifndef LIB_JXL_JPEG_JPEG_DATA_H_ +#define LIB_JXL_JPEG_JPEG_DATA_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <array> +#include <vector> + +#include "lib/jxl/common.h" // JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_dimensions.h" + +namespace jxl { +namespace jpeg { + +constexpr int kMaxComponents = 4; +constexpr int kMaxQuantTables = 4; +constexpr int kMaxHuffmanTables = 4; +constexpr size_t kJpegHuffmanMaxBitLength = 16; +constexpr int kJpegHuffmanAlphabetSize = 256; +constexpr int kJpegDCAlphabetSize = 12; +constexpr int kMaxDHTMarkers = 512; +constexpr int kMaxDimPixels = 65535; +constexpr uint8_t kApp1 = 0xE1; +constexpr uint8_t kApp2 = 0xE2; +const uint8_t kIccProfileTag[12] = "ICC_PROFILE"; +const uint8_t kExifTag[6] = "Exif\0"; +const uint8_t kXMPTag[29] = "http://ns.adobe.com/xap/1.0/"; + +/* clang-format off */ +constexpr uint32_t kJPEGNaturalOrder[80] = { + 0, 1, 8, 16, 9, 2, 3, 10, + 17, 24, 32, 25, 18, 11, 4, 5, + 12, 19, 26, 33, 40, 48, 41, 34, + 27, 20, 13, 6, 7, 14, 21, 28, + 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, + 58, 59, 52, 45, 38, 31, 39, 46, + 53, 60, 61, 54, 47, 55, 62, 63, + // extra entries for safety in decoder + 63, 63, 63, 63, 63, 63, 63, 63, + 63, 63, 63, 63, 63, 63, 63, 63 +}; + +constexpr uint32_t kJPEGZigZagOrder[64] = { + 0, 1, 5, 6, 14, 15, 27, 28, + 2, 4, 7, 13, 16, 26, 29, 42, + 3, 8, 12, 17, 25, 30, 41, 43, + 9, 11, 18, 24, 31, 40, 44, 53, + 10, 19, 23, 32, 39, 45, 52, 54, + 20, 22, 33, 38, 46, 51, 55, 60, + 21, 34, 37, 47, 50, 56, 59, 61, + 35, 36, 48, 49, 57, 58, 62, 63 +}; +/* clang-format on */ + +// Quantization values for an 8x8 pixel block. +struct JPEGQuantTable { + std::array<int32_t, kDCTBlockSize> values; + uint32_t precision = 0; + // The index of this quantization table as it was parsed from the input JPEG. + // Each DQT marker segment contains an 'index' field, and we save this index + // here. Valid values are 0 to 3. + uint32_t index = 0; + // Set to true if this table is the last one within its marker segment. + bool is_last = true; +}; + +// Huffman code and decoding lookup table used for DC and AC coefficients. +struct JPEGHuffmanCode { + // Bit length histogram. + std::array<uint32_t, kJpegHuffmanMaxBitLength + 1> counts = {}; + // Symbol values sorted by increasing bit lengths. + std::array<uint32_t, kJpegHuffmanAlphabetSize + 1> values = {}; + // The index of the Huffman code in the current set of Huffman codes. For AC + // component Huffman codes, 0x10 is added to the index. + int slot_id = 0; + // Set to true if this Huffman code is the last one within its marker segment. + bool is_last = true; +}; + +// Huffman table indexes used for one component of one scan. +struct JPEGComponentScanInfo { + uint32_t comp_idx; + uint32_t dc_tbl_idx; + uint32_t ac_tbl_idx; +}; + +// Contains information that is used in one scan. +struct JPEGScanInfo { + // Parameters used for progressive scans (named the same way as in the spec): + // Ss : Start of spectral band in zig-zag sequence. + // Se : End of spectral band in zig-zag sequence. + // Ah : Successive approximation bit position, high. + // Al : Successive approximation bit position, low. + uint32_t Ss; + uint32_t Se; + uint32_t Ah; + uint32_t Al; + uint32_t num_components = 0; + std::array<JPEGComponentScanInfo, 4> components; + // Last codestream pass that is needed to write this scan. + uint32_t last_needed_pass = 0; + + // Extra information required for bit-precise JPEG file reconstruction. + + // Set of block indexes where the JPEG encoder has to flush the end-of-block + // runs and refinement bits. + std::vector<uint32_t> reset_points; + // The number of extra zero runs (Huffman symbol 0xf0) before the end of + // block (if nonzero), indexed by block index. + // All of these symbols can be omitted without changing the pixel values, but + // some jpeg encoders put these at the end of blocks. + typedef struct { + uint32_t block_idx; + uint32_t num_extra_zero_runs; + } ExtraZeroRunInfo; + std::vector<ExtraZeroRunInfo> extra_zero_runs; +}; + +typedef int16_t coeff_t; + +// Represents one component of a jpeg file. +struct JPEGComponent { + JPEGComponent() + : id(0), + h_samp_factor(1), + v_samp_factor(1), + quant_idx(0), + width_in_blocks(0), + height_in_blocks(0) {} + + // One-byte id of the component. + uint32_t id; + // Horizontal and vertical sampling factors. + // In interleaved mode, each minimal coded unit (MCU) has + // h_samp_factor x v_samp_factor DCT blocks from this component. + int h_samp_factor; + int v_samp_factor; + // The index of the quantization table used for this component. + uint32_t quant_idx; + // The dimensions of the component measured in 8x8 blocks. + uint32_t width_in_blocks; + uint32_t height_in_blocks; + // The DCT coefficients of this component, laid out block-by-block, divided + // through the quantization matrix values. + std::vector<coeff_t> coeffs; +}; + +enum class AppMarkerType : uint32_t { + kUnknown = 0, + kICC = 1, + kExif = 2, + kXMP = 3, +}; + +// Represents a parsed jpeg file. +struct JPEGData : public Fields { + JPEGData() + : width(0), height(0), restart_interval(0), has_zero_padding_bit(false) {} + + JXL_FIELDS_NAME(JPEGData) +#if JPEGXL_ENABLE_TRANSCODE_JPEG + // Doesn't serialize everything - skips brotli-encoded data and what is + // already encoded in the codestream. + Status VisitFields(Visitor* visitor) override; +#else + Status VisitFields(Visitor* /* visitor */) override { + JXL_UNREACHABLE("JPEG transcoding support not enabled"); + } +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + + void CalculateMcuSize(const JPEGScanInfo& scan, int* MCUs_per_row, + int* MCU_rows) const; + + int width; + int height; + uint32_t restart_interval; + std::vector<std::vector<uint8_t>> app_data; + std::vector<AppMarkerType> app_marker_type; + std::vector<std::vector<uint8_t>> com_data; + std::vector<JPEGQuantTable> quant; + std::vector<JPEGHuffmanCode> huffman_code; + std::vector<JPEGComponent> components; + std::vector<JPEGScanInfo> scan_info; + std::vector<uint8_t> marker_order; + std::vector<std::vector<uint8_t>> inter_marker_data; + std::vector<uint8_t> tail_data; + + // Extra information required for bit-precise JPEG file reconstruction. + + bool has_zero_padding_bit; + std::vector<uint8_t> padding_bits; +}; + +#if JPEGXL_ENABLE_TRANSCODE_JPEG +// Set ICC profile in jpeg_data. +Status SetJPEGDataFromICC(const std::vector<uint8_t>& icc, + jpeg::JPEGData* jpeg_data); +#else +static JXL_INLINE Status SetJPEGDataFromICC( + const std::vector<uint8_t>& /* icc */, jpeg::JPEGData* /* jpeg_data */) { + JXL_UNREACHABLE("JPEG transcoding support not enabled"); +} +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +} // namespace jpeg +} // namespace jxl + +#endif // LIB_JXL_JPEG_JPEG_DATA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/jxl.syms b/third_party/jpeg-xl/lib/jxl/jxl.syms new file mode 100644 index 0000000000..0f398d7151 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jxl.syms @@ -0,0 +1,5 @@ +{ + extern "C" { + jpegxl_*; + }; +}; diff --git a/third_party/jpeg-xl/lib/jxl/jxl.version b/third_party/jpeg-xl/lib/jxl/jxl.version new file mode 100644 index 0000000000..26b0e9e54d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jxl.version @@ -0,0 +1,17 @@ +JXL_0 { + global: + Jxl*; + + local: + # Hide all the std namespace symbols. std namespace is explicitly marked + # as visibility(default) and header-only functions or methods (such as those + # from templates) should be exposed in shared libraries as weak symbols but + # this is only needed when we expose those types in the shared library API + # in any way. We don't use C++ std types in the API and we also don't + # support exceptions in the library. + # See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=36022 for a discussion + # about this. + extern "C++" { + *std::*; + }; +}; diff --git a/third_party/jpeg-xl/lib/jxl/jxl_osx.syms b/third_party/jpeg-xl/lib/jxl/jxl_osx.syms new file mode 100644 index 0000000000..96bc568025 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jxl_osx.syms @@ -0,0 +1 @@ +_Jxl* diff --git a/third_party/jpeg-xl/lib/jxl/jxl_test.cc b/third_party/jpeg-xl/lib/jxl/jxl_test.cc new file mode 100644 index 0000000000..a91dbd0672 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/jxl_test.cc @@ -0,0 +1,1677 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/extras/dec/jxl.h" + +#include <jxl/cms.h> +#include <jxl/color_encoding.h> +#include <jxl/encode.h> +#include <jxl/types.h> + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <cstdio> +#include <cstring> +#include <future> +#include <ostream> +#include <string> +#include <tuple> +#include <vector> + +#include "lib/extras/codec.h" +#include "lib/extras/dec/decode.h" +#include "lib/extras/enc/encode.h" +#include "lib/extras/enc/jxl.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" // JXL_HIGH_PRECISION +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/fake_parallel_runner_testonly.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/test_image.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { + +struct AuxOut; + +namespace { +using extras::JXLCompressParams; +using extras::JXLDecompressParams; +using extras::PackedPixelFile; +using test::ButteraugliDistance; +using test::ComputeDistance2; +using test::ReadTestData; +using test::Roundtrip; +using test::TestImage; +using test::ThreadPoolForTests; + +#define JXL_TEST_NL 0 // Disabled in code + +TEST(JxlTest, RoundtripSinglePixel) { + TestImage t; + t.SetDimensions(1, 1).AddFrame().ZeroFill(); + PackedPixelFile ppf_out; + EXPECT_EQ(Roundtrip(t.ppf(), {}, {}, nullptr, &ppf_out), 55); +} + +TEST(JxlTest, RoundtripSinglePixelWithAlpha) { + TestImage t; + t.SetDimensions(1, 1).SetChannels(4).AddFrame().ZeroFill(); + PackedPixelFile ppf_out; + EXPECT_EQ(Roundtrip(t.ppf(), {}, {}, nullptr, &ppf_out), 59); +} + +// Changing serialized signature causes Decode to fail. +#ifndef JXL_CRASH_ON_ERROR +TEST(JxlTest, RoundtripMarker) { + TestImage t; + t.SetDimensions(1, 1).AddFrame().ZeroFill(); + for (size_t i = 0; i < 2; ++i) { + std::vector<uint8_t> compressed; + EXPECT_TRUE(extras::EncodeImageJXL({}, t.ppf(), /*jpeg_bytes=*/nullptr, + &compressed)); + compressed[i] ^= 0xFF; + PackedPixelFile ppf_out; + EXPECT_FALSE(extras::DecodeImageJXL(compressed.data(), compressed.size(), + {}, /*decodec_bytes=*/nullptr, + &ppf_out)); + } +} +#endif + +TEST(JxlTest, RoundtripTinyFast) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata().SetDimensions(32, 32); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 7); + cparams.distance = 4.0f; + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 181, 15); +} + +TEST(JxlTest, RoundtripSmallD1) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + size_t xsize = t.ppf().info.xsize / 8; + size_t ysize = t.ppf().info.ysize / 8; + t.SetDimensions(xsize, ysize); + + { + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), {}, {}, pool, &ppf_out), 916, 40); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(0.888)); + } + + // With a lower intensity target than the default, the bitrate should be + // smaller. + t.ppf().info.intensity_target = 100.0f; + + { + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), {}, {}, pool, &ppf_out), 745, 20); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(1.3)); + EXPECT_EQ(ppf_out.info.intensity_target, t.ppf().info.intensity_target); + } +} +TEST(JxlTest, RoundtripResample2) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_RESAMPLING, 2); + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 3); // kFalcon + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 18500, 200); + EXPECT_THAT(ComputeDistance2(t.ppf(), ppf_out), IsSlightlyBelow(90)); +} + +TEST(JxlTest, RoundtripResample2Slow) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_RESAMPLING, 2); + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 9); // kTortoise + cparams.distance = 10.0; + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 3888, 200); + EXPECT_THAT(ComputeDistance2(t.ppf(), ppf_out), IsSlightlyBelow(250)); +} + +TEST(JxlTest, RoundtripResample2MT) { + ThreadPoolForTests pool(4); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + // image has to be large enough to have multiple groups after downsampling + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_RESAMPLING, 2); + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 3); // kFalcon + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, &pool, &ppf_out), 223310, 2000); + EXPECT_THAT(ComputeDistance2(t.ppf(), ppf_out), IsSlightlyBelow(340)); +} + +// Roundtrip the image using a parallel runner that executes single-threaded but +// in random order. +TEST(JxlTest, RoundtripOutOfOrderProcessing) { + FakeParallelRunner fake_pool(/*order_seed=*/123, /*num_threads=*/8); + ThreadPool pool(&JxlFakeParallelRunner, &fake_pool); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + // Image size is selected so that the block border needed is larger than the + // amount of pixels available on the next block. + t.SetDimensions(513, 515); + + JXLCompressParams cparams; + // Force epf so we end up needing a lot of border. + cparams.AddOption(JXL_ENC_FRAME_SETTING_EPF, 3); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, &pool, &ppf_out), 27444, 400); + EXPECT_LE(ButteraugliDistance(t.ppf(), ppf_out), 1.35); +} + +TEST(JxlTest, RoundtripOutOfOrderProcessingBorder) { + FakeParallelRunner fake_pool(/*order_seed=*/47, /*num_threads=*/8); + ThreadPool pool(&JxlFakeParallelRunner, &fake_pool); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + // Image size is selected so that the block border needed is larger than the + // amount of pixels available on the next block. + t.SetDimensions(513, 515); + + JXLCompressParams cparams; + // Force epf so we end up needing a lot of border. + cparams.AddOption(JXL_ENC_FRAME_SETTING_EPF, 3); + cparams.AddOption(JXL_ENC_FRAME_SETTING_RESAMPLING, 2); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, &pool, &ppf_out), 10065, 200); + EXPECT_LE(ButteraugliDistance(t.ppf(), ppf_out), 2.9); +} + +TEST(JxlTest, RoundtripResample4) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_RESAMPLING, 4); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 5758, 100); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(22)); +} + +TEST(JxlTest, RoundtripResample8) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_RESAMPLING, 8); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 2036, 50); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(50)); +} + +TEST(JxlTest, RoundtripUnalignedD2) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + size_t xsize = t.ppf().info.xsize / 12; + size_t ysize = t.ppf().info.ysize / 7; + t.SetDimensions(xsize, ysize); + + JXLCompressParams cparams; + cparams.distance = 2.0; + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 506, 30); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(1.72)); +} + +TEST(JxlTest, RoundtripMultiGroup) { + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata().SetDimensions(600, 1024); + + auto test = [&](jxl::SpeedTier speed_tier, float target_distance, + size_t expected_size, float expected_distance) { + ThreadPoolForTests pool(4); + JXLCompressParams cparams; + int64_t effort = 10 - static_cast<int>(speed_tier); + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, effort); + cparams.distance = target_distance; + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, &pool, &ppf_out), expected_size, + 700); + EXPECT_THAT(ComputeDistance2(t.ppf(), ppf_out), + IsSlightlyBelow(expected_distance)); + }; + + auto run_kitten = std::async(std::launch::async, test, SpeedTier::kKitten, + 1.0f, 63624u, 8.5); + auto run_wombat = std::async(std::launch::async, test, SpeedTier::kWombat, + 2.0f, 39620u, 15.5); +} + +TEST(JxlTest, RoundtripRGBToGrayscale) { + ThreadPoolForTests pool(4); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io, &pool)); + io.ShrinkTo(600, 1024); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0f; + cparams.speed_tier = SpeedTier::kFalcon; + + JXLDecompressParams dparams; + dparams.color_space = "Gra_D65_Rel_SRG"; + + CodecInOut io2; + EXPECT_FALSE(io.Main().IsGray()); + size_t compressed_size; + JXL_EXPECT_OK( + Roundtrip(&io, cparams, dparams, &io2, _, &compressed_size, &pool)); + EXPECT_LE(compressed_size, 65000u); + EXPECT_TRUE(io2.Main().IsGray()); + + // Convert original to grayscale here, because TransformTo refuses to + // convert between grayscale and RGB. + ColorEncoding srgb_lin = ColorEncoding::LinearSRGB(/*is_gray=*/false); + ASSERT_TRUE(io.frames[0].TransformTo(srgb_lin, *JxlGetDefaultCms())); + Image3F* color = io.Main().color(); + for (size_t y = 0; y < color->ysize(); ++y) { + float* row_r = color->PlaneRow(0, y); + float* row_g = color->PlaneRow(1, y); + float* row_b = color->PlaneRow(2, y); + for (size_t x = 0; x < color->xsize(); ++x) { + float luma = 0.2126 * row_r[x] + 0.7152 * row_g[x] + 0.0722 * row_b[x]; + row_r[x] = row_g[x] = row_b[x] = luma; + } + } + ColorEncoding srgb_gamma = ColorEncoding::SRGB(/*is_gray=*/false); + ASSERT_TRUE(io.frames[0].TransformTo(srgb_gamma, *JxlGetDefaultCms())); + io.metadata.m.color_encoding = io2.Main().c_current(); + io.Main().OverrideProfile(io2.Main().c_current()); + EXPECT_THAT(ButteraugliDistance(io.frames, io2.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr, &pool), + IsSlightlyBelow(1.36)); +} + +TEST(JxlTest, RoundtripLargeFast) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 7); // kSquirrel + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, &pool, &ppf_out), 492867, 5000); + EXPECT_THAT(ComputeDistance2(t.ppf(), ppf_out), IsSlightlyBelow(78)); +} + +TEST(JxlTest, RoundtripDotsForceEpf) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/cvo9xd_keong_macan_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 7); // kSquirrel + cparams.AddOption(JXL_ENC_FRAME_SETTING_EPF, 2); + cparams.AddOption(JXL_ENC_FRAME_SETTING_DOTS, 1); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, &pool, &ppf_out), 41355, 300); + EXPECT_THAT(ComputeDistance2(t.ppf(), ppf_out), IsSlightlyBelow(18)); +} + +// Checks for differing size/distance in two consecutive runs of distance 2, +// which involves additional processing including adaptive reconstruction. +// Failing this may be a sign of race conditions or invalid memory accesses. +TEST(JxlTest, RoundtripD2Consistent) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 7); // kSquirrel + cparams.distance = 2.0; + + // Try each xsize mod kBlockDim to verify right border handling. + for (size_t xsize = 48; xsize > 40; --xsize) { + t.SetDimensions(xsize, 15); + + PackedPixelFile ppf2; + const size_t size2 = Roundtrip(t.ppf(), cparams, {}, &pool, &ppf2); + + PackedPixelFile ppf3; + const size_t size3 = Roundtrip(t.ppf(), cparams, {}, &pool, &ppf3); + + // Exact same compressed size. + EXPECT_EQ(size2, size3); + + // Exact same distance. + const float dist2 = ComputeDistance2(t.ppf(), ppf2); + const float dist3 = ComputeDistance2(t.ppf(), ppf3); + EXPECT_EQ(dist2, dist3); + } +} + +// Same as above, but for full image, testing multiple groups. +TEST(JxlTest, RoundtripLargeConsistent) { + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 7); // kSquirrel + cparams.distance = 2.0; + + auto roundtrip_and_compare = [&]() { + ThreadPoolForTests pool(8); + PackedPixelFile ppf2; + size_t size = Roundtrip(t.ppf(), cparams, {}, &pool, &ppf2); + double dist = ComputeDistance2(t.ppf(), ppf2); + return std::tuple<size_t, double>(size, dist); + }; + + // Try each xsize mod kBlockDim to verify right border handling. + auto future2 = std::async(std::launch::async, roundtrip_and_compare); + auto future3 = std::async(std::launch::async, roundtrip_and_compare); + + const auto result2 = future2.get(); + const auto result3 = future3.get(); + + // Exact same compressed size. + EXPECT_EQ(std::get<0>(result2), std::get<0>(result3)); + + // Exact same distance. + EXPECT_EQ(std::get<1>(result2), std::get<1>(result3)); +} + +TEST(JxlTest, RoundtripSmallNL) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + size_t xsize = t.ppf().info.xsize / 8; + size_t ysize = t.ppf().info.ysize / 8; + t.SetDimensions(xsize, ysize); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), {}, {}, pool, &ppf_out), 916, 45); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(0.82)); +} + +TEST(JxlTest, RoundtripNoGaborishNoAR) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EPF, 0); + cparams.AddOption(JXL_ENC_FRAME_SETTING_GABORISH, 0); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 41142, 400); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(1.8)); +} + +TEST(JxlTest, RoundtripSmallNoGaborish) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + size_t xsize = t.ppf().info.xsize / 8; + size_t ysize = t.ppf().info.ysize / 8; + t.SetDimensions(xsize, ysize); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_GABORISH, 0); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 1006, 20); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(1.1)); +} + +TEST(JxlTest, RoundtripSmallPatchesAlpha) { + ThreadPool* pool = nullptr; + TestImage t; + t.SetDimensions(256, 256).SetChannels(4); + t.SetColorEncoding("RGB_D65_SRG_Rel_Lin"); + TestImage::Frame frame = t.AddFrame(); + frame.ZeroFill(); + // This pattern should be picked up by the patch detection heuristics. + for (size_t y = 0; y < t.ppf().info.ysize; ++y) { + for (size_t x = 0; x < t.ppf().info.xsize; ++x) { + if (x % 4 == 0 && (y / 32) % 4 == 0) { + frame.SetValue(y, x, 1, 127.0f / 255.0f); + } + frame.SetValue(y, x, 3, 1.0f); + } + } + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 7); // kSquirrel + cparams.distance = 0.1f; + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 597, 100); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(0.018f)); +} + +TEST(JxlTest, RoundtripSmallPatches) { + ThreadPool* pool = nullptr; + TestImage t; + t.SetDimensions(256, 256); + t.SetColorEncoding("RGB_D65_SRG_Rel_Lin"); + TestImage::Frame frame = t.AddFrame(); + frame.ZeroFill(); + // This pattern should be picked up by the patch detection heuristics. + for (size_t y = 0; y < t.ppf().info.ysize; ++y) { + for (size_t x = 0; x < t.ppf().info.xsize; ++x) { + if (x % 4 == 0 && (y / 32) % 4 == 0) { + frame.SetValue(y, x, 1, 127.0f / 255.0f); + } + } + } + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 7); // kSquirrel + cparams.distance = 0.1f; + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 486, 100); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(0.018f)); +} + +// TODO(szabadka) Add encoder and decoder API functions that accept frame +// buffers in arbitrary unsigned and floating point formats, and then roundtrip +// test the lossless codepath to make sure the exact binary representations +// are preserved. +#if 0 +TEST(JxlTest, RoundtripImageBundleOriginalBits) { + // Image does not matter, only io.metadata.m and io2.metadata.m are tested. + Image3F image(1, 1); + ZeroFillImage(&image); + CodecInOut io; + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + io.SetFromImage(std::move(image), ColorEncoding::LinearSRGB()); + + CompressParams cparams; + + // Test unsigned integers from 1 to 32 bits + for (uint32_t bit_depth = 1; bit_depth <= 32; bit_depth++) { + if (bit_depth == 32) { + // TODO(lode): allow testing 32, however the code below ends up in + // enc_modular which does not support 32. We only want to test the header + // encoding though, so try without modular. + break; + } + + io.metadata.m.SetUintSamples(bit_depth); + CodecInOut io2; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _)); + + EXPECT_EQ(bit_depth, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io2.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_EQ(0u, io2.metadata.m.GetAlphaBits()); + } + + // Test various existing and non-existing floating point formats + for (uint32_t bit_depth = 8; bit_depth <= 32; bit_depth++) { + if (bit_depth != 32) { + // TODO(user): test other float types once they work + break; + } + + uint32_t exponent_bit_depth; + if (bit_depth < 10) { + exponent_bit_depth = 2; + } else if (bit_depth < 12) { + exponent_bit_depth = 3; + } else if (bit_depth < 16) { + exponent_bit_depth = 4; + } else if (bit_depth < 20) { + exponent_bit_depth = 5; + } else if (bit_depth < 24) { + exponent_bit_depth = 6; + } else if (bit_depth < 28) { + exponent_bit_depth = 7; + } else { + exponent_bit_depth = 8; + } + + io.metadata.m.bit_depth.bits_per_sample = bit_depth; + io.metadata.m.bit_depth.floating_point_sample = true; + io.metadata.m.bit_depth.exponent_bits_per_sample = exponent_bit_depth; + + CodecInOut io2; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2)); + + EXPECT_EQ(bit_depth, io2.metadata.m.bit_depth.bits_per_sample); + EXPECT_TRUE(io2.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(exponent_bit_depth, + io2.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_EQ(0u, io2.metadata.m.GetAlphaBits()); + } +} +#endif + +TEST(JxlTest, RoundtripGrayscale) { + const std::vector<uint8_t> orig = ReadTestData( + "external/wesaturate/500px/cvo9xd_keong_macan_grayscale.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + ASSERT_NE(io.xsize(), 0u); + io.ShrinkTo(128, 128); + EXPECT_TRUE(io.Main().IsGray()); + EXPECT_EQ(8u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.Tf().IsSRGB()); + + { + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + + std::vector<uint8_t> compressed; + EXPECT_TRUE(test::EncodeFile(cparams, &io, &compressed)); + CodecInOut io2; + EXPECT_TRUE(test::DecodeFile({}, Bytes(compressed), &io2)); + EXPECT_TRUE(io2.Main().IsGray()); + + EXPECT_LE(compressed.size(), 7000u); + EXPECT_THAT(ButteraugliDistance(io.frames, io2.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(1.6)); + } + + // Test with larger butteraugli distance and other settings enabled so + // different jxl codepaths trigger. + { + CompressParams cparams; + cparams.butteraugli_distance = 8.0; + + std::vector<uint8_t> compressed; + EXPECT_TRUE(test::EncodeFile(cparams, &io, &compressed)); + CodecInOut io2; + EXPECT_TRUE(test::DecodeFile({}, Bytes(compressed), &io2)); + EXPECT_TRUE(io2.Main().IsGray()); + + EXPECT_LE(compressed.size(), 1300u); + EXPECT_THAT(ButteraugliDistance(io.frames, io2.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(6.7)); + } + + { + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + + std::vector<uint8_t> compressed; + EXPECT_TRUE(test::EncodeFile(cparams, &io, &compressed)); + + CodecInOut io2; + JXLDecompressParams dparams; + dparams.color_space = "RGB_D65_SRG_Rel_SRG"; + EXPECT_TRUE(test::DecodeFile(dparams, Bytes(compressed), &io2)); + EXPECT_FALSE(io2.Main().IsGray()); + + EXPECT_LE(compressed.size(), 7000u); + EXPECT_THAT(ButteraugliDistance(io.frames, io2.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(1.6)); + } +} + +TEST(JxlTest, RoundtripAlpha) { + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + + ASSERT_NE(io.xsize(), 0u); + ASSERT_TRUE(io.metadata.m.HasAlpha()); + ASSERT_TRUE(io.Main().HasAlpha()); + io.ShrinkTo(300, 300); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + + EXPECT_EQ(8u, io.metadata.m.bit_depth.bits_per_sample); + EXPECT_FALSE(io.metadata.m.bit_depth.floating_point_sample); + EXPECT_EQ(0u, io.metadata.m.bit_depth.exponent_bits_per_sample); + EXPECT_TRUE(io.metadata.m.color_encoding.Tf().IsSRGB()); + std::vector<uint8_t> compressed; + EXPECT_TRUE(test::EncodeFile(cparams, &io, &compressed)); + + EXPECT_LE(compressed.size(), 10077u); + + for (bool use_image_callback : {false, true}) { + for (bool unpremul_alpha : {false, true}) { + CodecInOut io2; + JXLDecompressParams dparams; + dparams.use_image_callback = use_image_callback; + dparams.unpremultiply_alpha = unpremul_alpha; + EXPECT_TRUE(test::DecodeFile(dparams, Bytes(compressed), &io2)); + EXPECT_THAT(ButteraugliDistance(io.frames, io2.frames, + ButteraugliParams(), *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(1.15)); + } + } +} + +namespace { +// Performs "PremultiplyAlpha" for each ImageBundle (preview/frames). +bool PremultiplyAlpha(CodecInOut& io) { + const auto doPremultiplyAlpha = [](ImageBundle& bundle) { + if (!bundle.HasAlpha()) return; + if (!bundle.HasColor()) return; + auto* color = bundle.color(); + const auto* alpha = bundle.alpha(); + JXL_CHECK(color->ysize() == alpha->ysize()); + JXL_CHECK(color->xsize() == alpha->xsize()); + for (size_t y = 0; y < color->ysize(); y++) { + ::jxl::PremultiplyAlpha(color->PlaneRow(0, y), color->PlaneRow(1, y), + color->PlaneRow(2, y), alpha->Row(y), + color->xsize()); + } + }; + ExtraChannelInfo* eci = io.metadata.m.Find(ExtraChannel::kAlpha); + if (eci == nullptr || eci->alpha_associated) return false; + if (io.metadata.m.have_preview) { + doPremultiplyAlpha(io.preview_frame); + } + for (ImageBundle& ib : io.frames) { + doPremultiplyAlpha(ib); + } + eci->alpha_associated = true; + return true; +} + +bool UnpremultiplyAlpha(CodecInOut& io) { + const auto doUnpremultiplyAlpha = [](ImageBundle& bundle) { + if (!bundle.HasAlpha()) return; + if (!bundle.HasColor()) return; + auto* color = bundle.color(); + const auto* alpha = bundle.alpha(); + JXL_CHECK(color->ysize() == alpha->ysize()); + JXL_CHECK(color->xsize() == alpha->xsize()); + for (size_t y = 0; y < color->ysize(); y++) { + ::jxl::UnpremultiplyAlpha(color->PlaneRow(0, y), color->PlaneRow(1, y), + color->PlaneRow(2, y), alpha->Row(y), + color->xsize()); + } + }; + ExtraChannelInfo* eci = io.metadata.m.Find(ExtraChannel::kAlpha); + if (eci == nullptr || !eci->alpha_associated) return false; + if (io.metadata.m.have_preview) { + doUnpremultiplyAlpha(io.preview_frame); + } + for (ImageBundle& ib : io.frames) { + doUnpremultiplyAlpha(ib); + } + eci->alpha_associated = false; + return true; +} +} // namespace + +TEST(JxlTest, RoundtripAlphaPremultiplied) { + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + CodecInOut io, io_nopremul; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io_nopremul)); + + ASSERT_NE(io.xsize(), 0u); + ASSERT_TRUE(io.metadata.m.HasAlpha()); + ASSERT_TRUE(io.Main().HasAlpha()); + io.ShrinkTo(300, 300); + io_nopremul.ShrinkTo(300, 300); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + cparams.SetCms(*JxlGetDefaultCms()); + + EXPECT_FALSE(io.Main().AlphaIsPremultiplied()); + EXPECT_TRUE(PremultiplyAlpha(io)); + EXPECT_TRUE(io.Main().AlphaIsPremultiplied()); + + EXPECT_FALSE(io_nopremul.Main().AlphaIsPremultiplied()); + + std::vector<uint8_t> compressed; + EXPECT_TRUE(test::EncodeFile(cparams, &io, &compressed)); + EXPECT_LE(compressed.size(), 10000u); + + for (bool use_image_callback : {false, true}) { + for (bool unpremul_alpha : {false, true}) { + for (bool use_uint8 : {false, true}) { + printf( + "Testing premultiplied alpha using %s %s requesting " + "%spremultiplied output.\n", + use_uint8 ? "uint8" : "float", + use_image_callback ? "image callback" : "image_buffer", + unpremul_alpha ? "un" : ""); + CodecInOut io2; + JXLDecompressParams dparams; + dparams.use_image_callback = use_image_callback; + dparams.unpremultiply_alpha = unpremul_alpha; + if (use_uint8) { + dparams.accepted_formats = { + {4, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}}; + } + EXPECT_TRUE(test::DecodeFile(dparams, Bytes(compressed), &io2)); + + EXPECT_EQ(unpremul_alpha, !io2.Main().AlphaIsPremultiplied()); + if (!unpremul_alpha) { + EXPECT_THAT( + ButteraugliDistance(io.frames, io2.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(1.111)); + EXPECT_TRUE(UnpremultiplyAlpha(io2)); + EXPECT_FALSE(io2.Main().AlphaIsPremultiplied()); + } + EXPECT_THAT( + ButteraugliDistance(io_nopremul.frames, io2.frames, + ButteraugliParams(), *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(1.55)); + } + } + } +} + +TEST(JxlTest, RoundtripAlphaResampling) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + ASSERT_NE(t.ppf().info.xsize, 0); + ASSERT_TRUE(t.ppf().info.alpha_bits > 0); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 5); // kHare + cparams.AddOption(JXL_ENC_FRAME_SETTING_RESAMPLING, 2); + cparams.AddOption(JXL_ENC_FRAME_SETTING_EXTRA_CHANNEL_RESAMPLING, 2); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 13507, 130); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(5.2)); +} + +TEST(JxlTest, RoundtripAlphaResamplingOnlyAlpha) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + ASSERT_NE(t.ppf().info.xsize, 0); + ASSERT_TRUE(t.ppf().info.alpha_bits > 0); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 3); // kFalcon + cparams.AddOption(JXL_ENC_FRAME_SETTING_EXTRA_CHANNEL_RESAMPLING, 2); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 33571, 400); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(1.49)); +} + +TEST(JxlTest, RoundtripAlphaNonMultipleOf8) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata().SetDimensions(12, 12); + ASSERT_NE(t.ppf().info.xsize, 0); + ASSERT_TRUE(t.ppf().info.alpha_bits > 0); + EXPECT_EQ(t.ppf().frames[0].color.format.data_type, JXL_TYPE_UINT8); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), {}, {}, pool, &ppf_out), 107, 10); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(0.95)); +} + +TEST(JxlTest, RoundtripAlpha16) { + ThreadPoolForTests pool(4); + // The image is wider than 512 pixels to ensure multiple groups are tested. + size_t xsize = 1200, ysize = 160; + TestImage t; + t.SetDimensions(xsize, ysize).SetChannels(4).SetAllBitDepths(16); + TestImage::Frame frame = t.AddFrame(); + // Generate 16-bit pattern that uses various colors and alpha values. + const float mul = 1.0f / 65535; + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + uint16_t r = y * 65535 / ysize; + uint16_t g = x * 65535 / xsize; + uint16_t b = (y + x) * 65535 / (xsize + ysize); + frame.SetValue(y, x, 0, r * mul); + frame.SetValue(y, x, 1, g * mul); + frame.SetValue(y, x, 2, b * mul); + frame.SetValue(y, x, 3, g * mul); + } + } + + ASSERT_NE(t.ppf().info.xsize, 0); + ASSERT_EQ(t.ppf().info.alpha_bits, 16); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 6); // kWombat + cparams.distance = 0.5; + + PackedPixelFile ppf_out; + // TODO(szabadka) Investigate big size difference on i686 + // This still keeps happening (2023-04-18). + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, &pool, &ppf_out), 3666, 120); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(0.65)); +} + +namespace { +JXLCompressParams CompressParamsForLossless() { + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_MODULAR, 1); + cparams.AddOption(JXL_ENC_FRAME_SETTING_COLOR_TRANSFORM, 1); + cparams.AddOption(JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR, 6); // Weighted + cparams.distance = 0; + return cparams; +} +} // namespace + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLossless8)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams = CompressParamsForLossless(); + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + EXPECT_EQ(Roundtrip(t.ppf(), cparams, dparams, &pool, &ppf_out), 223058); + EXPECT_EQ(ComputeDistance2(t.ppf(), ppf_out), 0.0); +} + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLossless8ThunderGradient)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams = CompressParamsForLossless(); + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 2); // kThunder + cparams.AddOption(JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR, 5); // Gradient + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + EXPECT_EQ(Roundtrip(t.ppf(), cparams, dparams, &pool, &ppf_out), 261684); + EXPECT_EQ(ComputeDistance2(t.ppf(), ppf_out), 0.0); +} + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLossless8LightningGradient)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams = CompressParamsForLossless(); + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 1); // kLightning + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + // Lax comparison because different SIMD will cause different compression. + EXPECT_THAT(Roundtrip(t.ppf(), cparams, dparams, &pool, &ppf_out), + IsSlightlyBelow(286848u)); + EXPECT_EQ(ComputeDistance2(t.ppf(), ppf_out), 0.0); +} + +TEST(JxlTest, JXL_SLOW_TEST(RoundtripLossless8Falcon)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams = CompressParamsForLossless(); + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 3); // kFalcon + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + EXPECT_EQ(Roundtrip(t.ppf(), cparams, dparams, &pool, &ppf_out), 230766); + EXPECT_EQ(ComputeDistance2(t.ppf(), ppf_out), 0.0); +} + +TEST(JxlTest, RoundtripLossless8Alpha) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/tmshre_riaphotographs_alpha.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + ASSERT_EQ(t.ppf().info.alpha_bits, 8); + EXPECT_EQ(t.ppf().frames[0].color.format.data_type, JXL_TYPE_UINT8); + + JXLCompressParams cparams = CompressParamsForLossless(); + + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + EXPECT_EQ(Roundtrip(t.ppf(), cparams, dparams, pool, &ppf_out), 251470); + EXPECT_EQ(ComputeDistance2(t.ppf(), ppf_out), 0.0); + EXPECT_EQ(ppf_out.info.alpha_bits, 8); + EXPECT_TRUE(test::SameAlpha(t.ppf(), ppf_out)); +} + +TEST(JxlTest, RoundtripLossless16Alpha) { + ThreadPool* pool = nullptr; + size_t xsize = 1200, ysize = 160; + TestImage t; + t.SetDimensions(xsize, ysize).SetChannels(4).SetAllBitDepths(16); + TestImage::Frame frame = t.AddFrame(); + // Generate 16-bit pattern that uses various colors and alpha values. + const float mul = 1.0f / 65535; + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + uint16_t r = y * 65535 / ysize; + uint16_t g = x * 65535 / xsize + 37; + uint16_t b = (y + x) * 65535 / (xsize + ysize); + frame.SetValue(y, x, 0, r * mul); + frame.SetValue(y, x, 1, g * mul); + frame.SetValue(y, x, 2, b * mul); + frame.SetValue(y, x, 3, g * mul); + } + } + ASSERT_EQ(t.ppf().info.bits_per_sample, 16); + ASSERT_EQ(t.ppf().info.alpha_bits, 16); + + JXLCompressParams cparams = CompressParamsForLossless(); + + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + // TODO(szabadka) Investigate big size difference on i686 + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, dparams, pool, &ppf_out), 4884, 100); + EXPECT_EQ(ComputeDistance2(t.ppf(), ppf_out), 0.0); + EXPECT_EQ(ppf_out.info.alpha_bits, 16); + EXPECT_TRUE(test::SameAlpha(t.ppf(), ppf_out)); +} + +TEST(JxlTest, RoundtripLossless16AlphaNotMisdetectedAs8Bit) { + ThreadPool* pool = nullptr; + size_t xsize = 128, ysize = 128; + TestImage t; + t.SetDimensions(xsize, ysize).SetChannels(4).SetAllBitDepths(16); + TestImage::Frame frame = t.AddFrame(); + // All 16-bit values, both color and alpha, of this image are below 64. + // This allows testing if a code path wrongly concludes it's an 8-bit instead + // of 16-bit image (or even 6-bit). + const float mul = 1.0f / 65535; + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + uint16_t r = y * 64 / ysize; + uint16_t g = x * 64 / xsize + 37; + uint16_t b = (y + x) * 64 / (xsize + ysize); + frame.SetValue(y, x, 0, r * mul); + frame.SetValue(y, x, 1, g * mul); + frame.SetValue(y, x, 2, b * mul); + frame.SetValue(y, x, 3, g * mul); + } + } + ASSERT_EQ(t.ppf().info.bits_per_sample, 16); + ASSERT_EQ(t.ppf().info.alpha_bits, 16); + + JXLCompressParams cparams = CompressParamsForLossless(); + + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, dparams, pool, &ppf_out), 591, 50); + EXPECT_EQ(ComputeDistance2(t.ppf(), ppf_out), 0.0); + EXPECT_EQ(ppf_out.info.bits_per_sample, 16); + EXPECT_EQ(ppf_out.info.alpha_bits, 16); + EXPECT_TRUE(test::SameAlpha(t.ppf(), ppf_out)); +} + +TEST(JxlTest, RoundtripDots) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/cvo9xd_keong_macan_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + ASSERT_NE(t.ppf().info.xsize, 0); + EXPECT_EQ(t.ppf().info.bits_per_sample, 8); + EXPECT_EQ(t.ppf().color_encoding.transfer_function, + JXL_TRANSFER_FUNCTION_SRGB); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 7); // kSkirrel + cparams.AddOption(JXL_ENC_FRAME_SETTING_DOTS, 1); + cparams.distance = 0.04; + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 280333, 4000); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(0.35)); +} + +TEST(JxlTest, RoundtripNoise) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + ASSERT_NE(t.ppf().info.xsize, 0); + EXPECT_EQ(t.ppf().info.bits_per_sample, 8); + EXPECT_EQ(t.ppf().color_encoding.transfer_function, + JXL_TRANSFER_FUNCTION_SRGB); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 7); // kSkirrel + cparams.AddOption(JXL_ENC_FRAME_SETTING_NOISE, 1); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, pool, &ppf_out), 41009, 750); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(1.42)); +} + +TEST(JxlTest, RoundtripLossless8Gray) { + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = ReadTestData( + "external/wesaturate/500px/cvo9xd_keong_macan_grayscale.png"); + TestImage t; + t.SetColorEncoding("Gra_D65_Rel_SRG").DecodeFromBytes(orig).ClearMetadata(); + EXPECT_EQ(t.ppf().color_encoding.color_space, JXL_COLOR_SPACE_GRAY); + EXPECT_EQ(t.ppf().info.bits_per_sample, 8); + + JXLCompressParams cparams = CompressParamsForLossless(); + + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + EXPECT_EQ(Roundtrip(t.ppf(), cparams, dparams, pool, &ppf_out), 92185); + EXPECT_EQ(ComputeDistance2(t.ppf(), ppf_out), 0.0); + EXPECT_EQ(ppf_out.color_encoding.color_space, JXL_COLOR_SPACE_GRAY); + EXPECT_EQ(ppf_out.info.bits_per_sample, 8); +} + +TEST(JxlTest, RoundtripAnimation) { + if (!jxl::extras::CanDecode(jxl::extras::Codec::kGIF)) { + fprintf(stderr, "Skipping test because of missing GIF decoder.\n"); + return; + } + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = ReadTestData("jxl/traffic_light.gif"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + EXPECT_EQ(4, t.ppf().frames.size()); + + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + EXPECT_THAT(Roundtrip(t.ppf(), {}, dparams, pool, &ppf_out), + IsSlightlyBelow(2888)); + + t.CoalesceGIFAnimationWithAlpha(); + ASSERT_EQ(ppf_out.frames.size(), t.ppf().frames.size()); + EXPECT_LE(ButteraugliDistance(t.ppf(), ppf_out), +#if JXL_HIGH_PRECISION + 1.55); +#else + 1.75); +#endif +} + +TEST(JxlTest, RoundtripLosslessAnimation) { + if (!jxl::extras::CanDecode(jxl::extras::Codec::kGIF)) { + fprintf(stderr, "Skipping test because of missing GIF decoder.\n"); + return; + } + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = ReadTestData("jxl/traffic_light.gif"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + EXPECT_EQ(4, t.ppf().frames.size()); + + JXLCompressParams cparams = CompressParamsForLossless(); + + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + EXPECT_THAT(Roundtrip(t.ppf(), cparams, dparams, pool, &ppf_out), + IsSlightlyBelow(958)); + + t.CoalesceGIFAnimationWithAlpha(); + ASSERT_EQ(ppf_out.frames.size(), t.ppf().frames.size()); + EXPECT_LE(ButteraugliDistance(t.ppf(), ppf_out), 5e-4); +} + +TEST(JxlTest, RoundtripAnimationPatches) { + if (!jxl::extras::CanDecode(jxl::extras::Codec::kGIF)) { + fprintf(stderr, "Skipping test because of missing GIF decoder.\n"); + return; + } + ThreadPool* pool = nullptr; + const std::vector<uint8_t> orig = ReadTestData("jxl/animation_patches.gif"); + + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + ASSERT_EQ(2u, t.ppf().frames.size()); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_PATCHES, 1); + + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + // 40k with no patches, 27k with patch frames encoded multiple times. + EXPECT_THAT(Roundtrip(t.ppf(), cparams, dparams, pool, &ppf_out), + IsSlightlyBelow(19300)); + EXPECT_EQ(ppf_out.frames.size(), t.ppf().frames.size()); + // >10 with broken patches; not all patches are detected on borders. + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(1.9)); +} + +size_t RoundtripJpeg(const std::vector<uint8_t>& jpeg_in, ThreadPool* pool) { + std::vector<uint8_t> compressed; + EXPECT_TRUE(extras::EncodeImageJXL({}, extras::PackedPixelFile(), &jpeg_in, + &compressed)); + + jxl::JXLDecompressParams dparams; + test::DefaultAcceptedFormats(dparams); + test::SetThreadParallelRunner(dparams, pool); + std::vector<uint8_t> out; + jxl::PackedPixelFile ppf; + EXPECT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), dparams, + nullptr, &ppf, &out)); + EXPECT_EQ(out.size(), jpeg_in.size()); + size_t failures = 0; + for (size_t i = 0; i < std::min(out.size(), jpeg_in.size()); i++) { + if (out[i] != jpeg_in[i]) { + EXPECT_EQ(out[i], jpeg_in[i]) + << "byte mismatch " << i << " " << out[i] << " != " << jpeg_in[i]; + if (++failures > 4) { + return compressed.size(); + } + } + } + return compressed.size(); +} + +void RoundtripJpegToPixels(const std::vector<uint8_t>& jpeg_in, + JXLDecompressParams dparams, ThreadPool* pool, + PackedPixelFile* ppf_out) { + std::vector<uint8_t> jpeg_bytes(jpeg_in.data(), + jpeg_in.data() + jpeg_in.size()); + std::vector<uint8_t> compressed; + EXPECT_TRUE(extras::EncodeImageJXL({}, extras::PackedPixelFile(), &jpeg_bytes, + &compressed)); + + test::DefaultAcceptedFormats(dparams); + test::SetThreadParallelRunner(dparams, pool); + EXPECT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), dparams, + nullptr, ppf_out, nullptr)); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression444)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_444.jpg"); + // JPEG size is 696,659 bytes. + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 568940u, 20); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionToPixels)) { + TEST_LIBJPEG_SUPPORT(); + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_444.jpg"); + TestImage t; + t.DecodeFromBytes(orig); + + PackedPixelFile ppf_out; + RoundtripJpegToPixels(orig, {}, &pool, &ppf_out); + EXPECT_THAT(ComputeDistance2(t.ppf(), ppf_out), IsSlightlyBelow(12)); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionToPixels420)) { + TEST_LIBJPEG_SUPPORT(); + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_420.jpg"); + TestImage t; + t.DecodeFromBytes(orig); + + PackedPixelFile ppf_out; + RoundtripJpegToPixels(orig, {}, &pool, &ppf_out); + EXPECT_THAT(ComputeDistance2(t.ppf(), ppf_out), IsSlightlyBelow(11)); +} + +TEST(JxlTest, + JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionToPixels420EarlyFlush)) { + TEST_LIBJPEG_SUPPORT(); + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_420.jpg"); + TestImage t; + t.DecodeFromBytes(orig); + + JXLDecompressParams dparams; + dparams.max_downsampling = 8; + + PackedPixelFile ppf_out; + RoundtripJpegToPixels(orig, dparams, &pool, &ppf_out); + EXPECT_THAT(ComputeDistance2(t.ppf(), ppf_out), IsSlightlyBelow(4410)); +} + +TEST(JxlTest, + JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionToPixels420Mul16)) { + TEST_LIBJPEG_SUPPORT(); + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower_cropped.jpg"); + TestImage t; + t.DecodeFromBytes(orig); + + PackedPixelFile ppf_out; + RoundtripJpegToPixels(orig, {}, &pool, &ppf_out); + EXPECT_THAT(ComputeDistance2(t.ppf(), ppf_out), IsSlightlyBelow(4)); +} + +TEST(JxlTest, + JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionToPixels_asymmetric)) { + TEST_LIBJPEG_SUPPORT(); + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_asymmetric.jpg"); + TestImage t; + t.DecodeFromBytes(orig); + + PackedPixelFile ppf_out; + RoundtripJpegToPixels(orig, {}, &pool, &ppf_out); + EXPECT_THAT(ComputeDistance2(t.ppf(), ppf_out), IsSlightlyBelow(10)); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionGray)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_gray.jpg"); + // JPEG size is 456,528 bytes. + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 387496u, 200); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression420)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_420.jpg"); + // JPEG size is 546,797 bytes. + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 455560u, 10); +} + +TEST(JxlTest, + JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression_luma_subsample)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_luma_subsample.jpg"); + // JPEG size is 400,724 bytes. + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 325354u, 15); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression444_12)) { + // 444 JPEG that has an interesting sampling-factor (1x2, 1x2, 1x2). + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_444_1x2.jpg"); + // JPEG size is 703,874 bytes. + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 569679u, 10); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression422)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_422.jpg"); + // JPEG size is 522,057 bytes. + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 499282u, 10); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression440)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_440.jpg"); + // JPEG size is 603,623 bytes. + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 501151u, 10); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression_asymmetric)) { + // 2x vertical downsample of one chroma channel, 2x horizontal downsample of + // the other. + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_asymmetric.jpg"); + // JPEG size is 604,601 bytes. + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 500602u, 10); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompression420Progr)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/flower/flower.png.im_q85_420_progr.jpg"); + // JPEG size is 522,057 bytes. + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 455499u, 10); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionMetadata)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/jpeg_reconstruction/1x1_exif_xmp.jpg"); + // JPEG size is 4290 bytes + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 1400u, 30); +} + +TEST(JxlTest, JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionRestarts)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/jpeg_reconstruction/bicycles_restarts.jpg"); + // JPEG size is 87478 bytes + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 76125u, 30); +} + +TEST(JxlTest, + JXL_TRANSCODE_JPEG_TEST(RoundtripJpegRecompressionOrientationICC)) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("jxl/jpeg_reconstruction/sideways_bench.jpg"); + // JPEG size is 15252 bytes + EXPECT_NEAR(RoundtripJpeg(orig, &pool), 12000u, 470); + // TODO(jon): investigate why 'Cross-compiling i686-linux-gnu' produces a + // larger result +} + +TEST(JxlTest, RoundtripProgressive) { + ThreadPoolForTests pool(4); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata().SetDimensions(600, 1024); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC, 1); + cparams.AddOption(JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC, 1); + cparams.AddOption(JXL_ENC_FRAME_SETTING_RESPONSIVE, 1); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, &pool, &ppf_out), 70544, 750); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(1.4)); +} + +TEST(JxlTest, RoundtripProgressiveLevel2Slow) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata().SetDimensions(600, 1024); + + JXLCompressParams cparams; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 9); // kTortoise + cparams.AddOption(JXL_ENC_FRAME_SETTING_PROGRESSIVE_DC, 2); + cparams.AddOption(JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC, 1); + cparams.AddOption(JXL_ENC_FRAME_SETTING_RESPONSIVE, 1); + + PackedPixelFile ppf_out; + EXPECT_NEAR(Roundtrip(t.ppf(), cparams, {}, &pool, &ppf_out), 76666, 1000); + EXPECT_THAT(ButteraugliDistance(t.ppf(), ppf_out), IsSlightlyBelow(1.17)); +} + +TEST(JxlTest, RoundtripUnsignedCustomBitdepthLossless) { + ThreadPool* pool = nullptr; + for (uint32_t num_channels = 1; num_channels < 6; ++num_channels) { + for (JxlEndianness endianness : {JXL_LITTLE_ENDIAN, JXL_BIG_ENDIAN}) { + for (uint32_t bitdepth = 3; bitdepth <= 16; ++bitdepth) { + if (bitdepth <= 8 && endianness == JXL_BIG_ENDIAN) continue; + printf("Testing %u channel unsigned %u bit %s endian lossless.\n", + num_channels, bitdepth, + endianness == JXL_LITTLE_ENDIAN ? "little" : "big"); + TestImage t; + t.SetDimensions(256, 256).SetChannels(num_channels); + t.SetAllBitDepths(bitdepth).SetEndianness(endianness); + TestImage::Frame frame = t.AddFrame(); + frame.RandomFill(); + + JXLCompressParams cparams = CompressParamsForLossless(); + cparams.input_bitdepth.type = JXL_BIT_DEPTH_FROM_CODESTREAM; + + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + dparams.output_bitdepth.type = JXL_BIT_DEPTH_FROM_CODESTREAM; + + PackedPixelFile ppf_out; + Roundtrip(t.ppf(), cparams, dparams, pool, &ppf_out); + + ASSERT_TRUE(test::SamePixels(t.ppf(), ppf_out)); + } + } + } +} + +TEST(JxlTest, LosslessPNMRoundtrip) { + static const char* kChannels[] = {"", "g", "ga", "rgb", "rgba"}; + static const char* kExtension[] = {"", ".pgm", ".pam", ".ppm", ".pam"}; + for (size_t bit_depth = 1; bit_depth <= 16; ++bit_depth) { + for (size_t channels = 1; channels <= 4; ++channels) { + if (bit_depth == 1 && (channels == 2 || channels == 4)) continue; + std::string extension(kExtension[channels]); + std::string filename = "jxl/flower/flower_small." + + std::string(kChannels[channels]) + ".depth" + + std::to_string(bit_depth) + extension; + const std::vector<uint8_t> orig = ReadTestData(filename); + test::TestImage t; + if (channels < 3) t.SetColorEncoding("Gra_D65_Rel_SRG"); + t.DecodeFromBytes(orig); + + JXLCompressParams cparams = CompressParamsForLossless(); + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 1); // kLightning + cparams.input_bitdepth.type = JXL_BIT_DEPTH_FROM_CODESTREAM; + + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + dparams.output_bitdepth.type = JXL_BIT_DEPTH_FROM_CODESTREAM; + + PackedPixelFile ppf_out; + Roundtrip(t.ppf(), cparams, dparams, nullptr, &ppf_out); + + extras::EncodedImage encoded; + auto encoder = extras::Encoder::FromExtension(extension); + ASSERT_TRUE(encoder.get()); + ASSERT_TRUE(encoder->Encode(ppf_out, &encoded, nullptr)); + ASSERT_EQ(encoded.bitstreams.size(), 1); + ASSERT_EQ(orig.size(), encoded.bitstreams[0].size()); + EXPECT_EQ(0, + memcmp(orig.data(), encoded.bitstreams[0].data(), orig.size())); + } + } +} + +class JxlTest : public ::testing::TestWithParam<const char*> {}; + +TEST_P(JxlTest, LosslessSmallFewColors) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = ReadTestData(GetParam()); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + + JXLCompressParams cparams; + cparams.distance = 0; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 1); + JXLDecompressParams dparams; + dparams.accepted_formats.push_back(t.ppf().frames[0].color.format); + + PackedPixelFile ppf_out; + Roundtrip(t.ppf(), cparams, dparams, &pool, &ppf_out); + EXPECT_EQ(ComputeDistance2(t.ppf(), ppf_out), 0.0); +} + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P( + ImageTests, JxlTest, + ::testing::Values("jxl/blending/cropped_traffic_light_frame-0.png", + "palette/358colors.png")); + +struct StreamingTestParam { + size_t xsize; + size_t ysize; + bool is_grey; + int effort; + bool progressive; + + size_t num_channels() const { return is_grey ? 1 : 3; } + + float max_psnr() const { return is_grey ? 90 : 50; } + + static std::vector<StreamingTestParam> All() { + std::vector<StreamingTestParam> params; + for (int e : {1, 3, 4, 7}) { + for (bool g : {false, true}) { + params.push_back(StreamingTestParam{357, 517, g, e, false}); + params.push_back(StreamingTestParam{2247, 2357, g, e, false}); + } + } + params.push_back(StreamingTestParam{2247, 2357, false, 1, true}); + return params; + } +}; + +std::ostream& operator<<(std::ostream& out, StreamingTestParam p) { + out << (p.is_grey ? "Grey" : "RGB"); + out << p.xsize << "x" << p.ysize; + out << "e" << p.effort; + if (p.progressive) { + out << "Progressive"; + } + return out; +} + +class JxlStreamingTest : public ::testing::TestWithParam<StreamingTestParam> {}; + +TEST_P(JxlStreamingTest, Roundtrip) { + const StreamingTestParam& p = GetParam(); + + jxl::test::TestImage image; + image.SetDimensions(p.xsize, p.ysize) + .SetDataType(JXL_TYPE_UINT8) + .SetChannels(p.num_channels()) + .SetAllBitDepths(8); + image.AddFrame().RandomFill(); + JXLCompressParams cparams; + cparams.distance = 0.1; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, p.effort); + cparams.AddOption(JXL_ENC_FRAME_SETTING_BUFFERING, 3); + if (p.progressive) { + cparams.AddOption(JXL_ENC_FRAME_SETTING_PROGRESSIVE_AC, 1); + } + + ThreadPoolForTests pool(8); + PackedPixelFile ppf_out; + Roundtrip(image.ppf(), cparams, {}, &pool, &ppf_out); + EXPECT_GT(jxl::test::ComputePSNR(image.ppf(), ppf_out), p.max_psnr()); +} + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P( + JxlStreamingTest, JxlStreamingTest, + testing::ValuesIn(StreamingTestParam::All())); + +// This is broken on mingw32, so we only enable it for x86_64 now. +TEST(JxlTest, JXL_X86_64_TEST(StreamingSamePixels)) { + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + + jxl::test::TestImage image; + image.DecodeFromBytes(orig); + JXLCompressParams cparams; + cparams.distance = 1.0; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 6); + cparams.AddOption(JXL_ENC_FRAME_SETTING_USE_FULL_IMAGE_HEURISTICS, 0); + + ThreadPoolForTests pool(8); + PackedPixelFile ppf_out; + Roundtrip(image.ppf(), cparams, {}, &pool, &ppf_out); + + cparams.AddOption(JXL_ENC_FRAME_SETTING_BUFFERING, 3); + PackedPixelFile ppf_out_streaming; + Roundtrip(image.ppf(), cparams, {}, &pool, &ppf_out_streaming); + + EXPECT_TRUE(jxl::test::SamePixels(ppf_out, ppf_out_streaming)); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/lehmer_code.h b/third_party/jpeg-xl/lib/jxl/lehmer_code.h new file mode 100644 index 0000000000..dd1d21c6f7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/lehmer_code.h @@ -0,0 +1,102 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_LEHMER_CODE_H_ +#define LIB_JXL_LEHMER_CODE_H_ + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Permutation <=> factorial base representation (Lehmer code). + +using LehmerT = uint32_t; + +template <typename T> +constexpr T ValueOfLowest1Bit(T t) { + return t & -t; +} + +// Computes the Lehmer (factorial basis) code of permutation, an array of n +// unique indices in [0..n), and stores it in code[0..len). N*logN time. +// temp must have n + 1 elements but need not be initialized. +template <typename PermutationT> +void ComputeLehmerCode(const PermutationT* JXL_RESTRICT permutation, + uint32_t* JXL_RESTRICT temp, const size_t n, + LehmerT* JXL_RESTRICT code) { + for (size_t idx = 0; idx < n + 1; ++idx) temp[idx] = 0; + + for (size_t idx = 0; idx < n; ++idx) { + const PermutationT s = permutation[idx]; + + // Compute sum in Fenwick tree + uint32_t penalty = 0; + uint32_t i = s + 1; + while (i != 0) { + penalty += temp[i]; + i &= i - 1; // clear lowest bit + } + JXL_DASSERT(s >= penalty); + code[idx] = s - penalty; + i = s + 1; + // Add operation in Fenwick tree + while (i < n + 1) { + temp[i] += 1; + i += ValueOfLowest1Bit(i); + } + } +} + +// Decodes the Lehmer code in code[0..n) into permutation[0..n). +// temp must have 1 << CeilLog2(n) elements but need not be initialized. +template <typename PermutationT> +void DecodeLehmerCode(const LehmerT* JXL_RESTRICT code, + uint32_t* JXL_RESTRICT temp, size_t n, + PermutationT* JXL_RESTRICT permutation) { + JXL_DASSERT(n != 0); + const size_t log2n = CeilLog2Nonzero(n); + const size_t padded_n = 1ull << log2n; + + for (size_t i = 0; i < padded_n; i++) { + const int32_t i1 = static_cast<int32_t>(i + 1); + temp[i] = static_cast<uint32_t>(ValueOfLowest1Bit(i1)); + } + + for (size_t i = 0; i < n; i++) { + JXL_DASSERT(code[i] + i < n); + uint32_t rank = code[i] + 1; + + // Extract i-th unused element via implicit order-statistics tree. + size_t bit = padded_n; + size_t next = 0; + for (size_t i = 0; i <= log2n; i++) { + const size_t cand = next + bit; + JXL_DASSERT(cand >= 1); + bit >>= 1; + if (temp[cand - 1] < rank) { + next = cand; + rank -= temp[cand - 1]; + } + } + + permutation[i] = next; + + // Mark as used + next += 1; + while (next <= padded_n) { + temp[next - 1] -= 1; + next += ValueOfLowest1Bit(next); + } + } +} + +} // namespace jxl + +#endif // LIB_JXL_LEHMER_CODE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/lehmer_code_test.cc b/third_party/jpeg-xl/lib/jxl/lehmer_code_test.cc new file mode 100644 index 0000000000..acda762545 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/lehmer_code_test.cc @@ -0,0 +1,98 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/lehmer_code.h" + +#include <stdint.h> +#include <string.h> + +#include <algorithm> +#include <numeric> +#include <vector> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +template <typename PermutationT> +struct WorkingSet { + explicit WorkingSet(size_t max_n) + : padded_n(1ull << CeilLog2Nonzero(max_n + 1)), + permutation(max_n), + temp(padded_n), + lehmer(max_n), + decoded(max_n) {} + + size_t padded_n; + std::vector<PermutationT> permutation; + std::vector<uint32_t> temp; + std::vector<LehmerT> lehmer; + std::vector<PermutationT> decoded; +}; + +template <typename PermutationT> +void Roundtrip(size_t n, WorkingSet<PermutationT>* ws) { + JXL_ASSERT(n != 0); + const size_t padded_n = 1ull << CeilLog2Nonzero(n); + + Rng rng(n * 65537 + 13); + + // Ensure indices fit into PermutationT + EXPECT_LE(n, 1ULL << (sizeof(PermutationT) * 8)); + + std::iota(ws->permutation.begin(), ws->permutation.begin() + n, 0); + + // For various random permutations: + for (size_t rep = 0; rep < 3; ++rep) { + rng.Shuffle(ws->permutation.data(), n); + + // Must decode to the same permutation + ComputeLehmerCode(ws->permutation.data(), ws->temp.data(), n, + ws->lehmer.data()); + memset(ws->temp.data(), 0, padded_n * 4); + DecodeLehmerCode(ws->lehmer.data(), ws->temp.data(), n, ws->decoded.data()); + + for (size_t i = 0; i < n; ++i) { + EXPECT_EQ(ws->permutation[i], ws->decoded[i]); + } + } +} + +// Preallocates arrays and tests n = [begin, end). +template <typename PermutationT> +void RoundtripSizeRange(ThreadPool* pool, uint32_t begin, uint32_t end) { + ASSERT_NE(0u, begin); // n = 0 not allowed. + std::vector<WorkingSet<PermutationT>> working_sets; + + JXL_CHECK(RunOnPool( + pool, begin, end, + [&working_sets, end](const size_t num_threads) { + for (size_t i = 0; i < num_threads; i++) { + working_sets.emplace_back(end - 1); + } + return true; + }, + [&working_sets](const uint32_t n, const size_t thread) { + Roundtrip(n, &working_sets[thread]); + }, + "lehmer test")); +} + +TEST(LehmerCodeTest, TestRoundtrips) { + test::ThreadPoolForTests pool(8); + + RoundtripSizeRange<uint16_t>(&pool, 1, 1026); + + // Ensures PermutationT can fit > 16 bit values. + RoundtripSizeRange<uint32_t>(&pool, 65536, 65540); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/libjxl.pc.in b/third_party/jpeg-xl/lib/jxl/libjxl.pc.in new file mode 100644 index 0000000000..58b6941305 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/libjxl.pc.in @@ -0,0 +1,13 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=@PKGCONFIG_TARGET_LIBS@ +includedir=@PKGCONFIG_TARGET_INCLUDES@ + +Name: libjxl +Description: Loads and saves JPEG XL files +Version: @JPEGXL_LIBRARY_VERSION@ +@JPEGXL_REQUIRES_TYPE@: @JPEGXL_LIBRARY_REQUIRES@ +Libs: -L${libdir} -ljxl +Libs.private: -lm +Cflags: -I${includedir} +Cflags.private: -DJXL_STATIC_DEFINE diff --git a/third_party/jpeg-xl/lib/jxl/libjxl_cms.pc.in b/third_party/jpeg-xl/lib/jxl/libjxl_cms.pc.in new file mode 100644 index 0000000000..9aaa3f4dbe --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/libjxl_cms.pc.in @@ -0,0 +1,13 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=@PKGCONFIG_TARGET_LIBS@ +includedir=@PKGCONFIG_TARGET_INCLUDES@ + +Name: libjxl_cms +Description: CMS support library for libjxl +Version: @JPEGXL_LIBRARY_VERSION@ +@JPEGXL_REQUIRES_TYPE@: @JPEGXL_CMS_LIBRARY_REQUIRES@ +Libs: -L${libdir} -ljxl_cms +Libs.private: -lm +Cflags: -I${includedir} +Cflags.private: -DJXL_CMS_STATIC_DEFINE diff --git a/third_party/jpeg-xl/lib/jxl/loop_filter.cc b/third_party/jpeg-xl/lib/jxl/loop_filter.cc new file mode 100644 index 0000000000..5afe87617d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/loop_filter.cc @@ -0,0 +1,98 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/loop_filter.h" + +#include "lib/jxl/base/status.h" +#include "lib/jxl/fields.h" + +namespace jxl { + +LoopFilter::LoopFilter() { Bundle::Init(this); } +Status LoopFilter::VisitFields(Visitor* JXL_RESTRICT visitor) { + // Must come before AllDefault. + + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(true, &gab)); + if (visitor->Conditional(gab)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &gab_custom)); + if (visitor->Conditional(gab_custom)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.104699568f, &gab_x_weight1)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.055680538f, &gab_x_weight2)); + if (std::abs(1.0f + (gab_x_weight1 + gab_x_weight2) * 4) < 1e-8) { + return JXL_FAILURE( + "Gaborish x weights lead to near 0 unnormalized kernel"); + } + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.104699568f, &gab_y_weight1)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.055680538f, &gab_y_weight2)); + if (std::abs(1.0f + (gab_y_weight1 + gab_y_weight2) * 4) < 1e-8) { + return JXL_FAILURE( + "Gaborish y weights lead to near 0 unnormalized kernel"); + } + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.104699568f, &gab_b_weight1)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(1.1 * 0.055680538f, &gab_b_weight2)); + if (std::abs(1.0f + (gab_b_weight1 + gab_b_weight2) * 4) < 1e-8) { + return JXL_FAILURE( + "Gaborish b weights lead to near 0 unnormalized kernel"); + } + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(2, 2, &epf_iters)); + if (visitor->Conditional(epf_iters > 0)) { + if (visitor->Conditional(!nonserialized_is_modular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &epf_sharp_custom)); + if (visitor->Conditional(epf_sharp_custom)) { + for (size_t i = 0; i < kEpfSharpEntries; ++i) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16( + float(i) / float(kEpfSharpEntries - 1), &epf_sharp_lut[i])); + } + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &epf_weight_custom)); + if (visitor->Conditional(epf_weight_custom)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(40.0f, &epf_channel_scale[0])); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(5.0f, &epf_channel_scale[1])); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(3.5f, &epf_channel_scale[2])); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.45f, &epf_pass1_zeroflush)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.6f, &epf_pass2_zeroflush)); + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &epf_sigma_custom)); + if (visitor->Conditional(epf_sigma_custom)) { + if (visitor->Conditional(!nonserialized_is_modular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.46f, &epf_quant_mul)); + } + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(0.9f, &epf_pass0_sigma_scale)); + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(6.5f, &epf_pass2_sigma_scale)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->F16(0.6666666666666666f, &epf_border_sad_mul)); + } + if (visitor->Conditional(nonserialized_is_modular)) { + JXL_QUIET_RETURN_IF_ERROR(visitor->F16(1.0f, &epf_sigma_for_modular)); + if (epf_sigma_for_modular < 1e-8) { + return JXL_FAILURE("EPF: sigma for modular is too small"); + } + } + } + + JXL_QUIET_RETURN_IF_ERROR(visitor->BeginExtensions(&extensions)); + // Extensions: in chronological order of being added to the format. + return visitor->EndExtensions(); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/loop_filter.h b/third_party/jpeg-xl/lib/jxl/loop_filter.h new file mode 100644 index 0000000000..e4b418ba2b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/loop_filter.h @@ -0,0 +1,76 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_LOOP_FILTER_H_ +#define LIB_JXL_LOOP_FILTER_H_ + +// Parameters for loop filter(s), stored in each frame. + +#include <stddef.h> +#include <stdint.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +struct LoopFilter : public Fields { + LoopFilter(); + JXL_FIELDS_NAME(LoopFilter) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + size_t Padding() const { + static const size_t padding_per_epf_iter[4] = {0, 2, 3, 6}; + return padding_per_epf_iter[epf_iters] + (gab ? 1 : 0); + } + + mutable bool all_default; + + // --- Gaborish convolution + bool gab; + + bool gab_custom; + float gab_x_weight1; + float gab_x_weight2; + float gab_y_weight1; + float gab_y_weight2; + float gab_b_weight1; + float gab_b_weight2; + + // --- Edge-preserving filter + + // Number of EPF stages to apply. 0 means EPF disabled. 1 applies only the + // first stage, 2 applies both stages and 3 applies the first stage twice and + // the second stage once. + uint32_t epf_iters; + + bool epf_sharp_custom; + enum { kEpfSharpEntries = 8 }; + float epf_sharp_lut[kEpfSharpEntries]; + + bool epf_weight_custom; // Custom weight params + float epf_channel_scale[3]; // Relative weight of each channel + float epf_pass1_zeroflush; // Minimum weight for first pass + float epf_pass2_zeroflush; // Minimum weight for second pass + + bool epf_sigma_custom; // Custom sigma parameters + float epf_quant_mul; // Sigma is ~ this * quant + float epf_pass0_sigma_scale; // Multiplier for sigma in pass 0 + float epf_pass2_sigma_scale; // Multiplier for sigma in the second pass + float epf_border_sad_mul; // (inverse) multiplier for sigma on borders + + float epf_sigma_for_modular; + + uint64_t extensions; + + bool nonserialized_is_modular = false; +}; + +} // namespace jxl + +#endif // LIB_JXL_LOOP_FILTER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/luminance.cc b/third_party/jpeg-xl/lib/jxl/luminance.cc new file mode 100644 index 0000000000..7af4b2f9a9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/luminance.cc @@ -0,0 +1,26 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/luminance.h" + +#include "lib/jxl/image_metadata.h" + +namespace jxl { + +void SetIntensityTarget(ImageMetadata* m) { + if (m->color_encoding.Tf().IsPQ()) { + // Peak luminance of PQ as defined by SMPTE ST 2084:2014. + m->SetIntensityTarget(10000); + } else if (m->color_encoding.Tf().IsHLG()) { + // Nominal display peak luminance used as a reference by + // Rec. ITU-R BT.2100-2. + m->SetIntensityTarget(1000); + } else { + // SDR + m->SetIntensityTarget(kDefaultIntensityTarget); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/luminance.h b/third_party/jpeg-xl/lib/jxl/luminance.h new file mode 100644 index 0000000000..3181576823 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/luminance.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_LUMINANCE_H_ +#define LIB_JXL_LUMINANCE_H_ + +namespace jxl { + +// Chooses a default intensity target based on the transfer function of the +// image, if known. For SDR images or images not known to be HDR, returns +// kDefaultIntensityTarget, for images known to have PQ or HLG transfer function +// returns a higher value. + +struct ImageMetadata; +// TODO(eustas): rename +void SetIntensityTarget(ImageMetadata* m); + +} // namespace jxl + +#endif // LIB_JXL_LUMINANCE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/memory_manager_internal.cc b/third_party/jpeg-xl/lib/jxl/memory_manager_internal.cc new file mode 100644 index 0000000000..87727e75cd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/memory_manager_internal.cc @@ -0,0 +1,18 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/memory_manager_internal.h" + +#include <stdlib.h> + +namespace jxl { + +void* MemoryManagerDefaultAlloc(void* opaque, size_t size) { + return malloc(size); +} + +void MemoryManagerDefaultFree(void* opaque, void* address) { free(address); } + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/memory_manager_internal.h b/third_party/jpeg-xl/lib/jxl/memory_manager_internal.h new file mode 100644 index 0000000000..c728d62e35 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/memory_manager_internal.h @@ -0,0 +1,99 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MEMORY_MANAGER_INTERNAL_H_ +#define LIB_JXL_MEMORY_MANAGER_INTERNAL_H_ + +// Memory allocator with support for alignment + misalignment. + +#include <jxl/memory_manager.h> +#include <stddef.h> +#include <stdlib.h> +#include <string.h> // memcpy + +#include <memory> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" + +namespace jxl { + +// Default alloc and free functions. +void* MemoryManagerDefaultAlloc(void* opaque, size_t size); +void MemoryManagerDefaultFree(void* opaque, void* address); + +// Initializes the memory manager instance with the passed one. The +// MemoryManager passed in |memory_manager| may be NULL or contain NULL +// functions which will be initialized with the default ones. If either alloc +// or free are NULL, then both must be NULL, otherwise this function returns an +// error. +static JXL_INLINE Status MemoryManagerInit( + JxlMemoryManager* self, const JxlMemoryManager* memory_manager) { + if (memory_manager) { + *self = *memory_manager; + } else { + memset(self, 0, sizeof(*self)); + } + if (!self->alloc != !self->free) { + return false; + } + if (!self->alloc) self->alloc = jxl::MemoryManagerDefaultAlloc; + if (!self->free) self->free = jxl::MemoryManagerDefaultFree; + + return true; +} + +static JXL_INLINE void* MemoryManagerAlloc( + const JxlMemoryManager* memory_manager, size_t size) { + return memory_manager->alloc(memory_manager->opaque, size); +} + +static JXL_INLINE void MemoryManagerFree(const JxlMemoryManager* memory_manager, + void* address) { + return memory_manager->free(memory_manager->opaque, address); +} + +// Helper class to be used as a deleter in a unique_ptr<T> call. +class MemoryManagerDeleteHelper { + public: + explicit MemoryManagerDeleteHelper(const JxlMemoryManager* memory_manager) + : memory_manager_(memory_manager) {} + + // Delete and free the passed pointer using the memory_manager. + template <typename T> + void operator()(T* address) const { + if (!address) { + return; + } + address->~T(); + return memory_manager_->free(memory_manager_->opaque, address); + } + + private: + const JxlMemoryManager* memory_manager_; +}; + +template <typename T> +using MemoryManagerUniquePtr = std::unique_ptr<T, MemoryManagerDeleteHelper>; + +// Creates a new object T allocating it with the memory allocator into a +// unique_ptr. +template <typename T, typename... Args> +JXL_INLINE MemoryManagerUniquePtr<T> MemoryManagerMakeUnique( + const JxlMemoryManager* memory_manager, Args&&... args) { + T* mem = + static_cast<T*>(memory_manager->alloc(memory_manager->opaque, sizeof(T))); + if (!mem) { + // Allocation error case. + return MemoryManagerUniquePtr<T>(nullptr, + MemoryManagerDeleteHelper(memory_manager)); + } + return MemoryManagerUniquePtr<T>(new (mem) T(std::forward<Args>(args)...), + MemoryManagerDeleteHelper(memory_manager)); +} + +} // namespace jxl + +#endif // LIB_JXL_MEMORY_MANAGER_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h new file mode 100644 index 0000000000..4c3a33a52a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/context_predict.h @@ -0,0 +1,672 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ +#define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ + +#include <utility> +#include <vector> + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +namespace weighted { +constexpr static size_t kNumPredictors = 4; +constexpr static int64_t kPredExtraBits = 3; +constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1; +constexpr static size_t kNumProperties = 1; + +struct Header : public Fields { + JXL_FIELDS_NAME(WeightedPredictorHeader) + // TODO(janwas): move to cc file, avoid including fields.h. + Header() { Bundle::Init(this); } + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + if (visitor->AllDefault(*this, &all_default)) { + // Overwrite all serialized fields, but not any nonserialized_*. + visitor->SetDefault(this); + return true; + } + auto visit_p = [visitor](pixel_type val, pixel_type *p) { + uint32_t up = *p; + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up)); + *p = up; + return Status(true); + }; + JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd)); + JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2])); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3])); + return true; + } + + bool all_default; + pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0; + uint32_t w[kNumPredictors] = {}; +}; + +struct State { + pixel_type_w prediction[kNumPredictors] = {}; + pixel_type_w pred = 0; // *before* removing the added bits. + std::vector<uint32_t> pred_errors[kNumPredictors]; + std::vector<int32_t> error; + const Header header; + + // Allows to approximate division by a number from 1 to 64. + // for (int i = 0; i < 64; i++) divlookup[i] = (1 << 24) / (i + 1); + + const uint32_t divlookup[64] = { + 16777216, 8388608, 5592405, 4194304, 3355443, 2796202, 2396745, 2097152, + 1864135, 1677721, 1525201, 1398101, 1290555, 1198372, 1118481, 1048576, + 986895, 932067, 883011, 838860, 798915, 762600, 729444, 699050, + 671088, 645277, 621378, 599186, 578524, 559240, 541200, 524288, + 508400, 493447, 479349, 466033, 453438, 441505, 430185, 419430, + 409200, 399457, 390167, 381300, 372827, 364722, 356962, 349525, + 342392, 335544, 328965, 322638, 316551, 310689, 305040, 299593, + 294337, 289262, 284359, 279620, 275036, 270600, 266305, 262144}; + + constexpr static pixel_type_w AddBits(pixel_type_w x) { + return uint64_t(x) << kPredExtraBits; + } + + State(Header header, size_t xsize, size_t ysize) : header(header) { + // Extra margin to avoid out-of-bounds writes. + // All have space for two rows of data. + for (size_t i = 0; i < 4; i++) { + pred_errors[i].resize((xsize + 2) * 2); + } + error.resize((xsize + 2) * 2); + } + + // Approximates 4+(maxweight<<24)/(x+1), avoiding division + JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const { + int shift = static_cast<int>(FloorLog2Nonzero(x + 1)) - 5; + if (shift < 0) shift = 0; + return 4 + ((maxweight * divlookup[x >> shift]) >> shift); + } + + // Approximates the weighted average of the input values with the given + // weights, avoiding division. Weights must sum to at least 16. + JXL_INLINE pixel_type_w + WeightedAverage(const pixel_type_w *JXL_RESTRICT p, + std::array<uint32_t, kNumPredictors> w) const { + uint32_t weight_sum = 0; + for (size_t i = 0; i < kNumPredictors; i++) { + weight_sum += w[i]; + } + JXL_DASSERT(weight_sum > 15); + uint32_t log_weight = FloorLog2Nonzero(weight_sum); // at least 4. + weight_sum = 0; + for (size_t i = 0; i < kNumPredictors; i++) { + w[i] >>= log_weight - 4; + weight_sum += w[i]; + } + // for rounding. + pixel_type_w sum = (weight_sum >> 1) - 1; + for (size_t i = 0; i < kNumPredictors; i++) { + sum += p[i] * w[i]; + } + return (sum * divlookup[weight_sum - 1]) >> 24; + } + + template <bool compute_properties> + JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize, + pixel_type_w N, pixel_type_w W, + pixel_type_w NE, pixel_type_w NW, + pixel_type_w NN, Properties *properties, + size_t offset) { + size_t cur_row = y & 1 ? 0 : (xsize + 2); + size_t prev_row = y & 1 ? (xsize + 2) : 0; + size_t pos_N = prev_row + x; + size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N; + size_t pos_NW = x > 0 ? pos_N - 1 : pos_N; + std::array<uint32_t, kNumPredictors> weights; + for (size_t i = 0; i < kNumPredictors; i++) { + // pred_errors[pos_N] also contains the error of pixel W. + // pred_errors[pos_NW] also contains the error of pixel WW. + weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] + + pred_errors[i][pos_NW]; + weights[i] = ErrorWeight(weights[i], header.w[i]); + } + + N = AddBits(N); + W = AddBits(W); + NE = AddBits(NE); + NW = AddBits(NW); + NN = AddBits(NN); + + pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1]; + pixel_type_w teN = error[pos_N]; + pixel_type_w teNW = error[pos_NW]; + pixel_type_w sumWN = teN + teW; + pixel_type_w teNE = error[pos_NE]; + + if (compute_properties) { + pixel_type_w p = teW; + if (std::abs(teN) > std::abs(p)) p = teN; + if (std::abs(teNW) > std::abs(p)) p = teNW; + if (std::abs(teNE) > std::abs(p)) p = teNE; + (*properties)[offset++] = p; + } + + prediction[0] = W + NE - N; + prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5); + prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5); + prediction[3] = + N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc + + (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >> + 5); + + pred = WeightedAverage(prediction, weights); + + // If all three have the same sign, skip clamping. + if (((teN ^ teW) | (teN ^ teNW)) > 0) { + return (pred + kPredictionRound) >> kPredExtraBits; + } + + // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N). + pixel_type_w mx = std::max(W, std::max(NE, N)); + pixel_type_w mn = std::min(W, std::min(NE, N)); + pred = std::max(mn, std::min(mx, pred)); + return (pred + kPredictionRound) >> kPredExtraBits; + } + + JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y, + size_t xsize) { + size_t cur_row = y & 1 ? 0 : (xsize + 2); + size_t prev_row = y & 1 ? (xsize + 2) : 0; + val = AddBits(val); + error[cur_row + x] = pred - val; + for (size_t i = 0; i < kNumPredictors; i++) { + pixel_type_w err = + (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits; + // For predicting in the next row. + pred_errors[i][cur_row + x] = err; + // Add the error on this pixel to the error on the NE pixel. This has the + // effect of adding the error on this pixel to the E and EE pixels. + pred_errors[i][prev_row + x + 1] += err; + } + } +}; + +// Encoder helper function to set the parameters to some presets. +inline void PredictorMode(int i, Header *header) { + switch (i) { + case 0: + // ~ lossless16 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 16; + header->p2C = 10; + header->p3Ca = 7; + header->p3Cb = 7; + header->p3Cc = 7; + header->p3Cd = 0; + header->p3Ce = 0; + break; + case 1: + // ~ default lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xb; + header->p1C = 8; + header->p2C = 8; + header->p3Ca = 4; + header->p3Cb = 0; + header->p3Cc = 3; + header->p3Cd = 23; + header->p3Ce = 2; + break; + case 2: + // ~ west lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xd; + header->w[3] = 0xc; + header->p1C = 10; + header->p2C = 9; + header->p3Ca = 7; + header->p3Cb = 0; + header->p3Cc = 0; + header->p3Cd = 16; + header->p3Ce = 9; + break; + case 3: + // ~ north lossless8 predictor + header->w[0] = 0xd; + header->w[1] = 0xd; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 16; + header->p2C = 8; + header->p3Ca = 0; + header->p3Cb = 16; + header->p3Cc = 0; + header->p3Cd = 23; + header->p3Ce = 0; + break; + case 4: + default: + // something else, because why not + header->w[0] = 0xd; + header->w[1] = 0xc; + header->w[2] = 0xc; + header->w[3] = 0xc; + header->p1C = 10; + header->p2C = 10; + header->p3Ca = 5; + header->p3Cb = 5; + header->p3Cc = 5; + header->p3Cd = 12; + header->p3Ce = 4; + break; + } +} +} // namespace weighted + +// Stores a node and its two children at the same time. This significantly +// reduces the number of branches needed during decoding. +struct FlatDecisionNode { + // Property + splitval of the top node. + int32_t property0; // -1 if leaf. + union { + PropertyVal splitval0; + Predictor predictor; + }; + // Property+splitval of the two child nodes. + union { + PropertyVal splitvals[2]; + int32_t multiplier; + }; + uint32_t childID; // childID is ctx id if leaf. + union { + int16_t properties[2]; + int32_t predictor_offset; + }; +}; +using FlatTree = std::vector<FlatDecisionNode>; + +class MATreeLookup { + public: + explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {} + struct LookupResult { + uint32_t context; + Predictor predictor; + int32_t offset; + int32_t multiplier; + }; + JXL_INLINE LookupResult Lookup(const Properties &properties) const { + uint32_t pos = 0; + while (true) { +#define TRAVERSE_THE_TREE \ + { \ + const FlatDecisionNode &node = nodes_[pos]; \ + if (node.property0 < 0) { \ + return {node.childID, node.predictor, node.predictor_offset, \ + node.multiplier}; \ + } \ + bool p0 = properties[node.property0] <= node.splitval0; \ + uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0]; \ + uint32_t off1 = 2 | (properties[node.properties[1]] <= node.splitvals[1]); \ + pos = node.childID + (p0 ? off1 : off0); \ + } + + TRAVERSE_THE_TREE; + TRAVERSE_THE_TREE; + } + } + + private: + const FlatTree &nodes_; +}; + +static constexpr size_t kExtraPropsPerChannel = 4; +static constexpr size_t kNumNonrefProperties = + kNumStaticProperties + 13 + weighted::kNumProperties; + +constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties; +constexpr size_t kGradientProp = 9; + +// Clamps gradient to the min/max of n, w (and l, implicitly). +static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w, + const int32_t l) { + const int32_t m = std::min(n, w); + const int32_t M = std::max(n, w); + // The end result of this operation doesn't overflow or underflow if the + // result is between m and M, but the intermediate value may overflow, so we + // do the intermediate operations in uint32_t and check later if we had an + // overflow or underflow condition comparing m, M and l directly. + // grad = M + m - l = n + w - l + const int32_t grad = + static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) - + static_cast<uint32_t>(l)); + // We use two sets of ternary operators to force the evaluation of them in + // any case, allowing the compiler to avoid branches and use cmovl/cmovg in + // x86. + const int32_t grad_clamp_M = (l < m) ? M : grad; + return (l > M) ? m : grad_clamp_M; +} + +inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) { + pixel_type_w p = a + b - c; + pixel_type_w pa = std::abs(p - a); + pixel_type_w pb = std::abs(p - b); + return pa < pb ? a : b; +} + +inline void PrecomputeReferences(const Channel &ch, size_t y, + const Image &image, uint32_t i, + Channel *references) { + ZeroFillImage(&references->plane); + uint32_t offset = 0; + size_t num_extra_props = references->w; + intptr_t onerow = references->plane.PixelsPerRow(); + for (int32_t j = static_cast<int32_t>(i) - 1; + j >= 0 && offset < num_extra_props; j--) { + if (image.channel[j].w != image.channel[i].w || + image.channel[j].h != image.channel[i].h) { + continue; + } + if (image.channel[j].hshift != image.channel[i].hshift) continue; + if (image.channel[j].vshift != image.channel[i].vshift) continue; + pixel_type *JXL_RESTRICT rp = references->Row(0) + offset; + const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y); + const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0); + for (size_t x = 0; x < ch.w; x++, rp += onerow) { + pixel_type_w v = rpp[x]; + rp[0] = std::abs(v); + rp[1] = v; + pixel_type_w vleft = (x ? rpp[x - 1] : 0); + pixel_type_w vtop = (y ? rpprev[x] : vleft); + pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft); + pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft); + rp[2] = std::abs(v - vpredicted); + rp[3] = v - vpredicted; + } + + offset += kExtraPropsPerChannel; + } +} + +struct PredictionResult { + int context = 0; + pixel_type_w guess = 0; + Predictor predictor; + int32_t multiplier; +}; + +inline void InitPropsRow( + Properties *p, + const std::array<pixel_type, kNumStaticProperties> &static_props, + const int y) { + for (size_t i = 0; i < kNumStaticProperties; i++) { + (*p)[i] = static_props[i]; + } + (*p)[2] = y; + (*p)[9] = 0; // local gradient. +} + +namespace detail { +enum PredictorMode { + kUseTree = 1, + kUseWP = 2, + kForceComputeProperties = 4, + kAllPredictions = 8, + kNoEdgeCases = 16 +}; + +JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left, + pixel_type_w top, pixel_type_w toptop, + pixel_type_w topleft, pixel_type_w topright, + pixel_type_w leftleft, + pixel_type_w toprightright, + pixel_type_w wp_pred) { + switch (p) { + case Predictor::Zero: + return pixel_type_w{0}; + case Predictor::Left: + return left; + case Predictor::Top: + return top; + case Predictor::Select: + return Select(left, top, topleft); + case Predictor::Weighted: + return wp_pred; + case Predictor::Gradient: + return pixel_type_w{ClampedGradient(left, top, topleft)}; + case Predictor::TopLeft: + return topleft; + case Predictor::TopRight: + return topright; + case Predictor::LeftLeft: + return leftleft; + case Predictor::Average0: + return (left + top) / 2; + case Predictor::Average1: + return (left + topleft) / 2; + case Predictor::Average2: + return (topleft + top) / 2; + case Predictor::Average3: + return (top + topright) / 2; + case Predictor::Average4: + return (6 * top - 2 * toptop + 7 * left + 1 * leftleft + + 1 * toprightright + 3 * topright + 8) / + 16; + default: + return pixel_type_w{0}; + } +} + +template <int mode> +JXL_INLINE PredictionResult Predict( + Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const size_t x, const size_t y, Predictor predictor, + const MATreeLookup *lookup, const Channel *references, + weighted::State *wp_state, pixel_type_w *predictions) { + // We start in position 3 because of 2 static properties + y. + size_t offset = 3; + constexpr bool compute_properties = + mode & kUseTree || mode & kForceComputeProperties; + constexpr bool nec = mode & kNoEdgeCases; + pixel_type_w left = (nec || x ? pp[-1] : (y ? pp[-onerow] : 0)); + pixel_type_w top = (nec || y ? pp[-onerow] : left); + pixel_type_w topleft = (nec || (x && y) ? pp[-1 - onerow] : left); + pixel_type_w topright = (nec || (x + 1 < w && y) ? pp[1 - onerow] : top); + pixel_type_w leftleft = (nec || x > 1 ? pp[-2] : left); + pixel_type_w toptop = (nec || y > 1 ? pp[-onerow - onerow] : top); + pixel_type_w toprightright = + (nec || (x + 2 < w && y) ? pp[2 - onerow] : topright); + + if (compute_properties) { + // location + (*p)[offset++] = x; + // neighbors + (*p)[offset++] = top > 0 ? top : -top; + (*p)[offset++] = left > 0 ? left : -left; + (*p)[offset++] = top; + (*p)[offset++] = left; + + // local gradient + (*p)[offset] = left - (*p)[offset + 1]; + offset++; + // local gradient + (*p)[offset++] = left + top - topleft; + + // FFV1 context properties + (*p)[offset++] = left - topleft; + (*p)[offset++] = topleft - top; + (*p)[offset++] = top - topright; + (*p)[offset++] = top - toptop; + (*p)[offset++] = left - leftleft; + } + + pixel_type_w wp_pred = 0; + if (mode & kUseWP) { + wp_pred = wp_state->Predict<compute_properties>( + x, y, w, top, left, topright, topleft, toptop, p, offset); + } + if (!nec && compute_properties) { + offset += weighted::kNumProperties; + // Extra properties. + const pixel_type *JXL_RESTRICT rp = references->Row(x); + for (size_t i = 0; i < references->w; i++) { + (*p)[offset++] = rp[i]; + } + } + PredictionResult result; + if (mode & kUseTree) { + MATreeLookup::LookupResult lr = lookup->Lookup(*p); + result.context = lr.context; + result.guess = lr.offset; + result.multiplier = lr.multiplier; + predictor = lr.predictor; + } + if (mode & kAllPredictions) { + for (size_t i = 0; i < kNumModularPredictors; i++) { + predictions[i] = PredictOne((Predictor)i, left, top, toptop, topleft, + topright, leftleft, toprightright, wp_pred); + } + } + result.guess += PredictOne(predictor, left, top, toptop, topleft, topright, + leftleft, toprightright, wp_pred); + result.predictor = predictor; + + return result; +} +} // namespace detail + +inline PredictionResult PredictNoTreeNoWP(size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor) { + return detail::Predict</*mode=*/0>( + /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, + /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr); +} + +inline PredictionResult PredictNoTreeWP(size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor, + weighted::State *wp_state) { + return detail::Predict<detail::kUseWP>( + /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, + /*references=*/nullptr, wp_state, /*predictions=*/nullptr); +} + +inline PredictionResult PredictTreeNoWP(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, + const MATreeLookup &tree_lookup, + const Channel &references) { + return detail::Predict<detail::kUseTree>( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + /*wp_state=*/nullptr, /*predictions=*/nullptr); +} +// Only use for y > 1, x > 1, x < w-2, and empty references +JXL_INLINE PredictionResult +PredictTreeNoWPNEC(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + const MATreeLookup &tree_lookup, const Channel &references) { + return detail::Predict<detail::kUseTree | detail::kNoEdgeCases>( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + /*wp_state=*/nullptr, /*predictions=*/nullptr); +} + +inline PredictionResult PredictTreeWP(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, + const MATreeLookup &tree_lookup, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict<detail::kUseTree | detail::kUseWP>( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + wp_state, /*predictions=*/nullptr); +} +JXL_INLINE PredictionResult PredictTreeWPNEC(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, + const MATreeLookup &tree_lookup, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict<detail::kUseTree | detail::kUseWP | + detail::kNoEdgeCases>( + p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, + wp_state, /*predictions=*/nullptr); +} + +inline PredictionResult PredictLearn(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>( + p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references, + wp_state, /*predictions=*/nullptr); +} + +inline void PredictLearnAll(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + const Channel &references, + weighted::State *wp_state, + pixel_type_w *predictions) { + detail::Predict<detail::kForceComputeProperties | detail::kUseWP | + detail::kAllPredictions>( + p, w, pp, onerow, x, y, Predictor::Zero, + /*lookup=*/nullptr, &references, wp_state, predictions); +} +inline PredictionResult PredictLearnNEC(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, + const int y, Predictor predictor, + const Channel &references, + weighted::State *wp_state) { + return detail::Predict<detail::kForceComputeProperties | detail::kUseWP | + detail::kNoEdgeCases>( + p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references, + wp_state, /*predictions=*/nullptr); +} + +inline void PredictLearnAllNEC(Properties *p, size_t w, + const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + const Channel &references, + weighted::State *wp_state, + pixel_type_w *predictions) { + detail::Predict<detail::kForceComputeProperties | detail::kUseWP | + detail::kAllPredictions | detail::kNoEdgeCases>( + p, w, pp, onerow, x, y, Predictor::Zero, + /*lookup=*/nullptr, &references, wp_state, predictions); +} + +inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp, + const intptr_t onerow, const int x, const int y, + pixel_type_w *predictions) { + detail::Predict<detail::kAllPredictions>( + /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero, + /*lookup=*/nullptr, + /*references=*/nullptr, /*wp_state=*/nullptr, predictions); +} +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc new file mode 100644 index 0000000000..ee7177bcd6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.cc @@ -0,0 +1,108 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/dec_ma.h" + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/pack_signed.h" + +namespace jxl { + +namespace { + +Status ValidateTree( + const Tree &tree, + const std::vector<std::pair<pixel_type, pixel_type>> &prop_bounds, + size_t root) { + if (tree[root].property == -1) return true; + size_t p = tree[root].property; + int val = tree[root].splitval; + if (prop_bounds[p].first > val) return JXL_FAILURE("Invalid tree"); + // Splitting at max value makes no sense: left range will be exactly same + // as parent, right range will be invalid (min > max). + if (prop_bounds[p].second <= val) return JXL_FAILURE("Invalid tree"); + auto new_bounds = prop_bounds; + new_bounds[p].first = val + 1; + JXL_RETURN_IF_ERROR(ValidateTree(tree, new_bounds, tree[root].lchild)); + new_bounds[p] = prop_bounds[p]; + new_bounds[p].second = val; + return ValidateTree(tree, new_bounds, tree[root].rchild); +} + +Status DecodeTree(BitReader *br, ANSSymbolReader *reader, + const std::vector<uint8_t> &context_map, Tree *tree, + size_t tree_size_limit) { + size_t leaf_id = 0; + size_t to_decode = 1; + tree->clear(); + while (to_decode > 0) { + JXL_RETURN_IF_ERROR(br->AllReadsWithinBounds()); + if (tree->size() > tree_size_limit) { + return JXL_FAILURE("Tree is too large: %" PRIuS " nodes vs %" PRIuS + " max nodes", + tree->size(), tree_size_limit); + } + to_decode--; + uint32_t prop1 = reader->ReadHybridUint(kPropertyContext, br, context_map); + if (prop1 > 256) return JXL_FAILURE("Invalid tree property value"); + int property = prop1 - 1; + if (property == -1) { + size_t predictor = + reader->ReadHybridUint(kPredictorContext, br, context_map); + if (predictor >= kNumModularPredictors) { + return JXL_FAILURE("Invalid predictor"); + } + int64_t predictor_offset = + UnpackSigned(reader->ReadHybridUint(kOffsetContext, br, context_map)); + uint32_t mul_log = + reader->ReadHybridUint(kMultiplierLogContext, br, context_map); + if (mul_log >= 31) { + return JXL_FAILURE("Invalid multiplier logarithm"); + } + uint32_t mul_bits = + reader->ReadHybridUint(kMultiplierBitsContext, br, context_map); + if (mul_bits >= (1u << (31u - mul_log)) - 1u) { + return JXL_FAILURE("Invalid multiplier"); + } + uint32_t multiplier = (mul_bits + 1U) << mul_log; + tree->emplace_back(-1, 0, leaf_id++, 0, static_cast<Predictor>(predictor), + predictor_offset, multiplier); + continue; + } + int splitval = + UnpackSigned(reader->ReadHybridUint(kSplitValContext, br, context_map)); + tree->emplace_back(property, splitval, tree->size() + to_decode + 1, + tree->size() + to_decode + 2, Predictor::Zero, 0, 1); + to_decode += 2; + } + std::vector<std::pair<pixel_type, pixel_type>> prop_bounds; + prop_bounds.resize(256, {std::numeric_limits<pixel_type>::min(), + std::numeric_limits<pixel_type>::max()}); + return ValidateTree(*tree, prop_bounds, 0); +} +} // namespace + +Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit) { + std::vector<uint8_t> tree_context_map; + ANSCode tree_code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kNumTreeContexts, &tree_code, &tree_context_map)); + // TODO(eustas): investigate more infinite tree cases. + if (tree_code.degenerate_symbols[tree_context_map[kPropertyContext]] > 0) { + return JXL_FAILURE("Infinite tree"); + } + ANSSymbolReader reader(&tree_code, br); + JXL_RETURN_IF_ERROR(DecodeTree(br, &reader, tree_context_map, tree, + std::min(tree_size_limit, kMaxTreeSize))); + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("ANS decode final state failed"); + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h new file mode 100644 index 0000000000..a910c4deb1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/dec_ma.h @@ -0,0 +1,66 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ +#define LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// inner nodes +struct PropertyDecisionNode { + PropertyVal splitval; + int16_t property; // -1: leaf node, lchild points to leaf node + uint32_t lchild; + uint32_t rchild; + Predictor predictor; + int64_t predictor_offset; + uint32_t multiplier; + + PropertyDecisionNode(int p, int split_val, int lchild, int rchild, + Predictor predictor, int64_t predictor_offset, + uint32_t multiplier) + : splitval(split_val), + property(p), + lchild(lchild), + rchild(rchild), + predictor(predictor), + predictor_offset(predictor_offset), + multiplier(multiplier) {} + PropertyDecisionNode() + : splitval(0), + property(-1), + lchild(0), + rchild(0), + predictor(Predictor::Zero), + predictor_offset(0), + multiplier(1) {} + static PropertyDecisionNode Leaf(Predictor predictor, int64_t offset = 0, + uint32_t multiplier = 1) { + return PropertyDecisionNode(-1, 0, 0, 0, predictor, offset, multiplier); + } + static PropertyDecisionNode Split(int p, int split_val, int lchild, + int rchild = -1) { + if (rchild == -1) rchild = lchild + 1; + return PropertyDecisionNode(p, split_val, lchild, rchild, Predictor::Zero, + 0, 1); + } +}; + +using Tree = std::vector<PropertyDecisionNode>; + +Status DecodeTree(BitReader *br, Tree *tree, size_t tree_size_limit); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_DEC_MA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc new file mode 100644 index 0000000000..bd27f28458 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.cc @@ -0,0 +1,125 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/enc_debug_tree.h" + +#include <cinttypes> +#include <cstdint> +#include <cstdlib> + +#include "lib/jxl/base/os_macros.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/options.h" + +#if JXL_OS_IOS +#define JXL_ENABLE_DOT 0 +#else +#define JXL_ENABLE_DOT 1 // iOS lacks C89 system() +#endif + +namespace jxl { + +const char *PredictorName(Predictor p) { + switch (p) { + case Predictor::Zero: + return "Zero"; + case Predictor::Left: + return "Left"; + case Predictor::Top: + return "Top"; + case Predictor::Average0: + return "Avg0"; + case Predictor::Average1: + return "Avg1"; + case Predictor::Average2: + return "Avg2"; + case Predictor::Average3: + return "Avg3"; + case Predictor::Average4: + return "Avg4"; + case Predictor::Select: + return "Sel"; + case Predictor::Gradient: + return "Grd"; + case Predictor::Weighted: + return "Wgh"; + case Predictor::TopLeft: + return "TopL"; + case Predictor::TopRight: + return "TopR"; + case Predictor::LeftLeft: + return "LL"; + default: + return "INVALID"; + }; +} + +std::string PropertyName(size_t i) { + static_assert(kNumNonrefProperties == 16, "Update this function"); + switch (i) { + case 0: + return "c"; + case 1: + return "g"; + case 2: + return "y"; + case 3: + return "x"; + case 4: + return "|N|"; + case 5: + return "|W|"; + case 6: + return "N"; + case 7: + return "W"; + case 8: + return "W-WW-NW+NWW"; + case 9: + return "W+N-NW"; + case 10: + return "W-NW"; + case 11: + return "NW-N"; + case 12: + return "N-NE"; + case 13: + return "N-NN"; + case 14: + return "W-WW"; + case 15: + return "WGH"; + default: + return "ch[" + ToString(15 - (int)i) + "]"; + } +} + +void PrintTree(const Tree &tree, const std::string &path) { + FILE *f = fopen((path + ".dot").c_str(), "w"); + fprintf(f, "graph{\n"); + for (size_t cur = 0; cur < tree.size(); cur++) { + if (tree[cur].property < 0) { + fprintf(f, "n%05" PRIuS " [label=\"%s%+" PRId64 " (x%u)\"];\n", cur, + PredictorName(tree[cur].predictor), tree[cur].predictor_offset, + tree[cur].multiplier); + } else { + fprintf(f, "n%05" PRIuS " [label=\"%s>%d\"];\n", cur, + PropertyName(tree[cur].property).c_str(), tree[cur].splitval); + fprintf(f, "n%05" PRIuS " -- n%05d;\n", cur, tree[cur].lchild); + fprintf(f, "n%05" PRIuS " -- n%05d;\n", cur, tree[cur].rchild); + } + } + fprintf(f, "}\n"); + fclose(f); +#if JXL_ENABLE_DOT + JXL_ASSERT( + system(("dot " + path + ".dot -T svg -o " + path + ".svg").c_str()) == 0); +#endif +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h new file mode 100644 index 0000000000..78deaab1b8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_debug_tree.h @@ -0,0 +1,27 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +const char *PredictorName(Predictor p); +std::string PropertyName(size_t i); + +void PrintTree(const Tree &tree, const std::string &path); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENC_DEBUG_TREE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc new file mode 100644 index 0000000000..fc2e69e4a6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.cc @@ -0,0 +1,714 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <stdint.h> +#include <stdlib.h> + +#include <cinttypes> +#include <limits> +#include <numeric> +#include <queue> +#include <set> +#include <unordered_map> +#include <unordered_set> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/enc_debug_tree.h" +#include "lib/jxl/modular/encoding/enc_ma.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/encoding/ma_common.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" +#include "lib/jxl/pack_signed.h" +#include "lib/jxl/toc.h" + +namespace jxl { + +namespace { +// Plot tree (if enabled) and predictor usage map. +constexpr bool kWantDebug = true; +// constexpr bool kPrintTree = false; + +inline std::array<uint8_t, 3> PredictorColor(Predictor p) { + switch (p) { + case Predictor::Zero: + return {{0, 0, 0}}; + case Predictor::Left: + return {{255, 0, 0}}; + case Predictor::Top: + return {{0, 255, 0}}; + case Predictor::Average0: + return {{0, 0, 255}}; + case Predictor::Average4: + return {{192, 128, 128}}; + case Predictor::Select: + return {{255, 255, 0}}; + case Predictor::Gradient: + return {{255, 0, 255}}; + case Predictor::Weighted: + return {{0, 255, 255}}; + // TODO + default: + return {{255, 255, 255}}; + }; +} + +// `cutoffs` must be sorted. +Tree MakeFixedTree(int property, const std::vector<int32_t> &cutoffs, + Predictor pred, size_t num_pixels) { + size_t log_px = CeilLog2Nonzero(num_pixels); + size_t min_gap = 0; + // Reduce fixed tree height when encoding small images. + if (log_px < 14) { + min_gap = 8 * (14 - log_px); + } + Tree tree; + struct NodeInfo { + size_t begin, end, pos; + }; + std::queue<NodeInfo> q; + // Leaf IDs will be set by roundtrip decoding the tree. + tree.push_back(PropertyDecisionNode::Leaf(pred)); + q.push(NodeInfo{0, cutoffs.size(), 0}); + while (!q.empty()) { + NodeInfo info = q.front(); + q.pop(); + if (info.begin + min_gap >= info.end) continue; + uint32_t split = (info.begin + info.end) / 2; + tree[info.pos] = + PropertyDecisionNode::Split(property, cutoffs[split], tree.size()); + q.push(NodeInfo{split + 1, info.end, tree.size()}); + tree.push_back(PropertyDecisionNode::Leaf(pred)); + q.push(NodeInfo{info.begin, split, tree.size()}); + tree.push_back(PropertyDecisionNode::Leaf(pred)); + } + return tree; +} + +} // namespace + +void GatherTreeData(const Image &image, pixel_type chan, size_t group_id, + const weighted::Header &wp_header, + const ModularOptions &options, TreeSamples &tree_samples, + size_t *total_pixels) { + const Channel &channel = image.channel[chan]; + + JXL_DEBUG_V(7, "Learning %" PRIuS "x%" PRIuS " channel %d", channel.w, + channel.h, chan); + + std::array<pixel_type, kNumStaticProperties> static_props = { + {chan, (int)group_id}}; + Properties properties(kNumNonrefProperties + + kExtraPropsPerChannel * options.max_properties); + double pixel_fraction = std::min(1.0f, options.nb_repeats); + // a fraction of 0 is used to disable learning entirely. + if (pixel_fraction > 0) { + pixel_fraction = std::max(pixel_fraction, + std::min(1.0, 1024.0 / (channel.w * channel.h))); + } + uint64_t threshold = + (std::numeric_limits<uint64_t>::max() >> 32) * pixel_fraction; + uint64_t s[2] = {static_cast<uint64_t>(0x94D049BB133111EBull), + static_cast<uint64_t>(0xBF58476D1CE4E5B9ull)}; + // Xorshift128+ adapted from xorshift128+-inl.h + auto use_sample = [&]() { + auto s1 = s[0]; + const auto s0 = s[1]; + const auto bits = s1 + s0; // b, c + s[0] = s0; + s1 ^= s1 << 23; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s[1] = s1; + return (bits >> 32) <= threshold; + }; + + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + tree_samples.PrepareForSamples(pixel_fraction * channel.h * channel.w + 64); + const bool multiple_predictors = tree_samples.NumPredictors() != 1; + auto compute_sample = [&](const pixel_type *p, size_t x, size_t y) { + pixel_type_w pred[kNumModularPredictors]; + if (multiple_predictors) { + PredictLearnAll(&properties, channel.w, p + x, onerow, x, y, references, + &wp_state, pred); + } else { + pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] = + PredictLearn(&properties, channel.w, p + x, onerow, x, y, + tree_samples.PredictorFromIndex(0), references, + &wp_state) + .guess; + } + (*total_pixels)++; + if (use_sample()) { + tree_samples.AddSample(p[x], properties, pred); + } + wp_state.UpdateErrors(p[x], x, y, channel.w); + }; + + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + InitPropsRow(&properties, static_props, y); + + // TODO(veluca): avoid computing WP if we don't use its property or + // predictions. + if (y > 1 && channel.w > 8 && references.w == 0) { + for (size_t x = 0; x < 2; x++) { + compute_sample(p, x, y); + } + for (size_t x = 2; x < channel.w - 2; x++) { + pixel_type_w pred[kNumModularPredictors]; + if (multiple_predictors) { + PredictLearnAllNEC(&properties, channel.w, p + x, onerow, x, y, + references, &wp_state, pred); + } else { + pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] = + PredictLearnNEC(&properties, channel.w, p + x, onerow, x, y, + tree_samples.PredictorFromIndex(0), references, + &wp_state) + .guess; + } + (*total_pixels)++; + if (use_sample()) { + tree_samples.AddSample(p[x], properties, pred); + } + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + for (size_t x = channel.w - 2; x < channel.w; x++) { + compute_sample(p, x, y); + } + } else { + for (size_t x = 0; x < channel.w; x++) { + compute_sample(p, x, y); + } + } + } +} + +Tree PredefinedTree(ModularOptions::TreeKind tree_kind, size_t total_pixels) { + if (tree_kind == ModularOptions::TreeKind::kJpegTranscodeACMeta || + tree_kind == ModularOptions::TreeKind::kTrivialTreeNoPredictor) { + // All the data is 0, so no need for a fancy tree. + return {PropertyDecisionNode::Leaf(Predictor::Zero)}; + } + if (tree_kind == ModularOptions::TreeKind::kFalconACMeta) { + // All the data is 0 except the quant field. TODO(veluca): make that 0 too. + return {PropertyDecisionNode::Leaf(Predictor::Left)}; + } + if (tree_kind == ModularOptions::TreeKind::kACMeta) { + // Small image. + if (total_pixels < 1024) { + return {PropertyDecisionNode::Leaf(Predictor::Left)}; + } + Tree tree; + // 0: c > 1 + tree.push_back(PropertyDecisionNode::Split(0, 1, 1)); + // 1: c > 2 + tree.push_back(PropertyDecisionNode::Split(0, 2, 3)); + // 2: c > 0 + tree.push_back(PropertyDecisionNode::Split(0, 0, 5)); + // 3: EPF control field (all 0 or 4), top > 0 + tree.push_back(PropertyDecisionNode::Split(6, 0, 21)); + // 4: ACS+QF, y > 0 + tree.push_back(PropertyDecisionNode::Split(2, 0, 7)); + // 5: CfL x + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient)); + // 6: CfL b + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient)); + // 7: QF: split according to the left quant value. + tree.push_back(PropertyDecisionNode::Split(7, 5, 9)); + // 8: ACS: split in 4 segments (8x8 from 0 to 3, large square 4-5, large + // rectangular 6-11, 8x8 12+), according to previous ACS value. + tree.push_back(PropertyDecisionNode::Split(7, 5, 15)); + // QF + tree.push_back(PropertyDecisionNode::Split(7, 11, 11)); + tree.push_back(PropertyDecisionNode::Split(7, 3, 13)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left)); + // ACS + tree.push_back(PropertyDecisionNode::Split(7, 11, 17)); + tree.push_back(PropertyDecisionNode::Split(7, 3, 19)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + // EPF, left > 0 + tree.push_back(PropertyDecisionNode::Split(7, 0, 23)); + tree.push_back(PropertyDecisionNode::Split(7, 0, 25)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero)); + return tree; + } + if (tree_kind == ModularOptions::TreeKind::kWPFixedDC) { + std::vector<int32_t> cutoffs = { + -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15, + -11, -7, -4, -3, -1, 0, 1, 3, 5, 7, 11, + 15, 23, 31, 47, 63, 95, 127, 191, 255, 392, 500}; + return MakeFixedTree(kWPProp, cutoffs, Predictor::Weighted, total_pixels); + } + if (tree_kind == ModularOptions::TreeKind::kGradientFixedDC) { + std::vector<int32_t> cutoffs = { + -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15, + -11, -7, -4, -3, -1, 0, 1, 3, 5, 7, 11, + 15, 23, 31, 47, 63, 95, 127, 191, 255, 392, 500}; + return MakeFixedTree(kGradientProp, cutoffs, Predictor::Gradient, + total_pixels); + } + JXL_UNREACHABLE("Unreachable"); + return {}; +} + +Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels, + const ModularOptions &options, + const std::vector<ModularMultiplierInfo> &multiplier_info = {}, + StaticPropRange static_prop_range = {}) { + for (size_t i = 0; i < kNumStaticProperties; i++) { + if (static_prop_range[i][1] == 0) { + static_prop_range[i][1] = std::numeric_limits<uint32_t>::max(); + } + } + if (!tree_samples.HasSamples()) { + Tree tree; + tree.emplace_back(); + tree.back().predictor = tree_samples.PredictorFromIndex(0); + tree.back().property = -1; + tree.back().predictor_offset = 0; + tree.back().multiplier = 1; + return tree; + } + float pixel_fraction = tree_samples.NumSamples() * 1.0f / total_pixels; + float required_cost = pixel_fraction * 0.9 + 0.1; + tree_samples.AllSamplesDone(); + Tree tree; + ComputeBestTree(tree_samples, + options.splitting_heuristics_node_threshold * required_cost, + multiplier_info, static_prop_range, + options.fast_decode_multiplier, &tree); + return tree; +} + +Status EncodeModularChannelMAANS(const Image &image, pixel_type chan, + const weighted::Header &wp_header, + const Tree &global_tree, Token **tokenpp, + AuxOut *aux_out, size_t group_id, + bool skip_encoder_fast_path) { + const Channel &channel = image.channel[chan]; + Token *tokenp = *tokenpp; + JXL_ASSERT(channel.w != 0 && channel.h != 0); + + Image3F predictor_img; + if (kWantDebug) predictor_img = Image3F(channel.w, channel.h); + + JXL_DEBUG_V(6, + "Encoding %" PRIuS "x%" PRIuS + " channel %d, " + "(shift=%i,%i)", + channel.w, channel.h, chan, channel.hshift, channel.vshift); + + std::array<pixel_type, kNumStaticProperties> static_props = { + {chan, (int)group_id}}; + bool use_wp, is_wp_only; + bool is_gradient_only; + size_t num_props; + FlatTree tree = FilterTree(global_tree, static_props, &num_props, &use_wp, + &is_wp_only, &is_gradient_only); + Properties properties(num_props); + MATreeLookup tree_lookup(tree); + JXL_DEBUG_V(3, "Encoding using a MA tree with %" PRIuS " nodes", tree.size()); + + // Check if this tree is a WP-only tree with a small enough property value + // range. + // Initialized to avoid clang-tidy complaining. + auto tree_lut = jxl::make_unique<TreeLut<uint16_t, false>>(); + if (is_wp_only) { + is_wp_only = TreeToLookupTable(tree, *tree_lut); + } + if (is_gradient_only) { + is_gradient_only = TreeToLookupTable(tree, *tree_lut); + } + + if (is_wp_only && !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Weighted)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + weighted::State wp_state(wp_header, channel.w, channel.h); + Properties properties(1); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + size_t offset = 0; + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type_w topright = + (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top); + pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top); + int32_t guess = wp_state.Predict</*compute_properties=*/true>( + x, y, channel.w, top, left, topright, topleft, toptop, &properties, + offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = tree_lut->context_lookup[pos]; + int32_t residual = r[x] - guess - tree_lut->offsets[pos]; + *tokenp++ = Token(ctx_id, PackSigned(residual)); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } else if (tree.size() == 1 && tree[0].predictor == Predictor::Gradient && + tree[0].multiplier == 1 && tree[0].predictor_offset == 0 && + !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + int32_t residual = r[x] - guess; + *tokenp++ = Token(tree[0].childID, PackSigned(residual)); + } + } + } else if (is_gradient_only && !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]), + &predictor_img.Plane(c)); + } + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + uint32_t pos = + kPropRangeFast + + std::min<pixel_type_w>( + std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft), + kPropRangeFast - 1); + uint32_t ctx_id = tree_lut->context_lookup[pos]; + int32_t residual = r[x] - guess - tree_lut->offsets[pos]; + *tokenp++ = Token(ctx_id, PackSigned(residual)); + } + } + } else if (tree.size() == 1 && tree[0].predictor == Predictor::Zero && + tree[0].multiplier == 1 && tree[0].predictor_offset == 0 && + !skip_encoder_fast_path) { + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(Predictor::Zero)[c]), + &predictor_img.Plane(c)); + } + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + *tokenp++ = Token(tree[0].childID, PackSigned(p[x])); + } + } + } else if (tree.size() == 1 && tree[0].predictor != Predictor::Weighted && + (tree[0].multiplier & (tree[0].multiplier - 1)) == 0 && + tree[0].predictor_offset == 0 && !skip_encoder_fast_path) { + // multiplier is a power of 2. + for (size_t c = 0; c < 3; c++) { + FillImage(static_cast<float>(PredictorColor(tree[0].predictor)[c]), + &predictor_img.Plane(c)); + } + uint32_t mul_shift = FloorLog2Nonzero((uint32_t)tree[0].multiplier); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult pred = PredictNoTreeNoWP(channel.w, r + x, onerow, x, + y, tree[0].predictor); + pixel_type_w residual = r[x] - pred.guess; + JXL_DASSERT((residual >> mul_shift) * tree[0].multiplier == residual); + *tokenp++ = Token(tree[0].childID, PackSigned(residual >> mul_shift)); + } + } + + } else if (!use_wp && !skip_encoder_fast_path) { + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + float *pred_img_row[3]; + if (kWantDebug) { + for (size_t c = 0; c < 3; c++) { + pred_img_row[c] = predictor_img.PlaneRow(c, y); + } + } + InitPropsRow(&properties, static_props, y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + if (kWantDebug) { + for (size_t i = 0; i < 3; i++) { + pred_img_row[i][x] = PredictorColor(res.predictor)[i]; + } + } + pixel_type_w residual = p[x] - res.guess; + JXL_DASSERT(residual % res.multiplier == 0); + *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier)); + } + } + } else { + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + const pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, image, chan, &references); + float *pred_img_row[3]; + if (kWantDebug) { + for (size_t c = 0; c < 3; c++) { + pred_img_row[c] = predictor_img.PlaneRow(c, y); + } + } + InitPropsRow(&properties, static_props, y); + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + if (kWantDebug) { + for (size_t i = 0; i < 3; i++) { + pred_img_row[i][x] = PredictorColor(res.predictor)[i]; + } + } + pixel_type_w residual = p[x] - res.guess; + JXL_DASSERT(residual % res.multiplier == 0); + *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier)); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + } + /* TODO(szabadka): Add cparams to the call stack here. + if (kWantDebug && WantDebugOutput(cparams)) { + DumpImage( + cparams, + ("pred_" + ToString(group_id) + "_" + ToString(chan)).c_str(), + predictor_img); + } + */ + *tokenpp = tokenp; + return true; +} + +Status ModularEncode(const Image &image, const ModularOptions &options, + BitWriter *writer, AuxOut *aux_out, size_t layer, + size_t group_id, TreeSamples *tree_samples, + size_t *total_pixels, const Tree *tree, + GroupHeader *header, std::vector<Token> *tokens, + size_t *width) { + if (image.error) return JXL_FAILURE("Invalid image"); + size_t nb_channels = image.channel.size(); + JXL_DEBUG_V( + 2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.", + nb_channels, image.bitdepth, image.w, image.h); + + if (nb_channels < 1) { + return true; // is there any use for a zero-channel image? + } + + // encode transforms + GroupHeader header_storage; + if (header == nullptr) header = &header_storage; + Bundle::Init(header); + if (options.predictor == Predictor::Weighted) { + weighted::PredictorMode(options.wp_mode, &header->wp_header); + } + header->transforms = image.transform; + // This doesn't actually work + if (tree != nullptr) { + header->use_global_tree = true; + } + if (tree_samples == nullptr && tree == nullptr) { + JXL_RETURN_IF_ERROR(Bundle::Write(*header, writer, layer, aux_out)); + } + + TreeSamples tree_samples_storage; + size_t total_pixels_storage = 0; + if (!total_pixels) total_pixels = &total_pixels_storage; + // If there's no tree, compute one (or gather data to). + if (tree == nullptr) { + bool gather_data = tree_samples != nullptr; + if (tree_samples == nullptr) { + JXL_RETURN_IF_ERROR(tree_samples_storage.SetPredictor( + options.predictor, options.wp_tree_mode)); + JXL_RETURN_IF_ERROR(tree_samples_storage.SetProperties( + options.splitting_heuristics_properties, options.wp_tree_mode)); + std::vector<pixel_type> pixel_samples; + std::vector<pixel_type> diff_samples; + std::vector<uint32_t> group_pixel_count; + std::vector<uint32_t> channel_pixel_count; + CollectPixelSamples(image, options, 0, group_pixel_count, + channel_pixel_count, pixel_samples, diff_samples); + std::vector<ModularMultiplierInfo> placeholder_multiplier_info; + StaticPropRange range; + tree_samples_storage.PreQuantizeProperties( + range, placeholder_multiplier_info, group_pixel_count, + channel_pixel_count, pixel_samples, diff_samples, + options.max_property_values); + } + for (size_t i = 0; i < nb_channels; i++) { + if (!image.channel[i].w || !image.channel[i].h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + GatherTreeData(image, i, group_id, header->wp_header, options, + gather_data ? *tree_samples : tree_samples_storage, + total_pixels); + } + if (gather_data) return true; + } + + JXL_ASSERT((tree == nullptr) == (tokens == nullptr)); + + Tree tree_storage; + std::vector<std::vector<Token>> tokens_storage(1); + // Compute tree. + if (tree == nullptr) { + EntropyEncodingData code; + std::vector<uint8_t> context_map; + + std::vector<std::vector<Token>> tree_tokens(1); + + tree_storage = + options.tree_kind == ModularOptions::TreeKind::kLearn + ? LearnTree(std::move(tree_samples_storage), *total_pixels, options) + : PredefinedTree(options.tree_kind, *total_pixels); + tree = &tree_storage; + tokens = &tokens_storage[0]; + + Tree decoded_tree; + TokenizeTree(*tree, &tree_tokens[0], &decoded_tree); + JXL_ASSERT(tree->size() == decoded_tree.size()); + tree_storage = std::move(decoded_tree); + + /* TODO(szabadka) Add text output callback + if (kWantDebug && kPrintTree && WantDebugOutput(aux_out)) { + PrintTree(*tree, aux_out->debug_prefix + "/tree_" + ToString(group_id)); + } */ + + // Write tree + BuildAndEncodeHistograms(HistogramParams(), kNumTreeContexts, tree_tokens, + &code, &context_map, writer, kLayerModularTree, + aux_out); + WriteTokens(tree_tokens[0], code, context_map, 0, writer, kLayerModularTree, + aux_out); + } + + size_t image_width = 0; + size_t total_tokens = 0; + for (size_t i = 0; i < nb_channels; i++) { + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + if (image.channel[i].w > image_width) image_width = image.channel[i].w; + total_tokens += image.channel[i].w * image.channel[i].h; + } + if (options.zero_tokens) { + tokens->resize(tokens->size() + total_tokens, {0, 0}); + } else { + // Do one big allocation for all the tokens we'll need, + // to avoid reallocs that might require copying. + size_t pos = tokens->size(); + tokens->resize(pos + total_tokens); + Token *tokenp = tokens->data() + pos; + for (size_t i = 0; i < nb_channels; i++) { + if (!image.channel[i].w || !image.channel[i].h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + JXL_RETURN_IF_ERROR(EncodeModularChannelMAANS( + image, i, header->wp_header, *tree, &tokenp, aux_out, group_id, + options.skip_encoder_fast_path)); + } + // Make sure we actually wrote all tokens + JXL_CHECK(tokenp == tokens->data() + tokens->size()); + } + + // Write data if not using a global tree/ANS stream. + if (!header->use_global_tree) { + EntropyEncodingData code; + std::vector<uint8_t> context_map; + HistogramParams histo_params; + histo_params.image_widths.push_back(image_width); + BuildAndEncodeHistograms(histo_params, (tree->size() + 1) / 2, + tokens_storage, &code, &context_map, writer, layer, + aux_out); + WriteTokens(tokens_storage[0], code, context_map, 0, writer, layer, + aux_out); + } else { + *width = image_width; + } + return true; +} + +Status ModularGenericCompress(Image &image, const ModularOptions &opts, + BitWriter *writer, AuxOut *aux_out, size_t layer, + size_t group_id, TreeSamples *tree_samples, + size_t *total_pixels, const Tree *tree, + GroupHeader *header, std::vector<Token> *tokens, + size_t *width) { + if (image.w == 0 || image.h == 0) return true; + ModularOptions options = opts; // Make a copy to modify it. + + if (options.predictor == static_cast<Predictor>(-1)) { + options.predictor = Predictor::Gradient; + } + + size_t bits = writer ? writer->BitsWritten() : 0; + JXL_RETURN_IF_ERROR(ModularEncode(image, options, writer, aux_out, layer, + group_id, tree_samples, total_pixels, tree, + header, tokens, width)); + bits = writer ? writer->BitsWritten() - bits : 0; + if (writer) { + JXL_DEBUG_V(4, + "Modular-encoded a %" PRIuS "x%" PRIuS + " bitdepth=%i nbchans=%" PRIuS " image in %" PRIuS " bytes", + image.w, image.h, image.bitdepth, image.channel.size(), + bits / 8); + } + (void)bits; + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h new file mode 100644 index 0000000000..d610edaf0b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_encoding.h @@ -0,0 +1,44 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ + +#include <cstddef> +#include <vector> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/encoding/enc_ma.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +struct AuxOut; +struct GroupHeader; + +Tree PredefinedTree(ModularOptions::TreeKind tree_kind, size_t total_pixels); + +Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels, + const ModularOptions &options, + const std::vector<ModularMultiplierInfo> &multiplier_info = {}, + StaticPropRange static_prop_range = {}); + +// TODO(veluca): make cleaner interfaces. + +Status ModularGenericCompress( + Image &image, const ModularOptions &opts, BitWriter *writer, + AuxOut *aux_out = nullptr, size_t layer = 0, size_t group_id = 0, + // For gathering data for producing a global tree. + TreeSamples *tree_samples = nullptr, size_t *total_pixels = nullptr, + // For encoding with global tree. + const Tree *tree = nullptr, GroupHeader *header = nullptr, + std::vector<Token> *tokens = nullptr, size_t *widths = nullptr); +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENC_ENCODING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc new file mode 100644 index 0000000000..ef72b2477b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.cc @@ -0,0 +1,1012 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/enc_ma.h" + +#include <algorithm> +#include <limits> +#include <numeric> +#include <queue> +#include <unordered_map> +#include <unordered_set> + +#include "lib/jxl/modular/encoding/ma_common.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/fast_math-inl.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/pack_signed.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Eq; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::Max; + +const HWY_FULL(float) df; +const HWY_FULL(int32_t) di; +size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } + +// Compute entropy of the histogram, taking into account the minimum probability +// for symbols with non-zero counts. +float EstimateBits(const int32_t *counts, size_t num_symbols) { + int32_t total = std::accumulate(counts, counts + num_symbols, 0); + const auto zero = Zero(df); + const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE); + const auto inv_total = Set(df, 1.0f / total); + auto bits_lanes = Zero(df); + auto total_v = Set(di, total); + for (size_t i = 0; i < num_symbols; i += Lanes(df)) { + const auto counts_iv = LoadU(di, &counts[i]); + const auto counts_fv = ConvertTo(df, counts_iv); + const auto probs = Mul(counts_fv, inv_total); + const auto mprobs = Max(probs, minprob); + const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero), + BitCast(di, FastLog2f(df, mprobs))); + bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps))); + } + return GetLane(SumOfLanes(df, bits_lanes)); +} + +void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred, + int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { + // Note that the tree splits on *strictly greater*. + (*tree)[pos].lchild = tree->size(); + (*tree)[pos].rchild = tree->size() + 1; + (*tree)[pos].splitval = splitval; + (*tree)[pos].property = property; + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = rpred; + tree->back().predictor_offset = roff; + tree->back().multiplier = 1; + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = lpred; + tree->back().predictor_offset = loff; + tree->back().multiplier = 1; +} + +enum class IntersectionType { kNone, kPartial, kInside }; +IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack, + uint32_t &partial_axis, uint32_t &partial_val) { + bool partial = false; + for (size_t i = 0; i < kNumStaticProperties; i++) { + if (haystack[i][0] >= needle[i][1]) { + return IntersectionType::kNone; + } + if (haystack[i][1] <= needle[i][0]) { + return IntersectionType::kNone; + } + if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { + continue; + } + partial = true; + partial_axis = i; + if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { + partial_val = haystack[i][0] - 1; + } else { + JXL_DASSERT(haystack[i][1] > needle[i][0] && + haystack[i][1] < needle[i][1]); + partial_val = haystack[i][1] - 1; + } + } + return partial ? IntersectionType::kPartial : IntersectionType::kInside; +} + +void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos, + size_t end, size_t prop) { + auto cmp = [&](size_t a, size_t b) { + return int32_t(tree_samples.Property(prop, a)) - + int32_t(tree_samples.Property(prop, b)); + }; + Rng rng(0); + while (end > begin + 1) { + { + size_t pivot = rng.UniformU(begin, end); + tree_samples.Swap(begin, pivot); + } + size_t pivot_begin = begin; + size_t pivot_end = pivot_begin + 1; + for (size_t i = begin + 1; i < end; i++) { + JXL_DASSERT(i >= pivot_end); + JXL_DASSERT(pivot_end > pivot_begin); + int32_t cmp_result = cmp(i, pivot_begin); + if (cmp_result < 0) { // i < pivot, move pivot forward and put i before + // the pivot. + tree_samples.ThreeShuffle(pivot_begin, pivot_end, i); + pivot_begin++; + pivot_end++; + } else if (cmp_result == 0) { + tree_samples.Swap(pivot_end, i); + pivot_end++; + } + } + JXL_DASSERT(pivot_begin >= begin); + JXL_DASSERT(pivot_end > pivot_begin); + JXL_DASSERT(pivot_end <= end); + for (size_t i = begin; i < pivot_begin; i++) { + JXL_DASSERT(cmp(i, pivot_begin) < 0); + } + for (size_t i = pivot_end; i < end; i++) { + JXL_DASSERT(cmp(i, pivot_begin) > 0); + } + for (size_t i = pivot_begin; i < pivot_end; i++) { + JXL_DASSERT(cmp(i, pivot_begin) == 0); + } + // We now have that [begin, pivot_begin) is < pivot, [pivot_begin, + // pivot_end) is = pivot, and [pivot_end, end) is > pivot. + // If pos falls in the first or the last interval, we continue in that + // interval; otherwise, we are done. + if (pivot_begin > pos) { + end = pivot_begin; + } else if (pivot_end < pos) { + begin = pivot_end; + } else { + break; + } + } +} + +void FindBestSplit(TreeSamples &tree_samples, float threshold, + const std::vector<ModularMultiplierInfo> &mul_info, + StaticPropRange initial_static_prop_range, + float fast_decode_multiplier, Tree *tree) { + struct NodeInfo { + size_t pos; + size_t begin; + size_t end; + uint64_t used_properties; + StaticPropRange static_prop_range; + }; + std::vector<NodeInfo> nodes; + nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0, + initial_static_prop_range}); + + size_t num_predictors = tree_samples.NumPredictors(); + size_t num_properties = tree_samples.NumProperties(); + + // TODO(veluca): consider parallelizing the search (processing multiple nodes + // at a time). + while (!nodes.empty()) { + size_t pos = nodes.back().pos; + size_t begin = nodes.back().begin; + size_t end = nodes.back().end; + uint64_t used_properties = nodes.back().used_properties; + StaticPropRange static_prop_range = nodes.back().static_prop_range; + nodes.pop_back(); + if (begin == end) continue; + + struct SplitInfo { + size_t prop = 0; + uint32_t val = 0; + size_t pos = 0; + float lcost = std::numeric_limits<float>::max(); + float rcost = std::numeric_limits<float>::max(); + Predictor lpred = Predictor::Zero; + Predictor rpred = Predictor::Zero; + float Cost() { return lcost + rcost; } + }; + + SplitInfo best_split_static_constant; + SplitInfo best_split_static; + SplitInfo best_split_nonstatic; + SplitInfo best_split_nowp; + + JXL_DASSERT(begin <= end); + JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); + + // Compute the maximum token in the range. + size_t max_symbols = 0; + for (size_t pred = 0; pred < num_predictors; pred++) { + for (size_t i = begin; i < end; i++) { + uint32_t tok = tree_samples.Token(pred, i); + max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; + } + } + max_symbols = Padded(max_symbols); + std::vector<int32_t> counts(max_symbols * num_predictors); + std::vector<uint32_t> tot_extra_bits(num_predictors); + for (size_t pred = 0; pred < num_predictors; pred++) { + for (size_t i = begin; i < end; i++) { + counts[pred * max_symbols + tree_samples.Token(pred, i)] += + tree_samples.Count(i); + tot_extra_bits[pred] += + tree_samples.NBits(pred, i) * tree_samples.Count(i); + } + } + + float base_bits; + { + size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); + base_bits = + EstimateBits(counts.data() + pred * max_symbols, max_symbols) + + tot_extra_bits[pred]; + } + + SplitInfo *best = &best_split_nonstatic; + + SplitInfo forced_split; + // The multiplier ranges cut halfway through the current ranges of static + // properties. We do this even if the current node is not a leaf, to + // minimize the number of nodes in the resulting tree. + for (size_t i = 0; i < mul_info.size(); i++) { + uint32_t axis, val; + IntersectionType t = + BoxIntersects(static_prop_range, mul_info[i].range, axis, val); + if (t == IntersectionType::kNone) continue; + if (t == IntersectionType::kInside) { + (*tree)[pos].multiplier = mul_info[i].multiplier; + break; + } + if (t == IntersectionType::kPartial) { + forced_split.val = tree_samples.QuantizeProperty(axis, val); + forced_split.prop = axis; + forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; + forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; + best = &forced_split; + best->pos = begin; + JXL_ASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); + for (size_t x = begin; x < end; x++) { + if (tree_samples.Property(best->prop, x) <= best->val) { + best->pos++; + } + } + break; + } + } + + if (best != &forced_split) { + std::vector<int> prop_value_used_count; + std::vector<int> count_increase; + std::vector<size_t> extra_bits_increase; + // For each property, compute which of its values are used, and what + // tokens correspond to those usages. Then, iterate through the values, + // and compute the entropy of each side of the split (of the form `prop > + // threshold`). Finally, find the split that minimizes the cost. + struct CostInfo { + float cost = std::numeric_limits<float>::max(); + float extra_cost = 0; + float Cost() const { return cost + extra_cost; } + Predictor pred; // will be uninitialized in some cases, but never used. + }; + std::vector<CostInfo> costs_l; + std::vector<CostInfo> costs_r; + + std::vector<int32_t> counts_above(max_symbols); + std::vector<int32_t> counts_below(max_symbols); + + // The lower the threshold, the higher the expected noisiness of the + // estimate. Thus, discourage changing predictors. + float change_pred_penalty = 800.0f / (100.0f + threshold); + for (size_t prop = 0; prop < num_properties && base_bits > threshold; + prop++) { + costs_l.clear(); + costs_r.clear(); + size_t prop_size = tree_samples.NumPropertyValues(prop); + if (extra_bits_increase.size() < prop_size) { + count_increase.resize(prop_size * max_symbols); + extra_bits_increase.resize(prop_size); + } + // Clear prop_value_used_count (which cannot be cleared "on the go") + prop_value_used_count.clear(); + prop_value_used_count.resize(prop_size); + + size_t first_used = prop_size; + size_t last_used = 0; + + // TODO(veluca): consider finding multiple splits along a single + // property at the same time, possibly with a bottom-up approach. + for (size_t i = begin; i < end; i++) { + size_t p = tree_samples.Property(prop, i); + prop_value_used_count[p]++; + last_used = std::max(last_used, p); + first_used = std::min(first_used, p); + } + costs_l.resize(last_used - first_used); + costs_r.resize(last_used - first_used); + // For all predictors, compute the right and left costs of each split. + for (size_t pred = 0; pred < num_predictors; pred++) { + // Compute cost and histogram increments for each property value. + for (size_t i = begin; i < end; i++) { + size_t p = tree_samples.Property(prop, i); + size_t cnt = tree_samples.Count(i); + size_t sym = tree_samples.Token(pred, i); + count_increase[p * max_symbols + sym] += cnt; + extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt; + } + memcpy(counts_above.data(), counts.data() + pred * max_symbols, + max_symbols * sizeof counts_above[0]); + memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); + size_t extra_bits_below = 0; + // Exclude last used: this ensures neither counts_above nor + // counts_below is empty. + for (size_t i = first_used; i < last_used; i++) { + if (!prop_value_used_count[i]) continue; + extra_bits_below += extra_bits_increase[i]; + // The increase for this property value has been used, and will not + // be used again: clear it. Also below. + extra_bits_increase[i] = 0; + for (size_t sym = 0; sym < max_symbols; sym++) { + counts_above[sym] -= count_increase[i * max_symbols + sym]; + counts_below[sym] += count_increase[i * max_symbols + sym]; + count_increase[i * max_symbols + sym] = 0; + } + float rcost = EstimateBits(counts_above.data(), max_symbols) + + tot_extra_bits[pred] - extra_bits_below; + float lcost = EstimateBits(counts_below.data(), max_symbols) + + extra_bits_below; + JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); + float penalty = 0; + // Never discourage moving away from the Weighted predictor. + if (tree_samples.PredictorFromIndex(pred) != + (*tree)[pos].predictor && + (*tree)[pos].predictor != Predictor::Weighted) { + penalty = change_pred_penalty; + } + // If everything else is equal, disfavour Weighted (slower) and + // favour Zero (faster if it's the only predictor used in a + // group+channel combination) + if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { + penalty += 1e-8; + } + if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { + penalty -= 1e-8; + } + if (rcost + penalty < costs_r[i - first_used].Cost()) { + costs_r[i - first_used].cost = rcost; + costs_r[i - first_used].extra_cost = penalty; + costs_r[i - first_used].pred = + tree_samples.PredictorFromIndex(pred); + } + if (lcost + penalty < costs_l[i - first_used].Cost()) { + costs_l[i - first_used].cost = lcost; + costs_l[i - first_used].extra_cost = penalty; + costs_l[i - first_used].pred = + tree_samples.PredictorFromIndex(pred); + } + } + } + // Iterate through the possible splits and find the one with minimum sum + // of costs of the two sides. + size_t split = begin; + for (size_t i = first_used; i < last_used; i++) { + if (!prop_value_used_count[i]) continue; + split += prop_value_used_count[i]; + float rcost = costs_r[i - first_used].cost; + float lcost = costs_l[i - first_used].cost; + // WP was not used + we would use the WP property or predictor + bool adds_wp = + (tree_samples.PropertyFromIndex(prop) == kWPProp && + (used_properties & (1LU << prop)) == 0) || + ((costs_l[i - first_used].pred == Predictor::Weighted || + costs_r[i - first_used].pred == Predictor::Weighted) && + (*tree)[pos].predictor != Predictor::Weighted); + bool zero_entropy_side = rcost == 0 || lcost == 0; + + SplitInfo &best = + prop < kNumStaticProperties + ? (zero_entropy_side ? best_split_static_constant + : best_split_static) + : (adds_wp ? best_split_nonstatic : best_split_nowp); + if (lcost + rcost < best.Cost()) { + best.prop = prop; + best.val = i; + best.pos = split; + best.lcost = lcost; + best.lpred = costs_l[i - first_used].pred; + best.rcost = rcost; + best.rpred = costs_r[i - first_used].pred; + } + } + // Clear extra_bits_increase and cost_increase for last_used. + extra_bits_increase[last_used] = 0; + for (size_t sym = 0; sym < max_symbols; sym++) { + count_increase[last_used * max_symbols + sym] = 0; + } + } + + // Try to avoid introducing WP. + if (best_split_nowp.Cost() + threshold < base_bits && + best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { + best = &best_split_nowp; + } + // Split along static props if possible and not significantly more + // expensive. + if (best_split_static.Cost() + threshold < base_bits && + best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { + best = &best_split_static; + } + // Split along static props to create constant nodes if possible. + if (best_split_static_constant.Cost() + threshold < base_bits) { + best = &best_split_static_constant; + } + } + + if (best->Cost() + threshold < base_bits) { + uint32_t p = tree_samples.PropertyFromIndex(best->prop); + pixel_type dequant = + tree_samples.UnquantizeProperty(best->prop, best->val); + // Split node and try to split children. + MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); + // "Sort" according to winning property + SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop); + if (p >= kNumStaticProperties) { + used_properties |= 1 << best->prop; + } + auto new_sp_range = static_prop_range; + if (p < kNumStaticProperties) { + JXL_ASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); + new_sp_range[p][1] = dequant + 1; + JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); + } + nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos, + used_properties, new_sp_range}); + new_sp_range = static_prop_range; + if (p < kNumStaticProperties) { + JXL_ASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); + new_sp_range[p][0] = dequant + 1; + JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); + } + nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end, + used_properties, new_sp_range}); + } + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(FindBestSplit); // Local function. + +void ComputeBestTree(TreeSamples &tree_samples, float threshold, + const std::vector<ModularMultiplierInfo> &mul_info, + StaticPropRange static_prop_range, + float fast_decode_multiplier, Tree *tree) { + // TODO(veluca): take into account that different contexts can have different + // uint configs. + // + // Initialize tree. + tree->emplace_back(); + tree->back().property = -1; + tree->back().predictor = tree_samples.PredictorFromIndex(0); + tree->back().predictor_offset = 0; + tree->back().multiplier = 1; + JXL_ASSERT(tree_samples.NumProperties() < 64); + + JXL_ASSERT(tree_samples.NumDistinctSamples() <= + std::numeric_limits<uint32_t>::max()); + HWY_DYNAMIC_DISPATCH(FindBestSplit) + (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier, + tree); +} + +constexpr int32_t TreeSamples::kPropertyRange; +constexpr uint32_t TreeSamples::kDedupEntryUnused; + +Status TreeSamples::SetPredictor(Predictor predictor, + ModularOptions::TreeMode wp_tree_mode) { + if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { + predictors = {Predictor::Weighted}; + residuals.resize(1); + return true; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP && + predictor == Predictor::Weighted) { + return JXL_FAILURE("Invalid predictor settings"); + } + if (predictor == Predictor::Variable) { + for (size_t i = 0; i < kNumModularPredictors; i++) { + predictors.push_back(static_cast<Predictor>(i)); + } + std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]); + std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]); + } else if (predictor == Predictor::Best) { + predictors = {Predictor::Weighted, Predictor::Gradient}; + } else { + predictors = {predictor}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { + auto wp_it = + std::find(predictors.begin(), predictors.end(), Predictor::Weighted); + if (wp_it != predictors.end()) { + predictors.erase(wp_it); + } + } + residuals.resize(predictors.size()); + return true; +} + +Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties, + ModularOptions::TreeMode wp_tree_mode) { + props_to_use = properties; + if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { + props_to_use = {static_cast<uint32_t>(kWPProp)}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) { + props_to_use = {static_cast<uint32_t>(kGradientProp)}; + } + if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { + auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp); + if (it != props_to_use.end()) { + props_to_use.erase(it); + } + } + if (props_to_use.empty()) { + return JXL_FAILURE("Invalid property set configuration"); + } + props.resize(props_to_use.size()); + return true; +} + +void TreeSamples::InitTable(size_t size) { + JXL_DASSERT((size & (size - 1)) == 0); + if (dedup_table_.size() == size) return; + dedup_table_.resize(size, kDedupEntryUnused); + for (size_t i = 0; i < NumDistinctSamples(); i++) { + if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) { + AddToTable(i); + } + } +} + +bool TreeSamples::AddToTableAndMerge(size_t a) { + size_t pos1 = Hash1(a); + size_t pos2 = Hash2(a); + if (dedup_table_[pos1] != kDedupEntryUnused && + IsSameSample(a, dedup_table_[pos1])) { + JXL_DASSERT(sample_counts[a] == 1); + sample_counts[dedup_table_[pos1]]++; + // Remove from hash table samples that are saturated. + if (sample_counts[dedup_table_[pos1]] == + std::numeric_limits<uint16_t>::max()) { + dedup_table_[pos1] = kDedupEntryUnused; + } + return true; + } + if (dedup_table_[pos2] != kDedupEntryUnused && + IsSameSample(a, dedup_table_[pos2])) { + JXL_DASSERT(sample_counts[a] == 1); + sample_counts[dedup_table_[pos2]]++; + // Remove from hash table samples that are saturated. + if (sample_counts[dedup_table_[pos2]] == + std::numeric_limits<uint16_t>::max()) { + dedup_table_[pos2] = kDedupEntryUnused; + } + return true; + } + AddToTable(a); + return false; +} + +void TreeSamples::AddToTable(size_t a) { + size_t pos1 = Hash1(a); + size_t pos2 = Hash2(a); + if (dedup_table_[pos1] == kDedupEntryUnused) { + dedup_table_[pos1] = a; + } else if (dedup_table_[pos2] == kDedupEntryUnused) { + dedup_table_[pos2] = a; + } +} + +void TreeSamples::PrepareForSamples(size_t num_samples) { + for (auto &res : residuals) { + res.reserve(res.size() + num_samples); + } + for (auto &p : props) { + p.reserve(p.size() + num_samples); + } + size_t total_num_samples = num_samples + sample_counts.size(); + size_t next_pow2 = 1LLU << CeilLog2Nonzero(total_num_samples * 3 / 2); + InitTable(next_pow2); +} + +size_t TreeSamples::Hash1(size_t a) const { + constexpr uint64_t constant = 0x1e35a7bd; + uint64_t h = constant; + for (const auto &r : residuals) { + h = h * constant + r[a].tok; + h = h * constant + r[a].nbits; + } + for (const auto &p : props) { + h = h * constant + p[a]; + } + return (h >> 16) & (dedup_table_.size() - 1); +} +size_t TreeSamples::Hash2(size_t a) const { + constexpr uint64_t constant = 0x1e35a7bd1e35a7bd; + uint64_t h = constant; + for (const auto &p : props) { + h = h * constant ^ p[a]; + } + for (const auto &r : residuals) { + h = h * constant ^ r[a].tok; + h = h * constant ^ r[a].nbits; + } + return (h >> 16) & (dedup_table_.size() - 1); +} + +bool TreeSamples::IsSameSample(size_t a, size_t b) const { + bool ret = true; + for (const auto &r : residuals) { + if (r[a].tok != r[b].tok) { + ret = false; + } + if (r[a].nbits != r[b].nbits) { + ret = false; + } + } + for (const auto &p : props) { + if (p[a] != p[b]) { + ret = false; + } + } + return ret; +} + +void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, + const pixel_type_w *predictions) { + for (size_t i = 0; i < predictors.size(); i++) { + pixel_type v = pixel - predictions[static_cast<int>(predictors[i])]; + uint32_t tok, nbits, bits; + HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits); + JXL_DASSERT(tok < 256); + JXL_DASSERT(nbits < 256); + residuals[i].emplace_back( + ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)}); + } + for (size_t i = 0; i < props_to_use.size(); i++) { + props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]])); + } + sample_counts.push_back(1); + num_samples++; + if (AddToTableAndMerge(sample_counts.size() - 1)) { + for (auto &r : residuals) r.pop_back(); + for (auto &p : props) p.pop_back(); + sample_counts.pop_back(); + } +} + +void TreeSamples::Swap(size_t a, size_t b) { + if (a == b) return; + for (auto &r : residuals) { + std::swap(r[a], r[b]); + } + for (auto &p : props) { + std::swap(p[a], p[b]); + } + std::swap(sample_counts[a], sample_counts[b]); +} + +void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) { + if (b == c) return Swap(a, b); + for (auto &r : residuals) { + auto tmp = r[a]; + r[a] = r[c]; + r[c] = r[b]; + r[b] = tmp; + } + for (auto &p : props) { + auto tmp = p[a]; + p[a] = p[c]; + p[c] = p[b]; + p[b] = tmp; + } + auto tmp = sample_counts[a]; + sample_counts[a] = sample_counts[c]; + sample_counts[c] = sample_counts[b]; + sample_counts[b] = tmp; +} + +namespace { +std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram, + size_t num_chunks) { + if (histogram.empty()) return {}; + // TODO(veluca): selecting distinct quantiles is likely not the best + // way to go about this. + std::vector<int32_t> thresholds; + uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU); + uint64_t cumsum = 0; + uint64_t threshold = 1; + for (size_t i = 0; i + 1 < histogram.size(); i++) { + cumsum += histogram[i]; + if (cumsum >= threshold * sum / num_chunks) { + thresholds.push_back(i); + while (cumsum > threshold * sum / num_chunks) threshold++; + } + } + return thresholds; +} + +std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples, + size_t num_chunks) { + if (samples.empty()) return {}; + int min = *std::min_element(samples.begin(), samples.end()); + constexpr int kRange = 512; + min = std::min(std::max(min, -kRange), kRange); + std::vector<uint32_t> counts(2 * kRange + 1); + for (int s : samples) { + uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min; + counts[sample_offset]++; + } + std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks); + for (auto &v : thresholds) v += min; + return thresholds; +} +} // namespace + +void TreeSamples::PreQuantizeProperties( + const StaticPropRange &range, + const std::vector<ModularMultiplierInfo> &multiplier_info, + const std::vector<uint32_t> &group_pixel_count, + const std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples, size_t max_property_values) { + // If we have forced splits because of multipliers, choose channel and group + // thresholds accordingly. + std::vector<int32_t> group_multiplier_thresholds; + std::vector<int32_t> channel_multiplier_thresholds; + for (const auto &v : multiplier_info) { + if (v.range[0][0] != range[0][0]) { + channel_multiplier_thresholds.push_back(v.range[0][0] - 1); + } + if (v.range[0][1] != range[0][1]) { + channel_multiplier_thresholds.push_back(v.range[0][1] - 1); + } + if (v.range[1][0] != range[1][0]) { + group_multiplier_thresholds.push_back(v.range[1][0] - 1); + } + if (v.range[1][1] != range[1][1]) { + group_multiplier_thresholds.push_back(v.range[1][1] - 1); + } + } + std::sort(channel_multiplier_thresholds.begin(), + channel_multiplier_thresholds.end()); + channel_multiplier_thresholds.resize( + std::unique(channel_multiplier_thresholds.begin(), + channel_multiplier_thresholds.end()) - + channel_multiplier_thresholds.begin()); + std::sort(group_multiplier_thresholds.begin(), + group_multiplier_thresholds.end()); + group_multiplier_thresholds.resize( + std::unique(group_multiplier_thresholds.begin(), + group_multiplier_thresholds.end()) - + group_multiplier_thresholds.begin()); + + compact_properties.resize(props_to_use.size()); + auto quantize_channel = [&]() { + if (!channel_multiplier_thresholds.empty()) { + return channel_multiplier_thresholds; + } + return QuantizeHistogram(channel_pixel_count, max_property_values); + }; + auto quantize_group_id = [&]() { + if (!group_multiplier_thresholds.empty()) { + return group_multiplier_thresholds; + } + return QuantizeHistogram(group_pixel_count, max_property_values); + }; + auto quantize_coordinate = [&]() { + std::vector<int32_t> quantized; + quantized.reserve(max_property_values - 1); + for (size_t i = 0; i + 1 < max_property_values; i++) { + quantized.push_back((i + 1) * 256 / max_property_values - 1); + } + return quantized; + }; + std::vector<int32_t> abs_pixel_thr; + std::vector<int32_t> pixel_thr; + auto quantize_pixel_property = [&]() { + if (pixel_thr.empty()) { + pixel_thr = QuantizeSamples(pixel_samples, max_property_values); + } + return pixel_thr; + }; + auto quantize_abs_pixel_property = [&]() { + if (abs_pixel_thr.empty()) { + quantize_pixel_property(); // Compute the non-abs thresholds. + for (auto &v : pixel_samples) v = std::abs(v); + abs_pixel_thr = QuantizeSamples(pixel_samples, max_property_values); + } + return abs_pixel_thr; + }; + std::vector<int32_t> abs_diff_thr; + std::vector<int32_t> diff_thr; + auto quantize_diff_property = [&]() { + if (diff_thr.empty()) { + diff_thr = QuantizeSamples(diff_samples, max_property_values); + } + return diff_thr; + }; + auto quantize_abs_diff_property = [&]() { + if (abs_diff_thr.empty()) { + quantize_diff_property(); // Compute the non-abs thresholds. + for (auto &v : diff_samples) v = std::abs(v); + abs_diff_thr = QuantizeSamples(diff_samples, max_property_values); + } + return abs_diff_thr; + }; + auto quantize_wp = [&]() { + if (max_property_values < 32) { + return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0, + 1, 3, 7, 15, 31, 63, 127}; + } + if (max_property_values < 64) { + return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23, + -15, -11, -7, -5, -3, -1, 0, 1, + 3, 5, 7, 11, 15, 23, 31, 47, + 63, 95, 127, 191, 255}; + } + return std::vector<int32_t>{ + -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47, + -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6, + -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, + 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39, + 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255}; + }; + + property_mapping.resize(props_to_use.size()); + for (size_t i = 0; i < props_to_use.size(); i++) { + if (props_to_use[i] == 0) { + compact_properties[i] = quantize_channel(); + } else if (props_to_use[i] == 1) { + compact_properties[i] = quantize_group_id(); + } else if (props_to_use[i] == 2 || props_to_use[i] == 3) { + compact_properties[i] = quantize_coordinate(); + } else if (props_to_use[i] == 6 || props_to_use[i] == 7 || + props_to_use[i] == 8 || + (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) { + compact_properties[i] = quantize_pixel_property(); + } else if (props_to_use[i] == 4 || props_to_use[i] == 5 || + (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) { + compact_properties[i] = quantize_abs_pixel_property(); + } else if (props_to_use[i] >= kNumNonrefProperties && + (props_to_use[i] - kNumNonrefProperties) % 4 == 2) { + compact_properties[i] = quantize_abs_diff_property(); + } else if (props_to_use[i] == kWPProp) { + compact_properties[i] = quantize_wp(); + } else { + compact_properties[i] = quantize_diff_property(); + } + property_mapping[i].resize(kPropertyRange * 2 + 1); + size_t mapped = 0; + for (size_t j = 0; j < property_mapping[i].size(); j++) { + while (mapped < compact_properties[i].size() && + static_cast<int>(j) - kPropertyRange > + compact_properties[i][mapped]) { + mapped++; + } + // property_mapping[i] of a value V is `mapped` if + // compact_properties[i][mapped] <= j and + // compact_properties[i][mapped-1] > j + // This is because the decision node in the tree splits on (property) > j, + // hence everything that is not > of a threshold should be clustered + // together. + property_mapping[i][j] = mapped; + } + } +} + +void CollectPixelSamples(const Image &image, const ModularOptions &options, + size_t group_id, + std::vector<uint32_t> &group_pixel_count, + std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples) { + if (options.nb_repeats == 0) return; + if (group_pixel_count.size() <= group_id) { + group_pixel_count.resize(group_id + 1); + } + if (channel_pixel_count.size() < image.channel.size()) { + channel_pixel_count.resize(image.channel.size()); + } + Rng rng(group_id); + // Sample 10% of the final number of samples for property quantization. + float fraction = std::min(options.nb_repeats * 0.1, 0.99); + Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction); + size_t total_pixels = 0; + std::vector<size_t> channel_ids; + for (size_t i = 0; i < image.channel.size(); i++) { + if (image.channel[i].w <= 1 || image.channel[i].h == 0) { + continue; // skip empty or width-1 channels. + } + if (i >= image.nb_meta_channels && + (image.channel[i].w > options.max_chan_size || + image.channel[i].h > options.max_chan_size)) { + break; + } + channel_ids.push_back(i); + group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h; + channel_pixel_count[i] += image.channel[i].w * image.channel[i].h; + total_pixels += image.channel[i].w * image.channel[i].h; + } + if (channel_ids.empty()) return; + pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels); + diff_samples.reserve(diff_samples.size() + fraction * total_pixels); + size_t i = 0; + size_t y = 0; + size_t x = 0; + auto advance = [&](size_t amount) { + x += amount; + // Detect row overflow (rare). + while (x >= image.channel[channel_ids[i]].w) { + x -= image.channel[channel_ids[i]].w; + y++; + // Detect end-of-channel (even rarer). + if (y == image.channel[channel_ids[i]].h) { + i++; + y = 0; + if (i >= channel_ids.size()) { + return; + } + } + } + }; + advance(rng.Geometric(dist)); + for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) { + const pixel_type *row = image.channel[channel_ids[i]].Row(y); + pixel_samples.push_back(row[x]); + size_t xp = x == 0 ? 1 : x - 1; + diff_samples.push_back((int64_t)row[x] - row[xp]); + } +} + +// TODO(veluca): very simple encoding scheme. This should be improved. +void TokenizeTree(const Tree &tree, std::vector<Token> *tokens, + Tree *decoder_tree) { + JXL_ASSERT(tree.size() <= kMaxTreeSize); + std::queue<int> q; + q.push(0); + size_t leaf_id = 0; + decoder_tree->clear(); + while (!q.empty()) { + int cur = q.front(); + q.pop(); + JXL_ASSERT(tree[cur].property >= -1); + tokens->emplace_back(kPropertyContext, tree[cur].property + 1); + if (tree[cur].property == -1) { + tokens->emplace_back(kPredictorContext, + static_cast<int>(tree[cur].predictor)); + tokens->emplace_back(kOffsetContext, + PackSigned(tree[cur].predictor_offset)); + uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier); + uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1; + tokens->emplace_back(kMultiplierLogContext, mul_log); + tokens->emplace_back(kMultiplierBitsContext, mul_bits); + JXL_ASSERT(tree[cur].predictor < Predictor::Best); + decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor, + tree[cur].predictor_offset, + tree[cur].multiplier); + continue; + } + decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval, + decoder_tree->size() + q.size() + 1, + decoder_tree->size() + q.size() + 2, + Predictor::Zero, 0, 1); + q.push(tree[cur].lchild); + q.push(tree[cur].rchild); + tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval)); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h new file mode 100644 index 0000000000..ede37c8023 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/enc_ma.h @@ -0,0 +1,157 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ +#define LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ + +#include <numeric> + +#include "lib/jxl/enc_ans.h" +#include "lib/jxl/entropy_coder.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +// Struct to collect all the data needed to build a tree. +struct TreeSamples { + bool HasSamples() const { + return !residuals.empty() && !residuals[0].empty(); + } + size_t NumDistinctSamples() const { return sample_counts.size(); } + size_t NumSamples() const { return num_samples; } + // Set the predictor to use. Must be called before adding any samples. + Status SetPredictor(Predictor predictor, + ModularOptions::TreeMode wp_tree_mode); + // Set the properties to use. Must be called before adding any samples. + Status SetProperties(const std::vector<uint32_t> &properties, + ModularOptions::TreeMode wp_tree_mode); + + size_t Token(size_t pred, size_t i) const { return residuals[pred][i].tok; } + size_t NBits(size_t pred, size_t i) const { return residuals[pred][i].nbits; } + size_t Count(size_t i) const { return sample_counts[i]; } + size_t PredictorIndex(Predictor predictor) const { + const auto predictor_elem = + std::find(predictors.begin(), predictors.end(), predictor); + JXL_DASSERT(predictor_elem != predictors.end()); + return predictor_elem - predictors.begin(); + } + size_t PropertyIndex(size_t property) const { + const auto property_elem = + std::find(props_to_use.begin(), props_to_use.end(), property); + JXL_DASSERT(property_elem != props_to_use.end()); + return property_elem - props_to_use.begin(); + } + size_t NumPropertyValues(size_t property_index) const { + return compact_properties[property_index].size() + 1; + } + // Returns the *quantized* property value. + size_t Property(size_t property_index, size_t i) const { + return props[property_index][i]; + } + int UnquantizeProperty(size_t property_index, uint32_t quant) const { + JXL_ASSERT(quant < compact_properties[property_index].size()); + return compact_properties[property_index][quant]; + } + + Predictor PredictorFromIndex(size_t index) const { + JXL_DASSERT(index < predictors.size()); + return predictors[index]; + } + size_t PropertyFromIndex(size_t index) const { + JXL_DASSERT(index < props_to_use.size()); + return props_to_use[index]; + } + size_t NumPredictors() const { return predictors.size(); } + size_t NumProperties() const { return props_to_use.size(); } + + // Preallocate data for a given number of samples. MUST be called before + // adding any sample. + void PrepareForSamples(size_t num_samples); + // Add a sample. + void AddSample(pixel_type_w pixel, const Properties &properties, + const pixel_type_w *predictions); + // Pre-cluster property values. + void PreQuantizeProperties( + const StaticPropRange &range, + const std::vector<ModularMultiplierInfo> &multiplier_info, + const std::vector<uint32_t> &group_pixel_count, + const std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples, size_t max_property_values); + + void AllSamplesDone() { dedup_table_ = std::vector<uint32_t>(); } + + uint32_t QuantizeProperty(uint32_t prop, pixel_type v) const { + v = std::min(std::max(v, -kPropertyRange), kPropertyRange) + kPropertyRange; + return property_mapping[prop][v]; + } + + // Swaps samples in position a and b. Does nothing if a == b. + void Swap(size_t a, size_t b); + + // Cycles samples: a -> b -> c -> a. We assume a <= b <= c, so that we can + // just call Swap(a, b) if b==c. + void ThreeShuffle(size_t a, size_t b, size_t c); + + private: + // TODO(veluca): as the total number of properties and predictors are known + // before adding any samples, it might be better to interleave predictors, + // properties and counts in a single vector to improve locality. + // A first attempt at doing this actually results in much slower encoding, + // possibly because of the more complex addressing. + struct ResidualToken { + uint8_t tok; + uint8_t nbits; + }; + // Residual information: token and number of extra bits, per predictor. + std::vector<std::vector<ResidualToken>> residuals; + // Number of occurrences of each sample. + std::vector<uint16_t> sample_counts; + // Property values, quantized to at most 256 distinct values. + std::vector<std::vector<uint8_t>> props; + // Decompactification info for `props`. + std::vector<std::vector<int32_t>> compact_properties; + // List of properties to use. + std::vector<uint32_t> props_to_use; + // List of predictors to use. + std::vector<Predictor> predictors; + // Mapping property value -> quantized property value. + static constexpr int32_t kPropertyRange = 511; + std::vector<std::vector<uint8_t>> property_mapping; + // Number of samples seen. + size_t num_samples = 0; + // Table for deduplication. + static constexpr uint32_t kDedupEntryUnused{static_cast<uint32_t>(-1)}; + std::vector<uint32_t> dedup_table_; + + // Functions for sample deduplication. + bool IsSameSample(size_t a, size_t b) const; + size_t Hash1(size_t a) const; + size_t Hash2(size_t a) const; + void InitTable(size_t size); + // Returns true if `a` was already present in the table. + bool AddToTableAndMerge(size_t a); + void AddToTable(size_t a); +}; + +void TokenizeTree(const Tree &tree, std::vector<Token> *tokens, + Tree *decoder_tree); + +void CollectPixelSamples(const Image &image, const ModularOptions &options, + size_t group_id, + std::vector<uint32_t> &group_pixel_count, + std::vector<uint32_t> &channel_pixel_count, + std::vector<pixel_type> &pixel_samples, + std::vector<pixel_type> &diff_samples); + +void ComputeBestTree(TreeSamples &tree_samples, float threshold, + const std::vector<ModularMultiplierInfo> &mul_info, + StaticPropRange static_prop_range, + float fast_decode_multiplier, Tree *tree); + +} // namespace jxl +#endif // LIB_JXL_MODULAR_ENCODING_ENC_MA_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc new file mode 100644 index 0000000000..a6abdcfc91 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.cc @@ -0,0 +1,688 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/encoding/encoding.h" + +#include <stdint.h> +#include <stdlib.h> + +#include <queue> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/scope_guard.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/pack_signed.h" + +namespace jxl { + +// Removes all nodes that use a static property (i.e. channel or group ID) from +// the tree and collapses each node on even levels with its two children to +// produce a flatter tree. Also computes whether the resulting tree requires +// using the weighted predictor. +FlatTree FilterTree(const Tree &global_tree, + std::array<pixel_type, kNumStaticProperties> &static_props, + size_t *num_props, bool *use_wp, bool *wp_only, + bool *gradient_only) { + *num_props = 0; + bool has_wp = false; + bool has_non_wp = false; + *gradient_only = true; + const auto mark_property = [&](int32_t p) { + if (p == kWPProp) { + has_wp = true; + } else if (p >= kNumStaticProperties) { + has_non_wp = true; + } + if (p >= kNumStaticProperties && p != kGradientProp) { + *gradient_only = false; + } + }; + FlatTree output; + std::queue<size_t> nodes; + nodes.push(0); + // Produces a trimmed and flattened tree by doing a BFS visit of the original + // tree, ignoring branches that are known to be false and proceeding two + // levels at a time to collapse nodes in a flatter tree; if an inner parent + // node has a leaf as a child, the leaf is duplicated and an implicit fake + // node is added. This allows to reduce the number of branches when traversing + // the resulting flat tree. + while (!nodes.empty()) { + size_t cur = nodes.front(); + nodes.pop(); + // Skip nodes that we can decide now, by jumping directly to their children. + while (global_tree[cur].property < kNumStaticProperties && + global_tree[cur].property != -1) { + if (static_props[global_tree[cur].property] > global_tree[cur].splitval) { + cur = global_tree[cur].lchild; + } else { + cur = global_tree[cur].rchild; + } + } + FlatDecisionNode flat; + if (global_tree[cur].property == -1) { + flat.property0 = -1; + flat.childID = global_tree[cur].lchild; + flat.predictor = global_tree[cur].predictor; + flat.predictor_offset = global_tree[cur].predictor_offset; + flat.multiplier = global_tree[cur].multiplier; + *gradient_only &= flat.predictor == Predictor::Gradient; + has_wp |= flat.predictor == Predictor::Weighted; + has_non_wp |= flat.predictor != Predictor::Weighted; + output.push_back(flat); + continue; + } + flat.childID = output.size() + nodes.size() + 1; + + flat.property0 = global_tree[cur].property; + *num_props = std::max<size_t>(flat.property0 + 1, *num_props); + flat.splitval0 = global_tree[cur].splitval; + + for (size_t i = 0; i < 2; i++) { + size_t cur_child = + i == 0 ? global_tree[cur].lchild : global_tree[cur].rchild; + // Skip nodes that we can decide now. + while (global_tree[cur_child].property < kNumStaticProperties && + global_tree[cur_child].property != -1) { + if (static_props[global_tree[cur_child].property] > + global_tree[cur_child].splitval) { + cur_child = global_tree[cur_child].lchild; + } else { + cur_child = global_tree[cur_child].rchild; + } + } + // We ended up in a leaf, add a placeholder decision and two copies of the + // leaf. + if (global_tree[cur_child].property == -1) { + flat.properties[i] = 0; + flat.splitvals[i] = 0; + nodes.push(cur_child); + nodes.push(cur_child); + } else { + flat.properties[i] = global_tree[cur_child].property; + flat.splitvals[i] = global_tree[cur_child].splitval; + nodes.push(global_tree[cur_child].lchild); + nodes.push(global_tree[cur_child].rchild); + *num_props = std::max<size_t>(flat.properties[i] + 1, *num_props); + } + } + + for (size_t j = 0; j < 2; j++) mark_property(flat.properties[j]); + mark_property(flat.property0); + output.push_back(flat); + } + if (*num_props > kNumNonrefProperties) { + *num_props = + DivCeil(*num_props - kNumNonrefProperties, kExtraPropsPerChannel) * + kExtraPropsPerChannel + + kNumNonrefProperties; + } else { + *num_props = kNumNonrefProperties; + } + *use_wp = has_wp; + *wp_only = has_wp && !has_non_wp; + + return output; +} + +namespace detail { +template <bool uses_lz77> +Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader, + const std::vector<uint8_t> &context_map, + const Tree &global_tree, + const weighted::Header &wp_header, + pixel_type chan, size_t group_id, + TreeLut<uint8_t, true> &tree_lut, + Image *image) { + Channel &channel = image->channel[chan]; + + std::array<pixel_type, kNumStaticProperties> static_props = { + {chan, (int)group_id}}; + // TODO(veluca): filter the tree according to static_props. + + // zero pixel channel? could happen + if (channel.w == 0 || channel.h == 0) return true; + + bool tree_has_wp_prop_or_pred = false; + bool is_wp_only = false; + bool is_gradient_only = false; + size_t num_props; + FlatTree tree = + FilterTree(global_tree, static_props, &num_props, + &tree_has_wp_prop_or_pred, &is_wp_only, &is_gradient_only); + + // From here on, tree lookup returns a *clustered* context ID. + // This avoids an extra memory lookup after tree traversal. + for (size_t i = 0; i < tree.size(); i++) { + if (tree[i].property0 == -1) { + tree[i].childID = context_map[tree[i].childID]; + } + } + + JXL_DEBUG_V(3, "Decoded MA tree with %" PRIuS " nodes", tree.size()); + + // MAANS decode + const auto make_pixel = [](uint64_t v, pixel_type multiplier, + pixel_type_w offset) -> pixel_type { + JXL_DASSERT((v & 0xFFFFFFFF) == v); + pixel_type_w val = UnpackSigned(v); + // if it overflows, it overflows, and we have a problem anyway + return val * multiplier + offset; + }; + + if (tree.size() == 1) { + // special optimized case: no meta-adaptation, so no need + // to compute properties. + Predictor predictor = tree[0].predictor; + int64_t offset = tree[0].predictor_offset; + int32_t multiplier = tree[0].multiplier; + size_t ctx_id = tree[0].childID; + if (predictor == Predictor::Zero) { + uint32_t value; + if (reader->IsSingleValueAndAdvance(ctx_id, &value, + channel.w * channel.h)) { + // Special-case: histogram has a single symbol, with no extra bits, and + // we use ANS mode. + JXL_DEBUG_V(8, "Fastest track."); + pixel_type v = make_pixel(value, multiplier, offset); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + std::fill(r, r + channel.w, v); + } + } else { + JXL_DEBUG_V(8, "Fast track."); + if (multiplier == 1 && offset == 0) { + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + uint32_t v = + reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br); + r[x] = UnpackSigned(v); + } + } + } else { + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + uint32_t v = + reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>(ctx_id, + br); + r[x] = make_pixel(v, multiplier, offset); + } + } + } + } + return true; + } else if (uses_lz77 && predictor == Predictor::Gradient && offset == 0 && + multiplier == 1 && reader->HuffRleOnly()) { + JXL_DEBUG_V(8, "Gradient RLE (fjxl) very fast track."); + uint32_t run = 0; + uint32_t v = 0; + pixel_type_w sv = 0; + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + const pixel_type *JXL_RESTRICT rtop = (y ? channel.Row(y - 1) : r - 1); + const pixel_type *JXL_RESTRICT rtopleft = + (y ? channel.Row(y - 1) - 1 : r - 1); + pixel_type_w guess = (y ? rtop[0] : 0); + if (run == 0) { + reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &v, &run); + sv = UnpackSigned(v); + } else { + run--; + } + r[0] = sv + guess; + for (size_t x = 1; x < channel.w; x++) { + pixel_type left = r[x - 1]; + pixel_type top = rtop[x]; + pixel_type topleft = rtopleft[x]; + pixel_type_w guess = ClampedGradient(top, left, topleft); + if (!run) { + reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &v, &run); + sv = UnpackSigned(v); + } else { + run--; + } + r[x] = sv + guess; + } + } + return true; + } else if (predictor == Predictor::Gradient && offset == 0 && + multiplier == 1) { + JXL_DEBUG_V(8, "Gradient very fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type top = (y ? *(r + x - onerow) : left); + pixel_type topleft = (x && y ? *(r + x - 1 - onerow) : left); + pixel_type guess = ClampedGradient(top, left, topleft); + uint64_t v = reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>( + ctx_id, br); + r[x] = make_pixel(v, 1, guess); + } + } + return true; + } + } + + // Check if this tree is a WP-only tree with a small enough property value + // range. + if (is_wp_only) { + is_wp_only = TreeToLookupTable(tree, tree_lut); + } + if (is_gradient_only) { + is_gradient_only = TreeToLookupTable(tree, tree_lut); + } + + if (is_gradient_only) { + JXL_DEBUG_V(8, "Gradient fast track."); + const intptr_t onerow = channel.plane.PixelsPerRow(); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + for (size_t x = 0; x < channel.w; x++) { + pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0); + pixel_type_w top = (y ? *(r + x - onerow) : left); + pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left); + int32_t guess = ClampedGradient(top, left, topleft); + uint32_t pos = + kPropRangeFast + + std::min<pixel_type_w>( + std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft), + kPropRangeFast - 1); + uint32_t ctx_id = tree_lut.context_lookup[pos]; + uint64_t v = + reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>(ctx_id, br); + r[x] = make_pixel( + v, tree_lut.multipliers[pos], + static_cast<pixel_type_w>(tree_lut.offsets[pos]) + guess); + } + } + } else if (!uses_lz77 && is_wp_only && channel.w > 8) { + JXL_DEBUG_V(8, "WP fast track."); + weighted::State wp_state(wp_header, channel.w, channel.h); + Properties properties(1); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT r = channel.Row(y); + const pixel_type *JXL_RESTRICT rtop = (y ? channel.Row(y - 1) : r - 1); + const pixel_type *JXL_RESTRICT rtoptop = + (y > 1 ? channel.Row(y - 2) : rtop); + const pixel_type *JXL_RESTRICT rtopleft = + (y ? channel.Row(y - 1) - 1 : r - 1); + const pixel_type *JXL_RESTRICT rtopright = + (y ? channel.Row(y - 1) + 1 : r - 1); + size_t x = 0; + { + size_t offset = 0; + pixel_type_w left = y ? rtop[x] : 0; + pixel_type_w toptop = y ? rtoptop[x] : 0; + pixel_type_w topright = (x + 1 < channel.w && y ? rtop[x + 1] : left); + int32_t guess = wp_state.Predict</*compute_properties=*/true>( + x, y, channel.w, left, left, topright, left, toptop, &properties, + offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = tree_lut.context_lookup[pos]; + uint64_t v = + reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br); + r[x] = make_pixel( + v, tree_lut.multipliers[pos], + static_cast<pixel_type_w>(tree_lut.offsets[pos]) + guess); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + for (x = 1; x + 1 < channel.w; x++) { + size_t offset = 0; + int32_t guess = wp_state.Predict</*compute_properties=*/true>( + x, y, channel.w, rtop[x], r[x - 1], rtopright[x], rtopleft[x], + rtoptop[x], &properties, offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = tree_lut.context_lookup[pos]; + uint64_t v = + reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br); + r[x] = make_pixel( + v, tree_lut.multipliers[pos], + static_cast<pixel_type_w>(tree_lut.offsets[pos]) + guess); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + { + size_t offset = 0; + int32_t guess = wp_state.Predict</*compute_properties=*/true>( + x, y, channel.w, rtop[x], r[x - 1], rtop[x], rtopleft[x], + rtoptop[x], &properties, offset); + uint32_t pos = + kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]), + kPropRangeFast - 1); + uint32_t ctx_id = tree_lut.context_lookup[pos]; + uint64_t v = + reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br); + r[x] = make_pixel( + v, tree_lut.multipliers[pos], + static_cast<pixel_type_w>(tree_lut.offsets[pos]) + guess); + wp_state.UpdateErrors(r[x], x, y, channel.w); + } + } + } else if (!tree_has_wp_prop_or_pred) { + // special optimized case: the weighted predictor and its properties are not + // used, so no need to compute weights and properties. + JXL_DEBUG_V(8, "Slow track."); + MATreeLookup tree_lookup(tree); + Properties properties = Properties(num_props); + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + PrecomputeReferences(channel, y, *image, chan, &references); + InitPropsRow(&properties, static_props, y); + if (y > 1 && channel.w > 8 && references.w == 0) { + for (size_t x = 0; x < 2; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = + reader->ReadHybridUintClustered<uses_lz77>(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + for (size_t x = 2; x < channel.w - 2; x++) { + PredictionResult res = + PredictTreeNoWPNEC(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClusteredInlined<uses_lz77>( + res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + for (size_t x = channel.w - 2; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = + reader->ReadHybridUintClustered<uses_lz77>(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + } else { + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references); + uint64_t v = reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>( + res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + } + } + } + } else { + JXL_DEBUG_V(8, "Slowest track."); + MATreeLookup tree_lookup(tree); + Properties properties = Properties(num_props); + const intptr_t onerow = channel.plane.PixelsPerRow(); + Channel references(properties.size() - kNumNonrefProperties, channel.w); + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + InitPropsRow(&properties, static_props, y); + PrecomputeReferences(channel, y, *image, chan, &references); + if (!uses_lz77 && y > 1 && channel.w > 8 && references.w == 0) { + for (size_t x = 0; x < 2; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + uint64_t v = + reader->ReadHybridUintClustered<uses_lz77>(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + for (size_t x = 2; x < channel.w - 2; x++) { + PredictionResult res = + PredictTreeWPNEC(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + uint64_t v = reader->ReadHybridUintClusteredInlined<uses_lz77>( + res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + for (size_t x = channel.w - 2; x < channel.w; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + uint64_t v = + reader->ReadHybridUintClustered<uses_lz77>(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } else { + for (size_t x = 0; x < channel.w; x++) { + PredictionResult res = + PredictTreeWP(&properties, channel.w, p + x, onerow, x, y, + tree_lookup, references, &wp_state); + uint64_t v = + reader->ReadHybridUintClustered<uses_lz77>(res.context, br); + p[x] = make_pixel(v, res.multiplier, res.guess); + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + } + } + return true; +} +} // namespace detail + +Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader, + const std::vector<uint8_t> &context_map, + const Tree &global_tree, + const weighted::Header &wp_header, + pixel_type chan, size_t group_id, + TreeLut<uint8_t, true> &tree_lut, + Image *image) { + if (reader->UsesLZ77()) { + return detail::DecodeModularChannelMAANS</*uses_lz77=*/true>( + br, reader, context_map, global_tree, wp_header, chan, group_id, + tree_lut, image); + } else { + return detail::DecodeModularChannelMAANS</*uses_lz77=*/false>( + br, reader, context_map, global_tree, wp_header, chan, group_id, + tree_lut, image); + } +} + +GroupHeader::GroupHeader() { Bundle::Init(this); } + +Status ValidateChannelDimensions(const Image &image, + const ModularOptions &options) { + size_t nb_channels = image.channel.size(); + for (bool is_dc : {true, false}) { + size_t group_dim = options.group_dim * (is_dc ? kBlockDim : 1); + size_t c = image.nb_meta_channels; + for (; c < nb_channels; c++) { + const Channel &ch = image.channel[c]; + if (ch.w > options.group_dim || ch.h > options.group_dim) break; + } + for (; c < nb_channels; c++) { + const Channel &ch = image.channel[c]; + if (ch.w == 0 || ch.h == 0) continue; // skip empty + bool is_dc_channel = std::min(ch.hshift, ch.vshift) >= 3; + if (is_dc_channel != is_dc) continue; + size_t tile_dim = group_dim >> std::max(ch.hshift, ch.vshift); + if (tile_dim == 0) { + return JXL_FAILURE("Inconsistent transforms"); + } + } + } + return true; +} + +Status ModularDecode(BitReader *br, Image &image, GroupHeader &header, + size_t group_id, ModularOptions *options, + const Tree *global_tree, const ANSCode *global_code, + const std::vector<uint8_t> *global_ctx_map, + const bool allow_truncated_group) { + if (image.channel.empty()) return true; + + // decode transforms + Status status = Bundle::Read(br, &header); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(status); + if (status.IsFatalError()) return status; + if (!br->AllReadsWithinBounds()) { + // Don't do/undo transforms if header is incomplete. + header.transforms.clear(); + image.transform = header.transforms; + for (size_t c = 0; c < image.channel.size(); c++) { + ZeroFillImage(&image.channel[c].plane); + } + return Status(StatusCode::kNotEnoughBytes); + } + + JXL_DEBUG_V(3, "Image data underwent %" PRIuS " transformations: ", + header.transforms.size()); + image.transform = header.transforms; + for (Transform &transform : image.transform) { + JXL_RETURN_IF_ERROR(transform.MetaApply(image)); + } + if (image.error) { + return JXL_FAILURE("Corrupt file. Aborting."); + } + JXL_RETURN_IF_ERROR(ValidateChannelDimensions(image, *options)); + + size_t nb_channels = image.channel.size(); + + size_t num_chans = 0; + size_t distance_multiplier = 0; + for (size_t i = 0; i < nb_channels; i++) { + Channel &channel = image.channel[i]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + if (channel.w > distance_multiplier) { + distance_multiplier = channel.w; + } + num_chans++; + } + if (num_chans == 0) return true; + + size_t next_channel = 0; + auto scope_guard = MakeScopeGuard([&]() { + for (size_t c = next_channel; c < image.channel.size(); c++) { + ZeroFillImage(&image.channel[c].plane); + } + }); + // Do not do anything if truncated groups are not allowed. + if (allow_truncated_group) scope_guard.Disarm(); + + // Read tree. + Tree tree_storage; + std::vector<uint8_t> context_map_storage; + ANSCode code_storage; + const Tree *tree = &tree_storage; + const ANSCode *code = &code_storage; + const std::vector<uint8_t> *context_map = &context_map_storage; + if (!header.use_global_tree) { + uint64_t max_tree_size = 1024; + for (size_t i = 0; i < nb_channels; i++) { + Channel &channel = image.channel[i]; + if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + uint64_t pixels = channel.w * channel.h; + max_tree_size += pixels; + } + max_tree_size = std::min(static_cast<uint64_t>(1 << 20), max_tree_size); + JXL_RETURN_IF_ERROR(DecodeTree(br, &tree_storage, max_tree_size)); + JXL_RETURN_IF_ERROR(DecodeHistograms(br, (tree_storage.size() + 1) / 2, + &code_storage, &context_map_storage)); + } else { + if (!global_tree || !global_code || !global_ctx_map || + global_tree->empty()) { + return JXL_FAILURE("No global tree available but one was requested"); + } + tree = global_tree; + code = global_code; + context_map = global_ctx_map; + } + + // Read channels + ANSSymbolReader reader(code, br, distance_multiplier); + auto tree_lut = jxl::make_unique<TreeLut<uint8_t, true>>(); + for (; next_channel < nb_channels; next_channel++) { + Channel &channel = image.channel[next_channel]; + if (!channel.w || !channel.h) { + continue; // skip empty channels + } + if (next_channel >= image.nb_meta_channels && + (channel.w > options->max_chan_size || + channel.h > options->max_chan_size)) { + break; + } + JXL_RETURN_IF_ERROR(DecodeModularChannelMAANS( + br, &reader, *context_map, *tree, header.wp_header, next_channel, + group_id, *tree_lut, &image)); + + // Truncated group. + if (!br->AllReadsWithinBounds()) { + if (!allow_truncated_group) return JXL_FAILURE("Truncated input"); + return Status(StatusCode::kNotEnoughBytes); + } + } + + // Make sure no zero-filling happens even if next_channel < nb_channels. + scope_guard.Disarm(); + + if (!reader.CheckANSFinalState()) { + return JXL_FAILURE("ANS decode final state failed"); + } + return true; +} + +Status ModularGenericDecompress(BitReader *br, Image &image, + GroupHeader *header, size_t group_id, + ModularOptions *options, bool undo_transforms, + const Tree *tree, const ANSCode *code, + const std::vector<uint8_t> *ctx_map, + bool allow_truncated_group) { +#ifdef JXL_ENABLE_ASSERT + std::vector<std::pair<uint32_t, uint32_t>> req_sizes(image.channel.size()); + for (size_t c = 0; c < req_sizes.size(); c++) { + req_sizes[c] = {image.channel[c].w, image.channel[c].h}; + } +#endif + GroupHeader local_header; + if (header == nullptr) header = &local_header; + size_t bit_pos = br->TotalBitsConsumed(); + auto dec_status = ModularDecode(br, image, *header, group_id, options, tree, + code, ctx_map, allow_truncated_group); + if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status); + if (dec_status.IsFatalError()) return dec_status; + if (undo_transforms) image.undo_transforms(header->wp_header); + if (image.error) return JXL_FAILURE("Corrupt file. Aborting."); + JXL_DEBUG_V(4, + "Modular-decoded a %" PRIuS "x%" PRIuS " nbchans=%" PRIuS + " image from %" PRIuS " bytes", + image.w, image.h, image.channel.size(), + (br->TotalBitsConsumed() - bit_pos) / 8); + JXL_DEBUG_V(5, "Modular image: %s", image.DebugString().c_str()); + (void)bit_pos; +#ifdef JXL_ENABLE_ASSERT + // Check that after applying all transforms we are back to the requested image + // sizes, otherwise there's a programming error with the transformations. + if (undo_transforms) { + JXL_ASSERT(image.channel.size() == req_sizes.size()); + for (size_t c = 0; c < req_sizes.size(); c++) { + JXL_ASSERT(req_sizes[c].first == image.channel[c].w); + JXL_ASSERT(req_sizes[c].second == image.channel[c].h); + } + } +#endif + return dec_status; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h new file mode 100644 index 0000000000..25007803bd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/encoding.h @@ -0,0 +1,146 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_ENCODING_H_ +#define LIB_JXL_MODULAR_ENCODING_ENCODING_H_ + +#include <array> +#include <cstddef> +#include <cstdint> +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/image.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/encoding/dec_ma.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +struct ANSCode; +class BitReader; + +// Valid range of properties for using lookup tables instead of trees. +constexpr int32_t kPropRangeFast = 512; + +struct GroupHeader : public Fields { + GroupHeader(); + + JXL_FIELDS_NAME(GroupHeader) + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &use_global_tree)); + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&wp_header)); + uint32_t num_transforms = static_cast<uint32_t>(transforms.size()); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(0), Val(1), BitsOffset(4, 2), + BitsOffset(8, 18), 0, + &num_transforms)); + if (visitor->IsReading()) transforms.resize(num_transforms); + for (size_t i = 0; i < num_transforms; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&transforms[i])); + } + return true; + } + + bool use_global_tree; + weighted::Header wp_header; + + std::vector<Transform> transforms; +}; + +FlatTree FilterTree(const Tree &global_tree, + std::array<pixel_type, kNumStaticProperties> &static_props, + size_t *num_props, bool *use_wp, bool *wp_only, + bool *gradient_only); + +template <typename T, bool HAS_MULTIPLIERS> +struct TreeLut { + std::array<T, 2 * kPropRangeFast> context_lookup; + std::array<int8_t, 2 * kPropRangeFast> offsets; + std::array<int8_t, HAS_MULTIPLIERS ? (2 * kPropRangeFast) : 0> multipliers; +}; + +template <typename T, bool HAS_MULTIPLIERS> +bool TreeToLookupTable(const FlatTree &tree, TreeLut<T, HAS_MULTIPLIERS> &lut) { + struct TreeRange { + // Begin *excluded*, end *included*. This works best with > vs <= decision + // nodes. + int begin, end; + size_t pos; + }; + std::vector<TreeRange> ranges; + ranges.push_back(TreeRange{-kPropRangeFast - 1, kPropRangeFast - 1, 0}); + while (!ranges.empty()) { + TreeRange cur = ranges.back(); + ranges.pop_back(); + if (cur.begin < -kPropRangeFast - 1 || cur.begin >= kPropRangeFast - 1 || + cur.end > kPropRangeFast - 1) { + // Tree is outside the allowed range, exit. + return false; + } + auto &node = tree[cur.pos]; + // Leaf. + if (node.property0 == -1) { + if (node.predictor_offset < std::numeric_limits<int8_t>::min() || + node.predictor_offset > std::numeric_limits<int8_t>::max()) { + return false; + } + if (node.multiplier < std::numeric_limits<int8_t>::min() || + node.multiplier > std::numeric_limits<int8_t>::max()) { + return false; + } + if (!HAS_MULTIPLIERS && node.multiplier != 1) { + return false; + } + for (int i = cur.begin + 1; i < cur.end + 1; i++) { + lut.context_lookup[i + kPropRangeFast] = node.childID; + if (HAS_MULTIPLIERS) { + lut.multipliers[i + kPropRangeFast] = node.multiplier; + } + lut.offsets[i + kPropRangeFast] = node.predictor_offset; + } + continue; + } + // > side of top node. + if (node.properties[0] >= kNumStaticProperties) { + ranges.push_back(TreeRange({node.splitvals[0], cur.end, node.childID})); + ranges.push_back( + TreeRange({node.splitval0, node.splitvals[0], node.childID + 1})); + } else { + ranges.push_back(TreeRange({node.splitval0, cur.end, node.childID})); + } + // <= side + if (node.properties[1] >= kNumStaticProperties) { + ranges.push_back( + TreeRange({node.splitvals[1], node.splitval0, node.childID + 2})); + ranges.push_back( + TreeRange({cur.begin, node.splitvals[1], node.childID + 3})); + } else { + ranges.push_back( + TreeRange({cur.begin, node.splitval0, node.childID + 2})); + } + } + return true; +} +// TODO(veluca): make cleaner interfaces. + +Status ValidateChannelDimensions(const Image &image, + const ModularOptions &options); + +Status ModularGenericDecompress(BitReader *br, Image &image, + GroupHeader *header, size_t group_id, + ModularOptions *options, + bool undo_transforms = true, + const Tree *tree = nullptr, + const ANSCode *code = nullptr, + const std::vector<uint8_t> *ctx_map = nullptr, + bool allow_truncated_group = false); +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_ENCODING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h b/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h new file mode 100644 index 0000000000..71b7847321 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/encoding/ma_common.h @@ -0,0 +1,28 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ +#define LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ + +#include <stddef.h> + +namespace jxl { + +enum MATreeContext : size_t { + kSplitValContext = 0, + kPropertyContext = 1, + kPredictorContext = 2, + kOffsetContext = 3, + kMultiplierLogContext = 4, + kMultiplierBitsContext = 5, + + kNumTreeContexts = 6, +}; + +static constexpr size_t kMaxTreeSize = 1 << 22; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_ENCODING_MA_COMMON_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc b/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc new file mode 100644 index 0000000000..746d7c87fd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/modular_image.cc @@ -0,0 +1,78 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/modular_image.h" + +#include <sstream> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +void Image::undo_transforms(const weighted::Header &wp_header, + jxl::ThreadPool *pool) { + while (!transform.empty()) { + Transform t = transform.back(); + JXL_DEBUG_V(4, "Undoing transform"); + Status result = t.Inverse(*this, wp_header, pool); + if (result == false) { + JXL_NOTIFY_ERROR("Error while undoing transform."); + error = true; + return; + } + JXL_DEBUG_V(8, "Undoing transform: done"); + transform.pop_back(); + } +} + +Image::Image(size_t iw, size_t ih, int bitdepth, int nb_chans) + : w(iw), h(ih), bitdepth(bitdepth), nb_meta_channels(0), error(false) { + for (int i = 0; i < nb_chans; i++) channel.emplace_back(Channel(iw, ih)); +} + +Image::Image() : w(0), h(0), bitdepth(8), nb_meta_channels(0), error(true) {} + +Image &Image::operator=(Image &&other) noexcept { + w = other.w; + h = other.h; + bitdepth = other.bitdepth; + nb_meta_channels = other.nb_meta_channels; + error = other.error; + channel = std::move(other.channel); + transform = std::move(other.transform); + return *this; +} + +Image Image::clone() { + Image c(w, h, bitdepth, 0); + c.nb_meta_channels = nb_meta_channels; + c.error = error; + c.transform = transform; + for (Channel &ch : channel) { + Channel a(ch.w, ch.h, ch.hshift, ch.vshift); + CopyImageTo(ch.plane, &a.plane); + c.channel.push_back(std::move(a)); + } + return c; +} + +#if JXL_DEBUG_V_LEVEL >= 1 +std::string Image::DebugString() const { + std::ostringstream os; + os << w << "x" << h << ", depth: " << bitdepth; + if (!channel.empty()) { + os << ", channels:"; + for (size_t i = 0; i < channel.size(); ++i) { + os << " " << channel[i].w << "x" << channel[i].h + << "(shift: " << channel[i].hshift << "," << channel[i].vshift << ")"; + if (i < nb_meta_channels) os << "*"; + } + } + return os.str(); +} +#endif + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/modular_image.h b/third_party/jpeg-xl/lib/jxl/modular/modular_image.h new file mode 100644 index 0000000000..56e80d823a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/modular_image.h @@ -0,0 +1,117 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_MODULAR_IMAGE_H_ +#define LIB_JXL_MODULAR_MODULAR_IMAGE_H_ + +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#include <string> +#include <utility> +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { + +typedef int32_t pixel_type; // can use int16_t if it's only for 8-bit images. + // Need some wiggle room for YCoCg / Squeeze etc + +typedef int64_t pixel_type_w; + +namespace weighted { +struct Header; +} + +class Channel { + public: + jxl::Plane<pixel_type> plane; + size_t w, h; + int hshift, vshift; // w ~= image.w >> hshift; h ~= image.h >> vshift + Channel(size_t iw, size_t ih, int hsh = 0, int vsh = 0) + : plane(iw, ih), w(iw), h(ih), hshift(hsh), vshift(vsh) {} + + Channel(const Channel& other) = delete; + Channel& operator=(const Channel& other) = delete; + + // Move assignment + Channel& operator=(Channel&& other) noexcept { + w = other.w; + h = other.h; + hshift = other.hshift; + vshift = other.vshift; + plane = std::move(other.plane); + return *this; + } + + // Move constructor + Channel(Channel&& other) noexcept = default; + + void shrink() { + if (plane.xsize() == w && plane.ysize() == h) return; + jxl::Plane<pixel_type> resizedplane(w, h); + plane = std::move(resizedplane); + } + void shrink(int nw, int nh) { + w = nw; + h = nh; + shrink(); + } + + JXL_INLINE pixel_type* Row(const size_t y) { return plane.Row(y); } + JXL_INLINE const pixel_type* Row(const size_t y) const { + return plane.Row(y); + } +}; + +class Transform; + +class Image { + public: + // image data, transforms can dramatically change the number of channels and + // their semantics + std::vector<Channel> channel; + // transforms that have been applied (and that have to be undone) + std::vector<Transform> transform; + + // image dimensions (channels may have different dimensions due to transforms) + size_t w, h; + int bitdepth; + size_t nb_meta_channels; // first few channels might contain palette(s) + bool error; // true if a fatal error occurred, false otherwise + + Image(size_t iw, size_t ih, int bitdepth, int nb_chans); + Image(); + + Image(const Image& other) = delete; + Image& operator=(const Image& other) = delete; + + Image& operator=(Image&& other) noexcept; + Image(Image&& other) noexcept = default; + + bool empty() const { + for (const auto& ch : channel) { + if (ch.w && ch.h) return false; + } + return true; + } + + Image clone(); + + void undo_transforms(const weighted::Header& wp_header, + jxl::ThreadPool* pool = nullptr); + + std::string DebugString() const; +}; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_MODULAR_IMAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/options.h b/third_party/jpeg-xl/lib/jxl/modular/options.h new file mode 100644 index 0000000000..ce6596b912 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/options.h @@ -0,0 +1,117 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_OPTIONS_H_ +#define LIB_JXL_MODULAR_OPTIONS_H_ + +#include <stdint.h> + +#include <array> +#include <vector> + +namespace jxl { + +using PropertyVal = int32_t; +using Properties = std::vector<PropertyVal>; + +enum class Predictor : uint32_t { + Zero = 0, + Left = 1, + Top = 2, + Average0 = 3, + Select = 4, + Gradient = 5, + Weighted = 6, + TopRight = 7, + TopLeft = 8, + LeftLeft = 9, + Average1 = 10, + Average2 = 11, + Average3 = 12, + Average4 = 13, + // The following predictors are encoder-only. + Best = 14, // Best of Gradient and Weighted + Variable = + 15, // Find the best decision tree for predictors/predictor per row +}; + +constexpr size_t kNumModularPredictors = + static_cast<size_t>(Predictor::Average4) + 1; +constexpr size_t kNumModularEncoderPredictors = + static_cast<size_t>(Predictor::Variable) + 1; + +static constexpr ssize_t kNumStaticProperties = 2; // channel, group_id. + +using StaticPropRange = + std::array<std::array<uint32_t, 2>, kNumStaticProperties>; + +struct ModularMultiplierInfo { + StaticPropRange range; + uint32_t multiplier; +}; + +struct ModularOptions { + /// Used in both encode and decode: + + // Stop encoding/decoding when reaching a (non-meta) channel that has a + // dimension bigger than max_chan_size. + size_t max_chan_size = 0xFFFFFF; + + // Used during decoding for validation of transforms (sqeeezing) scheme. + size_t group_dim = 0x1FFFFFFF; + + /// Encode options: + // Fraction of pixels to look at to learn a MA tree + // Number of iterations to do to learn a MA tree + // (if zero there is no MA context model) + float nb_repeats = .5f; + + // Maximum number of (previous channel) properties to use in the MA trees + int max_properties = 0; // no previous channels + + // Alternative heuristic tweaks. + // Properties default to channel, group, weighted, gradient residual, W-NW, + // NW-N, N-NE, N-NN + std::vector<uint32_t> splitting_heuristics_properties = {0, 1, 15, 9, + 10, 11, 12, 13}; + float splitting_heuristics_node_threshold = 96; + size_t max_property_values = 32; + + // Predictor to use for each channel. + Predictor predictor = static_cast<Predictor>(-1); + + int wp_mode = 0; + + float fast_decode_multiplier = 1.01f; + + // Forces the encoder to produce a tree that is compatible with the WP-only + // decode path (or with the no-wp path, or the gradient-only path). + enum class TreeMode { kGradientOnly, kWPOnly, kNoWP, kDefault }; + TreeMode wp_tree_mode = TreeMode::kDefault; + + // Skip fast paths in the encoder. + bool skip_encoder_fast_path = false; + + // Kind of tree to use. + // TODO(veluca): add tree kinds for JPEG recompression with CfL enabled, + // general AC metadata, different DC qualities, and others. + enum class TreeKind { + kTrivialTreeNoPredictor, + kLearn, + kJpegTranscodeACMeta, + kFalconACMeta, + kACMeta, + kWPFixedDC, + kGradientFixedDC, + }; + TreeKind tree_kind = TreeKind::kLearn; + + // Ignore the image and just pretend all tokens are zeroes + bool zero_tokens = false; +}; + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_OPTIONS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc new file mode 100644 index 0000000000..f5172aa126 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.cc @@ -0,0 +1,595 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_palette.h" + +#include <array> +#include <map> +#include <set> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/enc_transform.h" +#include "lib/jxl/modular/transform/palette.h" + +namespace jxl { + +namespace palette_internal { + +static constexpr bool kEncodeToHighQualityImplicitPalette = true; + +// Inclusive. +static constexpr int kMinImplicitPaletteIndex = -(2 * 72 - 1); + +float ColorDistance(const std::vector<float> &JXL_RESTRICT a, + const std::vector<pixel_type> &JXL_RESTRICT b) { + JXL_ASSERT(a.size() == b.size()); + float distance = 0; + float ave3 = 0; + if (a.size() >= 3) { + ave3 = (a[0] + b[0] + a[1] + b[1] + a[2] + b[2]) * (1.21f / 3.0f); + } + float sum_a = 0, sum_b = 0; + for (size_t c = 0; c < a.size(); ++c) { + const float difference = + static_cast<float>(a[c]) - static_cast<float>(b[c]); + float weight = c == 0 ? 3 : c == 1 ? 5 : 2; + if (c < 3 && (a[c] + b[c] >= ave3)) { + const float add_w[3] = { + 1.15, + 1.15, + 1.12, + }; + weight += add_w[c]; + if (c == 2 && ((a[2] + b[2]) < 1.22 * ave3)) { + weight -= 0.5; + } + } + distance += difference * difference * weight * weight; + const int sum_weight = c == 0 ? 3 : c == 1 ? 5 : 1; + sum_a += a[c] * sum_weight; + sum_b += b[c] * sum_weight; + } + distance *= 4; + float sum_difference = sum_a - sum_b; + distance += sum_difference * sum_difference; + return distance; +} + +static int QuantizeColorToImplicitPaletteIndex( + const std::vector<pixel_type> &color, const int palette_size, + const int bit_depth, bool high_quality) { + int index = 0; + if (high_quality) { + int multiplier = 1; + for (size_t c = 0; c < color.size(); c++) { + int quantized = ((kLargeCube - 1) * color[c] + (1 << (bit_depth - 1))) / + ((1 << bit_depth) - 1); + JXL_ASSERT((quantized % kLargeCube) == quantized); + index += quantized * multiplier; + multiplier *= kLargeCube; + } + return index + palette_size + kLargeCubeOffset; + } else { + int multiplier = 1; + for (size_t c = 0; c < color.size(); c++) { + int value = color[c]; + value -= 1 << (std::max(0, bit_depth - 3)); + value = std::max(0, value); + int quantized = ((kLargeCube - 1) * value + (1 << (bit_depth - 1))) / + ((1 << bit_depth) - 1); + JXL_ASSERT((quantized % kLargeCube) == quantized); + if (quantized > kSmallCube - 1) { + quantized = kSmallCube - 1; + } + index += quantized * multiplier; + multiplier *= kSmallCube; + } + return index + palette_size; + } +} + +} // namespace palette_internal + +int RoundInt(int value, int div) { // symmetric rounding around 0 + if (value < 0) return -RoundInt(-value, div); + return (value + div / 2) / div; +} + +struct PaletteIterationData { + static constexpr int kMaxDeltas = 128; + bool final_run = false; + std::vector<pixel_type> deltas[3]; + std::vector<double> delta_distances; + std::vector<pixel_type> frequent_deltas[3]; + + // Populates `frequent_deltas` with items from `deltas` based on frequencies + // and color distances. + void FindFrequentColorDeltas(int num_pixels, int bitdepth) { + using pixel_type_3d = std::array<pixel_type, 3>; + std::map<pixel_type_3d, double> delta_frequency_map; + pixel_type bucket_size = 3 << std::max(0, bitdepth - 8); + // Store frequency weighted by delta distance from quantized value. + for (size_t i = 0; i < deltas[0].size(); ++i) { + pixel_type_3d delta = { + {RoundInt(deltas[0][i], bucket_size), + RoundInt(deltas[1][i], bucket_size), + RoundInt(deltas[2][i], bucket_size)}}; // a basic form of clustering + if (delta[0] == 0 && delta[1] == 0 && delta[2] == 0) continue; + delta_frequency_map[delta] += sqrt(sqrt(delta_distances[i])); + } + + const float delta_distance_multiplier = 1.0f / num_pixels; + + // Weigh frequencies by magnitude and normalize. + for (auto &delta_frequency : delta_frequency_map) { + std::vector<pixel_type> current_delta = {delta_frequency.first[0], + delta_frequency.first[1], + delta_frequency.first[2]}; + float delta_distance = + sqrt(palette_internal::ColorDistance({0, 0, 0}, current_delta)) + 1; + delta_frequency.second *= delta_distance * delta_distance_multiplier; + } + + // Sort by weighted frequency. + using pixel_type_3d_frequency = std::pair<pixel_type_3d, double>; + std::vector<pixel_type_3d_frequency> sorted_delta_frequency_map( + delta_frequency_map.begin(), delta_frequency_map.end()); + std::sort( + sorted_delta_frequency_map.begin(), sorted_delta_frequency_map.end(), + [](const pixel_type_3d_frequency &a, const pixel_type_3d_frequency &b) { + return a.second > b.second; + }); + + // Store the top deltas. + for (auto &delta_frequency : sorted_delta_frequency_map) { + if (frequent_deltas[0].size() >= kMaxDeltas) break; + // Number obtained by optimizing on jyrki31 corpus: + if (delta_frequency.second < 17) break; + for (int c = 0; c < 3; ++c) { + frequent_deltas[c].push_back(delta_frequency.first[c] * bucket_size); + } + } + } +}; + +Status FwdPaletteIteration(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t &nb_colors, uint32_t &nb_deltas, + bool ordered, bool lossy, Predictor &predictor, + const weighted::Header &wp_header, + PaletteIterationData &palette_iteration_data) { + JXL_QUIET_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, end_c)); + JXL_ASSERT(begin_c >= input.nb_meta_channels); + uint32_t nb = end_c - begin_c + 1; + + size_t w = input.channel[begin_c].w; + size_t h = input.channel[begin_c].h; + + if (!lossy && nb == 1) { + // Channel palette special case + if (nb_colors == 0) return false; + std::vector<pixel_type> lookup; + pixel_type minval, maxval; + compute_minmax(input.channel[begin_c], &minval, &maxval); + size_t lookup_table_size = + static_cast<int64_t>(maxval) - static_cast<int64_t>(minval) + 1; + if (lookup_table_size > palette_internal::kMaxPaletteLookupTableSize) { + // a lookup table would use too much memory, instead use a slower approach + // with std::set + std::set<pixel_type> chpalette; + pixel_type idx = 0; + for (size_t y = 0; y < h; y++) { + const pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + const bool new_color = chpalette.insert(p[x]).second; + if (new_color) { + idx++; + if (idx > (int)nb_colors) return false; + } + } + } + JXL_DEBUG_V(6, "Channel %i uses only %i colors.", begin_c, idx); + Channel pch(idx, 1); + pch.hshift = -1; + pch.vshift = -1; + nb_colors = idx; + idx = 0; + pixel_type *JXL_RESTRICT p_palette = pch.Row(0); + for (pixel_type p : chpalette) { + p_palette[idx++] = p; + } + for (size_t y = 0; y < h; y++) { + pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + for (idx = 0; p[x] != p_palette[idx] && idx < (int)nb_colors; idx++) { + } + JXL_DASSERT(idx < (int)nb_colors); + p[x] = idx; + } + } + predictor = Predictor::Zero; + input.nb_meta_channels++; + input.channel.insert(input.channel.begin(), std::move(pch)); + + return true; + } + lookup.resize(lookup_table_size, 0); + pixel_type idx = 0; + for (size_t y = 0; y < h; y++) { + const pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + if (lookup[p[x] - minval] == 0) { + lookup[p[x] - minval] = 1; + idx++; + if (idx > (int)nb_colors) return false; + } + } + } + JXL_DEBUG_V(6, "Channel %i uses only %i colors.", begin_c, idx); + Channel pch(idx, 1); + pch.hshift = -1; + pch.vshift = -1; + nb_colors = idx; + idx = 0; + pixel_type *JXL_RESTRICT p_palette = pch.Row(0); + for (size_t i = 0; i < lookup_table_size; i++) { + if (lookup[i]) { + p_palette[idx] = i + minval; + lookup[i] = idx; + idx++; + } + } + for (size_t y = 0; y < h; y++) { + pixel_type *p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) p[x] = lookup[p[x] - minval]; + } + predictor = Predictor::Zero; + input.nb_meta_channels++; + input.channel.insert(input.channel.begin(), std::move(pch)); + return true; + } + + Image quantized_input; + if (lossy) { + quantized_input = Image(w, h, input.bitdepth, nb); + for (size_t c = 0; c < nb; c++) { + CopyImageTo(input.channel[begin_c + c].plane, + &quantized_input.channel[c].plane); + } + } + + JXL_DEBUG_V( + 7, "Trying to represent channels %i-%i using at most a %i-color palette.", + begin_c, end_c, nb_colors); + nb_deltas = 0; + bool delta_used = false; + std::set<std::vector<pixel_type>> candidate_palette; + std::vector<std::vector<pixel_type>> candidate_palette_imageorder; + std::vector<pixel_type> color(nb); + std::vector<float> color_with_error(nb); + std::vector<const pixel_type *> p_in(nb); + std::map<std::vector<pixel_type>, size_t> inv_palette; + + if (lossy) { + palette_iteration_data.FindFrequentColorDeltas(w * h, input.bitdepth); + nb_deltas = palette_iteration_data.frequent_deltas[0].size(); + + // Count color frequency for colors that make a cross. + std::map<std::vector<pixel_type>, size_t> color_freq_map; + for (size_t y = 1; y + 1 < h; y++) { + for (uint32_t c = 0; c < nb; c++) { + p_in[c] = input.channel[begin_c + c].Row(y); + } + for (size_t x = 1; x + 1 < w; x++) { + for (uint32_t c = 0; c < nb; c++) { + color[c] = p_in[c][x]; + } + int offsets[4][2] = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}}; + bool makes_cross = true; + for (int i = 0; i < 4 && makes_cross; ++i) { + int dx = offsets[i][0]; + int dy = offsets[i][1]; + for (uint32_t c = 0; c < nb && makes_cross; c++) { + if (input.channel[begin_c + c].Row(y + dy)[x + dx] != color[c]) { + makes_cross = false; + } + } + } + if (makes_cross) color_freq_map[color] += 1; + } + } + // Add colors satisfying frequency condition to the palette. + constexpr float kImageFraction = 0.01f; + size_t color_frequency_lower_bound = 5 + input.h * input.w * kImageFraction; + for (const auto &color_freq : color_freq_map) { + if (color_freq.second > color_frequency_lower_bound) { + candidate_palette.insert(color_freq.first); + candidate_palette_imageorder.push_back(color_freq.first); + } + } + } + + for (size_t y = 0; y < h; y++) { + for (uint32_t c = 0; c < nb; c++) { + p_in[c] = input.channel[begin_c + c].Row(y); + } + for (size_t x = 0; x < w; x++) { + if (lossy && candidate_palette.size() >= nb_colors) break; + for (uint32_t c = 0; c < nb; c++) { + color[c] = p_in[c][x]; + } + const bool new_color = candidate_palette.insert(color).second; + if (new_color) { + candidate_palette_imageorder.push_back(color); + } + if (candidate_palette.size() > nb_colors) { + return false; // too many colors + } + } + } + + nb_colors = nb_deltas + candidate_palette.size(); + JXL_DEBUG_V(6, "Channels %i-%i can be represented using a %i-color palette.", + begin_c, end_c, nb_colors); + + Channel pch(nb_colors, nb); + pch.hshift = -1; + pch.vshift = -1; + pixel_type *JXL_RESTRICT p_palette = pch.Row(0); + intptr_t onerow = pch.plane.PixelsPerRow(); + intptr_t onerow_image = input.channel[begin_c].plane.PixelsPerRow(); + const int bit_depth = std::min(input.bitdepth, 24); + + if (lossy) { + for (uint32_t i = 0; i < nb_deltas; i++) { + for (size_t c = 0; c < 3; c++) { + p_palette[c * onerow + i] = + palette_iteration_data.frequent_deltas[c][i]; + } + } + } + + int x = 0; + if (ordered && nb >= 3) { + JXL_DEBUG_V(7, "Palette of %i colors, using luma order", nb_colors); + // sort on luma (multiplied by alpha if available) + std::sort(candidate_palette_imageorder.begin(), + candidate_palette_imageorder.end(), + [](std::vector<pixel_type> ap, std::vector<pixel_type> bp) { + float ay, by; + ay = (0.299f * ap[0] + 0.587f * ap[1] + 0.114f * ap[2] + 0.1f); + if (ap.size() > 3) ay *= 1.f + ap[3]; + by = (0.299f * bp[0] + 0.587f * bp[1] + 0.114f * bp[2] + 0.1f); + if (bp.size() > 3) by *= 1.f + bp[3]; + return ay < by; + }); + } else { + JXL_DEBUG_V(7, "Palette of %i colors, using image order", nb_colors); + } + for (auto pcol : candidate_palette_imageorder) { + JXL_DEBUG_V(9, " Color %i : ", x); + for (size_t i = 0; i < nb; i++) { + p_palette[nb_deltas + i * onerow + x] = pcol[i]; + JXL_DEBUG_V(9, "%i ", pcol[i]); + } + inv_palette[pcol] = x; + x++; + } + std::vector<weighted::State> wp_states; + for (size_t c = 0; c < nb; c++) { + wp_states.emplace_back(wp_header, w, h); + } + std::vector<pixel_type *> p_quant(nb); + // Three rows of error for dithering: y to y + 2. + // Each row has two pixels of padding in the ends, which is + // beneficial for both precision and encoding speed. + std::vector<std::vector<float>> error_row[3]; + if (lossy) { + for (int i = 0; i < 3; ++i) { + error_row[i].resize(nb); + for (size_t c = 0; c < nb; ++c) { + error_row[i][c].resize(w + 4); + } + } + } + for (size_t y = 0; y < h; y++) { + for (size_t c = 0; c < nb; c++) { + p_in[c] = input.channel[begin_c + c].Row(y); + if (lossy) p_quant[c] = quantized_input.channel[c].Row(y); + } + pixel_type *JXL_RESTRICT p = input.channel[begin_c].Row(y); + for (size_t x = 0; x < w; x++) { + int index; + if (!lossy) { + for (size_t c = 0; c < nb; c++) color[c] = p_in[c][x]; + index = inv_palette[color]; + } else { + int best_index = 0; + bool best_is_delta = false; + float best_distance = std::numeric_limits<float>::infinity(); + std::vector<pixel_type> best_val(nb, 0); + std::vector<pixel_type> ideal_residual(nb, 0); + std::vector<pixel_type> quantized_val(nb); + std::vector<pixel_type> predictions(nb); + static const double kDiffusionMultiplier[] = {0.55, 0.75}; + for (int diffusion_index = 0; diffusion_index < 2; ++diffusion_index) { + for (size_t c = 0; c < nb; c++) { + color_with_error[c] = + p_in[c][x] + palette_iteration_data.final_run * + kDiffusionMultiplier[diffusion_index] * + error_row[0][c][x + 2]; + color[c] = Clamp1(lroundf(color_with_error[c]), 0l, + (1l << input.bitdepth) - 1); + } + + for (size_t c = 0; c < nb; ++c) { + predictions[c] = PredictNoTreeWP(w, p_quant[c] + x, onerow_image, x, + y, predictor, &wp_states[c]) + .guess; + } + const auto TryIndex = [&](const int index) { + for (size_t c = 0; c < nb; c++) { + quantized_val[c] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/nb_colors, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + if (index < static_cast<int>(nb_deltas)) { + quantized_val[c] += predictions[c]; + } + } + const float color_distance = + 32.0 / (1LL << std::max(0, 2 * (bit_depth - 8))) * + palette_internal::ColorDistance(color_with_error, + quantized_val); + float index_penalty = 0; + if (index == -1) { + index_penalty = -124; + } else if (index < 0) { + index_penalty = -2 * index; + } else if (index < static_cast<int>(nb_deltas)) { + index_penalty = 250; + } else if (index < static_cast<int>(nb_colors)) { + index_penalty = 150; + } else if (index < static_cast<int>(nb_colors) + + palette_internal::kLargeCubeOffset) { + index_penalty = 70; + } else { + index_penalty = 256; + } + const float distance = color_distance + index_penalty; + if (distance < best_distance) { + best_distance = distance; + best_index = index; + best_is_delta = index < static_cast<int>(nb_deltas); + best_val.swap(quantized_val); + for (size_t c = 0; c < nb; ++c) { + ideal_residual[c] = color_with_error[c] - predictions[c]; + } + } + }; + for (index = palette_internal::kMinImplicitPaletteIndex; + index < static_cast<int32_t>(nb_colors); index++) { + TryIndex(index); + } + TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex( + color, nb_colors, bit_depth, + /*high_quality=*/false)); + if (palette_internal::kEncodeToHighQualityImplicitPalette) { + TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex( + color, nb_colors, bit_depth, + /*high_quality=*/true)); + } + } + index = best_index; + delta_used |= best_is_delta; + if (!palette_iteration_data.final_run) { + for (size_t c = 0; c < 3; ++c) { + palette_iteration_data.deltas[c].push_back(ideal_residual[c]); + } + palette_iteration_data.delta_distances.push_back(best_distance); + } + + for (size_t c = 0; c < nb; ++c) { + wp_states[c].UpdateErrors(best_val[c], x, y, w); + p_quant[c][x] = best_val[c]; + } + float len_error = 0; + for (size_t c = 0; c < nb; ++c) { + float local_error = color_with_error[c] - best_val[c]; + len_error += local_error * local_error; + } + len_error = sqrt(len_error); + float modulate = 1.0; + int len_limit = 38 << std::max(0, bit_depth - 8); + if (len_error > len_limit) { + modulate *= len_limit / len_error; + } + for (size_t c = 0; c < nb; ++c) { + float total_error = (color_with_error[c] - best_val[c]); + + // If the neighboring pixels have some error in the opposite + // direction of total_error, cancel some or all of it out before + // spreading among them. + constexpr int offsets[12][2] = {{1, 2}, {0, 3}, {0, 4}, {1, 1}, + {1, 3}, {2, 2}, {1, 0}, {1, 4}, + {2, 1}, {2, 3}, {2, 0}, {2, 4}}; + float total_available = 0; + for (int i = 0; i < 11; ++i) { + const int row = offsets[i][0]; + const int col = offsets[i][1]; + if (std::signbit(error_row[row][c][x + col]) != + std::signbit(total_error)) { + total_available += error_row[row][c][x + col]; + } + } + float weight = + std::abs(total_error) / (std::abs(total_available) + 1e-3); + weight = std::min(weight, 1.0f); + for (int i = 0; i < 11; ++i) { + const int row = offsets[i][0]; + const int col = offsets[i][1]; + if (std::signbit(error_row[row][c][x + col]) != + std::signbit(total_error)) { + total_error += weight * error_row[row][c][x + col]; + error_row[row][c][x + col] *= (1 - weight); + } + } + total_error *= modulate; + const float remaining_error = (1.0f / 14.) * total_error; + error_row[0][c][x + 3] += 2 * remaining_error; + error_row[0][c][x + 4] += remaining_error; + error_row[1][c][x + 0] += remaining_error; + for (int i = 0; i < 5; ++i) { + error_row[1][c][x + i] += remaining_error; + error_row[2][c][x + i] += remaining_error; + } + } + } + if (palette_iteration_data.final_run) p[x] = index; + } + if (lossy) { + for (size_t c = 0; c < nb; ++c) { + error_row[0][c].swap(error_row[1][c]); + error_row[1][c].swap(error_row[2][c]); + std::fill(error_row[2][c].begin(), error_row[2][c].end(), 0.f); + } + } + } + if (!delta_used) { + predictor = Predictor::Zero; + } + if (palette_iteration_data.final_run) { + input.nb_meta_channels++; + input.channel.erase(input.channel.begin() + begin_c + 1, + input.channel.begin() + end_c + 1); + input.channel.insert(input.channel.begin(), std::move(pch)); + } + nb_colors -= nb_deltas; + return true; +} + +Status FwdPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t &nb_colors, uint32_t &nb_deltas, bool ordered, + bool lossy, Predictor &predictor, + const weighted::Header &wp_header) { + PaletteIterationData palette_iteration_data; + uint32_t nb_colors_orig = nb_colors; + uint32_t nb_deltas_orig = nb_deltas; + // preprocessing pass in case of lossy palette + if (lossy && input.bitdepth >= 8) { + JXL_RETURN_IF_ERROR(FwdPaletteIteration( + input, begin_c, end_c, nb_colors_orig, nb_deltas_orig, ordered, lossy, + predictor, wp_header, palette_iteration_data)); + } + palette_iteration_data.final_run = true; + return FwdPaletteIteration(input, begin_c, end_c, nb_colors, nb_deltas, + ordered, lossy, predictor, wp_header, + palette_iteration_data); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.h new file mode 100644 index 0000000000..0f3d66825b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_palette.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_ + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +Status FwdPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t &nb_colors, uint32_t &nb_deltas, bool ordered, + bool lossy, Predictor &predictor, + const weighted::Header &wp_header); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_PALETTE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.cc new file mode 100644 index 0000000000..64930272db --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.cc @@ -0,0 +1,72 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_rct.h" + +#include "lib/jxl/base/status.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels + +namespace jxl { + +Status FwdRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) { + JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, begin_c + 2)); + if (rct_type == 0) { // noop + return false; + } + // Permutation: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR + int permutation = rct_type / 7; + // 0-5 values have the low bit corresponding to Third and the high bits + // corresponding to Second. 6 corresponds to YCoCg. + // + // Second: 0=nop, 1=SubtractFirst, 2=SubtractAvgFirstThird + // + // Third: 0=nop, 1=SubtractFirst + int custom = rct_type % 7; + size_t m = begin_c; + size_t w = input.channel[m + 0].w; + size_t h = input.channel[m + 0].h; + int second = (custom % 7) >> 1; + int third = (custom % 7) & 1; + const auto do_rct = [&](const int y, const int thread) { + const pixel_type* in0 = input.channel[m + (permutation % 3)].Row(y); + const pixel_type* in1 = + input.channel[m + ((permutation + 1 + permutation / 3) % 3)].Row(y); + const pixel_type* in2 = + input.channel[m + ((permutation + 2 - permutation / 3) % 3)].Row(y); + pixel_type* out0 = input.channel[m].Row(y); + pixel_type* out1 = input.channel[m + 1].Row(y); + pixel_type* out2 = input.channel[m + 2].Row(y); + if (custom == 6) { + for (size_t x = 0; x < w; x++) { + pixel_type R = in0[x]; + pixel_type G = in1[x]; + pixel_type B = in2[x]; + out1[x] = R - B; + pixel_type tmp = B + (out1[x] >> 1); + out2[x] = G - tmp; + out0[x] = tmp + (out2[x] >> 1); + } + } else { + for (size_t x = 0; x < w; x++) { + pixel_type First = in0[x]; + pixel_type Second = in1[x]; + pixel_type Third = in2[x]; + if (second == 1) { + Second = Second - First; + } else if (second == 2) { + Second = Second - ((First + Third) >> 1); + } + if (third) Third = Third - First; + out0[x] = First; + out1[x] = Second; + out2[x] = Third; + } + } + }; + return RunOnPool(pool, 0, h, ThreadPool::NoInit, do_rct, "FwdRCT"); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.h new file mode 100644 index 0000000000..cb5a193c8d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_rct.h @@ -0,0 +1,17 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_ + +#include "lib/jxl/modular/modular_image.h" + +namespace jxl { + +Status FwdRCT(Image &input, size_t begin_c, size_t rct_type, ThreadPool *pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_RCT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.cc new file mode 100644 index 0000000000..489f72a90d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.cc @@ -0,0 +1,140 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_squeeze.h" + +#include <stdlib.h> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/squeeze.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +void FwdHSqueeze(Image &input, int c, int rc) { + const Channel &chin = input.channel[c]; + + JXL_DEBUG_V(4, "Doing horizontal squeeze of channel %i to new channel %i", c, + rc); + + Channel chout((chin.w + 1) / 2, chin.h, chin.hshift + 1, chin.vshift); + Channel chout_residual(chin.w - chout.w, chout.h, chin.hshift + 1, + chin.vshift); + + for (size_t y = 0; y < chout.h; y++) { + const pixel_type *JXL_RESTRICT p_in = chin.Row(y); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + pixel_type *JXL_RESTRICT p_res = chout_residual.Row(y); + for (size_t x = 0; x < chout_residual.w; x++) { + pixel_type A = p_in[x * 2]; + pixel_type B = p_in[x * 2 + 1]; + pixel_type avg = (A + B + (A > B)) >> 1; + p_out[x] = avg; + + pixel_type diff = A - B; + + pixel_type next_avg = avg; + if (x + 1 < chout_residual.w) { + next_avg = (p_in[x * 2 + 2] + p_in[x * 2 + 3] + + (p_in[x * 2 + 2] > p_in[x * 2 + 3])) >> + 1; // which will be chout.value(y,x+1) + } else if (chin.w & 1) + next_avg = p_in[x * 2 + 2]; + pixel_type left = (x > 0 ? p_in[x * 2 - 1] : avg); + pixel_type tendency = SmoothTendency(left, avg, next_avg); + + p_res[x] = diff - tendency; + } + if (chin.w & 1) { + int x = chout.w - 1; + p_out[x] = p_in[x * 2]; + } + } + input.channel[c] = std::move(chout); + input.channel.insert(input.channel.begin() + rc, std::move(chout_residual)); +} + +void FwdVSqueeze(Image &input, int c, int rc) { + const Channel &chin = input.channel[c]; + + JXL_DEBUG_V(4, "Doing vertical squeeze of channel %i to new channel %i", c, + rc); + + Channel chout(chin.w, (chin.h + 1) / 2, chin.hshift, chin.vshift + 1); + Channel chout_residual(chin.w, chin.h - chout.h, chin.hshift, + chin.vshift + 1); + intptr_t onerow_in = chin.plane.PixelsPerRow(); + for (size_t y = 0; y < chout_residual.h; y++) { + const pixel_type *JXL_RESTRICT p_in = chin.Row(y * 2); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + pixel_type *JXL_RESTRICT p_res = chout_residual.Row(y); + for (size_t x = 0; x < chout.w; x++) { + pixel_type A = p_in[x]; + pixel_type B = p_in[x + onerow_in]; + pixel_type avg = (A + B + (A > B)) >> 1; + p_out[x] = avg; + + pixel_type diff = A - B; + + pixel_type next_avg = avg; + if (y + 1 < chout_residual.h) { + next_avg = (p_in[x + 2 * onerow_in] + p_in[x + 3 * onerow_in] + + (p_in[x + 2 * onerow_in] > p_in[x + 3 * onerow_in])) >> + 1; // which will be chout.value(y+1,x) + } else if (chin.h & 1) { + next_avg = p_in[x + 2 * onerow_in]; + } + pixel_type top = + (y > 0 ? p_in[static_cast<ssize_t>(x) - onerow_in] : avg); + pixel_type tendency = SmoothTendency(top, avg, next_avg); + + p_res[x] = diff - tendency; + } + } + if (chin.h & 1) { + size_t y = chout.h - 1; + const pixel_type *p_in = chin.Row(y * 2); + pixel_type *p_out = chout.Row(y); + for (size_t x = 0; x < chout.w; x++) { + p_out[x] = p_in[x]; + } + } + input.channel[c] = std::move(chout); + input.channel.insert(input.channel.begin() + rc, std::move(chout_residual)); +} + +Status FwdSqueeze(Image &input, std::vector<SqueezeParams> parameters, + ThreadPool *pool) { + if (parameters.empty()) { + DefaultSqueezeParameters(¶meters, input); + } + // if nothing to do, don't do squeeze + if (parameters.empty()) return false; + for (size_t i = 0; i < parameters.size(); i++) { + JXL_RETURN_IF_ERROR( + CheckMetaSqueezeParams(parameters[i], input.channel.size())); + bool horizontal = parameters[i].horizontal; + bool in_place = parameters[i].in_place; + uint32_t beginc = parameters[i].begin_c; + uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1; + uint32_t offset; + if (in_place) { + offset = endc + 1; + } else { + offset = input.channel.size(); + } + for (uint32_t c = beginc; c <= endc; c++) { + if (horizontal) { + FwdHSqueeze(input, c, offset + c - beginc); + } else { + FwdVSqueeze(input, c, offset + c - beginc); + } + } + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.h new file mode 100644 index 0000000000..39b001017b --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_squeeze.h @@ -0,0 +1,20 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_ + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +Status FwdSqueeze(Image &input, std::vector<SqueezeParams> parameters, + ThreadPool *pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_SQUEEZE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.cc new file mode 100644 index 0000000000..bdaaf9f87e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.cc @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/enc_transform.h" + +#include "lib/jxl/modular/transform/enc_palette.h" +#include "lib/jxl/modular/transform/enc_rct.h" +#include "lib/jxl/modular/transform/enc_squeeze.h" + +namespace jxl { + +Status TransformForward(Transform &t, Image &input, + const weighted::Header &wp_header, ThreadPool *pool) { + switch (t.id) { + case TransformId::kRCT: + return FwdRCT(input, t.begin_c, t.rct_type, pool); + case TransformId::kSqueeze: + return FwdSqueeze(input, t.squeezes, pool); + case TransformId::kPalette: + return FwdPalette(input, t.begin_c, t.begin_c + t.num_c - 1, t.nb_colors, + t.nb_deltas, t.ordered_palette, t.lossy_palette, + t.predictor, wp_header); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast<unsigned int>(t.id)); + } +} + +void compute_minmax(const Channel &ch, pixel_type *min, pixel_type *max) { + pixel_type realmin = std::numeric_limits<pixel_type>::max(); + pixel_type realmax = std::numeric_limits<pixel_type>::min(); + for (size_t y = 0; y < ch.h; y++) { + const pixel_type *JXL_RESTRICT p = ch.Row(y); + for (size_t x = 0; x < ch.w; x++) { + if (p[x] < realmin) realmin = p[x]; + if (p[x] > realmax) realmax = p[x]; + } + } + + if (min) *min = realmin; + if (max) *max = realmax; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.h b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.h new file mode 100644 index 0000000000..07659e1b0a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/enc_transform.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_ +#define LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_ + +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +namespace jxl { + +Status TransformForward(Transform &t, Image &input, + const weighted::Header &wp_header, ThreadPool *pool); + +void compute_minmax(const Channel &ch, pixel_type *min, pixel_type *max); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_ENC_TRANSFORM_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/palette.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.cc new file mode 100644 index 0000000000..bffbacf160 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.cc @@ -0,0 +1,177 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/palette.h" + +namespace jxl { + +Status InvPalette(Image &input, uint32_t begin_c, uint32_t nb_colors, + uint32_t nb_deltas, Predictor predictor, + const weighted::Header &wp_header, ThreadPool *pool) { + if (input.nb_meta_channels < 1) { + return JXL_FAILURE("Error: Palette transform without palette."); + } + std::atomic<int> num_errors{0}; + int nb = input.channel[0].h; + uint32_t c0 = begin_c + 1; + if (c0 >= input.channel.size()) { + return JXL_FAILURE("Channel is out of range."); + } + size_t w = input.channel[c0].w; + size_t h = input.channel[c0].h; + if (nb < 1) return JXL_FAILURE("Corrupted transforms"); + for (int i = 1; i < nb; i++) { + input.channel.insert( + input.channel.begin() + c0 + 1, + Channel(w, h, input.channel[c0].hshift, input.channel[c0].vshift)); + } + const Channel &palette = input.channel[0]; + const pixel_type *JXL_RESTRICT p_palette = input.channel[0].Row(0); + intptr_t onerow = input.channel[0].plane.PixelsPerRow(); + intptr_t onerow_image = input.channel[c0].plane.PixelsPerRow(); + const int bit_depth = std::min(input.bitdepth, 24); + + if (w == 0) { + // Nothing to do. + // Avoid touching "empty" channels with non-zero height. + } else if (nb_deltas == 0 && predictor == Predictor::Zero) { + if (nb == 1) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, h, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + pixel_type *p = input.channel[c0].Row(y); + for (size_t x = 0; x < w; x++) { + const int index = Clamp1<int>(p[x], 0, (pixel_type)palette.w - 1); + p[x] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/0, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + } + }, + "UndoChannelPalette")); + } else { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, h, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + std::vector<pixel_type *> p_out(nb); + const pixel_type *p_index = input.channel[c0].Row(y); + for (int c = 0; c < nb; c++) + p_out[c] = input.channel[c0 + c].Row(y); + for (size_t x = 0; x < w; x++) { + const int index = p_index[x]; + for (int c = 0; c < nb; c++) { + p_out[c][x] = palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + } + } + }, + "UndoPalette")); + } + } else { + // Parallelized per channel. + ImageI indices = std::move(input.channel[c0].plane); + input.channel[c0].plane = ImageI(indices.xsize(), indices.ysize()); + if (predictor == Predictor::Weighted) { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, nb, ThreadPool::NoInit, + [&](const uint32_t c, size_t /* thread */) { + Channel &channel = input.channel[c0 + c]; + weighted::State wp_state(wp_header, channel.w, channel.h); + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + const pixel_type *JXL_RESTRICT idx = indices.Row(y); + for (size_t x = 0; x < channel.w; x++) { + int index = idx[x]; + pixel_type_w val = 0; + const pixel_type palette_entry = + palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, /*onerow=*/onerow, + /*bit_depth=*/bit_depth); + if (index < static_cast<int32_t>(nb_deltas)) { + PredictionResult pred = + PredictNoTreeWP(channel.w, p + x, onerow_image, x, y, + predictor, &wp_state); + val = pred.guess + palette_entry; + } else { + val = palette_entry; + } + p[x] = val; + wp_state.UpdateErrors(p[x], x, y, channel.w); + } + } + }, + "UndoDeltaPaletteWP")); + } else { + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, nb, ThreadPool::NoInit, + [&](const uint32_t c, size_t /* thread */) { + Channel &channel = input.channel[c0 + c]; + for (size_t y = 0; y < channel.h; y++) { + pixel_type *JXL_RESTRICT p = channel.Row(y); + const pixel_type *JXL_RESTRICT idx = indices.Row(y); + for (size_t x = 0; x < channel.w; x++) { + int index = idx[x]; + pixel_type_w val = 0; + const pixel_type palette_entry = + palette_internal::GetPaletteValue( + p_palette, index, /*c=*/c, + /*palette_size=*/palette.w, + /*onerow=*/onerow, /*bit_depth=*/bit_depth); + if (index < static_cast<int32_t>(nb_deltas)) { + PredictionResult pred = PredictNoTreeNoWP( + channel.w, p + x, onerow_image, x, y, predictor); + val = pred.guess + palette_entry; + } else { + val = palette_entry; + } + p[x] = val; + } + } + }, + "UndoDeltaPaletteNoWP")); + } + } + if (c0 >= input.nb_meta_channels) { + // Palette was done on normal channels + input.nb_meta_channels--; + } else { + // Palette was done on metachannels + JXL_ASSERT(static_cast<int>(input.nb_meta_channels) >= 2 - nb); + input.nb_meta_channels -= 2 - nb; + JXL_ASSERT(begin_c + nb - 1 < input.nb_meta_channels); + } + input.channel.erase(input.channel.begin(), input.channel.begin() + 1); + return num_errors.load(std::memory_order_relaxed) == 0; +} + +Status MetaPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t nb_colors, uint32_t nb_deltas, bool lossy) { + JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, end_c)); + + size_t nb = end_c - begin_c + 1; + if (begin_c >= input.nb_meta_channels) { + // Palette was done on normal channels + input.nb_meta_channels++; + } else { + // Palette was done on metachannels + JXL_ASSERT(end_c < input.nb_meta_channels); + // we remove nb-1 metachannels and add one + input.nb_meta_channels += 2 - nb; + } + input.channel.erase(input.channel.begin() + begin_c + 1, + input.channel.begin() + end_c + 1); + Channel pch(nb_colors + nb_deltas, nb); + pch.hshift = -1; + pch.vshift = -1; + input.channel.insert(input.channel.begin(), std::move(pch)); + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h new file mode 100644 index 0000000000..279ef04568 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/palette.h @@ -0,0 +1,128 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ + +#include <atomic> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels + +namespace jxl { + +namespace palette_internal { + +static constexpr int kMaxPaletteLookupTableSize = 1 << 16; + +static constexpr int kRgbChannels = 3; + +// 5x5x5 color cube for the larger cube. +static constexpr int kLargeCube = 5; + +// Smaller interleaved color cube to fill the holes of the larger cube. +static constexpr int kSmallCube = 4; +static constexpr int kSmallCubeBits = 2; +// kSmallCube ** 3 +static constexpr int kLargeCubeOffset = kSmallCube * kSmallCube * kSmallCube; + +static inline pixel_type Scale(uint64_t value, uint64_t bit_depth, + uint64_t denom) { + // return (value * ((static_cast<pixel_type_w>(1) << bit_depth) - 1)) / denom; + // We only call this function with kSmallCube or kLargeCube - 1 as denom, + // allowing us to avoid a division here. + JXL_ASSERT(denom == 4); + return (value * ((static_cast<uint64_t>(1) << bit_depth) - 1)) >> 2; +} + +// The purpose of this function is solely to extend the interpretation of +// palette indices to implicit values. If index < nb_deltas, indicating that the +// result is a delta palette entry, it is the responsibility of the caller to +// treat it as such. +static JXL_MAYBE_UNUSED pixel_type +GetPaletteValue(const pixel_type *const palette, int index, const size_t c, + const int palette_size, const int onerow, const int bit_depth) { + if (index < 0) { + static constexpr std::array<std::array<pixel_type, 3>, 72> kDeltaPalette = { + { + {{0, 0, 0}}, {{4, 4, 4}}, {{11, 0, 0}}, + {{0, 0, -13}}, {{0, -12, 0}}, {{-10, -10, -10}}, + {{-18, -18, -18}}, {{-27, -27, -27}}, {{-18, -18, 0}}, + {{0, 0, -32}}, {{-32, 0, 0}}, {{-37, -37, -37}}, + {{0, -32, -32}}, {{24, 24, 45}}, {{50, 50, 50}}, + {{-45, -24, -24}}, {{-24, -45, -45}}, {{0, -24, -24}}, + {{-34, -34, 0}}, {{-24, 0, -24}}, {{-45, -45, -24}}, + {{64, 64, 64}}, {{-32, 0, -32}}, {{0, -32, 0}}, + {{-32, 0, 32}}, {{-24, -45, -24}}, {{45, 24, 45}}, + {{24, -24, -45}}, {{-45, -24, 24}}, {{80, 80, 80}}, + {{64, 0, 0}}, {{0, 0, -64}}, {{0, -64, -64}}, + {{-24, -24, 45}}, {{96, 96, 96}}, {{64, 64, 0}}, + {{45, -24, -24}}, {{34, -34, 0}}, {{112, 112, 112}}, + {{24, -45, -45}}, {{45, 45, -24}}, {{0, -32, 32}}, + {{24, -24, 45}}, {{0, 96, 96}}, {{45, -24, 24}}, + {{24, -45, -24}}, {{-24, -45, 24}}, {{0, -64, 0}}, + {{96, 0, 0}}, {{128, 128, 128}}, {{64, 0, 64}}, + {{144, 144, 144}}, {{96, 96, 0}}, {{-36, -36, 36}}, + {{45, -24, -45}}, {{45, -45, -24}}, {{0, 0, -96}}, + {{0, 128, 128}}, {{0, 96, 0}}, {{45, 24, -45}}, + {{-128, 0, 0}}, {{24, -45, 24}}, {{-45, 24, -45}}, + {{64, 0, -64}}, {{64, -64, -64}}, {{96, 0, 96}}, + {{45, -45, 24}}, {{24, 45, -45}}, {{64, 64, -64}}, + {{128, 128, 0}}, {{0, 0, -128}}, {{-24, 45, -45}}, + }}; + if (c >= kRgbChannels) { + return 0; + } + // Do not open the brackets, otherwise INT32_MIN negation could overflow. + index = -(index + 1); + index %= 1 + 2 * (kDeltaPalette.size() - 1); + static constexpr int kMultiplier[] = {-1, 1}; + pixel_type result = + kDeltaPalette[((index + 1) >> 1)][c] * kMultiplier[index & 1]; + if (bit_depth > 8) { + result *= static_cast<pixel_type>(1) << (bit_depth - 8); + } + return result; + } else if (palette_size <= index && index < palette_size + kLargeCubeOffset) { + if (c >= kRgbChannels) return 0; + index -= palette_size; + index >>= c * kSmallCubeBits; + return Scale(index % kSmallCube, bit_depth, kSmallCube) + + (1 << (std::max(0, bit_depth - 3))); + } else if (palette_size + kLargeCubeOffset <= index) { + if (c >= kRgbChannels) return 0; + index -= palette_size + kLargeCubeOffset; + // TODO(eustas): should we take care of ambiguity created by + // index >= kLargeCube ** 3 ? + switch (c) { + case 0: + break; + case 1: + index /= kLargeCube; + break; + case 2: + index /= kLargeCube * kLargeCube; + break; + } + return Scale(index % kLargeCube, bit_depth, kLargeCube - 1); + } + return palette[c * onerow + static_cast<size_t>(index)]; +} + +} // namespace palette_internal + +Status InvPalette(Image &input, uint32_t begin_c, uint32_t nb_colors, + uint32_t nb_deltas, Predictor predictor, + const weighted::Header &wp_header, ThreadPool *pool); + +Status MetaPalette(Image &input, uint32_t begin_c, uint32_t end_c, + uint32_t nb_colors, uint32_t nb_deltas, bool lossy); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_PALETTE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/rct.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.cc new file mode 100644 index 0000000000..f3002a5ac3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.cc @@ -0,0 +1,153 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/rct.h" +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/rct.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; + +template <int transform_type> +void InvRCTRow(const pixel_type* in0, const pixel_type* in1, + const pixel_type* in2, pixel_type* out0, pixel_type* out1, + pixel_type* out2, size_t w) { + static_assert(transform_type >= 0 && transform_type < 7, + "Invalid transform type"); + int second = transform_type >> 1; + int third = transform_type & 1; + + size_t x = 0; + const HWY_FULL(pixel_type) d; + const size_t N = Lanes(d); + for (; x + N - 1 < w; x += N) { + if (transform_type == 6) { + auto Y = Load(d, in0 + x); + auto Co = Load(d, in1 + x); + auto Cg = Load(d, in2 + x); + Y = Sub(Y, ShiftRight<1>(Cg)); + auto G = Add(Cg, Y); + Y = Sub(Y, ShiftRight<1>(Co)); + auto R = Add(Y, Co); + Store(R, d, out0 + x); + Store(G, d, out1 + x); + Store(Y, d, out2 + x); + } else { + auto First = Load(d, in0 + x); + auto Second = Load(d, in1 + x); + auto Third = Load(d, in2 + x); + if (third) Third = Add(Third, First); + if (second == 1) { + Second = Add(Second, First); + } else if (second == 2) { + Second = Add(Second, ShiftRight<1>(Add(First, Third))); + } + Store(First, d, out0 + x); + Store(Second, d, out1 + x); + Store(Third, d, out2 + x); + } + } + for (; x < w; x++) { + if (transform_type == 6) { + pixel_type Y = in0[x]; + pixel_type Co = in1[x]; + pixel_type Cg = in2[x]; + pixel_type tmp = PixelAdd(Y, -(Cg >> 1)); + pixel_type G = PixelAdd(Cg, tmp); + pixel_type B = PixelAdd(tmp, -(Co >> 1)); + pixel_type R = PixelAdd(B, Co); + out0[x] = R; + out1[x] = G; + out2[x] = B; + } else { + pixel_type First = in0[x]; + pixel_type Second = in1[x]; + pixel_type Third = in2[x]; + if (third) Third = PixelAdd(Third, First); + if (second == 1) { + Second = PixelAdd(Second, First); + } else if (second == 2) { + Second = PixelAdd(Second, (PixelAdd(First, Third) >> 1)); + } + out0[x] = First; + out1[x] = Second; + out2[x] = Third; + } + } +} + +Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) { + JXL_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, begin_c + 2)); + size_t m = begin_c; + Channel& c0 = input.channel[m + 0]; + size_t w = c0.w; + size_t h = c0.h; + if (rct_type == 0) { // noop + return true; + } + // Permutation: 0=RGB, 1=GBR, 2=BRG, 3=RBG, 4=GRB, 5=BGR + int permutation = rct_type / 7; + JXL_CHECK(permutation < 6); + // 0-5 values have the low bit corresponding to Third and the high bits + // corresponding to Second. 6 corresponds to YCoCg. + // + // Second: 0=nop, 1=SubtractFirst, 2=SubtractAvgFirstThird + // + // Third: 0=nop, 1=SubtractFirst + int custom = rct_type % 7; + // Special case: permute-only. Swap channels around. + if (custom == 0) { + Channel ch0 = std::move(input.channel[m]); + Channel ch1 = std::move(input.channel[m + 1]); + Channel ch2 = std::move(input.channel[m + 2]); + input.channel[m + (permutation % 3)] = std::move(ch0); + input.channel[m + ((permutation + 1 + permutation / 3) % 3)] = + std::move(ch1); + input.channel[m + ((permutation + 2 - permutation / 3) % 3)] = + std::move(ch2); + return true; + } + constexpr decltype(&InvRCTRow<0>) inv_rct_row[] = { + InvRCTRow<0>, InvRCTRow<1>, InvRCTRow<2>, InvRCTRow<3>, + InvRCTRow<4>, InvRCTRow<5>, InvRCTRow<6>}; + JXL_RETURN_IF_ERROR(RunOnPool( + pool, 0, h, ThreadPool::NoInit, + [&](const uint32_t task, size_t /* thread */) { + const size_t y = task; + const pixel_type* in0 = input.channel[m].Row(y); + const pixel_type* in1 = input.channel[m + 1].Row(y); + const pixel_type* in2 = input.channel[m + 2].Row(y); + pixel_type* out0 = input.channel[m + (permutation % 3)].Row(y); + pixel_type* out1 = + input.channel[m + ((permutation + 1 + permutation / 3) % 3)].Row(y); + pixel_type* out2 = + input.channel[m + ((permutation + 2 - permutation / 3) % 3)].Row(y); + inv_rct_row[custom](in0, in1, in2, out0, out1, out2, w); + }, + "InvRCT")); + return true; +} + +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(InvRCT); +Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool) { + return HWY_DYNAMIC_DISPATCH(InvRCT)(input, begin_c, rct_type, pool); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/rct.h b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.h new file mode 100644 index 0000000000..1ab57fec69 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/rct.h @@ -0,0 +1,19 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_RCT_H_ +#define LIB_JXL_MODULAR_TRANSFORM_RCT_H_ + +#include "lib/jxl/base/status.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" // CheckEqualChannels + +namespace jxl { + +Status InvRCT(Image& input, size_t begin_c, size_t rct_type, ThreadPool* pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_RCT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.cc new file mode 100644 index 0000000000..e9892ea48f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.cc @@ -0,0 +1,478 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/squeeze.h" + +#include <stdlib.h> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/squeeze.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/simd_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Abs; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::And; +using hwy::HWY_NAMESPACE::Gt; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::IfThenZeroElse; +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::MulEven; +using hwy::HWY_NAMESPACE::Ne; +using hwy::HWY_NAMESPACE::Neg; +using hwy::HWY_NAMESPACE::OddEven; +using hwy::HWY_NAMESPACE::RebindToUnsigned; +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::Xor; + +#if HWY_TARGET != HWY_SCALAR + +JXL_INLINE void FastUnsqueeze(const pixel_type *JXL_RESTRICT p_residual, + const pixel_type *JXL_RESTRICT p_avg, + const pixel_type *JXL_RESTRICT p_navg, + const pixel_type *p_pout, + pixel_type *JXL_RESTRICT p_out, + pixel_type *p_nout) { + const HWY_CAPPED(pixel_type, 8) d; + const RebindToUnsigned<decltype(d)> du; + const size_t N = Lanes(d); + auto onethird = Set(d, 0x55555556); + for (size_t x = 0; x < 8; x += N) { + auto avg = Load(d, p_avg + x); + auto next_avg = Load(d, p_navg + x); + auto top = Load(d, p_pout + x); + // Equivalent to SmoothTendency(top,avg,next_avg), but without branches + auto Ba = Sub(top, avg); + auto an = Sub(avg, next_avg); + auto nonmono = Xor(Ba, an); + auto absBa = Abs(Ba); + auto absan = Abs(an); + auto absBn = Abs(Sub(top, next_avg)); + // Compute a3 = absBa / 3 + auto a3e = BitCast(d, ShiftRight<32>(MulEven(absBa, onethird))); + auto a3oi = MulEven(Reverse(d, absBa), onethird); + auto a3o = BitCast( + d, Reverse(hwy::HWY_NAMESPACE::Repartition<pixel_type_w, decltype(d)>(), + a3oi)); + auto a3 = OddEven(a3o, a3e); + a3 = Add(a3, Add(absBn, Set(d, 2))); + auto absdiff = ShiftRight<2>(a3); + auto skipdiff = Ne(Ba, Zero(d)); + skipdiff = And(skipdiff, Ne(an, Zero(d))); + skipdiff = And(skipdiff, Lt(nonmono, Zero(d))); + auto absBa2 = Add(ShiftLeft<1>(absBa), And(absdiff, Set(d, 1))); + absdiff = IfThenElse(Gt(absdiff, absBa2), + Add(ShiftLeft<1>(absBa), Set(d, 1)), absdiff); + auto absan2 = ShiftLeft<1>(absan); + absdiff = IfThenElse(Gt(Add(absdiff, And(absdiff, Set(d, 1))), absan2), + absan2, absdiff); + auto diff1 = IfThenElse(Lt(top, next_avg), Neg(absdiff), absdiff); + auto tendency = IfThenZeroElse(skipdiff, diff1); + + auto diff_minus_tendency = Load(d, p_residual + x); + auto diff = Add(diff_minus_tendency, tendency); + auto out = + Add(avg, ShiftRight<1>( + Add(diff, BitCast(d, ShiftRight<31>(BitCast(du, diff)))))); + Store(out, d, p_out + x); + Store(Sub(out, diff), d, p_nout + x); + } +} + +#endif + +Status InvHSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) { + JXL_ASSERT(c < input.channel.size()); + JXL_ASSERT(rc < input.channel.size()); + Channel &chin = input.channel[c]; + const Channel &chin_residual = input.channel[rc]; + // These must be valid since we ran MetaApply already. + JXL_ASSERT(chin.w == DivCeil(chin.w + chin_residual.w, 2)); + JXL_ASSERT(chin.h == chin_residual.h); + + if (chin_residual.w == 0) { + // Short-circuit: output channel has same dimensions as input. + input.channel[c].hshift--; + return true; + } + + // Note: chin.w >= chin_residual.w and at most 1 different. + Channel chout(chin.w + chin_residual.w, chin.h, chin.hshift - 1, chin.vshift); + JXL_DEBUG_V(4, + "Undoing horizontal squeeze of channel %i using residuals in " + "channel %i (going from width %" PRIuS " to %" PRIuS ")", + c, rc, chin.w, chout.w); + + if (chin_residual.h == 0) { + // Short-circuit: channel with no pixels. + input.channel[c] = std::move(chout); + return true; + } + auto unsqueeze_row = [&](size_t y, size_t x0) { + const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y); + const pixel_type *JXL_RESTRICT p_avg = chin.Row(y); + pixel_type *JXL_RESTRICT p_out = chout.Row(y); + for (size_t x = x0; x < chin_residual.w; x++) { + pixel_type_w diff_minus_tendency = p_residual[x]; + pixel_type_w avg = p_avg[x]; + pixel_type_w next_avg = (x + 1 < chin.w ? p_avg[x + 1] : avg); + pixel_type_w left = (x ? p_out[(x << 1) - 1] : avg); + pixel_type_w tendency = SmoothTendency(left, avg, next_avg); + pixel_type_w diff = diff_minus_tendency + tendency; + pixel_type_w A = avg + (diff / 2); + p_out[(x << 1)] = A; + pixel_type_w B = A - diff; + p_out[(x << 1) + 1] = B; + } + if (chout.w & 1) p_out[chout.w - 1] = p_avg[chin.w - 1]; + }; + + // somewhat complicated trickery just to be able to SIMD this. + // Horizontal unsqueeze has horizontal data dependencies, so we do + // 8 rows at a time and treat it as a vertical unsqueeze of a + // transposed 8x8 block (or 9x8 for one input). + static constexpr const size_t kRowsPerThread = 8; + const auto unsqueeze_span = [&](const uint32_t task, size_t /* thread */) { + const size_t y0 = task * kRowsPerThread; + const size_t rows = std::min(kRowsPerThread, chin.h - y0); + size_t x = 0; + +#if HWY_TARGET != HWY_SCALAR + intptr_t onerow_in = chin.plane.PixelsPerRow(); + intptr_t onerow_inr = chin_residual.plane.PixelsPerRow(); + intptr_t onerow_out = chout.plane.PixelsPerRow(); + const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y0); + const pixel_type *JXL_RESTRICT p_avg = chin.Row(y0); + pixel_type *JXL_RESTRICT p_out = chout.Row(y0); + HWY_ALIGN pixel_type b_p_avg[9 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_residual[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_even[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_odd[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_evenT[8 * kRowsPerThread]; + HWY_ALIGN pixel_type b_p_out_oddT[8 * kRowsPerThread]; + const HWY_CAPPED(pixel_type, 8) d; + const size_t N = Lanes(d); + if (chin_residual.w > 16 && rows == kRowsPerThread) { + for (; x < chin_residual.w - 9; x += 8) { + Transpose8x8Block(p_residual + x, b_p_residual, onerow_inr); + Transpose8x8Block(p_avg + x, b_p_avg, onerow_in); + for (size_t y = 0; y < kRowsPerThread; y++) { + b_p_avg[8 * 8 + y] = p_avg[x + 8 + onerow_in * y]; + } + for (size_t i = 0; i < 8; i++) { + FastUnsqueeze( + b_p_residual + 8 * i, b_p_avg + 8 * i, b_p_avg + 8 * (i + 1), + (x + i ? b_p_out_odd + 8 * ((x + i - 1) & 7) : b_p_avg + 8 * i), + b_p_out_even + 8 * i, b_p_out_odd + 8 * i); + } + + Transpose8x8Block(b_p_out_even, b_p_out_evenT, 8); + Transpose8x8Block(b_p_out_odd, b_p_out_oddT, 8); + for (size_t y = 0; y < kRowsPerThread; y++) { + for (size_t i = 0; i < kRowsPerThread; i += N) { + auto even = Load(d, b_p_out_evenT + 8 * y + i); + auto odd = Load(d, b_p_out_oddT + 8 * y + i); + StoreInterleaved(d, even, odd, + p_out + ((x + i) << 1) + onerow_out * y); + } + } + } + } +#endif + for (size_t y = 0; y < rows; y++) { + unsqueeze_row(y0 + y, x); + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.h, kRowsPerThread), + ThreadPool::NoInit, unsqueeze_span, + "InvHorizontalSqueeze")); + input.channel[c] = std::move(chout); + return true; +} + +Status InvVSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) { + JXL_ASSERT(c < input.channel.size()); + JXL_ASSERT(rc < input.channel.size()); + const Channel &chin = input.channel[c]; + const Channel &chin_residual = input.channel[rc]; + // These must be valid since we ran MetaApply already. + JXL_ASSERT(chin.h == DivCeil(chin.h + chin_residual.h, 2)); + JXL_ASSERT(chin.w == chin_residual.w); + + if (chin_residual.h == 0) { + // Short-circuit: output channel has same dimensions as input. + input.channel[c].vshift--; + return true; + } + + // Note: chin.h >= chin_residual.h and at most 1 different. + Channel chout(chin.w, chin.h + chin_residual.h, chin.hshift, chin.vshift - 1); + JXL_DEBUG_V( + 4, + "Undoing vertical squeeze of channel %i using residuals in channel " + "%i (going from height %" PRIuS " to %" PRIuS ")", + c, rc, chin.h, chout.h); + + if (chin_residual.w == 0) { + // Short-circuit: channel with no pixels. + input.channel[c] = std::move(chout); + return true; + } + + static constexpr const int kColsPerThread = 64; + const auto unsqueeze_slice = [&](const uint32_t task, size_t /* thread */) { + const size_t x0 = task * kColsPerThread; + const size_t x1 = std::min((size_t)(task + 1) * kColsPerThread, chin.w); + const size_t w = x1 - x0; + // We only iterate up to std::min(chin_residual.h, chin.h) which is + // always chin_residual.h. + for (size_t y = 0; y < chin_residual.h; y++) { + const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y) + x0; + const pixel_type *JXL_RESTRICT p_avg = chin.Row(y) + x0; + const pixel_type *JXL_RESTRICT p_navg = + chin.Row(y + 1 < chin.h ? y + 1 : y) + x0; + pixel_type *JXL_RESTRICT p_out = chout.Row(y << 1) + x0; + pixel_type *JXL_RESTRICT p_nout = chout.Row((y << 1) + 1) + x0; + const pixel_type *p_pout = y > 0 ? chout.Row((y << 1) - 1) + x0 : p_avg; + size_t x = 0; +#if HWY_TARGET != HWY_SCALAR + for (; x + 7 < w; x += 8) { + FastUnsqueeze(p_residual + x, p_avg + x, p_navg + x, p_pout + x, + p_out + x, p_nout + x); + } +#endif + for (; x < w; x++) { + pixel_type_w avg = p_avg[x]; + pixel_type_w next_avg = p_navg[x]; + pixel_type_w top = p_pout[x]; + pixel_type_w tendency = SmoothTendency(top, avg, next_avg); + pixel_type_w diff_minus_tendency = p_residual[x]; + pixel_type_w diff = diff_minus_tendency + tendency; + pixel_type_w out = avg + (diff / 2); + p_out[x] = out; + // If the chin_residual.h == chin.h, the output has an even number + // of rows so the next line is fine. Otherwise, this loop won't + // write to the last output row which is handled separately. + p_nout[x] = out - diff; + } + } + }; + JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.w, kColsPerThread), + ThreadPool::NoInit, unsqueeze_slice, + "InvVertSqueeze")); + + if (chout.h & 1) { + size_t y = chin.h - 1; + const pixel_type *p_avg = chin.Row(y); + pixel_type *p_out = chout.Row(y << 1); + for (size_t x = 0; x < chin.w; x++) { + p_out[x] = p_avg[x]; + } + } + input.channel[c] = std::move(chout); + return true; +} + +Status InvSqueeze(Image &input, std::vector<SqueezeParams> parameters, + ThreadPool *pool) { + for (int i = parameters.size() - 1; i >= 0; i--) { + JXL_RETURN_IF_ERROR( + CheckMetaSqueezeParams(parameters[i], input.channel.size())); + bool horizontal = parameters[i].horizontal; + bool in_place = parameters[i].in_place; + uint32_t beginc = parameters[i].begin_c; + uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1; + uint32_t offset; + if (in_place) { + offset = endc + 1; + } else { + offset = input.channel.size() + beginc - endc - 1; + } + if (beginc < input.nb_meta_channels) { + // This is checked in MetaSqueeze. + JXL_ASSERT(input.nb_meta_channels > parameters[i].num_c); + input.nb_meta_channels -= parameters[i].num_c; + } + + for (uint32_t c = beginc; c <= endc; c++) { + uint32_t rc = offset + c - beginc; + // MetaApply should imply that `rc` is within range, otherwise there's a + // programming bug. + JXL_ASSERT(rc < input.channel.size()); + if ((input.channel[c].w < input.channel[rc].w) || + (input.channel[c].h < input.channel[rc].h)) { + return JXL_FAILURE("Corrupted squeeze transform"); + } + if (horizontal) { + JXL_RETURN_IF_ERROR(InvHSqueeze(input, c, rc, pool)); + } else { + JXL_RETURN_IF_ERROR(InvVSqueeze(input, c, rc, pool)); + } + } + input.channel.erase(input.channel.begin() + offset, + input.channel.begin() + offset + (endc - beginc + 1)); + } + return true; +} + +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { + +HWY_EXPORT(InvSqueeze); +Status InvSqueeze(Image &input, std::vector<SqueezeParams> parameters, + ThreadPool *pool) { + return HWY_DYNAMIC_DISPATCH(InvSqueeze)(input, parameters, pool); +} + +void DefaultSqueezeParameters(std::vector<SqueezeParams> *parameters, + const Image &image) { + int nb_channels = image.channel.size() - image.nb_meta_channels; + + parameters->clear(); + size_t w = image.channel[image.nb_meta_channels].w; + size_t h = image.channel[image.nb_meta_channels].h; + JXL_DEBUG_V( + 7, "Default squeeze parameters for %" PRIuS "x%" PRIuS " image: ", w, h); + + // do horizontal first on wide images; vertical first on tall images + bool wide = (w > h); + + if (nb_channels > 2 && image.channel[image.nb_meta_channels + 1].w == w && + image.channel[image.nb_meta_channels + 1].h == h) { + // assume channels 1 and 2 are chroma, and can be squeezed first for 4:2:0 + // previews + JXL_DEBUG_V(7, "(4:2:0 chroma), %" PRIuS "x%" PRIuS " image", w, h); + SqueezeParams params; + // horizontal chroma squeeze + params.horizontal = true; + params.in_place = false; + params.begin_c = image.nb_meta_channels + 1; + params.num_c = 2; + parameters->push_back(params); + params.horizontal = false; + // vertical chroma squeeze + parameters->push_back(params); + } + SqueezeParams params; + params.begin_c = image.nb_meta_channels; + params.num_c = nb_channels; + params.in_place = true; + + if (!wide) { + if (h > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = false; + parameters->push_back(params); + h = (h + 1) / 2; + JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h); + } + } + while (w > JXL_MAX_FIRST_PREVIEW_SIZE || h > JXL_MAX_FIRST_PREVIEW_SIZE) { + if (w > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = true; + parameters->push_back(params); + w = (w + 1) / 2; + JXL_DEBUG_V(7, "Horizontal (%" PRIuS "x%" PRIuS "), ", w, h); + } + if (h > JXL_MAX_FIRST_PREVIEW_SIZE) { + params.horizontal = false; + parameters->push_back(params); + h = (h + 1) / 2; + JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h); + } + } + JXL_DEBUG_V(7, "that's it"); +} + +Status CheckMetaSqueezeParams(const SqueezeParams ¶meter, + int num_channels) { + int c1 = parameter.begin_c; + int c2 = parameter.begin_c + parameter.num_c - 1; + if (c1 < 0 || c1 >= num_channels || c2 < 0 || c2 >= num_channels || c2 < c1) { + return JXL_FAILURE("Invalid channel range"); + } + return true; +} + +Status MetaSqueeze(Image &image, std::vector<SqueezeParams> *parameters) { + if (parameters->empty()) { + DefaultSqueezeParameters(parameters, image); + } + + for (size_t i = 0; i < parameters->size(); i++) { + JXL_RETURN_IF_ERROR( + CheckMetaSqueezeParams((*parameters)[i], image.channel.size())); + bool horizontal = (*parameters)[i].horizontal; + bool in_place = (*parameters)[i].in_place; + uint32_t beginc = (*parameters)[i].begin_c; + uint32_t endc = (*parameters)[i].begin_c + (*parameters)[i].num_c - 1; + + uint32_t offset; + if (beginc < image.nb_meta_channels) { + if (endc >= image.nb_meta_channels) { + return JXL_FAILURE("Invalid squeeze: mix of meta and nonmeta channels"); + } + if (!in_place) { + return JXL_FAILURE( + "Invalid squeeze: meta channels require in-place residuals"); + } + image.nb_meta_channels += (*parameters)[i].num_c; + } + if (in_place) { + offset = endc + 1; + } else { + offset = image.channel.size(); + } + for (uint32_t c = beginc; c <= endc; c++) { + if (image.channel[c].hshift > 30 || image.channel[c].vshift > 30) { + return JXL_FAILURE("Too many squeezes: shift > 30"); + } + size_t w = image.channel[c].w; + size_t h = image.channel[c].h; + if (w == 0 || h == 0) return JXL_FAILURE("Squeezing empty channel"); + if (horizontal) { + image.channel[c].w = (w + 1) / 2; + if (image.channel[c].hshift >= 0) image.channel[c].hshift++; + w = w - (w + 1) / 2; + } else { + image.channel[c].h = (h + 1) / 2; + if (image.channel[c].vshift >= 0) image.channel[c].vshift++; + h = h - (h + 1) / 2; + } + image.channel[c].shrink(); + Channel placeholder(w, h); + placeholder.hshift = image.channel[c].hshift; + placeholder.vshift = image.channel[c].vshift; + + image.channel.insert(image.channel.begin() + offset + (c - beginc), + std::move(placeholder)); + JXL_DEBUG_V(8, "MetaSqueeze applied, current image: %s", + image.DebugString().c_str()); + } + } + return true; +} + +} // namespace jxl + +#endif diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h new file mode 100644 index 0000000000..305a0ca3ec --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/squeeze.h @@ -0,0 +1,89 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ +#define LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ + +// Haar-like transform: halves the resolution in one direction +// A B -> (A+B)>>1 in one channel (average) -> same range as +// original channel +// A-B - tendency in a new channel ('residual' needed to make +// the transform reversible) +// -> theoretically range could be 2.5 +// times larger (2 times without the +// 'tendency'), but there should be lots +// of zeroes +// Repeated application (alternating horizontal and vertical squeezes) results +// in downscaling +// +// The default coefficient ordering is low-frequency to high-frequency, as in +// M. Antonini, M. Barlaud, P. Mathieu and I. Daubechies, "Image coding using +// wavelet transform", IEEE Transactions on Image Processing, vol. 1, no. 2, pp. +// 205-220, April 1992, doi: 10.1109/83.136597. + +#include <stdlib.h> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/transform.h" + +#define JXL_MAX_FIRST_PREVIEW_SIZE 8 + +namespace jxl { + +/* + int avg=(A+B)>>1; + int diff=(A-B); + int rA=(diff+(avg<<1)+(diff&1))>>1; + int rB=rA-diff; + +*/ +// |A B|C D|E F| +// p a n p=avg(A,B), a=avg(C,D), n=avg(E,F) +// +// Goal: estimate C-D (avoiding ringing artifacts) +// (ensuring that in smooth areas, a zero residual corresponds to a smooth +// gradient) + +// best estimate for C: (B + 2*a)/3 +// best estimate for D: (n + 3*a)/4 +// best estimate for C-D: 4*B - 3*n - a /12 + +// avoid ringing by 1) only doing this if B <= a <= n or B >= a >= n +// (otherwise, this is not a smooth area and we cannot really estimate C-D) +// 2) making sure that B <= C <= D <= n or B >= C >= D >= n + +inline pixel_type_w SmoothTendency(pixel_type_w B, pixel_type_w a, + pixel_type_w n) { + pixel_type_w diff = 0; + if (B >= a && a >= n) { + diff = (4 * B - 3 * n - a + 6) / 12; + // 2C = a<<1 + diff - diff&1 <= 2B so diff - diff&1 <= 2B - 2a + // 2D = a<<1 - diff - diff&1 >= 2n so diff + diff&1 <= 2a - 2n + if (diff - (diff & 1) > 2 * (B - a)) diff = 2 * (B - a) + 1; + if (diff + (diff & 1) > 2 * (a - n)) diff = 2 * (a - n); + } else if (B <= a && a <= n) { + diff = (4 * B - 3 * n - a - 6) / 12; + // 2C = a<<1 + diff + diff&1 >= 2B so diff + diff&1 >= 2B - 2a + // 2D = a<<1 - diff + diff&1 <= 2n so diff - diff&1 >= 2a - 2n + if (diff + (diff & 1) < 2 * (B - a)) diff = 2 * (B - a) - 1; + if (diff - (diff & 1) < 2 * (a - n)) diff = 2 * (a - n); + } + return diff; +} + +void DefaultSqueezeParameters(std::vector<SqueezeParams> *parameters, + const Image &image); + +Status CheckMetaSqueezeParams(const SqueezeParams ¶meter, int num_channels); + +Status MetaSqueeze(Image &image, std::vector<SqueezeParams> *parameters); + +Status InvSqueeze(Image &input, std::vector<SqueezeParams> parameters, + ThreadPool *pool); + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_SQUEEZE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc new file mode 100644 index 0000000000..33f7a10cc9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.cc @@ -0,0 +1,100 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/modular/transform/transform.h" + +#include <cinttypes> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/transform/palette.h" +#include "lib/jxl/modular/transform/rct.h" +#include "lib/jxl/modular/transform/squeeze.h" + +namespace jxl { + +SqueezeParams::SqueezeParams() { Bundle::Init(this); } +Transform::Transform(TransformId id) { + Bundle::Init(this); + this->id = id; +} + +Status Transform::Inverse(Image &input, const weighted::Header &wp_header, + ThreadPool *pool) { + JXL_DEBUG_V(6, "Input channels (%" PRIuS ", %" PRIuS " meta): ", + input.channel.size(), input.nb_meta_channels); + switch (id) { + case TransformId::kRCT: + return InvRCT(input, begin_c, rct_type, pool); + case TransformId::kSqueeze: + return InvSqueeze(input, squeezes, pool); + case TransformId::kPalette: + return InvPalette(input, begin_c, nb_colors, nb_deltas, predictor, + wp_header, pool); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast<unsigned int>(id)); + } +} + +Status Transform::MetaApply(Image &input) { + JXL_DEBUG_V(6, "MetaApply input: %s", input.DebugString().c_str()); + switch (id) { + case TransformId::kRCT: + JXL_DEBUG_V(2, "Transform: kRCT, rct_type=%" PRIu32, rct_type); + return CheckEqualChannels(input, begin_c, begin_c + 2); + case TransformId::kSqueeze: + JXL_DEBUG_V(2, "Transform: kSqueeze:"); +#if JXL_DEBUG_V_LEVEL >= 2 + { + auto squeezes_copy = squeezes; + if (squeezes_copy.empty()) { + DefaultSqueezeParameters(&squeezes_copy, input); + } + for (const auto ¶ms : squeezes_copy) { + JXL_DEBUG_V( + 2, + " squeeze params: horizontal=%d, in_place=%d, begin_c=%" PRIu32 + ", num_c=%" PRIu32, + params.horizontal, params.in_place, params.begin_c, params.num_c); + } + } +#endif + return MetaSqueeze(input, &squeezes); + case TransformId::kPalette: + JXL_DEBUG_V(2, + "Transform: kPalette, begin_c=%" PRIu32 ", num_c=%" PRIu32 + ", nb_colors=%" PRIu32 ", nb_deltas=%" PRIu32, + begin_c, num_c, nb_colors, nb_deltas); + return MetaPalette(input, begin_c, begin_c + num_c - 1, nb_colors, + nb_deltas, lossy_palette); + default: + return JXL_FAILURE("Unknown transformation (ID=%u)", + static_cast<unsigned int>(id)); + } +} + +Status CheckEqualChannels(const Image &image, uint32_t c1, uint32_t c2) { + if (c1 > image.channel.size() || c2 >= image.channel.size() || c2 < c1) { + return JXL_FAILURE("Invalid channel range: %u..%u (there are only %" PRIuS + " channels)", + c1, c2, image.channel.size()); + } + if (c1 < image.nb_meta_channels && c2 >= image.nb_meta_channels) { + return JXL_FAILURE("Invalid: transforming mix of meta and nonmeta"); + } + const auto &ch1 = image.channel[c1]; + for (size_t c = c1 + 1; c <= c2; c++) { + const auto &ch2 = image.channel[c]; + if (ch1.w != ch2.w || ch1.h != ch2.h || ch1.hshift != ch2.hshift || + ch1.vshift != ch2.vshift) { + return false; + } + } + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h new file mode 100644 index 0000000000..d5d3259f7a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular/transform/transform.h @@ -0,0 +1,148 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ +#define LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ + +#include <cstdint> +#include <string> +#include <vector> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/modular/encoding/context_predict.h" +#include "lib/jxl/modular/options.h" + +namespace jxl { + +enum class TransformId : uint32_t { + // G, R-G, B-G and variants (including YCoCg). + kRCT = 0, + + // Color palette. Parameters are: [begin_c] [end_c] [nb_colors] + kPalette = 1, + + // Squeezing (Haar-style) + kSqueeze = 2, + + // Invalid for now. + kInvalid = 3, +}; + +struct SqueezeParams : public Fields { + JXL_FIELDS_NAME(SqueezeParams) + bool horizontal; + bool in_place; + uint32_t begin_c; + uint32_t num_c; + SqueezeParams(); + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &horizontal)); + JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &in_place)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Bits(3), BitsOffset(6, 8), + BitsOffset(10, 72), + BitsOffset(13, 1096), 0, &begin_c)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(2), Val(3), BitsOffset(4, 4), 2, &num_c)); + return true; + } +}; + +class Transform : public Fields { + public: + TransformId id; + // for Palette and RCT. + uint32_t begin_c; + // for RCT. 42 possible values starting from 0. + uint32_t rct_type; + // Only for Palette and NearLossless. + uint32_t num_c; + // Only for Palette. + uint32_t nb_colors; + uint32_t nb_deltas; + // for Squeeze. Default squeeze if empty. + std::vector<SqueezeParams> squeezes; + // for NearLossless, not serialized. + int max_delta_error; + // Serialized for Palette. + Predictor predictor; + // for Palette, not serialized. + bool ordered_palette = true; + bool lossy_palette = false; + + explicit Transform(TransformId id); + // default constructor for bundles. + Transform() : Transform(TransformId::kInvalid) {} + + Status VisitFields(Visitor *JXL_RESTRICT visitor) override { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + Val((uint32_t)TransformId::kRCT), Val((uint32_t)TransformId::kPalette), + Val((uint32_t)TransformId::kSqueeze), + Val((uint32_t)TransformId::kInvalid), (uint32_t)TransformId::kRCT, + reinterpret_cast<uint32_t *>(&id))); + if (id == TransformId::kInvalid) { + return JXL_FAILURE("Invalid transform ID"); + } + if (visitor->Conditional(id == TransformId::kRCT || + id == TransformId::kPalette)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Bits(3), BitsOffset(6, 8), BitsOffset(10, 72), + BitsOffset(13, 1096), 0, &begin_c)); + } + if (visitor->Conditional(id == TransformId::kRCT)) { + // 0-41, default YCoCg. + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(6), Bits(2), BitsOffset(4, 2), + BitsOffset(6, 10), 6, &rct_type)); + if (rct_type >= 42) { + return JXL_FAILURE("Invalid transform RCT type"); + } + } + if (visitor->Conditional(id == TransformId::kPalette)) { + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(1), Val(3), Val(4), BitsOffset(13, 1), 3, &num_c)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + BitsOffset(8, 0), BitsOffset(10, 256), BitsOffset(12, 1280), + BitsOffset(16, 5376), 256, &nb_colors)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), BitsOffset(8, 1), BitsOffset(10, 257), + BitsOffset(16, 1281), 0, &nb_deltas)); + JXL_QUIET_RETURN_IF_ERROR( + visitor->Bits(4, (uint32_t)Predictor::Zero, + reinterpret_cast<uint32_t *>(&predictor))); + if (predictor >= Predictor::Best) { + return JXL_FAILURE("Invalid predictor"); + } + } + + if (visitor->Conditional(id == TransformId::kSqueeze)) { + uint32_t num_squeezes = static_cast<uint32_t>(squeezes.size()); + JXL_QUIET_RETURN_IF_ERROR( + visitor->U32(Val(0), BitsOffset(4, 1), BitsOffset(6, 9), + BitsOffset(8, 41), 0, &num_squeezes)); + if (visitor->IsReading()) squeezes.resize(num_squeezes); + for (size_t i = 0; i < num_squeezes; i++) { + JXL_QUIET_RETURN_IF_ERROR(visitor->VisitNested(&squeezes[i])); + } + } + return true; + } + + JXL_FIELDS_NAME(Transform) + + Status Inverse(Image &input, const weighted::Header &wp_header, + ThreadPool *pool = nullptr); + Status MetaApply(Image &input); +}; + +Status CheckEqualChannels(const Image &image, uint32_t c1, uint32_t c2); + +static inline pixel_type PixelAdd(pixel_type a, pixel_type b) { + return static_cast<pixel_type>(static_cast<uint32_t>(a) + + static_cast<uint32_t>(b)); +} + +} // namespace jxl + +#endif // LIB_JXL_MODULAR_TRANSFORM_TRANSFORM_H_ diff --git a/third_party/jpeg-xl/lib/jxl/modular_test.cc b/third_party/jpeg-xl/lib/jxl/modular_test.cc new file mode 100644 index 0000000000..689063ce95 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/modular_test.cc @@ -0,0 +1,525 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/cms.h> +#include <jxl/encode.h> +#include <jxl/types.h> + +#include <cstddef> +#include <cstdint> +#include <sstream> +#include <string> +#include <utility> +#include <vector> + +#include "lib/extras/codec.h" +#include "lib/extras/dec/jxl.h" +#include "lib/extras/enc/jxl.h" +#include "lib/extras/metrics.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_toc.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/modular/encoding/enc_encoding.h" +#include "lib/jxl/modular/encoding/encoding.h" +#include "lib/jxl/modular/modular_image.h" +#include "lib/jxl/modular/options.h" +#include "lib/jxl/modular/transform/transform.h" +#include "lib/jxl/padded_bytes.h" +#include "lib/jxl/test_image.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +using test::ReadTestData; +using test::Roundtrip; +using test::TestImage; + +void TestLosslessGroups(size_t group_size_shift) { + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + t.SetDimensions(t.ppf().xsize() / 4, t.ppf().ysize() / 4); + + extras::JXLCompressParams cparams; + cparams.distance = 0.0f; + cparams.AddOption(JXL_ENC_FRAME_SETTING_MODULAR_GROUP_SIZE, group_size_shift); + extras::JXLDecompressParams dparams; + dparams.accepted_formats = {{3, JXL_TYPE_UINT16, JXL_LITTLE_ENDIAN, 0}}; + + extras::PackedPixelFile ppf_out; + size_t compressed_size = + Roundtrip(t.ppf(), cparams, dparams, nullptr, &ppf_out); + EXPECT_LE(compressed_size, 280000u); + EXPECT_EQ(0.0f, test::ComputeDistance2(t.ppf(), ppf_out)); +} + +TEST(ModularTest, RoundtripLosslessGroups128) { TestLosslessGroups(0); } + +TEST(ModularTest, JXL_TSAN_SLOW_TEST(RoundtripLosslessGroups512)) { + TestLosslessGroups(2); +} + +TEST(ModularTest, JXL_TSAN_SLOW_TEST(RoundtripLosslessGroups1024)) { + TestLosslessGroups(3); +} + +TEST(ModularTest, RoundtripLosslessCustomWP_PermuteRCT) { + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + t.SetDimensions(100, 100); + + extras::JXLCompressParams cparams; + cparams.distance = 0.0f; + // 9 = permute to GBR, to test the special case of permutation-only + cparams.AddOption(JXL_ENC_FRAME_SETTING_MODULAR_COLOR_SPACE, 9); + cparams.AddOption(JXL_ENC_FRAME_SETTING_MODULAR_PREDICTOR, + static_cast<int64_t>(Predictor::Weighted)); + // slowest speed so different WP modes are tried + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, 9); + extras::JXLDecompressParams dparams; + dparams.accepted_formats = {{3, JXL_TYPE_UINT16, JXL_LITTLE_ENDIAN, 0}}; + + extras::PackedPixelFile ppf_out; + size_t compressed_size = + Roundtrip(t.ppf(), cparams, dparams, nullptr, &ppf_out); + EXPECT_LE(compressed_size, 10169u); + EXPECT_EQ(0.0f, test::ComputeDistance2(t.ppf(), ppf_out)); +} + +TEST(ModularTest, RoundtripLossyDeltaPalette) { + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CompressParams cparams; + cparams.modular_mode = true; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.lossy_palette = true; + cparams.palette_colors = 0; + + CodecInOut io_out; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + io.ShrinkTo(300, 100); + + size_t compressed_size; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io_out, _, &compressed_size)); + EXPECT_LE(compressed_size, 6800u); + EXPECT_THAT(ButteraugliDistance(io.frames, io_out.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(1.5)); +} +TEST(ModularTest, RoundtripLossyDeltaPaletteWP) { + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CompressParams cparams; + cparams.SetLossless(); + cparams.lossy_palette = true; + cparams.palette_colors = 0; + cparams.options.predictor = jxl::Predictor::Weighted; + + CodecInOut io_out; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + io.ShrinkTo(300, 100); + + size_t compressed_size; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io_out, _, &compressed_size)); + EXPECT_LE(compressed_size, 7000u); + EXPECT_THAT(ButteraugliDistance(io.frames, io_out.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(10.1)); +} + +TEST(ModularTest, RoundtripLossy) { + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CompressParams cparams; + cparams.modular_mode = true; + cparams.butteraugli_distance = 2.f; + cparams.SetCms(*JxlGetDefaultCms()); + + CodecInOut io_out; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + + size_t compressed_size; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io_out, _, &compressed_size)); + EXPECT_LE(compressed_size, 30000u); + EXPECT_THAT(ButteraugliDistance(io.frames, io_out.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(2.3)); +} + +TEST(ModularTest, RoundtripLossy16) { + const std::vector<uint8_t> orig = + ReadTestData("external/raw.pixls/DJI-FC6310-16bit_709_v4_krita.png"); + CompressParams cparams; + cparams.modular_mode = true; + cparams.butteraugli_distance = 2.f; + + CodecInOut io_out; + + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + JXL_CHECK(!io.metadata.m.have_preview); + JXL_CHECK(io.frames.size() == 1); + JXL_CHECK( + io.frames[0].TransformTo(ColorEncoding::SRGB(), *JxlGetDefaultCms())); + io.metadata.m.color_encoding = ColorEncoding::SRGB(); + + size_t compressed_size; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io_out, _, &compressed_size)); + EXPECT_LE(compressed_size, 300u); + EXPECT_THAT(ButteraugliDistance(io.frames, io_out.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(1.6)); +} + +TEST(ModularTest, RoundtripExtraProperties) { + constexpr size_t kSize = 250; + Image image(kSize, kSize, /*bitdepth=*/8, 3); + ModularOptions options; + options.max_properties = 4; + options.predictor = Predictor::Zero; + Rng rng(0); + for (size_t y = 0; y < kSize; y++) { + for (size_t x = 0; x < kSize; x++) { + image.channel[0].plane.Row(y)[x] = image.channel[2].plane.Row(y)[x] = + rng.UniformU(0, 9); + } + } + ZeroFillImage(&image.channel[1].plane); + BitWriter writer; + ASSERT_TRUE(ModularGenericCompress(image, options, &writer)); + writer.ZeroPadToByte(); + Image decoded(kSize, kSize, /*bitdepth=*/8, image.channel.size()); + for (size_t i = 0; i < image.channel.size(); i++) { + const Channel& ch = image.channel[i]; + decoded.channel[i] = Channel(ch.w, ch.h, ch.hshift, ch.vshift); + } + Status status = true; + { + BitReader reader(writer.GetSpan()); + BitReaderScopedCloser closer(&reader, &status); + ASSERT_TRUE(ModularGenericDecompress(&reader, decoded, /*header=*/nullptr, + /*group_id=*/0, &options)); + } + ASSERT_TRUE(status); + ASSERT_EQ(image.channel.size(), decoded.channel.size()); + for (size_t c = 0; c < image.channel.size(); c++) { + for (size_t y = 0; y < image.channel[c].plane.ysize(); y++) { + for (size_t x = 0; x < image.channel[c].plane.xsize(); x++) { + EXPECT_EQ(image.channel[c].plane.Row(y)[x], + decoded.channel[c].plane.Row(y)[x]) + << "c = " << c << ", x = " << x << ", y = " << y; + } + } + } +} + +struct RoundtripLosslessConfig { + int bitdepth; + int responsive; +}; +class ModularTestParam + : public ::testing::TestWithParam<RoundtripLosslessConfig> {}; + +std::vector<RoundtripLosslessConfig> GenerateLosslessTests() { + std::vector<RoundtripLosslessConfig> all; + for (int responsive = 0; responsive <= 1; responsive++) { + for (int bitdepth = 1; bitdepth < 32; bitdepth++) { + if (responsive && bitdepth > 30) continue; + all.push_back({bitdepth, responsive}); + } + } + return all; +} +std::string LosslessTestDescription( + const testing::TestParamInfo<ModularTestParam::ParamType>& info) { + std::stringstream name; + name << info.param.bitdepth << "bit"; + if (info.param.responsive) name << "Squeeze"; + return name.str(); +} + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(RoundtripLossless, ModularTestParam, + testing::ValuesIn(GenerateLosslessTests()), + LosslessTestDescription); + +TEST_P(ModularTestParam, RoundtripLossless) { + RoundtripLosslessConfig config = GetParam(); + int bitdepth = config.bitdepth; + int responsive = config.responsive; + + ThreadPool* pool = nullptr; + Rng generator(123); + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io1; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io1, pool)); + + // vary the dimensions a bit, in case of bugs related to + // even vs odd width or height. + size_t xsize = 423 + bitdepth; + size_t ysize = 467 + bitdepth; + + CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB(false); + io.metadata.m.SetUintSamples(bitdepth); + + double factor = ((1lu << bitdepth) - 1lu); + double ifactor = 1.0 / factor; + Image3F noise_added(xsize, ysize); + + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ysize; y++) { + const float* in = io1.Main().color()->PlaneRow(c, y); + float* out = noise_added.PlaneRow(c, y); + for (size_t x = 0; x < xsize; x++) { + // make the least significant bits random + float f = in[x] + generator.UniformF(0.0f, 1.f / 255.f); + if (f > 1.f) f = 1.f; + // quantize to the bitdepth we're testing + unsigned int u = f * factor + 0.5; + out[x] = u * ifactor; + } + } + } + io.SetFromImage(std::move(noise_added), jxl::ColorEncoding::SRGB(false)); + + CompressParams cparams; + cparams.modular_mode = true; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.butteraugli_distance = 0.f; + cparams.options.predictor = {Predictor::Zero}; + cparams.speed_tier = SpeedTier::kThunder; + cparams.responsive = responsive; + CodecInOut io2; + size_t compressed_size; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _, &compressed_size)); + EXPECT_LE(compressed_size, bitdepth * xsize * ysize / 3); + EXPECT_LE(0, ComputeDistance2(io.Main(), io2.Main(), *JxlGetDefaultCms())); + size_t different = 0; + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ysize; y++) { + const float* in = io.Main().color()->PlaneRow(c, y); + const float* out = io2.Main().color()->PlaneRow(c, y); + for (size_t x = 0; x < xsize; x++) { + uint32_t uin = in[x] * factor + 0.5; + uint32_t uout = out[x] * factor + 0.5; + // check that the integer values are identical + if (uin != uout) different++; + } + } + } + EXPECT_EQ(different, 0); +} + +TEST(ModularTest, RoundtripLosslessCustomFloat) { + CodecInOut io; + size_t xsize = 100, ysize = 300; + io.SetSize(xsize, ysize); + io.metadata.m.bit_depth.bits_per_sample = 18; + io.metadata.m.bit_depth.exponent_bits_per_sample = 6; + io.metadata.m.bit_depth.floating_point_sample = true; + io.metadata.m.modular_16_bit_buffer_sufficient = false; + ColorEncoding color_encoding; + color_encoding.Tf().SetTransferFunction(TransferFunction::kLinear); + color_encoding.SetColorSpace(ColorSpace::kRGB); + Image3F testimage(xsize, ysize); + float factor = 1.f / (1 << 14); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ysize; y++) { + float* const JXL_RESTRICT row = testimage.PlaneRow(c, y); + for (size_t x = 0; x < xsize; x++) { + row[x] = factor * (x ^ y); + } + } + } + io.SetFromImage(std::move(testimage), color_encoding); + io.metadata.m.color_encoding = color_encoding; + io.metadata.m.SetIntensityTarget(255); + + CompressParams cparams; + cparams.modular_mode = true; + cparams.color_transform = jxl::ColorTransform::kNone; + cparams.butteraugli_distance = 0.f; + cparams.options.predictor = {Predictor::Zero}; + cparams.speed_tier = SpeedTier::kThunder; + cparams.decoding_speed_tier = 2; + + CodecInOut io2; + size_t compressed_size; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _, &compressed_size)); + EXPECT_LE(compressed_size, 23000u); + JXL_EXPECT_OK(SamePixels(*io.Main().color(), *io2.Main().color(), _)); +} + +void WriteHeaders(BitWriter* writer, size_t xsize, size_t ysize) { + BitWriter::Allotment allotment(writer, 16); + writer->Write(8, 0xFF); + writer->Write(8, kCodestreamMarker); + allotment.ReclaimAndCharge(writer, 0, nullptr); + CodecMetadata metadata; + EXPECT_TRUE(metadata.size.Set(xsize, ysize)); + EXPECT_TRUE(WriteSizeHeader(metadata.size, writer, 0, nullptr)); + metadata.m.color_encoding = ColorEncoding::LinearSRGB(/*is_gray=*/true); + metadata.m.xyb_encoded = false; + metadata.m.SetUintSamples(31); + EXPECT_TRUE(WriteImageMetadata(metadata.m, writer, 0, nullptr)); + metadata.transform_data.nonserialized_xyb_encoded = metadata.m.xyb_encoded; + EXPECT_TRUE(Bundle::Write(metadata.transform_data, writer, 0, nullptr)); + writer->ZeroPadToByte(); + FrameHeader frame_header(&metadata); + frame_header.encoding = FrameEncoding::kModular; + frame_header.loop_filter.gab = false; + frame_header.loop_filter.epf_iters = 0; + EXPECT_TRUE(WriteFrameHeader(frame_header, writer, nullptr)); +} + +// Tree with single node, zero predictor, offset is 1 and multiplier is 1, +// entropy code is prefix tree with alphabet size 256 and all bits lengths 8. +void WriteHistograms(BitWriter* writer) { + writer->Write(1, 1); // default DC quant + writer->Write(1, 1); // has_tree + // tree histograms + writer->Write(1, 0); // LZ77 disabled + writer->Write(3, 1); // simple context map + writer->Write(1, 1); // prefix code + writer->Write(7, 0x63); // UnintConfig(3, 2, 1) + writer->Write(12, 0xfef); // alphabet_size = 256 + writer->Write(32, 0x10003); // all bit lengths 8 + // tree tokens + writer->Write(8, 0); // tree leaf + writer->Write(8, 0); // zero predictor + writer->Write(8, 64); // offset = UnpackSigned(ReverseBits(64)) = 1 + writer->Write(16, 0); // multiplier = 1 + // histograms + writer->Write(1, 0); // LZ77 disabled + writer->Write(1, 1); // prefix code + writer->Write(7, 0x63); // UnintConfig(3, 2, 1) + writer->Write(12, 0xfef); // alphabet_size = 256 + writer->Write(32, 0x10003); // all bit lengths 8 +} + +TEST(ModularTest, PredictorIntegerOverflow) { + const size_t xsize = 1; + const size_t ysize = 1; + BitWriter writer; + WriteHeaders(&writer, xsize, ysize); + std::vector<BitWriter> group_codes(1); + { + BitWriter* bw = &group_codes[0]; + BitWriter::Allotment allotment(bw, 1 << 20); + WriteHistograms(bw); + GroupHeader header; + header.use_global_tree = true; + EXPECT_TRUE(Bundle::Write(header, bw, 0, nullptr)); + // After UnpackSigned this becomes (1 << 31) - 1, the largest pixel_type, + // and after adding the offset we get -(1 << 31). + bw->Write(8, 119); + bw->Write(28, 0xfffffff); + bw->ZeroPadToByte(); + allotment.ReclaimAndCharge(bw, 0, nullptr); + } + EXPECT_TRUE(WriteGroupOffsets(group_codes, {}, &writer, nullptr)); + writer.AppendByteAligned(group_codes); + + PaddedBytes compressed = std::move(writer).TakeBytes(); + extras::PackedPixelFile ppf; + extras::JXLDecompressParams params; + params.accepted_formats.push_back({1, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}); + EXPECT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), params, + nullptr, &ppf)); + ASSERT_EQ(1, ppf.frames.size()); + const auto& img = ppf.frames[0].color; + const auto pixels = reinterpret_cast<const float*>(img.pixels()); + EXPECT_EQ(-1.0f, pixels[0]); +} + +TEST(ModularTest, UnsqueezeIntegerOverflow) { + // Image width is 9 so we can test both the SIMD and non-vector code paths. + const size_t xsize = 9; + const size_t ysize = 2; + BitWriter writer; + WriteHeaders(&writer, xsize, ysize); + std::vector<BitWriter> group_codes(1); + { + BitWriter* bw = &group_codes[0]; + BitWriter::Allotment allotment(bw, 1 << 20); + WriteHistograms(bw); + GroupHeader header; + header.use_global_tree = true; + header.transforms.emplace_back(); + header.transforms[0].id = TransformId::kSqueeze; + SqueezeParams params; + params.horizontal = false; + params.in_place = true; + params.begin_c = 0; + params.num_c = 1; + header.transforms[0].squeezes.emplace_back(params); + EXPECT_TRUE(Bundle::Write(header, bw, 0, nullptr)); + for (size_t i = 0; i < xsize * ysize; ++i) { + // After UnpackSigned and adding offset, this becomes (1 << 31) - 1, both + // in the image and in the residual channels, and unsqueeze makes them + // ~(3 << 30) and (1 << 30) (in pixel_type_w) and the first wraps around + // to about -(1 << 30). + bw->Write(8, 119); + bw->Write(28, 0xffffffe); + } + bw->ZeroPadToByte(); + allotment.ReclaimAndCharge(bw, 0, nullptr); + } + EXPECT_TRUE(WriteGroupOffsets(group_codes, {}, &writer, nullptr)); + writer.AppendByteAligned(group_codes); + + PaddedBytes compressed = std::move(writer).TakeBytes(); + extras::PackedPixelFile ppf; + extras::JXLDecompressParams params; + params.accepted_formats.push_back({1, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}); + EXPECT_TRUE(DecodeImageJXL(compressed.data(), compressed.size(), params, + nullptr, &ppf)); + ASSERT_EQ(1, ppf.frames.size()); + const auto& img = ppf.frames[0].color; + const auto pixels = reinterpret_cast<const float*>(img.pixels()); + for (size_t x = 0; x < xsize; ++x) { + EXPECT_NEAR(-0.5f, pixels[x], 1e-10); + EXPECT_NEAR(0.5f, pixels[xsize + x], 1e-10); + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/noise.h b/third_party/jpeg-xl/lib/jxl/noise.h new file mode 100644 index 0000000000..585fab0d42 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/noise.h @@ -0,0 +1,60 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_NOISE_H_ +#define LIB_JXL_NOISE_H_ + +// Noise parameters shared by encoder/decoder. + +#include <stddef.h> + +#include <algorithm> +#include <cmath> +#include <utility> + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +const float kNoisePrecision = 1 << 10; + +struct NoiseParams { + // LUT index is an intensity of pixel / mean intensity of patch + static constexpr size_t kNumNoisePoints = 8; + float lut[kNumNoisePoints]; + + void Clear() { + for (float& i : lut) i = 0.f; + } + bool HasAny() const { + for (float i : lut) { + if (std::abs(i) > 1e-3f) return true; + } + return false; + } +}; + +static inline std::pair<int, float> IndexAndFrac(float x) { + constexpr size_t kScaleNumerator = NoiseParams::kNumNoisePoints - 2; + // TODO(user): instead of 1, this should be a proper Y range. + constexpr float kScale = kScaleNumerator / 1; + float scaled_x = std::max(0.f, x * kScale); + float floor_x; + float frac_x = std::modf(scaled_x, &floor_x); + if (JXL_UNLIKELY(scaled_x >= kScaleNumerator + 1)) { + floor_x = kScaleNumerator; + frac_x = 1.f; + } + return std::make_pair(static_cast<int>(floor_x), frac_x); +} + +struct NoiseLevel { + float noise_level; + float intensity; +}; + +} // namespace jxl + +#endif // LIB_JXL_NOISE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/opsin_image_test.cc b/third_party/jpeg-xl/lib/jxl/opsin_image_test.cc new file mode 100644 index 0000000000..f7842c32e4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/opsin_image_test.cc @@ -0,0 +1,126 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/cms.h> + +#include <cstddef> +#include <utility> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/matrix_ops.h" +#include "lib/jxl/cms/opsin_params.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +// Convert a single linear sRGB color to xyb, using the exact image conversion +// procedure that jpeg xl uses. +void LinearSrgbToOpsin(float rgb_r, float rgb_g, float rgb_b, + float* JXL_RESTRICT xyb_x, float* JXL_RESTRICT xyb_y, + float* JXL_RESTRICT xyb_b) { + Image3F linear(1, 1); + linear.PlaneRow(0, 0)[0] = rgb_r; + linear.PlaneRow(1, 0)[0] = rgb_g; + linear.PlaneRow(2, 0)[0] = rgb_b; + + ImageMetadata metadata; + metadata.SetFloat32Samples(); + metadata.color_encoding = ColorEncoding::LinearSRGB(); + ImageBundle ib(&metadata); + ib.SetFromImage(std::move(linear), metadata.color_encoding); + Image3F opsin(1, 1); + (void)ToXYB(ib, /*pool=*/nullptr, &opsin, *JxlGetDefaultCms()); + + *xyb_x = opsin.PlaneRow(0, 0)[0]; + *xyb_y = opsin.PlaneRow(1, 0)[0]; + *xyb_b = opsin.PlaneRow(2, 0)[0]; +} + +// Convert a single XYB color to linear sRGB, using the exact image conversion +// procedure that jpeg xl uses. +void OpsinToLinearSrgb(float xyb_x, float xyb_y, float xyb_b, + float* JXL_RESTRICT rgb_r, float* JXL_RESTRICT rgb_g, + float* JXL_RESTRICT rgb_b) { + Image3F opsin(1, 1); + opsin.PlaneRow(0, 0)[0] = xyb_x; + opsin.PlaneRow(1, 0)[0] = xyb_y; + opsin.PlaneRow(2, 0)[0] = xyb_b; + Image3F linear(1, 1); + OpsinParams opsin_params; + opsin_params.Init(/*intensity_target=*/255.0f); + OpsinToLinear(opsin, Rect(opsin), nullptr, &linear, opsin_params); + *rgb_r = linear.PlaneRow(0, 0)[0]; + *rgb_g = linear.PlaneRow(1, 0)[0]; + *rgb_b = linear.PlaneRow(2, 0)[0]; +} + +void OpsinRoundtripTestRGB(float r, float g, float b) { + float xyb_x, xyb_y, xyb_b; + LinearSrgbToOpsin(r, g, b, &xyb_x, &xyb_y, &xyb_b); + float r2, g2, b2; + OpsinToLinearSrgb(xyb_x, xyb_y, xyb_b, &r2, &g2, &b2); + EXPECT_NEAR(r, r2, 1e-3); + EXPECT_NEAR(g, g2, 1e-3); + EXPECT_NEAR(b, b2, 1e-3); +} + +TEST(OpsinImageTest, VerifyOpsinAbsorbanceInverseMatrix) { + float matrix[9]; // writable copy + for (int i = 0; i < 9; i++) { + matrix[i] = GetOpsinAbsorbanceInverseMatrix()[i]; + } + EXPECT_TRUE(Inv3x3Matrix(matrix)); + for (int i = 0; i < 9; i++) { + EXPECT_NEAR(matrix[i], jxl::cms::kOpsinAbsorbanceMatrix[i], 1e-6); + } +} + +TEST(OpsinImageTest, OpsinRoundtrip) { + OpsinRoundtripTestRGB(0, 0, 0); + OpsinRoundtripTestRGB(1. / 255, 1. / 255, 1. / 255); + OpsinRoundtripTestRGB(128. / 255, 128. / 255, 128. / 255); + OpsinRoundtripTestRGB(1, 1, 1); + + OpsinRoundtripTestRGB(0, 0, 1. / 255); + OpsinRoundtripTestRGB(0, 0, 128. / 255); + OpsinRoundtripTestRGB(0, 0, 1); + + OpsinRoundtripTestRGB(0, 1. / 255, 0); + OpsinRoundtripTestRGB(0, 128. / 255, 0); + OpsinRoundtripTestRGB(0, 1, 0); + + OpsinRoundtripTestRGB(1. / 255, 0, 0); + OpsinRoundtripTestRGB(128. / 255, 0, 0); + OpsinRoundtripTestRGB(1, 0, 0); +} + +TEST(OpsinImageTest, VerifyZero) { + // Test that black color (zero energy) is 0,0,0 in xyb. + float x, y, b; + LinearSrgbToOpsin(0, 0, 0, &x, &y, &b); + EXPECT_NEAR(0, x, 1e-9); + EXPECT_NEAR(0, y, 1e-7); + EXPECT_NEAR(0, b, 1e-7); +} + +TEST(OpsinImageTest, VerifyGray) { + // Test that grayscale colors have a fixed y/b ratio and x==0. + for (size_t i = 1; i < 255; i++) { + float x, y, b; + LinearSrgbToOpsin(i / 255., i / 255., i / 255., &x, &y, &b); + EXPECT_NEAR(0, x, 1e-6); + EXPECT_NEAR(jxl::cms::kYToBRatio, b / y, 3e-5); + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/opsin_inverse_test.cc b/third_party/jpeg-xl/lib/jxl/opsin_inverse_test.cc new file mode 100644 index 0000000000..b8c151fbea --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/opsin_inverse_test.cc @@ -0,0 +1,61 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/cms.h> + +#include <utility> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/enc_xyb.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +TEST(OpsinInverseTest, LinearInverseInverts) { + Image3F linear(128, 128); + RandomFillImage(&linear, 0.0f, 1.0f); + + CodecInOut io; + io.metadata.m.SetFloat32Samples(); + io.metadata.m.color_encoding = ColorEncoding::LinearSRGB(); + Image3F linear2(128, 128); + CopyImageTo(linear, &linear2); + io.SetFromImage(std::move(linear2), io.metadata.m.color_encoding); + ThreadPool* null_pool = nullptr; + Image3F opsin(io.xsize(), io.ysize()); + (void)ToXYB(io.Main(), null_pool, &opsin, *JxlGetDefaultCms()); + + OpsinParams opsin_params; + opsin_params.Init(/*intensity_target=*/255.0f); + OpsinToLinearInplace(&opsin, /*pool=*/nullptr, opsin_params); + + JXL_ASSERT_OK(VerifyRelativeError(linear, opsin, 3E-3, 2E-4, _)); +} + +TEST(OpsinInverseTest, YcbCrInverts) { + Image3F rgb(128, 128); + RandomFillImage(&rgb, 0.0f, 1.0f); + + ThreadPool* null_pool = nullptr; + Image3F ycbcr(rgb.xsize(), rgb.ysize()); + EXPECT_TRUE(RgbToYcbcr(rgb.Plane(0), rgb.Plane(1), rgb.Plane(2), + &ycbcr.Plane(1), &ycbcr.Plane(0), &ycbcr.Plane(2), + null_pool)); + + Image3F rgb2(rgb.xsize(), rgb.ysize()); + YcbcrToRgb(ycbcr, &rgb2, Rect(rgb)); + + JXL_ASSERT_OK(VerifyRelativeError(rgb, rgb2, 4E-5, 4E-7, _)); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/opsin_params.cc b/third_party/jpeg-xl/lib/jxl/opsin_params.cc new file mode 100644 index 0000000000..e1fdda5322 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/opsin_params.cc @@ -0,0 +1,46 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/opsin_params.h" + +#include "lib/jxl/cms/opsin_params.h" + +#define INVERSE_OPSIN_FROM_SPEC 1 + +#if not(INVERSE_OPSIN_FROM_SPEC) +#include "lib/jxl/base/matrix_ops.h" +#endif + +namespace jxl { + +const float* GetOpsinAbsorbanceInverseMatrix() { +#if INVERSE_OPSIN_FROM_SPEC + return jxl::cms::DefaultInverseOpsinAbsorbanceMatrix(); +#else // INVERSE_OPSIN_FROM_SPEC + // Compute the inverse opsin matrix from the forward matrix. Less precise + // than taking the values from the specification, but must be used if the + // forward transform is changed and the spec will require updating. + static const float* const kInverse = [] { + static float inverse[9]; + for (int i = 0; i < 9; i++) { + inverse[i] = kOpsinAbsorbanceMatrix[i]; + } + Inv3x3Matrix(inverse); + return inverse; + }(); + return kInverse; +#endif // INVERSE_OPSIN_FROM_SPEC +} + +void InitSIMDInverseMatrix(const float* JXL_RESTRICT inverse, + float* JXL_RESTRICT simd_inverse, + float intensity_target) { + for (size_t i = 0; i < 9; ++i) { + simd_inverse[4 * i] = simd_inverse[4 * i + 1] = simd_inverse[4 * i + 2] = + simd_inverse[4 * i + 3] = inverse[i] * (255.0f / intensity_target); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/opsin_params.h b/third_party/jpeg-xl/lib/jxl/opsin_params.h new file mode 100644 index 0000000000..fc285ac208 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/opsin_params.h @@ -0,0 +1,25 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_OPSIN_PARAMS_H_ +#define LIB_JXL_OPSIN_PARAMS_H_ + +// Constants that define the XYB color space. + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { + +// Returns 3x3 row-major matrix inverse of kOpsinAbsorbanceMatrix. +// opsin_image_test verifies this is actually the inverse. +const float* GetOpsinAbsorbanceInverseMatrix(); + +void InitSIMDInverseMatrix(const float* JXL_RESTRICT inverse, + float* JXL_RESTRICT simd_inverse, + float intensity_target); + +} // namespace jxl + +#endif // LIB_JXL_OPSIN_PARAMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/pack_signed.h b/third_party/jpeg-xl/lib/jxl/pack_signed.h new file mode 100644 index 0000000000..326f06e6f8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/pack_signed.h @@ -0,0 +1,34 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_PACK_H_ +#define LIB_JXL_PACK_H_ + +// Pack/UnpackSigned utilities. + +#include <cstddef> +#include <cstdint> + +#include "lib/jxl/base/compiler_specific.h" + +namespace jxl { +// Encodes non-negative (X) into (2 * X), negative (-X) into (2 * X - 1) +constexpr uint32_t PackSigned(int32_t value) + JXL_NO_SANITIZE("unsigned-integer-overflow") { + return (static_cast<uint32_t>(value) << 1) ^ + ((static_cast<uint32_t>(~value) >> 31) - 1); +} + +// Reverse to PackSigned, i.e. UnpackSigned(PackSigned(X)) == X. +// (((~value) & 1) - 1) is either 0 or 0xFF...FF and it will have an expected +// unsigned-integer-overflow. +constexpr intptr_t UnpackSigned(size_t value) + JXL_NO_SANITIZE("unsigned-integer-overflow") { + return static_cast<intptr_t>((value >> 1) ^ (((~value) & 1) - 1)); +} + +} // namespace jxl + +#endif // LIB_JXL_PACK_H_ diff --git a/third_party/jpeg-xl/lib/jxl/padded_bytes.h b/third_party/jpeg-xl/lib/jxl/padded_bytes.h new file mode 100644 index 0000000000..0d696475fa --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/padded_bytes.h @@ -0,0 +1,216 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_BASE_PADDED_BYTES_H_ +#define LIB_JXL_BASE_PADDED_BYTES_H_ + +// std::vector replacement with padding to reduce bounds checks in WriteBits + +#include <stddef.h> +#include <stdint.h> +#include <string.h> // memcpy + +#include <algorithm> // max +#include <initializer_list> +#include <utility> // swap + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/cache_aligned.h" + +namespace jxl { + +// Provides a subset of the std::vector interface with some differences: +// - allows BitWriter to write 64 bits at a time without bounds checking; +// - ONLY zero-initializes the first byte (required by BitWriter); +// - ensures cache-line alignment. +class PaddedBytes { + public: + // Required for output params. + PaddedBytes() : size_(0), capacity_(0) {} + + explicit PaddedBytes(size_t size) : size_(size), capacity_(0) { + reserve(size); + } + + PaddedBytes(size_t size, uint8_t value) : size_(size), capacity_(0) { + reserve(size); + if (size_ != 0) { + memset(data(), value, size); + } + } + + PaddedBytes(const PaddedBytes& other) : size_(other.size_), capacity_(0) { + reserve(size_); + if (data() != nullptr) memcpy(data(), other.data(), size_); + } + PaddedBytes& operator=(const PaddedBytes& other) { + // Self-assignment is safe. + resize(other.size()); + if (data() != nullptr) memmove(data(), other.data(), size_); + return *this; + } + + // default is not OK - need to set other.size_ to 0! + PaddedBytes(PaddedBytes&& other) noexcept + : size_(other.size_), + capacity_(other.capacity_), + data_(std::move(other.data_)) { + other.size_ = other.capacity_ = 0; + } + PaddedBytes& operator=(PaddedBytes&& other) noexcept { + size_ = other.size_; + capacity_ = other.capacity_; + data_ = std::move(other.data_); + + if (&other != this) { + other.size_ = other.capacity_ = 0; + } + return *this; + } + + void swap(PaddedBytes& other) { + std::swap(size_, other.size_); + std::swap(capacity_, other.capacity_); + std::swap(data_, other.data_); + } + + // If current capacity is greater than requested, then no-op. Otherwise + // copies existing data to newly allocated "data_". If allocation fails, + // data() == nullptr and size_ = capacity_ = 0. + // The new capacity will be at least 1.5 times the old capacity. This ensures + // that we avoid quadratic behaviour. + void reserve(size_t capacity) { + if (capacity <= capacity_) return; + + size_t new_capacity = std::max(capacity, 3 * capacity_ / 2); + new_capacity = std::max<size_t>(64, new_capacity); + + // BitWriter writes up to 7 bytes past the end. + CacheAlignedUniquePtr new_data = AllocateArray(new_capacity + 8); + if (new_data == nullptr) { + // Allocation failed, discard all data to ensure this is noticed. + size_ = capacity_ = 0; + return; + } + + if (data_ == nullptr) { + // First allocation: ensure first byte is initialized (won't be copied). + new_data[0] = 0; + } else { + // Subsequent resize: copy existing data to new location. + memcpy(new_data.get(), data_.get(), size_); + // Ensure that the first new byte is initialized, to allow write_bits to + // safely append to the newly-resized PaddedBytes. + new_data[size_] = 0; + } + + capacity_ = new_capacity; + std::swap(new_data, data_); + } + + // NOTE: unlike vector, this does not initialize the new data! + // However, we guarantee that write_bits can safely append after + // the resize, as we zero-initialize the first new byte of data. + // If size < capacity(), does not invalidate the memory. + void resize(size_t size) { + reserve(size); + size_ = (data() == nullptr) ? 0 : size; + } + + // resize(size) plus explicit initialization of the new data with `value`. + void resize(size_t size, uint8_t value) { + size_t old_size = size_; + resize(size); + if (size_ > old_size) { + memset(data() + old_size, value, size_ - old_size); + } + } + + // Amortized constant complexity due to exponential growth. + void push_back(uint8_t x) { + if (size_ == capacity_) { + reserve(capacity_ + 1); + if (data() == nullptr) return; + } + + data_[size_++] = x; + } + + size_t size() const { return size_; } + size_t capacity() const { return capacity_; } + + uint8_t* data() { return data_.get(); } + const uint8_t* data() const { return data_.get(); } + + // std::vector operations implemented in terms of the public interface above. + + void clear() { resize(0); } + bool empty() const { return size() == 0; } + + void assign(std::initializer_list<uint8_t> il) { + resize(il.size()); + memcpy(data(), il.begin(), il.size()); + } + + uint8_t* begin() { return data(); } + const uint8_t* begin() const { return data(); } + uint8_t* end() { return begin() + size(); } + const uint8_t* end() const { return begin() + size(); } + + uint8_t& operator[](const size_t i) { + BoundsCheck(i); + return data()[i]; + } + const uint8_t& operator[](const size_t i) const { + BoundsCheck(i); + return data()[i]; + } + + uint8_t& back() { + JXL_ASSERT(size() != 0); + return data()[size() - 1]; + } + const uint8_t& back() const { + JXL_ASSERT(size() != 0); + return data()[size() - 1]; + } + + template <typename T> + void append(const T& other) { + append(reinterpret_cast<const uint8_t*>(other.data()), + reinterpret_cast<const uint8_t*>(other.data()) + other.size()); + } + + void append(const uint8_t* begin, const uint8_t* end) { + if (end - begin > 0) { + size_t old_size = size(); + resize(size() + (end - begin)); + memcpy(data() + old_size, begin, end - begin); + } + } + + private: + void BoundsCheck(size_t i) const { + // <= is safe due to padding and required by BitWriter. + JXL_ASSERT(i <= size()); + } + + size_t size_; + size_t capacity_; + CacheAlignedUniquePtr data_; +}; + +template <typename T> +static inline void Append(const T& s, PaddedBytes* out, + size_t* JXL_RESTRICT byte_pos) { + memcpy(out->data() + *byte_pos, s.data(), s.size()); + *byte_pos += s.size(); + JXL_CHECK(*byte_pos <= out->size()); +} + +} // namespace jxl + +#endif // LIB_JXL_BASE_PADDED_BYTES_H_ diff --git a/third_party/jpeg-xl/lib/jxl/padded_bytes_test.cc b/third_party/jpeg-xl/lib/jxl/padded_bytes_test.cc new file mode 100644 index 0000000000..83d1da9c25 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/padded_bytes_test.cc @@ -0,0 +1,64 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/padded_bytes.h" + +#include <numeric> // iota +#include <vector> + +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +TEST(PaddedBytesTest, TestNonEmptyFirstByteZero) { + PaddedBytes pb(1); + EXPECT_EQ(0, pb[0]); + // Even after resizing.. + pb.resize(20); + EXPECT_EQ(0, pb[0]); + // And reserving. + pb.reserve(200); + EXPECT_EQ(0, pb[0]); +} + +TEST(PaddedBytesTest, TestEmptyFirstByteZero) { + PaddedBytes pb(0); + // After resizing - new zero is written despite there being nothing to copy. + pb.resize(20); + EXPECT_EQ(0, pb[0]); +} + +TEST(PaddedBytesTest, TestFillWithoutReserve) { + PaddedBytes pb; + for (size_t i = 0; i < 170u; ++i) { + pb.push_back(i); + } + EXPECT_EQ(170u, pb.size()); + EXPECT_GE(pb.capacity(), 170u); +} + +TEST(PaddedBytesTest, TestFillWithExactReserve) { + PaddedBytes pb; + pb.reserve(170); + for (size_t i = 0; i < 170u; ++i) { + pb.push_back(i); + } + EXPECT_EQ(170u, pb.size()); + EXPECT_EQ(pb.capacity(), 170u); +} + +TEST(PaddedBytesTest, TestFillWithMoreReserve) { + PaddedBytes pb; + pb.reserve(171); + for (size_t i = 0; i < 170u; ++i) { + pb.push_back(i); + } + EXPECT_EQ(170u, pb.size()); + EXPECT_GT(pb.capacity(), 170u); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/passes_state.cc b/third_party/jpeg-xl/lib/jxl/passes_state.cc new file mode 100644 index 0000000000..12cc6a0c93 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/passes_state.cc @@ -0,0 +1,69 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/passes_state.h" + +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/frame_dimensions.h" + +namespace jxl { + +Status InitializePassesSharedState(const FrameHeader& frame_header, + PassesSharedState* JXL_RESTRICT shared, + bool encoder) { + JXL_ASSERT(frame_header.nonserialized_metadata != nullptr); + shared->metadata = frame_header.nonserialized_metadata; + shared->frame_dim = frame_header.ToFrameDimensions(); + shared->image_features.patches.SetPassesSharedState(shared); + + const FrameDimensions& frame_dim = shared->frame_dim; + + shared->ac_strategy = + AcStrategyImage(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared->raw_quant_field = + ImageI(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared->epf_sharpness = + ImageB(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared->cmap = ColorCorrelationMap(frame_dim.xsize, frame_dim.ysize); + + // In the decoder, we allocate coeff orders afterwards, when we know how many + // we will actually need. + shared->coeff_order_size = kCoeffOrderMaxSize; + if (encoder && + shared->coeff_orders.size() < + frame_header.passes.num_passes * kCoeffOrderMaxSize && + frame_header.encoding == FrameEncoding::kVarDCT) { + shared->coeff_orders.resize(frame_header.passes.num_passes * + kCoeffOrderMaxSize); + } + + shared->quant_dc = ImageB(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + + bool use_dc_frame = !!(frame_header.flags & FrameHeader::kUseDcFrame); + if (!encoder && use_dc_frame) { + if (frame_header.dc_level == 4) { + return JXL_FAILURE("Invalid DC level for kUseDcFrame: %u", + frame_header.dc_level); + } + shared->dc_storage = Image3F(); + shared->dc = &shared->dc_frames[frame_header.dc_level]; + if (shared->dc->xsize() == 0) { + return JXL_FAILURE( + "kUseDcFrame specified for dc_level %u, but no frame was decoded " + "with level %u", + frame_header.dc_level, frame_header.dc_level + 1); + } + ZeroFillImage(&shared->quant_dc); + } else { + shared->dc_storage = + Image3F(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + shared->dc = &shared->dc_storage; + } + + return true; +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/passes_state.h b/third_party/jpeg-xl/lib/jxl/passes_state.h new file mode 100644 index 0000000000..ffb213d4a4 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/passes_state.h @@ -0,0 +1,89 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_PASSES_STATE_H_ +#define LIB_JXL_PASSES_STATE_H_ + +#include "lib/jxl/ac_context.h" +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/noise.h" +#include "lib/jxl/quant_weights.h" +#include "lib/jxl/quantizer.h" +#include "lib/jxl/splines.h" + +// Structures that hold the (en/de)coder state for a JPEG XL kVarDCT +// (en/de)coder. + +namespace jxl { + +struct ImageFeatures { + NoiseParams noise_params; + PatchDictionary patches; + Splines splines; +}; + +// State common to both encoder and decoder. +// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) +struct PassesSharedState { + const CodecMetadata* metadata; + + FrameDimensions frame_dim; + + // Control fields and parameters. + AcStrategyImage ac_strategy; + + // Dequant matrices + quantizer. + DequantMatrices matrices; + Quantizer quantizer{&matrices}; + ImageI raw_quant_field; + + // Per-block side information for EPF detail preservation. + ImageB epf_sharpness; + + ColorCorrelationMap cmap; + + ImageFeatures image_features; + + // Memory area for storing coefficient orders. + // `coeff_order_size` is the size used by *one* set of coefficient orders (at + // most kMaxCoeffOrderSize). A set of coefficient orders is present for each + // pass. + size_t coeff_order_size = 0; + std::vector<coeff_order_t> coeff_orders; + + // Decoder-side DC and quantized DC. + ImageB quant_dc; + Image3F dc_storage; + const Image3F* JXL_RESTRICT dc = &dc_storage; + + BlockCtxMap block_ctx_map; + + Image3F dc_frames[4]; + + struct { + ImageBundle frame; + // ImageBundle doesn't yet have a simple way to state it is in XYB. + bool ib_is_in_xyb = false; + } reference_frames[4] = {}; + + // Number of pre-clustered set of histograms (with the same ctx map), per + // pass. Encoded as num_histograms_ - 1. + size_t num_histograms = 0; +}; + +// Initialized the state information that is shared between encoder and decoder. +Status InitializePassesSharedState(const FrameHeader& frame_header, + PassesSharedState* JXL_RESTRICT shared, + bool encoder = false); + +} // namespace jxl + +#endif // LIB_JXL_PASSES_STATE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/passes_test.cc b/third_party/jpeg-xl/lib/jxl/passes_test.cc new file mode 100644 index 0000000000..a47134cd00 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/passes_test.cc @@ -0,0 +1,385 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/cms.h> +#include <stddef.h> + +#include <cstdint> +#include <future> +#include <string> +#include <utility> +#include <vector> + +#include "lib/extras/codec.h" +#include "lib/extras/dec/jxl.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { + +using test::ReadTestData; +using test::Roundtrip; +using test::ThreadPoolForTests; + +namespace { + +TEST(PassesTest, RoundtripSmallPasses) { + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + CompressParams cparams; + cparams.butteraugli_distance = 1.0; + cparams.progressive_mode = Override::kOn; + cparams.SetCms(*JxlGetDefaultCms()); + + CodecInOut io2; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _)); + EXPECT_THAT(ButteraugliDistance(io.frames, io2.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(0.8222)); +} + +TEST(PassesTest, RoundtripUnalignedPasses) { + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + io.ShrinkTo(io.xsize() / 12, io.ysize() / 7); + + CompressParams cparams; + cparams.butteraugli_distance = 2.0; + cparams.progressive_mode = Override::kOn; + cparams.SetCms(*JxlGetDefaultCms()); + + CodecInOut io2; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _)); + EXPECT_THAT(ButteraugliDistance(io.frames, io2.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(1.72)); +} + +TEST(PassesTest, RoundtripMultiGroupPasses) { + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + { + ThreadPoolForTests pool(4); + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io, &pool)); + } + io.ShrinkTo(600, 1024); // partial X, full Y group + + auto test = [&](float target_distance, float threshold) { + ThreadPoolForTests pool(4); + CompressParams cparams; + cparams.butteraugli_distance = target_distance; + cparams.progressive_mode = Override::kOn; + cparams.SetCms(*JxlGetDefaultCms()); + CodecInOut io2; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _, + /* compressed_size */ nullptr, &pool)); + EXPECT_THAT(ButteraugliDistance(io.frames, io2.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr, &pool), + IsSlightlyBelow(target_distance + threshold)); + }; + + auto run1 = std::async(std::launch::async, test, 1.0f, 0.15f); + auto run2 = std::async(std::launch::async, test, 2.0f, 0.0f); +} + +TEST(PassesTest, RoundtripLargeFastPasses) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io, &pool)); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = Override::kOn; + cparams.SetCms(*JxlGetDefaultCms()); + + CodecInOut io2; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _, + /* compressed_size */ nullptr, &pool)); +} + +// Checks for differing size/distance in two consecutive runs of distance 2, +// which involves additional processing including adaptive reconstruction. +// Failing this may be a sign of race conditions or invalid memory accesses. +TEST(PassesTest, RoundtripProgressiveConsistent) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io, &pool)); + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = Override::kOn; + cparams.butteraugli_distance = 2.0; + cparams.SetCms(*JxlGetDefaultCms()); + + // Try each xsize mod kBlockDim to verify right border handling. + for (size_t xsize = 48; xsize > 40; --xsize) { + io.ShrinkTo(xsize, 15); + + CodecInOut io2; + size_t size2; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _, &size2, &pool)); + + CodecInOut io3; + size_t size3; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io3, _, &size3, &pool)); + + // Exact same compressed size. + EXPECT_EQ(size2, size3); + + // Exact same distance. + const float dist2 = ButteraugliDistance( + io.frames, io2.frames, ButteraugliParams(), *JxlGetDefaultCms(), + /*distmap=*/nullptr, &pool); + const float dist3 = ButteraugliDistance( + io.frames, io3.frames, ButteraugliParams(), *JxlGetDefaultCms(), + /*distmap=*/nullptr, &pool); + EXPECT_EQ(dist2, dist3); + } +} + +TEST(PassesTest, AllDownsampleFeasible) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io, &pool)); + + std::vector<uint8_t> compressed; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = Override::kOn; + cparams.butteraugli_distance = 1.0; + ASSERT_TRUE(test::EncodeFile(cparams, &io, &compressed, &pool)); + + EXPECT_LE(compressed.size(), 240000u); + float target_butteraugli[9] = {}; + target_butteraugli[1] = 2.5f; + target_butteraugli[2] = 16.0f; + target_butteraugli[4] = 20.0f; + target_butteraugli[8] = 80.0f; + + // The default progressive encoding scheme should make all these downsampling + // factors achievable. + // TODO(veluca): re-enable downsampling 16. + std::vector<size_t> downsamplings = {1, 2, 4, 8}; //, 16}; + + auto check = [&](const uint32_t task, size_t /* thread */) -> void { + const size_t downsampling = downsamplings[task]; + extras::JXLDecompressParams dparams; + dparams.max_downsampling = downsampling; + CodecInOut output; + ASSERT_TRUE(test::DecodeFile(dparams, Bytes(compressed), &output)); + EXPECT_EQ(output.xsize(), io.xsize()) << "downsampling = " << downsampling; + EXPECT_EQ(output.ysize(), io.ysize()) << "downsampling = " << downsampling; + EXPECT_LE(ButteraugliDistance(io.frames, output.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr, nullptr), + target_butteraugli[downsampling]) + << "downsampling: " << downsampling; + }; + EXPECT_TRUE(RunOnPool(&pool, 0, downsamplings.size(), ThreadPool::NoInit, + check, "TestDownsampling")); +} + +TEST(PassesTest, AllDownsampleFeasibleQProgressive) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io, &pool)); + + std::vector<uint8_t> compressed; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.qprogressive_mode = Override::kOn; + cparams.butteraugli_distance = 1.0; + ASSERT_TRUE(test::EncodeFile(cparams, &io, &compressed, &pool)); + + EXPECT_LE(compressed.size(), 220000u); + + float target_butteraugli[9] = {}; + target_butteraugli[1] = 3.0f; + target_butteraugli[2] = 6.0f; + target_butteraugli[4] = 10.0f; + target_butteraugli[8] = 80.0f; + + // The default progressive encoding scheme should make all these downsampling + // factors achievable. + std::vector<size_t> downsamplings = {1, 2, 4, 8}; + + auto check = [&](const uint32_t task, size_t /* thread */) -> void { + const size_t downsampling = downsamplings[task]; + extras::JXLDecompressParams dparams; + dparams.max_downsampling = downsampling; + CodecInOut output; + ASSERT_TRUE(test::DecodeFile(dparams, Bytes(compressed), &output)); + EXPECT_EQ(output.xsize(), io.xsize()) << "downsampling = " << downsampling; + EXPECT_EQ(output.ysize(), io.ysize()) << "downsampling = " << downsampling; + EXPECT_LE(ButteraugliDistance(io.frames, output.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + target_butteraugli[downsampling]) + << "downsampling: " << downsampling; + }; + EXPECT_TRUE(RunOnPool(&pool, 0, downsamplings.size(), ThreadPool::NoInit, + check, "TestQProgressive")); +} + +TEST(PassesTest, ProgressiveDownsample2DegradesCorrectlyGrayscale) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = ReadTestData( + "external/wesaturate/500px/cvo9xd_keong_macan_grayscale.png"); + CodecInOut io_orig; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io_orig, &pool)); + Rect rect(0, 0, io_orig.xsize(), 128); + // need 2 DC groups for the DC frame to actually be progressive. + Image3F large(4242, rect.ysize()); + ZeroFillImage(&large); + CopyImageTo(rect, *io_orig.Main().color(), rect, &large); + CodecInOut io; + io.metadata = io_orig.metadata; + io.SetFromImage(std::move(large), io_orig.Main().c_current()); + + std::vector<uint8_t> compressed; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_dc = 1; + cparams.responsive = true; + cparams.qprogressive_mode = Override::kOn; + cparams.butteraugli_distance = 1.0; + ASSERT_TRUE(test::EncodeFile(cparams, &io, &compressed, &pool)); + + EXPECT_LE(compressed.size(), 10000u); + + extras::JXLDecompressParams dparams; + dparams.max_downsampling = 1; + CodecInOut output; + ASSERT_TRUE(test::DecodeFile(dparams, Bytes(compressed), &output)); + + dparams.max_downsampling = 2; + CodecInOut output_d2; + ASSERT_TRUE(test::DecodeFile(dparams, Bytes(compressed), &output_d2)); + + // 0 if reading all the passes, ~15 if skipping the 8x pass. + float butteraugli_distance_down2_full = ButteraugliDistance( + output.frames, output_d2.frames, ButteraugliParams(), *JxlGetDefaultCms(), + /*distmap=*/nullptr); + + EXPECT_LE(butteraugli_distance_down2_full, 3.2f); + EXPECT_GE(butteraugli_distance_down2_full, 1.0f); +} + +TEST(PassesTest, ProgressiveDownsample2DegradesCorrectly) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io_orig; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io_orig, &pool)); + Rect rect(0, 0, io_orig.xsize(), 128); + // need 2 DC groups for the DC frame to actually be progressive. + Image3F large(4242, rect.ysize()); + ZeroFillImage(&large); + CopyImageTo(rect, *io_orig.Main().color(), rect, &large); + CodecInOut io; + io.SetFromImage(std::move(large), io_orig.Main().c_current()); + + std::vector<uint8_t> compressed; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_dc = 1; + cparams.responsive = true; + cparams.qprogressive_mode = Override::kOn; + cparams.butteraugli_distance = 1.0; + ASSERT_TRUE(test::EncodeFile(cparams, &io, &compressed, &pool)); + + EXPECT_LE(compressed.size(), 220000u); + + extras::JXLDecompressParams dparams; + dparams.max_downsampling = 1; + CodecInOut output; + ASSERT_TRUE(test::DecodeFile(dparams, Bytes(compressed), &output)); + + dparams.max_downsampling = 2; + CodecInOut output_d2; + ASSERT_TRUE(test::DecodeFile(dparams, Bytes(compressed), &output_d2)); + + // 0 if reading all the passes, ~15 if skipping the 8x pass. + float butteraugli_distance_down2_full = ButteraugliDistance( + output.frames, output_d2.frames, ButteraugliParams(), *JxlGetDefaultCms(), + /*distmap=*/nullptr); + + EXPECT_LE(butteraugli_distance_down2_full, 3.0f); + EXPECT_GE(butteraugli_distance_down2_full, 1.0f); +} + +TEST(PassesTest, NonProgressiveDCImage) { + ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = ReadTestData("jxl/flower/flower.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io, &pool)); + + std::vector<uint8_t> compressed; + + CompressParams cparams; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.progressive_mode = Override::kOff; + cparams.butteraugli_distance = 2.0; + ASSERT_TRUE(test::EncodeFile(cparams, &io, &compressed, &pool)); + + // Even in non-progressive mode, it should be possible to return a DC-only + // image. + extras::JXLDecompressParams dparams; + dparams.max_downsampling = 100; + CodecInOut output; + ASSERT_TRUE(test::DecodeFile(dparams, Bytes(compressed), &output, &pool)); + EXPECT_EQ(output.xsize(), io.xsize()); + EXPECT_EQ(output.ysize(), io.ysize()); +} + +TEST(PassesTest, RoundtripSmallNoGaborishPasses) { + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + + CompressParams cparams; + cparams.gaborish = Override::kOff; + cparams.butteraugli_distance = 1.0; + cparams.progressive_mode = Override::kOn; + cparams.SetCms(*JxlGetDefaultCms()); + + CodecInOut io2; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _)); + EXPECT_THAT(ButteraugliDistance(io.frames, io2.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + IsSlightlyBelow(1.2)); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/patch_dictionary_internal.h b/third_party/jpeg-xl/lib/jxl/patch_dictionary_internal.h new file mode 100644 index 0000000000..e4172f6db6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/patch_dictionary_internal.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_PATCH_DICTIONARY_INTERNAL_H_ +#define LIB_JXL_PATCH_DICTIONARY_INTERNAL_H_ + +#include "lib/jxl/dec_patch_dictionary.h" +#include "lib/jxl/passes_state.h" // for PassesSharedState + +namespace jxl { + +// Context numbers as specified in Section C.4.5, Listing C.2: +enum Contexts { + kNumRefPatchContext = 0, + kReferenceFrameContext = 1, + kPatchSizeContext = 2, + kPatchReferencePositionContext = 3, + kPatchPositionContext = 4, + kPatchBlendModeContext = 5, + kPatchOffsetContext = 6, + kPatchCountContext = 7, + kPatchAlphaChannelContext = 8, + kPatchClampContext = 9, + kNumPatchDictionaryContexts +}; + +} // namespace jxl + +#endif // LIB_JXL_PATCH_DICTIONARY_INTERNAL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/patch_dictionary_test.cc b/third_party/jpeg-xl/lib/jxl/patch_dictionary_test.cc new file mode 100644 index 0000000000..60f7c32229 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/patch_dictionary_test.cc @@ -0,0 +1,66 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/cms.h> + +#include <cstddef> +#include <cstdint> +#include <vector> + +#include "lib/extras/codec.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +using test::ReadTestData; +using test::Roundtrip; + +TEST(PatchDictionaryTest, GrayscaleModular) { + const std::vector<uint8_t> orig = ReadTestData("jxl/grayscale_patches.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + + CompressParams cparams; + cparams.SetLossless(); + cparams.patches = jxl::Override::kOn; + + CodecInOut io2; + // Without patches: ~25k + size_t compressed_size; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _, &compressed_size)); + EXPECT_LE(compressed_size, 8000u); + JXL_ASSERT_OK(VerifyRelativeError(*io.Main().color(), *io2.Main().color(), + 1e-7f, 0, _)); +} + +TEST(PatchDictionaryTest, GrayscaleVarDCT) { + const std::vector<uint8_t> orig = ReadTestData("jxl/grayscale_patches.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + + CompressParams cparams; + cparams.patches = jxl::Override::kOn; + + CodecInOut io2; + // Without patches: ~47k + size_t compressed_size; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _, &compressed_size)); + EXPECT_LE(compressed_size, 14000u); + // Without patches: ~1.2 + EXPECT_LE(ButteraugliDistance(io.frames, io2.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + 1.1); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/preview_test.cc b/third_party/jpeg-xl/lib/jxl/preview_test.cc new file mode 100644 index 0000000000..b7fe855d4d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/preview_test.cc @@ -0,0 +1,66 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/cms.h> + +#include <cstddef> +#include <cstdint> +#include <string> +#include <vector> + +#include "lib/extras/codec.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { +using test::ReadTestData; +using test::Roundtrip; + +TEST(PreviewTest, RoundtripGivenPreview) { + const std::vector<uint8_t> orig = + ReadTestData("external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + CodecInOut io; + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io)); + io.ShrinkTo(io.xsize() / 8, io.ysize() / 8); + // Same as main image + io.preview_frame = io.Main().Copy(); + const size_t preview_xsize = 15; + const size_t preview_ysize = 27; + io.preview_frame.ShrinkTo(preview_xsize, preview_ysize); + io.metadata.m.have_preview = true; + ASSERT_TRUE(io.metadata.m.preview_size.Set(io.preview_frame.xsize(), + io.preview_frame.ysize())); + + CompressParams cparams; + cparams.butteraugli_distance = 2.0; + cparams.speed_tier = SpeedTier::kSquirrel; + cparams.SetCms(*JxlGetDefaultCms()); + + CodecInOut io2; + JXL_EXPECT_OK(Roundtrip(&io, cparams, {}, &io2, _)); + EXPECT_EQ(preview_xsize, io2.metadata.m.preview_size.xsize()); + EXPECT_EQ(preview_ysize, io2.metadata.m.preview_size.ysize()); + EXPECT_EQ(preview_xsize, io2.preview_frame.xsize()); + EXPECT_EQ(preview_ysize, io2.preview_frame.ysize()); + + EXPECT_LE(ButteraugliDistance(io.preview_frame, io2.preview_frame, + ButteraugliParams(), *JxlGetDefaultCms(), + /*distmap=*/nullptr), + 2.5); + EXPECT_LE(ButteraugliDistance(io.Main(), io2.Main(), ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr), + 2.5); +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/quant_weights.cc b/third_party/jpeg-xl/lib/jxl/quant_weights.cc new file mode 100644 index 0000000000..70b3b9e451 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quant_weights.cc @@ -0,0 +1,1238 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +#include "lib/jxl/quant_weights.h" + +#include <stdio.h> +#include <stdlib.h> + +#include <algorithm> +#include <cmath> +#include <limits> +#include <utility> + +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_modular.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/quant_weights.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/fast_math-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Sqrt; + +// kQuantWeights[N * N * c + N * y + x] is the relative weight of the (x, y) +// coefficient in component c. Higher weights correspond to finer quantization +// intervals and more bits spent in encoding. + +static constexpr const float kAlmostZero = 1e-8f; + +void GetQuantWeightsDCT2(const QuantEncoding::DCT2Weights& dct2weights, + float* weights) { + for (size_t c = 0; c < 3; c++) { + size_t start = c * 64; + weights[start] = 0xBAD; + weights[start + 1] = weights[start + 8] = dct2weights[c][0]; + weights[start + 9] = dct2weights[c][1]; + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + weights[start + y * 8 + x + 2] = dct2weights[c][2]; + weights[start + (y + 2) * 8 + x] = dct2weights[c][2]; + } + } + for (size_t y = 0; y < 2; y++) { + for (size_t x = 0; x < 2; x++) { + weights[start + (y + 2) * 8 + x + 2] = dct2weights[c][3]; + } + } + for (size_t y = 0; y < 4; y++) { + for (size_t x = 0; x < 4; x++) { + weights[start + y * 8 + x + 4] = dct2weights[c][4]; + weights[start + (y + 4) * 8 + x] = dct2weights[c][4]; + } + } + for (size_t y = 0; y < 4; y++) { + for (size_t x = 0; x < 4; x++) { + weights[start + (y + 4) * 8 + x + 4] = dct2weights[c][5]; + } + } + } +} + +void GetQuantWeightsIdentity(const QuantEncoding::IdWeights& idweights, + float* weights) { + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 64; i++) { + weights[64 * c + i] = idweights[c][0]; + } + weights[64 * c + 1] = idweights[c][1]; + weights[64 * c + 8] = idweights[c][1]; + weights[64 * c + 9] = idweights[c][2]; + } +} + +float Interpolate(float pos, float max, const float* array, size_t len) { + float scaled_pos = pos * (len - 1) / max; + size_t idx = scaled_pos; + JXL_DASSERT(idx + 1 < len); + float a = array[idx]; + float b = array[idx + 1]; + return a * FastPowf(b / a, scaled_pos - idx); +} + +float Mult(float v) { + if (v > 0.0f) return 1.0f + v; + return 1.0f / (1.0f - v); +} + +using DF4 = HWY_CAPPED(float, 4); + +hwy::HWY_NAMESPACE::Vec<DF4> InterpolateVec( + hwy::HWY_NAMESPACE::Vec<DF4> scaled_pos, const float* array) { + HWY_CAPPED(int32_t, 4) di; + + auto idx = ConvertTo(di, scaled_pos); + + auto frac = Sub(scaled_pos, ConvertTo(DF4(), idx)); + + // TODO(veluca): in theory, this could be done with 8 TableLookupBytes, but + // it's probably slower. + auto a = GatherIndex(DF4(), array, idx); + auto b = GatherIndex(DF4(), array + 1, idx); + + return Mul(a, FastPowf(DF4(), Div(b, a), frac)); +} + +// Computes quant weights for a COLS*ROWS-sized transform, using num_bands +// eccentricity bands and num_ebands eccentricity bands. If print_mode is 1, +// prints the resulting matrix; if print_mode is 2, prints the matrix in a +// format suitable for a 3d plot with gnuplot. +Status GetQuantWeights( + size_t ROWS, size_t COLS, + const DctQuantWeightParams::DistanceBandsArray& distance_bands, + size_t num_bands, float* out) { + for (size_t c = 0; c < 3; c++) { + float bands[DctQuantWeightParams::kMaxDistanceBands] = { + distance_bands[c][0]}; + if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); + for (size_t i = 1; i < num_bands; i++) { + bands[i] = bands[i - 1] * Mult(distance_bands[c][i]); + if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); + } + float scale = (num_bands - 1) / (kSqrt2 + 1e-6f); + float rcpcol = scale / (COLS - 1); + float rcprow = scale / (ROWS - 1); + JXL_ASSERT(COLS >= Lanes(DF4())); + HWY_ALIGN float l0123[4] = {0, 1, 2, 3}; + for (uint32_t y = 0; y < ROWS; y++) { + float dy = y * rcprow; + float dy2 = dy * dy; + for (uint32_t x = 0; x < COLS; x += Lanes(DF4())) { + auto dx = + Mul(Add(Set(DF4(), x), Load(DF4(), l0123)), Set(DF4(), rcpcol)); + auto scaled_distance = Sqrt(MulAdd(dx, dx, Set(DF4(), dy2))); + auto weight = num_bands == 1 ? Set(DF4(), bands[0]) + : InterpolateVec(scaled_distance, bands); + StoreU(weight, DF4(), out + c * COLS * ROWS + y * COLS + x); + } + } + } + return true; +} + +// TODO(veluca): SIMD-fy. With 256x256, this is actually slow. +Status ComputeQuantTable(const QuantEncoding& encoding, + float* JXL_RESTRICT table, + float* JXL_RESTRICT inv_table, size_t table_num, + DequantMatrices::QuantTable kind, size_t* pos) { + constexpr size_t N = kBlockDim; + size_t wrows = 8 * DequantMatrices::required_size_x[kind], + wcols = 8 * DequantMatrices::required_size_y[kind]; + size_t num = wrows * wcols; + + std::vector<float> weights(3 * num); + + switch (encoding.mode) { + case QuantEncoding::kQuantModeLibrary: { + // Library and copy quant encoding should get replaced by the actual + // parameters by the caller. + JXL_ASSERT(false); + break; + } + case QuantEncoding::kQuantModeID: { + JXL_ASSERT(num == kDCTBlockSize); + GetQuantWeightsIdentity(encoding.idweights, weights.data()); + break; + } + case QuantEncoding::kQuantModeDCT2: { + JXL_ASSERT(num == kDCTBlockSize); + GetQuantWeightsDCT2(encoding.dct2weights, weights.data()); + break; + } + case QuantEncoding::kQuantModeDCT4: { + JXL_ASSERT(num == kDCTBlockSize); + float weights4x4[3 * 4 * 4]; + // Always use 4x4 GetQuantWeights for DCT4 quantization tables. + JXL_RETURN_IF_ERROR( + GetQuantWeights(4, 4, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x4)); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + weights[c * num + y * kBlockDim + x] = + weights4x4[c * 16 + (y / 2) * 4 + (x / 2)]; + } + } + weights[c * num + 1] /= encoding.dct4multipliers[c][0]; + weights[c * num + N] /= encoding.dct4multipliers[c][0]; + weights[c * num + N + 1] /= encoding.dct4multipliers[c][1]; + } + break; + } + case QuantEncoding::kQuantModeDCT4X8: { + JXL_ASSERT(num == kDCTBlockSize); + float weights4x8[3 * 4 * 8]; + // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables. + JXL_RETURN_IF_ERROR( + GetQuantWeights(4, 8, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x8)); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < kBlockDim; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + weights[c * num + y * kBlockDim + x] = + weights4x8[c * 32 + (y / 2) * 8 + x]; + } + } + weights[c * num + N] /= encoding.dct4x8multipliers[c]; + } + break; + } + case QuantEncoding::kQuantModeDCT: { + JXL_RETURN_IF_ERROR(GetQuantWeights( + wrows, wcols, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights.data())); + break; + } + case QuantEncoding::kQuantModeRAW: { + if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) { + return JXL_FAILURE("Invalid table encoding"); + } + for (size_t i = 0; i < 3 * num; i++) { + weights[i] = + 1.f / (encoding.qraw.qtable_den * (*encoding.qraw.qtable)[i]); + } + break; + } + case QuantEncoding::kQuantModeAFV: { + constexpr float kFreqs[] = { + 0xBAD, + 0xBAD, + 0.8517778890324296, + 5.37778436506804, + 0xBAD, + 0xBAD, + 4.734747904497923, + 5.449245381693219, + 1.6598270267479331, + 4, + 7.275749096817861, + 10.423227632456525, + 2.662932286148962, + 7.630657783650829, + 8.962388608184032, + 12.97166202570235, + }; + + float weights4x8[3 * 4 * 8]; + JXL_RETURN_IF_ERROR(( + GetQuantWeights(4, 8, encoding.dct_params.distance_bands, + encoding.dct_params.num_distance_bands, weights4x8))); + float weights4x4[3 * 4 * 4]; + JXL_RETURN_IF_ERROR((GetQuantWeights( + 4, 4, encoding.dct_params_afv_4x4.distance_bands, + encoding.dct_params_afv_4x4.num_distance_bands, weights4x4))); + + constexpr float lo = 0.8517778890324296; + constexpr float hi = 12.97166202570235f - lo + 1e-6f; + for (size_t c = 0; c < 3; c++) { + float bands[4]; + bands[0] = encoding.afv_weights[c][5]; + if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); + for (size_t i = 1; i < 4; i++) { + bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]); + if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); + } + size_t start = c * 64; + auto set_weight = [&start, &weights](size_t x, size_t y, float val) { + weights[start + y * 8 + x] = val; + }; + weights[start] = 1; // Not used, but causes MSAN error otherwise. + // Weights for (0, 1) and (1, 0). + set_weight(0, 1, encoding.afv_weights[c][0]); + set_weight(1, 0, encoding.afv_weights[c][1]); + // AFV special weights for 3-pixel corner. + set_weight(0, 2, encoding.afv_weights[c][2]); + set_weight(2, 0, encoding.afv_weights[c][3]); + set_weight(2, 2, encoding.afv_weights[c][4]); + + // All other AFV weights. + for (size_t y = 0; y < 4; y++) { + for (size_t x = 0; x < 4; x++) { + if (x < 2 && y < 2) continue; + float val = Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4); + set_weight(2 * x, 2 * y, val); + } + } + + // Put 4x8 weights in odd rows, except (1, 0). + for (size_t y = 0; y < kBlockDim / 2; y++) { + for (size_t x = 0; x < kBlockDim; x++) { + if (x == 0 && y == 0) continue; + weights[c * num + (2 * y + 1) * kBlockDim + x] = + weights4x8[c * 32 + y * 8 + x]; + } + } + // Put 4x4 weights in even rows / odd columns, except (0, 1). + for (size_t y = 0; y < kBlockDim / 2; y++) { + for (size_t x = 0; x < kBlockDim / 2; x++) { + if (x == 0 && y == 0) continue; + weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] = + weights4x4[c * 16 + y * 4 + x]; + } + } + } + break; + } + } + size_t prev_pos = *pos; + HWY_CAPPED(float, 64) d; + for (size_t i = 0; i < num * 3; i += Lanes(d)) { + auto inv_val = LoadU(d, weights.data() + i); + if (JXL_UNLIKELY(!AllFalse(d, Ge(inv_val, Set(d, 1.0f / kAlmostZero))) || + !AllFalse(d, Lt(inv_val, Set(d, kAlmostZero))))) { + return JXL_FAILURE("Invalid quantization table"); + } + auto val = Div(Set(d, 1.0f), inv_val); + StoreU(val, d, table + *pos + i); + StoreU(inv_val, d, inv_table + *pos + i); + } + (*pos) += 3 * num; + + // Ensure that the lowest frequencies have a 0 inverse table. + // This does not affect en/decoding, but allows AC strategy selection to be + // slightly simpler. + size_t xs = DequantMatrices::required_size_x[kind]; + size_t ys = DequantMatrices::required_size_y[kind]; + CoefficientLayout(&ys, &xs); + for (size_t c = 0; c < 3; c++) { + for (size_t y = 0; y < ys; y++) { + for (size_t x = 0; x < xs; x++) { + inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs + + x] = 0; + } + } + } + return true; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { +namespace { + +HWY_EXPORT(ComputeQuantTable); + +static constexpr const float kAlmostZero = 1e-8f; + +Status DecodeDctParams(BitReader* br, DctQuantWeightParams* params) { + params->num_distance_bands = + br->ReadFixedBits<DctQuantWeightParams::kLog2MaxDistanceBands>() + 1; + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < params->num_distance_bands; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, ¶ms->distance_bands[c][i])); + } + if (params->distance_bands[c][0] < kAlmostZero) { + return JXL_FAILURE("Distance band seed is too small"); + } + params->distance_bands[c][0] *= 64.0f; + } + return true; +} + +Status Decode(BitReader* br, QuantEncoding* encoding, size_t required_size_x, + size_t required_size_y, size_t idx, + ModularFrameDecoder* modular_frame_decoder) { + size_t required_size = required_size_x * required_size_y; + required_size_x *= kBlockDim; + required_size_y *= kBlockDim; + int mode = br->ReadFixedBits<kLog2NumQuantModes>(); + switch (mode) { + case QuantEncoding::kQuantModeLibrary: { + encoding->predefined = br->ReadFixedBits<kCeilLog2NumPredefinedTables>(); + if (encoding->predefined >= kNumPredefinedTables) { + return JXL_FAILURE("Invalid predefined table"); + } + break; + } + case QuantEncoding::kQuantModeID: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 3; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->idweights[c][i])); + if (std::abs(encoding->idweights[c][i]) < kAlmostZero) { + return JXL_FAILURE("ID Quantizer is too small"); + } + encoding->idweights[c][i] *= 64; + } + } + break; + } + case QuantEncoding::kQuantModeDCT2: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 6; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->dct2weights[c][i])); + if (std::abs(encoding->dct2weights[c][i]) < kAlmostZero) { + return JXL_FAILURE("Quantizer is too small"); + } + encoding->dct2weights[c][i] *= 64; + } + } + break; + } + case QuantEncoding::kQuantModeDCT4X8: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR( + F16Coder::Read(br, &encoding->dct4x8multipliers[c])); + if (std::abs(encoding->dct4x8multipliers[c]) < kAlmostZero) { + return JXL_FAILURE("DCT4X8 multiplier is too small"); + } + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeDCT4: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 2; i++) { + JXL_RETURN_IF_ERROR( + F16Coder::Read(br, &encoding->dct4multipliers[c][i])); + if (std::abs(encoding->dct4multipliers[c][i]) < kAlmostZero) { + return JXL_FAILURE("DCT4 multiplier is too small"); + } + } + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeAFV: { + if (required_size != 1) return JXL_FAILURE("Invalid mode"); + for (size_t c = 0; c < 3; c++) { + for (size_t i = 0; i < 9; i++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->afv_weights[c][i])); + } + for (size_t i = 0; i < 6; i++) { + encoding->afv_weights[c][i] *= 64; + } + } + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params_afv_4x4)); + break; + } + case QuantEncoding::kQuantModeDCT: { + JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); + break; + } + case QuantEncoding::kQuantModeRAW: { + // Set mode early, to avoid mem-leak. + encoding->mode = QuantEncoding::kQuantModeRAW; + JXL_RETURN_IF_ERROR(ModularFrameDecoder::DecodeQuantTable( + required_size_x, required_size_y, br, encoding, idx, + modular_frame_decoder)); + break; + } + default: + return JXL_FAILURE("Invalid quantization table encoding"); + } + encoding->mode = QuantEncoding::Mode(mode); + return true; +} + +} // namespace + +// These definitions are needed before C++17. +constexpr size_t DequantMatrices::required_size_[]; +constexpr size_t DequantMatrices::required_size_x[]; +constexpr size_t DequantMatrices::required_size_y[]; +constexpr DequantMatrices::QuantTable DequantMatrices::kQuantTable[]; + +Status DequantMatrices::Decode(BitReader* br, + ModularFrameDecoder* modular_frame_decoder) { + size_t all_default = br->ReadBits(1); + size_t num_tables = all_default ? 0 : static_cast<size_t>(kNum); + encodings_.clear(); + encodings_.resize(kNum, QuantEncoding::Library(0)); + for (size_t i = 0; i < num_tables; i++) { + JXL_RETURN_IF_ERROR( + jxl::Decode(br, &encodings_[i], required_size_x[i % kNum], + required_size_y[i % kNum], i, modular_frame_decoder)); + } + computed_mask_ = 0; + return true; +} + +Status DequantMatrices::DecodeDC(BitReader* br) { + bool all_default = br->ReadBits(1); + if (!br->AllReadsWithinBounds()) return JXL_FAILURE("EOS during DecodeDC"); + if (!all_default) { + for (size_t c = 0; c < 3; c++) { + JXL_RETURN_IF_ERROR(F16Coder::Read(br, &dc_quant_[c])); + dc_quant_[c] *= 1.0f / 128.0f; + // Negative values and nearly zero are invalid values. + if (dc_quant_[c] < kAlmostZero) { + return JXL_FAILURE("Invalid dc_quant: coefficient is too small."); + } + inv_dc_quant_[c] = 1.0f / dc_quant_[c]; + } + } + return true; +} + +constexpr float V(float v) { return static_cast<float>(v); } + +namespace { +struct DequantMatricesLibraryDef { + // DCT8 + static constexpr QuantEncodingInternal DCT() { + return QuantEncodingInternal::DCT(DctQuantWeightParams({{{{ + V(3150.0), + V(0.0), + V(-0.4), + V(-0.4), + V(-0.4), + V(-2.0), + }}, + {{ + V(560.0), + V(0.0), + V(-0.3), + V(-0.3), + V(-0.3), + V(-0.3), + }}, + {{ + V(512.0), + V(-2.0), + V(-1.0), + V(0.0), + V(-1.0), + V(-2.0), + }}}}, + 6)); + } + + // Identity + static constexpr QuantEncodingInternal IDENTITY() { + return QuantEncodingInternal::Identity({{{{ + V(280.0), + V(3160.0), + V(3160.0), + }}, + {{ + V(60.0), + V(864.0), + V(864.0), + }}, + {{ + V(18.0), + V(200.0), + V(200.0), + }}}}); + } + + // DCT2 + static constexpr QuantEncodingInternal DCT2X2() { + return QuantEncodingInternal::DCT2({{{{ + V(3840.0), + V(2560.0), + V(1280.0), + V(640.0), + V(480.0), + V(300.0), + }}, + {{ + V(960.0), + V(640.0), + V(320.0), + V(180.0), + V(140.0), + V(120.0), + }}, + {{ + V(640.0), + V(320.0), + V(128.0), + V(64.0), + V(32.0), + V(16.0), + }}}}); + } + + // DCT4 (quant_kind 3) + static constexpr QuantEncodingInternal DCT4X4() { + return QuantEncodingInternal::DCT4(DctQuantWeightParams({{{{ + V(2200.0), + V(0.0), + V(0.0), + V(0.0), + }}, + {{ + V(392.0), + V(0.0), + V(0.0), + V(0.0), + }}, + {{ + V(112.0), + V(-0.25), + V(-0.25), + V(-0.5), + }}}}, + 4), + /* kMul */ + {{{{ + V(1.0), + V(1.0), + }}, + {{ + V(1.0), + V(1.0), + }}, + {{ + V(1.0), + V(1.0), + }}}}); + } + + // DCT16 + static constexpr QuantEncodingInternal DCT16X16() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(8996.8725711814115328), + V(-1.3000777393353804), + V(-0.49424529824571225), + V(-0.439093774457103443), + V(-0.6350101832695744), + V(-0.90177264050827612), + V(-1.6162099239887414), + }}, + {{ + V(3191.48366296844234752), + V(-0.67424582104194355), + V(-0.80745813428471001), + V(-0.44925837484843441), + V(-0.35865440981033403), + V(-0.31322389111877305), + V(-0.37615025315725483), + }}, + {{ + V(1157.50408145487200256), + V(-2.0531423165804414), + V(-1.4), + V(-0.50687130033378396), + V(-0.42708730624733904), + V(-1.4856834539296244), + V(-4.9209142884401604), + }}}}, + 7)); + } + + // DCT32 + static constexpr QuantEncodingInternal DCT32X32() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(15718.40830982518931456), + V(-1.025), + V(-0.98), + V(-0.9012), + V(-0.4), + V(-0.48819395464), + V(-0.421064), + V(-0.27), + }}, + {{ + V(7305.7636810695983104), + V(-0.8041958212306401), + V(-0.7633036457487539), + V(-0.55660379990111464), + V(-0.49785304658857626), + V(-0.43699592683512467), + V(-0.40180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(3803.53173721215041536), + V(-3.060733579805728), + V(-2.0413270132490346), + V(-2.0235650159727417), + V(-0.5495389509954993), + V(-0.4), + V(-0.4), + V(-0.3), + }}}}, + 8)); + } + + // DCT16X8 + static constexpr QuantEncodingInternal DCT8X16() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(7240.7734393502), + V(-0.7), + V(-0.7), + V(-0.2), + V(-0.2), + V(-0.2), + V(-0.5), + }}, + {{ + V(1448.15468787004), + V(-0.5), + V(-0.5), + V(-0.5), + V(-0.2), + V(-0.2), + V(-0.2), + }}, + {{ + V(506.854140754517), + V(-1.4), + V(-0.2), + V(-0.5), + V(-0.5), + V(-1.5), + V(-3.6), + }}}}, + 7)); + } + + // DCT32X8 + static constexpr QuantEncodingInternal DCT8X32() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(16283.2494710648897), + V(-1.7812845336559429), + V(-1.6309059012653515), + V(-1.0382179034313539), + V(-0.85), + V(-0.7), + V(-0.9), + V(-1.2360638576849587), + }}, + {{ + V(5089.15750884921511936), + V(-0.320049391452786891), + V(-0.35362849922161446), + V(-0.30340000000000003), + V(-0.61), + V(-0.5), + V(-0.5), + V(-0.6), + }}, + {{ + V(3397.77603275308720128), + V(-0.321327362693153371), + V(-0.34507619223117997), + V(-0.70340000000000003), + V(-0.9), + V(-1.0), + V(-1.0), + V(-1.1754605576265209), + }}}}, + 8)); + } + + // DCT32X16 + static constexpr QuantEncodingInternal DCT16X32() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(13844.97076442300573), + V(-0.97113799999999995), + V(-0.658), + V(-0.42026), + V(-0.22712), + V(-0.2206), + V(-0.226), + V(-0.6), + }}, + {{ + V(4798.964084220744293), + V(-0.61125308982767057), + V(-0.83770786552491361), + V(-0.79014862079498627), + V(-0.2692727459704829), + V(-0.38272769465388551), + V(-0.22924222653091453), + V(-0.20719098826199578), + }}, + {{ + V(1807.236946760964614), + V(-1.2), + V(-1.2), + V(-0.7), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + + // DCT4X8 and 8x4 + static constexpr QuantEncodingInternal DCT4X8() { + return QuantEncodingInternal::DCT4X8( + DctQuantWeightParams({{ + {{ + V(2198.050556016380522), + V(-0.96269623020744692), + V(-0.76194253026666783), + V(-0.6551140670773547), + }}, + {{ + V(764.3655248643528689), + V(-0.92630200888366945), + V(-0.9675229603596517), + V(-0.27845290869168118), + }}, + {{ + V(527.107573587542228), + V(-1.4594385811273854), + V(-1.450082094097871593), + V(-1.5843722511996204), + }}, + }}, + 4), + /* kMuls */ + {{ + V(1.0), + V(1.0), + V(1.0), + }}); + } + // AFV + static QuantEncodingInternal AFV0() { + return QuantEncodingInternal::AFV(DCT4X8().dct_params, DCT4X4().dct_params, + {{{{ + // 4x4/4x8 DC tendency. + V(3072.0), + V(3072.0), + // AFV corner. + V(256.0), + V(256.0), + V(256.0), + // AFV high freqs. + V(414.0), + V(0.0), + V(0.0), + V(0.0), + }}, + {{ + // 4x4/4x8 DC tendency. + V(1024.0), + V(1024.0), + // AFV corner. + V(50), + V(50), + V(50), + // AFV high freqs. + V(58.0), + V(0.0), + V(0.0), + V(0.0), + }}, + {{ + // 4x4/4x8 DC tendency. + V(384.0), + V(384.0), + // AFV corner. + V(12.0), + V(12.0), + V(12.0), + // AFV high freqs. + V(22.0), + V(-0.25), + V(-0.25), + V(-0.25), + }}}}); + } + + // DCT64 + static QuantEncodingInternal DCT64X64() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(0.9 * 26629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(0.9 * 9311.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(0.9 * 4992.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + + // DCT64X32 + static QuantEncodingInternal DCT32X64() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(0.65 * 23629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(0.65 * 8611.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(0.65 * 4492.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + // DCT128X128 + static QuantEncodingInternal DCT128X128() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(1.8 * 26629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(1.8 * 9311.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(1.8 * 4992.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + + // DCT128X64 + static QuantEncodingInternal DCT64X128() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(1.3 * 23629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(1.3 * 8611.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(1.3 * 4492.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + // DCT256X256 + static QuantEncodingInternal DCT256X256() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(3.6 * 26629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(3.6 * 9311.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(3.6 * 4992.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } + + // DCT256X128 + static QuantEncodingInternal DCT128X256() { + return QuantEncodingInternal::DCT( + DctQuantWeightParams({{{{ + V(2.6 * 23629.073922049845), + V(-1.025), + V(-0.78), + V(-0.65012), + V(-0.19041574084286472), + V(-0.20819395464), + V(-0.421064), + V(-0.32733845535848671), + }}, + {{ + V(2.6 * 8611.3238710010046), + V(-0.3041958212306401), + V(-0.3633036457487539), + V(-0.35660379990111464), + V(-0.3443074455424403), + V(-0.33699592683512467), + V(-0.30180866526242109), + V(-0.27321683125358037), + }}, + {{ + V(2.6 * 4492.2486445538634), + V(-1.2), + V(-1.2), + V(-0.8), + V(-0.7), + V(-0.7), + V(-0.4), + V(-0.5), + }}}}, + 8)); + } +}; +} // namespace + +DequantMatrices::DequantLibraryInternal DequantMatrices::LibraryInit() { + static_assert(kNum == 17, + "Update this function when adding new quantization kinds."); + static_assert(kNumPredefinedTables == 1, + "Update this function when adding new quantization matrices to " + "the library."); + + // The library and the indices need to be kept in sync manually. + static_assert(0 == DCT, "Update the DequantLibrary array below."); + static_assert(1 == IDENTITY, "Update the DequantLibrary array below."); + static_assert(2 == DCT2X2, "Update the DequantLibrary array below."); + static_assert(3 == DCT4X4, "Update the DequantLibrary array below."); + static_assert(4 == DCT16X16, "Update the DequantLibrary array below."); + static_assert(5 == DCT32X32, "Update the DequantLibrary array below."); + static_assert(6 == DCT8X16, "Update the DequantLibrary array below."); + static_assert(7 == DCT8X32, "Update the DequantLibrary array below."); + static_assert(8 == DCT16X32, "Update the DequantLibrary array below."); + static_assert(9 == DCT4X8, "Update the DequantLibrary array below."); + static_assert(10 == AFV0, "Update the DequantLibrary array below."); + static_assert(11 == DCT64X64, "Update the DequantLibrary array below."); + static_assert(12 == DCT32X64, "Update the DequantLibrary array below."); + static_assert(13 == DCT128X128, "Update the DequantLibrary array below."); + static_assert(14 == DCT64X128, "Update the DequantLibrary array below."); + static_assert(15 == DCT256X256, "Update the DequantLibrary array below."); + static_assert(16 == DCT128X256, "Update the DequantLibrary array below."); + return DequantMatrices::DequantLibraryInternal{{ + DequantMatricesLibraryDef::DCT(), + DequantMatricesLibraryDef::IDENTITY(), + DequantMatricesLibraryDef::DCT2X2(), + DequantMatricesLibraryDef::DCT4X4(), + DequantMatricesLibraryDef::DCT16X16(), + DequantMatricesLibraryDef::DCT32X32(), + DequantMatricesLibraryDef::DCT8X16(), + DequantMatricesLibraryDef::DCT8X32(), + DequantMatricesLibraryDef::DCT16X32(), + DequantMatricesLibraryDef::DCT4X8(), + DequantMatricesLibraryDef::AFV0(), + DequantMatricesLibraryDef::DCT64X64(), + DequantMatricesLibraryDef::DCT32X64(), + // Same default for large transforms (128+) as for 64x* transforms. + DequantMatricesLibraryDef::DCT128X128(), + DequantMatricesLibraryDef::DCT64X128(), + DequantMatricesLibraryDef::DCT256X256(), + DequantMatricesLibraryDef::DCT128X256(), + }}; +} + +const QuantEncoding* DequantMatrices::Library() { + static const DequantMatrices::DequantLibraryInternal kDequantLibrary = + DequantMatrices::LibraryInit(); + // Downcast the result to a const QuantEncoding* from QuantEncodingInternal* + // since the subclass (QuantEncoding) doesn't add any new members and users + // will need to upcast to QuantEncodingInternal to access the members of that + // class. This allows to have kDequantLibrary as a constexpr value while still + // allowing to create QuantEncoding::RAW() instances that use std::vector in + // C++11. + return reinterpret_cast<const QuantEncoding*>(kDequantLibrary.data()); +} + +DequantMatrices::DequantMatrices() { + encodings_.resize(size_t(QuantTable::kNum), QuantEncoding::Library(0)); + size_t pos = 0; + size_t offsets[kNum * 3]; + for (size_t i = 0; i < size_t(QuantTable::kNum); i++) { + size_t num = required_size_[i] * kDCTBlockSize; + for (size_t c = 0; c < 3; c++) { + offsets[3 * i + c] = pos + c * num; + } + pos += 3 * num; + } + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + for (size_t c = 0; c < 3; c++) { + table_offsets_[i * 3 + c] = offsets[kQuantTable[i] * 3 + c]; + } + } +} + +Status DequantMatrices::EnsureComputed(uint32_t acs_mask) { + const QuantEncoding* library = Library(); + + if (!table_storage_) { + table_storage_ = hwy::AllocateAligned<float>(2 * kTotalTableSize); + table_ = table_storage_.get(); + inv_table_ = table_storage_.get() + kTotalTableSize; + } + + size_t offsets[kNum * 3 + 1]; + size_t pos = 0; + for (size_t i = 0; i < kNum; i++) { + size_t num = required_size_[i] * kDCTBlockSize; + for (size_t c = 0; c < 3; c++) { + offsets[3 * i + c] = pos + c * num; + } + pos += 3 * num; + } + offsets[kNum * 3] = pos; + JXL_ASSERT(pos == kTotalTableSize); + + uint32_t kind_mask = 0; + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + if (acs_mask & (1u << i)) { + kind_mask |= 1u << kQuantTable[i]; + } + } + uint32_t computed_kind_mask = 0; + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + if (computed_mask_ & (1u << i)) { + computed_kind_mask |= 1u << kQuantTable[i]; + } + } + for (size_t table = 0; table < kNum; table++) { + if ((1 << table) & computed_kind_mask) continue; + if ((1 << table) & ~kind_mask) continue; + size_t pos = offsets[table * 3]; + if (encodings_[table].mode == QuantEncoding::kQuantModeLibrary) { + JXL_CHECK(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( + library[table], table_storage_.get(), + table_storage_.get() + kTotalTableSize, table, QuantTable(table), + &pos)); + } else { + JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( + encodings_[table], table_storage_.get(), + table_storage_.get() + kTotalTableSize, table, QuantTable(table), + &pos)); + } + JXL_ASSERT(pos == offsets[table * 3 + 3]); + } + computed_mask_ |= acs_mask; + + return true; +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/quant_weights.h b/third_party/jpeg-xl/lib/jxl/quant_weights.h new file mode 100644 index 0000000000..3004176aba --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quant_weights.h @@ -0,0 +1,446 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_QUANT_WEIGHTS_H_ +#define LIB_JXL_QUANT_WEIGHTS_H_ + +#include <stdint.h> +#include <string.h> + +#include <array> +#include <hwy/aligned_allocator.h> +#include <utility> +#include <vector> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/image.h" + +namespace jxl { + +template <typename T, size_t N> +constexpr T ArraySum(T (&a)[N], size_t i = N - 1) { + static_assert(N > 0, "Trying to compute the sum of an empty array"); + return i == 0 ? a[0] : a[i] + ArraySum(a, i - 1); +} + +static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea; +static constexpr size_t kNumPredefinedTables = 1; +static constexpr size_t kCeilLog2NumPredefinedTables = 0; +static constexpr size_t kLog2NumQuantModes = 3; + +struct DctQuantWeightParams { + static constexpr size_t kLog2MaxDistanceBands = 4; + static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands); + typedef std::array<std::array<float, kMaxDistanceBands>, 3> + DistanceBandsArray; + + size_t num_distance_bands = 0; + DistanceBandsArray distance_bands = {}; + + constexpr DctQuantWeightParams() : num_distance_bands(0) {} + + constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands, + size_t num_dist_bands) + : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {} + + template <size_t num_dist_bands> + explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) { + num_distance_bands = num_dist_bands; + for (size_t c = 0; c < 3; c++) { + memcpy(distance_bands[c].data(), dist_bands[c], + sizeof(float) * num_dist_bands); + } + } +}; + +// NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) +struct QuantEncodingInternal { + enum Mode { + kQuantModeLibrary, + kQuantModeID, + kQuantModeDCT2, + kQuantModeDCT4, + kQuantModeDCT4X8, + kQuantModeAFV, + kQuantModeDCT, + kQuantModeRAW, + }; + + template <Mode mode> + struct Tag {}; + + typedef std::array<std::array<float, 3>, 3> IdWeights; + typedef std::array<std::array<float, 6>, 3> DCT2Weights; + typedef std::array<std::array<float, 2>, 3> DCT4Multipliers; + typedef std::array<std::array<float, 9>, 3> AFVWeights; + typedef std::array<float, 3> DCT4x8Multipliers; + + static constexpr QuantEncodingInternal Library(uint8_t predefined) { + return ((predefined < kNumPredefinedTables) || + JXL_ABORT("Assert predefined < kNumPredefinedTables")), + QuantEncodingInternal(Tag<kQuantModeLibrary>(), predefined); + } + constexpr QuantEncodingInternal(Tag<kQuantModeLibrary> /* tag */, + uint8_t predefined) + : mode(kQuantModeLibrary), predefined(predefined) {} + + // Identity + // xybweights is an array of {xweights, yweights, bweights}. + static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) { + return QuantEncodingInternal(Tag<kQuantModeID>(), xybweights); + } + constexpr QuantEncodingInternal(Tag<kQuantModeID> /* tag */, + const IdWeights& xybweights) + : mode(kQuantModeID), idweights(xybweights) {} + + // DCT2 + static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) { + return QuantEncodingInternal(Tag<kQuantModeDCT2>(), xybweights); + } + constexpr QuantEncodingInternal(Tag<kQuantModeDCT2> /* tag */, + const DCT2Weights& xybweights) + : mode(kQuantModeDCT2), dct2weights(xybweights) {} + + // DCT4 + static constexpr QuantEncodingInternal DCT4( + const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) { + return QuantEncodingInternal(Tag<kQuantModeDCT4>(), params, xybmul); + } + constexpr QuantEncodingInternal(Tag<kQuantModeDCT4> /* tag */, + const DctQuantWeightParams& params, + const DCT4Multipliers& xybmul) + : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {} + + // DCT4x8 + static constexpr QuantEncodingInternal DCT4X8( + const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) { + return QuantEncodingInternal(Tag<kQuantModeDCT4X8>(), params, xybmul); + } + constexpr QuantEncodingInternal(Tag<kQuantModeDCT4X8> /* tag */, + const DctQuantWeightParams& params, + const DCT4x8Multipliers& xybmul) + : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {} + + // DCT + static constexpr QuantEncodingInternal DCT( + const DctQuantWeightParams& params) { + return QuantEncodingInternal(Tag<kQuantModeDCT>(), params); + } + constexpr QuantEncodingInternal(Tag<kQuantModeDCT> /* tag */, + const DctQuantWeightParams& params) + : mode(kQuantModeDCT), dct_params(params) {} + + // AFV + static constexpr QuantEncodingInternal AFV( + const DctQuantWeightParams& params4x8, + const DctQuantWeightParams& params4x4, const AFVWeights& weights) { + return QuantEncodingInternal(Tag<kQuantModeAFV>(), params4x8, params4x4, + weights); + } + constexpr QuantEncodingInternal(Tag<kQuantModeAFV> /* tag */, + const DctQuantWeightParams& params4x8, + const DctQuantWeightParams& params4x4, + const AFVWeights& weights) + : mode(kQuantModeAFV), + dct_params(params4x8), + afv_weights(weights), + dct_params_afv_4x4(params4x4) {} + + // This constructor is not constexpr so it can't be used in any of the + // constexpr cases above. + explicit QuantEncodingInternal(Mode mode) : mode(mode) {} + + Mode mode; + + // Weights for DCT4+ tables. + DctQuantWeightParams dct_params; + + union { + // Weights for identity. + IdWeights idweights; + + // Weights for DCT2. + DCT2Weights dct2weights; + + // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV. + DCT4Multipliers dct4multipliers; + + // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1, + // 0); {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) + + // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated + // as in GetQuantWeights for DC and are used for other coefficients. + AFVWeights afv_weights = {}; + + // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4. + DCT4x8Multipliers dct4x8multipliers; + + // Only used in kQuantModeRAW mode. + struct { + // explicit quantization table (like in JPEG) + std::vector<int>* qtable = nullptr; + float qtable_den = 1.f / (8 * 255); + } qraw; + }; + + // Weights for 4x4 sub-block in AFV. + DctQuantWeightParams dct_params_afv_4x4; + + union { + // Which predefined table to use. Only used if mode is kQuantModeLibrary. + uint8_t predefined = 0; + + // Which other quant table to copy; must copy from a table that comes before + // the current one. Only used if mode is kQuantModeCopy. + uint8_t source; + }; +}; + +class QuantEncoding final : public QuantEncodingInternal { + public: + QuantEncoding(const QuantEncoding& other) + : QuantEncodingInternal( + static_cast<const QuantEncodingInternal&>(other)) { + if (mode == kQuantModeRAW && qraw.qtable) { + // Need to make a copy of the passed *qtable. + qraw.qtable = new std::vector<int>(*other.qraw.qtable); + } + } + QuantEncoding(QuantEncoding&& other) noexcept + : QuantEncodingInternal( + static_cast<const QuantEncodingInternal&>(other)) { + // Steal the qtable from the other object if any. + if (mode == kQuantModeRAW) { + other.qraw.qtable = nullptr; + } + } + QuantEncoding& operator=(const QuantEncoding& other) { + if (mode == kQuantModeRAW && qraw.qtable) { + delete qraw.qtable; + } + *static_cast<QuantEncodingInternal*>(this) = + QuantEncodingInternal(static_cast<const QuantEncodingInternal&>(other)); + if (mode == kQuantModeRAW && qraw.qtable) { + // Need to make a copy of the passed *qtable. + qraw.qtable = new std::vector<int>(*other.qraw.qtable); + } + return *this; + } + + ~QuantEncoding() { + if (mode == kQuantModeRAW && qraw.qtable) { + delete qraw.qtable; + } + } + + // Wrappers of the QuantEncodingInternal:: static functions that return a + // QuantEncoding instead. This is using the explicit and private cast from + // QuantEncodingInternal to QuantEncoding, which would be inlined anyway. + // In general, you should use this wrappers. The only reason to directly + // create a QuantEncodingInternal instance is if you need a constexpr version + // of this class. Note that RAW() is not supported in that case since it uses + // a std::vector. + static QuantEncoding Library(uint8_t predefined_arg) { + return QuantEncoding(QuantEncodingInternal::Library(predefined_arg)); + } + static QuantEncoding Identity(const IdWeights& xybweights) { + return QuantEncoding(QuantEncodingInternal::Identity(xybweights)); + } + static QuantEncoding DCT2(const DCT2Weights& xybweights) { + return QuantEncoding(QuantEncodingInternal::DCT2(xybweights)); + } + static QuantEncoding DCT4(const DctQuantWeightParams& params, + const DCT4Multipliers& xybmul) { + return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul)); + } + static QuantEncoding DCT4X8(const DctQuantWeightParams& params, + const DCT4x8Multipliers& xybmul) { + return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul)); + } + static QuantEncoding DCT(const DctQuantWeightParams& params) { + return QuantEncoding(QuantEncodingInternal::DCT(params)); + } + static QuantEncoding AFV(const DctQuantWeightParams& params4x8, + const DctQuantWeightParams& params4x4, + const AFVWeights& weights) { + return QuantEncoding( + QuantEncodingInternal::AFV(params4x8, params4x4, weights)); + } + + // RAW, note that this one is not a constexpr one. + static QuantEncoding RAW(const std::vector<int>& qtable, int shift = 0) { + QuantEncoding encoding(kQuantModeRAW); + encoding.qraw.qtable = new std::vector<int>(); + *encoding.qraw.qtable = qtable; + encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255)); + return encoding; + } + + private: + explicit QuantEncoding(const QuantEncodingInternal& other) + : QuantEncodingInternal(other) {} + + explicit QuantEncoding(QuantEncodingInternal::Mode mode_arg) + : QuantEncodingInternal(mode_arg) {} +}; + +// A constexpr QuantEncodingInternal instance is often downcasted to the +// QuantEncoding subclass even if the instance wasn't an instance of the +// subclass. This is safe because user will upcast to QuantEncodingInternal to +// access any of its members. +static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal), + "Don't add any members to QuantEncoding"); + +// Let's try to keep these 2**N for possible future simplicity. +const float kInvDCQuant[3] = { + 4096.0f, + 512.0f, + 256.0f, +}; + +const float kDCQuant[3] = { + 1.0f / kInvDCQuant[0], + 1.0f / kInvDCQuant[1], + 1.0f / kInvDCQuant[2], +}; + +class ModularFrameEncoder; +class ModularFrameDecoder; + +class DequantMatrices { + public: + enum QuantTable : size_t { + DCT = 0, + IDENTITY, + DCT2X2, + DCT4X4, + DCT16X16, + DCT32X32, + // DCT16X8 + DCT8X16, + // DCT32X8 + DCT8X32, + // DCT32X16 + DCT16X32, + DCT4X8, + // DCT8X4 + AFV0, + // AFV1 + // AFV2 + // AFV3 + DCT64X64, + // DCT64X32, + DCT32X64, + DCT128X128, + // DCT128X64, + DCT64X128, + DCT256X256, + // DCT256X128, + DCT128X256, + kNum + }; + + static constexpr QuantTable kQuantTable[] = { + QuantTable::DCT, QuantTable::IDENTITY, QuantTable::DCT2X2, + QuantTable::DCT4X4, QuantTable::DCT16X16, QuantTable::DCT32X32, + QuantTable::DCT8X16, QuantTable::DCT8X16, QuantTable::DCT8X32, + QuantTable::DCT8X32, QuantTable::DCT16X32, QuantTable::DCT16X32, + QuantTable::DCT4X8, QuantTable::DCT4X8, QuantTable::AFV0, + QuantTable::AFV0, QuantTable::AFV0, QuantTable::AFV0, + QuantTable::DCT64X64, QuantTable::DCT32X64, QuantTable::DCT32X64, + QuantTable::DCT128X128, QuantTable::DCT64X128, QuantTable::DCT64X128, + QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256, + }; + static_assert(AcStrategy::kNumValidStrategies == + sizeof(kQuantTable) / sizeof *kQuantTable, + "Update this array when adding or removing AC strategies."); + + DequantMatrices(); + + static const QuantEncoding* Library(); + + typedef std::array<QuantEncodingInternal, kNumPredefinedTables * kNum> + DequantLibraryInternal; + // Return the array of library kNumPredefinedTables QuantEncoding entries as + // a constexpr array. Use Library() to obtain a pointer to the copy in the + // .cc file. + static DequantLibraryInternal LibraryInit(); + + // Returns aligned memory. + JXL_INLINE const float* Matrix(size_t quant_kind, size_t c) const { + JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); + JXL_DASSERT((1 << quant_kind) & computed_mask_); + return &table_[table_offsets_[quant_kind * 3 + c]]; + } + + JXL_INLINE const float* InvMatrix(size_t quant_kind, size_t c) const { + JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); + JXL_DASSERT((1 << quant_kind) & computed_mask_); + return &inv_table_[table_offsets_[quant_kind * 3 + c]]; + } + + // DC quants are used in modular mode for XYB multipliers. + JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; } + JXL_INLINE const float* DCQuants() const { return dc_quant_; } + + JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; } + + // For encoder. + void SetEncodings(const std::vector<QuantEncoding>& encodings) { + encodings_ = encodings; + computed_mask_ = 0; + } + + // For encoder. + void SetDCQuant(const float dc[3]) { + for (size_t c = 0; c < 3; c++) { + dc_quant_[c] = 1.0f / dc[c]; + inv_dc_quant_[c] = dc[c]; + } + } + + Status Decode(BitReader* br, + ModularFrameDecoder* modular_frame_decoder = nullptr); + Status DecodeDC(BitReader* br); + + const std::vector<QuantEncoding>& encodings() const { return encodings_; } + + static constexpr size_t required_size_x[] = {1, 1, 1, 1, 2, 4, 1, 1, 2, + 1, 1, 8, 4, 16, 8, 32, 16}; + static_assert(kNum == sizeof(required_size_x) / sizeof(*required_size_x), + "Update this array when adding or removing quant tables."); + + static constexpr size_t required_size_y[] = {1, 1, 1, 1, 2, 4, 2, 4, 4, + 1, 1, 8, 8, 16, 16, 32, 32}; + static_assert(kNum == sizeof(required_size_y) / sizeof(*required_size_y), + "Update this array when adding or removing quant tables."); + + Status EnsureComputed(uint32_t acs_mask); + + private: + static constexpr size_t required_size_[] = { + 1, 1, 1, 1, 4, 16, 2, 4, 8, 1, 1, 64, 32, 256, 128, 1024, 512}; + static_assert(kNum == sizeof(required_size_) / sizeof(*required_size_), + "Update this array when adding or removing quant tables."); + static constexpr size_t kTotalTableSize = + ArraySum(required_size_) * kDCTBlockSize * 3; + + uint32_t computed_mask_ = 0; + // kTotalTableSize entries followed by kTotalTableSize for inv_table + hwy::AlignedFreeUniquePtr<float[]> table_storage_; + const float* table_; + const float* inv_table_; + float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]}; + float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]}; + size_t table_offsets_[AcStrategy::kNumValidStrategies * 3]; + std::vector<QuantEncoding> encodings_; +}; + +} // namespace jxl + +#endif // LIB_JXL_QUANT_WEIGHTS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/quant_weights_test.cc b/third_party/jpeg-xl/lib/jxl/quant_weights_test.cc new file mode 100644 index 0000000000..2dd513804c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quant_weights_test.cc @@ -0,0 +1,241 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +#include "lib/jxl/quant_weights.h" + +#include <stdlib.h> + +#include <algorithm> +#include <cmath> +#include <hwy/base.h> // HWY_ALIGN_MAX +#include <hwy/tests/hwy_gtest.h> +#include <numeric> + +#include "lib/jxl/base/random.h" +#include "lib/jxl/dct_for_test.h" +#include "lib/jxl/dec_transforms_testonly.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_quant_weights.h" +#include "lib/jxl/enc_transforms.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +template <typename T> +void CheckSimilar(T a, T b) { + EXPECT_EQ(a, b); +} +// minimum exponent = -15. +template <> +void CheckSimilar(float a, float b) { + float m = std::max(std::abs(a), std::abs(b)); + // 10 bits of precision are used in the format. Relative error should be + // below 2^-10. + EXPECT_LE(std::abs(a - b), m / 1024.0f) << "a: " << a << " b: " << b; +} + +TEST(QuantWeightsTest, DC) { + DequantMatrices mat; + float dc_quant[3] = {1e+5, 1e+3, 1e+1}; + DequantMatricesSetCustomDC(&mat, dc_quant); + for (size_t c = 0; c < 3; c++) { + CheckSimilar(mat.InvDCQuant(c), dc_quant[c]); + } +} + +void RoundtripMatrices(const std::vector<QuantEncoding>& encodings) { + ASSERT_TRUE(encodings.size() == DequantMatrices::kNum); + DequantMatrices mat; + CodecMetadata metadata; + FrameHeader frame_header(&metadata); + ModularFrameEncoder encoder(frame_header, CompressParams{}); + DequantMatricesSetCustom(&mat, encodings, &encoder); + const std::vector<QuantEncoding>& encodings_dec = mat.encodings(); + for (size_t i = 0; i < encodings.size(); i++) { + const QuantEncoding& e = encodings[i]; + const QuantEncoding& d = encodings_dec[i]; + // Check values roundtripped correctly. + EXPECT_EQ(e.mode, d.mode); + EXPECT_EQ(e.predefined, d.predefined); + EXPECT_EQ(e.source, d.source); + + EXPECT_EQ(static_cast<uint64_t>(e.dct_params.num_distance_bands), + static_cast<uint64_t>(d.dct_params.num_distance_bands)); + for (size_t c = 0; c < 3; c++) { + for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) { + CheckSimilar(e.dct_params.distance_bands[c][j], + d.dct_params.distance_bands[c][j]); + } + } + + if (e.mode == QuantEncoding::kQuantModeRAW) { + EXPECT_FALSE(!e.qraw.qtable); + EXPECT_FALSE(!d.qraw.qtable); + EXPECT_EQ(e.qraw.qtable->size(), d.qraw.qtable->size()); + for (size_t j = 0; j < e.qraw.qtable->size(); j++) { + EXPECT_EQ((*e.qraw.qtable)[j], (*d.qraw.qtable)[j]); + } + EXPECT_NEAR(e.qraw.qtable_den, d.qraw.qtable_den, 1e-7f); + } else { + // modes different than kQuantModeRAW use one of the other fields used + // here, which all happen to be arrays of floats. + for (size_t c = 0; c < 3; c++) { + for (size_t j = 0; j < 3; j++) { + CheckSimilar(e.idweights[c][j], d.idweights[c][j]); + } + for (size_t j = 0; j < 6; j++) { + CheckSimilar(e.dct2weights[c][j], d.dct2weights[c][j]); + } + for (size_t j = 0; j < 2; j++) { + CheckSimilar(e.dct4multipliers[c][j], d.dct4multipliers[c][j]); + } + CheckSimilar(e.dct4x8multipliers[c], d.dct4x8multipliers[c]); + for (size_t j = 0; j < 9; j++) { + CheckSimilar(e.afv_weights[c][j], d.afv_weights[c][j]); + } + for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) { + CheckSimilar(e.dct_params_afv_4x4.distance_bands[c][j], + d.dct_params_afv_4x4.distance_bands[c][j]); + } + } + } + } +} + +TEST(QuantWeightsTest, AllDefault) { + std::vector<QuantEncoding> encodings(DequantMatrices::kNum, + QuantEncoding::Library(0)); + RoundtripMatrices(encodings); +} + +void TestSingleQuantMatrix(DequantMatrices::QuantTable kind) { + std::vector<QuantEncoding> encodings(DequantMatrices::kNum, + QuantEncoding::Library(0)); + encodings[kind] = DequantMatrices::Library()[kind]; + RoundtripMatrices(encodings); +} + +// Ensure we can reasonably represent default quant tables. +TEST(QuantWeightsTest, DCT) { TestSingleQuantMatrix(DequantMatrices::DCT); } +TEST(QuantWeightsTest, IDENTITY) { + TestSingleQuantMatrix(DequantMatrices::IDENTITY); +} +TEST(QuantWeightsTest, DCT2X2) { + TestSingleQuantMatrix(DequantMatrices::DCT2X2); +} +TEST(QuantWeightsTest, DCT4X4) { + TestSingleQuantMatrix(DequantMatrices::DCT4X4); +} +TEST(QuantWeightsTest, DCT16X16) { + TestSingleQuantMatrix(DequantMatrices::DCT16X16); +} +TEST(QuantWeightsTest, DCT32X32) { + TestSingleQuantMatrix(DequantMatrices::DCT32X32); +} +TEST(QuantWeightsTest, DCT8X16) { + TestSingleQuantMatrix(DequantMatrices::DCT8X16); +} +TEST(QuantWeightsTest, DCT8X32) { + TestSingleQuantMatrix(DequantMatrices::DCT8X32); +} +TEST(QuantWeightsTest, DCT16X32) { + TestSingleQuantMatrix(DequantMatrices::DCT16X32); +} +TEST(QuantWeightsTest, DCT4X8) { + TestSingleQuantMatrix(DequantMatrices::DCT4X8); +} +TEST(QuantWeightsTest, AFV0) { TestSingleQuantMatrix(DequantMatrices::AFV0); } +TEST(QuantWeightsTest, RAW) { + std::vector<QuantEncoding> encodings(DequantMatrices::kNum, + QuantEncoding::Library(0)); + std::vector<int> matrix(3 * 32 * 32); + Rng rng(0); + for (size_t i = 0; i < matrix.size(); i++) matrix[i] = rng.UniformI(1, 256); + encodings[DequantMatrices::kQuantTable[AcStrategy::DCT32X32]] = + QuantEncoding::RAW(matrix, 2); + RoundtripMatrices(encodings); +} + +class QuantWeightsTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(QuantWeightsTargetTest); + +TEST_P(QuantWeightsTargetTest, DCTUniform) { + constexpr float kUniformQuant = 4; + float weights[3][2] = {{1.0f / kUniformQuant, 0}, + {1.0f / kUniformQuant, 0}, + {1.0f / kUniformQuant, 0}}; + DctQuantWeightParams dct_params(weights); + std::vector<QuantEncoding> encodings(DequantMatrices::kNum, + QuantEncoding::DCT(dct_params)); + DequantMatrices dequant_matrices; + CodecMetadata metadata; + FrameHeader frame_header(&metadata); + ModularFrameEncoder encoder(frame_header, CompressParams{}); + DequantMatricesSetCustom(&dequant_matrices, encodings, &encoder); + JXL_CHECK(dequant_matrices.EnsureComputed(~0u)); + + const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant, + 1.0f / kUniformQuant}; + DequantMatricesSetCustomDC(&dequant_matrices, dc_quant); + + HWY_ALIGN_MAX float scratch_space[16 * 16 * 5]; + + // DCT8 + { + HWY_ALIGN_MAX float pixels[64]; + std::iota(std::begin(pixels), std::end(pixels), 0); + HWY_ALIGN_MAX float coeffs[64]; + const AcStrategy::Type dct = AcStrategy::DCT; + TransformFromPixels(dct, pixels, 8, coeffs, scratch_space); + HWY_ALIGN_MAX double slow_coeffs[64]; + for (size_t i = 0; i < 64; i++) slow_coeffs[i] = pixels[i]; + DCTSlow<8>(slow_coeffs); + + for (size_t i = 0; i < 64; i++) { + // DCTSlow doesn't multiply/divide by 1/N, so we do it manually. + slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant; + coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) * + dequant_matrices.Matrix(dct, 0)[i]; + } + IDCTSlow<8>(slow_coeffs); + TransformToPixels(dct, coeffs, pixels, 8, scratch_space); + for (size_t i = 0; i < 64; i++) { + EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4); + } + } + + // DCT16 + { + HWY_ALIGN_MAX float pixels[64 * 4]; + std::iota(std::begin(pixels), std::end(pixels), 0); + HWY_ALIGN_MAX float coeffs[64 * 4]; + const AcStrategy::Type dct = AcStrategy::DCT16X16; + TransformFromPixels(dct, pixels, 16, coeffs, scratch_space); + HWY_ALIGN_MAX double slow_coeffs[64 * 4]; + for (size_t i = 0; i < 64 * 4; i++) slow_coeffs[i] = pixels[i]; + DCTSlow<16>(slow_coeffs); + + for (size_t i = 0; i < 64 * 4; i++) { + slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant; + coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) * + dequant_matrices.Matrix(dct, 0)[i]; + } + + IDCTSlow<16>(slow_coeffs); + TransformToPixels(dct, coeffs, pixels, 16, scratch_space); + for (size_t i = 0; i < 64 * 4; i++) { + EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4); + } + } + + // Check that all matrices have the same DC quantization, i.e. that they all + // have the same scaling. + for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { + EXPECT_NEAR(dequant_matrices.Matrix(i, 0)[0], kUniformQuant, 1e-6); + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/quantizer-inl.h b/third_party/jpeg-xl/lib/jxl/quantizer-inl.h new file mode 100644 index 0000000000..64d273c552 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quantizer-inl.h @@ -0,0 +1,74 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#if defined(LIB_JXL_QUANTIZER_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_QUANTIZER_INL_H_ +#undef LIB_JXL_QUANTIZER_INL_H_ +#else +#define LIB_JXL_QUANTIZER_INL_H_ +#endif + +#include <stddef.h> + +#include <hwy/highway.h> +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::And; +using hwy::HWY_NAMESPACE::AndNot; +using hwy::HWY_NAMESPACE::ApproximateReciprocal; +using hwy::HWY_NAMESPACE::Gt; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::IfThenElseZero; +using hwy::HWY_NAMESPACE::Lt; +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::Vec; +using hwy::HWY_NAMESPACE::Xor; + +template <class DI> +HWY_INLINE HWY_MAYBE_UNUSED Vec<Rebind<float, DI>> AdjustQuantBias( + DI di, const size_t c, const Vec<DI> quant_i, + const float* HWY_RESTRICT biases) { + const Rebind<float, DI> df; + + const auto quant = ConvertTo(df, quant_i); + + // Compare |quant|, keep sign bit for negating result. + const auto kSign = BitCast(df, Set(di, INT32_MIN)); + const auto sign = And(quant, kSign); // TODO(janwas): = abs ^ orig + const auto abs_quant = AndNot(kSign, quant); + + // If |x| is 1, kZeroBias creates a different bias for each channel. + // We're implementing the following: + // if (quant == 0) return 0; + // if (quant == 1) return biases[c]; + // if (quant == -1) return -biases[c]; + // return quant - biases[3] / quant; + + // Integer comparison is not helpful because Clang incurs bypass penalties + // from unnecessarily mixing integer and float. + const auto is_01 = Lt(abs_quant, Set(df, 1.125f)); + const auto not_0 = Gt(abs_quant, Zero(df)); + + // Bitwise logic is faster than quant * biases[c]. + const auto one_bias = IfThenElseZero(not_0, Xor(Set(df, biases[c]), sign)); + + // About 2E-5 worse than ReciprocalNR or division. + const auto bias = + NegMulAdd(Set(df, biases[3]), ApproximateReciprocal(quant), quant); + + return IfThenElse(is_01, one_bias, bias); +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_QUANTIZER_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/quantizer.cc b/third_party/jpeg-xl/lib/jxl/quantizer.cc new file mode 100644 index 0000000000..b9ea43e0d6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quantizer.cc @@ -0,0 +1,155 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/quantizer.h" + +#include <string.h> + +#include <algorithm> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/field_encodings.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/quant_weights.h" + +namespace jxl { + +static const int32_t kDefaultQuant = 64; + +constexpr int32_t Quantizer::kQuantMax; + +Quantizer::Quantizer(const DequantMatrices* dequant) + : Quantizer(dequant, kDefaultQuant, kGlobalScaleDenom / kDefaultQuant) {} + +Quantizer::Quantizer(const DequantMatrices* dequant, int quant_dc, + int global_scale) + : global_scale_(global_scale), quant_dc_(quant_dc), dequant_(dequant) { + JXL_ASSERT(dequant_ != nullptr); + RecomputeFromGlobalScale(); + inv_quant_dc_ = inv_global_scale_ / quant_dc_; + + memcpy(zero_bias_, kZeroBiasDefault, sizeof(kZeroBiasDefault)); +} + +void Quantizer::ComputeGlobalScaleAndQuant(float quant_dc, float quant_median, + float quant_median_absd) { + // Target value for the median value in the quant field. + const float kQuantFieldTarget = 5; + // We reduce the median of the quant field by the median absolute deviation: + // higher resolution on highly varying quant fields. + float scale = kGlobalScaleDenom * (quant_median - quant_median_absd) / + kQuantFieldTarget; + // Ensure that new_global_scale is positive and no more than 1<<15. + if (scale < 1) scale = 1; + if (scale > (1 << 15)) scale = 1 << 15; + int new_global_scale = static_cast<int>(scale); + // Ensure that quant_dc_ will always be at least + // 0.625 * kGlobalScaleDenom/kGlobalScaleNumerator = 10. + const int scaled_quant_dc = + static_cast<int>(quant_dc * kGlobalScaleNumerator * 1.6); + if (new_global_scale > scaled_quant_dc) { + new_global_scale = scaled_quant_dc; + if (new_global_scale <= 0) new_global_scale = 1; + } + global_scale_ = new_global_scale; + // Code below uses inv_global_scale_. + RecomputeFromGlobalScale(); + + float fval = quant_dc * inv_global_scale_ + 0.5f; + fval = std::min<float>(1 << 16, fval); + const int new_quant_dc = static_cast<int>(fval); + quant_dc_ = new_quant_dc; + + // quant_dc_ was updated, recompute values. + RecomputeFromGlobalScale(); +} + +void Quantizer::SetQuantFieldRect(const ImageF& qf, const Rect& rect, + ImageI* JXL_RESTRICT raw_quant_field) const { + for (size_t y = 0; y < rect.ysize(); ++y) { + const float* JXL_RESTRICT row_qf = rect.ConstRow(qf, y); + int32_t* JXL_RESTRICT row_qi = rect.Row(raw_quant_field, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + int val = ClampVal(row_qf[x] * inv_global_scale_ + 0.5f); + row_qi[x] = val; + } + } +} + +void Quantizer::SetQuantField(const float quant_dc, const ImageF& qf, + ImageI* JXL_RESTRICT raw_quant_field) { + std::vector<float> data(qf.xsize() * qf.ysize()); + for (size_t y = 0; y < qf.ysize(); ++y) { + const float* JXL_RESTRICT row_qf = qf.Row(y); + for (size_t x = 0; x < qf.xsize(); ++x) { + float quant = row_qf[x]; + data[qf.xsize() * y + x] = quant; + } + } + std::nth_element(data.begin(), data.begin() + data.size() / 2, data.end()); + const float quant_median = data[data.size() / 2]; + std::vector<float> deviations(data.size()); + for (size_t i = 0; i < data.size(); i++) { + deviations[i] = fabsf(data[i] - quant_median); + } + std::nth_element(deviations.begin(), + deviations.begin() + deviations.size() / 2, + deviations.end()); + const float quant_median_absd = deviations[deviations.size() / 2]; + ComputeGlobalScaleAndQuant(quant_dc, quant_median, quant_median_absd); + if (raw_quant_field) { + JXL_CHECK(SameSize(*raw_quant_field, qf)); + SetQuantFieldRect(qf, Rect(qf), raw_quant_field); + } +} + +void Quantizer::SetQuant(float quant_dc, float quant_ac, + ImageI* JXL_RESTRICT raw_quant_field) { + ComputeGlobalScaleAndQuant(quant_dc, quant_ac, 0); + int32_t val = ClampVal(quant_ac * inv_global_scale_ + 0.5f); + FillImage(val, raw_quant_field); +} + +Status QuantizerParams::VisitFields(Visitor* JXL_RESTRICT visitor) { + JXL_QUIET_RETURN_IF_ERROR(visitor->U32( + BitsOffset(11, 1), BitsOffset(11, 2049), BitsOffset(12, 4097), + BitsOffset(16, 8193), 1, &global_scale)); + JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(16), BitsOffset(5, 1), + BitsOffset(8, 1), BitsOffset(16, 1), 1, + &quant_dc)); + return true; +} + +QuantizerParams Quantizer::GetParams() const { + QuantizerParams params; + params.global_scale = global_scale_; + params.quant_dc = quant_dc_; + return params; +} + +Status Quantizer::Decode(BitReader* reader) { + QuantizerParams params; + JXL_RETURN_IF_ERROR(Bundle::Read(reader, ¶ms)); + global_scale_ = static_cast<int>(params.global_scale); + quant_dc_ = static_cast<int>(params.quant_dc); + RecomputeFromGlobalScale(); + return true; +} + +void Quantizer::DumpQuantizationMap(const ImageI& raw_quant_field) const { + printf("Global scale: %d (%.7f)\nDC quant: %d\n", global_scale_, + global_scale_ * 1.0 / kGlobalScaleDenom, quant_dc_); + printf("AC quantization Map:\n"); + for (size_t y = 0; y < raw_quant_field.ysize(); ++y) { + for (size_t x = 0; x < raw_quant_field.xsize(); ++x) { + printf(" %3d", raw_quant_field.Row(y)[x]); + } + printf("\n"); + } +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/quantizer.h b/third_party/jpeg-xl/lib/jxl/quantizer.h new file mode 100644 index 0000000000..4e34ac78e8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quantizer.h @@ -0,0 +1,180 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_QUANTIZER_H_ +#define LIB_JXL_QUANTIZER_H_ + +#include <stddef.h> +#include <stdint.h> +#include <stdlib.h> + +#include <algorithm> +#include <cmath> +#include <utility> +#include <vector> + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/bits.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/dct_util.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/image.h" +#include "lib/jxl/quant_weights.h" + +// Quantizes DC and AC coefficients, with separate quantization tables according +// to the quant_kind (which is currently computed from the AC strategy and the +// block index inside that strategy). + +namespace jxl { + +static constexpr int kGlobalScaleDenom = 1 << 16; +static constexpr int kGlobalScaleNumerator = 4096; + +// zero-biases for quantizing channels X, Y, B +static constexpr float kZeroBiasDefault[3] = {0.5f, 0.5f, 0.5f}; + +// Returns adjusted version of a quantized integer, such that its value is +// closer to the expected value of the original. +// The residuals of AC coefficients that we quantize are not uniformly +// distributed. Numerical experiments show that they have a distribution with +// the "shape" of 1/(1+x^2) [up to some coefficients]. This means that the +// expected value of a coefficient that gets quantized to x will not be x +// itself, but (at least with reasonable approximation): +// - 0 if x is 0 +// - x * biases[c] if x is 1 or -1 +// - x - biases[3]/x otherwise +// This follows from computing the distribution of the quantization bias, which +// can be approximated fairly well by <constant>/x when |x| is at least two. +static constexpr float kBiasNumerator = 0.145f; + +static constexpr float kDefaultQuantBias[4] = { + 1.0f - 0.05465007330715401f, + 1.0f - 0.07005449891748593f, + 1.0f - 0.049935103337343655f, + 0.145f, +}; + +struct QuantizerParams; + +class Quantizer { + public: + explicit Quantizer(const DequantMatrices* dequant); + Quantizer(const DequantMatrices* dequant, int quant_dc, int global_scale); + + static constexpr int32_t kQuantMax = 256; + + static JXL_INLINE int32_t ClampVal(float val) { + return static_cast<int32_t>( + std::max(1.0f, std::min<float>(val, kQuantMax))); + } + + float ScaleGlobalScale(const float scale) { + int new_global_scale = static_cast<int>(global_scale_ * scale + 0.5f); + float scale_out = new_global_scale * 1.0f / global_scale_; + global_scale_ = new_global_scale; + RecomputeFromGlobalScale(); + return scale_out; + } + + // Recomputes other derived fields after global_scale_ has changed. + void RecomputeFromGlobalScale() { + global_scale_float_ = global_scale_ * (1.0 / kGlobalScaleDenom); + inv_global_scale_ = 1.0 * kGlobalScaleDenom / global_scale_; + inv_quant_dc_ = inv_global_scale_ / quant_dc_; + for (size_t c = 0; c < 3; c++) { + mul_dc_[c] = GetDcStep(c); + inv_mul_dc_[c] = GetInvDcStep(c); + } + } + + // Returns scaling factor such that Scale() * (RawDC() or RawQuantField()) + // pixels yields the same float values returned by GetQuantField. + JXL_INLINE float Scale() const { return global_scale_float_; } + + // Reciprocal of Scale(). + JXL_INLINE float InvGlobalScale() const { return inv_global_scale_; } + + void SetQuantFieldRect(const ImageF& qf, const Rect& rect, + ImageI* JXL_RESTRICT raw_quant_field) const; + + void SetQuantField(float quant_dc, const ImageF& qf, + ImageI* JXL_RESTRICT raw_quant_field); + + void SetQuant(float quant_dc, float quant_ac, + ImageI* JXL_RESTRICT raw_quant_field); + + // Returns the DC quantization base value, which is currently global (not + // adaptive). The actual scale factor used to dequantize pixels in channel c + // is: inv_quant_dc() * dequant_->DCQuant(c). + float inv_quant_dc() const { return inv_quant_dc_; } + + // Dequantize by multiplying with this times dequant_matrix. + float inv_quant_ac(int32_t quant) const { return inv_global_scale_ / quant; } + + QuantizerParams GetParams() const; + + Status Decode(BitReader* reader); + + void DumpQuantizationMap(const ImageI& raw_quant_field) const; + + JXL_INLINE const float* DequantMatrix(size_t quant_kind, size_t c) const { + return dequant_->Matrix(quant_kind, c); + } + + JXL_INLINE const float* InvDequantMatrix(size_t quant_kind, size_t c) const { + return dequant_->InvMatrix(quant_kind, c); + } + + // Calculates DC quantization step. + JXL_INLINE float GetDcStep(size_t c) const { + return inv_quant_dc_ * dequant_->DCQuant(c); + } + JXL_INLINE float GetInvDcStep(size_t c) const { + return dequant_->InvDCQuant(c) * (global_scale_float_ * quant_dc_); + } + + JXL_INLINE const float* MulDC() const { return mul_dc_; } + JXL_INLINE const float* InvMulDC() const { return inv_mul_dc_; } + + JXL_INLINE void ClearDCMul() { + std::fill(mul_dc_, mul_dc_ + 4, 1.f); + std::fill(inv_mul_dc_, inv_mul_dc_ + 4, 1.f); + } + + void ComputeGlobalScaleAndQuant(float quant_dc, float quant_median, + float quant_median_absd); + + private: + float mul_dc_[4]; + float inv_mul_dc_[4]; + + // These are serialized: + int global_scale_; + int quant_dc_; + + // These are derived from global_scale_: + float inv_global_scale_; + float global_scale_float_; // reciprocal of inv_global_scale_ + float inv_quant_dc_; + + float zero_bias_[3]; + const DequantMatrices* dequant_; +}; + +struct QuantizerParams : public Fields { + QuantizerParams() { Bundle::Init(this); } + JXL_FIELDS_NAME(QuantizerParams) + + Status VisitFields(Visitor* JXL_RESTRICT visitor) override; + + uint32_t global_scale; + uint32_t quant_dc; +}; + +} // namespace jxl + +#endif // LIB_JXL_QUANTIZER_H_ diff --git a/third_party/jpeg-xl/lib/jxl/quantizer_test.cc b/third_party/jpeg-xl/lib/jxl/quantizer_test.cc new file mode 100644 index 0000000000..aff19f42c1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/quantizer_test.cc @@ -0,0 +1,80 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/quantizer.h" + +#include "lib/jxl/base/span.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +void TestEquivalence(int qxsize, int qysize, const Quantizer& quantizer1, + const Quantizer& quantizer2) { + ASSERT_NEAR(quantizer1.inv_quant_dc(), quantizer2.inv_quant_dc(), 1e-7); +} + +TEST(QuantizerTest, QuantizerParams) { + for (uint32_t i = 1; i < 10000; ++i) { + QuantizerParams p; + p.global_scale = i; + size_t extension_bits = 0, total_bits = 0; + EXPECT_TRUE(Bundle::CanEncode(p, &extension_bits, &total_bits)); + EXPECT_EQ(0u, extension_bits); + EXPECT_GE(total_bits, 4u); + } +} + +TEST(QuantizerTest, BitStreamRoundtripSameQuant) { + const int qxsize = 8; + const int qysize = 8; + DequantMatrices dequant; + Quantizer quantizer1(&dequant); + ImageI raw_quant_field(qxsize, qysize); + quantizer1.SetQuant(0.17f, 0.17f, &raw_quant_field); + BitWriter writer; + QuantizerParams params = quantizer1.GetParams(); + EXPECT_TRUE(WriteQuantizerParams(params, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + const size_t bits_written = writer.BitsWritten(); + Quantizer quantizer2(&dequant); + BitReader reader(writer.GetSpan()); + EXPECT_TRUE(quantizer2.Decode(&reader)); + EXPECT_TRUE(reader.JumpToByteBoundary()); + EXPECT_EQ(reader.TotalBitsConsumed(), bits_written); + EXPECT_TRUE(reader.Close()); + TestEquivalence(qxsize, qysize, quantizer1, quantizer2); +} + +TEST(QuantizerTest, BitStreamRoundtripRandomQuant) { + const int qxsize = 8; + const int qysize = 8; + DequantMatrices dequant; + Quantizer quantizer1(&dequant); + ImageI raw_quant_field(qxsize, qysize); + quantizer1.SetQuant(0.17f, 0.17f, &raw_quant_field); + float quant_dc = 0.17f; + ImageF qf(qxsize, qysize); + RandomFillImage(&qf, 0.0f, 1.0f); + quantizer1.SetQuantField(quant_dc, qf, &raw_quant_field); + BitWriter writer; + QuantizerParams params = quantizer1.GetParams(); + EXPECT_TRUE(WriteQuantizerParams(params, &writer, 0, nullptr)); + writer.ZeroPadToByte(); + const size_t bits_written = writer.BitsWritten(); + Quantizer quantizer2(&dequant); + BitReader reader(writer.GetSpan()); + EXPECT_TRUE(quantizer2.Decode(&reader)); + EXPECT_TRUE(reader.JumpToByteBoundary()); + EXPECT_EQ(reader.TotalBitsConsumed(), bits_written); + EXPECT_TRUE(reader.Close()); + TestEquivalence(qxsize, qysize, quantizer1, quantizer2); +} +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/rational_polynomial_test.cc b/third_party/jpeg-xl/lib/jxl/rational_polynomial_test.cc new file mode 100644 index 0000000000..bc31cdd092 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/rational_polynomial_test.cc @@ -0,0 +1,237 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <cmath> +#include <string> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/rational_polynomial_test.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> +#include <hwy/tests/hwy_gtest.h> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/rational_polynomial-inl.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/testing.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +using T = float; // required by EvalLog2 +using D = HWY_FULL(T); + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::GetLane; +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; + +// Generic: only computes polynomial +struct EvalPoly { + template <size_t NP, size_t NQ> + T operator()(T x, const T (&p)[NP], const T (&q)[NQ]) const { + const HWY_FULL(T) d; + const auto vx = Set(d, x); + const auto approx = EvalRationalPolynomial(d, vx, p, q); + return GetLane(approx); + } +}; + +// Range reduction for log2 +struct EvalLog2 { + template <size_t NP, size_t NQ> + T operator()(T x, const T (&p)[NP], const T (&q)[NQ]) const { + const HWY_FULL(T) d; + auto vx = Set(d, x); + + const HWY_FULL(int32_t) di; + const auto x_bits = BitCast(di, vx); + // Cannot handle negative numbers / NaN. + JXL_DASSERT(AllTrue(di, Eq(Abs(x_bits), x_bits))); + + // Range reduction to [-1/3, 1/3] - 3 integer, 2 float ops + const auto exp_bits = Sub(x_bits, Set(di, 0x3f2aaaab)); // = 2/3 + // Shifted exponent = log2; also used to clear mantissa. + const auto exp_shifted = ShiftRight<23>(exp_bits); + const auto mantissa = BitCast(d, Sub(x_bits, ShiftLeft<23>(exp_shifted))); + const auto exp_val = ConvertTo(d, exp_shifted); + vx = Sub(mantissa, Set(d, 1.0f)); + + const auto approx = Add(EvalRationalPolynomial(d, vx, p, q), exp_val); + return GetLane(approx); + } +}; + +// Functions to approximate: + +T LinearToSrgb8Direct(T val) { + if (val < 0.0) return 0.0; + if (val >= 255.0) return 255.0; + if (val <= 10.0 / 12.92) return val * 12.92; + return 255.0 * (std::pow(val / 255.0, 1.0 / 2.4) * 1.055 - 0.055); +} + +T SimpleGamma(T v) { + static const T kGamma = 0.387494322593; + static const T limit = 43.01745241042018; + T bright = v - limit; + if (bright >= 0) { + static const T mul = 0.0383723643799; + v -= bright * mul; + } + static const T limit2 = 94.68634353321337; + T bright2 = v - limit2; + if (bright2 >= 0) { + static const T mul = 0.22885405968; + v -= bright2 * mul; + } + static const T offset = 0.156775786057; + static const T scale = 8.898059160493739; + T retval = scale * (offset + pow(v, kGamma)); + return retval; +} + +// Runs CaratheodoryFejer and verifies the polynomial using a lot of samples to +// return the biggest error. +template <size_t NP, size_t NQ, class Eval> +T RunApproximation(T x0, T x1, const T (&p)[NP], const T (&q)[NQ], + const Eval& eval, T func_to_approx(T)) { + float maxerr = 0; + T lastPrint = 0; + // NOLINTNEXTLINE(clang-analyzer-security.FloatLoopCounter) + for (T x = x0; x <= x1; x += (x1 - x0) / 10000.0) { + const T f = func_to_approx(x); + const T g = eval(x, p, q); + maxerr = std::max(fabsf(g - f), maxerr); + if (x == x0 || x - lastPrint > (x1 - x0) / 20.0) { + printf("x: %11.6f, f: %11.6f, g: %11.6f, e: %11.6f\n", x, f, g, + fabs(g - f)); + lastPrint = x; + } + } + return maxerr; +} + +void TestSimpleGamma() { + const T p[4 * (6 + 1)] = { + HWY_REP4(-5.0646949363741811E-05), HWY_REP4(6.7369380528439771E-05), + HWY_REP4(8.9376652530412794E-05), HWY_REP4(2.1153513301520462E-06), + HWY_REP4(-6.9130322970386449E-08), HWY_REP4(3.9424752749293728E-10), + HWY_REP4(1.2360288207619576E-13)}; + + const T q[4 * (6 + 1)] = { + HWY_REP4(-6.6389733798591366E-06), HWY_REP4(1.3299859726565908E-05), + HWY_REP4(3.8538748358398873E-06), HWY_REP4(-2.8707687262928236E-08), + HWY_REP4(-6.6897385800005434E-10), HWY_REP4(6.1428748869186003E-12), + HWY_REP4(-2.5475738169252870E-15)}; + + const T err = RunApproximation(0.77, 274.579999999999984, p, q, EvalPoly(), + SimpleGamma); + EXPECT_LT(err, 0.05); +} + +void TestLinearToSrgb8Direct() { + const T p[4 * (5 + 1)] = { + HWY_REP4(-9.5357499040105154E-05), HWY_REP4(4.6761186249798248E-04), + HWY_REP4(2.5708174333943594E-04), HWY_REP4(1.5250087770436082E-05), + HWY_REP4(1.1946768008931187E-07), HWY_REP4(5.9916446295972850E-11)}; + + const T q[4 * (4 + 1)] = { + HWY_REP4(1.8932479758079768E-05), HWY_REP4(2.7312342474687321E-05), + HWY_REP4(4.3901204783327006E-06), HWY_REP4(1.0417787306920273E-07), + HWY_REP4(3.0084206762140419E-10)}; + + const T err = + RunApproximation(0.77, 255, p, q, EvalPoly(), LinearToSrgb8Direct); + EXPECT_LT(err, 0.05); +} + +void TestExp() { + const T p[4 * (2 + 1)] = {HWY_REP4(9.6266879665530902E-01), + HWY_REP4(4.8961265681586763E-01), + HWY_REP4(8.2619259189548433E-02)}; + const T q[4 * (2 + 1)] = {HWY_REP4(9.6259895571622622E-01), + HWY_REP4(-4.7272457588933831E-01), + HWY_REP4(7.4802088567547664E-02)}; + const T err = + RunApproximation(-1, 1, p, q, EvalPoly(), [](T x) { return T(exp(x)); }); + EXPECT_LT(err, 1E-4); +} + +void TestNegExp() { + // 4,3 is the min required for monotonicity; max error in 0,10: 751 ppm + // no benefit for k>50. + const T p[4 * (4 + 1)] = { + HWY_REP4(5.9580258551150123E-02), HWY_REP4(-2.5073728806886408E-02), + HWY_REP4(4.1561830213689248E-03), HWY_REP4(-3.1815408488900372E-04), + HWY_REP4(9.3866690094906802E-06)}; + const T q[4 * (3 + 1)] = { + HWY_REP4(5.9579108238812878E-02), HWY_REP4(3.4542074345478582E-02), + HWY_REP4(8.7263562483501714E-03), HWY_REP4(1.4095109143061216E-03)}; + + const T err = + RunApproximation(0, 10, p, q, EvalPoly(), [](T x) { return T(exp(-x)); }); + EXPECT_LT(err, sizeof(T) == 8 ? 2E-5 : 3E-5); +} + +void TestSin() { + const T p[4 * (6 + 1)] = { + HWY_REP4(1.5518122109203780E-05), HWY_REP4(2.3388958643675966E+00), + HWY_REP4(-8.6705520940849157E-01), HWY_REP4(-1.9702294764873535E-01), + HWY_REP4(1.2193404314472320E-01), HWY_REP4(-1.7373966109788839E-02), + HWY_REP4(7.8829435883034796E-04)}; + const T q[4 * (5 + 1)] = { + HWY_REP4(2.3394371422557279E+00), HWY_REP4(-8.7028221081288615E-01), + HWY_REP4(2.0052872219658430E-01), HWY_REP4(-3.2460335995264836E-02), + HWY_REP4(3.1546157932479282E-03), HWY_REP4(-1.6692542019380155E-04)}; + + const T err = RunApproximation(0, Pi<T>(1) * 2, p, q, EvalPoly(), + [](T x) { return T(sin(x)); }); + EXPECT_LT(err, sizeof(T) == 8 ? 5E-4 : 7E-4); +} + +void TestLog() { + HWY_ALIGN const T p[4 * (2 + 1)] = {HWY_REP4(-1.8503833400518310E-06), + HWY_REP4(1.4287160470083755E+00), + HWY_REP4(7.4245873327820566E-01)}; + HWY_ALIGN const T q[4 * (2 + 1)] = {HWY_REP4(9.9032814277590719E-01), + HWY_REP4(1.0096718572241148E+00), + HWY_REP4(1.7409343003366853E-01)}; + const T err = RunApproximation(1E-6, 1000, p, q, EvalLog2(), std::log2); + printf("%E\n", err); +} + +HWY_NOINLINE void TestRationalPolynomial() { + TestSimpleGamma(); + TestLinearToSrgb8Direct(); + TestExp(); + TestNegExp(); + TestSin(); + TestLog(); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class RationalPolynomialTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(RationalPolynomialTest); + +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestSimpleGamma); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestLinearToSrgb8Direct); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestExp); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestNegExp); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestSin); +HWY_EXPORT_AND_TEST_P(RationalPolynomialTest, TestLog); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/low_memory_render_pipeline.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/low_memory_render_pipeline.cc new file mode 100644 index 0000000000..9aefdd007d --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/low_memory_render_pipeline.cc @@ -0,0 +1,864 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/low_memory_render_pipeline.h" + +#include <algorithm> +#include <queue> +#include <tuple> + +#include "lib/jxl/base/arch_macros.h" +#include "lib/jxl/image_ops.h" + +namespace jxl { +std::pair<size_t, size_t> +LowMemoryRenderPipeline::ColorDimensionsToChannelDimensions( + std::pair<size_t, size_t> in, size_t c, size_t stage) const { + std::pair<size_t, size_t> ret; + std::pair<size_t, size_t> shift = channel_shifts_[stage][c]; + ret.first = + ((in.first << base_color_shift_) + (1 << shift.first) - 1) >> shift.first; + ret.second = ((in.second << base_color_shift_) + (1 << shift.second) - 1) >> + shift.second; + return ret; +} + +std::pair<size_t, size_t> LowMemoryRenderPipeline::BorderToStore( + size_t c) const { + auto ret = ColorDimensionsToChannelDimensions(group_border_, c, 0); + ret.first += padding_[0][c].first; + ret.second += padding_[0][c].second; + return ret; +} + +void LowMemoryRenderPipeline::SaveBorders(size_t group_id, size_t c, + const ImageF& in) { + size_t gy = group_id / frame_dimensions_.xsize_groups; + size_t gx = group_id % frame_dimensions_.xsize_groups; + size_t hshift = channel_shifts_[0][c].first; + size_t vshift = channel_shifts_[0][c].second; + size_t x0 = gx * GroupInputXSize(c); + size_t x1 = std::min((gx + 1) * GroupInputXSize(c), + DivCeil(frame_dimensions_.xsize_upsampled, 1 << hshift)); + size_t y0 = gy * GroupInputYSize(c); + size_t y1 = std::min((gy + 1) * GroupInputYSize(c), + DivCeil(frame_dimensions_.ysize_upsampled, 1 << vshift)); + + auto borders = BorderToStore(c); + size_t borderx_write = borders.first; + size_t bordery_write = borders.second; + + if (gy > 0) { + Rect from(group_data_x_border_, group_data_y_border_, x1 - x0, + bordery_write); + Rect to(x0, (gy * 2 - 1) * bordery_write, x1 - x0, bordery_write); + CopyImageTo(from, in, to, &borders_horizontal_[c]); + } + if (gy + 1 < frame_dimensions_.ysize_groups) { + Rect from(group_data_x_border_, + group_data_y_border_ + y1 - y0 - bordery_write, x1 - x0, + bordery_write); + Rect to(x0, (gy * 2) * bordery_write, x1 - x0, bordery_write); + CopyImageTo(from, in, to, &borders_horizontal_[c]); + } + if (gx > 0) { + Rect from(group_data_x_border_, group_data_y_border_, borderx_write, + y1 - y0); + Rect to((gx * 2 - 1) * borderx_write, y0, borderx_write, y1 - y0); + CopyImageTo(from, in, to, &borders_vertical_[c]); + } + if (gx + 1 < frame_dimensions_.xsize_groups) { + Rect from(group_data_x_border_ + x1 - x0 - borderx_write, + group_data_y_border_, borderx_write, y1 - y0); + Rect to((gx * 2) * borderx_write, y0, borderx_write, y1 - y0); + CopyImageTo(from, in, to, &borders_vertical_[c]); + } +} + +void LowMemoryRenderPipeline::LoadBorders(size_t group_id, size_t c, + const Rect& r, ImageF* out) { + size_t gy = group_id / frame_dimensions_.xsize_groups; + size_t gx = group_id % frame_dimensions_.xsize_groups; + size_t hshift = channel_shifts_[0][c].first; + size_t vshift = channel_shifts_[0][c].second; + // Coordinates of the group in the image. + size_t x0 = gx * GroupInputXSize(c); + size_t x1 = std::min((gx + 1) * GroupInputXSize(c), + DivCeil(frame_dimensions_.xsize_upsampled, 1 << hshift)); + size_t y0 = gy * GroupInputYSize(c); + size_t y1 = std::min((gy + 1) * GroupInputYSize(c), + DivCeil(frame_dimensions_.ysize_upsampled, 1 << vshift)); + + size_t paddingx = padding_[0][c].first; + size_t paddingy = padding_[0][c].second; + + auto borders = BorderToStore(c); + size_t borderx_write = borders.first; + size_t bordery_write = borders.second; + + // Limits of the area to copy from, in image coordinates. + JXL_DASSERT(r.x0() == 0 || (r.x0() << base_color_shift_) >= paddingx); + size_t x0src = DivCeil(r.x0() << base_color_shift_, 1 << hshift); + if (x0src != 0) { + x0src -= paddingx; + } + // r may be such that r.x1 (namely x0() + xsize()) is within paddingx of the + // right side of the image, so we use min() here. + size_t x1src = + DivCeil((r.x0() + r.xsize()) << base_color_shift_, 1 << hshift); + x1src = std::min(x1src + paddingx, + DivCeil(frame_dimensions_.xsize_upsampled, 1 << hshift)); + + // Similar computation for y. + JXL_DASSERT(r.y0() == 0 || (r.y0() << base_color_shift_) >= paddingy); + size_t y0src = DivCeil(r.y0() << base_color_shift_, 1 << vshift); + if (y0src != 0) { + y0src -= paddingy; + } + size_t y1src = + DivCeil((r.y0() + r.ysize()) << base_color_shift_, 1 << vshift); + y1src = std::min(y1src + paddingy, + DivCeil(frame_dimensions_.ysize_upsampled, 1 << vshift)); + + // Copy other groups' borders from the border storage. + if (y0src < y0) { + JXL_DASSERT(gy > 0); + CopyImageTo( + Rect(x0src, (gy * 2 - 2) * bordery_write, x1src - x0src, bordery_write), + borders_horizontal_[c], + Rect(group_data_x_border_ + x0src - x0, + group_data_y_border_ - bordery_write, x1src - x0src, + bordery_write), + out); + } + if (y1src > y1) { + // When copying the bottom border we must not be on the bottom groups. + JXL_DASSERT(gy + 1 < frame_dimensions_.ysize_groups); + CopyImageTo( + Rect(x0src, (gy * 2 + 1) * bordery_write, x1src - x0src, bordery_write), + borders_horizontal_[c], + Rect(group_data_x_border_ + x0src - x0, group_data_y_border_ + y1 - y0, + x1src - x0src, bordery_write), + out); + } + if (x0src < x0) { + JXL_DASSERT(gx > 0); + CopyImageTo( + Rect((gx * 2 - 2) * borderx_write, y0src, borderx_write, y1src - y0src), + borders_vertical_[c], + Rect(group_data_x_border_ - borderx_write, + group_data_y_border_ + y0src - y0, borderx_write, y1src - y0src), + out); + } + if (x1src > x1) { + // When copying the right border we must not be on the rightmost groups. + JXL_DASSERT(gx + 1 < frame_dimensions_.xsize_groups); + CopyImageTo( + Rect((gx * 2 + 1) * borderx_write, y0src, borderx_write, y1src - y0src), + borders_vertical_[c], + Rect(group_data_x_border_ + x1 - x0, group_data_y_border_ + y0src - y0, + borderx_write, y1src - y0src), + out); + } +} + +size_t LowMemoryRenderPipeline::GroupInputXSize(size_t c) const { + return (frame_dimensions_.group_dim << base_color_shift_) >> + channel_shifts_[0][c].first; +} + +size_t LowMemoryRenderPipeline::GroupInputYSize(size_t c) const { + return (frame_dimensions_.group_dim << base_color_shift_) >> + channel_shifts_[0][c].second; +} + +void LowMemoryRenderPipeline::EnsureBordersStorage() { + const auto& shifts = channel_shifts_[0]; + if (borders_horizontal_.size() < shifts.size()) { + borders_horizontal_.resize(shifts.size()); + borders_vertical_.resize(shifts.size()); + } + for (size_t c = 0; c < shifts.size(); c++) { + auto borders = BorderToStore(c); + size_t borderx = borders.first; + size_t bordery = borders.second; + JXL_DASSERT(frame_dimensions_.xsize_groups > 0); + size_t num_xborders = (frame_dimensions_.xsize_groups - 1) * 2; + JXL_DASSERT(frame_dimensions_.ysize_groups > 0); + size_t num_yborders = (frame_dimensions_.ysize_groups - 1) * 2; + size_t downsampled_xsize = + DivCeil(frame_dimensions_.xsize_upsampled_padded, 1 << shifts[c].first); + size_t downsampled_ysize = DivCeil(frame_dimensions_.ysize_upsampled_padded, + 1 << shifts[c].second); + Rect horizontal = Rect(0, 0, downsampled_xsize, bordery * num_yborders); + if (!SameSize(horizontal, borders_horizontal_[c])) { + borders_horizontal_[c] = ImageF(horizontal.xsize(), horizontal.ysize()); + } + Rect vertical = Rect(0, 0, borderx * num_xborders, downsampled_ysize); + if (!SameSize(vertical, borders_vertical_[c])) { + borders_vertical_[c] = ImageF(vertical.xsize(), vertical.ysize()); + } + } +} + +void LowMemoryRenderPipeline::Init() { + group_border_ = {0, 0}; + base_color_shift_ = CeilLog2Nonzero(frame_dimensions_.xsize_upsampled_padded / + frame_dimensions_.xsize_padded); + + const auto& shifts = channel_shifts_[0]; + + // Ensure that each channel has enough many border pixels. + for (size_t c = 0; c < shifts.size(); c++) { + group_border_.first = + std::max(group_border_.first, + DivCeil(padding_[0][c].first << channel_shifts_[0][c].first, + 1 << base_color_shift_)); + group_border_.second = + std::max(group_border_.second, + DivCeil(padding_[0][c].second << channel_shifts_[0][c].second, + 1 << base_color_shift_)); + } + + // Ensure that all channels have an integer number of border pixels in the + // input. + for (size_t c = 0; c < shifts.size(); c++) { + if (channel_shifts_[0][c].first >= base_color_shift_) { + group_border_.first = + RoundUpTo(group_border_.first, + 1 << (channel_shifts_[0][c].first - base_color_shift_)); + } + if (channel_shifts_[0][c].second >= base_color_shift_) { + group_border_.second = + RoundUpTo(group_border_.second, + 1 << (channel_shifts_[0][c].second - base_color_shift_)); + } + } + // Ensure that the X border on color channels is a multiple of kBlockDim or + // the vector size (required for EPF stages). Vectors on ARM NEON are never + // wider than 4 floats, so rounding to multiples of 4 is enough. +#if JXL_ARCH_ARM + constexpr size_t kGroupXAlign = 4; +#else + constexpr size_t kGroupXAlign = 16; +#endif + group_border_.first = RoundUpTo(group_border_.first, kGroupXAlign); + // Allocate borders in group images that are just enough for storing the + // borders to be copied in, plus any rounding to ensure alignment. + std::pair<size_t, size_t> max_border = {0, 0}; + for (size_t c = 0; c < shifts.size(); c++) { + max_border.first = std::max(BorderToStore(c).first, max_border.first); + max_border.second = std::max(BorderToStore(c).second, max_border.second); + } + group_data_x_border_ = RoundUpTo(max_border.first, kGroupXAlign); + group_data_y_border_ = max_border.second; + + EnsureBordersStorage(); + group_border_assigner_.Init(frame_dimensions_); + + for (first_trailing_stage_ = stages_.size(); first_trailing_stage_ > 0; + first_trailing_stage_--) { + bool has_inout_c = false; + for (size_t c = 0; c < shifts.size(); c++) { + if (stages_[first_trailing_stage_ - 1]->GetChannelMode(c) == + RenderPipelineChannelMode::kInOut) { + has_inout_c = true; + } + } + if (has_inout_c) { + break; + } + } + + first_image_dim_stage_ = stages_.size(); + for (size_t i = 0; i < stages_.size(); i++) { + std::vector<std::pair<size_t, size_t>> input_sizes(shifts.size()); + for (size_t c = 0; c < shifts.size(); c++) { + input_sizes[c] = + std::make_pair(DivCeil(frame_dimensions_.xsize_upsampled, + 1 << channel_shifts_[i][c].first), + DivCeil(frame_dimensions_.ysize_upsampled, + 1 << channel_shifts_[i][c].second)); + } + stages_[i]->SetInputSizes(input_sizes); + if (stages_[i]->SwitchToImageDimensions()) { + // We don't allow kInOut after switching to image dimensions. + JXL_ASSERT(i >= first_trailing_stage_); + first_image_dim_stage_ = i + 1; + stages_[i]->GetImageDimensions(&full_image_xsize_, &full_image_ysize_, + &frame_origin_); + break; + } + } + for (size_t i = first_image_dim_stage_; i < stages_.size(); i++) { + if (stages_[i]->SwitchToImageDimensions()) { + JXL_UNREACHABLE("Cannot switch to image dimensions multiple times"); + } + std::vector<std::pair<size_t, size_t>> input_sizes(shifts.size()); + for (size_t c = 0; c < shifts.size(); c++) { + input_sizes[c] = {full_image_xsize_, full_image_ysize_}; + } + stages_[i]->SetInputSizes(input_sizes); + } + + anyc_.resize(stages_.size()); + for (size_t i = 0; i < stages_.size(); i++) { + for (size_t c = 0; c < shifts.size(); c++) { + if (stages_[i]->GetChannelMode(c) != + RenderPipelineChannelMode::kIgnored) { + anyc_[i] = c; + } + } + } + + stage_input_for_channel_ = std::vector<std::vector<int32_t>>( + stages_.size(), std::vector<int32_t>(shifts.size())); + for (size_t c = 0; c < shifts.size(); c++) { + int input = -1; + for (size_t i = 0; i < stages_.size(); i++) { + stage_input_for_channel_[i][c] = input; + if (stages_[i]->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + input = i; + } + } + } + + image_rect_.resize(stages_.size()); + for (size_t i = 0; i < stages_.size(); i++) { + size_t x1 = DivCeil(frame_dimensions_.xsize_upsampled, + 1 << channel_shifts_[i][anyc_[i]].first); + size_t y1 = DivCeil(frame_dimensions_.ysize_upsampled, + 1 << channel_shifts_[i][anyc_[i]].second); + image_rect_[i] = Rect(0, 0, x1, y1); + } + + virtual_ypadding_for_output_.resize(stages_.size()); + xpadding_for_output_.resize(stages_.size()); + for (size_t c = 0; c < shifts.size(); c++) { + int ypad = 0; + int xpad = 0; + for (size_t i = stages_.size(); i-- > 0;) { + if (stages_[i]->GetChannelMode(c) != + RenderPipelineChannelMode::kIgnored) { + virtual_ypadding_for_output_[i] = + std::max(ypad, virtual_ypadding_for_output_[i]); + xpadding_for_output_[i] = std::max(xpad, xpadding_for_output_[i]); + } + if (stages_[i]->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + ypad = (DivCeil(ypad, 1 << channel_shifts_[i][c].second) + + stages_[i]->settings_.border_y) + << channel_shifts_[i][c].second; + xpad = DivCeil(xpad, 1 << stages_[i]->settings_.shift_x) + + stages_[i]->settings_.border_x; + } + } + } +} + +void LowMemoryRenderPipeline::PrepareForThreadsInternal(size_t num, + bool use_group_ids) { + const auto& shifts = channel_shifts_[0]; + use_group_ids_ = use_group_ids; + size_t num_buffers = use_group_ids_ ? frame_dimensions_.num_groups : num; + for (size_t t = group_data_.size(); t < num_buffers; t++) { + group_data_.emplace_back(); + group_data_[t].resize(shifts.size()); + for (size_t c = 0; c < shifts.size(); c++) { + group_data_[t][c] = ImageF(GroupInputXSize(c) + group_data_x_border_ * 2, + GroupInputYSize(c) + group_data_y_border_ * 2); + } + } + // TODO(veluca): avoid reallocating buffers if not needed. + stage_data_.resize(num); + size_t upsampling = 1u << base_color_shift_; + size_t group_dim = frame_dimensions_.group_dim * upsampling; + size_t padding = + 2 * group_data_x_border_ * upsampling + // maximum size of a rect + 2 * kRenderPipelineXOffset; // extra padding for processing + size_t stage_buffer_xsize = group_dim + padding; + for (size_t t = 0; t < num; t++) { + stage_data_[t].resize(shifts.size()); + for (size_t c = 0; c < shifts.size(); c++) { + stage_data_[t][c].resize(stages_.size()); + size_t next_y_border = 0; + for (size_t i = stages_.size(); i-- > 0;) { + if (stages_[i]->GetChannelMode(c) == + RenderPipelineChannelMode::kInOut) { + size_t stage_buffer_ysize = + 2 * next_y_border + (1 << stages_[i]->settings_.shift_y); + stage_buffer_ysize = 1 << CeilLog2Nonzero(stage_buffer_ysize); + next_y_border = stages_[i]->settings_.border_y; + stage_data_[t][c][i] = ImageF(stage_buffer_xsize, stage_buffer_ysize); + } + } + } + } + if (first_image_dim_stage_ != stages_.size()) { + RectT<ssize_t> image_rect(0, 0, frame_dimensions_.xsize_upsampled, + frame_dimensions_.ysize_upsampled); + RectT<ssize_t> full_image_rect(0, 0, full_image_xsize_, full_image_ysize_); + image_rect = image_rect.Translate(frame_origin_.x0, frame_origin_.y0); + image_rect = image_rect.Intersection(full_image_rect); + if (image_rect.xsize() == 0 || image_rect.ysize() == 0) { + image_rect = RectT<ssize_t>(0, 0, 0, 0); + } + size_t left_padding = image_rect.x0(); + size_t middle_padding = group_dim; + size_t right_padding = full_image_xsize_ - image_rect.x1(); + size_t out_of_frame_xsize = + padding + + std::max(left_padding, std::max(middle_padding, right_padding)); + out_of_frame_data_.resize(num); + for (size_t t = 0; t < num; t++) { + out_of_frame_data_[t] = ImageF(out_of_frame_xsize, shifts.size()); + } + } +} + +std::vector<std::pair<ImageF*, Rect>> LowMemoryRenderPipeline::PrepareBuffers( + size_t group_id, size_t thread_id) { + std::vector<std::pair<ImageF*, Rect>> ret(channel_shifts_[0].size()); + const size_t gx = group_id % frame_dimensions_.xsize_groups; + const size_t gy = group_id / frame_dimensions_.xsize_groups; + for (size_t c = 0; c < channel_shifts_[0].size(); c++) { + ret[c].first = &group_data_[use_group_ids_ ? group_id : thread_id][c]; + ret[c].second = Rect(group_data_x_border_, group_data_y_border_, + GroupInputXSize(c), GroupInputYSize(c), + DivCeil(frame_dimensions_.xsize_upsampled, + 1 << channel_shifts_[0][c].first) - + gx * GroupInputXSize(c) + group_data_x_border_, + DivCeil(frame_dimensions_.ysize_upsampled, + 1 << channel_shifts_[0][c].second) - + gy * GroupInputYSize(c) + group_data_y_border_); + } + return ret; +} + +namespace { + +JXL_INLINE int GetMirroredY(int y, ssize_t group_y0, ssize_t image_ysize) { + if (group_y0 == 0 && (y < 0 || y + group_y0 >= image_ysize)) { + return Mirror(y, image_ysize); + } + if (y + group_y0 >= image_ysize) { + // Here we know that the one mirroring step is sufficient. + return 2 * image_ysize - (y + group_y0) - 1 - group_y0; + } + return y; +} + +JXL_INLINE void ApplyXMirroring(float* row, ssize_t borderx, ssize_t group_x0, + ssize_t group_xsize, ssize_t image_xsize) { + if (image_xsize <= borderx) { + if (group_x0 == 0) { + for (ssize_t ix = 0; ix < borderx; ix++) { + row[kRenderPipelineXOffset - ix - 1] = + row[kRenderPipelineXOffset + Mirror(-ix - 1, image_xsize)]; + } + } + if (group_xsize + borderx + group_x0 >= image_xsize) { + for (ssize_t ix = 0; ix < borderx; ix++) { + row[kRenderPipelineXOffset + image_xsize + ix - group_x0] = + row[kRenderPipelineXOffset + Mirror(image_xsize + ix, image_xsize) - + group_x0]; + } + } + } else { + // Here we know that the one mirroring step is sufficient. + if (group_x0 == 0) { + for (ssize_t ix = 0; ix < borderx; ix++) { + row[kRenderPipelineXOffset - ix - 1] = row[kRenderPipelineXOffset + ix]; + } + } + if (group_xsize + borderx + group_x0 >= image_xsize) { + for (ssize_t ix = 0; ix < borderx; ix++) { + row[kRenderPipelineXOffset + image_xsize - group_x0 + ix] = + row[kRenderPipelineXOffset + image_xsize - group_x0 - ix - 1]; + } + } + } +} + +// Information about where the *output* of each stage is stored. +class Rows { + public: + Rows(const std::vector<std::unique_ptr<RenderPipelineStage>>& stages, + const Rect data_max_color_channel_rect, int group_data_x_border, + int group_data_y_border, + const std::vector<std::pair<size_t, size_t>>& group_data_shift, + size_t base_color_shift, std::vector<std::vector<ImageF>>& thread_data, + std::vector<ImageF>& input_data) { + size_t num_stages = stages.size(); + size_t num_channels = input_data.size(); + + JXL_ASSERT(thread_data.size() == num_channels); + JXL_ASSERT(group_data_shift.size() == num_channels); + +#if JXL_ENABLE_ASSERT + for (const auto& td : thread_data) { + JXL_ASSERT(td.size() == num_stages); + } +#endif + + rows_.resize(num_stages + 1, std::vector<RowInfo>(num_channels)); + + for (size_t i = 0; i < num_stages; i++) { + for (size_t c = 0; c < input_data.size(); c++) { + if (stages[i]->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + rows_[i + 1][c].ymod_minus_1 = thread_data[c][i].ysize() - 1; + rows_[i + 1][c].base_ptr = thread_data[c][i].Row(0); + rows_[i + 1][c].stride = thread_data[c][i].PixelsPerRow(); + } + } + } + + for (size_t c = 0; c < input_data.size(); c++) { + auto channel_group_data_rect = + data_max_color_channel_rect.As<ssize_t>() + .Translate(-group_data_x_border, -group_data_y_border) + .ShiftLeft(base_color_shift) + .CeilShiftRight(group_data_shift[c]) + .Translate(group_data_x_border - ssize_t(kRenderPipelineXOffset), + group_data_y_border); + rows_[0][c].base_ptr = channel_group_data_rect.Row(&input_data[c], 0); + rows_[0][c].stride = input_data[c].PixelsPerRow(); + rows_[0][c].ymod_minus_1 = -1; + } + } + + // Stage -1 refers to the input data; all other values must be nonnegative and + // refer to the data for the output of that stage. + JXL_INLINE float* GetBuffer(int stage, int y, size_t c) const { + JXL_DASSERT(stage >= -1); + const RowInfo& info = rows_[stage + 1][c]; + return info.base_ptr + ssize_t(info.stride) * (y & info.ymod_minus_1); + } + + private: + struct RowInfo { + // Pointer to beginning of the first row. + float* base_ptr; + // Modulo value for the y axis minus 1 (ymod is guaranteed to be a power of + // 2, which allows efficient mod computation by masking). + int ymod_minus_1; + // Number of floats per row. + size_t stride; + }; + std::vector<std::vector<RowInfo>> rows_; +}; + +} // namespace + +void LowMemoryRenderPipeline::RenderRect(size_t thread_id, + std::vector<ImageF>& input_data, + Rect data_max_color_channel_rect, + Rect image_max_color_channel_rect) { + // For each stage, the rect corresponding to the image area currently being + // processed, in the coordinates of that stage (i.e. with the scaling factor + // that that stage has). + std::vector<Rect> group_rect; + group_rect.resize(stages_.size()); + Rect image_area_rect = + image_max_color_channel_rect.ShiftLeft(base_color_shift_) + .Crop(frame_dimensions_.xsize_upsampled, + frame_dimensions_.ysize_upsampled); + for (size_t i = 0; i < stages_.size(); i++) { + group_rect[i] = + image_area_rect.CeilShiftRight(channel_shifts_[i][anyc_[i]]); + } + + ssize_t frame_x0 = + first_image_dim_stage_ == stages_.size() ? 0 : frame_origin_.x0; + ssize_t frame_y0 = + first_image_dim_stage_ == stages_.size() ? 0 : frame_origin_.y0; + size_t full_image_xsize = first_image_dim_stage_ == stages_.size() + ? frame_dimensions_.xsize_upsampled + : full_image_xsize_; + size_t full_image_ysize = first_image_dim_stage_ == stages_.size() + ? frame_dimensions_.ysize_upsampled + : full_image_ysize_; + + // Compute actual x-axis bounds for the current image area in the context of + // the full image this frame is part of. As the left boundary may be negative, + // we also create the x_pixels_skip value, defined as follows: + // - both x_pixels_skip and full_image_x0 are >= 0, and at least one is 0; + // - full_image_x0 - x_pixels_skip is the position of the current frame area + // in the full image. + ssize_t full_image_x0 = frame_x0 + image_area_rect.x0(); + ssize_t x_pixels_skip = 0; + if (full_image_x0 < 0) { + x_pixels_skip = -full_image_x0; + full_image_x0 = 0; + } + ssize_t full_image_x1 = frame_x0 + image_area_rect.x1(); + full_image_x1 = std::min<ssize_t>(full_image_x1, full_image_xsize); + + // If the current image area is entirely outside of the visible image, there + // is no point in proceeding. Note: this uses the assumption that if there is + // a stage with observable effects (i.e. a kInput stage), it only appears + // after the stage that switches to image dimensions. + if (full_image_x1 <= full_image_x0) return; + + // Data structures to hold information about input/output rows and their + // buffers. + Rows rows(stages_, data_max_color_channel_rect, group_data_x_border_, + group_data_y_border_, channel_shifts_[0], base_color_shift_, + stage_data_[thread_id], input_data); + + std::vector<RenderPipelineStage::RowInfo> input_rows(first_trailing_stage_ + + 1); + for (size_t i = 0; i < first_trailing_stage_; i++) { + input_rows[i].resize(input_data.size()); + } + input_rows[first_trailing_stage_].resize(input_data.size(), + std::vector<float*>(1)); + + // Maximum possible shift is 3. + RenderPipelineStage::RowInfo output_rows(input_data.size(), + std::vector<float*>(8)); + + // Fills in input_rows and output_rows for a given y value (relative to the + // start of the group, measured in actual pixels at the appropriate vertical + // scaling factor) and a given stage, applying mirroring if necessary. This + // function is somewhat inefficient for trailing kInOut or kInput stages, + // where just filling the input row once ought to be sufficient. + auto prepare_io_rows = [&](int y, size_t i) { + ssize_t bordery = stages_[i]->settings_.border_y; + size_t shifty = stages_[i]->settings_.shift_y; + auto make_row = [&](size_t c, ssize_t iy) { + size_t mirrored_y = GetMirroredY(y + iy - bordery, group_rect[i].y0(), + image_rect_[i].ysize()); + input_rows[i][c][iy] = + rows.GetBuffer(stage_input_for_channel_[i][c], mirrored_y, c); + ApplyXMirroring(input_rows[i][c][iy], stages_[i]->settings_.border_x, + group_rect[i].x0(), group_rect[i].xsize(), + image_rect_[i].xsize()); + }; + for (size_t c = 0; c < input_data.size(); c++) { + RenderPipelineChannelMode mode = stages_[i]->GetChannelMode(c); + if (mode == RenderPipelineChannelMode::kIgnored) { + continue; + } + // If we already have rows from a previous iteration, we can just shift + // the rows by 1 and insert the new one. + if (input_rows[i][c].size() == 2 * size_t(bordery) + 1) { + for (ssize_t iy = 0; iy < 2 * bordery; iy++) { + input_rows[i][c][iy] = input_rows[i][c][iy + 1]; + } + make_row(c, bordery * 2); + } else { + input_rows[i][c].resize(2 * bordery + 1); + for (ssize_t iy = 0; iy < 2 * bordery + 1; iy++) { + make_row(c, iy); + } + } + + // If necessary, get the output buffers. + if (mode == RenderPipelineChannelMode::kInOut) { + for (size_t iy = 0; iy < (1u << shifty); iy++) { + output_rows[c][iy] = rows.GetBuffer(i, y * (1 << shifty) + iy, c); + } + } + } + }; + + // We pretend that every stage has a vertical shift of 0, i.e. it is as tall + // as the final image. + // We call each such row a "virtual" row, because it may or may not correspond + // to an actual row of the current processing stage; actual processing happens + // when vy % (1<<vshift) == 0. + + int num_extra_rows = *std::max_element(virtual_ypadding_for_output_.begin(), + virtual_ypadding_for_output_.end()); + + for (int vy = -num_extra_rows; + vy < int(image_area_rect.ysize()) + num_extra_rows; vy++) { + for (size_t i = 0; i < first_trailing_stage_; i++) { + int stage_vy = vy - num_extra_rows + virtual_ypadding_for_output_[i]; + + if (stage_vy % (1 << channel_shifts_[i][anyc_[i]].second) != 0) { + continue; + } + + if (stage_vy < -virtual_ypadding_for_output_[i]) { + continue; + } + + int y = stage_vy >> channel_shifts_[i][anyc_[i]].second; + + ssize_t image_y = ssize_t(group_rect[i].y0()) + y; + // Do not produce rows in out-of-bounds areas. + if (image_y < 0 || image_y >= ssize_t(image_rect_[i].ysize())) { + continue; + } + + // Get the input/output rows and potentially apply mirroring to the input. + prepare_io_rows(y, i); + + // Produce output rows. + stages_[i]->ProcessRow(input_rows[i], output_rows, + xpadding_for_output_[i], group_rect[i].xsize(), + group_rect[i].x0(), image_y, thread_id); + } + + // Process trailing stages, i.e. the final set of non-kInOut stages; they + // all have the same input buffer and no need to use any mirroring. + + int y = vy - num_extra_rows; + + for (size_t c = 0; c < input_data.size(); c++) { + // Skip pixels that are not part of the actual final image area. + input_rows[first_trailing_stage_][c][0] = + rows.GetBuffer(stage_input_for_channel_[first_trailing_stage_][c], y, + c) + + x_pixels_skip; + } + + // Check that we are not outside of the bounds for the current rendering + // rect. Not doing so might result in overwriting some rows that have been + // written (or will be written) by other threads. + if (y < 0 || y >= ssize_t(image_area_rect.ysize())) { + continue; + } + + // Avoid running pipeline stages on pixels that are outside the full image + // area. As trailing stages have no borders, this is a free optimization + // (and may be necessary for correctness, as some stages assume coordinates + // are within bounds). + ssize_t full_image_y = frame_y0 + image_area_rect.y0() + y; + if (full_image_y < 0 || full_image_y >= ssize_t(full_image_ysize)) { + continue; + } + + for (size_t i = first_trailing_stage_; i < stages_.size(); i++) { + // Before the first_image_dim_stage_, coordinates are relative to the + // current frame. + size_t x0 = + i < first_image_dim_stage_ ? full_image_x0 - frame_x0 : full_image_x0; + size_t y = + i < first_image_dim_stage_ ? full_image_y - frame_y0 : full_image_y; + stages_[i]->ProcessRow(input_rows[first_trailing_stage_], output_rows, + /*xextra=*/0, full_image_x1 - full_image_x0, x0, y, + thread_id); + } + } +} + +void LowMemoryRenderPipeline::RenderPadding(size_t thread_id, Rect rect) { + if (rect.xsize() == 0) return; + size_t numc = channel_shifts_[0].size(); + RenderPipelineStage::RowInfo input_rows(numc, std::vector<float*>(1)); + RenderPipelineStage::RowInfo output_rows; + + for (size_t c = 0; c < numc; c++) { + input_rows[c][0] = out_of_frame_data_[thread_id].Row(c); + } + + for (size_t y = 0; y < rect.ysize(); y++) { + stages_[first_image_dim_stage_ - 1]->ProcessPaddingRow( + input_rows, rect.xsize(), rect.x0(), rect.y0() + y); + for (size_t i = first_image_dim_stage_; i < stages_.size(); i++) { + stages_[i]->ProcessRow(input_rows, output_rows, + /*xextra=*/0, rect.xsize(), rect.x0(), + rect.y0() + y, thread_id); + } + } +} + +void LowMemoryRenderPipeline::ProcessBuffers(size_t group_id, + size_t thread_id) { + std::vector<ImageF>& input_data = + group_data_[use_group_ids_ ? group_id : thread_id]; + + // Copy the group borders to the border storage. + for (size_t c = 0; c < input_data.size(); c++) { + SaveBorders(group_id, c, input_data[c]); + } + + size_t gy = group_id / frame_dimensions_.xsize_groups; + size_t gx = group_id % frame_dimensions_.xsize_groups; + + if (first_image_dim_stage_ != stages_.size()) { + size_t group_dim = frame_dimensions_.group_dim << base_color_shift_; + RectT<ssize_t> group_rect(gx * group_dim, gy * group_dim, group_dim, + group_dim); + RectT<ssize_t> image_rect(0, 0, frame_dimensions_.xsize_upsampled, + frame_dimensions_.ysize_upsampled); + RectT<ssize_t> full_image_rect(0, 0, full_image_xsize_, full_image_ysize_); + group_rect = group_rect.Translate(frame_origin_.x0, frame_origin_.y0); + image_rect = image_rect.Translate(frame_origin_.x0, frame_origin_.y0); + image_rect = image_rect.Intersection(full_image_rect); + group_rect = group_rect.Intersection(image_rect); + size_t x0 = group_rect.x0(); + size_t y0 = group_rect.y0(); + size_t x1 = group_rect.x1(); + size_t y1 = group_rect.y1(); + JXL_DEBUG_V(6, + "Rendering padding for full image rect %s " + "outside group rect %s", + Description(full_image_rect).c_str(), + Description(group_rect).c_str()); + + if (group_id == 0 && (image_rect.xsize() == 0 || image_rect.ysize() == 0)) { + // If this frame does not intersect with the full image, we have to + // initialize the whole image area with RenderPadding. + RenderPadding(thread_id, + Rect(0, 0, full_image_xsize_, full_image_ysize_)); + } + + // Render padding for groups that intersect with the full image. The case + // where no groups intersect was handled above. + if (group_rect.xsize() > 0 && group_rect.ysize() > 0) { + if (gx == 0 && gy == 0) { + RenderPadding(thread_id, Rect(0, 0, x0, y0)); + } + if (gy == 0) { + RenderPadding(thread_id, Rect(x0, 0, x1 - x0, y0)); + } + if (gx == 0) { + RenderPadding(thread_id, Rect(0, y0, x0, y1 - y0)); + } + if (gx == 0 && gy + 1 == frame_dimensions_.ysize_groups) { + RenderPadding(thread_id, Rect(0, y1, x0, full_image_ysize_ - y1)); + } + if (gy + 1 == frame_dimensions_.ysize_groups) { + RenderPadding(thread_id, Rect(x0, y1, x1 - x0, full_image_ysize_ - y1)); + } + if (gy == 0 && gx + 1 == frame_dimensions_.xsize_groups) { + RenderPadding(thread_id, Rect(x1, 0, full_image_xsize_ - x1, y0)); + } + if (gx + 1 == frame_dimensions_.xsize_groups) { + RenderPadding(thread_id, Rect(x1, y0, full_image_xsize_ - x1, y1 - y0)); + } + if (gy + 1 == frame_dimensions_.ysize_groups && + gx + 1 == frame_dimensions_.xsize_groups) { + RenderPadding(thread_id, Rect(x1, y1, full_image_xsize_ - x1, + full_image_ysize_ - y1)); + } + } + } + + Rect ready_rects[GroupBorderAssigner::kMaxToFinalize]; + size_t num_ready_rects = 0; + group_border_assigner_.GroupDone(group_id, group_border_.first, + group_border_.second, ready_rects, + &num_ready_rects); + for (size_t i = 0; i < num_ready_rects; i++) { + const Rect& image_max_color_channel_rect = ready_rects[i]; + for (size_t c = 0; c < input_data.size(); c++) { + LoadBorders(group_id, c, image_max_color_channel_rect, &input_data[c]); + } + Rect data_max_color_channel_rect( + group_data_x_border_ + image_max_color_channel_rect.x0() - + gx * frame_dimensions_.group_dim, + group_data_y_border_ + image_max_color_channel_rect.y0() - + gy * frame_dimensions_.group_dim, + image_max_color_channel_rect.xsize(), + image_max_color_channel_rect.ysize()); + RenderRect(thread_id, input_data, data_max_color_channel_rect, + image_max_color_channel_rect); + } +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/low_memory_render_pipeline.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/low_memory_render_pipeline.h new file mode 100644 index 0000000000..b386f7c078 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/low_memory_render_pipeline.h @@ -0,0 +1,111 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_LOW_MEMORY_RENDER_PIPELINE_H_ +#define LIB_JXL_RENDER_PIPELINE_LOW_MEMORY_RENDER_PIPELINE_H_ + +#include <stdint.h> + +#include "lib/jxl/dec_group_border.h" +#include "lib/jxl/render_pipeline/render_pipeline.h" + +namespace jxl { + +// A multithreaded, low-memory rendering pipeline that only allocates a minimal +// amount of buffers. +class LowMemoryRenderPipeline final : public RenderPipeline { + private: + std::vector<std::pair<ImageF*, Rect>> PrepareBuffers( + size_t group_id, size_t thread_id) override; + + void PrepareForThreadsInternal(size_t num, bool use_group_ids) override; + + void ProcessBuffers(size_t group_id, size_t thread_id) override; + + void ClearDone(size_t i) override { group_border_assigner_.ClearDone(i); } + + void Init() override; + + void EnsureBordersStorage(); + size_t GroupInputXSize(size_t c) const; + size_t GroupInputYSize(size_t c) const; + void RenderRect(size_t thread_id, std::vector<ImageF>& input_data, + Rect data_max_color_channel_rect, + Rect image_max_color_channel_rect); + void RenderPadding(size_t thread_id, Rect rect); + + void SaveBorders(size_t group_id, size_t c, const ImageF& in); + void LoadBorders(size_t group_id, size_t c, const Rect& r, ImageF* out); + + std::pair<size_t, size_t> ColorDimensionsToChannelDimensions( + std::pair<size_t, size_t> in, size_t c, size_t stage) const; + + std::pair<size_t, size_t> BorderToStore(size_t c) const; + + bool use_group_ids_; + + // Storage for borders between groups. Borders of adjacent groups are stacked + // together, e.g. bottom border of current group is followed by top border + // of next group. + std::vector<ImageF> borders_horizontal_; + std::vector<ImageF> borders_vertical_; + + // Manages the status of borders. + GroupBorderAssigner group_border_assigner_; + + // Size (in color-channel-pixels) of the border around each group that might + // be assigned to that group. + std::pair<size_t, size_t> group_border_; + // base_color_shift_ defines the size of groups in terms of final image + // pixels. + size_t base_color_shift_; + + // Buffer for decoded pixel data for a group, indexed by [thread][channel] or + // [group][channel] depending on `use_group_ids_`. + std::vector<std::vector<ImageF>> group_data_; + + // Borders for storing group data. + size_t group_data_x_border_; + size_t group_data_y_border_; + + // Buffers for intermediate rows for the various stages, indexed by + // [thread][channel][stage]. + std::vector<std::vector<std::vector<ImageF>>> stage_data_; + + // Buffers for out-of-frame data, indexed by [thread]; every row is a + // different channel. + std::vector<ImageF> out_of_frame_data_; + + // For each stage, a non-kIgnored channel. + std::vector<int32_t> anyc_; + + // Size of the image at each stage. + std::vector<Rect> image_rect_; + + // For each stage, for each channel, keep track of the kInOut stage that + // produced the input to that stage (which corresponds to the buffer index + // containing the data). -1 if data comes from the original input. + std::vector<std::vector<int32_t>> stage_input_for_channel_; + + // Number of (virtual) extra rows that must be processed at each stage + // to produce sufficient output for future stages. + std::vector<int> virtual_ypadding_for_output_; + + // Same thing for columns, except these are real columns and not virtual ones. + std::vector<int> xpadding_for_output_; + + // First stage that doesn't have any kInOut channel. + size_t first_trailing_stage_; + + // Origin and size of the frame after switching to image dimensions. + FrameOrigin frame_origin_; + size_t full_image_xsize_; + size_t full_image_ysize_; + size_t first_image_dim_stage_; +}; + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_LOW_MEMORY_RENDER_PIPELINE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline.cc new file mode 100644 index 0000000000..68b6ef613f --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline.cc @@ -0,0 +1,132 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/render_pipeline.h" + +#include <algorithm> + +#include "lib/jxl/render_pipeline/low_memory_render_pipeline.h" +#include "lib/jxl/render_pipeline/simple_render_pipeline.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { + +void RenderPipeline::Builder::AddStage( + std::unique_ptr<RenderPipelineStage> stage) { + stages_.push_back(std::move(stage)); +} + +std::unique_ptr<RenderPipeline> RenderPipeline::Builder::Finalize( + FrameDimensions frame_dimensions) && { +#if JXL_ENABLE_ASSERT + // Check that the last stage is not an kInOut stage for any channel, and that + // there is at least one stage. + JXL_ASSERT(!stages_.empty()); + for (size_t c = 0; c < num_c_; c++) { + JXL_ASSERT(stages_.back()->GetChannelMode(c) != + RenderPipelineChannelMode::kInOut); + } +#endif + + std::unique_ptr<RenderPipeline> res; + if (use_simple_implementation_) { + res = jxl::make_unique<SimpleRenderPipeline>(); + } else { + res = jxl::make_unique<LowMemoryRenderPipeline>(); + } + + res->padding_.resize(stages_.size()); + for (size_t i = stages_.size(); i-- > 0;) { + const auto& stage = stages_[i]; + res->padding_[i].resize(num_c_); + if (i + 1 == stages_.size()) { + continue; + } + for (size_t c = 0; c < num_c_; c++) { + if (stage->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + res->padding_[i][c].first = DivCeil(res->padding_[i + 1][c].first, + 1 << stage->settings_.shift_x) + + stage->settings_.border_x; + res->padding_[i][c].second = DivCeil(res->padding_[i + 1][c].second, + 1 << stage->settings_.shift_y) + + stage->settings_.border_y; + } else { + res->padding_[i][c] = res->padding_[i + 1][c]; + } + } + } + + res->frame_dimensions_ = frame_dimensions; + res->group_completed_passes_.resize(frame_dimensions.num_groups); + res->channel_shifts_.resize(stages_.size()); + res->channel_shifts_[0].resize(num_c_); + for (size_t i = 1; i < stages_.size(); i++) { + auto& stage = stages_[i - 1]; + for (size_t c = 0; c < num_c_; c++) { + if (stage->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + res->channel_shifts_[0][c].first += stage->settings_.shift_x; + res->channel_shifts_[0][c].second += stage->settings_.shift_y; + } + } + } + for (size_t i = 1; i < stages_.size(); i++) { + auto& stage = stages_[i - 1]; + res->channel_shifts_[i].resize(num_c_); + for (size_t c = 0; c < num_c_; c++) { + if (stage->GetChannelMode(c) == RenderPipelineChannelMode::kInOut) { + res->channel_shifts_[i][c].first = + res->channel_shifts_[i - 1][c].first - stage->settings_.shift_x; + res->channel_shifts_[i][c].second = + res->channel_shifts_[i - 1][c].second - stage->settings_.shift_y; + } else { + res->channel_shifts_[i][c].first = res->channel_shifts_[i - 1][c].first; + res->channel_shifts_[i][c].second = + res->channel_shifts_[i - 1][c].second; + } + } + } + res->stages_ = std::move(stages_); + res->Init(); + return res; +} + +RenderPipelineInput RenderPipeline::GetInputBuffers(size_t group_id, + size_t thread_id) { + RenderPipelineInput ret; + JXL_DASSERT(group_id < group_completed_passes_.size()); + ret.group_id_ = group_id; + ret.thread_id_ = thread_id; + ret.pipeline_ = this; + ret.buffers_ = PrepareBuffers(group_id, thread_id); + return ret; +} + +void RenderPipeline::InputReady( + size_t group_id, size_t thread_id, + const std::vector<std::pair<ImageF*, Rect>>& buffers) { + JXL_DASSERT(group_id < group_completed_passes_.size()); + group_completed_passes_[group_id]++; + for (size_t i = 0; i < buffers.size(); ++i) { + (void)i; + JXL_CHECK_PLANE_INITIALIZED(*buffers[i].first, buffers[i].second, i); + } + + ProcessBuffers(group_id, thread_id); +} + +Status RenderPipeline::PrepareForThreads(size_t num, bool use_group_ids) { + for (const auto& stage : stages_) { + JXL_RETURN_IF_ERROR(stage->PrepareForThreads(num)); + } + PrepareForThreadsInternal(num, use_group_ids); + return true; +} + +void RenderPipelineInput::Done() { + JXL_ASSERT(pipeline_); + pipeline_->InputReady(group_id_, thread_id_, buffers_); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline.h new file mode 100644 index 0000000000..bf3ad4975e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline.h @@ -0,0 +1,139 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_RENDER_PIPELINE_H_ +#define LIB_JXL_RENDER_PIPELINE_RENDER_PIPELINE_H_ + +#include <stdint.h> + +#include "lib/jxl/image.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Interface to provide input to the rendering pipeline. When this object is +// destroyed, all the data in the provided ImageF's Rects must have been +// initialized. +class RenderPipelineInput { + public: + RenderPipelineInput(const RenderPipelineInput&) = delete; + RenderPipelineInput(RenderPipelineInput&& other) noexcept { + *this = std::move(other); + } + RenderPipelineInput& operator=(RenderPipelineInput&& other) noexcept { + pipeline_ = other.pipeline_; + group_id_ = other.group_id_; + thread_id_ = other.thread_id_; + buffers_ = std::move(other.buffers_); + other.pipeline_ = nullptr; + return *this; + } + + RenderPipelineInput() = default; + void Done(); + + const std::pair<ImageF*, Rect>& GetBuffer(size_t c) const { + JXL_ASSERT(c < buffers_.size()); + return buffers_[c]; + } + + private: + RenderPipeline* pipeline_ = nullptr; + size_t group_id_; + size_t thread_id_; + std::vector<std::pair<ImageF*, Rect>> buffers_; + friend class RenderPipeline; +}; + +class RenderPipeline { + public: + class Builder { + public: + explicit Builder(size_t num_c) : num_c_(num_c) { JXL_ASSERT(num_c > 0); } + + // Adds a stage to the pipeline. Must be called at least once; the last + // added stage cannot have kInOut channels. + void AddStage(std::unique_ptr<RenderPipelineStage> stage); + + // Enables using the simple (i.e. non-memory-efficient) implementation of + // the pipeline. + void UseSimpleImplementation() { use_simple_implementation_ = true; } + + // Finalizes setup of the pipeline. Shifts for all channels should be 0 at + // this point. + std::unique_ptr<RenderPipeline> Finalize( + FrameDimensions frame_dimensions) &&; + + private: + std::vector<std::unique_ptr<RenderPipelineStage>> stages_; + size_t num_c_; + bool use_simple_implementation_ = false; + }; + + friend class Builder; + + virtual ~RenderPipeline() = default; + + Status IsInitialized() const { + for (const auto& stage : stages_) { + JXL_RETURN_IF_ERROR(stage->IsInitialized()); + } + return true; + } + + // Allocates storage to run with `num` threads. If `use_group_ids` is true, + // storage is allocated for each group, not each thread. The behaviour is + // undefined if calling this function multiple times with a different value + // for `use_group_ids`. + Status PrepareForThreads(size_t num, bool use_group_ids); + + // Retrieves a buffer where input data should be stored by the callee. When + // input has been provided for all buffers, the pipeline will complete its + // processing. This method may be called multiple times concurrently from + // different threads, provided that a different `thread_id` is given. + RenderPipelineInput GetInputBuffers(size_t group_id, size_t thread_id); + + size_t PassesWithAllInput() const { + return *std::min_element(group_completed_passes_.begin(), + group_completed_passes_.end()); + } + + virtual void ClearDone(size_t i) {} + + protected: + std::vector<std::unique_ptr<RenderPipelineStage>> stages_; + // Shifts for every channel at the input of each stage. + std::vector<std::vector<std::pair<size_t, size_t>>> channel_shifts_; + + // Amount of (cumulative) padding required by each stage and channel, in + // either direction. + std::vector<std::vector<std::pair<size_t, size_t>>> padding_; + + FrameDimensions frame_dimensions_; + + std::vector<uint8_t> group_completed_passes_; + + friend class RenderPipelineInput; + + private: + void InputReady(size_t group_id, size_t thread_id, + const std::vector<std::pair<ImageF*, Rect>>& buffers); + + virtual std::vector<std::pair<ImageF*, Rect>> PrepareBuffers( + size_t group_id, size_t thread_id) = 0; + + virtual void ProcessBuffers(size_t group_id, size_t thread_id) = 0; + + // Note that this method may be called multiple times with different (or + // equal) `num`. + virtual void PrepareForThreadsInternal(size_t num, bool use_group_ids) = 0; + + // Called once frame dimensions and stages are known. + virtual void Init() {} +}; + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_RENDER_PIPELINE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline_stage.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline_stage.h new file mode 100644 index 0000000000..d1a0074161 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline_stage.h @@ -0,0 +1,171 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_RENDER_PIPELINE_STAGE_H_ +#define LIB_JXL_RENDER_PIPELINE_RENDER_PIPELINE_STAGE_H_ + +#include <stdint.h> + +#include "lib/jxl/base/arch_macros.h" +#include "lib/jxl/frame_header.h" + +namespace jxl { + +// The first pixel in the input to RenderPipelineStage will be located at +// this position. Pixels before this position may be accessed as padding. +// This should be at least the RoundUpTo(maximum padding / 2, maximum vector +// size) times 2: this is realized when using Gaborish + EPF + upsampling + +// chroma subsampling. +#if JXL_ARCH_ARM +constexpr size_t kRenderPipelineXOffset = 16; +#else +constexpr size_t kRenderPipelineXOffset = 32; +#endif + +enum class RenderPipelineChannelMode { + // This channel is not modified by this stage. + kIgnored = 0, + // This channel is modified in-place. + kInPlace = 1, + // This channel is modified and written to a new buffer. + kInOut = 2, + // This channel is only read. These are the only stages that are assumed to + // have observable effects, i.e. calls to ProcessRow for other stages may be + // omitted if it can be shown they can't affect any kInput stage ProcessRow + // call that happens inside image boundaries. + kInput = 3, +}; + +class RenderPipeline; + +class RenderPipelineStage { + protected: + using Row = float*; + using ChannelRows = std::vector<Row>; + + public: + using RowInfo = std::vector<ChannelRows>; + struct Settings { + // Amount of padding required in the various directions by all channels + // that have kInOut mode. + size_t border_x = 0; + size_t border_y = 0; + + // Log2 of the number of columns/rows of output that this stage will produce + // for every input row for kInOut channels. + size_t shift_x = 0; + size_t shift_y = 0; + + static Settings ShiftX(size_t shift, size_t border) { + Settings settings; + settings.border_x = border; + settings.shift_x = shift; + return settings; + } + + static Settings ShiftY(size_t shift, size_t border) { + Settings settings; + settings.border_y = border; + settings.shift_y = shift; + return settings; + } + + static Settings Symmetric(size_t shift, size_t border) { + Settings settings; + settings.border_x = settings.border_y = border; + settings.shift_x = settings.shift_y = shift; + return settings; + } + + static Settings SymmetricBorderOnly(size_t border) { + return Symmetric(0, border); + } + }; + + virtual ~RenderPipelineStage() = default; + + // Processes one row of input, producing the appropriate number of rows of + // output. Input/output rows can be obtained by calls to + // `GetInputRow`/`GetOutputRow`. `xsize+2*xextra` represents the total number + // of pixels to be processed in the input row, where the first pixel is at + // position `kRenderPipelineXOffset-xextra`. All pixels in the + // `[kRenderPipelineXOffset-xextra-border_x, + // kRenderPipelineXOffset+xsize+xextra+border_x)` range are initialized and + // accessible. `xpos` and `ypos` represent the position of the first + // (non-extra, i.e. in position kRenderPipelineXOffset) pixel in the center + // row of the input in the full image. `xpos` is a multiple of + // `GroupBorderAssigner::kPaddingXRound`. If `settings_.temp_buffer_size` is + // nonzero, `temp` will point to an HWY-aligned buffer of at least that number + // of floats; concurrent calls will have different buffers. + virtual void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const = 0; + + // How each channel will be processed. Channels are numbered starting from + // color channels (always 3) and followed by all other channels. + virtual RenderPipelineChannelMode GetChannelMode(size_t c) const = 0; + + protected: + explicit RenderPipelineStage(Settings settings) : settings_(settings) {} + + virtual Status IsInitialized() const { return true; } + + // Informs the stage about the total size of each channel. Few stages will + // actually need to use this information. + virtual void SetInputSizes( + const std::vector<std::pair<size_t, size_t>>& input_sizes) {} + + virtual Status PrepareForThreads(size_t num_threads) { return true; } + + // Returns a pointer to the input row of channel `c` with offset `y`. + // `y` must be in [-settings_.border_y, settings_.border_y]. `c` must be such + // that `GetChannelMode(c) != kIgnored`. The returned pointer points to the + // offset-ed row (i.e. kRenderPipelineXOffset has been applied). + float* GetInputRow(const RowInfo& input_rows, size_t c, int offset) const { + JXL_DASSERT(GetChannelMode(c) != RenderPipelineChannelMode::kIgnored); + JXL_DASSERT(-offset <= static_cast<int>(settings_.border_y)); + JXL_DASSERT(offset <= static_cast<int>(settings_.border_y)); + return input_rows[c][settings_.border_y + offset] + kRenderPipelineXOffset; + } + // Similar to `GetInputRow`, but can only be used if `GetChannelMode(c) == + // kInOut`. Offset must be less than `1<<settings_.shift_y`.. The returned + // pointer points to the offset-ed row (i.e. kRenderPipelineXOffset has been + // applied). + float* GetOutputRow(const RowInfo& output_rows, size_t c, + size_t offset) const { + JXL_DASSERT(GetChannelMode(c) == RenderPipelineChannelMode::kInOut); + JXL_DASSERT(offset <= 1ul << settings_.shift_y); + return output_rows[c][offset] + kRenderPipelineXOffset; + } + + // Indicates whether, from this stage on, the pipeline will operate on an + // image- rather than frame-sized buffer. Only one stage in the pipeline + // should return true, and it should implement ProcessPaddingRow below too. + // It is assumed that, if there is a SwitchToImageDimensions() == true stage, + // all kInput stages appear after it. + virtual bool SwitchToImageDimensions() const { return false; } + + // If SwitchToImageDimensions returns true, then this should set xsize and + // ysize to the image size, and frame_origin to the location of the frame + // within the image. Otherwise, this is not called at all. + virtual void GetImageDimensions(size_t* xsize, size_t* ysize, + FrameOrigin* frame_origin) const {} + + // Produces the appropriate output data outside of the frame dimensions. xpos + // and ypos are now relative to the full image. + virtual void ProcessPaddingRow(const RowInfo& output_rows, size_t xsize, + size_t xpos, size_t ypos) const {} + + virtual const char* GetName() const = 0; + + Settings settings_; + friend class RenderPipeline; + friend class SimpleRenderPipeline; + friend class LowMemoryRenderPipeline; +}; + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_RENDER_PIPELINE_STAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline_test.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline_test.cc new file mode 100644 index 0000000000..51b9f273f8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/render_pipeline_test.cc @@ -0,0 +1,579 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/render_pipeline.h" + +#include <jxl/cms.h> + +#include <algorithm> +#include <cctype> +#include <cstdint> +#include <cstdio> +#include <ostream> +#include <sstream> +#include <string> +#include <utility> +#include <vector> + +#include "lib/extras/codec.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/override.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" // JXL_HIGH_PRECISION, JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_frame.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/fake_parallel_runner_testonly.h" +#include "lib/jxl/fields.h" +#include "lib/jxl/frame_dimensions.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/headers.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/jpeg/enc_jpeg_data.h" +#include "lib/jxl/render_pipeline/test_render_pipeline_stages.h" +#include "lib/jxl/splines.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +Status DecodeFile(const Span<const uint8_t> file, bool use_slow_pipeline, + CodecInOut* io, ThreadPool* pool) { + Status ret = true; + { + BitReader reader(file); + BitReaderScopedCloser reader_closer(&reader, &ret); + JXL_RETURN_IF_ERROR(reader.ReadFixedBits<16>() == 0x0AFF); + JXL_RETURN_IF_ERROR(ReadSizeHeader(&reader, &io->metadata.size)); + JXL_RETURN_IF_ERROR(ReadImageMetadata(&reader, &io->metadata.m)); + io->metadata.transform_data.nonserialized_xyb_encoded = + io->metadata.m.xyb_encoded; + JXL_RETURN_IF_ERROR(Bundle::Read(&reader, &io->metadata.transform_data)); + if (io->metadata.m.color_encoding.WantICC()) { + std::vector<uint8_t> icc; + JXL_RETURN_IF_ERROR(test::ReadICC(&reader, &icc)); + JXL_RETURN_IF_ERROR(io->metadata.m.color_encoding.SetICC( + std::move(icc), JxlGetDefaultCms())); + } + PassesDecoderState dec_state; + JXL_RETURN_IF_ERROR( + dec_state.output_encoding_info.SetFromMetadata(io->metadata)); + JXL_RETURN_IF_ERROR(reader.JumpToByteBoundary()); + io->frames.clear(); + FrameHeader frame_header(&io->metadata); + do { + io->frames.emplace_back(&io->metadata.m); + // Skip frames that are not displayed. + do { + size_t frame_start = reader.TotalBitsConsumed() / kBitsPerByte; + size_t size_left = file.size() - frame_start; + JXL_RETURN_IF_ERROR(DecodeFrame(&dec_state, pool, + file.data() + frame_start, size_left, + &frame_header, &io->frames.back(), + io->metadata, use_slow_pipeline)); + reader.SkipBits(io->frames.back().decoded_bytes() * kBitsPerByte); + } while (frame_header.frame_type != FrameType::kRegularFrame && + frame_header.frame_type != FrameType::kSkipProgressive); + } while (!frame_header.is_last); + + if (io->frames.empty()) return JXL_FAILURE("Not enough data."); + + if (reader.TotalBitsConsumed() != file.size() * kBitsPerByte) { + return JXL_FAILURE("Reader position not at EOF."); + } + if (!reader.AllReadsWithinBounds()) { + return JXL_FAILURE("Reader out of bounds read."); + } + io->CheckMetadata(); + // reader is closed here. + } + return ret; +} + +TEST(RenderPipelineTest, Build) { + RenderPipeline::Builder builder(/*num_c=*/1); + builder.AddStage(jxl::make_unique<UpsampleXSlowStage>()); + builder.AddStage(jxl::make_unique<UpsampleYSlowStage>()); + builder.AddStage(jxl::make_unique<Check0FinalStage>()); + builder.UseSimpleImplementation(); + FrameDimensions frame_dimensions; + frame_dimensions.Set(/*xsize=*/1024, /*ysize=*/1024, /*group_size_shift=*/0, + /*max_hshift=*/0, /*max_vshift=*/0, + /*modular_mode=*/false, /*upsampling=*/1); + std::move(builder).Finalize(frame_dimensions); +} + +TEST(RenderPipelineTest, CallAllGroups) { + RenderPipeline::Builder builder(/*num_c=*/1); + builder.AddStage(jxl::make_unique<UpsampleXSlowStage>()); + builder.AddStage(jxl::make_unique<UpsampleYSlowStage>()); + builder.AddStage(jxl::make_unique<Check0FinalStage>()); + builder.UseSimpleImplementation(); + FrameDimensions frame_dimensions; + frame_dimensions.Set(/*xsize=*/1024, /*ysize=*/1024, /*group_size_shift=*/0, + /*max_hshift=*/0, /*max_vshift=*/0, + /*modular_mode=*/false, /*upsampling=*/1); + auto pipeline = std::move(builder).Finalize(frame_dimensions); + ASSERT_TRUE(pipeline->PrepareForThreads(1, /*use_group_ids=*/false)); + + for (size_t i = 0; i < frame_dimensions.num_groups; i++) { + auto input_buffers = pipeline->GetInputBuffers(i, 0); + FillPlane(0.0f, input_buffers.GetBuffer(0).first, + input_buffers.GetBuffer(0).second); + input_buffers.Done(); + } + + EXPECT_EQ(pipeline->PassesWithAllInput(), 1); +} + +TEST(RenderPipelineTest, BuildFast) { + RenderPipeline::Builder builder(/*num_c=*/1); + builder.AddStage(jxl::make_unique<UpsampleXSlowStage>()); + builder.AddStage(jxl::make_unique<UpsampleYSlowStage>()); + builder.AddStage(jxl::make_unique<Check0FinalStage>()); + FrameDimensions frame_dimensions; + frame_dimensions.Set(/*xsize=*/1024, /*ysize=*/1024, /*group_size_shift=*/0, + /*max_hshift=*/0, /*max_vshift=*/0, + /*modular_mode=*/false, /*upsampling=*/1); + std::move(builder).Finalize(frame_dimensions); +} + +TEST(RenderPipelineTest, CallAllGroupsFast) { + RenderPipeline::Builder builder(/*num_c=*/1); + builder.AddStage(jxl::make_unique<UpsampleXSlowStage>()); + builder.AddStage(jxl::make_unique<UpsampleYSlowStage>()); + builder.AddStage(jxl::make_unique<Check0FinalStage>()); + builder.UseSimpleImplementation(); + FrameDimensions frame_dimensions; + frame_dimensions.Set(/*xsize=*/1024, /*ysize=*/1024, /*group_size_shift=*/0, + /*max_hshift=*/0, /*max_vshift=*/0, + /*modular_mode=*/false, /*upsampling=*/1); + auto pipeline = std::move(builder).Finalize(frame_dimensions); + ASSERT_TRUE(pipeline->PrepareForThreads(1, /*use_group_ids=*/false)); + + for (size_t i = 0; i < frame_dimensions.num_groups; i++) { + auto input_buffers = pipeline->GetInputBuffers(i, 0); + FillPlane(0.0f, input_buffers.GetBuffer(0).first, + input_buffers.GetBuffer(0).second); + input_buffers.Done(); + } + + EXPECT_EQ(pipeline->PassesWithAllInput(), 1); +} + +struct RenderPipelineTestInputSettings { + // Input image. + std::string input_path; + size_t xsize, ysize; + bool jpeg_transcode = false; + // Encoding settings. + CompressParams cparams; + // Short name for the encoder settings. + std::string cparams_descr; + + bool add_spot_color = false; + + Splines splines; +}; + +class RenderPipelineTestParam + : public ::testing::TestWithParam<RenderPipelineTestInputSettings> {}; + +TEST_P(RenderPipelineTestParam, PipelineTest) { + RenderPipelineTestInputSettings config = GetParam(); + + // Use a parallel runner that randomly shuffles tasks to detect possible + // border handling bugs. + FakeParallelRunner fake_pool(/*order_seed=*/123, /*num_threads=*/8); + ThreadPool pool(&JxlFakeParallelRunner, &fake_pool); + const std::vector<uint8_t> orig = jxl::test::ReadTestData(config.input_path); + + CodecInOut io; + if (config.jpeg_transcode) { + ASSERT_TRUE(jpeg::DecodeImageJPG(Bytes(orig), &io)); + } else { + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io, &pool)); + } + io.ShrinkTo(config.xsize, config.ysize); + + if (config.add_spot_color) { + jxl::ImageF spot(config.xsize, config.ysize); + jxl::ZeroFillImage(&spot); + + for (size_t y = 0; y < config.ysize; y++) { + float* JXL_RESTRICT row = spot.Row(y); + for (size_t x = 0; x < config.xsize; x++) { + row[x] = ((x ^ y) & 255) * (1.f / 255.f); + } + } + ExtraChannelInfo info; + info.bit_depth.bits_per_sample = 8; + info.dim_shift = 0; + info.type = jxl::ExtraChannel::kSpotColor; + info.spot_color[0] = 0.5f; + info.spot_color[1] = 0.2f; + info.spot_color[2] = 1.f; + info.spot_color[3] = 0.5f; + + io.metadata.m.extra_channel_info.push_back(info); + std::vector<jxl::ImageF> ec; + ec.push_back(std::move(spot)); + io.frames[0].SetExtraChannels(std::move(ec)); + } + + std::vector<uint8_t> compressed; + + config.cparams.custom_splines = config.splines; + ASSERT_TRUE(test::EncodeFile(config.cparams, &io, &compressed, &pool)); + + CodecInOut io_default; + ASSERT_TRUE(DecodeFile(Bytes(compressed), + /*use_slow_pipeline=*/false, &io_default, &pool)); + CodecInOut io_slow_pipeline; + ASSERT_TRUE(DecodeFile(Bytes(compressed), + /*use_slow_pipeline=*/true, &io_slow_pipeline, &pool)); + + ASSERT_EQ(io_default.frames.size(), io_slow_pipeline.frames.size()); + for (size_t i = 0; i < io_default.frames.size(); i++) { +#if JXL_HIGH_PRECISION + constexpr float kMaxError = 5e-5; +#else + constexpr float kMaxError = 5e-4; +#endif + Image3F def = std::move(*io_default.frames[i].color()); + Image3F pip = std::move(*io_slow_pipeline.frames[i].color()); + JXL_ASSERT_OK(VerifyRelativeError(pip, def, kMaxError, kMaxError, _)); + for (size_t ec = 0; ec < io_default.frames[i].extra_channels().size(); + ec++) { + JXL_ASSERT_OK(VerifyRelativeError( + io_slow_pipeline.frames[i].extra_channels()[ec], + io_default.frames[i].extra_channels()[ec], kMaxError, kMaxError, _)); + } + } +} + +Splines CreateTestSplines() { + const ColorCorrelationMap cmap; + std::vector<Spline::Point> control_points{{9, 54}, {118, 159}, {97, 3}, + {10, 40}, {150, 25}, {120, 300}}; + const Spline spline{ + control_points, + /*color_dct=*/ + {{0.03125f, 0.00625f, 0.003125f}, {1.f, 0.321875f}, {1.f, 0.24375f}}, + /*sigma_dct=*/{0.3125f, 0.f, 0.f, 0.0625f}}; + std::vector<Spline> spline_data = {spline}; + std::vector<QuantizedSpline> quantized_splines; + std::vector<Spline::Point> starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, /*quantization_adjustment=*/0, + cmap.YtoXRatio(0), cmap.YtoBRatio(0)); + starting_points.push_back(spline.control_points.front()); + } + return Splines(/*quantization_adjustment=*/0, std::move(quantized_splines), + std::move(starting_points)); +} + +std::vector<RenderPipelineTestInputSettings> GeneratePipelineTests() { + std::vector<RenderPipelineTestInputSettings> all_tests; + + std::pair<size_t, size_t> sizes[] = { + {3, 8}, {128, 128}, {256, 256}, {258, 258}, {533, 401}, {777, 777}, + }; + + for (auto size : sizes) { + RenderPipelineTestInputSettings settings; + settings.input_path = "jxl/flower/flower.png"; + settings.xsize = size.first; + settings.ysize = size.second; + + // Base settings. + settings.cparams.butteraugli_distance = 1.0; + settings.cparams.patches = Override::kOff; + settings.cparams.dots = Override::kOff; + settings.cparams.gaborish = Override::kOff; + settings.cparams.epf = 0; + settings.cparams.color_transform = ColorTransform::kXYB; + + { + auto s = settings; + s.cparams_descr = "NoGabNoEpfNoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.color_transform = ColorTransform::kNone; + s.cparams_descr = "NoGabNoEpfNoPatchesNoXYB"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.gaborish = Override::kOn; + s.cparams_descr = "GabNoEpfNoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.epf = 1; + s.cparams_descr = "NoGabEpf1NoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.epf = 2; + s.cparams_descr = "NoGabEpf2NoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.epf = 3; + s.cparams_descr = "NoGabEpf3NoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.gaborish = Override::kOn; + s.cparams.epf = 3; + s.cparams_descr = "GabEpf3NoPatches"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "Splines"; + s.splines = CreateTestSplines(); + all_tests.push_back(s); + } + + for (size_t ups : {2, 4, 8}) { + { + auto s = settings; + s.cparams.resampling = ups; + s.cparams_descr = "Ups" + std::to_string(ups); + all_tests.push_back(s); + } + { + auto s = settings; + s.cparams.resampling = ups; + s.cparams.epf = 1; + s.cparams_descr = "Ups" + std::to_string(ups) + "EPF1"; + all_tests.push_back(s); + } + { + auto s = settings; + s.cparams.resampling = ups; + s.cparams.gaborish = Override::kOn; + s.cparams.epf = 1; + s.cparams_descr = "Ups" + std::to_string(ups) + "GabEPF1"; + all_tests.push_back(s); + } + } + + { + auto s = settings; + s.cparams_descr = "Noise"; + s.cparams.photon_noise_iso = 3200; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "NoiseUps"; + s.cparams.photon_noise_iso = 3200; + s.cparams.resampling = 2; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "ModularLossless"; + s.cparams.modular_mode = true; + s.cparams.butteraugli_distance = 0; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "ProgressiveDC"; + s.cparams.progressive_dc = 1; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "ModularLossy"; + s.cparams.modular_mode = true; + s.cparams.butteraugli_distance = 1.f; + all_tests.push_back(s); + } + + { + auto s = settings; + s.input_path = "jxl/flower/flower_alpha.png"; + s.cparams_descr = "AlphaVarDCT"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.input_path = "jxl/flower/flower_alpha.png"; + s.cparams_descr = "AlphaVarDCTUpsamplingEPF"; + s.cparams.epf = 1; + s.cparams.ec_resampling = 2; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams.modular_mode = true; + s.cparams.butteraugli_distance = 0; + s.input_path = "jxl/flower/flower_alpha.png"; + s.cparams_descr = "AlphaLossless"; + all_tests.push_back(s); + } + + { + auto s = settings; + s.input_path = "jxl/flower/flower_alpha.png"; + s.cparams_descr = "AlphaDownsample"; + s.cparams.ec_resampling = 2; + all_tests.push_back(s); + } + + { + auto s = settings; + s.cparams_descr = "SpotColor"; + s.add_spot_color = true; + all_tests.push_back(s); + } + } + +#if JPEGXL_ENABLE_TRANSCODE_JPEG + for (const char* input : {"jxl/flower/flower.png.im_q85_444.jpg", + "jxl/flower/flower.png.im_q85_420.jpg", + "jxl/flower/flower.png.im_q85_422.jpg", + "jxl/flower/flower.png.im_q85_440.jpg"}) { + RenderPipelineTestInputSettings settings; + settings.input_path = input; + settings.jpeg_transcode = true; + settings.xsize = 2268; + settings.ysize = 1512; + settings.cparams_descr = "Default"; + all_tests.push_back(settings); + } + +#endif + + { + RenderPipelineTestInputSettings settings; + settings.input_path = "jxl/grayscale_patches.png"; + settings.xsize = 1011; + settings.ysize = 277; + settings.cparams_descr = "Patches"; + all_tests.push_back(settings); + } + + { + RenderPipelineTestInputSettings settings; + settings.input_path = "jxl/grayscale_patches.png"; + settings.xsize = 1011; + settings.ysize = 277; + settings.cparams.photon_noise_iso = 1000; + settings.cparams_descr = "PatchesAndNoise"; + all_tests.push_back(settings); + } + + { + RenderPipelineTestInputSettings settings; + settings.input_path = "jxl/grayscale_patches.png"; + settings.xsize = 1011; + settings.ysize = 277; + settings.cparams.resampling = 2; + settings.cparams_descr = "PatchesAndUps2"; + all_tests.push_back(settings); + } + + return all_tests; +} + +std::ostream& operator<<(std::ostream& os, + const RenderPipelineTestInputSettings& c) { + std::string filename; + size_t pos = c.input_path.find_last_of('/'); + if (pos == std::string::npos) { + filename = c.input_path; + } else { + filename = c.input_path.substr(pos + 1); + } + std::replace_if( + filename.begin(), filename.end(), [](char c) { return !isalnum(c); }, + '_'); + os << filename << "_" << (c.jpeg_transcode ? "JPEG_" : "") << c.xsize << "x" + << c.ysize << "_" << c.cparams_descr; + return os; +} + +std::string PipelineTestDescription( + const testing::TestParamInfo<RenderPipelineTestParam::ParamType>& info) { + std::stringstream name; + name << info.param; + return name.str(); +} + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P(RenderPipelineTest, RenderPipelineTestParam, + testing::ValuesIn(GeneratePipelineTests()), + PipelineTestDescription); + +TEST(RenderPipelineDecodingTest, Animation) { + FakeParallelRunner fake_pool(/*order_seed=*/123, /*num_threads=*/8); + ThreadPool pool(&JxlFakeParallelRunner, &fake_pool); + + std::vector<uint8_t> compressed = + jxl::test::ReadTestData("jxl/blending/cropped_traffic_light.jxl"); + + CodecInOut io_default; + ASSERT_TRUE(DecodeFile(Bytes(compressed), + /*use_slow_pipeline=*/false, &io_default, &pool)); + CodecInOut io_slow_pipeline; + ASSERT_TRUE(DecodeFile(Bytes(compressed), + /*use_slow_pipeline=*/true, &io_slow_pipeline, &pool)); + + ASSERT_EQ(io_default.frames.size(), io_slow_pipeline.frames.size()); + for (size_t i = 0; i < io_default.frames.size(); i++) { +#if JXL_HIGH_PRECISION + constexpr float kMaxError = 1e-5; +#else + constexpr float kMaxError = 1e-4; +#endif + + Image3F fast_pipeline = std::move(*io_default.frames[i].color()); + Image3F slow_pipeline = std::move(*io_slow_pipeline.frames[i].color()); + JXL_ASSERT_OK(VerifyRelativeError(slow_pipeline, fast_pipeline, kMaxError, + kMaxError, _)) + for (size_t ec = 0; ec < io_default.frames[i].extra_channels().size(); + ec++) { + JXL_ASSERT_OK(VerifyRelativeError( + io_slow_pipeline.frames[i].extra_channels()[ec], + io_default.frames[i].extra_channels()[ec], kMaxError, kMaxError, _)); + } + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/simple_render_pipeline.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/simple_render_pipeline.cc new file mode 100644 index 0000000000..4495288860 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/simple_render_pipeline.cc @@ -0,0 +1,266 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/simple_render_pipeline.h" + +#include <hwy/base.h> + +#include "lib/jxl/image_ops.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" +#include "lib/jxl/sanitizers.h" + +namespace jxl { + +void SimpleRenderPipeline::PrepareForThreadsInternal(size_t num, + bool use_group_ids) { + if (!channel_data_.empty()) { + return; + } + auto ch_size = [](size_t frame_size, size_t shift) { + return DivCeil(frame_size, 1 << shift) + kRenderPipelineXOffset * 2; + }; + for (size_t c = 0; c < channel_shifts_[0].size(); c++) { + channel_data_.push_back(ImageF( + ch_size(frame_dimensions_.xsize_upsampled, channel_shifts_[0][c].first), + ch_size(frame_dimensions_.ysize_upsampled, + channel_shifts_[0][c].second))); + msan::PoisonImage(channel_data_.back()); + } +} + +Rect SimpleRenderPipeline::MakeChannelRect(size_t group_id, size_t channel) { + size_t base_color_shift = + CeilLog2Nonzero(frame_dimensions_.xsize_upsampled_padded / + frame_dimensions_.xsize_padded); + + const size_t gx = group_id % frame_dimensions_.xsize_groups; + const size_t gy = group_id / frame_dimensions_.xsize_groups; + size_t xgroupdim = (frame_dimensions_.group_dim << base_color_shift) >> + channel_shifts_[0][channel].first; + size_t ygroupdim = (frame_dimensions_.group_dim << base_color_shift) >> + channel_shifts_[0][channel].second; + return Rect( + kRenderPipelineXOffset + gx * xgroupdim, + kRenderPipelineXOffset + gy * ygroupdim, xgroupdim, ygroupdim, + kRenderPipelineXOffset + DivCeil(frame_dimensions_.xsize_upsampled, + 1 << channel_shifts_[0][channel].first), + kRenderPipelineXOffset + + DivCeil(frame_dimensions_.ysize_upsampled, + 1 << channel_shifts_[0][channel].second)); +} + +std::vector<std::pair<ImageF*, Rect>> SimpleRenderPipeline::PrepareBuffers( + size_t group_id, size_t thread_id) { + std::vector<std::pair<ImageF*, Rect>> ret; + for (size_t c = 0; c < channel_data_.size(); c++) { + ret.emplace_back(&channel_data_[c], MakeChannelRect(group_id, c)); + } + return ret; +} + +void SimpleRenderPipeline::ProcessBuffers(size_t group_id, size_t thread_id) { + for (size_t c = 0; c < channel_data_.size(); c++) { + Rect r = MakeChannelRect(group_id, c); + (void)r; + JXL_CHECK_PLANE_INITIALIZED(channel_data_[c], r, c); + } + + if (PassesWithAllInput() <= processed_passes_) return; + processed_passes_++; + + for (size_t stage_id = 0; stage_id < stages_.size(); stage_id++) { + const auto& stage = stages_[stage_id]; + // Prepare buffers for kInOut channels. + std::vector<ImageF> new_channels(channel_data_.size()); + std::vector<ImageF*> output_channels(channel_data_.size()); + + std::vector<std::pair<size_t, size_t>> input_sizes(channel_data_.size()); + for (size_t c = 0; c < channel_data_.size(); c++) { + input_sizes[c] = + std::make_pair(channel_data_[c].xsize() - kRenderPipelineXOffset * 2, + channel_data_[c].ysize() - kRenderPipelineXOffset * 2); + } + + for (size_t c = 0; c < channel_data_.size(); c++) { + if (stage->GetChannelMode(c) != RenderPipelineChannelMode::kInOut) { + continue; + } + // Ensure that the newly allocated channels are large enough to avoid + // problems with padding. + new_channels[c] = + ImageF(frame_dimensions_.xsize_upsampled_padded + + kRenderPipelineXOffset * 2 + hwy::kMaxVectorSize * 8, + frame_dimensions_.ysize_upsampled_padded + + kRenderPipelineXOffset * 2); + new_channels[c].ShrinkTo( + (input_sizes[c].first << stage->settings_.shift_x) + + kRenderPipelineXOffset * 2, + (input_sizes[c].second << stage->settings_.shift_y) + + kRenderPipelineXOffset * 2); + output_channels[c] = &new_channels[c]; + } + + auto get_row = [&](size_t c, int64_t y) { + return channel_data_[c].Row(kRenderPipelineXOffset + y) + + kRenderPipelineXOffset; + }; + + // Add mirrored pixes to all kInOut channels. + for (size_t c = 0; c < channel_data_.size(); c++) { + if (stage->GetChannelMode(c) != RenderPipelineChannelMode::kInOut) { + continue; + } + // Horizontal mirroring. + for (size_t y = 0; y < input_sizes[c].second; y++) { + float* row = get_row(c, y); + for (size_t ix = 0; ix < stage->settings_.border_x; ix++) { + *(row - ix - 1) = row[Mirror(-ssize_t(ix) - 1, input_sizes[c].first)]; + } + for (size_t ix = 0; ix < stage->settings_.border_x; ix++) { + *(row + ix + input_sizes[c].first) = + row[Mirror(ix + input_sizes[c].first, input_sizes[c].first)]; + } + } + // Vertical mirroring. + for (int y = 0; y < static_cast<int>(stage->settings_.border_y); y++) { + memcpy(get_row(c, -y - 1) - stage->settings_.border_x, + get_row(c, Mirror(-ssize_t(y) - 1, input_sizes[c].second)) - + stage->settings_.border_x, + sizeof(float) * + (input_sizes[c].first + 2 * stage->settings_.border_x)); + } + for (int y = 0; y < static_cast<int>(stage->settings_.border_y); y++) { + memcpy( + get_row(c, input_sizes[c].second + y) - stage->settings_.border_x, + get_row(c, + Mirror(input_sizes[c].second + y, input_sizes[c].second)) - + stage->settings_.border_x, + sizeof(float) * + (input_sizes[c].first + 2 * stage->settings_.border_x)); + } + } + + size_t ysize = 0; + size_t xsize = 0; + for (size_t c = 0; c < channel_data_.size(); c++) { + if (stage->GetChannelMode(c) == RenderPipelineChannelMode::kIgnored) { + continue; + } + ysize = std::max(input_sizes[c].second, ysize); + xsize = std::max(input_sizes[c].first, xsize); + } + + JXL_ASSERT(ysize != 0); + JXL_ASSERT(xsize != 0); + + RenderPipelineStage::RowInfo input_rows(channel_data_.size()); + RenderPipelineStage::RowInfo output_rows(channel_data_.size()); + + // Run the pipeline. + { + stage->SetInputSizes(input_sizes); + int border_y = stage->settings_.border_y; + for (size_t y = 0; y < ysize; y++) { + // Prepare input rows. + for (size_t c = 0; c < channel_data_.size(); c++) { + if (stage->GetChannelMode(c) == RenderPipelineChannelMode::kIgnored) { + continue; + } + input_rows[c].resize(2 * border_y + 1); + for (int iy = -border_y; iy <= border_y; iy++) { + input_rows[c][iy + border_y] = + channel_data_[c].Row(y + kRenderPipelineXOffset + iy); + } + } + // Prepare output rows. + for (size_t c = 0; c < channel_data_.size(); c++) { + if (!output_channels[c]) continue; + output_rows[c].resize(1 << stage->settings_.shift_y); + for (size_t iy = 0; iy < output_rows[c].size(); iy++) { + output_rows[c][iy] = output_channels[c]->Row( + (y << stage->settings_.shift_y) + iy + kRenderPipelineXOffset); + } + } + stage->ProcessRow(input_rows, output_rows, /*xextra=*/0, xsize, + /*xpos=*/0, y, thread_id); + } + } + + // Move new channels to current channels. + for (size_t c = 0; c < channel_data_.size(); c++) { + if (stage->GetChannelMode(c) != RenderPipelineChannelMode::kInOut) { + continue; + } + channel_data_[c] = std::move(new_channels[c]); + } + for (size_t c = 0; c < channel_data_.size(); c++) { + size_t next_stage = std::min(stage_id + 1, channel_shifts_.size() - 1); + size_t xsize = DivCeil(frame_dimensions_.xsize_upsampled, + 1 << channel_shifts_[next_stage][c].first); + size_t ysize = DivCeil(frame_dimensions_.ysize_upsampled, + 1 << channel_shifts_[next_stage][c].second); + channel_data_[c].ShrinkTo(xsize + 2 * kRenderPipelineXOffset, + ysize + 2 * kRenderPipelineXOffset); + JXL_CHECK_PLANE_INITIALIZED( + channel_data_[c], + Rect(kRenderPipelineXOffset, kRenderPipelineXOffset, xsize, ysize), + c); + } + + if (stage->SwitchToImageDimensions()) { + size_t image_xsize, image_ysize; + FrameOrigin frame_origin; + stage->GetImageDimensions(&image_xsize, &image_ysize, &frame_origin); + frame_dimensions_.Set(image_xsize, image_ysize, 0, 0, 0, false, 1); + std::vector<ImageF> old_channels = std::move(channel_data_); + channel_data_.clear(); + channel_data_.reserve(old_channels.size()); + for (size_t c = 0; c < old_channels.size(); c++) { + channel_data_.emplace_back(2 * kRenderPipelineXOffset + image_xsize, + 2 * kRenderPipelineXOffset + image_ysize); + } + for (size_t y = 0; y < image_ysize; ++y) { + for (size_t c = 0; c < channel_data_.size(); c++) { + output_rows[c].resize(1); + output_rows[c][0] = channel_data_[c].Row(kRenderPipelineXOffset + y); + } + // TODO(sboukortt): consider doing this only on the parts of the + // background that won't be occluded. + stage->ProcessPaddingRow(output_rows, image_xsize, 0, y); + } + ssize_t x0 = frame_origin.x0; + ssize_t y0 = frame_origin.y0; + size_t x0_fg = 0; + size_t y0_fg = 0; + if (x0 < 0) { + xsize += x0; + x0_fg -= x0; + x0 = 0; + } + if (x0 + xsize > image_xsize) { + xsize = image_xsize - x0; + } + if (y0 < 0) { + ysize += y0; + y0_fg -= x0; + y0 = 0; + } + if (y0 + ysize > image_ysize) { + ysize = image_ysize - y0; + } + const Rect rect_fg_relative_to_image = + Rect(x0, y0, xsize, ysize) + .Translate(kRenderPipelineXOffset, kRenderPipelineXOffset); + const Rect rect_fg = + Rect(x0_fg, y0_fg, xsize, ysize) + .Translate(kRenderPipelineXOffset, kRenderPipelineXOffset); + for (size_t c = 0; c < channel_data_.size(); c++) { + CopyImageTo(rect_fg, old_channels[c], rect_fg_relative_to_image, + &channel_data_[c]); + } + } + } +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/simple_render_pipeline.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/simple_render_pipeline.h new file mode 100644 index 0000000000..10f4505912 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/simple_render_pipeline.h @@ -0,0 +1,37 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_SIMPLE_RENDER_PIPELINE_H_ +#define LIB_JXL_RENDER_PIPELINE_SIMPLE_RENDER_PIPELINE_H_ + +#include <stdint.h> + +#include "lib/jxl/render_pipeline/render_pipeline.h" + +namespace jxl { + +// A RenderPipeline that is "obviously correct"; it may use potentially large +// amounts of memory and be slow. It is intended to be used mostly for testing +// purposes. +class SimpleRenderPipeline : public RenderPipeline { + std::vector<std::pair<ImageF*, Rect>> PrepareBuffers( + size_t group_id, size_t thread_id) override; + + void ProcessBuffers(size_t group_id, size_t thread_id) override; + + void PrepareForThreadsInternal(size_t num, bool use_group_ids) override; + + // Full frame buffers. Both X and Y dimensions are padded by + // kRenderPipelineXOffset. + std::vector<ImageF> channel_data_; + size_t processed_passes_ = 0; + + private: + Rect MakeChannelRect(size_t group_id, size_t channel); +}; + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_SIMPLE_RENDER_PIPELINE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_blending.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_blending.cc new file mode 100644 index 0000000000..b68105f4c9 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_blending.cc @@ -0,0 +1,250 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_blending.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_blending.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/blending.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +class BlendingStage : public RenderPipelineStage { + public: + explicit BlendingStage(const FrameHeader& frame_header, + const PassesDecoderState* dec_state, + const ColorEncoding& frame_color_encoding) + : RenderPipelineStage(RenderPipelineStage::Settings()), + frame_header_(frame_header), + state_(*dec_state->shared) { + image_xsize_ = frame_header_.nonserialized_metadata->xsize(); + image_ysize_ = frame_header_.nonserialized_metadata->ysize(); + extra_channel_info_ = + &frame_header_.nonserialized_metadata->m.extra_channel_info; + info_ = frame_header_.blending_info; + const std::vector<BlendingInfo>& ec_info = + frame_header_.extra_channel_blending_info; + const ImageBundle& bg = state_.reference_frames[info_.source].frame; + bg_ = &bg; + if (bg.xsize() == 0 || bg.ysize() == 0) { + zeroes_.resize(image_xsize_, 0.f); + } else if (state_.reference_frames[info_.source].ib_is_in_xyb) { + initialized_ = JXL_FAILURE( + "Trying to blend XYB reference frame %i and non-XYB frame", + info_.source); + return; + } else if (std::any_of(ec_info.begin(), ec_info.end(), + [this](const BlendingInfo& info) { + const ImageBundle& bg = + state_.reference_frames[info.source].frame; + return bg.xsize() == 0 || bg.ysize() == 0; + })) { + zeroes_.resize(image_xsize_, 0.f); + } + + auto verify_bg_size = [&](const ImageBundle& bg) -> Status { + if (bg.xsize() != 0 && bg.ysize() != 0 && + (bg.xsize() < image_xsize_ || bg.ysize() < image_ysize_ || + bg.origin.x0 != 0 || bg.origin.y0 != 0)) { + return JXL_FAILURE("Trying to use a %" PRIuS "x%" PRIuS + " crop as a background", + bg.xsize(), bg.ysize()); + } + return true; + }; + + Status ok = verify_bg_size(bg); + for (const auto& info : ec_info) { + const ImageBundle& bg = state_.reference_frames[info.source].frame; + if (!!ok) ok = verify_bg_size(bg); + } + if (!ok) { + initialized_ = ok; + return; + } + + if (state_.metadata->m.xyb_encoded) { + if (!dec_state->output_encoding_info.color_encoding_is_original) { + initialized_ = JXL_FAILURE("Blending in unsupported color space"); + return; + } + } + + blending_info_.resize(ec_info.size() + 1); + auto make_blending = [&](const BlendingInfo& info, PatchBlending* pb) { + pb->alpha_channel = info.alpha_channel; + pb->clamp = info.clamp; + switch (info.mode) { + case BlendMode::kReplace: { + pb->mode = PatchBlendMode::kReplace; + break; + } + case BlendMode::kAdd: { + pb->mode = PatchBlendMode::kAdd; + break; + } + case BlendMode::kMul: { + pb->mode = PatchBlendMode::kMul; + break; + } + case BlendMode::kBlend: { + pb->mode = PatchBlendMode::kBlendAbove; + break; + } + case BlendMode::kAlphaWeightedAdd: { + pb->mode = PatchBlendMode::kAlphaWeightedAddAbove; + break; + } + default: { + JXL_UNREACHABLE( + "Invalid blend mode"); // should have failed to decode + } + } + }; + make_blending(info_, &blending_info_[0]); + for (size_t i = 0; i < ec_info.size(); i++) { + make_blending(ec_info[i], &blending_info_[1 + i]); + } + } + + Status IsInitialized() const override { return initialized_; } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + JXL_ASSERT(initialized_); + const FrameOrigin& frame_origin = frame_header_.frame_origin; + ssize_t bg_xpos = frame_origin.x0 + static_cast<ssize_t>(xpos); + ssize_t bg_ypos = frame_origin.y0 + static_cast<ssize_t>(ypos); + int offset = 0; + if (bg_xpos + static_cast<ssize_t>(xsize) <= 0 || + frame_origin.x0 >= static_cast<ssize_t>(image_xsize_) || bg_ypos < 0 || + bg_ypos >= static_cast<ssize_t>(image_ysize_)) { + return; + } + if (bg_xpos < 0) { + offset -= bg_xpos; + xsize += bg_xpos; + bg_xpos = 0; + } + if (bg_xpos + xsize > image_xsize_) { + xsize = + std::max<ssize_t>(0, static_cast<ssize_t>(image_xsize_) - bg_xpos); + } + std::vector<const float*> bg_row_ptrs_(input_rows.size()); + std::vector<float*> fg_row_ptrs_(input_rows.size()); + size_t num_c = std::min(input_rows.size(), extra_channel_info_->size() + 3); + for (size_t c = 0; c < num_c; ++c) { + fg_row_ptrs_[c] = GetInputRow(input_rows, c, 0) + offset; + if (c < 3) { + bg_row_ptrs_[c] = bg_->xsize() != 0 && bg_->ysize() != 0 + ? bg_->color().ConstPlaneRow(c, bg_ypos) + bg_xpos + : zeroes_.data(); + } else { + const ImageBundle& ec_bg = + state_ + .reference_frames + [frame_header_.extra_channel_blending_info[c - 3].source] + .frame; + bg_row_ptrs_[c] = + ec_bg.xsize() != 0 && ec_bg.ysize() != 0 + ? ec_bg.extra_channels()[c - 3].ConstRow(bg_ypos) + bg_xpos + : zeroes_.data(); + } + } + PerformBlending(bg_row_ptrs_.data(), fg_row_ptrs_.data(), + fg_row_ptrs_.data(), 0, xsize, blending_info_[0], + blending_info_.data() + 1, *extra_channel_info_); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return RenderPipelineChannelMode::kInPlace; + } + + bool SwitchToImageDimensions() const override { return true; } + + void GetImageDimensions(size_t* xsize, size_t* ysize, + FrameOrigin* frame_origin) const override { + *xsize = image_xsize_; + *ysize = image_ysize_; + *frame_origin = frame_header_.frame_origin; + } + + void ProcessPaddingRow(const RowInfo& output_rows, size_t xsize, size_t xpos, + size_t ypos) const override { + if (bg_->xsize() == 0 || bg_->ysize() == 0) { + for (size_t c = 0; c < 3; ++c) { + memset(GetInputRow(output_rows, c, 0), 0, xsize * sizeof(float)); + } + } else { + for (size_t c = 0; c < 3; ++c) { + memcpy(GetInputRow(output_rows, c, 0), + bg_->color().ConstPlaneRow(c, ypos) + xpos, + xsize * sizeof(float)); + } + } + for (size_t ec = 0; ec < extra_channel_info_->size(); ++ec) { + const ImageBundle& ec_bg = + state_ + .reference_frames[frame_header_.extra_channel_blending_info[ec] + .source] + .frame; + if (ec_bg.xsize() == 0 || ec_bg.ysize() == 0) { + memset(GetInputRow(output_rows, 3 + ec, 0), 0, xsize * sizeof(float)); + } else { + memcpy(GetInputRow(output_rows, 3 + ec, 0), + ec_bg.extra_channels()[ec].ConstRow(ypos) + xpos, + xsize * sizeof(float)); + } + } + } + + const char* GetName() const override { return "Blending"; } + + private: + const FrameHeader& frame_header_; + const PassesSharedState& state_; + BlendingInfo info_; + const ImageBundle* bg_; + Status initialized_ = true; + size_t image_xsize_; + size_t image_ysize_; + std::vector<PatchBlending> blending_info_; + const std::vector<ExtraChannelInfo>* extra_channel_info_; + std::vector<float> zeroes_; +}; + +std::unique_ptr<RenderPipelineStage> GetBlendingStage( + const FrameHeader& frame_header, const PassesDecoderState* dec_state, + const ColorEncoding& frame_color_encoding) { + return jxl::make_unique<BlendingStage>(frame_header, dec_state, + frame_color_encoding); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetBlendingStage); + +std::unique_ptr<RenderPipelineStage> GetBlendingStage( + const FrameHeader& frame_header, const PassesDecoderState* dec_state, + const ColorEncoding& frame_color_encoding) { + return HWY_DYNAMIC_DISPATCH(GetBlendingStage)(frame_header, dec_state, + frame_color_encoding); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_blending.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_blending.h new file mode 100644 index 0000000000..aedc8c2e99 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_blending.h @@ -0,0 +1,25 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_BLENDING_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_BLENDING_H_ + +#include <memory> + +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Applies blending if applicable. +std::unique_ptr<RenderPipelineStage> GetBlendingStage( + const FrameHeader& frame_header, const PassesDecoderState* dec_state, + const ColorEncoding& frame_color_encoding); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_BLENDING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_chroma_upsampling.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_chroma_upsampling.cc new file mode 100644 index 0000000000..936fbd3a44 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_chroma_upsampling.cc @@ -0,0 +1,127 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_chroma_upsampling.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_chroma_upsampling.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/simd_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; + +class HorizontalChromaUpsamplingStage : public RenderPipelineStage { + public: + explicit HorizontalChromaUpsamplingStage(size_t channel) + : RenderPipelineStage(RenderPipelineStage::Settings::ShiftX( + /*shift=*/1, /*border=*/1)), + c_(channel) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + HWY_FULL(float) df; + xextra = RoundUpTo(xextra, Lanes(df)); + auto threefour = Set(df, 0.75f); + auto onefour = Set(df, 0.25f); + const float* row_in = GetInputRow(input_rows, c_, 0); + float* row_out = GetOutputRow(output_rows, c_, 0); + for (ssize_t x = -xextra; x < static_cast<ssize_t>(xsize + xextra); + x += Lanes(df)) { + auto current = Mul(LoadU(df, row_in + x), threefour); + auto prev = LoadU(df, row_in + x - 1); + auto next = LoadU(df, row_in + x + 1); + auto left = MulAdd(onefour, prev, current); + auto right = MulAdd(onefour, next, current); + StoreInterleaved(df, left, right, row_out + x * 2); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c == c_ ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "HChromaUps"; } + + private: + size_t c_; +}; + +class VerticalChromaUpsamplingStage : public RenderPipelineStage { + public: + explicit VerticalChromaUpsamplingStage(size_t channel) + : RenderPipelineStage(RenderPipelineStage::Settings::ShiftY( + /*shift=*/1, /*border=*/1)), + c_(channel) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + HWY_FULL(float) df; + xextra = RoundUpTo(xextra, Lanes(df)); + auto threefour = Set(df, 0.75f); + auto onefour = Set(df, 0.25f); + const float* row_top = GetInputRow(input_rows, c_, -1); + const float* row_mid = GetInputRow(input_rows, c_, 0); + const float* row_bot = GetInputRow(input_rows, c_, 1); + float* row_out0 = GetOutputRow(output_rows, c_, 0); + float* row_out1 = GetOutputRow(output_rows, c_, 1); + for (ssize_t x = -xextra; x < static_cast<ssize_t>(xsize + xextra); + x += Lanes(df)) { + auto it = LoadU(df, row_top + x); + auto im = LoadU(df, row_mid + x); + auto ib = LoadU(df, row_bot + x); + auto im_scaled = Mul(im, threefour); + Store(MulAdd(it, onefour, im_scaled), df, row_out0 + x); + Store(MulAdd(ib, onefour, im_scaled), df, row_out1 + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c == c_ ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "VChromaUps"; } + + private: + size_t c_; +}; + +std::unique_ptr<RenderPipelineStage> GetChromaUpsamplingStage(size_t channel, + bool horizontal) { + if (horizontal) { + return jxl::make_unique<HorizontalChromaUpsamplingStage>(channel); + } else { + return jxl::make_unique<VerticalChromaUpsamplingStage>(channel); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetChromaUpsamplingStage); + +std::unique_ptr<RenderPipelineStage> GetChromaUpsamplingStage(size_t channel, + bool horizontal) { + return HWY_DYNAMIC_DISPATCH(GetChromaUpsamplingStage)(channel, horizontal); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_chroma_upsampling.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_chroma_upsampling.h new file mode 100644 index 0000000000..b4d0cbdfd3 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_chroma_upsampling.h @@ -0,0 +1,26 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_CHROMA_UPSAMPLING_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_CHROMA_UPSAMPLING_H_ +#include <math.h> +#include <stdint.h> + +#include <algorithm> +#include <utility> +#include <vector> + +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Applies simple upsampling, either horizontal or vertical, to the given +// channel. +std::unique_ptr<RenderPipelineStage> GetChromaUpsamplingStage(size_t channel, + bool horizontal); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_CHROMA_UPSAMPLING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_cms.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_cms.cc new file mode 100644 index 0000000000..2465146b47 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_cms.cc @@ -0,0 +1,134 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_cms.h" + +#include <memory> + +#include "jxl/cms_interface.h" +#include "jxl/color_encoding.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/common.h" +#include "lib/jxl/dec_xyb.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_cms.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/sanitizers.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +class CmsStage : public RenderPipelineStage { + public: + explicit CmsStage(OutputEncodingInfo output_encoding_info) + : RenderPipelineStage(RenderPipelineStage::Settings()), + output_encoding_info_(std::move(output_encoding_info)) { + c_src_ = output_encoding_info_.linear_color_encoding; + } + + bool IsNeeded() const { + const size_t channels_src = (c_src_.IsCMYK() ? 4 : c_src_.Channels()); + const size_t channels_dst = output_encoding_info_.color_encoding.Channels(); + const bool not_mixing_color_and_grey = + (channels_src == channels_dst || + (channels_src == 4 && channels_dst == 3)); + return (output_encoding_info_.cms_set) && + !c_src_.SameColorEncoding(output_encoding_info_.color_encoding) && + not_mixing_color_and_grey; + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + JXL_ASSERT(xsize == xsize_); + // TODO(firsching): handle grey case seperately + // interleave + float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0); + float* mutable_buf_src = color_space_transform->BufSrc(thread_id); + + for (size_t x = 0; x < xsize; x++) { + mutable_buf_src[3 * x + 0] = row0[x]; + mutable_buf_src[3 * x + 1] = row1[x]; + mutable_buf_src[3 * x + 2] = row2[x]; + } + const float* buf_src = mutable_buf_src; + float* JXL_RESTRICT buf_dst = color_space_transform->BufDst(thread_id); + if (!color_space_transform->Run(thread_id, buf_src, buf_dst)) { + // TODO(firsching): somehow mark failing here? + return; + } + // de-interleave + for (size_t x = 0; x < xsize; x++) { + row0[x] = buf_dst[3 * x + 0]; + row1[x] = buf_dst[3 * x + 1]; + row2[x] = buf_dst[3 * x + 2]; + } + } + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Cms"; } + + private: + OutputEncodingInfo output_encoding_info_; + size_t xsize_; + std::unique_ptr<jxl::ColorSpaceTransform> color_space_transform; + ColorEncoding c_src_; + + void SetInputSizes( + const std::vector<std::pair<size_t, size_t>>& input_sizes) override { +#if JXL_ENABLE_ASSERT + JXL_ASSERT(input_sizes.size() >= 3); + for (size_t c = 1; c < input_sizes.size(); c++) { + JXL_ASSERT(input_sizes[c].first == input_sizes[0].first); + JXL_ASSERT(input_sizes[c].second == input_sizes[0].second); + } +#endif + xsize_ = input_sizes[0].first; + } + + Status PrepareForThreads(size_t num_threads) override { + color_space_transform = jxl::make_unique<jxl::ColorSpaceTransform>( + output_encoding_info_.color_management_system); + JXL_RETURN_IF_ERROR(color_space_transform->Init( + c_src_, output_encoding_info_.color_encoding, + output_encoding_info_.desired_intensity_target, xsize_, num_threads)); + return true; + } +}; + +std::unique_ptr<RenderPipelineStage> GetCmsStage( + const OutputEncodingInfo& output_encoding_info) { + auto stage = jxl::make_unique<CmsStage>(output_encoding_info); + if (!stage->IsNeeded()) return nullptr; + return stage; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetCmsStage); + +std::unique_ptr<RenderPipelineStage> GetCmsStage( + const OutputEncodingInfo& output_encoding_info) { + return HWY_DYNAMIC_DISPATCH(GetCmsStage)(output_encoding_info); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_cms.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_cms.h new file mode 100644 index 0000000000..23277ae6f7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_cms.h @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_CMS_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_CMS_H_ + +#include <memory> + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +std::unique_ptr<RenderPipelineStage> GetCmsStage( + const OutputEncodingInfo& output_encoding_info); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_CMS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_epf.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_epf.cc new file mode 100644 index 0000000000..5d1a379ede --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_epf.cc @@ -0,0 +1,526 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_epf.h" + +#include "lib/jxl/base/common.h" +#include "lib/jxl/common.h" // JXL_HIGH_PRECISION +#include "lib/jxl/epf.h" +#include "lib/jxl/sanitizers.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_epf.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +// TODO(veluca): In principle, vectors could be not capped, if we want to deal +// with having two different sigma values in a single vector. +using DF = HWY_CAPPED(float, 8); + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::AbsDiff; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Div; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Vec; +using hwy::HWY_NAMESPACE::VFromD; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +JXL_INLINE Vec<DF> Weight(Vec<DF> sad, Vec<DF> inv_sigma, Vec<DF> thres) { + auto v = MulAdd(sad, inv_sigma, Set(DF(), 1.0f)); + return ZeroIfNegative(v); +} + +// 5x5 plus-shaped kernel with 5 SADs per pixel (3x3 plus-shaped). So this makes +// this filter a 7x7 filter. +class EPF0Stage : public RenderPipelineStage { + public: + EPF0Stage(const LoopFilter& lf, const ImageF& sigma) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/3)), + lf_(lf), + sigma_(&sigma) {} + + template <bool aligned> + JXL_INLINE void AddPixel(int row, float* JXL_RESTRICT rows[3][7], ssize_t x, + Vec<DF> sad, Vec<DF> inv_sigma, + Vec<DF>* JXL_RESTRICT X, Vec<DF>* JXL_RESTRICT Y, + Vec<DF>* JXL_RESTRICT B, + Vec<DF>* JXL_RESTRICT w) const { + auto cx = aligned ? Load(DF(), rows[0][3 + row] + x) + : LoadU(DF(), rows[0][3 + row] + x); + auto cy = aligned ? Load(DF(), rows[1][3 + row] + x) + : LoadU(DF(), rows[1][3 + row] + x); + auto cb = aligned ? Load(DF(), rows[2][3 + row] + x) + : LoadU(DF(), rows[2][3 + row] + x); + + auto weight = Weight(sad, inv_sigma, Set(DF(), lf_.epf_pass1_zeroflush)); + *w = Add(*w, weight); + *X = MulAdd(weight, cx, *X); + *Y = MulAdd(weight, cy, *Y); + *B = MulAdd(weight, cb, *B); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + DF df; + + using V = decltype(Zero(df)); + V t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, tA, tB; + V* sads[12] = {&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7, &t8, &t9, &tA, &tB}; + + xextra = RoundUpTo(xextra, Lanes(df)); + const float* JXL_RESTRICT row_sigma = + sigma_->Row(ypos / kBlockDim + kSigmaPadding); + + float sm = lf_.epf_pass0_sigma_scale * 1.65; + float bsm = sm * lf_.epf_border_sad_mul; + + HWY_ALIGN float sad_mul_center[kBlockDim] = {bsm, sm, sm, sm, + sm, sm, sm, bsm}; + HWY_ALIGN float sad_mul_border[kBlockDim] = {bsm, bsm, bsm, bsm, + bsm, bsm, bsm, bsm}; + float* JXL_RESTRICT rows[3][7]; + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 7; i++) { + rows[c][i] = GetInputRow(input_rows, c, i - 3); + } + } + + const float* sad_mul = + (ypos % kBlockDim == 0 || ypos % kBlockDim == kBlockDim - 1) + ? sad_mul_border + : sad_mul_center; + + for (ssize_t x = -xextra; x < static_cast<ssize_t>(xsize + xextra); + x += Lanes(df)) { + size_t bx = (x + xpos + kSigmaPadding * kBlockDim) / kBlockDim; + size_t ix = (x + xpos) % kBlockDim; + + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows[c][3 + 0] + x); + StoreU(px, df, GetOutputRow(output_rows, c, 0) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Mul(Set(df, row_sigma[bx]), sm); + + for (size_t i = 0; i < 12; i++) *sads[i] = Zero(df); + constexpr std::array<int, 2> sads_off[12] = { + {{-2, 0}}, {{-1, -1}}, {{-1, 0}}, {{-1, 1}}, {{0, -2}}, {{0, -1}}, + {{0, 1}}, {{0, 2}}, {{1, -1}}, {{1, 0}}, {{1, 1}}, {{2, 0}}, + }; + + // compute sads + // TODO(veluca): consider unrolling and optimizing this. + for (size_t c = 0; c < 3; c++) { + auto scale = Set(df, lf_.epf_channel_scale[c]); + for (size_t i = 0; i < 12; i++) { + auto sad = Zero(df); + constexpr std::array<int, 2> plus_off[] = { + {{0, 0}}, {{-1, 0}}, {{0, -1}}, {{1, 0}}, {{0, 1}}}; + for (size_t j = 0; j < 5; j++) { + const auto r11 = + LoadU(df, rows[c][3 + plus_off[j][0]] + x + plus_off[j][1]); + const auto c11 = + LoadU(df, rows[c][3 + sads_off[i][0] + plus_off[j][0]] + x + + sads_off[i][1] + plus_off[j][1]); + sad = Add(sad, AbsDiff(r11, c11)); + } + *sads[i] = MulAdd(sad, scale, *sads[i]); + } + } + const auto x_cc = Load(df, rows[0][3 + 0] + x); + const auto y_cc = Load(df, rows[1][3 + 0] + x); + const auto b_cc = Load(df, rows[2][3 + 0] + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + for (size_t i = 0; i < 12; i++) { + AddPixel</*aligned=*/false>(/*row=*/sads_off[i][0], rows, + x + sads_off[i][1], *sads[i], inv_sigma, &X, + &Y, &B, &w); + } +#if JXL_HIGH_PRECISION + auto inv_w = Div(Set(df, 1.0f), w); +#else + auto inv_w = ApproximateReciprocal(w); +#endif + StoreU(Mul(X, inv_w), df, GetOutputRow(output_rows, 0, 0) + x); + StoreU(Mul(Y, inv_w), df, GetOutputRow(output_rows, 1, 0) + x); + StoreU(Mul(B, inv_w), df, GetOutputRow(output_rows, 2, 0) + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "EPF0"; } + + private: + LoopFilter lf_; + const ImageF* sigma_; +}; + +// 3x3 plus-shaped kernel with 5 SADs per pixel (also 3x3 plus-shaped). So this +// makes this filter a 5x5 filter. +class EPF1Stage : public RenderPipelineStage { + public: + EPF1Stage(const LoopFilter& lf, const ImageF& sigma) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/2)), + lf_(lf), + sigma_(&sigma) {} + + template <bool aligned> + JXL_INLINE void AddPixel(int row, float* JXL_RESTRICT rows[3][5], ssize_t x, + Vec<DF> sad, Vec<DF> inv_sigma, + Vec<DF>* JXL_RESTRICT X, Vec<DF>* JXL_RESTRICT Y, + Vec<DF>* JXL_RESTRICT B, + Vec<DF>* JXL_RESTRICT w) const { + auto cx = aligned ? Load(DF(), rows[0][2 + row] + x) + : LoadU(DF(), rows[0][2 + row] + x); + auto cy = aligned ? Load(DF(), rows[1][2 + row] + x) + : LoadU(DF(), rows[1][2 + row] + x); + auto cb = aligned ? Load(DF(), rows[2][2 + row] + x) + : LoadU(DF(), rows[2][2 + row] + x); + + auto weight = Weight(sad, inv_sigma, Set(DF(), lf_.epf_pass1_zeroflush)); + *w = Add(*w, weight); + *X = MulAdd(weight, cx, *X); + *Y = MulAdd(weight, cy, *Y); + *B = MulAdd(weight, cb, *B); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + DF df; + xextra = RoundUpTo(xextra, Lanes(df)); + const float* JXL_RESTRICT row_sigma = + sigma_->Row(ypos / kBlockDim + kSigmaPadding); + + float sm = 1.65f; + float bsm = sm * lf_.epf_border_sad_mul; + + HWY_ALIGN float sad_mul_center[kBlockDim] = {bsm, sm, sm, sm, + sm, sm, sm, bsm}; + HWY_ALIGN float sad_mul_border[kBlockDim] = {bsm, bsm, bsm, bsm, + bsm, bsm, bsm, bsm}; + + float* JXL_RESTRICT rows[3][5]; + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 5; i++) { + rows[c][i] = GetInputRow(input_rows, c, i - 2); + } + } + + const float* sad_mul = + (ypos % kBlockDim == 0 || ypos % kBlockDim == kBlockDim - 1) + ? sad_mul_border + : sad_mul_center; + + for (ssize_t x = -xextra; x < static_cast<ssize_t>(xsize + xextra); + x += Lanes(df)) { + size_t bx = (x + xpos + kSigmaPadding * kBlockDim) / kBlockDim; + size_t ix = (x + xpos) % kBlockDim; + + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows[c][2 + 0] + x); + Store(px, df, GetOutputRow(output_rows, c, 0) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Mul(Set(df, row_sigma[bx]), sm); + auto sad0 = Zero(df); + auto sad1 = Zero(df); + auto sad2 = Zero(df); + auto sad3 = Zero(df); + + // compute sads + for (size_t c = 0; c < 3; c++) { + // center px = 22, px above = 21 + auto t = Undefined(df); + + const auto p20 = Load(df, rows[c][2 + -2] + x); + const auto p21 = Load(df, rows[c][2 + -1] + x); + auto sad0c = AbsDiff(p20, p21); // SAD 2, 1 + + const auto p11 = LoadU(df, rows[c][2 + -1] + x - 1); + auto sad1c = AbsDiff(p11, p21); // SAD 1, 2 + + const auto p31 = LoadU(df, rows[c][2 + -1] + x + 1); + auto sad2c = AbsDiff(p31, p21); // SAD 3, 2 + + const auto p02 = LoadU(df, rows[c][2 + 0] + x - 2); + const auto p12 = LoadU(df, rows[c][2 + 0] + x - 1); + sad1c = Add(sad1c, AbsDiff(p02, p12)); // SAD 1, 2 + sad0c = Add(sad0c, AbsDiff(p11, p12)); // SAD 2, 1 + + const auto p22 = LoadU(df, rows[c][2 + 0] + x); + t = AbsDiff(p12, p22); + sad1c = Add(sad1c, t); // SAD 1, 2 + sad2c = Add(sad2c, t); // SAD 3, 2 + t = AbsDiff(p22, p21); + auto sad3c = t; // SAD 2, 3 + sad0c = Add(sad0c, t); // SAD 2, 1 + + const auto p32 = LoadU(df, rows[c][2 + 0] + x + 1); + sad0c = Add(sad0c, AbsDiff(p31, p32)); // SAD 2, 1 + t = AbsDiff(p22, p32); + sad1c = Add(sad1c, t); // SAD 1, 2 + sad2c = Add(sad2c, t); // SAD 3, 2 + + const auto p42 = LoadU(df, rows[c][2 + 0] + x + 2); + sad2c = Add(sad2c, AbsDiff(p42, p32)); // SAD 3, 2 + + const auto p13 = LoadU(df, rows[c][2 + 1] + x - 1); + sad3c = Add(sad3c, AbsDiff(p13, p12)); // SAD 2, 3 + + const auto p23 = Load(df, rows[c][2 + 1] + x); + t = AbsDiff(p22, p23); + sad0c = Add(sad0c, t); // SAD 2, 1 + sad3c = Add(sad3c, t); // SAD 2, 3 + sad1c = Add(sad1c, AbsDiff(p13, p23)); // SAD 1, 2 + + const auto p33 = LoadU(df, rows[c][2 + 1] + x + 1); + sad2c = Add(sad2c, AbsDiff(p33, p23)); // SAD 3, 2 + sad3c = Add(sad3c, AbsDiff(p33, p32)); // SAD 2, 3 + + const auto p24 = Load(df, rows[c][2 + 2] + x); + sad3c = Add(sad3c, AbsDiff(p24, p23)); // SAD 2, 3 + + auto scale = Set(df, lf_.epf_channel_scale[c]); + sad0 = MulAdd(sad0c, scale, sad0); + sad1 = MulAdd(sad1c, scale, sad1); + sad2 = MulAdd(sad2c, scale, sad2); + sad3 = MulAdd(sad3c, scale, sad3); + } + const auto x_cc = Load(df, rows[0][2 + 0] + x); + const auto y_cc = Load(df, rows[1][2 + 0] + x); + const auto b_cc = Load(df, rows[2][2 + 0] + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + // Top row + AddPixel</*aligned=*/true>(/*row=*/-1, rows, x, sad0, inv_sigma, &X, &Y, + &B, &w); + // Center + AddPixel</*aligned=*/false>(/*row=*/0, rows, x - 1, sad1, inv_sigma, &X, + &Y, &B, &w); + AddPixel</*aligned=*/false>(/*row=*/0, rows, x + 1, sad2, inv_sigma, &X, + &Y, &B, &w); + // Bottom + AddPixel</*aligned=*/true>(/*row=*/1, rows, x, sad3, inv_sigma, &X, &Y, + &B, &w); +#if JXL_HIGH_PRECISION + auto inv_w = Div(Set(df, 1.0f), w); +#else + auto inv_w = ApproximateReciprocal(w); +#endif + Store(Mul(X, inv_w), df, GetOutputRow(output_rows, 0, 0) + x); + Store(Mul(Y, inv_w), df, GetOutputRow(output_rows, 1, 0) + x); + Store(Mul(B, inv_w), df, GetOutputRow(output_rows, 2, 0) + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "EPF1"; } + + private: + LoopFilter lf_; + const ImageF* sigma_; +}; + +// 3x3 plus-shaped kernel with 1 SAD per pixel. So this makes this filter a 3x3 +// filter. +class EPF2Stage : public RenderPipelineStage { + public: + EPF2Stage(const LoopFilter& lf, const ImageF& sigma) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/1)), + lf_(lf), + sigma_(&sigma) {} + + template <bool aligned> + JXL_INLINE void AddPixel(int row, float* JXL_RESTRICT rows[3][3], ssize_t x, + Vec<DF> rx, Vec<DF> ry, Vec<DF> rb, + Vec<DF> inv_sigma, Vec<DF>* JXL_RESTRICT X, + Vec<DF>* JXL_RESTRICT Y, Vec<DF>* JXL_RESTRICT B, + Vec<DF>* JXL_RESTRICT w) const { + auto cx = aligned ? Load(DF(), rows[0][1 + row] + x) + : LoadU(DF(), rows[0][1 + row] + x); + auto cy = aligned ? Load(DF(), rows[1][1 + row] + x) + : LoadU(DF(), rows[1][1 + row] + x); + auto cb = aligned ? Load(DF(), rows[2][1 + row] + x) + : LoadU(DF(), rows[2][1 + row] + x); + + auto sad = Mul(AbsDiff(cx, rx), Set(DF(), lf_.epf_channel_scale[0])); + sad = MulAdd(AbsDiff(cy, ry), Set(DF(), lf_.epf_channel_scale[1]), sad); + sad = MulAdd(AbsDiff(cb, rb), Set(DF(), lf_.epf_channel_scale[2]), sad); + + auto weight = Weight(sad, inv_sigma, Set(DF(), lf_.epf_pass2_zeroflush)); + + *w = Add(*w, weight); + *X = MulAdd(weight, cx, *X); + *Y = MulAdd(weight, cy, *Y); + *B = MulAdd(weight, cb, *B); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + DF df; + xextra = RoundUpTo(xextra, Lanes(df)); + const float* JXL_RESTRICT row_sigma = + sigma_->Row(ypos / kBlockDim + kSigmaPadding); + + float sm = lf_.epf_pass2_sigma_scale * 1.65; + float bsm = sm * lf_.epf_border_sad_mul; + + HWY_ALIGN float sad_mul_center[kBlockDim] = {bsm, sm, sm, sm, + sm, sm, sm, bsm}; + HWY_ALIGN float sad_mul_border[kBlockDim] = {bsm, bsm, bsm, bsm, + bsm, bsm, bsm, bsm}; + + float* JXL_RESTRICT rows[3][3]; + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 3; i++) { + rows[c][i] = GetInputRow(input_rows, c, i - 1); + } + } + + const float* sad_mul = + (ypos % kBlockDim == 0 || ypos % kBlockDim == kBlockDim - 1) + ? sad_mul_border + : sad_mul_center; + + for (ssize_t x = -xextra; x < static_cast<ssize_t>(xsize + xextra); + x += Lanes(df)) { + size_t bx = (x + xpos + kSigmaPadding * kBlockDim) / kBlockDim; + size_t ix = (x + xpos) % kBlockDim; + + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows[c][1 + 0] + x); + Store(px, df, GetOutputRow(output_rows, c, 0) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Mul(Set(df, row_sigma[bx]), sm); + + const auto x_cc = Load(df, rows[0][1 + 0] + x); + const auto y_cc = Load(df, rows[1][1 + 0] + x); + const auto b_cc = Load(df, rows[2][1 + 0] + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + // Top row + AddPixel</*aligned=*/true>(/*row=*/-1, rows, x, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); + // Center + AddPixel</*aligned=*/false>(/*row=*/0, rows, x - 1, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); + AddPixel</*aligned=*/false>(/*row=*/0, rows, x + 1, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); + // Bottom + AddPixel</*aligned=*/true>(/*row=*/1, rows, x, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); +#if JXL_HIGH_PRECISION + auto inv_w = Div(Set(df, 1.0f), w); +#else + auto inv_w = ApproximateReciprocal(w); +#endif + Store(Mul(X, inv_w), df, GetOutputRow(output_rows, 0, 0) + x); + Store(Mul(Y, inv_w), df, GetOutputRow(output_rows, 1, 0) + x); + Store(Mul(B, inv_w), df, GetOutputRow(output_rows, 2, 0) + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "EPF2"; } + + private: + LoopFilter lf_; + const ImageF* sigma_; +}; + +std::unique_ptr<RenderPipelineStage> GetEPFStage0(const LoopFilter& lf, + const ImageF& sigma) { + return jxl::make_unique<EPF0Stage>(lf, sigma); +} + +std::unique_ptr<RenderPipelineStage> GetEPFStage1(const LoopFilter& lf, + const ImageF& sigma) { + return jxl::make_unique<EPF1Stage>(lf, sigma); +} + +std::unique_ptr<RenderPipelineStage> GetEPFStage2(const LoopFilter& lf, + const ImageF& sigma) { + return jxl::make_unique<EPF2Stage>(lf, sigma); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetEPFStage0); +HWY_EXPORT(GetEPFStage1); +HWY_EXPORT(GetEPFStage2); + +std::unique_ptr<RenderPipelineStage> GetEPFStage(const LoopFilter& lf, + const ImageF& sigma, + size_t epf_stage) { + JXL_ASSERT(lf.epf_iters != 0); + switch (epf_stage) { + case 0: + return HWY_DYNAMIC_DISPATCH(GetEPFStage0)(lf, sigma); + case 1: + return HWY_DYNAMIC_DISPATCH(GetEPFStage1)(lf, sigma); + case 2: + return HWY_DYNAMIC_DISPATCH(GetEPFStage2)(lf, sigma); + default: + JXL_UNREACHABLE("Invalid EPF stage"); + } +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_epf.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_epf.h new file mode 100644 index 0000000000..c9d0d0c785 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_epf.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_EPF_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_EPF_H_ +#include <math.h> +#include <stdint.h> +#include <stdio.h> + +#include <algorithm> +#include <utility> +#include <vector> + +#include "lib/jxl/image.h" +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Applies the `epf_stage`-th EPF step with the given settings and `sigma`. +// `sigma` will be accessed with an offset of (kSigmaPadding, kSigmaPadding), +// and should have (kSigmaBorder, kSigmaBorder) mirrored sigma values available +// around the main image. See also filters.(h|cc) +std::unique_ptr<RenderPipelineStage> GetEPFStage(const LoopFilter& lf, + const ImageF& sigma, + size_t epf_stage); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_EPF_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_from_linear.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_from_linear.cc new file mode 100644 index 0000000000..6b1f646cd5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_from_linear.cc @@ -0,0 +1,194 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_from_linear.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_from_linear.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/cms/tone_mapping-inl.h" +#include "lib/jxl/cms/transfer_functions-inl.h" +#include "lib/jxl/common.h" // JXL_HIGH_PRECISION +#include "lib/jxl/sanitizers.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::IfThenZeroElse; + +template <typename Op> +struct PerChannelOp { + explicit PerChannelOp(Op op) : op(op) {} + template <typename D, typename T> + void Transform(D d, T* r, T* g, T* b) const { + *r = op.Transform(d, *r); + *g = op.Transform(d, *g); + *b = op.Transform(d, *b); + } + + Op op; +}; +template <typename Op> +PerChannelOp<Op> MakePerChannelOp(Op&& op) { + return PerChannelOp<Op>(std::forward<Op>(op)); +} + +struct OpLinear { + template <typename D, typename T> + T Transform(D d, const T& linear) const { + return linear; + } +}; + +struct OpRgb { + template <typename D, typename T> + T Transform(D d, const T& linear) const { +#if JXL_HIGH_PRECISION + return TF_SRGB().EncodedFromDisplay(d, linear); +#else + return FastLinearToSRGB(d, linear); +#endif + } +}; + +struct OpPq { + explicit OpPq(const float intensity_target) : tf_pq_(intensity_target) {} + template <typename D, typename T> + T Transform(D d, const T& linear) const { + return tf_pq_.EncodedFromDisplay(d, linear); + } + TF_PQ tf_pq_; +}; + +struct OpHlg { + explicit OpHlg(const float luminances[3], const float intensity_target) + : hlg_ootf_(HlgOOTF::ToSceneLight(/*display_luminance=*/intensity_target, + luminances)) {} + + template <typename D, typename T> + void Transform(D d, T* r, T* g, T* b) const { + hlg_ootf_.Apply(r, g, b); + *r = TF_HLG().EncodedFromDisplay(d, *r); + *g = TF_HLG().EncodedFromDisplay(d, *g); + *b = TF_HLG().EncodedFromDisplay(d, *b); + } + HlgOOTF hlg_ootf_; +}; + +struct Op709 { + template <typename D, typename T> + T Transform(D d, const T& linear) const { + return TF_709().EncodedFromDisplay(d, linear); + } +}; + +struct OpGamma { + const float inverse_gamma; + template <typename D, typename T> + T Transform(D d, const T& linear) const { + return IfThenZeroElse(Le(linear, Set(d, 1e-5f)), + FastPowf(d, linear, Set(d, inverse_gamma))); + } +}; + +template <typename Op> +class FromLinearStage : public RenderPipelineStage { + public: + explicit FromLinearStage(Op op) + : RenderPipelineStage(RenderPipelineStage::Settings()), + op_(std::move(op)) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + const HWY_FULL(float) d; + const size_t xsize_v = RoundUpTo(xsize, Lanes(d)); + float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0); + // All calculations are lane-wise, still some might require + // value-dependent behaviour (e.g. NearestInt). Temporary unpoison last + // vector tail. + msan::UnpoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + for (ssize_t x = -xextra; x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + auto r = LoadU(d, row0 + x); + auto g = LoadU(d, row1 + x); + auto b = LoadU(d, row2 + x); + op_.Transform(d, &r, &g, &b); + StoreU(r, d, row0 + x); + StoreU(g, d, row1 + x); + StoreU(b, d, row2 + x); + } + msan::PoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "FromLinear"; } + + private: + Op op_; +}; + +template <typename Op> +std::unique_ptr<FromLinearStage<Op>> MakeFromLinearStage(Op&& op) { + return jxl::make_unique<FromLinearStage<Op>>(std::forward<Op>(op)); +} + +std::unique_ptr<RenderPipelineStage> GetFromLinearStage( + const OutputEncodingInfo& output_encoding_info) { + const auto& tf = output_encoding_info.color_encoding.Tf(); + if (tf.IsLinear()) { + return MakeFromLinearStage(MakePerChannelOp(OpLinear())); + } else if (tf.IsSRGB()) { + return MakeFromLinearStage(MakePerChannelOp(OpRgb())); + } else if (tf.IsPQ()) { + return MakeFromLinearStage( + MakePerChannelOp(OpPq(output_encoding_info.orig_intensity_target))); + } else if (tf.IsHLG()) { + return MakeFromLinearStage( + OpHlg(output_encoding_info.luminances, + output_encoding_info.desired_intensity_target)); + } else if (tf.Is709()) { + return MakeFromLinearStage(MakePerChannelOp(Op709())); + } else if (tf.have_gamma || tf.IsDCI()) { + return MakeFromLinearStage( + MakePerChannelOp(OpGamma{output_encoding_info.inverse_gamma})); + } else { + // This is a programming error. + JXL_UNREACHABLE("Invalid target encoding"); + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetFromLinearStage); + +std::unique_ptr<RenderPipelineStage> GetFromLinearStage( + const OutputEncodingInfo& output_encoding_info) { + return HWY_DYNAMIC_DISPATCH(GetFromLinearStage)(output_encoding_info); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_from_linear.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_from_linear.h new file mode 100644 index 0000000000..548ab50b8c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_from_linear.h @@ -0,0 +1,20 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_FROM_LINEAR_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_FROM_LINEAR_H_ + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Converts the color channels from linear to the specified output encoding. +std::unique_ptr<RenderPipelineStage> GetFromLinearStage( + const OutputEncodingInfo& output_encoding_info); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_FROM_LINEAR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_gaborish.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_gaborish.cc new file mode 100644 index 0000000000..0917db3f9a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_gaborish.cc @@ -0,0 +1,120 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_gaborish.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_gaborish.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; + +class GaborishStage : public RenderPipelineStage { + public: + explicit GaborishStage(const LoopFilter& lf) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/1)) { + weights_[0] = 1; + weights_[1] = lf.gab_x_weight1; + weights_[2] = lf.gab_x_weight2; + weights_[3] = 1; + weights_[4] = lf.gab_y_weight1; + weights_[5] = lf.gab_y_weight2; + weights_[6] = 1; + weights_[7] = lf.gab_b_weight1; + weights_[8] = lf.gab_b_weight2; + // Normalize + for (size_t c = 0; c < 3; c++) { + const float div = + weights_[3 * c] + 4 * (weights_[3 * c + 1] + weights_[3 * c + 2]); + const float mul = 1.0f / div; + weights_[3 * c] *= mul; + weights_[3 * c + 1] *= mul; + weights_[3 * c + 2] *= mul; + } + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + const HWY_FULL(float) d; + for (size_t c = 0; c < 3; c++) { + float* JXL_RESTRICT row_t = GetInputRow(input_rows, c, -1); + float* JXL_RESTRICT row_m = GetInputRow(input_rows, c, 0); + float* JXL_RESTRICT row_b = GetInputRow(input_rows, c, 1); + float* JXL_RESTRICT row_out = GetOutputRow(output_rows, c, 0); + const auto w0 = Set(d, weights_[3 * c + 0]); + const auto w1 = Set(d, weights_[3 * c + 1]); + const auto w2 = Set(d, weights_[3 * c + 2]); +// Group data need only be aligned to a block; for >=512 bit vectors, this may +// result in unaligned loads. +#if HWY_CAP_GE512 +#define LoadMaybeU LoadU +#else +#define LoadMaybeU Load +#endif + // Since GetInputRow(input_rows, c, {-1, 0, 1}) is aligned, rounding + // xextra up to Lanes(d) doesn't access anything problematic. + for (ssize_t x = -RoundUpTo(xextra, Lanes(d)); + x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + const auto t = LoadMaybeU(d, row_t + x); + const auto tl = LoadU(d, row_t + x - 1); + const auto tr = LoadU(d, row_t + x + 1); + const auto m = LoadMaybeU(d, row_m + x); + const auto l = LoadU(d, row_m + x - 1); + const auto r = LoadU(d, row_m + x + 1); + const auto b = LoadMaybeU(d, row_b + x); + const auto bl = LoadU(d, row_b + x - 1); + const auto br = LoadU(d, row_b + x + 1); + const auto sum0 = m; + const auto sum1 = Add(Add(l, r), Add(t, b)); + const auto sum2 = Add(Add(tl, tr), Add(bl, br)); + auto pixels = MulAdd(sum2, w2, MulAdd(sum1, w1, Mul(sum0, w0))); + Store(pixels, d, row_out + x); + } + } + } +#undef LoadMaybeU + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Gab"; } + + private: + float weights_[9]; +}; + +std::unique_ptr<RenderPipelineStage> GetGaborishStage(const LoopFilter& lf) { + return jxl::make_unique<GaborishStage>(lf); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetGaborishStage); + +std::unique_ptr<RenderPipelineStage> GetGaborishStage(const LoopFilter& lf) { + JXL_ASSERT(lf.gab == 1); + return HWY_DYNAMIC_DISPATCH(GetGaborishStage)(lf); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_gaborish.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_gaborish.h new file mode 100644 index 0000000000..55166e3ed8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_gaborish.h @@ -0,0 +1,24 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_GABORISH_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_GABORISH_H_ +#include <math.h> +#include <stdint.h> + +#include <algorithm> +#include <utility> +#include <vector> + +#include "lib/jxl/loop_filter.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Applies decoder-side Gaborish with the given settings. `lf.gab` must be 1. +std::unique_ptr<RenderPipelineStage> GetGaborishStage(const LoopFilter& lf); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_GABORISH_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_noise.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_noise.cc new file mode 100644 index 0000000000..5cf8a6ed51 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_noise.cc @@ -0,0 +1,316 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_noise.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_noise.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/sanitizers.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::And; +using hwy::HWY_NAMESPACE::Floor; +using hwy::HWY_NAMESPACE::Ge; +using hwy::HWY_NAMESPACE::IfThenElse; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Min; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Or; +using hwy::HWY_NAMESPACE::Sub; +using hwy::HWY_NAMESPACE::TableLookupBytes; +using hwy::HWY_NAMESPACE::Vec; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +using D = HWY_CAPPED(float, kBlockDim); +using DI = hwy::HWY_NAMESPACE::Rebind<int32_t, D>; +using DI8 = hwy::HWY_NAMESPACE::Repartition<uint8_t, D>; + +// [0, max_value] +template <class D, class V> +static HWY_INLINE V Clamp0ToMax(D d, const V x, const V max_value) { + const auto clamped = Min(x, max_value); + return ZeroIfNegative(clamped); +} + +// x is in [0+delta, 1+delta], delta ~= 0.06 +template <class StrengthEval> +typename StrengthEval::V NoiseStrength(const StrengthEval& eval, + const typename StrengthEval::V x) { + return Clamp0ToMax(D(), eval(x), Set(D(), 1.0f)); +} + +// TODO(veluca): SIMD-fy. +class StrengthEvalLut { + public: + using V = Vec<D>; + + explicit StrengthEvalLut(const NoiseParams& noise_params) +#if HWY_TARGET == HWY_SCALAR + : noise_params_(noise_params) +#endif + { +#if HWY_TARGET != HWY_SCALAR + uint32_t lut[8]; + memcpy(lut, noise_params.lut, sizeof(lut)); + for (size_t i = 0; i < 8; i++) { + low16_lut[2 * i] = (lut[i] >> 0) & 0xFF; + low16_lut[2 * i + 1] = (lut[i] >> 8) & 0xFF; + high16_lut[2 * i] = (lut[i] >> 16) & 0xFF; + high16_lut[2 * i + 1] = (lut[i] >> 24) & 0xFF; + } +#endif + } + + V operator()(const V vx) const { + constexpr size_t kScale = NoiseParams::kNumNoisePoints - 2; + auto scaled_vx = Max(Zero(D()), Mul(vx, Set(D(), kScale))); + auto floor_x = Floor(scaled_vx); + auto frac_x = Sub(scaled_vx, floor_x); + floor_x = IfThenElse(Ge(scaled_vx, Set(D(), kScale + 1)), Set(D(), kScale), + floor_x); + frac_x = + IfThenElse(Ge(scaled_vx, Set(D(), kScale + 1)), Set(D(), 1), frac_x); + auto floor_x_int = ConvertTo(DI(), floor_x); +#if HWY_TARGET == HWY_SCALAR + auto low = Set(D(), noise_params_.lut[floor_x_int.raw]); + auto hi = Set(D(), noise_params_.lut[floor_x_int.raw + 1]); +#else + // Set each lane's bytes to {0, 0, 2x+1, 2x}. + auto floorx_indices_low = + Add(Mul(floor_x_int, Set(DI(), 0x0202)), Set(DI(), 0x0100)); + // Set each lane's bytes to {2x+1, 2x, 0, 0}. + auto floorx_indices_hi = + Add(Mul(floor_x_int, Set(DI(), 0x02020000)), Set(DI(), 0x01000000)); + // load LUT + auto low16 = BitCast(DI(), LoadDup128(DI8(), low16_lut)); + auto lowm = Set(DI(), 0xFFFF); + auto hi16 = BitCast(DI(), LoadDup128(DI8(), high16_lut)); + auto him = Set(DI(), 0xFFFF0000); + // low = noise_params.lut[floor_x] + auto low = + BitCast(D(), Or(And(TableLookupBytes(low16, floorx_indices_low), lowm), + And(TableLookupBytes(hi16, floorx_indices_hi), him))); + // hi = noise_params.lut[floor_x+1] + floorx_indices_low = Add(floorx_indices_low, Set(DI(), 0x0202)); + floorx_indices_hi = Add(floorx_indices_hi, Set(DI(), 0x02020000)); + auto hi = + BitCast(D(), Or(And(TableLookupBytes(low16, floorx_indices_low), lowm), + And(TableLookupBytes(hi16, floorx_indices_hi), him))); +#endif + return MulAdd(Sub(hi, low), frac_x, low); + } + + private: +#if HWY_TARGET != HWY_SCALAR + // noise_params.lut transformed into two 16-bit lookup tables. + HWY_ALIGN uint8_t high16_lut[16]; + HWY_ALIGN uint8_t low16_lut[16]; +#else + const NoiseParams& noise_params_; +#endif +}; + +template <class D> +void AddNoiseToRGB(const D d, const Vec<D> rnd_noise_r, + const Vec<D> rnd_noise_g, const Vec<D> rnd_noise_cor, + const Vec<D> noise_strength_g, const Vec<D> noise_strength_r, + float ytox, float ytob, float* JXL_RESTRICT out_x, + float* JXL_RESTRICT out_y, float* JXL_RESTRICT out_b) { + const auto kRGCorr = Set(d, 0.9921875f); // 127/128 + const auto kRGNCorr = Set(d, 0.0078125f); // 1/128 + + const auto red_noise = + Mul(noise_strength_r, + MulAdd(kRGNCorr, rnd_noise_r, Mul(kRGCorr, rnd_noise_cor))); + const auto green_noise = + Mul(noise_strength_g, + MulAdd(kRGNCorr, rnd_noise_g, Mul(kRGCorr, rnd_noise_cor))); + + auto vx = LoadU(d, out_x); + auto vy = LoadU(d, out_y); + auto vb = LoadU(d, out_b); + + const auto rg_noise = Add(red_noise, green_noise); + vx = Add(MulAdd(Set(d, ytox), rg_noise, Sub(red_noise, green_noise)), vx); + vy = Add(vy, rg_noise); + vb = MulAdd(Set(d, ytob), rg_noise, vb); + + StoreU(vx, d, out_x); + StoreU(vy, d, out_y); + StoreU(vb, d, out_b); +} + +class AddNoiseStage : public RenderPipelineStage { + public: + AddNoiseStage(const NoiseParams& noise_params, + const ColorCorrelationMap& cmap, size_t first_c) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/0)), + noise_params_(noise_params), + cmap_(cmap), + first_c_(first_c) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + if (!noise_params_.HasAny()) return; + const StrengthEvalLut noise_model(noise_params_); + D d; + const auto half = Set(d, 0.5f); + + // With the prior subtract-random Laplacian approximation, rnd_* ranges were + // about [-1.5, 1.6]; Laplacian3 about doubles this to [-3.6, 3.6], so the + // normalizer is half of what it was before (0.5). + const auto norm_const = Set(d, 0.22f); + + float ytox = cmap_.YtoXRatio(0); + float ytob = cmap_.YtoBRatio(0); + + const size_t xsize_v = RoundUpTo(xsize, Lanes(d)); + + float* JXL_RESTRICT row_x = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row_y = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row_b = GetInputRow(input_rows, 2, 0); + const float* JXL_RESTRICT row_rnd_r = + GetInputRow(input_rows, first_c_ + 0, 0); + const float* JXL_RESTRICT row_rnd_g = + GetInputRow(input_rows, first_c_ + 1, 0); + const float* JXL_RESTRICT row_rnd_c = + GetInputRow(input_rows, first_c_ + 2, 0); + // Needed by the calls to Floor() in StrengthEvalLut. Only arithmetic and + // shuffles are otherwise done on the data, so this is safe. + msan::UnpoisonMemory(row_x + xsize, (xsize_v - xsize) * sizeof(float)); + msan::UnpoisonMemory(row_y + xsize, (xsize_v - xsize) * sizeof(float)); + for (size_t x = 0; x < xsize_v; x += Lanes(d)) { + const auto vx = LoadU(d, row_x + x); + const auto vy = LoadU(d, row_y + x); + const auto in_g = Sub(vy, vx); + const auto in_r = Add(vy, vx); + const auto noise_strength_g = NoiseStrength(noise_model, Mul(in_g, half)); + const auto noise_strength_r = NoiseStrength(noise_model, Mul(in_r, half)); + const auto addit_rnd_noise_red = Mul(LoadU(d, row_rnd_r + x), norm_const); + const auto addit_rnd_noise_green = + Mul(LoadU(d, row_rnd_g + x), norm_const); + const auto addit_rnd_noise_correlated = + Mul(LoadU(d, row_rnd_c + x), norm_const); + AddNoiseToRGB(D(), addit_rnd_noise_red, addit_rnd_noise_green, + addit_rnd_noise_correlated, noise_strength_g, + noise_strength_r, ytox, ytob, row_x + x, row_y + x, + row_b + x); + } + msan::PoisonMemory(row_x + xsize, (xsize_v - xsize) * sizeof(float)); + msan::PoisonMemory(row_y + xsize, (xsize_v - xsize) * sizeof(float)); + msan::PoisonMemory(row_b + xsize, (xsize_v - xsize) * sizeof(float)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c >= first_c_ ? RenderPipelineChannelMode::kInput + : c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "AddNoise"; } + + private: + const NoiseParams& noise_params_; + const ColorCorrelationMap& cmap_; + size_t first_c_; +}; + +std::unique_ptr<RenderPipelineStage> GetAddNoiseStage( + const NoiseParams& noise_params, const ColorCorrelationMap& cmap, + size_t noise_c_start) { + return jxl::make_unique<AddNoiseStage>(noise_params, cmap, noise_c_start); +} + +class ConvolveNoiseStage : public RenderPipelineStage { + public: + explicit ConvolveNoiseStage(size_t first_c) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/2)), + first_c_(first_c) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + const HWY_FULL(float) d; + for (size_t c = first_c_; c < first_c_ + 3; c++) { + float* JXL_RESTRICT rows[5]; + for (size_t i = 0; i < 5; i++) { + rows[i] = GetInputRow(input_rows, c, i - 2); + } + float* JXL_RESTRICT row_out = GetOutputRow(output_rows, c, 0); + for (ssize_t x = -RoundUpTo(xextra, Lanes(d)); + x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + const auto p00 = LoadU(d, rows[2] + x); + auto others = Zero(d); + // TODO(eustas): sum loaded values to reduce the calculation chain + for (ssize_t i = -2; i <= 2; i++) { + others = Add(others, LoadU(d, rows[0] + x + i)); + others = Add(others, LoadU(d, rows[1] + x + i)); + others = Add(others, LoadU(d, rows[3] + x + i)); + others = Add(others, LoadU(d, rows[4] + x + i)); + } + others = Add(others, LoadU(d, rows[2] + x - 2)); + others = Add(others, LoadU(d, rows[2] + x - 1)); + others = Add(others, LoadU(d, rows[2] + x + 1)); + others = Add(others, LoadU(d, rows[2] + x + 2)); + // 4 * (1 - box kernel) + auto pixels = MulAdd(others, Set(d, 0.16), Mul(p00, Set(d, -3.84))); + StoreU(pixels, d, row_out + x); + } + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c >= first_c_ ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "ConvNoise"; } + + private: + size_t first_c_; +}; + +std::unique_ptr<RenderPipelineStage> GetConvolveNoiseStage( + size_t noise_c_start) { + return jxl::make_unique<ConvolveNoiseStage>(noise_c_start); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetAddNoiseStage); +HWY_EXPORT(GetConvolveNoiseStage); + +std::unique_ptr<RenderPipelineStage> GetAddNoiseStage( + const NoiseParams& noise_params, const ColorCorrelationMap& cmap, + size_t noise_c_start) { + return HWY_DYNAMIC_DISPATCH(GetAddNoiseStage)(noise_params, cmap, + noise_c_start); +} + +std::unique_ptr<RenderPipelineStage> GetConvolveNoiseStage( + size_t noise_c_start) { + return HWY_DYNAMIC_DISPATCH(GetConvolveNoiseStage)(noise_c_start); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_noise.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_noise.h new file mode 100644 index 0000000000..bd7797f991 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_noise.h @@ -0,0 +1,32 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_NOISE_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_NOISE_H_ +#include <math.h> +#include <stdint.h> +#include <stdio.h> + +#include <algorithm> +#include <utility> +#include <vector> + +#include "lib/jxl/dec_noise.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Adds noise to color channels. +std::unique_ptr<RenderPipelineStage> GetAddNoiseStage( + const NoiseParams& noise_params, const ColorCorrelationMap& cmap, + size_t noise_c_start); + +// Applies a 5x5 subtract-box-filter convolution to the noise input channels. +std::unique_ptr<RenderPipelineStage> GetConvolveNoiseStage( + size_t noise_c_start); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_NOISE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_patches.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_patches.cc new file mode 100644 index 0000000000..c5a75b09f7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_patches.cc @@ -0,0 +1,47 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_patches.h" + +namespace jxl { +namespace { +class PatchDictionaryStage : public RenderPipelineStage { + public: + PatchDictionaryStage(const PatchDictionary* patches, size_t num_channels) + : RenderPipelineStage(RenderPipelineStage::Settings()), + patches_(*patches), + num_channels_(num_channels) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + JXL_ASSERT(xpos == 0 || xpos >= xextra); + size_t x0 = xpos ? xpos - xextra : 0; + std::vector<float*> row_ptrs(num_channels_); + for (size_t i = 0; i < num_channels_; i++) { + row_ptrs[i] = GetInputRow(input_rows, i, 0) + x0 - xpos; + } + patches_.AddOneRow(row_ptrs.data(), ypos, x0, xsize + xextra + xpos - x0); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < num_channels_ ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Patches"; } + + private: + const PatchDictionary& patches_; + const size_t num_channels_; +}; +} // namespace + +std::unique_ptr<RenderPipelineStage> GetPatchesStage( + const PatchDictionary* patches, size_t num_channels) { + return jxl::make_unique<PatchDictionaryStage>(patches, num_channels); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_patches.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_patches.h new file mode 100644 index 0000000000..b35abdc2eb --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_patches.h @@ -0,0 +1,22 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_PATCHES_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_PATCHES_H_ + +#include <utility> + +#include "lib/jxl/patch_dictionary_internal.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Draws patches if applicable. +std::unique_ptr<RenderPipelineStage> GetPatchesStage( + const PatchDictionary* patches, size_t num_channels); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_PATCHES_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_splines.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_splines.cc new file mode 100644 index 0000000000..4a0529ce2c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_splines.cc @@ -0,0 +1,62 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_splines.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_splines.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +class SplineStage : public RenderPipelineStage { + public: + explicit SplineStage(const Splines* splines) + : RenderPipelineStage(RenderPipelineStage::Settings()), + splines_(*splines) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + float* row_x = GetInputRow(input_rows, 0, 0); + float* row_y = GetInputRow(input_rows, 1, 0); + float* row_b = GetInputRow(input_rows, 2, 0); + splines_.AddToRow(row_x, row_y, row_b, Rect(xpos, ypos, xsize, 1)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Splines"; } + + private: + const Splines& splines_; +}; + +std::unique_ptr<RenderPipelineStage> GetSplineStage(const Splines* splines) { + return jxl::make_unique<SplineStage>(splines); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetSplineStage); + +std::unique_ptr<RenderPipelineStage> GetSplineStage(const Splines* splines) { + return HWY_DYNAMIC_DISPATCH(GetSplineStage)(splines); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_splines.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_splines.h new file mode 100644 index 0000000000..363af393ec --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_splines.h @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_SPLINES_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_SPLINES_H_ + +#include <utility> + +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" +#include "lib/jxl/splines.h" + +namespace jxl { + +// Draws splines if applicable. +std::unique_ptr<RenderPipelineStage> GetSplineStage(const Splines* splines); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_SPLINES_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_spot.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_spot.cc new file mode 100644 index 0000000000..a43cb4e1ab --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_spot.cc @@ -0,0 +1,51 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_spot.h" + +namespace jxl { +class SpotColorStage : public RenderPipelineStage { + public: + explicit SpotColorStage(size_t spot_c, const float* spot_color) + : RenderPipelineStage(RenderPipelineStage::Settings()), + spot_c_(spot_c), + spot_color_(spot_color) { + JXL_ASSERT(spot_c_ >= 3); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + // TODO(veluca): add SIMD. + float scale = spot_color_[3]; + for (size_t c = 0; c < 3; c++) { + float* JXL_RESTRICT p = GetInputRow(input_rows, c, 0); + const float* JXL_RESTRICT s = GetInputRow(input_rows, spot_c_, 0); + for (ssize_t x = -xextra; x < ssize_t(xsize + xextra); x++) { + float mix = scale * s[x]; + p[x] = mix * spot_color_[c] + (1.0f - mix) * p[x]; + } + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : c == spot_c_ ? RenderPipelineChannelMode::kInput + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Spot"; } + + private: + size_t spot_c_; + const float* spot_color_; +}; + +std::unique_ptr<RenderPipelineStage> GetSpotColorStage( + size_t spot_c, const float* spot_color) { + return jxl::make_unique<SpotColorStage>(spot_c, spot_color); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_spot.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_spot.h new file mode 100644 index 0000000000..3e79c75823 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_spot.h @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_SPOT_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_SPOT_H_ + +#include <utility> + +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Render the spot color channels. +std::unique_ptr<RenderPipelineStage> GetSpotColorStage(size_t spot_c, + const float* spot_color); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_SPOT_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_to_linear.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_to_linear.cc new file mode 100644 index 0000000000..85eca2f039 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_to_linear.cc @@ -0,0 +1,203 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_to_linear.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_to_linear.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/cms/tone_mapping-inl.h" +#include "lib/jxl/cms/transfer_functions-inl.h" +#include "lib/jxl/sanitizers.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::IfThenZeroElse; + +template <typename Op> +struct PerChannelOp { + explicit PerChannelOp(Op op) : op(op) {} + template <typename D, typename T> + void Transform(D d, T* r, T* g, T* b) const { + *r = op.Transform(d, *r); + *g = op.Transform(d, *g); + *b = op.Transform(d, *b); + } + + Op op; +}; +template <typename Op> +PerChannelOp<Op> MakePerChannelOp(Op&& op) { + return PerChannelOp<Op>(std::forward<Op>(op)); +} + +struct OpLinear { + template <typename D, typename T> + T Transform(D d, const T& encoded) const { + return encoded; + } +}; + +struct OpRgb { + template <typename D, typename T> + T Transform(D d, const T& encoded) const { + return TF_SRGB().DisplayFromEncoded(encoded); + } +}; + +struct OpPq { + explicit OpPq(const float intensity_target) : tf_pq_(intensity_target) {} + template <typename D, typename T> + T Transform(D d, const T& encoded) const { + return tf_pq_.DisplayFromEncoded(d, encoded); + } + TF_PQ tf_pq_; +}; + +struct OpHlg { + explicit OpHlg(const float luminances[3], const float intensity_target) + : hlg_ootf_(HlgOOTF::FromSceneLight( + /*display_luminance=*/intensity_target, luminances)) {} + + template <typename D, typename T> + void Transform(D d, T* r, T* g, T* b) const { + for (T* val : {r, g, b}) { + HWY_ALIGN float vals[MaxLanes(d)]; + Store(*val, d, vals); + for (size_t i = 0; i < Lanes(d); ++i) { + vals[i] = TF_HLG_Base::DisplayFromEncoded(vals[i]); + } + *val = Load(d, vals); + } + hlg_ootf_.Apply(r, g, b); + } + HlgOOTF hlg_ootf_; +}; + +struct Op709 { + template <typename D, typename T> + T Transform(D d, const T& encoded) const { + return TF_709().DisplayFromEncoded(d, encoded); + } +}; + +struct OpGamma { + const float gamma; + template <typename D, typename T> + T Transform(D d, const T& encoded) const { + return IfThenZeroElse(Le(encoded, Set(d, 1e-5f)), + FastPowf(d, encoded, Set(d, gamma))); + } +}; + +struct OpInvalid { + template <typename D, typename T> + void Transform(D d, T* r, T* g, T* b) const {} +}; + +template <typename Op> +class ToLinearStage : public RenderPipelineStage { + public: + explicit ToLinearStage(Op op) + : RenderPipelineStage(RenderPipelineStage::Settings()), + op_(std::move(op)) {} + + explicit ToLinearStage() + : RenderPipelineStage(RenderPipelineStage::Settings()), valid_(false) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + const HWY_FULL(float) d; + const size_t xsize_v = RoundUpTo(xsize, Lanes(d)); + float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0); + // All calculations are lane-wise, still some might require + // value-dependent behaviour (e.g. NearestInt). Temporary unpoison last + // vector tail. + msan::UnpoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + for (ssize_t x = -xextra; x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + auto r = LoadU(d, row0 + x); + auto g = LoadU(d, row1 + x); + auto b = LoadU(d, row2 + x); + op_.Transform(d, &r, &g, &b); + StoreU(r, d, row0 + x); + StoreU(g, d, row1 + x); + StoreU(b, d, row2 + x); + } + msan::PoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "ToLinear"; } + + private: + Status IsInitialized() const override { return valid_; } + + Op op_; + bool valid_ = true; +}; + +template <typename Op> +std::unique_ptr<ToLinearStage<Op>> MakeToLinearStage(Op&& op) { + return jxl::make_unique<ToLinearStage<Op>>(std::forward<Op>(op)); +} + +std::unique_ptr<RenderPipelineStage> GetToLinearStage( + const OutputEncodingInfo& output_encoding_info) { + const auto& tf = output_encoding_info.color_encoding.Tf(); + if (tf.IsLinear()) { + return MakeToLinearStage(MakePerChannelOp(OpLinear())); + } else if (tf.IsSRGB()) { + return MakeToLinearStage(MakePerChannelOp(OpRgb())); + } else if (tf.IsPQ()) { + return MakeToLinearStage( + MakePerChannelOp(OpPq(output_encoding_info.orig_intensity_target))); + } else if (tf.IsHLG()) { + return MakeToLinearStage(OpHlg(output_encoding_info.luminances, + output_encoding_info.orig_intensity_target)); + } else if (tf.Is709()) { + return MakeToLinearStage(MakePerChannelOp(Op709())); + } else if (tf.have_gamma || tf.IsDCI()) { + return MakeToLinearStage( + MakePerChannelOp(OpGamma{1.f / output_encoding_info.inverse_gamma})); + } else { + return jxl::make_unique<ToLinearStage<OpInvalid>>(); + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetToLinearStage); + +std::unique_ptr<RenderPipelineStage> GetToLinearStage( + const OutputEncodingInfo& output_encoding_info) { + return HWY_DYNAMIC_DISPATCH(GetToLinearStage)(output_encoding_info); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_to_linear.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_to_linear.h new file mode 100644 index 0000000000..ccee7b09f0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_to_linear.h @@ -0,0 +1,21 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_TO_LINEAR_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_TO_LINEAR_H_ + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Converts the color channels from `output_encoding_info.color_encoding` to +// linear. +std::unique_ptr<RenderPipelineStage> GetToLinearStage( + const OutputEncodingInfo& output_encoding_info); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_TO_LINEAR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_tone_mapping.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_tone_mapping.cc new file mode 100644 index 0000000000..2a272e15dc --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_tone_mapping.cc @@ -0,0 +1,147 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_tone_mapping.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_tone_mapping.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/cms/tone_mapping-inl.h" +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/sanitizers.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +class ToneMappingStage : public RenderPipelineStage { + public: + explicit ToneMappingStage(OutputEncodingInfo output_encoding_info) + : RenderPipelineStage(RenderPipelineStage::Settings()), + output_encoding_info_(std::move(output_encoding_info)) { + if (output_encoding_info_.desired_intensity_target == + output_encoding_info_.orig_intensity_target) { + // No tone mapping requested. + return; + } + const auto& orig_tf = output_encoding_info_.orig_color_encoding.Tf(); + const auto& dest_tf = output_encoding_info_.color_encoding.Tf(); + if (orig_tf.IsPQ() && output_encoding_info_.desired_intensity_target < + output_encoding_info_.orig_intensity_target) { + tone_mapper_ = jxl::make_unique<ToneMapper>( + /*source_range=*/std::pair<float, float>( + 0, output_encoding_info_.orig_intensity_target), + /*target_range=*/ + std::pair<float, float>( + 0, output_encoding_info_.desired_intensity_target), + output_encoding_info_.luminances); + } else if (orig_tf.IsHLG() && !dest_tf.IsHLG()) { + hlg_ootf_ = jxl::make_unique<HlgOOTF>( + /*source_luminance=*/output_encoding_info_.orig_intensity_target, + /*target_luminance=*/output_encoding_info_.desired_intensity_target, + output_encoding_info_.luminances); + } + + if (dest_tf.IsPQ() && (tone_mapper_ || hlg_ootf_)) { + to_intensity_target_ = + 10000.f / output_encoding_info_.orig_intensity_target; + from_desired_intensity_target_ = + output_encoding_info_.desired_intensity_target / 10000.f; + } + } + + bool IsNeeded() const { return tone_mapper_ || hlg_ootf_; } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + if (!(tone_mapper_ || hlg_ootf_)) return; + + const HWY_FULL(float) d; + const size_t xsize_v = RoundUpTo(xsize, Lanes(d)); + float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0); + // All calculations are lane-wise, still some might require + // value-dependent behaviour (e.g. NearestInt). Temporary unpoison last + // vector tail. + msan::UnpoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + for (ssize_t x = -xextra; x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + auto r = LoadU(d, row0 + x); + auto g = LoadU(d, row1 + x); + auto b = LoadU(d, row2 + x); + if (tone_mapper_ || hlg_ootf_) { + r = Mul(r, Set(d, to_intensity_target_)); + g = Mul(g, Set(d, to_intensity_target_)); + b = Mul(b, Set(d, to_intensity_target_)); + if (tone_mapper_) { + tone_mapper_->ToneMap(&r, &g, &b); + } else { + JXL_ASSERT(hlg_ootf_); + hlg_ootf_->Apply(&r, &g, &b); + } + if (tone_mapper_ || hlg_ootf_->WarrantsGamutMapping()) { + GamutMap(&r, &g, &b, output_encoding_info_.luminances); + } + r = Mul(r, Set(d, from_desired_intensity_target_)); + g = Mul(g, Set(d, from_desired_intensity_target_)); + b = Mul(b, Set(d, from_desired_intensity_target_)); + } + StoreU(r, d, row0 + x); + StoreU(g, d, row1 + x); + StoreU(b, d, row2 + x); + } + msan::PoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "ToneMapping"; } + + private: + using ToneMapper = Rec2408ToneMapper<HWY_FULL(float)>; + OutputEncodingInfo output_encoding_info_; + std::unique_ptr<ToneMapper> tone_mapper_; + std::unique_ptr<HlgOOTF> hlg_ootf_; + // When the target colorspace is PQ, 1 represents 10000 nits instead of + // orig_intensity_target. This temporarily changes this if the tone mappers + // require it. + float to_intensity_target_ = 1.f; + float from_desired_intensity_target_ = 1.f; +}; + +std::unique_ptr<RenderPipelineStage> GetToneMappingStage( + const OutputEncodingInfo& output_encoding_info) { + auto stage = jxl::make_unique<ToneMappingStage>(output_encoding_info); + if (!stage->IsNeeded()) return nullptr; + return stage; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetToneMappingStage); + +std::unique_ptr<RenderPipelineStage> GetToneMappingStage( + const OutputEncodingInfo& output_encoding_info) { + return HWY_DYNAMIC_DISPATCH(GetToneMappingStage)(output_encoding_info); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_tone_mapping.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_tone_mapping.h new file mode 100644 index 0000000000..57eb9a9abf --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_tone_mapping.h @@ -0,0 +1,36 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_TONE_MAPPING_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_TONE_MAPPING_H_ +#include <math.h> +#include <stdint.h> + +#include <algorithm> +#include <utility> +#include <vector> + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Tone maps the image if appropriate. It must be in linear space and +// `output_encoding_info.luminances` must contain the luminance for the +// primaries of that space. It must also be encoded such that (1, 1, 1) +// represents `output_encoding_info.orig_intensity_target` nits, unless +// `output_encoding_info.color_encoding.tf.IsPQ()`, in which case (1, 1, 1) must +// represent 10000 nits. This corresponds to what XYBStage outputs. After this +// stage, (1, 1, 1) will represent +// `output_encoding_info.desired_intensity_target` nits, except in the PQ +// special case in which it remains 10000. +// +// If no tone mapping is necessary, this will return nullptr. +std::unique_ptr<RenderPipelineStage> GetToneMappingStage( + const OutputEncodingInfo& output_encoding_info); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_TONE_MAPPING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_upsampling.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_upsampling.cc new file mode 100644 index 0000000000..ade37d59a6 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_upsampling.cc @@ -0,0 +1,192 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_upsampling.h" + +#include "lib/jxl/base/status.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_upsampling.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/simd_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Clamp; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Min; +using hwy::HWY_NAMESPACE::MulAdd; + +class UpsamplingStage : public RenderPipelineStage { + public: + explicit UpsamplingStage(const CustomTransformData& ups_factors, size_t c, + size_t shift) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/shift, /*border=*/2)), + c_(c) { + const float* weights = shift == 1 ? ups_factors.upsampling2_weights + : shift == 2 ? ups_factors.upsampling4_weights + : ups_factors.upsampling8_weights; + size_t N = 1 << (shift - 1); + for (size_t i = 0; i < 5 * N; i++) { + for (size_t j = 0; j < 5 * N; j++) { + size_t y = std::min(i, j); + size_t x = std::max(i, j); + kernel_[j / 5][i / 5][j % 5][i % 5] = + weights[5 * N * y - y * (y - 1) / 2 + x - y]; + } + } + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + static HWY_FULL(float) df; + size_t shift = settings_.shift_x; + size_t N = 1 << shift; + const size_t xsize_v = RoundUpTo(xsize, Lanes(df)); + for (ssize_t iy = -2; iy <= 2; iy++) { + msan::UnpoisonMemory(GetInputRow(input_rows, c_, iy) + xsize + 2, + sizeof(float) * (xsize_v - xsize)); + } + JXL_ASSERT(xextra == 0); + ssize_t x0 = 0; + ssize_t x1 = xsize; + if (N == 2) { + ProcessRowImpl<2>(input_rows, output_rows, x0, x1); + } + if (N == 4) { + ProcessRowImpl<4>(input_rows, output_rows, x0, x1); + } + if (N == 8) { + ProcessRowImpl<8>(input_rows, output_rows, x0, x1); + } + for (size_t oy = 0; oy < N; oy++) { + float* dst_row = GetOutputRow(output_rows, c_, oy); + msan::PoisonMemory(dst_row + xsize * N, + sizeof(float) * (xsize_v - xsize) * N); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c == c_ ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Upsample"; } + + private: + template <size_t N> + JXL_INLINE float Kernel(size_t x, size_t y, ssize_t ix, ssize_t iy) const { + ix += 2; + iy += 2; + if (N == 2) { + return kernel_[0][0][y % 2 ? 4 - iy : iy][x % 2 ? 4 - ix : ix]; + } + if (N == 4) { + return kernel_[y % 4 < 2 ? y % 2 : 1 - y % 2] + [x % 4 < 2 ? x % 2 : 1 - x % 2][y % 4 < 2 ? iy : 4 - iy] + [x % 4 < 2 ? ix : 4 - ix]; + } + if (N == 8) { + return kernel_[y % 8 < 4 ? y % 4 : 3 - y % 4] + [x % 8 < 4 ? x % 4 : 3 - x % 4][y % 8 < 4 ? iy : 4 - iy] + [x % 8 < 4 ? ix : 4 - ix]; + } + JXL_UNREACHABLE("Invalid upsample"); + } + + template <ssize_t N> + void ProcessRowImpl(const RowInfo& input_rows, const RowInfo& output_rows, + ssize_t x0, ssize_t x1) const { + static HWY_FULL(float) df; + using V = hwy::HWY_NAMESPACE::Vec<HWY_FULL(float)>; + V ups0, ups1, ups2, ups3, ups4, ups5, ups6, ups7; + (void)ups2, (void)ups3, (void)ups4, (void)ups5, (void)ups6, (void)ups7; + // Once we have C++17 available, change this back to `V* ups[N]` and + // initialize using `if constexpr` below. + V* ups[8] = {}; + static_assert(N == 2 || N == 4 || N == 8, "N must be 2, 4, or 8"); + if (N >= 2) { + ups[0] = &ups0; + ups[1] = &ups1; + } + if (N >= 4) { + ups[2] = &ups2; + ups[3] = &ups3; + } + if (N == 8) { + ups[4] = &ups4; + ups[5] = &ups5; + ups[6] = &ups6; + ups[7] = &ups7; + } + + for (size_t oy = 0; oy < N; oy++) { + float* dst_row = GetOutputRow(output_rows, c_, oy); + for (ssize_t x = x0; x < x1; x += Lanes(df)) { + for (size_t ox = 0; ox < N; ox++) { + auto result = Zero(df); + auto min = LoadU(df, GetInputRow(input_rows, c_, 0) + x); + auto max = min; + for (ssize_t iy = -2; iy <= 2; iy++) { + for (ssize_t ix = -2; ix <= 2; ix++) { + auto v = LoadU(df, GetInputRow(input_rows, c_, iy) + x + ix); + result = MulAdd(Set(df, Kernel<N>(ox, oy, ix, iy)), v, result); + min = Min(v, min); + max = Max(v, max); + } + } + // Avoid overshooting. + *ups[ox] = Clamp(result, min, max); + } + if (N == 2) { + StoreInterleaved(df, ups0, ups1, dst_row + x * N); + } + if (N == 4) { + StoreInterleaved(df, ups0, ups1, ups2, ups3, dst_row + x * N); + } + if (N == 8) { + StoreInterleaved(df, ups0, ups1, ups2, ups3, ups4, ups5, ups6, ups7, + dst_row + x * N); + } + } + } + } + + size_t c_; + float kernel_[4][4][5][5]; +}; + +std::unique_ptr<RenderPipelineStage> GetUpsamplingStage( + const CustomTransformData& ups_factors, size_t c, size_t shift) { + return jxl::make_unique<UpsamplingStage>(ups_factors, c, shift); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetUpsamplingStage); + +std::unique_ptr<RenderPipelineStage> GetUpsamplingStage( + const CustomTransformData& ups_factors, size_t c, size_t shift) { + JXL_ASSERT(shift != 0); + JXL_ASSERT(shift <= 3); + return HWY_DYNAMIC_DISPATCH(GetUpsamplingStage)(ups_factors, c, shift); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_upsampling.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_upsampling.h new file mode 100644 index 0000000000..7d5defd23c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_upsampling.h @@ -0,0 +1,26 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_UPSAMPLING_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_UPSAMPLING_H_ +#include <math.h> +#include <stdint.h> +#include <stdio.h> + +#include <algorithm> +#include <utility> +#include <vector> + +#include "lib/jxl/image_metadata.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Upsamples the given channel by the given factor. +std::unique_ptr<RenderPipelineStage> GetUpsamplingStage( + const CustomTransformData& ups_factors, size_t c, size_t shift); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_UPSAMPLING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_write.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_write.cc new file mode 100644 index 0000000000..847972acc8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_write.cc @@ -0,0 +1,671 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_write.h" + +#include <type_traits> + +#include "lib/jxl/alpha.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/sanitizers.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_write.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Clamp; +using hwy::HWY_NAMESPACE::Div; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::NearestInt; +using hwy::HWY_NAMESPACE::Or; +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::ShiftLeftSame; +using hwy::HWY_NAMESPACE::ShiftRightSame; +using hwy::HWY_NAMESPACE::VFromD; + +// 8x8 ordered dithering pattern from +// https://en.wikipedia.org/wiki/Ordered_dithering +// scaled to have an average of 0 and be fully contained in (-0.5, 0.5). +// Matrix is duplicated in width to avoid inconsistencies or out-of-bound-reads +// if doing unaligned operations. +const float kDither[(2 * 8) * 8] = { + -0.4921875, 0.0078125, -0.3671875, 0.1328125, // + -0.4609375, 0.0390625, -0.3359375, 0.1640625, // + -0.4921875, 0.0078125, -0.3671875, 0.1328125, // + -0.4609375, 0.0390625, -0.3359375, 0.1640625, // + // + 0.2578125, -0.2421875, 0.3828125, -0.1171875, // + 0.2890625, -0.2109375, 0.4140625, -0.0859375, // + 0.2578125, -0.2421875, 0.3828125, -0.1171875, // + 0.2890625, -0.2109375, 0.4140625, -0.0859375, // + // + -0.3046875, 0.1953125, -0.4296875, 0.0703125, // + -0.2734375, 0.2265625, -0.3984375, 0.1015625, // + -0.3046875, 0.1953125, -0.4296875, 0.0703125, // + -0.2734375, 0.2265625, -0.3984375, 0.1015625, // + // + 0.4453125, -0.0546875, 0.3203125, -0.1796875, // + 0.4765625, -0.0234375, 0.3515625, -0.1484375, // + 0.4453125, -0.0546875, 0.3203125, -0.1796875, // + 0.4765625, -0.0234375, 0.3515625, -0.1484375, // + // + -0.4453125, 0.0546875, -0.3203125, 0.1796875, // + -0.4765625, 0.0234375, -0.3515625, 0.1484375, // + -0.4453125, 0.0546875, -0.3203125, 0.1796875, // + -0.4765625, 0.0234375, -0.3515625, 0.1484375, // + // + 0.3046875, -0.1953125, 0.4296875, -0.0703125, // + 0.2734375, -0.2265625, 0.3984375, -0.1015625, // + 0.3046875, -0.1953125, 0.4296875, -0.0703125, // + 0.2734375, -0.2265625, 0.3984375, -0.1015625, // + // + -0.2578125, 0.2421875, -0.3828125, 0.1171875, // + -0.2890625, 0.2109375, -0.4140625, 0.0859375, // + -0.2578125, 0.2421875, -0.3828125, 0.1171875, // + -0.2890625, 0.2109375, -0.4140625, 0.0859375, // + // + 0.4921875, -0.0078125, 0.3671875, -0.1328125, // + 0.4609375, -0.0390625, 0.3359375, -0.1640625, // + 0.4921875, -0.0078125, 0.3671875, -0.1328125, // + 0.4609375, -0.0390625, 0.3359375, -0.1640625, // +}; + +using DF = HWY_FULL(float); + +// Converts `v` to an appropriate value for the given unsigned type. +// If the unsigned type is an 8-bit type, performs ordered dithering. +template <typename T> +VFromD<Rebind<T, DF>> MakeUnsigned(VFromD<DF> v, size_t x0, size_t y0, + VFromD<DF> mul) { + static_assert(std::is_unsigned<T>::value, "T must be an unsigned type"); + using DU = Rebind<T, DF>; + v = Mul(v, mul); + // TODO(veluca): if constexpr with C++17 + if (sizeof(T) == 1) { + size_t pos = (y0 % 8) * (2 * 8) + (x0 % 8); +#if HWY_TARGET != HWY_SCALAR + auto dither = LoadDup128(DF(), kDither + pos); +#else + auto dither = LoadU(DF(), kDither + pos); +#endif + v = Add(v, dither); + } + v = Clamp(Zero(DF()), v, mul); + return DemoteTo(DU(), NearestInt(v)); +} + +class WriteToOutputStage : public RenderPipelineStage { + public: + WriteToOutputStage(const ImageOutput& main_output, size_t width, + size_t height, bool has_alpha, bool unpremul_alpha, + size_t alpha_c, Orientation undo_orientation, + const std::vector<ImageOutput>& extra_output) + : RenderPipelineStage(RenderPipelineStage::Settings()), + width_(width), + height_(height), + main_(main_output), + num_color_(main_.num_channels_ < 3 ? 1 : 3), + want_alpha_(main_.num_channels_ == 2 || main_.num_channels_ == 4), + has_alpha_(has_alpha), + unpremul_alpha_(unpremul_alpha), + alpha_c_(alpha_c), + flip_x_(ShouldFlipX(undo_orientation)), + flip_y_(ShouldFlipY(undo_orientation)), + transpose_(ShouldTranspose(undo_orientation)), + opaque_alpha_(kMaxPixelsPerCall, 1.0f) { + for (size_t ec = 0; ec < extra_output.size(); ++ec) { + if (extra_output[ec].callback.IsPresent() || extra_output[ec].buffer) { + Output extra(extra_output[ec]); + extra.channel_index_ = 3 + ec; + extra_channels_.push_back(extra); + } + } + } + + WriteToOutputStage(const WriteToOutputStage&) = delete; + WriteToOutputStage& operator=(const WriteToOutputStage&) = delete; + WriteToOutputStage(WriteToOutputStage&&) = delete; + WriteToOutputStage& operator=(WriteToOutputStage&&) = delete; + + ~WriteToOutputStage() override { + if (main_.run_opaque_) { + main_.pixel_callback_.destroy(main_.run_opaque_); + } + for (auto& extra : extra_channels_) { + if (extra.run_opaque_) { + extra.pixel_callback_.destroy(extra.run_opaque_); + } + } + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + JXL_DASSERT(xextra == 0); + JXL_DASSERT(main_.run_opaque_ || main_.buffer_); + if (ypos >= height_) return; + if (xpos >= width_) return; + if (flip_y_) { + ypos = height_ - 1u - ypos; + } + size_t limit = std::min(xsize, width_ - xpos); + for (size_t x0 = 0; x0 < limit; x0 += kMaxPixelsPerCall) { + size_t xstart = xpos + x0; + size_t len = std::min<size_t>(kMaxPixelsPerCall, limit - x0); + + const float* line_buffers[4]; + for (size_t c = 0; c < num_color_; c++) { + line_buffers[c] = GetInputRow(input_rows, c, 0) + x0; + } + if (has_alpha_) { + line_buffers[num_color_] = GetInputRow(input_rows, alpha_c_, 0) + x0; + } else { + // opaque_alpha_ is a way to set all values to 1.0f. + line_buffers[num_color_] = opaque_alpha_.data(); + } + if (has_alpha_ && want_alpha_ && unpremul_alpha_) { + UnpremulAlpha(thread_id, len, line_buffers); + } + OutputBuffers(main_, thread_id, ypos, xstart, len, line_buffers); + for (const auto& extra : extra_channels_) { + line_buffers[0] = GetInputRow(input_rows, extra.channel_index_, 0) + x0; + OutputBuffers(extra, thread_id, ypos, xstart, len, line_buffers); + } + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + if (c < num_color_ || (has_alpha_ && c == alpha_c_)) { + return RenderPipelineChannelMode::kInput; + } + for (const auto& extra : extra_channels_) { + if (c == extra.channel_index_) { + return RenderPipelineChannelMode::kInput; + } + } + return RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "WritePixelCB"; } + + private: + struct Output { + Output(const ImageOutput& image_out) + : pixel_callback_(image_out.callback), + buffer_(image_out.buffer), + buffer_size_(image_out.buffer_size), + stride_(image_out.stride), + num_channels_(image_out.format.num_channels), + swap_endianness_(SwapEndianness(image_out.format.endianness)), + data_type_(image_out.format.data_type), + bits_per_sample_(image_out.bits_per_sample) {} + + Status PrepareForThreads(size_t num_threads) { + if (pixel_callback_.IsPresent()) { + run_opaque_ = + pixel_callback_.Init(num_threads, /*num_pixels=*/kMaxPixelsPerCall); + JXL_RETURN_IF_ERROR(run_opaque_ != nullptr); + } else { + JXL_RETURN_IF_ERROR(buffer_ != nullptr); + } + return true; + } + + PixelCallback pixel_callback_; + void* run_opaque_ = nullptr; + void* buffer_ = nullptr; + size_t buffer_size_; + size_t stride_; + size_t num_channels_; + bool swap_endianness_; + JxlDataType data_type_; + size_t bits_per_sample_; + size_t channel_index_; // used for extra_channels + }; + + Status PrepareForThreads(size_t num_threads) override { + JXL_RETURN_IF_ERROR(main_.PrepareForThreads(num_threads)); + for (auto& extra : extra_channels_) { + JXL_RETURN_IF_ERROR(extra.PrepareForThreads(num_threads)); + } + temp_out_.resize(num_threads); + for (CacheAlignedUniquePtr& temp : temp_out_) { + temp = AllocateArray(sizeof(float) * kMaxPixelsPerCall * + main_.num_channels_); + } + if ((has_alpha_ && want_alpha_ && unpremul_alpha_) || flip_x_) { + temp_in_.resize(num_threads * main_.num_channels_); + for (CacheAlignedUniquePtr& temp : temp_in_) { + temp = AllocateArray(sizeof(float) * kMaxPixelsPerCall); + } + } + return true; + } + static bool ShouldFlipX(Orientation undo_orientation) { + return (undo_orientation == Orientation::kFlipHorizontal || + undo_orientation == Orientation::kRotate180 || + undo_orientation == Orientation::kRotate270 || + undo_orientation == Orientation::kAntiTranspose); + } + static bool ShouldFlipY(Orientation undo_orientation) { + return (undo_orientation == Orientation::kFlipVertical || + undo_orientation == Orientation::kRotate180 || + undo_orientation == Orientation::kRotate90 || + undo_orientation == Orientation::kAntiTranspose); + } + static bool ShouldTranspose(Orientation undo_orientation) { + return (undo_orientation == Orientation::kTranspose || + undo_orientation == Orientation::kRotate90 || + undo_orientation == Orientation::kRotate270 || + undo_orientation == Orientation::kAntiTranspose); + } + + void UnpremulAlpha(size_t thread_id, size_t len, + const float** line_buffers) const { + const HWY_FULL(float) d; + auto one = Set(d, 1.0f); + float* temp_in[4]; + for (size_t c = 0; c < main_.num_channels_; ++c) { + size_t tix = thread_id * main_.num_channels_ + c; + temp_in[c] = reinterpret_cast<float*>(temp_in_[tix].get()); + memcpy(temp_in[c], line_buffers[c], sizeof(float) * len); + } + auto small_alpha = Set(d, kSmallAlpha); + for (size_t ix = 0; ix < len; ix += Lanes(d)) { + auto alpha = LoadU(d, temp_in[num_color_] + ix); + auto mul = Div(one, Max(small_alpha, alpha)); + for (size_t c = 0; c < num_color_; ++c) { + auto val = LoadU(d, temp_in[c] + ix); + StoreU(Mul(val, mul), d, temp_in[c] + ix); + } + } + for (size_t c = 0; c < main_.num_channels_; ++c) { + line_buffers[c] = temp_in[c]; + } + } + + void OutputBuffers(const Output& out, size_t thread_id, size_t ypos, + size_t xstart, size_t len, const float* input[4]) const { + if (flip_x_) { + FlipX(out, thread_id, len, &xstart, input); + } + if (out.data_type_ == JXL_TYPE_UINT8) { + uint8_t* JXL_RESTRICT temp = + reinterpret_cast<uint8_t*>(temp_out_[thread_id].get()); + StoreUnsignedRow(out, input, len, temp, xstart, ypos); + WriteToOutput(out, thread_id, ypos, xstart, len, temp); + } else if (out.data_type_ == JXL_TYPE_UINT16 || + out.data_type_ == JXL_TYPE_FLOAT16) { + uint16_t* JXL_RESTRICT temp = + reinterpret_cast<uint16_t*>(temp_out_[thread_id].get()); + if (out.data_type_ == JXL_TYPE_UINT16) { + StoreUnsignedRow(out, input, len, temp, xstart, ypos); + } else { + StoreFloat16Row(out, input, len, temp); + } + if (out.swap_endianness_) { + const HWY_FULL(uint16_t) du; + size_t output_len = len * out.num_channels_; + for (size_t j = 0; j < output_len; j += Lanes(du)) { + auto v = LoadU(du, temp + j); + auto vswap = Or(ShiftRightSame(v, 8), ShiftLeftSame(v, 8)); + StoreU(vswap, du, temp + j); + } + } + WriteToOutput(out, thread_id, ypos, xstart, len, temp); + } else if (out.data_type_ == JXL_TYPE_FLOAT) { + float* JXL_RESTRICT temp = + reinterpret_cast<float*>(temp_out_[thread_id].get()); + StoreFloatRow(out, input, len, temp); + if (out.swap_endianness_) { + size_t output_len = len * out.num_channels_; + for (size_t j = 0; j < output_len; ++j) { + temp[j] = BSwapFloat(temp[j]); + } + } + WriteToOutput(out, thread_id, ypos, xstart, len, temp); + } + } + + void FlipX(const Output& out, size_t thread_id, size_t len, size_t* xstart, + const float** line_buffers) const { + float* temp_in[4]; + for (size_t c = 0; c < out.num_channels_; ++c) { + size_t tix = thread_id * main_.num_channels_ + c; + temp_in[c] = reinterpret_cast<float*>(temp_in_[tix].get()); + if (temp_in[c] != line_buffers[c]) { + memcpy(temp_in[c], line_buffers[c], sizeof(float) * len); + } + } + size_t last = (len - 1u); + size_t num = (len / 2); + for (size_t i = 0; i < num; ++i) { + for (size_t c = 0; c < out.num_channels_; ++c) { + std::swap(temp_in[c][i], temp_in[c][last - i]); + } + } + for (size_t c = 0; c < out.num_channels_; ++c) { + line_buffers[c] = temp_in[c]; + } + *xstart = width_ - *xstart - len; + } + + template <typename T> + void StoreUnsignedRow(const Output& out, const float* input[4], size_t len, + T* output, size_t xstart, size_t ypos) const { + const HWY_FULL(float) d; + auto mul = Set(d, (1u << (out.bits_per_sample_)) - 1); + const Rebind<T, decltype(d)> du; + const size_t padding = RoundUpTo(len, Lanes(d)) - len; + for (size_t c = 0; c < out.num_channels_; ++c) { + msan::UnpoisonMemory(input[c] + len, sizeof(input[c][0]) * padding); + } + if (out.num_channels_ == 1) { + for (size_t i = 0; i < len; i += Lanes(d)) { + StoreU(MakeUnsigned<T>(LoadU(d, &input[0][i]), xstart + i, ypos, mul), + du, &output[i]); + } + } else if (out.num_channels_ == 2) { + for (size_t i = 0; i < len; i += Lanes(d)) { + StoreInterleaved2( + MakeUnsigned<T>(LoadU(d, &input[0][i]), xstart + i, ypos, mul), + MakeUnsigned<T>(LoadU(d, &input[1][i]), xstart + i, ypos, mul), du, + &output[2 * i]); + } + } else if (out.num_channels_ == 3) { + for (size_t i = 0; i < len; i += Lanes(d)) { + StoreInterleaved3( + MakeUnsigned<T>(LoadU(d, &input[0][i]), xstart + i, ypos, mul), + MakeUnsigned<T>(LoadU(d, &input[1][i]), xstart + i, ypos, mul), + MakeUnsigned<T>(LoadU(d, &input[2][i]), xstart + i, ypos, mul), du, + &output[3 * i]); + } + } else if (out.num_channels_ == 4) { + for (size_t i = 0; i < len; i += Lanes(d)) { + StoreInterleaved4( + MakeUnsigned<T>(LoadU(d, &input[0][i]), xstart + i, ypos, mul), + MakeUnsigned<T>(LoadU(d, &input[1][i]), xstart + i, ypos, mul), + MakeUnsigned<T>(LoadU(d, &input[2][i]), xstart + i, ypos, mul), + MakeUnsigned<T>(LoadU(d, &input[3][i]), xstart + i, ypos, mul), du, + &output[4 * i]); + } + } + msan::PoisonMemory(output + out.num_channels_ * len, + sizeof(output[0]) * out.num_channels_ * padding); + } + + void StoreFloat16Row(const Output& out, const float* input[4], size_t len, + uint16_t* output) const { + const HWY_FULL(float) d; + const Rebind<uint16_t, decltype(d)> du; + const Rebind<hwy::float16_t, decltype(d)> df16; + const size_t padding = RoundUpTo(len, Lanes(d)) - len; + for (size_t c = 0; c < out.num_channels_; ++c) { + msan::UnpoisonMemory(input[c] + len, sizeof(input[c][0]) * padding); + } + if (out.num_channels_ == 1) { + for (size_t i = 0; i < len; i += Lanes(d)) { + auto v0 = LoadU(d, &input[0][i]); + StoreU(BitCast(du, DemoteTo(df16, v0)), du, &output[i]); + } + } else if (out.num_channels_ == 2) { + for (size_t i = 0; i < len; i += Lanes(d)) { + auto v0 = LoadU(d, &input[0][i]); + auto v1 = LoadU(d, &input[1][i]); + StoreInterleaved2(BitCast(du, DemoteTo(df16, v0)), + BitCast(du, DemoteTo(df16, v1)), du, &output[2 * i]); + } + } else if (out.num_channels_ == 3) { + for (size_t i = 0; i < len; i += Lanes(d)) { + auto v0 = LoadU(d, &input[0][i]); + auto v1 = LoadU(d, &input[1][i]); + auto v2 = LoadU(d, &input[2][i]); + StoreInterleaved3(BitCast(du, DemoteTo(df16, v0)), + BitCast(du, DemoteTo(df16, v1)), + BitCast(du, DemoteTo(df16, v2)), du, &output[3 * i]); + } + } else if (out.num_channels_ == 4) { + for (size_t i = 0; i < len; i += Lanes(d)) { + auto v0 = LoadU(d, &input[0][i]); + auto v1 = LoadU(d, &input[1][i]); + auto v2 = LoadU(d, &input[2][i]); + auto v3 = LoadU(d, &input[3][i]); + StoreInterleaved4(BitCast(du, DemoteTo(df16, v0)), + BitCast(du, DemoteTo(df16, v1)), + BitCast(du, DemoteTo(df16, v2)), + BitCast(du, DemoteTo(df16, v3)), du, &output[4 * i]); + } + } + msan::PoisonMemory(output + out.num_channels_ * len, + sizeof(output[0]) * out.num_channels_ * padding); + } + + void StoreFloatRow(const Output& out, const float* input[4], size_t len, + float* output) const { + const HWY_FULL(float) d; + if (out.num_channels_ == 1) { + memcpy(output, input[0], len * sizeof(output[0])); + } else if (out.num_channels_ == 2) { + for (size_t i = 0; i < len; i += Lanes(d)) { + StoreInterleaved2(LoadU(d, &input[0][i]), LoadU(d, &input[1][i]), d, + &output[2 * i]); + } + } else if (out.num_channels_ == 3) { + for (size_t i = 0; i < len; i += Lanes(d)) { + StoreInterleaved3(LoadU(d, &input[0][i]), LoadU(d, &input[1][i]), + LoadU(d, &input[2][i]), d, &output[3 * i]); + } + } else { + for (size_t i = 0; i < len; i += Lanes(d)) { + StoreInterleaved4(LoadU(d, &input[0][i]), LoadU(d, &input[1][i]), + LoadU(d, &input[2][i]), LoadU(d, &input[3][i]), d, + &output[4 * i]); + } + } + } + + template <typename T> + void WriteToOutput(const Output& out, size_t thread_id, size_t ypos, + size_t xstart, size_t len, T* output) const { + if (transpose_) { + // TODO(szabadka) Buffer 8x8 chunks and transpose with SIMD. + if (out.run_opaque_) { + for (size_t i = 0, j = 0; i < len; ++i, j += out.num_channels_) { + out.pixel_callback_.run(out.run_opaque_, thread_id, ypos, xstart + i, + 1, output + j); + } + } else { + const size_t pixel_stride = out.num_channels_ * sizeof(T); + const size_t offset = xstart * out.stride_ + ypos * pixel_stride; + for (size_t i = 0, j = 0; i < len; ++i, j += out.num_channels_) { + const size_t ix = offset + i * out.stride_; + JXL_DASSERT(ix + pixel_stride <= out.buffer_size_); + memcpy(reinterpret_cast<uint8_t*>(out.buffer_) + ix, output + j, + pixel_stride); + } + } + } else { + if (out.run_opaque_) { + out.pixel_callback_.run(out.run_opaque_, thread_id, xstart, ypos, len, + output); + } else { + const size_t pixel_stride = out.num_channels_ * sizeof(T); + const size_t offset = ypos * out.stride_ + xstart * pixel_stride; + JXL_DASSERT(offset + len * pixel_stride <= out.buffer_size_); + memcpy(reinterpret_cast<uint8_t*>(out.buffer_) + offset, output, + len * pixel_stride); + } + } + } + + static constexpr size_t kMaxPixelsPerCall = 1024; + size_t width_; + size_t height_; + Output main_; // color + alpha + size_t num_color_; + bool want_alpha_; + bool has_alpha_; + bool unpremul_alpha_; + size_t alpha_c_; + bool flip_x_; + bool flip_y_; + bool transpose_; + std::vector<Output> extra_channels_; + std::vector<float> opaque_alpha_; + std::vector<CacheAlignedUniquePtr> temp_in_; + std::vector<CacheAlignedUniquePtr> temp_out_; +}; + +constexpr size_t WriteToOutputStage::kMaxPixelsPerCall; + +std::unique_ptr<RenderPipelineStage> GetWriteToOutputStage( + const ImageOutput& main_output, size_t width, size_t height, bool has_alpha, + bool unpremul_alpha, size_t alpha_c, Orientation undo_orientation, + std::vector<ImageOutput>& extra_output) { + return jxl::make_unique<WriteToOutputStage>( + main_output, width, height, has_alpha, unpremul_alpha, alpha_c, + undo_orientation, extra_output); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace jxl { + +HWY_EXPORT(GetWriteToOutputStage); + +namespace { +class WriteToImageBundleStage : public RenderPipelineStage { + public: + explicit WriteToImageBundleStage(ImageBundle* image_bundle, + ColorEncoding color_encoding) + : RenderPipelineStage(RenderPipelineStage::Settings()), + image_bundle_(image_bundle), + color_encoding_(std::move(color_encoding)) {} + + void SetInputSizes( + const std::vector<std::pair<size_t, size_t>>& input_sizes) override { +#if JXL_ENABLE_ASSERT + JXL_ASSERT(input_sizes.size() >= 3); + for (size_t c = 1; c < input_sizes.size(); c++) { + JXL_ASSERT(input_sizes[c].first == input_sizes[0].first); + JXL_ASSERT(input_sizes[c].second == input_sizes[0].second); + } +#endif + // TODO(eustas): what should we do in the case of "want only ECs"? + image_bundle_->SetFromImage( + Image3F(input_sizes[0].first, input_sizes[0].second), color_encoding_); + // TODO(veluca): consider not reallocating ECs if not needed. + image_bundle_->extra_channels().clear(); + for (size_t c = 3; c < input_sizes.size(); c++) { + image_bundle_->extra_channels().emplace_back(input_sizes[c].first, + input_sizes[c].second); + } + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + for (size_t c = 0; c < 3; c++) { + memcpy(image_bundle_->color()->PlaneRow(c, ypos) + xpos - xextra, + GetInputRow(input_rows, c, 0) - xextra, + sizeof(float) * (xsize + 2 * xextra)); + } + for (size_t ec = 0; ec < image_bundle_->extra_channels().size(); ec++) { + JXL_ASSERT(image_bundle_->extra_channels()[ec].xsize() >= + xpos + xsize + xextra); + memcpy(image_bundle_->extra_channels()[ec].Row(ypos) + xpos - xextra, + GetInputRow(input_rows, 3 + ec, 0) - xextra, + sizeof(float) * (xsize + 2 * xextra)); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return RenderPipelineChannelMode::kInput; + } + + const char* GetName() const override { return "WriteIB"; } + + private: + ImageBundle* image_bundle_; + ColorEncoding color_encoding_; +}; + +class WriteToImage3FStage : public RenderPipelineStage { + public: + explicit WriteToImage3FStage(Image3F* image) + : RenderPipelineStage(RenderPipelineStage::Settings()), image_(image) {} + + void SetInputSizes( + const std::vector<std::pair<size_t, size_t>>& input_sizes) override { +#if JXL_ENABLE_ASSERT + JXL_ASSERT(input_sizes.size() >= 3); + for (size_t c = 1; c < 3; ++c) { + JXL_ASSERT(input_sizes[c].first == input_sizes[0].first); + JXL_ASSERT(input_sizes[c].second == input_sizes[0].second); + } +#endif + *image_ = Image3F(input_sizes[0].first, input_sizes[0].second); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + for (size_t c = 0; c < 3; c++) { + memcpy(image_->PlaneRow(c, ypos) + xpos - xextra, + GetInputRow(input_rows, c, 0) - xextra, + sizeof(float) * (xsize + 2 * xextra)); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInput + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "WriteI3F"; } + + private: + Image3F* image_; +}; + +} // namespace + +std::unique_ptr<RenderPipelineStage> GetWriteToImageBundleStage( + ImageBundle* image_bundle, ColorEncoding color_encoding) { + return jxl::make_unique<WriteToImageBundleStage>(image_bundle, + std::move(color_encoding)); +} + +std::unique_ptr<RenderPipelineStage> GetWriteToImage3FStage(Image3F* image) { + return jxl::make_unique<WriteToImage3FStage>(image); +} + +std::unique_ptr<RenderPipelineStage> GetWriteToOutputStage( + const ImageOutput& main_output, size_t width, size_t height, bool has_alpha, + bool unpremul_alpha, size_t alpha_c, Orientation undo_orientation, + std::vector<ImageOutput>& extra_output) { + return HWY_DYNAMIC_DISPATCH(GetWriteToOutputStage)( + main_output, width, height, has_alpha, unpremul_alpha, alpha_c, + undo_orientation, extra_output); +} + +} // namespace jxl + +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_write.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_write.h new file mode 100644 index 0000000000..c5f844ebe8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_write.h @@ -0,0 +1,31 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_WRITE_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_WRITE_H_ + +#include <functional> + +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +std::unique_ptr<RenderPipelineStage> GetWriteToImageBundleStage( + ImageBundle* image_bundle, ColorEncoding color_encoding); + +// Gets a stage to write color channels to an Image3F. +std::unique_ptr<RenderPipelineStage> GetWriteToImage3FStage(Image3F* image); + +// Gets a stage to write to a pixel callback or image buffer. +std::unique_ptr<RenderPipelineStage> GetWriteToOutputStage( + const ImageOutput& main_output, size_t width, size_t height, bool has_alpha, + bool unpremul_alpha, size_t alpha_c, Orientation undo_orientation, + std::vector<ImageOutput>& extra_output); + +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_WRITE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_xyb.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_xyb.cc new file mode 100644 index 0000000000..56e86e6095 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_xyb.cc @@ -0,0 +1,178 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_xyb.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_xyb.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/cms/opsin_params.h" +#include "lib/jxl/common.h" // JXL_HIGH_PRECISION +#include "lib/jxl/dec_xyb-inl.h" +#include "lib/jxl/sanitizers.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +class XYBStage : public RenderPipelineStage { + public: + explicit XYBStage(const OutputEncodingInfo& output_encoding_info) + : RenderPipelineStage(RenderPipelineStage::Settings()), + opsin_params_(output_encoding_info.opsin_params), + output_is_xyb_(output_encoding_info.color_encoding.GetColorSpace() == + ColorSpace::kXYB) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + const HWY_FULL(float) d; + JXL_ASSERT(xextra == 0); + const size_t xsize_v = RoundUpTo(xsize, Lanes(d)); + float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0); + // All calculations are lane-wise, still some might require + // value-dependent behaviour (e.g. NearestInt). Temporary unpoison last + // vector tail. + msan::UnpoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::UnpoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + // TODO(eustas): when using frame origin, addresses might be unaligned; + // making them aligned will void performance penalty. + if (output_is_xyb_) { + const auto scale_x = Set(d, jxl::cms::kScaledXYBScale[0]); + const auto scale_y = Set(d, jxl::cms::kScaledXYBScale[1]); + const auto scale_bmy = Set(d, jxl::cms::kScaledXYBScale[2]); + const auto offset_x = Set(d, jxl::cms::kScaledXYBOffset[0]); + const auto offset_y = Set(d, jxl::cms::kScaledXYBOffset[1]); + const auto offset_bmy = Set(d, jxl::cms::kScaledXYBOffset[2]); + for (ssize_t x = -xextra; x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + const auto in_x = LoadU(d, row0 + x); + const auto in_y = LoadU(d, row1 + x); + const auto in_b = LoadU(d, row2 + x); + auto out_x = Mul(Add(in_x, offset_x), scale_x); + auto out_y = Mul(Add(in_y, offset_y), scale_y); + auto out_b = Mul(Add(Sub(in_b, in_y), offset_bmy), scale_bmy); + StoreU(out_x, d, row0 + x); + StoreU(out_y, d, row1 + x); + StoreU(out_b, d, row2 + x); + } + } else { + for (ssize_t x = -xextra; x < (ssize_t)(xsize + xextra); x += Lanes(d)) { + const auto in_opsin_x = LoadU(d, row0 + x); + const auto in_opsin_y = LoadU(d, row1 + x); + const auto in_opsin_b = LoadU(d, row2 + x); + auto r = Undefined(d); + auto g = Undefined(d); + auto b = Undefined(d); + XybToRgb(d, in_opsin_x, in_opsin_y, in_opsin_b, opsin_params_, &r, &g, + &b); + StoreU(r, d, row0 + x); + StoreU(g, d, row1 + x); + StoreU(b, d, row2 + x); + } + } + msan::PoisonMemory(row0 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row1 + xsize, sizeof(float) * (xsize_v - xsize)); + msan::PoisonMemory(row2 + xsize, sizeof(float) * (xsize_v - xsize)); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "XYB"; } + + private: + const OpsinParams opsin_params_; + const bool output_is_xyb_; +}; + +std::unique_ptr<RenderPipelineStage> GetXYBStage( + const OutputEncodingInfo& output_encoding_info) { + return jxl::make_unique<XYBStage>(output_encoding_info); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetXYBStage); + +std::unique_ptr<RenderPipelineStage> GetXYBStage( + const OutputEncodingInfo& output_encoding_info) { + return HWY_DYNAMIC_DISPATCH(GetXYBStage)(output_encoding_info); +} + +#if !JXL_HIGH_PRECISION +namespace { +class FastXYBStage : public RenderPipelineStage { + public: + FastXYBStage(uint8_t* rgb, size_t stride, size_t width, size_t height, + bool rgba, bool has_alpha, size_t alpha_c) + : RenderPipelineStage(RenderPipelineStage::Settings()), + rgb_(rgb), + stride_(stride), + width_(width), + height_(height), + rgba_(rgba), + has_alpha_(has_alpha), + alpha_c_(alpha_c) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + if (ypos >= height_) return; + JXL_ASSERT(xextra == 0); + const float* xyba[4] = { + GetInputRow(input_rows, 0, 0), GetInputRow(input_rows, 1, 0), + GetInputRow(input_rows, 2, 0), + has_alpha_ ? GetInputRow(input_rows, alpha_c_, 0) : nullptr}; + uint8_t* out_buf = rgb_ + stride_ * ypos + (rgba_ ? 4 : 3) * xpos; + FastXYBTosRGB8(xyba, out_buf, rgba_, + xsize + xpos <= width_ ? xsize : width_ - xpos); + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 || (has_alpha_ && c == alpha_c_) + ? RenderPipelineChannelMode::kInput + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "FastXYB"; } + + private: + uint8_t* rgb_; + size_t stride_; + size_t width_; + size_t height_; + bool rgba_; + bool has_alpha_; + size_t alpha_c_; + std::vector<float> opaque_alpha_; +}; + +} // namespace + +std::unique_ptr<RenderPipelineStage> GetFastXYBTosRGB8Stage( + uint8_t* rgb, size_t stride, size_t width, size_t height, bool rgba, + bool has_alpha, size_t alpha_c) { + JXL_ASSERT(HasFastXYBTosRGB8()); + return make_unique<FastXYBStage>(rgb, stride, width, height, rgba, has_alpha, + alpha_c); +} +#endif + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_xyb.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_xyb.h new file mode 100644 index 0000000000..7b06345c36 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_xyb.h @@ -0,0 +1,26 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_XYB_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_XYB_H_ +#include <stdint.h> + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Converts the color channels from XYB to linear with appropriate primaries. +std::unique_ptr<RenderPipelineStage> GetXYBStage( + const OutputEncodingInfo& output_encoding_info); + +// Gets a stage to convert with fixed point arithmetic from XYB to sRGB8 and +// write to a uint8 buffer. +std::unique_ptr<RenderPipelineStage> GetFastXYBTosRGB8Stage( + uint8_t* rgb, size_t stride, size_t width, size_t height, bool rgba, + bool has_alpha, size_t alpha_c); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_XYB_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_ycbcr.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_ycbcr.cc new file mode 100644 index 0000000000..30ad327221 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_ycbcr.cc @@ -0,0 +1,83 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_ycbcr.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_ycbcr.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::MulAdd; + +class kYCbCrStage : public RenderPipelineStage { + public: + kYCbCrStage() : RenderPipelineStage(RenderPipelineStage::Settings()) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + const HWY_FULL(float) df; + + // Full-range BT.601 as defined by JFIF Clause 7: + // https://www.itu.int/rec/T-REC-T.871-201105-I/en + const auto c128 = Set(df, 128.0f / 255); + const auto crcr = Set(df, 1.402f); + const auto cgcb = Set(df, -0.114f * 1.772f / 0.587f); + const auto cgcr = Set(df, -0.299f * 1.402f / 0.587f); + const auto cbcb = Set(df, 1.772f); + + float* JXL_RESTRICT row0 = GetInputRow(input_rows, 0, 0); + float* JXL_RESTRICT row1 = GetInputRow(input_rows, 1, 0); + float* JXL_RESTRICT row2 = GetInputRow(input_rows, 2, 0); + // TODO(eustas): when using frame origin, addresses might be unaligned; + // making them aligned will void performance penalty. + for (size_t x = 0; x < xsize; x += Lanes(df)) { + const auto y_vec = Add(LoadU(df, row1 + x), c128); + const auto cb_vec = LoadU(df, row0 + x); + const auto cr_vec = LoadU(df, row2 + x); + const auto r_vec = MulAdd(crcr, cr_vec, y_vec); + const auto g_vec = MulAdd(cgcr, cr_vec, MulAdd(cgcb, cb_vec, y_vec)); + const auto b_vec = MulAdd(cbcb, cb_vec, y_vec); + StoreU(r_vec, df, row0 + x); + StoreU(g_vec, df, row1 + x); + StoreU(b_vec, df, row2 + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInPlace + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "YCbCr"; } +}; + +std::unique_ptr<RenderPipelineStage> GetYCbCrStage() { + return jxl::make_unique<kYCbCrStage>(); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetYCbCrStage); + +std::unique_ptr<RenderPipelineStage> GetYCbCrStage() { + return HWY_DYNAMIC_DISPATCH(GetYCbCrStage)(); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_ycbcr.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_ycbcr.h new file mode 100644 index 0000000000..3e99af7a38 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_ycbcr.h @@ -0,0 +1,24 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_RENDER_PIPELINE_STAGE_YCBCR_H_ +#define LIB_JXL_RENDER_PIPELINE_STAGE_YCBCR_H_ +#include <math.h> +#include <stdint.h> + +#include <algorithm> +#include <utility> +#include <vector> + +#include "lib/jxl/dec_xyb.h" +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +// Converts the color channels from YCbCr to RGB. +std::unique_ptr<RenderPipelineStage> GetYCbCrStage(); +} // namespace jxl + +#endif // LIB_JXL_RENDER_PIPELINE_STAGE_YCBCR_H_ diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/test_render_pipeline_stages.h b/third_party/jpeg-xl/lib/jxl/render_pipeline/test_render_pipeline_stages.h new file mode 100644 index 0000000000..789a52f8b2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/test_render_pipeline_stages.h @@ -0,0 +1,101 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <math.h> +#include <stdint.h> +#include <stdio.h> + +#include <algorithm> +#include <utility> +#include <vector> + +#include "lib/jxl/render_pipeline/render_pipeline_stage.h" + +namespace jxl { + +class UpsampleXSlowStage : public RenderPipelineStage { + public: + UpsampleXSlowStage() + : RenderPipelineStage(RenderPipelineStage::Settings::ShiftX(1, 1)) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + for (size_t c = 0; c < input_rows.size(); c++) { + const float* row = GetInputRow(input_rows, c, 0); + float* row_out = GetOutputRow(output_rows, c, 0); + for (int64_t x = -xextra; x < (int64_t)(xsize + xextra); x++) { + float xp = *(row + x - 1); + float xc = *(row + x); + float xn = *(row + x + 1); + float xout0 = xp * 0.25f + xc * 0.75f; + float xout1 = xc * 0.75f + xn * 0.25f; + *(row_out + 2 * x + 0) = xout0; + *(row_out + 2 * x + 1) = xout1; + } + } + } + + const char* GetName() const override { return "TEST::UpsampleXSlowStage"; } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return RenderPipelineChannelMode::kInOut; + } +}; + +class UpsampleYSlowStage : public RenderPipelineStage { + public: + UpsampleYSlowStage() + : RenderPipelineStage(RenderPipelineStage::Settings::ShiftY(1, 1)) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + for (size_t c = 0; c < input_rows.size(); c++) { + const float* rowp = GetInputRow(input_rows, c, -1); + const float* rowc = GetInputRow(input_rows, c, 0); + const float* rown = GetInputRow(input_rows, c, 1); + float* row_out0 = GetOutputRow(output_rows, c, 0); + float* row_out1 = GetOutputRow(output_rows, c, 1); + for (int64_t x = -xextra; x < (int64_t)(xsize + xextra); x++) { + float xp = *(rowp + x); + float xc = *(rowc + x); + float xn = *(rown + x); + float yout0 = xp * 0.25f + xc * 0.75f; + float yout1 = xc * 0.75f + xn * 0.25f; + *(row_out0 + x) = yout0; + *(row_out1 + x) = yout1; + } + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return RenderPipelineChannelMode::kInOut; + } + + const char* GetName() const override { return "TEST::UpsampleYSlowStage"; } +}; + +class Check0FinalStage : public RenderPipelineStage { + public: + Check0FinalStage() : RenderPipelineStage(RenderPipelineStage::Settings()) {} + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + for (size_t c = 0; c < input_rows.size(); c++) { + for (size_t x = 0; x < xsize; x++) { + JXL_CHECK(fabsf(GetInputRow(input_rows, c, 0)[x]) < 1e-8); + } + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return RenderPipelineChannelMode::kInput; + } + const char* GetName() const override { return "TEST::Check0FinalStage"; } +}; + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/roundtrip_test.cc b/third_party/jpeg-xl/lib/jxl/roundtrip_test.cc new file mode 100644 index 0000000000..c00fda0de1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/roundtrip_test.cc @@ -0,0 +1,985 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/cms.h> +#include <jxl/codestream_header.h> +#include <jxl/color_encoding.h> +#include <jxl/decode.h> +#include <jxl/decode_cxx.h> +#include <jxl/encode.h> +#include <jxl/encode_cxx.h> +#include <jxl/types.h> + +#include <cstddef> +#include <cstdint> +#include <cstdio> +#include <cstring> +#include <string> +#include <utility> +#include <vector> + +#include "lib/extras/codec.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/encode_internal.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace { + +// Converts a test image to a CodecInOut. +// icc_profile can be empty to automatically deduce profile from the pixel +// format, or filled in to force this ICC profile +jxl::CodecInOut ConvertTestImage(const std::vector<uint8_t>& buf, + const size_t xsize, const size_t ysize, + const JxlPixelFormat& pixel_format, + const jxl::Bytes& icc_profile) { + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + + bool is_gray = pixel_format.num_channels < 3; + bool has_alpha = + pixel_format.num_channels == 2 || pixel_format.num_channels == 4; + + io.metadata.m.color_encoding.SetColorSpace(is_gray ? jxl::ColorSpace::kGray + : jxl::ColorSpace::kRGB); + if (has_alpha) { + // Note: alpha > 16 not yet supported by the C++ codec + switch (pixel_format.data_type) { + case JXL_TYPE_UINT8: + io.metadata.m.SetAlphaBits(8); + break; + case JXL_TYPE_UINT16: + case JXL_TYPE_FLOAT: + case JXL_TYPE_FLOAT16: + io.metadata.m.SetAlphaBits(16); + break; + default: + ADD_FAILURE() << "Roundtrip tests for data type " + << pixel_format.data_type << " not yet implemented."; + } + } + size_t bitdepth = 0; + switch (pixel_format.data_type) { + case JXL_TYPE_FLOAT: + bitdepth = 32; + io.metadata.m.SetFloat32Samples(); + break; + case JXL_TYPE_FLOAT16: + bitdepth = 16; + io.metadata.m.SetFloat16Samples(); + break; + case JXL_TYPE_UINT8: + bitdepth = 8; + io.metadata.m.SetUintSamples(8); + break; + case JXL_TYPE_UINT16: + bitdepth = 16; + io.metadata.m.SetUintSamples(16); + break; + default: + ADD_FAILURE() << "Roundtrip tests for data type " + << pixel_format.data_type << " not yet implemented."; + } + jxl::ColorEncoding color_encoding; + if (!icc_profile.empty()) { + jxl::IccBytes icc_profile_copy; + icc_profile.AppendTo(&icc_profile_copy); + EXPECT_TRUE( + color_encoding.SetICC(std::move(icc_profile_copy), JxlGetDefaultCms())); + } else if (pixel_format.data_type == JXL_TYPE_FLOAT) { + color_encoding = jxl::ColorEncoding::LinearSRGB(is_gray); + } else { + color_encoding = jxl::ColorEncoding::SRGB(is_gray); + } + EXPECT_TRUE(ConvertFromExternal(jxl::Bytes(buf), xsize, ysize, color_encoding, + /*bits_per_sample=*/bitdepth, pixel_format, + /*pool=*/nullptr, &io.Main())); + return io; +} + +template <typename T> +T ConvertTestPixel(float val); + +template <> +float ConvertTestPixel<float>(const float val) { + return val; +} + +template <> +uint16_t ConvertTestPixel<uint16_t>(const float val) { + return (uint16_t)(val * UINT16_MAX); +} + +template <> +uint8_t ConvertTestPixel<uint8_t>(const float val) { + return (uint8_t)(val * UINT8_MAX); +} + +// Returns a test image. +template <typename T> +std::vector<uint8_t> GetTestImage(const size_t xsize, const size_t ysize, + const JxlPixelFormat& pixel_format) { + std::vector<T> pixels(xsize * ysize * pixel_format.num_channels); + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + for (size_t chan = 0; chan < pixel_format.num_channels; chan++) { + float val; + switch (chan % 4) { + case 0: + val = static_cast<float>(y) / static_cast<float>(ysize); + break; + case 1: + val = static_cast<float>(x) / static_cast<float>(xsize); + break; + case 2: + val = static_cast<float>(x + y) / static_cast<float>(xsize + ysize); + break; + case 3: + val = static_cast<float>(x * y) / static_cast<float>(xsize * ysize); + break; + } + pixels[(y * xsize + x) * pixel_format.num_channels + chan] = + ConvertTestPixel<T>(val); + } + } + } + std::vector<uint8_t> bytes(pixels.size() * sizeof(T)); + memcpy(bytes.data(), pixels.data(), sizeof(T) * pixels.size()); + return bytes; +} + +void EncodeWithEncoder(JxlEncoder* enc, std::vector<uint8_t>* compressed) { + compressed->resize(64); + uint8_t* next_out = compressed->data(); + size_t avail_out = compressed->size() - (next_out - compressed->data()); + JxlEncoderStatus process_result = JXL_ENC_NEED_MORE_OUTPUT; + while (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + process_result = JxlEncoderProcessOutput(enc, &next_out, &avail_out); + if (process_result == JXL_ENC_NEED_MORE_OUTPUT) { + size_t offset = next_out - compressed->data(); + compressed->resize(compressed->size() * 2); + next_out = compressed->data() + offset; + avail_out = compressed->size() - offset; + } + } + compressed->resize(next_out - compressed->data()); + EXPECT_EQ(JXL_ENC_SUCCESS, process_result); +} + +// Generates some pixels using some dimensions and pixel_format, +// compresses them, and verifies that the decoded version is similar to the +// original pixels. +// TODO(firsching): change this to be a parameterized test, like in +// decode_test.cc +template <typename T> +void VerifyRoundtripCompression( + const size_t xsize, const size_t ysize, + const JxlPixelFormat& input_pixel_format, + const JxlPixelFormat& output_pixel_format, const bool lossless, + const bool use_container, const uint32_t resampling = 1, + const bool already_downsampled = false, + const std::vector<std::pair<JxlExtraChannelType, std::string>>& + extra_channels = {}, + const int upsampling_mode = -1) { + size_t orig_xsize = xsize; + size_t orig_ysize = ysize; + if (already_downsampled) { + orig_xsize = jxl::DivCeil(xsize, resampling); + orig_ysize = jxl::DivCeil(ysize, resampling); + } + + JxlPixelFormat extra_channel_pixel_format = input_pixel_format; + extra_channel_pixel_format.num_channels = 1; + const std::vector<uint8_t> extra_channel_bytes = + GetTestImage<T>(xsize, ysize, extra_channel_pixel_format); + const std::vector<uint8_t> original_bytes = + GetTestImage<T>(orig_xsize, orig_ysize, input_pixel_format); + jxl::CodecInOut original_io = ConvertTestImage( + original_bytes, orig_xsize, orig_ysize, input_pixel_format, {}); + + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc, 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc, use_container)); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &input_pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = lossless; + uint32_t num_channels = input_pixel_format.num_channels; + size_t has_interleaved_alpha = num_channels == 2 || num_channels == 4; + JxlPixelFormat output_pixel_format_with_extra_channel_alpha = + output_pixel_format; + + // In the case where we have an alpha channel, but it is provided as an extra + // channel and not interleaved, we do two things here: + // 1. modify the original_io to have the correct alpha channel + // 2. change the output_format_with_extra_alpha to have an alpha channel + bool alpha_in_extra_channels_vector = false; + for (const auto& extra_channel : extra_channels) { + if (extra_channel.first == JXL_CHANNEL_ALPHA) { + alpha_in_extra_channels_vector = true; + } + } + if (alpha_in_extra_channels_vector && !has_interleaved_alpha) { + jxl::ImageF alpha_channel(xsize, ysize); + EXPECT_TRUE(jxl::ConvertFromExternal( + extra_channel_bytes.data(), extra_channel_bytes.size(), xsize, ysize, + basic_info.bits_per_sample, extra_channel_pixel_format, 0, + /*pool=*/nullptr, &alpha_channel)); + + original_io.metadata.m.SetAlphaBits(basic_info.bits_per_sample); + original_io.Main().SetAlpha(std::move(alpha_channel)); + output_pixel_format_with_extra_channel_alpha.num_channels++; + } + // Those are the num_extra_channels including a potential alpha channel. + basic_info.num_extra_channels = extra_channels.size() + has_interleaved_alpha; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + EXPECT_EQ(enc->metadata.m.num_extra_channels, + extra_channels.size() + has_interleaved_alpha); + JxlColorEncoding color_encoding; + if (input_pixel_format.data_type == JXL_TYPE_FLOAT) { + JxlColorEncodingSetToLinearSRGB( + &color_encoding, + /*is_gray=*/input_pixel_format.num_channels < 3); + } else { + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/input_pixel_format.num_channels < 3); + } + + std::vector<JxlExtraChannelInfo> channel_infos; + for (const auto& extra_channel : extra_channels) { + auto channel_type = extra_channel.first; + JxlExtraChannelInfo channel_info; + JxlEncoderInitExtraChannelInfo(channel_type, &channel_info); + channel_info.bits_per_sample = (lossless ? basic_info.bits_per_sample : 8); + channel_info.exponent_bits_per_sample = + (lossless ? basic_info.exponent_bits_per_sample : 0); + channel_infos.push_back(channel_info); + } + for (size_t index = 0; index < channel_infos.size(); index++) { + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetExtraChannelInfo(enc, index + has_interleaved_alpha, + &channel_infos[index])); + std::string name = extra_channels[index].second; + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetExtraChannelName(enc, index + has_interleaved_alpha, + name.c_str(), name.length())); + } + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetColorEncoding(enc, &color_encoding)); + if (resampling > 1) { + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderSetUpsamplingMode(enc, 3, 0)); + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderSetUpsamplingMode(enc, resampling, -2)); + EXPECT_EQ(JXL_ENC_ERROR, JxlEncoderSetUpsamplingMode(enc, resampling, 2)); + } + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetUpsamplingMode(enc, resampling, upsampling_mode)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc, nullptr); + JxlEncoderSetFrameLossless(frame_settings, lossless); + if (resampling > 1) { + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_RESAMPLING, resampling)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption( + frame_settings, JXL_ENC_FRAME_SETTING_ALREADY_DOWNSAMPLED, + already_downsampled)); + } + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &input_pixel_format, + (void*)original_bytes.data(), + original_bytes.size())); + EXPECT_EQ(frame_settings->enc->input_queue.empty(), false); + for (size_t index = 0; index < channel_infos.size(); index++) { + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetExtraChannelBuffer( + frame_settings, &extra_channel_pixel_format, + (void*)extra_channel_bytes.data(), extra_channel_bytes.size(), + index + has_interleaved_alpha)); + } + JxlEncoderCloseInput(enc); + std::vector<uint8_t> compressed; + EncodeWithEncoder(enc, &compressed); + JxlEncoderDestroy(enc); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)); + + JxlDecoderSetInput(dec, next_in, avail_in); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize( + dec, &output_pixel_format_with_extra_channel_alpha, &buffer_size)); + if (&input_pixel_format == &output_pixel_format_with_extra_channel_alpha && + !already_downsampled) { + EXPECT_EQ(buffer_size, original_bytes.size()); + } + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + EXPECT_EQ(extra_channels.size() + has_interleaved_alpha, + info.num_extra_channels); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + size_t icc_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_DATA, + &icc_profile_size)); + std::vector<uint8_t> icc_profile(icc_profile_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile.data(), icc_profile.size())); + + std::vector<uint8_t> decoded_bytes(buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer( + dec, &output_pixel_format_with_extra_channel_alpha, + decoded_bytes.data(), decoded_bytes.size())); + std::vector<std::vector<uint8_t>> extra_channel_decoded_bytes( + info.num_extra_channels - has_interleaved_alpha); + + for (size_t index = has_interleaved_alpha; index < info.num_extra_channels; + index++) { + JxlExtraChannelInfo channel_info; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetExtraChannelInfo(dec, index, &channel_info)); + EXPECT_EQ(channel_info.type, + extra_channels[index - has_interleaved_alpha].first); + std::string input_name = + extra_channels[index - has_interleaved_alpha].second; + const size_t name_length = channel_info.name_length; + EXPECT_EQ(input_name.size(), name_length); + std::vector<char> output_name(name_length + 1); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetExtraChannelName(dec, index, output_name.data(), + output_name.size())); + EXPECT_EQ(0, + memcmp(input_name.data(), output_name.data(), input_name.size())); + size_t extra_buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderExtraChannelBufferSize(dec, &output_pixel_format, + &extra_buffer_size, index)); + std::vector<uint8_t> extra_decoded_bytes(extra_buffer_size); + extra_channel_decoded_bytes[index - has_interleaved_alpha] = + std::move(extra_decoded_bytes); + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSetExtraChannelBuffer( + dec, &output_pixel_format, + extra_channel_decoded_bytes[index - has_interleaved_alpha].data(), + extra_channel_decoded_bytes[index - has_interleaved_alpha].size(), + index)); + } + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + // Check if there are no further errors after getting the full image, e.g. + // check that the final codestream box is actually marked as last. + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); + + jxl::CodecInOut decoded_io = ConvertTestImage( + decoded_bytes, xsize, ysize, output_pixel_format_with_extra_channel_alpha, + jxl::Bytes(icc_profile)); + + if (already_downsampled) { + jxl::Image3F* color = decoded_io.Main().color(); + jxl::DownsampleImage(color, resampling); + if (decoded_io.Main().HasAlpha()) { + jxl::ImageF* alpha = decoded_io.Main().alpha(); + jxl::DownsampleImage(alpha, resampling); + } + decoded_io.SetSize(color->xsize(), color->ysize()); + } + + if (lossless && !already_downsampled) { + JXL_EXPECT_OK(jxl::SamePixels(*original_io.Main().color(), + *decoded_io.Main().color(), _)); + } else { + jxl::ButteraugliParams ba; + float butteraugli_score = ButteraugliDistance( + original_io.frames, decoded_io.frames, ba, *JxlGetDefaultCms(), + /*distmap=*/nullptr, nullptr); + float target_score = 1.3f; + // upsampling mode 1 (unlike default and NN) does not downscale back to the + // already downsampled image + if (upsampling_mode == 1 && resampling >= 4 && already_downsampled) + target_score = 15.f; + EXPECT_LE(butteraugli_score, target_score); + } + JxlPixelFormat extra_channel_output_pixel_format = output_pixel_format; + extra_channel_output_pixel_format.num_channels = 1; + for (auto& extra_channel : extra_channel_decoded_bytes) { + EXPECT_EQ(extra_channel.size(), extra_channel_bytes.size()); + if (lossless) { + EXPECT_EQ(jxl::test::ComparePixels(extra_channel.data(), + extra_channel_bytes.data(), xsize, + ysize, extra_channel_pixel_format, + extra_channel_output_pixel_format), + 0u); + EXPECT_EQ(extra_channel, extra_channel_bytes); + } + } +} + +} // namespace + +TEST(RoundtripTest, FloatFrameRoundtripTest) { + std::vector<std::vector<std::pair<JxlExtraChannelType, std::string>>> + extra_channels_cases = {{}, + {{JXL_CHANNEL_ALPHA, "my extra alpha channel"}}, + {{JXL_CHANNEL_CFA, "my cfa channel"}}, + {{JXL_CHANNEL_DEPTH, "depth"}, + {JXL_CHANNEL_SELECTION_MASK, "mask"}, + {JXL_CHANNEL_BLACK, "black"}, + {JXL_CHANNEL_CFA, "my cfa channel"}, + {JXL_CHANNEL_OPTIONAL, "optional channel"}}, + {{JXL_CHANNEL_DEPTH, "very deep"}}}; + for (int use_container = 0; use_container < 2; use_container++) { + for (int lossless = 0; lossless < 2; lossless++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + for (auto& extra_channels : extra_channels_cases) { + uint32_t has_alpha = static_cast<uint32_t>(num_channels % 2 == 0); + uint32_t total_extra_channels = has_alpha + extra_channels.size(); + // There's no support (yet) for lossless extra float + // channels, so we don't test it. + if (total_extra_channels == 0 || !lossless) { + JxlPixelFormat pixel_format = JxlPixelFormat{ + num_channels, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression<float>( + 63, 129, pixel_format, pixel_format, (bool)lossless, + (bool)use_container, 1, false, extra_channels); + } + } + } + } + } +} + +TEST(RoundtripTest, Uint16FrameRoundtripTest) { + std::vector<std::vector<std::pair<JxlExtraChannelType, std::string>>> + extra_channels_cases = {{}, + {{JXL_CHANNEL_ALPHA, "my extra alpha channel"}}, + {{JXL_CHANNEL_CFA, "my cfa channel"}}, + {{JXL_CHANNEL_CFA, "my cfa channel"}, + {JXL_CHANNEL_BLACK, "k_channel"}}, + {{JXL_CHANNEL_DEPTH, "very deep"}}}; + for (int use_container = 0; use_container < 2; use_container++) { + for (int lossless = 0; lossless < 2; lossless++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + for (auto& extra_channels : extra_channels_cases) { + JxlPixelFormat pixel_format = JxlPixelFormat{ + num_channels, JXL_TYPE_UINT16, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression<uint16_t>( + 63, 129, pixel_format, pixel_format, (bool)lossless, + (bool)use_container, 1, false, extra_channels); + } + } + } + } +} + +TEST(RoundtripTest, Uint8FrameRoundtripTest) { + std::vector<std::vector<std::pair<JxlExtraChannelType, std::string>>> + extra_channels_cases = {{}, + {{JXL_CHANNEL_THERMAL, "temperature"}}, + {{JXL_CHANNEL_ALPHA, "my extra alpha channel"}}, + {{JXL_CHANNEL_CFA, "my cfa channel"}}, + {{JXL_CHANNEL_CFA, "my cfa channel"}, + {JXL_CHANNEL_BLACK, "k_channel"}}, + {{JXL_CHANNEL_DEPTH, "very deep"}}}; + for (int use_container = 0; use_container < 2; use_container++) { + for (int lossless = 0; lossless < 2; lossless++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + for (auto& extra_channels : extra_channels_cases) { + JxlPixelFormat pixel_format = JxlPixelFormat{ + num_channels, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression<uint8_t>( + 63, 129, pixel_format, pixel_format, (bool)lossless, + (bool)use_container, 1, false, extra_channels); + } + } + } + } +} + +TEST(RoundtripTest, TestNonlinearSrgbAsXybEncoded) { + for (int use_container = 0; use_container < 2; use_container++) { + for (uint32_t num_channels = 1; num_channels < 5; num_channels++) { + JxlPixelFormat pixel_format_in = + JxlPixelFormat{num_channels, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + JxlPixelFormat pixel_format_out = + JxlPixelFormat{num_channels, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression<uint8_t>( + 63, 129, pixel_format_in, pixel_format_out, + /*lossless=*/false, (bool)use_container, 1, false, {}); + } + } +} + +TEST(RoundtripTest, Resampling) { + JxlPixelFormat pixel_format = + JxlPixelFormat{3, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + VerifyRoundtripCompression<uint8_t>(63, 129, pixel_format, pixel_format, + /*lossless=*/false, + /*use_container=*/false, 2, + /*already_downsampled=*/false); + + // TODO(lode): also make this work for odd sizes. This requires a fix in + // enc_frame.cc to not set custom_size_or_origin to true due to even/odd + // mismatch. + for (int factor : {2, 4, 8}) { + for (int upsampling_mode : {-1, 0, 1}) { + VerifyRoundtripCompression<uint8_t>( + 64, 128, pixel_format, pixel_format, + /*lossless=*/true, + /*use_container=*/false, factor, + /*already_downsampled=*/true, /*extra_channels=*/{}, upsampling_mode); + } + } +} + +TEST(RoundtripTest, ExtraBoxesTest) { + JxlPixelFormat pixel_format = + JxlPixelFormat{4, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + const size_t xsize = 61; + const size_t ysize = 71; + + const std::vector<uint8_t> original_bytes = + GetTestImage<float>(xsize, ysize, pixel_format); + jxl::CodecInOut original_io = + ConvertTestImage(original_bytes, xsize, ysize, pixel_format, {}); + + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc, true)); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = false; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc, 10)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + JxlColorEncoding color_encoding; + if (pixel_format.data_type == JXL_TYPE_FLOAT) { + JxlColorEncodingSetToLinearSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + } else { + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + } + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetColorEncoding(enc, &color_encoding)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc, nullptr); + JxlEncoderSetFrameLossless(frame_settings, false); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &pixel_format, + (void*)original_bytes.data(), + original_bytes.size())); + JxlEncoderCloseInput(enc); + + std::vector<uint8_t> compressed; + EncodeWithEncoder(enc, &compressed); + JxlEncoderDestroy(enc); + + std::vector<uint8_t> extra_data(1023); + jxl::AppendBoxHeader(jxl::MakeBoxType("crud"), extra_data.size(), false, + &compressed); + compressed.insert(compressed.end(), extra_data.begin(), extra_data.end()); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)); + + JxlDecoderSetInput(dec, next_in, avail_in); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &pixel_format, &buffer_size)); + EXPECT_EQ(buffer_size, original_bytes.size()); + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + size_t icc_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_DATA, + &icc_profile_size)); + std::vector<uint8_t> icc_profile(icc_profile_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile.data(), icc_profile.size())); + + std::vector<uint8_t> decoded_bytes(buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderSetImageOutBuffer(dec, &pixel_format, + decoded_bytes.data(), + decoded_bytes.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + JxlDecoderDestroy(dec); + + jxl::CodecInOut decoded_io = ConvertTestImage( + decoded_bytes, xsize, ysize, pixel_format, jxl::Bytes(icc_profile)); + + jxl::ButteraugliParams ba; + float butteraugli_score = ButteraugliDistance( + original_io.frames, decoded_io.frames, ba, *JxlGetDefaultCms(), + /*distmap=*/nullptr, nullptr); + EXPECT_LE(butteraugli_score, 1.0f); +} + +TEST(RoundtripTest, MultiFrameTest) { + JxlPixelFormat pixel_format = + JxlPixelFormat{4, JXL_TYPE_FLOAT, JXL_NATIVE_ENDIAN, 0}; + const size_t xsize = 61; + const size_t ysize = 71; + const size_t nb_frames = 4; + size_t compressed_size = 0; + + for (int index_frames : {0, 1}) { + // use a vertical filmstrip of nb_frames frames + const std::vector<uint8_t> original_bytes = + GetTestImage<float>(xsize, ysize * nb_frames, pixel_format); + jxl::CodecInOut original_io = ConvertTestImage( + original_bytes, xsize, ysize * nb_frames, pixel_format, {}); + + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc, true)); + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &pixel_format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = JXL_FALSE; + basic_info.have_animation = JXL_TRUE; + basic_info.animation.tps_numerator = 1; + basic_info.animation.tps_denominator = 1; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetCodestreamLevel(enc, 10)); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + JxlColorEncoding color_encoding; + if (pixel_format.data_type == JXL_TYPE_FLOAT) { + JxlColorEncodingSetToLinearSRGB( + &color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + } else { + JxlColorEncodingSetToSRGB(&color_encoding, + /*is_gray=*/pixel_format.num_channels < 3); + } + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetColorEncoding(enc, &color_encoding)); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc, nullptr); + JxlEncoderSetFrameLossless(frame_settings, false); + if (index_frames == 1) { + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderFrameSettingsSetOption(frame_settings, + JXL_ENC_FRAME_INDEX_BOX, 1)); + } + + size_t oneframesize = original_bytes.size() / nb_frames; + JxlFrameHeader frame_header; + JxlEncoderInitFrameHeader(&frame_header); + frame_header.duration = 1; + frame_header.is_last = JXL_FALSE; + + for (size_t i = 0; i < nb_frames; i++) { + if (i + 1 == nb_frames) frame_header.is_last = JXL_TRUE; + JxlEncoderSetFrameHeader(frame_settings, &frame_header); + EXPECT_EQ( + JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame( + frame_settings, &pixel_format, + (void*)(original_bytes.data() + oneframesize * i), oneframesize)); + } + JxlEncoderCloseInput(enc); + + std::vector<uint8_t> compressed; + EncodeWithEncoder(enc, &compressed); + JxlEncoderDestroy(enc); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + if (index_frames == 0) { + compressed_size = avail_in; + } else { + // a non-empty jxli box should be added + EXPECT_LE(avail_in, compressed_size + 50); + EXPECT_GE(avail_in, compressed_size + 10); + } + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)); + + JxlDecoderSetInput(dec, next_in, avail_in); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &pixel_format, &buffer_size)); + EXPECT_EQ(buffer_size, oneframesize); + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + size_t icc_profile_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_DATA, + &icc_profile_size)); + std::vector<uint8_t> icc_profile(icc_profile_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, JXL_COLOR_PROFILE_TARGET_DATA, + icc_profile.data(), icc_profile.size())); + + std::vector<uint8_t> decoded_bytes(buffer_size * nb_frames); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + for (size_t i = 0; i < nb_frames; i++) { + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer( + dec, &pixel_format, decoded_bytes.data() + i * oneframesize, + buffer_size)); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + } + JxlDecoderDestroy(dec); + jxl::CodecInOut decoded_io = + ConvertTestImage(decoded_bytes, xsize, ysize * nb_frames, pixel_format, + jxl::Bytes(icc_profile)); + + jxl::ButteraugliParams ba; + float butteraugli_score = ButteraugliDistance( + original_io.frames, decoded_io.frames, ba, *JxlGetDefaultCms(), + /*distmap=*/nullptr, nullptr); + EXPECT_LE(butteraugli_score, 1.0f); + } +} + +static const unsigned char kEncodedTestProfile[] = { + 0x1f, 0x8b, 0x1, 0x13, 0x10, 0x0, 0x0, 0x0, 0x20, 0x4c, 0xcc, 0x3, + 0xe7, 0xa0, 0xa5, 0xa2, 0x90, 0xa4, 0x27, 0xe8, 0x79, 0x1d, 0xe3, 0x26, + 0x57, 0x54, 0xef, 0x0, 0xe8, 0x97, 0x2, 0xce, 0xa1, 0xd7, 0x85, 0x16, + 0xb4, 0x29, 0x94, 0x58, 0xf2, 0x56, 0xc0, 0x76, 0xea, 0x23, 0xec, 0x7c, + 0x73, 0x51, 0x41, 0x40, 0x23, 0x21, 0x95, 0x4, 0x75, 0x12, 0xc9, 0xcc, + 0x16, 0xbd, 0xb6, 0x99, 0xad, 0xf8, 0x75, 0x35, 0xb6, 0x42, 0xae, 0xae, + 0xae, 0x86, 0x56, 0xf8, 0xcc, 0x16, 0x30, 0xb3, 0x45, 0xad, 0xd, 0x40, + 0xd6, 0xd1, 0xd6, 0x99, 0x40, 0xbe, 0xe2, 0xdc, 0x31, 0x7, 0xa6, 0xb9, + 0x27, 0x92, 0x38, 0x0, 0x3, 0x5e, 0x2c, 0xbe, 0xe6, 0xfb, 0x19, 0xbf, + 0xf3, 0x6d, 0xbc, 0x4d, 0x64, 0xe5, 0xba, 0x76, 0xde, 0x31, 0x65, 0x66, + 0x14, 0xa6, 0x3a, 0xc5, 0x8f, 0xb1, 0xb4, 0xba, 0x1f, 0xb1, 0xb8, 0xd4, + 0x75, 0xba, 0x18, 0x86, 0x95, 0x3c, 0x26, 0xf6, 0x25, 0x62, 0x53, 0xfd, + 0x9c, 0x94, 0x76, 0xf6, 0x95, 0x2c, 0xb1, 0xfd, 0xdc, 0xc0, 0xe4, 0x3f, + 0xb3, 0xff, 0x67, 0xde, 0xd5, 0x94, 0xcc, 0xb0, 0x83, 0x2f, 0x28, 0x93, + 0x92, 0x3, 0xa1, 0x41, 0x64, 0x60, 0x62, 0x70, 0x80, 0x87, 0xaf, 0xe7, + 0x60, 0x4a, 0x20, 0x23, 0xb3, 0x11, 0x7, 0x38, 0x38, 0xd4, 0xa, 0x66, + 0xb5, 0x93, 0x41, 0x90, 0x19, 0x17, 0x18, 0x60, 0xa5, 0xb, 0x7a, 0x24, + 0xaa, 0x20, 0x81, 0xac, 0xa9, 0xa1, 0x70, 0xa6, 0x12, 0x8a, 0x4a, 0xa3, + 0xa0, 0xf9, 0x9a, 0x97, 0xe7, 0xa8, 0xac, 0x8, 0xa8, 0xc4, 0x2a, 0x86, + 0xa7, 0x69, 0x1e, 0x67, 0xe6, 0xbe, 0xa4, 0xd3, 0xff, 0x91, 0x61, 0xf6, + 0x8a, 0xe6, 0xb5, 0xb3, 0x61, 0x9f, 0x19, 0x17, 0x98, 0x27, 0x6b, 0xe9, + 0x8, 0x98, 0xe1, 0x21, 0x4a, 0x9, 0xb5, 0xd7, 0xca, 0xfa, 0x94, 0xd0, + 0x69, 0x1a, 0xeb, 0x52, 0x1, 0x4e, 0xf5, 0xf6, 0xdf, 0x7f, 0xe7, 0x29, + 0x70, 0xee, 0x4, 0xda, 0x2f, 0xa4, 0xff, 0xfe, 0xbb, 0x6f, 0xa8, 0xff, + 0xfe, 0xdb, 0xaf, 0x8, 0xf6, 0x72, 0xa1, 0x40, 0x5d, 0xf0, 0x2d, 0x8, + 0x82, 0x5b, 0x87, 0xbd, 0x10, 0x8, 0xe9, 0x7, 0xee, 0x4b, 0x80, 0xda, + 0x4a, 0x4, 0xc5, 0x5e, 0xa0, 0xb7, 0x1e, 0x60, 0xb0, 0x59, 0x76, 0x60, + 0xb, 0x2e, 0x19, 0x8a, 0x2e, 0x1c, 0xe6, 0x6, 0x20, 0xb8, 0x64, 0x18, + 0x2a, 0xcf, 0x51, 0x94, 0xd4, 0xee, 0xc3, 0xfe, 0x39, 0x74, 0xd4, 0x2b, + 0x48, 0xc9, 0x83, 0x4c, 0x9b, 0xd0, 0x4c, 0x35, 0x10, 0xe3, 0x9, 0xf7, + 0x72, 0xf0, 0x7a, 0xe, 0xbf, 0x7d, 0x36, 0x2e, 0x19, 0x7e, 0x3f, 0xc, + 0xf7, 0x93, 0xe7, 0xf4, 0x1d, 0x32, 0xc6, 0xb0, 0x89, 0xad, 0xe0, 0x28, + 0xc1, 0xa7, 0x59, 0xe3, 0x0, +}; + +TEST(RoundtripTest, TestICCProfile) { + // JxlEncoderSetICCProfile parses the ICC profile, so a valid profile is + // needed. The profile should be passed correctly through the roundtrip. + jxl::BitReader reader( + jxl::Bytes(kEncodedTestProfile, sizeof(kEncodedTestProfile))); + std::vector<uint8_t> icc; + ASSERT_TRUE(jxl::test::ReadICC(&reader, &icc)); + ASSERT_TRUE(reader.Close()); + + JxlPixelFormat format = + JxlPixelFormat{3, JXL_TYPE_UINT8, JXL_NATIVE_ENDIAN, 0}; + + size_t xsize = 25; + size_t ysize = 37; + const std::vector<uint8_t> original_bytes = + GetTestImage<uint8_t>(xsize, ysize, format); + + JxlEncoder* enc = JxlEncoderCreate(nullptr); + EXPECT_NE(nullptr, enc); + + JxlBasicInfo basic_info; + jxl::test::JxlBasicInfoSetFromPixelFormat(&basic_info, &format); + basic_info.xsize = xsize; + basic_info.ysize = ysize; + basic_info.uses_original_profile = JXL_TRUE; + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderSetBasicInfo(enc, &basic_info)); + + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderSetICCProfile(enc, icc.data(), icc.size())); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc, nullptr); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddImageFrame(frame_settings, &format, + (void*)original_bytes.data(), + original_bytes.size())); + JxlEncoderCloseInput(enc); + + std::vector<uint8_t> compressed; + EncodeWithEncoder(enc, &compressed); + JxlEncoderDestroy(enc); + + JxlDecoder* dec = JxlDecoderCreate(nullptr); + EXPECT_NE(nullptr, dec); + + const uint8_t* next_in = compressed.data(); + size_t avail_in = compressed.size(); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents(dec, JXL_DEC_BASIC_INFO | + JXL_DEC_COLOR_ENCODING | + JXL_DEC_FULL_IMAGE)); + + JxlDecoderSetInput(dec, next_in, avail_in); + EXPECT_EQ(JXL_DEC_BASIC_INFO, JxlDecoderProcessInput(dec)); + size_t buffer_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderImageOutBufferSize(dec, &format, &buffer_size)); + EXPECT_EQ(buffer_size, original_bytes.size()); + + JxlBasicInfo info; + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetBasicInfo(dec, &info)); + EXPECT_EQ(xsize, info.xsize); + EXPECT_EQ(ysize, info.ysize); + + EXPECT_EQ(JXL_DEC_COLOR_ENCODING, JxlDecoderProcessInput(dec)); + + size_t dec_icc_size; + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderGetICCProfileSize(dec, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + &dec_icc_size)); + EXPECT_EQ(icc.size(), dec_icc_size); + std::vector<uint8_t> dec_icc(dec_icc_size); + EXPECT_EQ(JXL_DEC_SUCCESS, JxlDecoderGetColorAsICCProfile( + dec, JXL_COLOR_PROFILE_TARGET_ORIGINAL, + dec_icc.data(), dec_icc.size())); + + std::vector<uint8_t> decoded_bytes(buffer_size); + + EXPECT_EQ(JXL_DEC_NEED_IMAGE_OUT_BUFFER, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetImageOutBuffer(dec, &format, decoded_bytes.data(), + decoded_bytes.size())); + + EXPECT_EQ(JXL_DEC_FULL_IMAGE, JxlDecoderProcessInput(dec)); + + EXPECT_EQ(icc, dec_icc); + + JxlDecoderDestroy(dec); +} + +TEST(RoundtripTest, JXL_TRANSCODE_JPEG_TEST(TestJPEGReconstruction)) { + TEST_LIBJPEG_SUPPORT(); + const std::string jpeg_path = "jxl/flower/flower.png.im_q85_420.jpg"; + const std::vector<uint8_t> orig = jxl::test::ReadTestData(jpeg_path); + jxl::CodecInOut orig_io; + ASSERT_TRUE(SetFromBytes(jxl::Bytes(orig), &orig_io, /*pool=*/nullptr)); + + JxlEncoderPtr enc = JxlEncoderMake(nullptr); + JxlEncoderFrameSettings* frame_settings = + JxlEncoderFrameSettingsCreate(enc.get(), NULL); + + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderUseContainer(enc.get(), JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, JxlEncoderStoreJPEGMetadata(enc.get(), JXL_TRUE)); + EXPECT_EQ(JXL_ENC_SUCCESS, + JxlEncoderAddJPEGFrame(frame_settings, orig.data(), orig.size())); + JxlEncoderCloseInput(enc.get()); + + std::vector<uint8_t> compressed; + EncodeWithEncoder(enc.get(), &compressed); + + JxlDecoderPtr dec = JxlDecoderMake(nullptr); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSubscribeEvents( + dec.get(), JXL_DEC_JPEG_RECONSTRUCTION | JXL_DEC_FULL_IMAGE)); + JxlDecoderSetInput(dec.get(), compressed.data(), compressed.size()); + EXPECT_EQ(JXL_DEC_JPEG_RECONSTRUCTION, JxlDecoderProcessInput(dec.get())); + std::vector<uint8_t> reconstructed_buffer(128); + EXPECT_EQ(JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data(), + reconstructed_buffer.size())); + size_t used = 0; + JxlDecoderStatus dec_process_result = JXL_DEC_JPEG_NEED_MORE_OUTPUT; + while (dec_process_result == JXL_DEC_JPEG_NEED_MORE_OUTPUT) { + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + reconstructed_buffer.resize(reconstructed_buffer.size() * 2); + EXPECT_EQ( + JXL_DEC_SUCCESS, + JxlDecoderSetJPEGBuffer(dec.get(), reconstructed_buffer.data() + used, + reconstructed_buffer.size() - used)); + dec_process_result = JxlDecoderProcessInput(dec.get()); + } + ASSERT_EQ(JXL_DEC_FULL_IMAGE, dec_process_result); + used = reconstructed_buffer.size() - JxlDecoderReleaseJPEGBuffer(dec.get()); + ASSERT_EQ(used, orig.size()); + EXPECT_EQ(0, memcmp(reconstructed_buffer.data(), orig.data(), used)); +} diff --git a/third_party/jpeg-xl/lib/jxl/sanitizers.h b/third_party/jpeg-xl/lib/jxl/sanitizers.h new file mode 100644 index 0000000000..adeaea67ed --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/sanitizers.h @@ -0,0 +1,242 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_SANITIZERS_H_ +#define LIB_JXL_SANITIZERS_H_ + +#include <stddef.h> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/sanitizer_definitions.h" +#include "lib/jxl/image.h" + +#if JXL_MEMORY_SANITIZER +#include <inttypes.h> +#include <stdio.h> + +#include <algorithm> +#include <string> +#include <vector> + +#include "lib/jxl/base/status.h" +#include "sanitizer/msan_interface.h" +#endif + +namespace jxl { +namespace msan { + +#if JXL_MEMORY_SANITIZER + +// Chosen so that kSanitizerSentinel is four copies of kSanitizerSentinelByte. +constexpr uint8_t kSanitizerSentinelByte = 0x48; +constexpr float kSanitizerSentinel = 205089.125f; + +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonMemory(const volatile void* m, + size_t size) { + __msan_poison(m, size); +} + +static JXL_INLINE JXL_MAYBE_UNUSED void UnpoisonMemory(const volatile void* m, + size_t size) { + __msan_unpoison(m, size); +} + +static JXL_INLINE JXL_MAYBE_UNUSED void UnpoisonCStr(const char* c) { + do { + UnpoisonMemory(c, 1); + } while (*c++); +} + +static JXL_INLINE JXL_MAYBE_UNUSED void MemoryIsInitialized( + const volatile void* m, size_t size) { + __msan_check_mem_is_initialized(m, size); +} + +// Mark all the bytes of an image (including padding) as poisoned bytes. +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonImage(const PlaneBase& im) { + PoisonMemory(im.bytes(), im.bytes_per_row() * im.ysize()); +} + +template <typename T> +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonImage(const Image3<T>& im) { + PoisonImage(im.Plane(0)); + PoisonImage(im.Plane(1)); + PoisonImage(im.Plane(2)); +} + +// Print the uninitialized regions of an image. +template <typename T> +static JXL_INLINE JXL_MAYBE_UNUSED void PrintImageUninitialized( + const Plane<T>& im) { + fprintf(stderr, + "Uninitialized regions for image of size %" PRIu64 "x%" PRIu64 ":\n", + static_cast<uint64_t>(im.xsize()), static_cast<uint64_t>(im.ysize())); + + // A segment of uninitialized pixels in a row, in the format [first, second). + typedef std::pair<size_t, size_t> PixelSegment; + + // Helper class to merge and print a list of rows of PixelSegment that may be + // the same over big ranges of rows. This compacts the output to ranges of + // rows like "[y0, y1): [x0, x1) [x2, x3)". + class RowsMerger { + public: + // Add a new row the list of rows. If the row is the same as the previous + // one it will be merged showing a range of rows [y0, y1), but if the new + // row is different the current range of rows (if any) will be printed and a + // new one will be started. + void AddRow(size_t y, std::vector<PixelSegment>&& new_row) { + if (start_y_ != -1 && new_row != segments_) { + PrintRow(y); + } + if (new_row.empty()) { + // Skip ranges with no uninitialized pixels. + start_y_ = -1; + segments_.clear(); + return; + } + if (start_y_ == -1) { + start_y_ = y; + segments_ = std::move(new_row); + } + } + + // Print the contents of the range of rows [start_y_, end_y) if any. + void PrintRow(size_t end_y) { + if (start_y_ == -1) return; + if (segments_.empty()) { + start_y_ = -1; + return; + } + if (end_y - start_y_ > 1) { + fprintf(stderr, " y=[%" PRId64 ", %" PRIu64 "):", + static_cast<int64_t>(start_y_), static_cast<uint64_t>(end_y)); + } else { + fprintf(stderr, " y=[%" PRId64 "]:", static_cast<int64_t>(start_y_)); + } + for (const auto& seg : segments_) { + if (seg.first + 1 == seg.second) { + fprintf(stderr, " [%" PRId64 "]", static_cast<int64_t>(seg.first)); + } else { + fprintf(stderr, " [%" PRId64 ", %" PRIu64 ")", + static_cast<int64_t>(seg.first), + static_cast<uint64_t>(seg.second)); + } + } + fprintf(stderr, "\n"); + start_y_ = -1; + } + + private: + std::vector<PixelSegment> segments_; + // Row number of the first row in the range of rows that have |segments| as + // the undefined segments. + ssize_t start_y_ = -1; + } rows_merger; + + class SegmentsMerger { + public: + void AddValue(size_t x) { + if (row.empty() || row.back().second != x) { + row.emplace_back(x, x + 1); + } else { + row.back().second = x + 1; + } + } + + std::vector<PixelSegment> row; + }; + + for (size_t y = 0; y < im.ysize(); y++) { + auto* row = im.Row(y); + SegmentsMerger seg_merger; + size_t x = 0; + while (x < im.xsize()) { + intptr_t ret = + __msan_test_shadow(row + x, (im.xsize() - x) * sizeof(row[0])); + if (ret < 0) break; + size_t next_x = x + ret / sizeof(row[0]); + seg_merger.AddValue(next_x); + x = next_x + 1; + } + rows_merger.AddRow(y, std::move(seg_merger.row)); + } + rows_merger.PrintRow(im.ysize()); +} + +// Check that all the pixels in the provided rect of the image are initialized +// (not poisoned). If any of the values is poisoned it will abort. +template <typename T> +static JXL_INLINE JXL_MAYBE_UNUSED void CheckImageInitialized( + const Plane<T>& im, const Rect& r, size_t c, const char* message) { + JXL_ASSERT(r.x0() <= im.xsize()); + JXL_ASSERT(r.x0() + r.xsize() <= im.xsize()); + JXL_ASSERT(r.y0() <= im.ysize()); + JXL_ASSERT(r.y0() + r.ysize() <= im.ysize()); + for (size_t y = r.y0(); y < r.y0() + r.ysize(); y++) { + const auto* row = im.Row(y); + intptr_t ret = __msan_test_shadow(row + r.x0(), sizeof(*row) * r.xsize()); + if (ret != -1) { + JXL_DEBUG( + 1, + "Checking an image of %" PRIu64 " x %" PRIu64 ", rect x0=%" PRIu64 + ", y0=%" PRIu64 + ", " + "xsize=%" PRIu64 ", ysize=%" PRIu64, + static_cast<uint64_t>(im.xsize()), static_cast<uint64_t>(im.ysize()), + static_cast<uint64_t>(r.x0()), static_cast<uint64_t>(r.y0()), + static_cast<uint64_t>(r.xsize()), static_cast<uint64_t>(r.ysize())); + size_t x = ret / sizeof(*row); + JXL_DEBUG(1, + "CheckImageInitialized failed at x=%" PRIu64 ", y=%" PRIu64 + ", c=%" PRIu64 ": %s", + static_cast<uint64_t>(r.x0() + x), static_cast<uint64_t>(y), + static_cast<uint64_t>(c), message ? message : ""); + PrintImageUninitialized(im); + } + // This will report an error if memory is not initialized. + __msan_check_mem_is_initialized(row + r.x0(), sizeof(*row) * r.xsize()); + } +} + +template <typename T> +static JXL_INLINE JXL_MAYBE_UNUSED void CheckImageInitialized( + const Image3<T>& im, const Rect& r, const char* message) { + for (size_t c = 0; c < 3; c++) { + std::string str_message(message); + str_message += " c=" + std::to_string(c); + CheckImageInitialized(im.Plane(c), r, c, str_message.c_str()); + } +} + +#define JXL_CHECK_IMAGE_INITIALIZED(im, r) \ + ::jxl::msan::CheckImageInitialized(im, r, "im=" #im ", r=" #r); + +#define JXL_CHECK_PLANE_INITIALIZED(im, r, c) \ + ::jxl::msan::CheckImageInitialized(im, r, c, "im=" #im ", r=" #r ", c=" #c); + +#else // JXL_MEMORY_SANITIZER + +// In non-msan mode these functions don't use volatile since it is not needed +// for the empty functions. + +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonMemory(const void*, size_t) {} +static JXL_INLINE JXL_MAYBE_UNUSED void UnpoisonMemory(const void*, size_t) {} +static JXL_INLINE JXL_MAYBE_UNUSED void UnpoisonCStr(const char*) {} +static JXL_INLINE JXL_MAYBE_UNUSED void MemoryIsInitialized(const void*, + size_t) {} + +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonImage(const PlaneBase& im) {} +template <typename T> +static JXL_INLINE JXL_MAYBE_UNUSED void PoisonImage(const Plane<T>& im) {} + +#define JXL_CHECK_IMAGE_INITIALIZED(im, r) +#define JXL_CHECK_PLANE_INITIALIZED(im, r, c) + +#endif + +} // namespace msan +} // namespace jxl + +#endif // LIB_JXL_SANITIZERS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/simd_util-inl.h b/third_party/jpeg-xl/lib/jxl/simd_util-inl.h new file mode 100644 index 0000000000..77b207ffe8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/simd_util-inl.h @@ -0,0 +1,349 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Misc utilities for SIMD operations + +#if defined(LIB_JXL_SIMD_UTIL_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_SIMD_UTIL_INL_H_ +#undef LIB_JXL_SIMD_UTIL_INL_H_ +#else +#define LIB_JXL_SIMD_UTIL_INL_H_ +#endif + +#include <hwy/highway.h> + +#include "lib/jxl/base/compiler_specific.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +#if HWY_CAP_GE512 +using hwy::HWY_NAMESPACE::Half; +using hwy::HWY_NAMESPACE::Vec; +template <size_t i, class DF, class V> +HWY_INLINE Vec<Half<Half<DF>>> Quarter(const DF df, V v) { + using HF = Half<DF>; + using HHF = Half<HF>; + auto half = i >= 2 ? UpperHalf(HF(), v) : LowerHalf(HF(), v); + return i & 1 ? UpperHalf(HHF(), half) : LowerHalf(HHF(), half); +} + +template <class DF, class V> +HWY_INLINE Vec<DF> Concat4(const DF df, V v0, V v1, V v2, V v3) { + using HF = Half<DF>; + return Combine(DF(), Combine(HF(), v3, v2), Combine(HF(), v1, v0)); +} + +#endif + +// Stores v0[0], v1[0], v0[1], v1[1], ... to mem, in this order. Mem must be +// aligned. +template <class DF, class V, typename T> +void StoreInterleaved(const DF df, V v0, V v1, T* mem) { + static_assert(sizeof(T) == 4, "only use StoreInterleaved for 4-byte types"); +#if HWY_TARGET == HWY_SCALAR + Store(v0, df, mem); + Store(v1, df, mem + 1); +#elif !HWY_CAP_GE256 + Store(InterleaveLower(df, v0, v1), df, mem); + Store(InterleaveUpper(df, v0, v1), df, mem + Lanes(df)); +#else + if (!HWY_CAP_GE512 || Lanes(df) == 8) { + auto t0 = InterleaveLower(df, v0, v1); + auto t1 = InterleaveUpper(df, v0, v1); + Store(ConcatLowerLower(df, t1, t0), df, mem); + Store(ConcatUpperUpper(df, t1, t0), df, mem + Lanes(df)); + } else { +#if HWY_CAP_GE512 + auto t0 = InterleaveLower(df, v0, v1); + auto t1 = InterleaveUpper(df, v0, v1); + Store(Concat4(df, Quarter<0>(df, t0), Quarter<0>(df, t1), + Quarter<1>(df, t0), Quarter<1>(df, t1)), + df, mem); + Store(Concat4(df, Quarter<2>(df, t0), Quarter<2>(df, t1), + Quarter<3>(df, t0), Quarter<3>(df, t1)), + df, mem + Lanes(df)); +#endif + } +#endif +} + +// Stores v0[0], v1[0], v2[0], v3[0], v0[1] ... to mem, in this order. Mem must +// be aligned. +template <class DF, class V, typename T> +void StoreInterleaved(const DF df, V v0, V v1, V v2, V v3, T* mem) { + static_assert(sizeof(T) == 4, "only use StoreInterleaved for 4-byte types"); +#if HWY_TARGET == HWY_SCALAR + Store(v0, df, mem); + Store(v1, df, mem + 1); + Store(v2, df, mem + 2); + Store(v3, df, mem + 3); +#elif !HWY_CAP_GE256 + auto t0 = InterleaveLower(df, v0, v2); + auto t1 = InterleaveLower(df, v1, v3); + auto t2 = InterleaveUpper(df, v0, v2); + auto t3 = InterleaveUpper(df, v1, v3); + Store(InterleaveLower(df, t0, t1), df, mem); + Store(InterleaveUpper(df, t0, t1), df, mem + Lanes(df)); + Store(InterleaveLower(df, t2, t3), df, mem + 2 * Lanes(df)); + Store(InterleaveUpper(df, t2, t3), df, mem + 3 * Lanes(df)); +#elif !HWY_CAP_GE512 + auto t0 = InterleaveLower(df, v0, v2); + auto t1 = InterleaveLower(df, v1, v3); + auto t2 = InterleaveUpper(df, v0, v2); + auto t3 = InterleaveUpper(df, v1, v3); + + auto m0 = InterleaveLower(df, t0, t1); + auto m1 = InterleaveUpper(df, t0, t1); + auto m2 = InterleaveLower(df, t2, t3); + auto m3 = InterleaveUpper(df, t2, t3); + + Store(ConcatLowerLower(df, m1, m0), df, mem); + Store(ConcatLowerLower(df, m3, m2), df, mem + Lanes(df)); + Store(ConcatUpperUpper(df, m1, m0), df, mem + 2 * Lanes(df)); + Store(ConcatUpperUpper(df, m3, m2), df, mem + 3 * Lanes(df)); +#else + auto t0 = InterleaveLower(df, v0, v2); + auto t1 = InterleaveLower(df, v1, v3); + auto t2 = InterleaveUpper(df, v0, v2); + auto t3 = InterleaveUpper(df, v1, v3); + + auto m0 = InterleaveLower(df, t0, t1); + auto m1 = InterleaveUpper(df, t0, t1); + auto m2 = InterleaveLower(df, t2, t3); + auto m3 = InterleaveUpper(df, t2, t3); + + Store(Concat4(df, Quarter<0>(df, m0), Quarter<0>(df, m1), Quarter<0>(df, m2), + Quarter<0>(df, m3)), + df, mem); + Store(Concat4(df, Quarter<1>(df, m0), Quarter<1>(df, m1), Quarter<1>(df, m2), + Quarter<1>(df, m3)), + df, mem + Lanes(df)); + Store(Concat4(df, Quarter<2>(df, m0), Quarter<2>(df, m1), Quarter<2>(df, m2), + Quarter<2>(df, m3)), + df, mem + 2 * Lanes(df)); + Store(Concat4(df, Quarter<3>(df, m0), Quarter<3>(df, m1), Quarter<3>(df, m2), + Quarter<3>(df, m3)), + df, mem + 3 * Lanes(df)); +#endif +} + +// Stores v0[0], v1[0], v2[0], v3[0], v4[0], v5[0], v6[0], v7[0], v0[1] ... to +// mem, in this order. Mem must be aligned. +template <class DF, class V> +void StoreInterleaved(const DF df, V v0, V v1, V v2, V v3, V v4, V v5, V v6, + V v7, float* mem) { +#if HWY_TARGET == HWY_SCALAR + Store(v0, df, mem); + Store(v1, df, mem + 1); + Store(v2, df, mem + 2); + Store(v3, df, mem + 3); + Store(v4, df, mem + 4); + Store(v5, df, mem + 5); + Store(v6, df, mem + 6); + Store(v7, df, mem + 7); +#elif !HWY_CAP_GE256 + auto t0 = InterleaveLower(df, v0, v4); + auto t1 = InterleaveLower(df, v1, v5); + auto t2 = InterleaveLower(df, v2, v6); + auto t3 = InterleaveLower(df, v3, v7); + auto t4 = InterleaveUpper(df, v0, v4); + auto t5 = InterleaveUpper(df, v1, v5); + auto t6 = InterleaveUpper(df, v2, v6); + auto t7 = InterleaveUpper(df, v3, v7); + + auto w0 = InterleaveLower(df, t0, t2); + auto w1 = InterleaveLower(df, t1, t3); + auto w2 = InterleaveUpper(df, t0, t2); + auto w3 = InterleaveUpper(df, t1, t3); + auto w4 = InterleaveLower(df, t4, t6); + auto w5 = InterleaveLower(df, t5, t7); + auto w6 = InterleaveUpper(df, t4, t6); + auto w7 = InterleaveUpper(df, t5, t7); + + Store(InterleaveLower(df, w0, w1), df, mem); + Store(InterleaveUpper(df, w0, w1), df, mem + Lanes(df)); + Store(InterleaveLower(df, w2, w3), df, mem + 2 * Lanes(df)); + Store(InterleaveUpper(df, w2, w3), df, mem + 3 * Lanes(df)); + Store(InterleaveLower(df, w4, w5), df, mem + 4 * Lanes(df)); + Store(InterleaveUpper(df, w4, w5), df, mem + 5 * Lanes(df)); + Store(InterleaveLower(df, w6, w7), df, mem + 6 * Lanes(df)); + Store(InterleaveUpper(df, w6, w7), df, mem + 7 * Lanes(df)); +#elif !HWY_CAP_GE512 + auto t0 = InterleaveLower(df, v0, v4); + auto t1 = InterleaveLower(df, v1, v5); + auto t2 = InterleaveLower(df, v2, v6); + auto t3 = InterleaveLower(df, v3, v7); + auto t4 = InterleaveUpper(df, v0, v4); + auto t5 = InterleaveUpper(df, v1, v5); + auto t6 = InterleaveUpper(df, v2, v6); + auto t7 = InterleaveUpper(df, v3, v7); + + auto w0 = InterleaveLower(df, t0, t2); + auto w1 = InterleaveLower(df, t1, t3); + auto w2 = InterleaveUpper(df, t0, t2); + auto w3 = InterleaveUpper(df, t1, t3); + auto w4 = InterleaveLower(df, t4, t6); + auto w5 = InterleaveLower(df, t5, t7); + auto w6 = InterleaveUpper(df, t4, t6); + auto w7 = InterleaveUpper(df, t5, t7); + + auto m0 = InterleaveLower(df, w0, w1); + auto m1 = InterleaveUpper(df, w0, w1); + auto m2 = InterleaveLower(df, w2, w3); + auto m3 = InterleaveUpper(df, w2, w3); + auto m4 = InterleaveLower(df, w4, w5); + auto m5 = InterleaveUpper(df, w4, w5); + auto m6 = InterleaveLower(df, w6, w7); + auto m7 = InterleaveUpper(df, w6, w7); + + Store(ConcatLowerLower(df, m1, m0), df, mem); + Store(ConcatLowerLower(df, m3, m2), df, mem + Lanes(df)); + Store(ConcatLowerLower(df, m5, m4), df, mem + 2 * Lanes(df)); + Store(ConcatLowerLower(df, m7, m6), df, mem + 3 * Lanes(df)); + Store(ConcatUpperUpper(df, m1, m0), df, mem + 4 * Lanes(df)); + Store(ConcatUpperUpper(df, m3, m2), df, mem + 5 * Lanes(df)); + Store(ConcatUpperUpper(df, m5, m4), df, mem + 6 * Lanes(df)); + Store(ConcatUpperUpper(df, m7, m6), df, mem + 7 * Lanes(df)); +#else + auto t0 = InterleaveLower(df, v0, v4); + auto t1 = InterleaveLower(df, v1, v5); + auto t2 = InterleaveLower(df, v2, v6); + auto t3 = InterleaveLower(df, v3, v7); + auto t4 = InterleaveUpper(df, v0, v4); + auto t5 = InterleaveUpper(df, v1, v5); + auto t6 = InterleaveUpper(df, v2, v6); + auto t7 = InterleaveUpper(df, v3, v7); + + auto w0 = InterleaveLower(df, t0, t2); + auto w1 = InterleaveLower(df, t1, t3); + auto w2 = InterleaveUpper(df, t0, t2); + auto w3 = InterleaveUpper(df, t1, t3); + auto w4 = InterleaveLower(df, t4, t6); + auto w5 = InterleaveLower(df, t5, t7); + auto w6 = InterleaveUpper(df, t4, t6); + auto w7 = InterleaveUpper(df, t5, t7); + + auto m0 = InterleaveLower(df, w0, w1); + auto m1 = InterleaveUpper(df, w0, w1); + auto m2 = InterleaveLower(df, w2, w3); + auto m3 = InterleaveUpper(df, w2, w3); + auto m4 = InterleaveLower(df, w4, w5); + auto m5 = InterleaveUpper(df, w4, w5); + auto m6 = InterleaveLower(df, w6, w7); + auto m7 = InterleaveUpper(df, w6, w7); + + Store(Concat4(df, Quarter<0>(df, m0), Quarter<0>(df, m1), Quarter<0>(df, m2), + Quarter<0>(df, m3)), + df, mem); + Store(Concat4(df, Quarter<0>(df, m4), Quarter<0>(df, m5), Quarter<0>(df, m6), + Quarter<0>(df, m7)), + df, mem + Lanes(df)); + Store(Concat4(df, Quarter<1>(df, m0), Quarter<1>(df, m1), Quarter<1>(df, m2), + Quarter<1>(df, m3)), + df, mem + 2 * Lanes(df)); + Store(Concat4(df, Quarter<1>(df, m4), Quarter<1>(df, m5), Quarter<1>(df, m6), + Quarter<1>(df, m7)), + df, mem + 3 * Lanes(df)); + Store(Concat4(df, Quarter<2>(df, m0), Quarter<2>(df, m1), Quarter<2>(df, m2), + Quarter<2>(df, m3)), + df, mem + 4 * Lanes(df)); + Store(Concat4(df, Quarter<2>(df, m4), Quarter<2>(df, m5), Quarter<2>(df, m6), + Quarter<2>(df, m7)), + df, mem + 5 * Lanes(df)); + Store(Concat4(df, Quarter<3>(df, m0), Quarter<3>(df, m1), Quarter<3>(df, m2), + Quarter<3>(df, m3)), + df, mem + 6 * Lanes(df)); + Store(Concat4(df, Quarter<3>(df, m4), Quarter<3>(df, m5), Quarter<3>(df, m6), + Quarter<3>(df, m7)), + df, mem + 7 * Lanes(df)); +#endif +} + +#if HWY_CAP_GE256 +JXL_INLINE void Transpose8x8Block(const int32_t* JXL_RESTRICT from, + int32_t* JXL_RESTRICT to, size_t fromstride) { + const HWY_CAPPED(int32_t, 8) d; + auto i0 = Load(d, from); + auto i1 = Load(d, from + 1 * fromstride); + auto i2 = Load(d, from + 2 * fromstride); + auto i3 = Load(d, from + 3 * fromstride); + auto i4 = Load(d, from + 4 * fromstride); + auto i5 = Load(d, from + 5 * fromstride); + auto i6 = Load(d, from + 6 * fromstride); + auto i7 = Load(d, from + 7 * fromstride); + + const auto q0 = InterleaveLower(d, i0, i2); + const auto q1 = InterleaveLower(d, i1, i3); + const auto q2 = InterleaveUpper(d, i0, i2); + const auto q3 = InterleaveUpper(d, i1, i3); + const auto q4 = InterleaveLower(d, i4, i6); + const auto q5 = InterleaveLower(d, i5, i7); + const auto q6 = InterleaveUpper(d, i4, i6); + const auto q7 = InterleaveUpper(d, i5, i7); + + const auto r0 = InterleaveLower(d, q0, q1); + const auto r1 = InterleaveUpper(d, q0, q1); + const auto r2 = InterleaveLower(d, q2, q3); + const auto r3 = InterleaveUpper(d, q2, q3); + const auto r4 = InterleaveLower(d, q4, q5); + const auto r5 = InterleaveUpper(d, q4, q5); + const auto r6 = InterleaveLower(d, q6, q7); + const auto r7 = InterleaveUpper(d, q6, q7); + + i0 = ConcatLowerLower(d, r4, r0); + i1 = ConcatLowerLower(d, r5, r1); + i2 = ConcatLowerLower(d, r6, r2); + i3 = ConcatLowerLower(d, r7, r3); + i4 = ConcatUpperUpper(d, r4, r0); + i5 = ConcatUpperUpper(d, r5, r1); + i6 = ConcatUpperUpper(d, r6, r2); + i7 = ConcatUpperUpper(d, r7, r3); + + Store(i0, d, to); + Store(i1, d, to + 1 * 8); + Store(i2, d, to + 2 * 8); + Store(i3, d, to + 3 * 8); + Store(i4, d, to + 4 * 8); + Store(i5, d, to + 5 * 8); + Store(i6, d, to + 6 * 8); + Store(i7, d, to + 7 * 8); +} +#elif HWY_TARGET != HWY_SCALAR +JXL_INLINE void Transpose8x8Block(const int32_t* JXL_RESTRICT from, + int32_t* JXL_RESTRICT to, size_t fromstride) { + const HWY_CAPPED(int32_t, 4) d; + for (size_t n = 0; n < 8; n += 4) { + for (size_t m = 0; m < 8; m += 4) { + auto p0 = Load(d, from + n * fromstride + m); + auto p1 = Load(d, from + (n + 1) * fromstride + m); + auto p2 = Load(d, from + (n + 2) * fromstride + m); + auto p3 = Load(d, from + (n + 3) * fromstride + m); + const auto q0 = InterleaveLower(d, p0, p2); + const auto q1 = InterleaveLower(d, p1, p3); + const auto q2 = InterleaveUpper(d, p0, p2); + const auto q3 = InterleaveUpper(d, p1, p3); + + const auto r0 = InterleaveLower(d, q0, q1); + const auto r1 = InterleaveUpper(d, q0, q1); + const auto r2 = InterleaveLower(d, q2, q3); + const auto r3 = InterleaveUpper(d, q2, q3); + Store(r0, d, to + m * 8 + n); + Store(r1, d, to + (1 + m) * 8 + n); + Store(r2, d, to + (2 + m) * 8 + n); + Store(r3, d, to + (3 + m) * 8 + n); + } + } +} + +#endif + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_SIMD_UTIL_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/simd_util.cc b/third_party/jpeg-xl/lib/jxl/simd_util.cc new file mode 100644 index 0000000000..a3971ff900 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/simd_util.cc @@ -0,0 +1,40 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/simd_util.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/simd_util.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +size_t MaxVectorSize() { + HWY_FULL(float) df; + return Lanes(df) * sizeof(float); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(MaxVectorSize); + +size_t MaxVectorSize() { + // Ideally HWY framework should provide us this value. + // Less than ideal is to check all available targets and choose maximal. + // As for now, we just ask current active target, assuming it won't change. + return HWY_DYNAMIC_DISPATCH(MaxVectorSize)(); +} + +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/simd_util.h b/third_party/jpeg-xl/lib/jxl/simd_util.h new file mode 100644 index 0000000000..84938a931a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/simd_util.h @@ -0,0 +1,17 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_SIMD_UTIL_H_ +#define LIB_JXL_SIMD_UTIL_H_ +#include <stddef.h> + +namespace jxl { + +// Maximal vector size in bytes. +size_t MaxVectorSize(); + +} // namespace jxl + +#endif // LIB_JXL_SIMD_UTIL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/simd_util_test.cc b/third_party/jpeg-xl/lib/jxl/simd_util_test.cc new file mode 100644 index 0000000000..94f7788a8e --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/simd_util_test.cc @@ -0,0 +1,84 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/testing.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/simd_util_test.cc" +#include <hwy/foreach_target.h> + +#include "lib/jxl/simd_util-inl.h" + +// Test utils +#include <hwy/highway.h> +#include <hwy/tests/hwy_gtest.h> +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +HWY_NOINLINE void TestInterleave2() { + HWY_FULL(float) d; + auto vec1 = Iota(d, 0 * 128.0); + auto vec2 = Iota(d, 1 * 128.0); + HWY_ALIGN float mem[MaxLanes(d) * 2]; + StoreInterleaved(d, vec1, vec2, mem); + for (size_t i = 0; i < Lanes(d); i++) { + for (size_t j = 0; j < 2; j++) { + EXPECT_EQ(mem[2 * i + j], j * 128 + i) << "i: " << i << " j: " << j; + } + } +} +HWY_NOINLINE void TestInterleave4() { + HWY_FULL(float) d; + auto vec1 = Iota(d, 0 * 128.0); + auto vec2 = Iota(d, 1 * 128.0); + auto vec3 = Iota(d, 2 * 128.0); + auto vec4 = Iota(d, 3 * 128.0); + HWY_ALIGN float mem[MaxLanes(d) * 4]; + StoreInterleaved(d, vec1, vec2, vec3, vec4, mem); + for (size_t i = 0; i < Lanes(d); i++) { + for (size_t j = 0; j < 4; j++) { + EXPECT_EQ(mem[4 * i + j], j * 128 + i) << "i: " << i << " j: " << j; + } + } +} +HWY_NOINLINE void TestInterleave8() { + HWY_FULL(float) d; + auto vec1 = Iota(d, 0 * 128.0); + auto vec2 = Iota(d, 1 * 128.0); + auto vec3 = Iota(d, 2 * 128.0); + auto vec4 = Iota(d, 3 * 128.0); + auto vec5 = Iota(d, 4 * 128.0); + auto vec6 = Iota(d, 5 * 128.0); + auto vec7 = Iota(d, 6 * 128.0); + auto vec8 = Iota(d, 7 * 128.0); + HWY_ALIGN float mem[MaxLanes(d) * 8]; + StoreInterleaved(d, vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8, mem); + for (size_t i = 0; i < Lanes(d); i++) { + for (size_t j = 0; j < 8; j++) { + EXPECT_EQ(mem[8 * i + j], j * 128 + i) << "i: " << i << " j: " << j; + } + } +} + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class SimdUtilTargetTest : public hwy::TestWithParamTarget {}; +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(SimdUtilTargetTest); + +HWY_EXPORT_AND_TEST_P(SimdUtilTargetTest, TestInterleave2); +HWY_EXPORT_AND_TEST_P(SimdUtilTargetTest, TestInterleave4); +HWY_EXPORT_AND_TEST_P(SimdUtilTargetTest, TestInterleave8); + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/speed_tier_test.cc b/third_party/jpeg-xl/lib/jxl/speed_tier_test.cc new file mode 100644 index 0000000000..7874bdc158 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/speed_tier_test.cc @@ -0,0 +1,118 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <jxl/encode.h> +#include <jxl/types.h> + +#include <cstddef> +#include <cstdint> +#include <ios> +#include <ostream> +#include <string> +#include <vector> + +#include "lib/extras/dec/jxl.h" +#include "lib/extras/enc/jxl.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/test_image.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +struct SpeedTierTestParams { + explicit SpeedTierTestParams(const SpeedTier speed_tier, + const bool shrink8 = false) + : speed_tier(speed_tier), shrink8(shrink8) {} + SpeedTier speed_tier; + bool shrink8; +}; + +std::ostream& operator<<(std::ostream& os, SpeedTierTestParams params) { + auto previous_flags = os.flags(); + os << std::boolalpha; + os << "SpeedTierTestParams{" << static_cast<size_t>(params.speed_tier) + << ", /*shrink8=*/" << params.shrink8 << "}"; + os.flags(previous_flags); + return os; +} + +class SpeedTierTest : public testing::TestWithParam<SpeedTierTestParams> {}; + +JXL_GTEST_INSTANTIATE_TEST_SUITE_P( + SpeedTierTestInstantiation, SpeedTierTest, + testing::Values(SpeedTierTestParams{SpeedTier::kCheetah, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kCheetah, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kThunder, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kThunder, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kLightning, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kLightning, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kFalcon, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kFalcon, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kHare, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kHare, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kWombat, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kWombat, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kSquirrel, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kSquirrel, + /*shrink8=*/false}, + SpeedTierTestParams{SpeedTier::kKitten, + /*shrink8=*/false}, + // Only downscaled image for Tortoise mode. + SpeedTierTestParams{SpeedTier::kTortoise, + /*shrink8=*/true}, + SpeedTierTestParams{SpeedTier::kGlacier, + /*shrink8=*/true})); + +TEST_P(SpeedTierTest, Roundtrip) { + const SpeedTierTestParams& params = GetParam(); + test::ThreadPoolForTests pool(8); + const std::vector<uint8_t> orig = jxl::test::ReadTestData( + "external/wesaturate/500px/u76c0g_bliznaca_srgb8.png"); + test::TestImage t; + t.DecodeFromBytes(orig).ClearMetadata(); + if (params.speed_tier == SpeedTier::kGlacier) { + // just a few pixels will already take enough time at this setting + t.SetDimensions(8, 8); + } else if (params.shrink8) { + t.SetDimensions(t.ppf().xsize() / 8, t.ppf().ysize() / 8); + } + + extras::JXLCompressParams cparams; + cparams.distance = 1.0f; + cparams.allow_expert_options = true; + cparams.AddOption(JXL_ENC_FRAME_SETTING_EFFORT, + 10 - static_cast<int>(params.speed_tier)); + extras::JXLDecompressParams dparams; + dparams.accepted_formats = {{3, JXL_TYPE_UINT16, JXL_LITTLE_ENDIAN, 0}}; + + { + extras::PackedPixelFile ppf_out; + test::Roundtrip(t.ppf(), cparams, dparams, nullptr, &ppf_out); + EXPECT_LE(test::ButteraugliDistance(t.ppf(), ppf_out), 1.6); + } + if (params.shrink8) { + cparams.distance = 0.0f; + extras::PackedPixelFile ppf_out; + test::Roundtrip(t.ppf(), cparams, dparams, nullptr, &ppf_out); + EXPECT_EQ(0.0f, test::ComputeDistance2(t.ppf(), ppf_out)); + } +} +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/splines.cc b/third_party/jpeg-xl/lib/jxl/splines.cc new file mode 100644 index 0000000000..acbaf38428 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/splines.cc @@ -0,0 +1,721 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/splines.h" + +#include <algorithm> +#include <cinttypes> +#include <cmath> +#include <limits> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/common.h" // JXL_HIGH_PRECISION +#include "lib/jxl/dct_scales.h" +#include "lib/jxl/dec_ans.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/pack_signed.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/splines.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/base/fast_math-inl.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::MulSub; +using hwy::HWY_NAMESPACE::Sqrt; +using hwy::HWY_NAMESPACE::Sub; + +// Given a set of DCT coefficients, this returns the result of performing cosine +// interpolation on the original samples. +float ContinuousIDCT(const float dct[32], const float t) { + // We compute here the DCT-3 of the `dct` vector, rescaled by a factor of + // sqrt(32). This is such that an input vector vector {x, 0, ..., 0} produces + // a constant result of x. dct[0] was scaled in Dequantize() to allow uniform + // treatment of all the coefficients. + constexpr float kMultipliers[32] = { + kPi / 32 * 0, kPi / 32 * 1, kPi / 32 * 2, kPi / 32 * 3, kPi / 32 * 4, + kPi / 32 * 5, kPi / 32 * 6, kPi / 32 * 7, kPi / 32 * 8, kPi / 32 * 9, + kPi / 32 * 10, kPi / 32 * 11, kPi / 32 * 12, kPi / 32 * 13, kPi / 32 * 14, + kPi / 32 * 15, kPi / 32 * 16, kPi / 32 * 17, kPi / 32 * 18, kPi / 32 * 19, + kPi / 32 * 20, kPi / 32 * 21, kPi / 32 * 22, kPi / 32 * 23, kPi / 32 * 24, + kPi / 32 * 25, kPi / 32 * 26, kPi / 32 * 27, kPi / 32 * 28, kPi / 32 * 29, + kPi / 32 * 30, kPi / 32 * 31, + }; + HWY_CAPPED(float, 32) df; + auto result = Zero(df); + const auto tandhalf = Set(df, t + 0.5f); + for (int i = 0; i < 32; i += Lanes(df)) { + auto cos_arg = Mul(LoadU(df, kMultipliers + i), tandhalf); + auto cos = FastCosf(df, cos_arg); + auto local_res = Mul(LoadU(df, dct + i), cos); + result = MulAdd(Set(df, kSqrt2), local_res, result); + } + return GetLane(SumOfLanes(df, result)); +} + +template <typename DF> +void DrawSegment(DF df, const SplineSegment& segment, const bool add, + const size_t y, const size_t x, float* JXL_RESTRICT rows[3]) { + Rebind<int32_t, DF> di; + const auto inv_sigma = Set(df, segment.inv_sigma); + const auto half = Set(df, 0.5f); + const auto one_over_2s2 = Set(df, 0.353553391f); + const auto sigma_over_4_times_intensity = + Set(df, segment.sigma_over_4_times_intensity); + const auto dx = Sub(ConvertTo(df, Iota(di, x)), Set(df, segment.center_x)); + const auto dy = Set(df, y - segment.center_y); + const auto sqd = MulAdd(dx, dx, Mul(dy, dy)); + const auto distance = Sqrt(sqd); + const auto one_dimensional_factor = + Sub(FastErff(df, Mul(MulAdd(distance, half, one_over_2s2), inv_sigma)), + FastErff(df, Mul(MulSub(distance, half, one_over_2s2), inv_sigma))); + auto local_intensity = + Mul(sigma_over_4_times_intensity, + Mul(one_dimensional_factor, one_dimensional_factor)); + for (size_t c = 0; c < 3; ++c) { + const auto cm = Set(df, add ? segment.color[c] : -segment.color[c]); + const auto in = LoadU(df, rows[c] + x); + StoreU(MulAdd(cm, local_intensity, in), df, rows[c] + x); + } +} + +void DrawSegment(const SplineSegment& segment, const bool add, const size_t y, + const ssize_t x0, ssize_t x1, float* JXL_RESTRICT rows[3]) { + ssize_t x = + std::max<ssize_t>(x0, segment.center_x - segment.maximum_distance + 0.5f); + // one-past-the-end + x1 = + std::min<ssize_t>(x1, segment.center_x + segment.maximum_distance + 1.5f); + HWY_FULL(float) df; + for (; x + static_cast<ssize_t>(Lanes(df)) <= x1; x += Lanes(df)) { + DrawSegment(df, segment, add, y, x, rows); + } + for (; x < x1; ++x) { + DrawSegment(HWY_CAPPED(float, 1)(), segment, add, y, x, rows); + } +} + +void ComputeSegments(const Spline::Point& center, const float intensity, + const float color[3], const float sigma, + std::vector<SplineSegment>& segments, + std::vector<std::pair<size_t, size_t>>& segments_by_y) { + // Sanity check sigma, inverse sigma and intensity + if (!(std::isfinite(sigma) && sigma != 0.0f && std::isfinite(1.0f / sigma) && + std::isfinite(intensity))) { + return; + } +#if JXL_HIGH_PRECISION + constexpr float kDistanceExp = 5; +#else + // About 30% faster. + constexpr float kDistanceExp = 3; +#endif + // We cap from below colors to at least 0.01. + float max_color = 0.01f; + for (size_t c = 0; c < 3; c++) { + max_color = std::max(max_color, std::abs(color[c] * intensity)); + } + // Distance beyond which max_color*intensity*exp(-d^2 / (2 * sigma^2)) drops + // below 10^-kDistanceExp. + const float maximum_distance = + std::sqrt(-2 * sigma * sigma * + (std::log(0.1) * kDistanceExp - std::log(max_color))); + SplineSegment segment; + segment.center_y = center.y; + segment.center_x = center.x; + memcpy(segment.color, color, sizeof(segment.color)); + segment.inv_sigma = 1.0f / sigma; + segment.sigma_over_4_times_intensity = .25f * sigma * intensity; + segment.maximum_distance = maximum_distance; + ssize_t y0 = center.y - maximum_distance + .5f; + ssize_t y1 = center.y + maximum_distance + 1.5f; // one-past-the-end + for (ssize_t y = std::max<ssize_t>(y0, 0); y < y1; y++) { + segments_by_y.emplace_back(y, segments.size()); + } + segments.push_back(segment); +} + +void DrawSegments(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y, + float* JXL_RESTRICT row_b, const Rect& image_rect, + const bool add, const SplineSegment* segments, + const size_t* segment_indices, + const size_t* segment_y_start) { + JXL_ASSERT(image_rect.ysize() == 1); + float* JXL_RESTRICT rows[3] = {row_x - image_rect.x0(), + row_y - image_rect.x0(), + row_b - image_rect.x0()}; + size_t y = image_rect.y0(); + for (size_t i = segment_y_start[y]; i < segment_y_start[y + 1]; i++) { + DrawSegment(segments[segment_indices[i]], add, y, image_rect.x0(), + image_rect.x0() + image_rect.xsize(), rows); + } +} + +void SegmentsFromPoints( + const Spline& spline, + const std::vector<std::pair<Spline::Point, float>>& points_to_draw, + const float arc_length, std::vector<SplineSegment>& segments, + std::vector<std::pair<size_t, size_t>>& segments_by_y) { + const float inv_arc_length = 1.0f / arc_length; + int k = 0; + for (const auto& point_to_draw : points_to_draw) { + const Spline::Point& point = point_to_draw.first; + const float multiplier = point_to_draw.second; + const float progress_along_arc = + std::min(1.f, (k * kDesiredRenderingDistance) * inv_arc_length); + ++k; + float color[3]; + for (size_t c = 0; c < 3; ++c) { + color[c] = + ContinuousIDCT(spline.color_dct[c], (32 - 1) * progress_along_arc); + } + const float sigma = + ContinuousIDCT(spline.sigma_dct, (32 - 1) * progress_along_arc); + ComputeSegments(point, multiplier, color, sigma, segments, segments_by_y); + } +} +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(SegmentsFromPoints); +HWY_EXPORT(DrawSegments); + +namespace { + +// It is not in spec, but reasonable limit to avoid overflows. +template <typename T> +Status ValidateSplinePointPos(const T& x, const T& y) { + constexpr T kSplinePosLimit = 1u << 23; + if ((x >= kSplinePosLimit) || (x <= -kSplinePosLimit) || + (y >= kSplinePosLimit) || (y <= -kSplinePosLimit)) { + return JXL_FAILURE("Spline coordinates out of bounds"); + } + return true; +} + +// Maximum number of spline control points per frame is +// std::min(kMaxNumControlPoints, xsize * ysize / 2) +constexpr size_t kMaxNumControlPoints = 1u << 20u; +constexpr size_t kMaxNumControlPointsPerPixelRatio = 2; + +float AdjustedQuant(const int32_t adjustment) { + return (adjustment >= 0) ? (1.f + .125f * adjustment) + : 1.f / (1.f - .125f * adjustment); +} + +float InvAdjustedQuant(const int32_t adjustment) { + return (adjustment >= 0) ? 1.f / (1.f + .125f * adjustment) + : (1.f - .125f * adjustment); +} + +// X, Y, B, sigma. +static constexpr float kChannelWeight[] = {0.0042f, 0.075f, 0.07f, .3333f}; + +Status DecodeAllStartingPoints(std::vector<Spline::Point>* const points, + BitReader* const br, ANSSymbolReader* reader, + const std::vector<uint8_t>& context_map, + const size_t num_splines) { + points->clear(); + points->reserve(num_splines); + int64_t last_x = 0; + int64_t last_y = 0; + for (size_t i = 0; i < num_splines; i++) { + int64_t x = + reader->ReadHybridUint(kStartingPositionContext, br, context_map); + int64_t y = + reader->ReadHybridUint(kStartingPositionContext, br, context_map); + if (i != 0) { + x = UnpackSigned(x) + last_x; + y = UnpackSigned(y) + last_y; + } + JXL_RETURN_IF_ERROR(ValidateSplinePointPos(x, y)); + points->emplace_back(static_cast<float>(x), static_cast<float>(y)); + last_x = x; + last_y = y; + } + return true; +} + +struct Vector { + float x, y; + Vector operator-() const { return {-x, -y}; } + Vector operator+(const Vector& other) const { + return {x + other.x, y + other.y}; + } + float SquaredNorm() const { return x * x + y * y; } +}; +Vector operator*(const float k, const Vector& vec) { + return {k * vec.x, k * vec.y}; +} + +Spline::Point operator+(const Spline::Point& p, const Vector& vec) { + return {p.x + vec.x, p.y + vec.y}; +} +Vector operator-(const Spline::Point& a, const Spline::Point& b) { + return {a.x - b.x, a.y - b.y}; +} + +// TODO(eustas): avoid making a copy of "points". +void DrawCentripetalCatmullRomSpline(std::vector<Spline::Point> points, + std::vector<Spline::Point>& result) { + if (points.empty()) return; + if (points.size() == 1) { + result.push_back(points[0]); + return; + } + // Number of points to compute between each control point. + static constexpr int kNumPoints = 16; + result.reserve((points.size() - 1) * kNumPoints + 1); + points.insert(points.begin(), points[0] + (points[0] - points[1])); + points.push_back(points[points.size() - 1] + + (points[points.size() - 1] - points[points.size() - 2])); + // points has at least 4 elements at this point. + for (size_t start = 0; start < points.size() - 3; ++start) { + // 4 of them are used, and we draw from p[1] to p[2]. + const Spline::Point* const p = &points[start]; + result.push_back(p[1]); + float d[3]; + float t[4]; + t[0] = 0; + for (int k = 0; k < 3; ++k) { + // TODO(eustas): for each segment delta is calculated 3 times... + // TODO(eustas): restrict d[k] with reasonable limit and spec it. + d[k] = std::sqrt(hypotf(p[k + 1].x - p[k].x, p[k + 1].y - p[k].y)); + t[k + 1] = t[k] + d[k]; + } + for (int i = 1; i < kNumPoints; ++i) { + const float tt = d[0] + (static_cast<float>(i) / kNumPoints) * d[1]; + Spline::Point a[3]; + for (int k = 0; k < 3; ++k) { + // TODO(eustas): reciprocal multiplication would be faster. + a[k] = p[k] + ((tt - t[k]) / d[k]) * (p[k + 1] - p[k]); + } + Spline::Point b[2]; + for (int k = 0; k < 2; ++k) { + b[k] = a[k] + ((tt - t[k]) / (d[k] + d[k + 1])) * (a[k + 1] - a[k]); + } + result.push_back(b[0] + ((tt - t[1]) / d[1]) * (b[1] - b[0])); + } + } + result.push_back(points[points.size() - 2]); +} + +// Move along the line segments defined by `points`, `kDesiredRenderingDistance` +// pixels at a time, and call `functor` with each point and the actual distance +// to the previous point (which will always be kDesiredRenderingDistance except +// possibly for the very last point). +// TODO(eustas): this method always adds the last point, but never the first +// (unless those are one); I believe both ends matter. +template <typename Points, typename Functor> +void ForEachEquallySpacedPoint(const Points& points, const Functor& functor) { + JXL_ASSERT(!points.empty()); + Spline::Point current = points.front(); + functor(current, kDesiredRenderingDistance); + auto next = points.begin(); + while (next != points.end()) { + const Spline::Point* previous = ¤t; + float arclength_from_previous = 0.f; + for (;;) { + if (next == points.end()) { + functor(*previous, arclength_from_previous); + return; + } + const float arclength_to_next = + std::sqrt((*next - *previous).SquaredNorm()); + if (arclength_from_previous + arclength_to_next >= + kDesiredRenderingDistance) { + current = + *previous + ((kDesiredRenderingDistance - arclength_from_previous) / + arclength_to_next) * + (*next - *previous); + functor(current, kDesiredRenderingDistance); + break; + } + arclength_from_previous += arclength_to_next; + previous = &*next; + ++next; + } + } +} + +} // namespace + +QuantizedSpline::QuantizedSpline(const Spline& original, + const int32_t quantization_adjustment, + const float y_to_x, const float y_to_b) { + JXL_ASSERT(!original.control_points.empty()); + control_points_.reserve(original.control_points.size() - 1); + const Spline::Point& starting_point = original.control_points.front(); + int previous_x = static_cast<int>(std::roundf(starting_point.x)); + int previous_y = static_cast<int>(std::roundf(starting_point.y)); + int previous_delta_x = 0, previous_delta_y = 0; + for (auto it = original.control_points.begin() + 1; + it != original.control_points.end(); ++it) { + const int new_x = static_cast<int>(std::roundf(it->x)); + const int new_y = static_cast<int>(std::roundf(it->y)); + const int new_delta_x = new_x - previous_x; + const int new_delta_y = new_y - previous_y; + control_points_.emplace_back(new_delta_x - previous_delta_x, + new_delta_y - previous_delta_y); + previous_delta_x = new_delta_x; + previous_delta_y = new_delta_y; + previous_x = new_x; + previous_y = new_y; + } + + const auto to_int = [](float v) -> int { + // Maximal int representable with float. + constexpr float kMax = std::numeric_limits<int>::max() - 127; + constexpr float kMin = -kMax; + return static_cast<int>(std::roundf(Clamp1(v, kMin, kMax))); + }; + + const auto quant = AdjustedQuant(quantization_adjustment); + const auto inv_quant = InvAdjustedQuant(quantization_adjustment); + for (int c : {1, 0, 2}) { + float factor = (c == 0) ? y_to_x : (c == 1) ? 0 : y_to_b; + for (int i = 0; i < 32; ++i) { + const float dct_factor = (i == 0) ? kSqrt2 : 1.0f; + const float inv_dct_factor = (i == 0) ? kSqrt0_5 : 1.0f; + auto restored_y = + color_dct_[1][i] * inv_dct_factor * kChannelWeight[1] * inv_quant; + auto decorellated = original.color_dct[c][i] - factor * restored_y; + color_dct_[c][i] = + to_int(decorellated * dct_factor * quant / kChannelWeight[c]); + } + } + for (int i = 0; i < 32; ++i) { + const float dct_factor = (i == 0) ? kSqrt2 : 1.0f; + sigma_dct_[i] = + to_int(original.sigma_dct[i] * dct_factor * quant / kChannelWeight[3]); + } +} + +Status QuantizedSpline::Dequantize(const Spline::Point& starting_point, + const int32_t quantization_adjustment, + const float y_to_x, const float y_to_b, + const uint64_t image_size, + uint64_t* total_estimated_area_reached, + Spline& result) const { + constexpr uint64_t kOne = static_cast<uint64_t>(1); + const uint64_t area_limit = + std::min(1024 * image_size + (kOne << 32), kOne << 42); + + result.control_points.clear(); + result.control_points.reserve(control_points_.size() + 1); + float px = std::roundf(starting_point.x); + float py = std::roundf(starting_point.y); + JXL_RETURN_IF_ERROR(ValidateSplinePointPos(px, py)); + int current_x = static_cast<int>(px); + int current_y = static_cast<int>(py); + result.control_points.push_back(Spline::Point{static_cast<float>(current_x), + static_cast<float>(current_y)}); + int current_delta_x = 0, current_delta_y = 0; + uint64_t manhattan_distance = 0; + for (const auto& point : control_points_) { + current_delta_x += point.first; + current_delta_y += point.second; + manhattan_distance += std::abs(current_delta_x) + std::abs(current_delta_y); + if (manhattan_distance > area_limit) { + return JXL_FAILURE("Too large manhattan_distance reached: %" PRIu64, + manhattan_distance); + } + JXL_RETURN_IF_ERROR( + ValidateSplinePointPos(current_delta_x, current_delta_y)); + current_x += current_delta_x; + current_y += current_delta_y; + JXL_RETURN_IF_ERROR(ValidateSplinePointPos(current_x, current_y)); + result.control_points.push_back(Spline::Point{ + static_cast<float>(current_x), static_cast<float>(current_y)}); + } + + const auto inv_quant = InvAdjustedQuant(quantization_adjustment); + for (int c = 0; c < 3; ++c) { + for (int i = 0; i < 32; ++i) { + const float inv_dct_factor = (i == 0) ? kSqrt0_5 : 1.0f; + result.color_dct[c][i] = + color_dct_[c][i] * inv_dct_factor * kChannelWeight[c] * inv_quant; + } + } + for (int i = 0; i < 32; ++i) { + result.color_dct[0][i] += y_to_x * result.color_dct[1][i]; + result.color_dct[2][i] += y_to_b * result.color_dct[1][i]; + } + uint64_t width_estimate = 0; + + uint64_t color[3] = {}; + for (int c = 0; c < 3; ++c) { + for (int i = 0; i < 32; ++i) { + color[c] += static_cast<uint64_t>( + std::ceil(inv_quant * std::abs(color_dct_[c][i]))); + } + } + color[0] += static_cast<uint64_t>(std::ceil(std::abs(y_to_x))) * color[1]; + color[2] += static_cast<uint64_t>(std::ceil(std::abs(y_to_b))) * color[1]; + // This is not taking kChannelWeight into account, but up to constant factors + // it gives an indication of the influence of the color values on the area + // that will need to be rendered. + const uint64_t max_color = std::max({color[1], color[0], color[2]}); + uint64_t logcolor = + std::max(kOne, static_cast<uint64_t>(CeilLog2Nonzero(kOne + max_color))); + + const float weight_limit = + std::ceil(std::sqrt((static_cast<float>(area_limit) / logcolor) / + std::max<size_t>(1, manhattan_distance))); + + for (int i = 0; i < 32; ++i) { + const float inv_dct_factor = (i == 0) ? kSqrt0_5 : 1.0f; + result.sigma_dct[i] = + sigma_dct_[i] * inv_dct_factor * kChannelWeight[3] * inv_quant; + // If we include the factor kChannelWeight[3]=.3333f here, we get a + // realistic area estimate. We leave it out to simplify the calculations, + // and understand that this way we underestimate the area by a factor of + // 1/(0.3333*0.3333). This is taken into account in the limits below. + float weight_f = std::ceil(inv_quant * std::abs(sigma_dct_[i])); + uint64_t weight = + static_cast<uint64_t>(std::min(weight_limit, std::max(1.0f, weight_f))); + width_estimate += weight * weight * logcolor; + } + *total_estimated_area_reached += (width_estimate * manhattan_distance); + if (*total_estimated_area_reached > area_limit) { + return JXL_FAILURE("Too large total_estimated_area eached: %" PRIu64, + *total_estimated_area_reached); + } + + return true; +} + +Status QuantizedSpline::Decode(const std::vector<uint8_t>& context_map, + ANSSymbolReader* const decoder, + BitReader* const br, + const size_t max_control_points, + size_t* total_num_control_points) { + const size_t num_control_points = + decoder->ReadHybridUint(kNumControlPointsContext, br, context_map); + if (num_control_points > max_control_points) { + return JXL_FAILURE("Too many control points: %" PRIuS, num_control_points); + } + *total_num_control_points += num_control_points; + if (*total_num_control_points > max_control_points) { + return JXL_FAILURE("Too many control points: %" PRIuS, + *total_num_control_points); + } + control_points_.resize(num_control_points); + // Maximal image dimension. + constexpr int64_t kDeltaLimit = 1u << 30; + for (std::pair<int64_t, int64_t>& control_point : control_points_) { + control_point.first = UnpackSigned( + decoder->ReadHybridUint(kControlPointsContext, br, context_map)); + control_point.second = UnpackSigned( + decoder->ReadHybridUint(kControlPointsContext, br, context_map)); + // Check delta-deltas are not outrageous; it is not in spec, but there is + // no reason to allow larger values. + if ((control_point.first >= kDeltaLimit) || + (control_point.first <= -kDeltaLimit) || + (control_point.second >= kDeltaLimit) || + (control_point.second <= -kDeltaLimit)) { + return JXL_FAILURE("Spline delta-delta is out of bounds"); + } + } + + const auto decode_dct = [decoder, br, &context_map](int dct[32]) -> Status { + constexpr int kWeirdNumber = std::numeric_limits<int>::min(); + for (int i = 0; i < 32; ++i) { + dct[i] = + UnpackSigned(decoder->ReadHybridUint(kDCTContext, br, context_map)); + if (dct[i] == kWeirdNumber) { + return JXL_FAILURE("The weird number in spline DCT"); + } + } + return true; + }; + for (int c = 0; c < 3; ++c) { + JXL_RETURN_IF_ERROR(decode_dct(color_dct_[c])); + } + JXL_RETURN_IF_ERROR(decode_dct(sigma_dct_)); + return true; +} + +void Splines::Clear() { + quantization_adjustment_ = 0; + splines_.clear(); + starting_points_.clear(); + segments_.clear(); + segment_indices_.clear(); + segment_y_start_.clear(); +} + +Status Splines::Decode(jxl::BitReader* br, const size_t num_pixels) { + std::vector<uint8_t> context_map; + ANSCode code; + JXL_RETURN_IF_ERROR( + DecodeHistograms(br, kNumSplineContexts, &code, &context_map)); + ANSSymbolReader decoder(&code, br); + size_t num_splines = + decoder.ReadHybridUint(kNumSplinesContext, br, context_map); + size_t max_control_points = std::min( + kMaxNumControlPoints, num_pixels / kMaxNumControlPointsPerPixelRatio); + if (num_splines > max_control_points || + num_splines + 1 > max_control_points) { + return JXL_FAILURE("Too many splines: %" PRIuS, num_splines); + } + num_splines++; + JXL_RETURN_IF_ERROR(DecodeAllStartingPoints(&starting_points_, br, &decoder, + context_map, num_splines)); + + quantization_adjustment_ = UnpackSigned( + decoder.ReadHybridUint(kQuantizationAdjustmentContext, br, context_map)); + + splines_.clear(); + splines_.reserve(num_splines); + size_t num_control_points = num_splines; + for (size_t i = 0; i < num_splines; ++i) { + QuantizedSpline spline; + JXL_RETURN_IF_ERROR(spline.Decode(context_map, &decoder, br, + max_control_points, &num_control_points)); + splines_.push_back(std::move(spline)); + } + + JXL_RETURN_IF_ERROR(decoder.CheckANSFinalState()); + + if (!HasAny()) { + return JXL_FAILURE("Decoded splines but got none"); + } + + return true; +} + +void Splines::AddTo(Image3F* const opsin, const Rect& opsin_rect, + const Rect& image_rect) const { + return Apply</*add=*/true>(opsin, opsin_rect, image_rect); +} +void Splines::AddToRow(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y, + float* JXL_RESTRICT row_b, const Rect& image_row) const { + return ApplyToRow</*add=*/true>(row_x, row_y, row_b, image_row); +} + +void Splines::SubtractFrom(Image3F* const opsin) const { + return Apply</*add=*/false>(opsin, Rect(*opsin), Rect(*opsin)); +} + +Status Splines::InitializeDrawCache(const size_t image_xsize, + const size_t image_ysize, + const ColorCorrelationMap& cmap) { + // TODO(veluca): avoid storing segments that are entirely outside image + // boundaries. + segments_.clear(); + segment_indices_.clear(); + segment_y_start_.clear(); + std::vector<std::pair<size_t, size_t>> segments_by_y; + std::vector<Spline::Point> intermediate_points; + uint64_t total_estimated_area_reached = 0; + std::vector<Spline> splines; + for (size_t i = 0; i < splines_.size(); ++i) { + Spline spline; + JXL_RETURN_IF_ERROR(splines_[i].Dequantize( + starting_points_[i], quantization_adjustment_, cmap.YtoXRatio(0), + cmap.YtoBRatio(0), image_xsize * image_ysize, + &total_estimated_area_reached, spline)); + if (std::adjacent_find(spline.control_points.begin(), + spline.control_points.end()) != + spline.control_points.end()) { + // Otherwise division by zero might occur. Once control points coincide, + // the direction of curve is undefined... + return JXL_FAILURE( + "identical successive control points in spline %" PRIuS, i); + } + splines.push_back(spline); + } + // TODO(firsching) Change this into a JXL_FAILURE for level 5 codestreams. + if (total_estimated_area_reached > + std::min((8 * image_xsize * image_ysize + (uint64_t(1) << 25)), + (uint64_t(1) << 30))) { + JXL_WARNING( + "Large total_estimated_area_reached, expect slower decoding: %" PRIu64, + total_estimated_area_reached); +#ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + return JXL_FAILURE("Total spline area is too large"); +#endif + } + + for (Spline& spline : splines) { + std::vector<std::pair<Spline::Point, float>> points_to_draw; + auto add_point = [&](const Spline::Point& point, const float multiplier) { + points_to_draw.emplace_back(point, multiplier); + }; + intermediate_points.clear(); + DrawCentripetalCatmullRomSpline(spline.control_points, intermediate_points); + ForEachEquallySpacedPoint(intermediate_points, add_point); + const float arc_length = + (points_to_draw.size() - 2) * kDesiredRenderingDistance + + points_to_draw.back().second; + if (arc_length <= 0.f) { + // This spline wouldn't have any effect. + continue; + } + HWY_DYNAMIC_DISPATCH(SegmentsFromPoints) + (spline, points_to_draw, arc_length, segments_, segments_by_y); + } + + // TODO(eustas): consider linear sorting here. + std::sort(segments_by_y.begin(), segments_by_y.end()); + segment_indices_.resize(segments_by_y.size()); + segment_y_start_.resize(image_ysize + 1); + for (size_t i = 0; i < segments_by_y.size(); i++) { + segment_indices_[i] = segments_by_y[i].second; + size_t y = segments_by_y[i].first; + if (y < image_ysize) { + segment_y_start_[y + 1]++; + } + } + for (size_t y = 0; y < image_ysize; y++) { + segment_y_start_[y + 1] += segment_y_start_[y]; + } + return true; +} + +template <bool add> +void Splines::ApplyToRow(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y, + float* JXL_RESTRICT row_b, + const Rect& image_row) const { + if (segments_.empty()) return; + JXL_ASSERT(image_row.ysize() == 1); + for (size_t iy = 0; iy < image_row.ysize(); iy++) { + HWY_DYNAMIC_DISPATCH(DrawSegments) + (row_x, row_y, row_b, image_row.Line(iy), add, segments_.data(), + segment_indices_.data(), segment_y_start_.data()); + } +} + +template <bool add> +void Splines::Apply(Image3F* const opsin, const Rect& opsin_rect, + const Rect& image_rect) const { + if (segments_.empty()) return; + for (size_t iy = 0; iy < image_rect.ysize(); iy++) { + const size_t y0 = opsin_rect.Line(iy).y0(); + const size_t x0 = opsin_rect.x0(); + ApplyToRow<add>(opsin->PlaneRow(0, y0) + x0, opsin->PlaneRow(1, y0) + x0, + opsin->PlaneRow(2, y0) + x0, image_rect.Line(iy)); + } +} + +} // namespace jxl +#endif // HWY_ONCE diff --git a/third_party/jpeg-xl/lib/jxl/splines.h b/third_party/jpeg-xl/lib/jxl/splines.h new file mode 100644 index 0000000000..acdd0857d0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/splines.h @@ -0,0 +1,148 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_SPLINES_H_ +#define LIB_JXL_SPLINES_H_ + +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <utility> +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/image.h" + +namespace jxl { + +class ANSSymbolReader; +class BitReader; + +static constexpr float kDesiredRenderingDistance = 1.f; + +enum SplineEntropyContexts : size_t { + kQuantizationAdjustmentContext = 0, + kStartingPositionContext, + kNumSplinesContext, + kNumControlPointsContext, + kControlPointsContext, + kDCTContext, + kNumSplineContexts +}; + +struct Spline { + struct Point { + Point() : x(0.0f), y(0.0f) {} + Point(float x, float y) : x(x), y(y) {} + float x, y; + bool operator==(const Point& other) const { + return std::fabs(x - other.x) < 1e-3f && std::fabs(y - other.y) < 1e-3f; + } + }; + std::vector<Point> control_points; + // X, Y, B. + float color_dct[3][32]; + // Splines are draws by normalized Gaussian splatting. This controls the + // Gaussian's parameter along the spline. + float sigma_dct[32]; +}; + +class QuantizedSplineEncoder; + +class QuantizedSpline { + public: + QuantizedSpline() = default; + explicit QuantizedSpline(const Spline& original, + int32_t quantization_adjustment, float y_to_x, + float y_to_b); + + Status Dequantize(const Spline::Point& starting_point, + int32_t quantization_adjustment, float y_to_x, float y_to_b, + uint64_t image_size, uint64_t* total_estimated_area_reached, + Spline& result) const; + + Status Decode(const std::vector<uint8_t>& context_map, + ANSSymbolReader* decoder, BitReader* br, + size_t max_control_points, size_t* total_num_control_points); + + private: + friend class QuantizedSplineEncoder; + + std::vector<std::pair<int64_t, int64_t>> + control_points_; // Double delta-encoded. + int color_dct_[3][32] = {}; + int sigma_dct_[32] = {}; +}; + +// A single "drawable unit" of a spline, i.e. a line of the region in which we +// render each Gaussian. The structure doesn't actually depend on the exact +// row, which allows reuse for different y values (which are tracked +// separately). +struct SplineSegment { + float center_x, center_y; + float maximum_distance; + float inv_sigma; + float sigma_over_4_times_intensity; + float color[3]; +}; + +class Splines { + public: + Splines() = default; + explicit Splines(const int32_t quantization_adjustment, + std::vector<QuantizedSpline> splines, + std::vector<Spline::Point> starting_points) + : quantization_adjustment_(quantization_adjustment), + splines_(std::move(splines)), + starting_points_(std::move(starting_points)) {} + + bool HasAny() const { return !splines_.empty(); } + + void Clear(); + + Status Decode(BitReader* br, size_t num_pixels); + + void AddTo(Image3F* opsin, const Rect& opsin_rect, + const Rect& image_rect) const; + void AddToRow(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y, + float* JXL_RESTRICT row_b, const Rect& image_row) const; + void SubtractFrom(Image3F* opsin) const; + + const std::vector<QuantizedSpline>& QuantizedSplines() const { + return splines_; + } + const std::vector<Spline::Point>& StartingPoints() const { + return starting_points_; + } + + int32_t GetQuantizationAdjustment() const { return quantization_adjustment_; } + + Status InitializeDrawCache(size_t image_xsize, size_t image_ysize, + const ColorCorrelationMap& cmap); + + private: + template <bool> + void ApplyToRow(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y, + float* JXL_RESTRICT row_b, const Rect& image_row) const; + template <bool> + void Apply(Image3F* opsin, const Rect& opsin_rect, + const Rect& image_rect) const; + + // If positive, quantization weights are multiplied by 1 + this/8, which + // increases precision. If negative, they are divided by 1 - this/8. If 0, + // they are unchanged. + int32_t quantization_adjustment_ = 0; + std::vector<QuantizedSpline> splines_; + std::vector<Spline::Point> starting_points_; + std::vector<SplineSegment> segments_; + std::vector<size_t> segment_indices_; + std::vector<size_t> segment_y_start_; +}; + +} // namespace jxl + +#endif // LIB_JXL_SPLINES_H_ diff --git a/third_party/jpeg-xl/lib/jxl/splines_gbench.cc b/third_party/jpeg-xl/lib/jxl/splines_gbench.cc new file mode 100644 index 0000000000..78ff6d41c0 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/splines_gbench.cc @@ -0,0 +1,52 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "benchmark/benchmark.h" +#include "lib/jxl/splines.h" + +namespace jxl { +namespace { + +constexpr int kQuantizationAdjustment = 0; +const ColorCorrelationMap* const cmap = new ColorCorrelationMap; +const float kYToX = cmap->YtoXRatio(0); +const float kYToB = cmap->YtoBRatio(0); + +void BM_Splines(benchmark::State& state) { + const size_t n = state.range(); + + std::vector<Spline> spline_data = { + {/*control_points=*/{ + {9, 54}, {118, 159}, {97, 3}, {10, 40}, {150, 25}, {120, 300}}, + /*color_dct=*/ + {{0.03125f, 0.00625f, 0.003125f}, {1.f, 0.321875f}, {1.f, 0.24375f}}, + /*sigma_dct=*/{0.3125f, 0.f, 0.f, 0.0625f}}}; + std::vector<QuantizedSpline> quantized_splines; + std::vector<Spline::Point> starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + + Image3F drawing_area(320, 320); + ZeroFillImage(&drawing_area); + for (auto _ : state) { + for (size_t i = 0; i < n; ++i) { + JXL_CHECK(splines.InitializeDrawCache(drawing_area.xsize(), + drawing_area.ysize(), *cmap)); + splines.AddTo(&drawing_area, Rect(drawing_area), Rect(drawing_area)); + } + } + + state.SetItemsProcessed(n * state.iterations()); +} + +BENCHMARK(BM_Splines)->Range(1, 1 << 10); + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/splines_test.cc b/third_party/jpeg-xl/lib/jxl/splines_test.cc new file mode 100644 index 0000000000..d812545a37 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/splines_test.cc @@ -0,0 +1,366 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/splines.h" + +#include <jxl/cms.h> + +#include <cstddef> +#include <cstdint> +#include <cstdio> +#include <ostream> +#include <utility> +#include <vector> + +#include "lib/extras/codec.h" +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/chroma_from_luma.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_splines.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/image_test_utils.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" + +namespace jxl { + +std::ostream& operator<<(std::ostream& os, const Spline::Point& p) { + return os << "(" << p.x << ", " << p.y << ")"; +} + +std::ostream& operator<<(std::ostream& os, const Spline& spline) { + return os << "(spline with " << spline.control_points.size() + << " control points)"; +} + +namespace { + +using test::ReadTestData; +using ::testing::AllOf; +using ::testing::Field; +using ::testing::FloatNear; +using ::testing::Pointwise; + +constexpr int kQuantizationAdjustment = 0; +const ColorCorrelationMap* const cmap = new ColorCorrelationMap; +const float kYToX = cmap->YtoXRatio(0); +const float kYToB = cmap->YtoBRatio(0); + +constexpr float kTolerance = 0.003125; + +std::vector<Spline> DequantizeSplines(const Splines& splines) { + const auto& quantized_splines = splines.QuantizedSplines(); + const auto& starting_points = splines.StartingPoints(); + JXL_CHECK(quantized_splines.size() == starting_points.size()); + + std::vector<Spline> dequantized; + uint64_t total = 0; + for (size_t i = 0; i < quantized_splines.size(); ++i) { + dequantized.emplace_back(); + JXL_CHECK(quantized_splines[i].Dequantize( + starting_points[i], kQuantizationAdjustment, kYToX, kYToB, 2u << 30u, + &total, dequantized.back())); + } + return dequantized; +} + +MATCHER(ControlPointIs, "") { + const Spline::Point& actual = std::get<0>(arg); + const Spline::Point& expected = std::get<1>(arg); + return testing::ExplainMatchResult( + AllOf(Field(&Spline::Point::x, FloatNear(expected.x, kTolerance)), + Field(&Spline::Point::y, FloatNear(expected.y, kTolerance))), + actual, result_listener); +} + +MATCHER(ControlPointsMatch, "") { + const Spline& actual = std::get<0>(arg); + const Spline& expected = std::get<1>(arg); + return testing::ExplainMatchResult( + Field(&Spline::control_points, + Pointwise(ControlPointIs(), expected.control_points)), + actual, result_listener); +} + +MATCHER(SplinesMatch, "") { + const Spline& actual = std::get<0>(arg); + const Spline& expected = std::get<1>(arg); + if (!testing::ExplainMatchResult(ControlPointsMatch(), arg, + result_listener)) { + return false; + } + for (int i = 0; i < 3; ++i) { + size_t color_dct_size = + sizeof(expected.color_dct[i]) / sizeof(expected.color_dct[i][0]); + for (size_t j = 0; j < color_dct_size; j++) { + testing::StringMatchResultListener color_dct_listener; + if (!testing::ExplainMatchResult( + FloatNear(expected.color_dct[i][j], kTolerance), + actual.color_dct[i][j], &color_dct_listener)) { + *result_listener << ", where color_dct[" << i << "][" << j + << "] don't match, " << color_dct_listener.str(); + return false; + } + } + } + size_t sigma_dct_size = + sizeof(expected.sigma_dct) / sizeof(expected.sigma_dct[0]); + for (size_t i = 0; i < sigma_dct_size; i++) { + testing::StringMatchResultListener sigma_listener; + if (!testing::ExplainMatchResult( + FloatNear(expected.sigma_dct[i], kTolerance), actual.sigma_dct[i], + &sigma_listener)) { + *result_listener << ", where sigma_dct[" << i << "] don't match, " + << sigma_listener.str(); + return false; + } + } + return true; +} + +} // namespace + +TEST(SplinesTest, Serialization) { + std::vector<Spline> spline_data = { + {/*control_points=*/{ + {109, 54}, {218, 159}, {80, 3}, {110, 274}, {94, 185}, {17, 277}}, + /*color_dct=*/ + {{36.3, 39.7, 23.2, 67.5, 4.4, 71.5, 62.3, 32.3, 92.2, 10.1, 10.8, + 9.2, 6.1, 10.5, 79.1, 7, 24.6, 90.8, 5.5, 84, 43.8, 49, + 33.5, 78.9, 54.5, 77.9, 62.1, 51.4, 36.4, 14.3, 83.7, 35.4}, + {9.4, 53.4, 9.5, 74.9, 72.7, 26.7, 7.9, 0.9, 84.9, 23.2, 26.5, + 31.1, 91, 11.7, 74.1, 39.3, 23.7, 82.5, 4.8, 2.7, 61.2, 96.4, + 13.7, 66.7, 62.9, 82.4, 5.9, 98.7, 21.5, 7.9, 51.7, 63.1}, + {48, 39.3, 6.9, 26.3, 33.3, 6.2, 1.7, 98.9, 59.9, 59.6, 95, + 61.3, 82.7, 53, 6.1, 30.4, 34.7, 96.9, 93.4, 17, 38.8, 80.8, + 63, 18.6, 43.6, 32.3, 61, 20.2, 24.3, 28.3, 69.1, 62.4}}, + /*sigma_dct=*/{32.7, 21.5, 44.4, 1.8, 45.8, 90.6, 29.3, 59.2, + 23.7, 85.2, 84.8, 27.2, 42.1, 84.1, 50.6, 17.6, + 93.7, 4.9, 2.6, 69.8, 94.9, 52, 24.3, 18.8, + 12.1, 95.7, 28.5, 81.4, 89.9, 31.4, 74.8, 52}}, + {/*control_points=*/{{172, 309}, + {196, 277}, + {42, 238}, + {114, 350}, + {307, 290}, + {316, 269}, + {124, 66}, + {233, 267}}, + /*color_dct=*/ + {{15, 28.9, 22, 6.6, 41.8, 83, 8.6, 56.8, 68.9, 9.7, 5.4, + 19.8, 70.8, 90, 52.5, 65.2, 7.8, 23.5, 26.4, 72.2, 64.7, 87.1, + 1.3, 67.5, 46, 68.4, 65.4, 35.5, 29.1, 13, 41.6, 23.9}, + {47.7, 79.4, 62.7, 29.1, 96.8, 18.5, 17.6, 15.2, 80.5, 56, 96.2, + 59.9, 26.7, 96.1, 92.3, 42.1, 35.8, 54, 23.2, 55, 76, 35.8, + 58.4, 88.7, 2.4, 78.1, 95.6, 27.5, 6.6, 78.5, 24.1, 69.8}, + {43.8, 96.5, 0.9, 95.1, 49.1, 71.2, 25.1, 33.6, 75.2, 95, 82.1, + 19.7, 10.5, 44.9, 50, 93.3, 83.5, 99.5, 64.6, 54, 3.5, 99.7, + 45.3, 82.1, 22.4, 37.9, 60, 32.2, 12.6, 4.6, 65.5, 96.4}}, + /*sigma_dct=*/{72.5, 2.6, 41.7, 2.2, 39.7, 79.1, 69.6, 19.9, + 92.3, 71.5, 41.9, 62.1, 30, 49.4, 70.3, 45.3, + 62.5, 47.2, 46.7, 41.2, 90.8, 46.8, 91.2, 55, + 8.1, 69.6, 25.4, 84.7, 61.7, 27.6, 3.7, 46.9}}, + {/*control_points=*/{{100, 186}, + {257, 97}, + {170, 49}, + {25, 169}, + {309, 104}, + {232, 237}, + {385, 101}, + {122, 168}, + {26, 300}, + {390, 88}}, + /*color_dct=*/ + {{16.9, 64.8, 4.2, 10.6, 23.5, 17, 79.3, 5.7, 60.4, 16.6, 94.9, + 63.7, 87.6, 10.5, 3.8, 61.1, 22.9, 81.9, 80.4, 40.5, 45.9, 25.4, + 39.8, 30, 50.2, 90.4, 27.9, 93.7, 65.1, 48.2, 22.3, 43.9}, + {24.9, 66, 3.5, 90.2, 97.1, 15.8, 35.6, 0.6, 68, 39.6, 24.4, + 85.9, 57.7, 77.6, 47.5, 67.9, 4.3, 5.4, 91.2, 58.5, 0.1, 52.2, + 3.5, 47.8, 63.2, 43.5, 85.8, 35.8, 50.2, 35.9, 19.2, 48.2}, + {82.8, 44.9, 76.4, 39.5, 94.1, 14.3, 89.8, 10, 10.5, 74.5, 56.3, + 65.8, 7.8, 23.3, 52.8, 99.3, 56.8, 46, 76.7, 13.5, 67, 22.4, + 29.9, 43.3, 70.3, 26, 74.3, 53.9, 62, 19.1, 49.3, 46.7}}, + /*sigma_dct=*/{83.5, 1.7, 25.1, 18.7, 46.5, 75.3, 28, 62.3, + 50.3, 23.3, 85.6, 96, 45.8, 33.1, 33.4, 52.9, + 26.3, 58.5, 19.6, 70, 92.6, 22.5, 57, 21.6, + 76.8, 87.5, 22.9, 66.3, 35.7, 35.6, 56.8, 67.2}}, + }; + + std::vector<QuantizedSpline> quantized_splines; + std::vector<Spline::Point> starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + const std::vector<Spline> quantized_spline_data = DequantizeSplines(splines); + EXPECT_THAT(quantized_spline_data, + Pointwise(ControlPointsMatch(), spline_data)); + + BitWriter writer; + EncodeSplines(splines, &writer, kLayerSplines, HistogramParams(), nullptr); + writer.ZeroPadToByte(); + const size_t bits_written = writer.BitsWritten(); + + printf("Wrote %" PRIuS " bits of splines.\n", bits_written); + + BitReader reader(writer.GetSpan()); + Splines decoded_splines; + ASSERT_TRUE(decoded_splines.Decode(&reader, /*num_pixels=*/1000)); + ASSERT_TRUE(reader.JumpToByteBoundary()); + EXPECT_EQ(reader.TotalBitsConsumed(), bits_written); + ASSERT_TRUE(reader.Close()); + + const std::vector<Spline> decoded_spline_data = + DequantizeSplines(decoded_splines); + EXPECT_THAT(decoded_spline_data, + Pointwise(SplinesMatch(), quantized_spline_data)); +} + +#ifdef JXL_CRASH_ON_ERROR +TEST(SplinesTest, DISABLED_TooManySplinesTest) { +#else +TEST(SplinesTest, TooManySplinesTest) { +#endif + // This is more than the limit for 1000 pixels. + const size_t kNumSplines = 300; + + std::vector<QuantizedSpline> quantized_splines; + std::vector<Spline::Point> starting_points; + for (size_t i = 0; i < kNumSplines; i++) { + Spline spline = { + /*control_points=*/{{1.f + i, 2}, {10.f + i, 25}, {30.f + i, 300}}, + /*color_dct=*/ + {{1.f, 0.2f, 0.1f}, {35.7f, 10.3f}, {35.7f, 7.8f}}, + /*sigma_dct=*/{10.f, 0.f, 0.f, 2.f}}; + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + BitWriter writer; + EncodeSplines(splines, &writer, kLayerSplines, + HistogramParams(SpeedTier::kFalcon, 1), nullptr); + writer.ZeroPadToByte(); + // Re-read splines. + BitReader reader(writer.GetSpan()); + Splines decoded_splines; + EXPECT_FALSE(decoded_splines.Decode(&reader, /*num_pixels=*/1000)); + EXPECT_TRUE(reader.Close()); +} + +#ifdef JXL_CRASH_ON_ERROR +TEST(SplinesTest, DISABLED_DuplicatePoints) { +#else +TEST(SplinesTest, DuplicatePoints) { +#endif + std::vector<Spline::Point> control_points{ + {9, 54}, {118, 159}, {97, 3}, // Repeated. + {97, 3}, {10, 40}, {150, 25}, {120, 300}}; + Spline spline{control_points, + /*color_dct=*/ + {{1.f, 0.2f, 0.1f}, {35.7f, 10.3f}, {35.7f, 7.8f}}, + /*sigma_dct=*/{10.f, 0.f, 0.f, 2.f}}; + std::vector<Spline> spline_data{spline}; + std::vector<QuantizedSpline> quantized_splines; + std::vector<Spline::Point> starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + + Image3F image(320, 320); + ZeroFillImage(&image); + EXPECT_FALSE( + splines.InitializeDrawCache(image.xsize(), image.ysize(), *cmap)); +} + +TEST(SplinesTest, Drawing) { + CodecInOut io_expected; + const std::vector<uint8_t> orig = ReadTestData("jxl/splines.pfm"); + ASSERT_TRUE(SetFromBytes(Bytes(orig), &io_expected, + /*pool=*/nullptr)); + + std::vector<Spline::Point> control_points{{9, 54}, {118, 159}, {97, 3}, + {10, 40}, {150, 25}, {120, 300}}; + // Use values that survive quant/decorellation roundtrip. + const Spline spline{ + control_points, + /*color_dct=*/ + {{0.4989345073699951171875000f, 0.4997999966144561767578125f}, + {0.4772970676422119140625000f, 0.f, 0.5250000357627868652343750f}, + {-0.0176776945590972900390625f, 0.4900000095367431640625000f, + 0.5250000357627868652343750f}}, + /*sigma_dct=*/ + {0.9427147507667541503906250f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.6665999889373779296875000f}}; + std::vector<Spline> spline_data = {spline}; + std::vector<QuantizedSpline> quantized_splines; + std::vector<Spline::Point> starting_points; + for (const Spline& spline : spline_data) { + quantized_splines.emplace_back(spline, kQuantizationAdjustment, kYToX, + kYToB); + starting_points.push_back(spline.control_points.front()); + } + Splines splines(kQuantizationAdjustment, std::move(quantized_splines), + std::move(starting_points)); + + Image3F image(320, 320); + ZeroFillImage(&image); + ASSERT_TRUE(splines.InitializeDrawCache(image.xsize(), image.ysize(), *cmap)); + splines.AddTo(&image, Rect(image), Rect(image)); + + CodecInOut io_actual; + Image3F image2(320, 320); + CopyImageTo(image, &image2); + io_actual.SetFromImage(std::move(image2), ColorEncoding::SRGB()); + ASSERT_TRUE(io_actual.frames[0].TransformTo(io_expected.Main().c_current(), + *JxlGetDefaultCms())); + + JXL_ASSERT_OK(VerifyRelativeError( + *io_expected.Main().color(), *io_actual.Main().color(), 1e-2f, 1e-1f, _)); +} + +TEST(SplinesTest, ClearedEveryFrame) { + CodecInOut io_expected; + const std::vector<uint8_t> bytes_expected = + ReadTestData("jxl/spline_on_first_frame.png"); + ASSERT_TRUE(SetFromBytes(Bytes(bytes_expected), &io_expected, + /*pool=*/nullptr)); + CodecInOut io_actual; + const std::vector<uint8_t> bytes_actual = + ReadTestData("jxl/spline_on_first_frame.jxl"); + ASSERT_TRUE(test::DecodeFile({}, Bytes(bytes_actual), &io_actual)); + + ASSERT_TRUE(io_actual.frames[0].TransformTo(ColorEncoding::SRGB(), + *JxlGetDefaultCms())); + for (size_t c = 0; c < 3; ++c) { + for (size_t y = 0; y < io_actual.ysize(); ++y) { + float* const JXL_RESTRICT row = io_actual.Main().color()->PlaneRow(c, y); + for (size_t x = 0; x < io_actual.xsize(); ++x) { + row[x] = Clamp1(row[x], 0.f, 1.f); + } + } + } + JXL_ASSERT_OK(VerifyRelativeError( + *io_expected.Main().color(), *io_actual.Main().color(), 1e-2f, 1e-1f, _)); +} + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/test_image.cc b/third_party/jpeg-xl/lib/jxl/test_image.cc new file mode 100644 index 0000000000..098e9c25a1 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/test_image.cc @@ -0,0 +1,456 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/test_image.h" + +#include <jxl/encode.h> + +#include <algorithm> +#include <cstring> +#include <utility> + +#include "lib/extras/dec/color_description.h" +#include "lib/extras/dec/color_hints.h" +#include "lib/extras/dec/decode.h" +#include "lib/jxl/base/byte_order.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/color_encoding_internal.h" + +namespace jxl { +namespace test { + +namespace { + +void StoreValue(float val, size_t bits_per_sample, JxlPixelFormat format, + uint8_t** out) { + const float mul = (1u << bits_per_sample) - 1; + if (format.data_type == JXL_TYPE_UINT8) { + **out = val * mul; + } else if (format.data_type == JXL_TYPE_UINT16) { + uint16_t uval = val * mul; + if (SwapEndianness(format.endianness)) { + uval = JXL_BSWAP16(uval); + } + memcpy(*out, &uval, 2); + } else if (format.data_type == JXL_TYPE_FLOAT) { + // TODO(szabadka) Add support for custom bits / exponent bits floats. + if (SwapEndianness(format.endianness)) { + val = BSwapFloat(val); + } + memcpy(*out, &val, 4); + } else { + // TODO(szabadka) Add support for FLOAT16. + } + *out += extras::PackedImage::BitsPerChannel(format.data_type) / 8; +} + +void FillPackedImage(size_t bits_per_sample, uint16_t seed, + extras::PackedImage* image) { + const size_t xsize = image->xsize; + const size_t ysize = image->ysize; + const JxlPixelFormat format = image->format; + + // Cause more significant image difference for successive seeds. + Rng generator(seed); + + // Returns random integer in interval [0, max_value) + auto rngu = [&generator](size_t max_value) -> size_t { + return generator.UniformU(0, max_value); + }; + + // Returns random float in interval [0.0, max_value) + auto rngf = [&generator](float max_value) { + return generator.UniformF(0.0f, max_value); + }; + + // Dark background gradient color + float r0 = rngf(0.5f); + float g0 = rngf(0.5f); + float b0 = rngf(0.5f); + float a0 = rngf(0.5f); + float r1 = rngf(0.5f); + float g1 = rngf(0.5f); + float b1 = rngf(0.5f); + float a1 = rngf(0.5f); + + // Circle with different color + size_t circle_x = rngu(xsize); + size_t circle_y = rngu(ysize); + size_t circle_r = rngu(std::min(xsize, ysize)); + + // Rectangle with random noise + size_t rect_x0 = rngu(xsize); + size_t rect_y0 = rngu(ysize); + size_t rect_x1 = rngu(xsize); + size_t rect_y1 = rngu(ysize); + if (rect_x1 < rect_x0) std::swap(rect_x0, rect_y1); + if (rect_y1 < rect_y0) std::swap(rect_y0, rect_y1); + + // Create pixel content to test, actual content does not matter as long as it + // can be compared after roundtrip. + const float imul16 = 1.0f / 65536.0f; + for (size_t y = 0; y < ysize; y++) { + uint8_t* out = + reinterpret_cast<uint8_t*>(image->pixels()) + y * image->stride; + for (size_t x = 0; x < xsize; x++) { + float r = r0 * (ysize - y - 1) / ysize + r1 * y / ysize; + float g = g0 * (ysize - y - 1) / ysize + g1 * y / ysize; + float b = b0 * (ysize - y - 1) / ysize + b1 * y / ysize; + float a = a0 * (ysize - y - 1) / ysize + a1 * y / ysize; + // put some shape in there for visual debugging + if ((x - circle_x) * (x - circle_x) + (y - circle_y) * (y - circle_y) < + circle_r * circle_r) { + r = std::min(1.0f, ((65535 - x * y) ^ seed) * imul16); + g = std::min(1.0f, ((x << 8) + y + seed) * imul16); + b = std::min(1.0f, ((y << 8) + x * seed) * imul16); + a = std::min(1.0f, (32768 + x * 256 - y) * imul16); + } else if (x > rect_x0 && x < rect_x1 && y > rect_y0 && y < rect_y1) { + r = rngf(1.0f); + g = rngf(1.0f); + b = rngf(1.0f); + a = rngf(1.0f); + } + if (format.num_channels == 1) { + StoreValue(g, bits_per_sample, format, &out); + } else if (format.num_channels == 2) { + StoreValue(g, bits_per_sample, format, &out); + StoreValue(a, bits_per_sample, format, &out); + } else if (format.num_channels == 3) { + StoreValue(r, bits_per_sample, format, &out); + StoreValue(g, bits_per_sample, format, &out); + StoreValue(b, bits_per_sample, format, &out); + } else if (format.num_channels == 4) { + StoreValue(r, bits_per_sample, format, &out); + StoreValue(g, bits_per_sample, format, &out); + StoreValue(b, bits_per_sample, format, &out); + StoreValue(a, bits_per_sample, format, &out); + } + } + } +} + +} // namespace + +std::vector<uint8_t> GetSomeTestImage(size_t xsize, size_t ysize, + size_t num_channels, uint16_t seed) { + // Cause more significant image difference for successive seeds. + Rng generator(seed); + + // Returns random integer in interval [0, max_value) + auto rng = [&generator](size_t max_value) -> size_t { + return generator.UniformU(0, max_value); + }; + + // Dark background gradient color + uint16_t r0 = rng(32768); + uint16_t g0 = rng(32768); + uint16_t b0 = rng(32768); + uint16_t a0 = rng(32768); + uint16_t r1 = rng(32768); + uint16_t g1 = rng(32768); + uint16_t b1 = rng(32768); + uint16_t a1 = rng(32768); + + // Circle with different color + size_t circle_x = rng(xsize); + size_t circle_y = rng(ysize); + size_t circle_r = rng(std::min(xsize, ysize)); + + // Rectangle with random noise + size_t rect_x0 = rng(xsize); + size_t rect_y0 = rng(ysize); + size_t rect_x1 = rng(xsize); + size_t rect_y1 = rng(ysize); + if (rect_x1 < rect_x0) std::swap(rect_x0, rect_y1); + if (rect_y1 < rect_y0) std::swap(rect_y0, rect_y1); + + size_t num_pixels = xsize * ysize; + // 16 bits per channel, big endian, 4 channels + std::vector<uint8_t> pixels(num_pixels * num_channels * 2); + // Create pixel content to test, actual content does not matter as long as it + // can be compared after roundtrip. + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + uint16_t r = r0 * (ysize - y - 1) / ysize + r1 * y / ysize; + uint16_t g = g0 * (ysize - y - 1) / ysize + g1 * y / ysize; + uint16_t b = b0 * (ysize - y - 1) / ysize + b1 * y / ysize; + uint16_t a = a0 * (ysize - y - 1) / ysize + a1 * y / ysize; + // put some shape in there for visual debugging + if ((x - circle_x) * (x - circle_x) + (y - circle_y) * (y - circle_y) < + circle_r * circle_r) { + r = (65535 - x * y) ^ seed; + g = (x << 8) + y + seed; + b = (y << 8) + x * seed; + a = 32768 + x * 256 - y; + } else if (x > rect_x0 && x < rect_x1 && y > rect_y0 && y < rect_y1) { + r = rng(65536); + g = rng(65536); + b = rng(65536); + a = rng(65536); + } + size_t i = (y * xsize + x) * 2 * num_channels; + pixels[i + 0] = (r >> 8); + pixels[i + 1] = (r & 255); + if (num_channels >= 2) { + // This may store what is called 'g' in the alpha channel of a 2-channel + // image, but that's ok since the content is arbitrary + pixels[i + 2] = (g >> 8); + pixels[i + 3] = (g & 255); + } + if (num_channels >= 3) { + pixels[i + 4] = (b >> 8); + pixels[i + 5] = (b & 255); + } + if (num_channels >= 4) { + pixels[i + 6] = (a >> 8); + pixels[i + 7] = (a & 255); + } + } + } + return pixels; +} + +TestImage::TestImage() { + SetChannels(3); + SetAllBitDepths(8); + SetColorEncoding("RGB_D65_SRG_Rel_SRG"); +} + +TestImage& TestImage::DecodeFromBytes(const std::vector<uint8_t>& bytes) { + ColorEncoding c_enc; + JXL_CHECK(c_enc.FromExternal(ppf_.color_encoding)); + extras::ColorHints color_hints; + color_hints.Add("color_space", Description(c_enc)); + JXL_CHECK(extras::DecodeBytes(Bytes(bytes), color_hints, &ppf_)); + return *this; +} + +TestImage& TestImage::ClearMetadata() { + ppf_.metadata = extras::PackedMetadata(); + return *this; +} + +TestImage& TestImage::SetDimensions(size_t xsize, size_t ysize) { + if (xsize <= ppf_.info.xsize && ysize <= ppf_.info.ysize) { + for (auto& frame : ppf_.frames) { + CropLayerInfo(xsize, ysize, &frame.frame_info.layer_info); + CropImage(xsize, ysize, &frame.color); + for (auto& ec : frame.extra_channels) { + CropImage(xsize, ysize, &ec); + } + } + } else { + JXL_CHECK(ppf_.info.xsize == 0 && ppf_.info.ysize == 0); + } + ppf_.info.xsize = xsize; + ppf_.info.ysize = ysize; + return *this; +} + +TestImage& TestImage::SetChannels(size_t num_channels) { + JXL_CHECK(ppf_.frames.empty()); + JXL_CHECK(!ppf_.preview_frame); + ppf_.info.num_color_channels = num_channels < 3 ? 1 : 3; + ppf_.info.num_extra_channels = num_channels - ppf_.info.num_color_channels; + if (ppf_.info.num_extra_channels > 0 && ppf_.info.alpha_bits == 0) { + ppf_.info.alpha_bits = ppf_.info.bits_per_sample; + ppf_.info.alpha_exponent_bits = ppf_.info.exponent_bits_per_sample; + } + ppf_.extra_channels_info.clear(); + for (size_t i = 1; i < ppf_.info.num_extra_channels; ++i) { + extras::PackedExtraChannel ec; + ec.index = i; + JxlEncoderInitExtraChannelInfo(JXL_CHANNEL_ALPHA, &ec.ec_info); + if (ec.ec_info.bits_per_sample == 0) { + ec.ec_info.bits_per_sample = ppf_.info.bits_per_sample; + ec.ec_info.exponent_bits_per_sample = ppf_.info.exponent_bits_per_sample; + } + ppf_.extra_channels_info.emplace_back(std::move(ec)); + } + format_.num_channels = std::min(static_cast<size_t>(4), num_channels); + if (ppf_.info.num_color_channels == 1 && + ppf_.color_encoding.color_space != JXL_COLOR_SPACE_GRAY) { + SetColorEncoding("Gra_D65_Rel_SRG"); + } + return *this; +} + +// Sets the same bit depth on color, alpha and all extra channels. +TestImage& TestImage::SetAllBitDepths(uint32_t bits_per_sample, + uint32_t exponent_bits_per_sample) { + ppf_.info.bits_per_sample = bits_per_sample; + ppf_.info.exponent_bits_per_sample = exponent_bits_per_sample; + if (ppf_.info.num_extra_channels > 0) { + ppf_.info.alpha_bits = bits_per_sample; + ppf_.info.alpha_exponent_bits = exponent_bits_per_sample; + } + for (size_t i = 0; i < ppf_.extra_channels_info.size(); ++i) { + extras::PackedExtraChannel& ec = ppf_.extra_channels_info[i]; + ec.ec_info.bits_per_sample = bits_per_sample; + ec.ec_info.exponent_bits_per_sample = exponent_bits_per_sample; + } + format_.data_type = DefaultDataType(ppf_.info); + return *this; +} + +TestImage& TestImage::SetDataType(JxlDataType data_type) { + format_.data_type = data_type; + return *this; +} + +TestImage& TestImage::SetEndianness(JxlEndianness endianness) { + format_.endianness = endianness; + return *this; +} + +TestImage& TestImage::SetRowAlignment(size_t align) { + format_.align = align; + return *this; +} + +TestImage& TestImage::SetColorEncoding(const std::string& description) { + JXL_CHECK(ParseDescription(description, &ppf_.color_encoding)); + ColorEncoding c_enc; + JXL_CHECK(c_enc.FromExternal(ppf_.color_encoding)); + IccBytes icc = c_enc.ICC(); + JXL_CHECK(!icc.empty()); + ppf_.icc.assign(icc.begin(), icc.end()); + return *this; +} + +TestImage& TestImage::CoalesceGIFAnimationWithAlpha() { + extras::PackedFrame canvas = ppf_.frames[0].Copy(); + JXL_CHECK(canvas.color.format.num_channels == 3); + JXL_CHECK(canvas.color.format.data_type == JXL_TYPE_UINT8); + JXL_CHECK(canvas.extra_channels.size() == 1); + for (size_t i = 1; i < ppf_.frames.size(); i++) { + const extras::PackedFrame& frame = ppf_.frames[i]; + JXL_CHECK(frame.extra_channels.size() == 1); + const JxlLayerInfo& layer_info = frame.frame_info.layer_info; + extras::PackedFrame rendered = canvas.Copy(); + uint8_t* pixels_rendered = + reinterpret_cast<uint8_t*>(rendered.color.pixels()); + const uint8_t* pixels_frame = + reinterpret_cast<const uint8_t*>(frame.color.pixels()); + uint8_t* alpha_rendered = + reinterpret_cast<uint8_t*>(rendered.extra_channels[0].pixels()); + const uint8_t* alpha_frame = + reinterpret_cast<const uint8_t*>(frame.extra_channels[0].pixels()); + for (size_t y = 0; y < frame.color.ysize; y++) { + for (size_t x = 0; x < frame.color.xsize; x++) { + size_t idx_frame = y * frame.color.xsize + x; + size_t idx_rendered = ((layer_info.crop_y0 + y) * rendered.color.xsize + + (layer_info.crop_x0 + x)); + if (alpha_frame[idx_frame] != 0) { + memcpy(&pixels_rendered[idx_rendered * 3], + &pixels_frame[idx_frame * 3], 3); + alpha_rendered[idx_rendered] = alpha_frame[idx_frame]; + } + } + } + if (layer_info.save_as_reference != 0) { + canvas = rendered.Copy(); + } + ppf_.frames[i] = std::move(rendered); + } + return *this; +} + +TestImage::Frame::Frame(TestImage* parent, bool is_preview, size_t index) + : parent_(parent), is_preview_(is_preview), index_(index) {} + +void TestImage::Frame::ZeroFill() { + memset(frame().color.pixels(), 0, frame().color.pixels_size); + for (auto& ec : frame().extra_channels) { + memset(ec.pixels(), 0, ec.pixels_size); + } +} + +void TestImage::Frame::RandomFill(uint16_t seed) { + FillPackedImage(ppf().info.bits_per_sample, seed, &frame().color); + for (size_t i = 0; i < ppf().extra_channels_info.size(); ++i) { + FillPackedImage(ppf().extra_channels_info[i].ec_info.bits_per_sample, + seed + 1 + i, &frame().extra_channels[i]); + } +} + +void TestImage::Frame::SetValue(size_t y, size_t x, size_t c, float val) { + const extras::PackedImage& color = frame().color; + JxlPixelFormat format = color.format; + JXL_CHECK(y < ppf().info.ysize); + JXL_CHECK(x < ppf().info.xsize); + JXL_CHECK(c < format.num_channels); + size_t pwidth = extras::PackedImage::BitsPerChannel(format.data_type) / 8; + size_t idx = ((y * color.xsize + x) * format.num_channels + c) * pwidth; + uint8_t* pixels = reinterpret_cast<uint8_t*>(frame().color.pixels()); + uint8_t* p = pixels + idx; + StoreValue(val, ppf().info.bits_per_sample, frame().color.format, &p); +} + +TestImage::Frame TestImage::AddFrame() { + size_t index = ppf_.frames.size(); + extras::PackedFrame frame(ppf_.info.xsize, ppf_.info.ysize, format_); + for (size_t i = 0; i < ppf_.extra_channels_info.size(); ++i) { + JxlPixelFormat ec_format = {1, format_.data_type, format_.endianness, 0}; + extras::PackedImage image(ppf_.info.xsize, ppf_.info.ysize, ec_format); + frame.extra_channels.emplace_back(std::move(image)); + } + ppf_.frames.emplace_back(std::move(frame)); + return Frame(this, false, index); +} + +TestImage::Frame TestImage::AddPreview(size_t xsize, size_t ysize) { + extras::PackedFrame frame(xsize, ysize, format_); + for (size_t i = 0; i < ppf_.extra_channels_info.size(); ++i) { + JxlPixelFormat ec_format = {1, format_.data_type, format_.endianness, 0}; + extras::PackedImage image(xsize, ysize, ec_format); + frame.extra_channels.emplace_back(std::move(image)); + } + ppf_.preview_frame = make_unique<extras::PackedFrame>(std::move(frame)); + return Frame(this, true, 0); +} + +void TestImage::CropLayerInfo(size_t xsize, size_t ysize, JxlLayerInfo* info) { + if (info->crop_x0 < static_cast<ssize_t>(xsize)) { + info->xsize = std::min<size_t>(info->xsize, xsize - info->crop_x0); + } else { + info->xsize = 0; + } + if (info->crop_y0 < static_cast<ssize_t>(ysize)) { + info->ysize = std::min<size_t>(info->ysize, ysize - info->crop_y0); + } else { + info->ysize = 0; + } +} + +void TestImage::CropImage(size_t xsize, size_t ysize, + extras::PackedImage* image) { + size_t new_stride = (image->stride / image->xsize) * xsize; + uint8_t* buf = reinterpret_cast<uint8_t*>(image->pixels()); + for (size_t y = 0; y < ysize; ++y) { + memmove(&buf[y * new_stride], &buf[y * image->stride], new_stride); + } + image->xsize = xsize; + image->ysize = ysize; + image->stride = new_stride; + image->pixels_size = ysize * new_stride; +} + +JxlDataType TestImage::DefaultDataType(const JxlBasicInfo& info) { + if (info.bits_per_sample == 16 && info.exponent_bits_per_sample == 5) { + return JXL_TYPE_FLOAT16; + } else if (info.exponent_bits_per_sample > 0 || info.bits_per_sample > 16) { + return JXL_TYPE_FLOAT; + } else if (info.bits_per_sample > 8) { + return JXL_TYPE_UINT16; + } else { + return JXL_TYPE_UINT8; + } +} + +} // namespace test +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/test_image.h b/third_party/jpeg-xl/lib/jxl/test_image.h new file mode 100644 index 0000000000..13d0806ec8 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/test_image.h @@ -0,0 +1,96 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_TEST_IMAGE_H_ +#define LIB_JXL_TEST_IMAGE_H_ + +#include <jxl/codestream_header.h> +#include <jxl/types.h> +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <vector> + +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/span.h" + +namespace jxl { +namespace test { + +// Returns a test image with some autogenerated pixel content, using 16 bits per +// channel, big endian order, 1 to 4 channels +// The seed parameter allows to create images with different pixel content. +std::vector<uint8_t> GetSomeTestImage(size_t xsize, size_t ysize, + size_t num_channels, uint16_t seed); + +class TestImage { + public: + TestImage(); + + extras::PackedPixelFile& ppf() { return ppf_; } + + TestImage& DecodeFromBytes(const std::vector<uint8_t>& bytes); + + TestImage& ClearMetadata(); + + TestImage& SetDimensions(size_t xsize, size_t ysize); + + TestImage& SetChannels(size_t num_channels); + + // Sets the same bit depth on color, alpha and all extra channels. + TestImage& SetAllBitDepths(uint32_t bits_per_sample, + uint32_t exponent_bits_per_sample = 0); + + TestImage& SetDataType(JxlDataType data_type); + + TestImage& SetEndianness(JxlEndianness endianness); + + TestImage& SetRowAlignment(size_t align); + + TestImage& SetColorEncoding(const std::string& description); + + TestImage& CoalesceGIFAnimationWithAlpha(); + + class Frame { + public: + Frame(TestImage* parent, bool is_preview, size_t index); + + void ZeroFill(); + void RandomFill(uint16_t seed = 177); + + void SetValue(size_t y, size_t x, size_t c, float val); + + private: + extras::PackedPixelFile& ppf() const { return parent_->ppf(); } + + extras::PackedFrame& frame() { + return is_preview_ ? *ppf().preview_frame : ppf().frames[index_]; + } + + TestImage* parent_; + bool is_preview_; + size_t index_; + }; + + Frame AddFrame(); + + Frame AddPreview(size_t xsize, size_t ysize); + + private: + extras::PackedPixelFile ppf_; + JxlPixelFormat format_ = {3, JXL_TYPE_UINT8, JXL_LITTLE_ENDIAN, 0}; + + static void CropLayerInfo(size_t xsize, size_t ysize, JxlLayerInfo* info); + + static void CropImage(size_t xsize, size_t ysize, extras::PackedImage* image); + + static JxlDataType DefaultDataType(const JxlBasicInfo& info); +}; + +} // namespace test +} // namespace jxl + +#endif // LIB_JXL_TEST_IMAGE_H_ diff --git a/third_party/jpeg-xl/lib/jxl/test_utils.cc b/third_party/jpeg-xl/lib/jxl/test_utils.cc new file mode 100644 index 0000000000..451f2a0a03 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/test_utils.cc @@ -0,0 +1,805 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/test_utils.h" + +#include <jxl/cms.h> +#include <jxl/cms_interface.h> +#include <jxl/types.h> + +#include <cstddef> +#include <fstream> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "lib/extras/metrics.h" +#include "lib/extras/packed_image_convert.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/float.h" +#include "lib/jxl/base/printf_macros.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_bit_writer.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_external_image.h" +#include "lib/jxl/enc_fields.h" +#include "lib/jxl/enc_frame.h" +#include "lib/jxl/enc_icc_codec.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/frame_header.h" +#include "lib/jxl/icc_codec.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/padded_bytes.h" + +#if !defined(TEST_DATA_PATH) +#include "tools/cpp/runfiles/runfiles.h" +#endif + +namespace jxl { +namespace test { + +#if defined(TEST_DATA_PATH) +std::string GetTestDataPath(const std::string& filename) { + return std::string(TEST_DATA_PATH "/") + filename; +} +#else +using bazel::tools::cpp::runfiles::Runfiles; +const std::unique_ptr<Runfiles> kRunfiles(Runfiles::Create("")); +std::string GetTestDataPath(const std::string& filename) { + std::string root(JPEGXL_ROOT_PACKAGE "/testdata/"); + return kRunfiles->Rlocation(root + filename); +} +#endif + +std::vector<uint8_t> ReadTestData(const std::string& filename) { + std::string full_path = GetTestDataPath(filename); + fprintf(stderr, "ReadTestData %s\n", full_path.c_str()); + std::ifstream file(full_path, std::ios::binary); + std::vector<char> str((std::istreambuf_iterator<char>(file)), + std::istreambuf_iterator<char>()); + JXL_CHECK(file.good()); + const uint8_t* raw = reinterpret_cast<const uint8_t*>(str.data()); + std::vector<uint8_t> data(raw, raw + str.size()); + printf("Test data %s is %d bytes long.\n", filename.c_str(), + static_cast<int>(data.size())); + return data; +} + +void DefaultAcceptedFormats(extras::JXLDecompressParams& dparams) { + if (dparams.accepted_formats.empty()) { + for (const uint32_t num_channels : {1, 2, 3, 4}) { + dparams.accepted_formats.push_back( + {num_channels, JXL_TYPE_FLOAT, JXL_LITTLE_ENDIAN, /*align=*/0}); + } + } +} + +Status DecodeFile(extras::JXLDecompressParams dparams, + const Span<const uint8_t> file, CodecInOut* JXL_RESTRICT io, + ThreadPool* pool) { + DefaultAcceptedFormats(dparams); + SetThreadParallelRunner(dparams, pool); + extras::PackedPixelFile ppf; + JXL_RETURN_IF_ERROR(DecodeImageJXL(file.data(), file.size(), dparams, + /*decoded_bytes=*/nullptr, &ppf)); + JXL_RETURN_IF_ERROR(ConvertPackedPixelFileToCodecInOut(ppf, pool, io)); + return true; +} + +void JxlBasicInfoSetFromPixelFormat(JxlBasicInfo* basic_info, + const JxlPixelFormat* pixel_format) { + JxlEncoderInitBasicInfo(basic_info); + switch (pixel_format->data_type) { + case JXL_TYPE_FLOAT: + basic_info->bits_per_sample = 32; + basic_info->exponent_bits_per_sample = 8; + break; + case JXL_TYPE_FLOAT16: + basic_info->bits_per_sample = 16; + basic_info->exponent_bits_per_sample = 5; + break; + case JXL_TYPE_UINT8: + basic_info->bits_per_sample = 8; + basic_info->exponent_bits_per_sample = 0; + break; + case JXL_TYPE_UINT16: + basic_info->bits_per_sample = 16; + basic_info->exponent_bits_per_sample = 0; + break; + default: + JXL_ABORT("Unhandled JxlDataType"); + } + if (pixel_format->num_channels < 3) { + basic_info->num_color_channels = 1; + } else { + basic_info->num_color_channels = 3; + } + if (pixel_format->num_channels == 2 || pixel_format->num_channels == 4) { + basic_info->alpha_exponent_bits = basic_info->exponent_bits_per_sample; + basic_info->alpha_bits = basic_info->bits_per_sample; + basic_info->num_extra_channels = 1; + } else { + basic_info->alpha_exponent_bits = 0; + basic_info->alpha_bits = 0; + } +} + +ColorEncoding ColorEncodingFromDescriptor(const ColorEncodingDescriptor& desc) { + ColorEncoding c; + c.SetColorSpace(desc.color_space); + if (desc.color_space != ColorSpace::kXYB) { + JXL_CHECK(c.SetWhitePointType(desc.white_point)); + if (desc.color_space != ColorSpace::kGray) { + JXL_CHECK(c.SetPrimariesType(desc.primaries)); + } + c.Tf().SetTransferFunction(desc.tf); + } + c.SetRenderingIntent(desc.rendering_intent); + JXL_CHECK(c.CreateICC()); + return c; +} + +namespace { +void CheckSameEncodings(const std::vector<ColorEncoding>& a, + const std::vector<ColorEncoding>& b, + const std::string& check_name, + std::stringstream& failures) { + JXL_CHECK(a.size() == b.size()); + for (size_t i = 0; i < a.size(); ++i) { + if ((a[i].ICC() == b[i].ICC()) || + ((a[i].GetPrimariesType() == b[i].GetPrimariesType()) && + a[i].Tf().IsSame(b[i].Tf()))) { + continue; + } + failures << "CheckSameEncodings " << check_name << ": " << i + << "-th encoding mismatch\n"; + } +} +} // namespace + +bool Roundtrip(const CodecInOut* io, const CompressParams& cparams, + extras::JXLDecompressParams dparams, + CodecInOut* JXL_RESTRICT io2, std::stringstream& failures, + size_t* compressed_size, ThreadPool* pool) { + DefaultAcceptedFormats(dparams); + if (compressed_size) { + *compressed_size = static_cast<size_t>(-1); + } + std::vector<uint8_t> compressed; + + std::vector<ColorEncoding> original_metadata_encodings; + std::vector<ColorEncoding> original_current_encodings; + std::vector<ColorEncoding> metadata_encodings_1; + std::vector<ColorEncoding> metadata_encodings_2; + std::vector<ColorEncoding> current_encodings_2; + original_metadata_encodings.reserve(io->frames.size()); + original_current_encodings.reserve(io->frames.size()); + metadata_encodings_1.reserve(io->frames.size()); + metadata_encodings_2.reserve(io->frames.size()); + current_encodings_2.reserve(io->frames.size()); + + for (const ImageBundle& ib : io->frames) { + // Remember original encoding, will be returned by decoder. + original_metadata_encodings.push_back(ib.metadata()->color_encoding); + // c_current should not change during encoding. + original_current_encodings.push_back(ib.c_current()); + } + + JXL_CHECK(test::EncodeFile(cparams, io, &compressed, pool)); + + for (const ImageBundle& ib1 : io->frames) { + metadata_encodings_1.push_back(ib1.metadata()->color_encoding); + } + + // Should still be in the same color space after encoding. + CheckSameEncodings(metadata_encodings_1, original_metadata_encodings, + "original vs after encoding", failures); + + JXL_CHECK(DecodeFile(dparams, Bytes(compressed), io2, pool)); + JXL_CHECK(io2->frames.size() == io->frames.size()); + + for (const ImageBundle& ib2 : io2->frames) { + metadata_encodings_2.push_back(ib2.metadata()->color_encoding); + current_encodings_2.push_back(ib2.c_current()); + } + + // We always produce the original color encoding if a color transform hook is + // set. + CheckSameEncodings(current_encodings_2, original_current_encodings, + "current: original vs decoded", failures); + + // Decoder returns the originals passed to the encoder. + CheckSameEncodings(metadata_encodings_2, original_metadata_encodings, + "metadata: original vs decoded", failures); + + if (compressed_size) { + *compressed_size = compressed.size(); + } + + return failures.str().empty(); +} + +size_t Roundtrip(const extras::PackedPixelFile& ppf_in, + extras::JXLCompressParams cparams, + extras::JXLDecompressParams dparams, ThreadPool* pool, + extras::PackedPixelFile* ppf_out) { + DefaultAcceptedFormats(dparams); + SetThreadParallelRunner(cparams, pool); + SetThreadParallelRunner(dparams, pool); + std::vector<uint8_t> compressed; + JXL_CHECK(extras::EncodeImageJXL(cparams, ppf_in, /*jpeg_bytes=*/nullptr, + &compressed)); + size_t decoded_bytes = 0; + JXL_CHECK(extras::DecodeImageJXL(compressed.data(), compressed.size(), + dparams, &decoded_bytes, ppf_out)); + JXL_CHECK(decoded_bytes == compressed.size()); + return compressed.size(); +} + +std::vector<ColorEncodingDescriptor> AllEncodings() { + std::vector<ColorEncodingDescriptor> all_encodings; + all_encodings.reserve(300); + + for (ColorSpace cs : Values<ColorSpace>()) { + if (cs == ColorSpace::kUnknown || cs == ColorSpace::kXYB || + cs == ColorSpace::kGray) { + continue; + } + + for (WhitePoint wp : Values<WhitePoint>()) { + if (wp == WhitePoint::kCustom) continue; + for (Primaries primaries : Values<Primaries>()) { + if (primaries == Primaries::kCustom) continue; + for (TransferFunction tf : Values<TransferFunction>()) { + if (tf == TransferFunction::kUnknown) continue; + for (RenderingIntent ri : Values<RenderingIntent>()) { + ColorEncodingDescriptor cdesc; + cdesc.color_space = cs; + cdesc.white_point = wp; + cdesc.primaries = primaries; + cdesc.tf = tf; + cdesc.rendering_intent = ri; + all_encodings.push_back(cdesc); + } + } + } + } + } + + return all_encodings; +} + +jxl::CodecInOut SomeTestImageToCodecInOut(const std::vector<uint8_t>& buf, + size_t num_channels, size_t xsize, + size_t ysize) { + jxl::CodecInOut io; + io.SetSize(xsize, ysize); + io.metadata.m.SetAlphaBits(16); + io.metadata.m.color_encoding = jxl::ColorEncoding::SRGB( + /*is_gray=*/num_channels == 1 || num_channels == 2); + JxlPixelFormat format = {static_cast<uint32_t>(num_channels), JXL_TYPE_UINT16, + JXL_BIG_ENDIAN, 0}; + JXL_CHECK(ConvertFromExternal( + jxl::Bytes(buf.data(), buf.size()), xsize, ysize, + jxl::ColorEncoding::SRGB(/*is_gray=*/num_channels < 3), + /*bits_per_sample=*/16, format, + /*pool=*/nullptr, + /*ib=*/&io.Main())); + return io; +} + +bool Near(double expected, double value, double max_dist) { + double dist = expected > value ? expected - value : value - expected; + return dist <= max_dist; +} + +float LoadLEFloat16(const uint8_t* p) { + uint16_t bits16 = LoadLE16(p); + return LoadFloat16(bits16); +} + +float LoadBEFloat16(const uint8_t* p) { + uint16_t bits16 = LoadBE16(p); + return LoadFloat16(bits16); +} + +size_t GetPrecision(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_UINT8: + return 8; + case JXL_TYPE_UINT16: + return 16; + case JXL_TYPE_FLOAT: + // Floating point mantissa precision + return 24; + case JXL_TYPE_FLOAT16: + return 11; + default: + JXL_ABORT("Unhandled JxlDataType"); + } +} + +size_t GetDataBits(JxlDataType data_type) { + switch (data_type) { + case JXL_TYPE_UINT8: + return 8; + case JXL_TYPE_UINT16: + return 16; + case JXL_TYPE_FLOAT: + return 32; + case JXL_TYPE_FLOAT16: + return 16; + default: + JXL_ABORT("Unhandled JxlDataType"); + } +} + +std::vector<double> ConvertToRGBA32(const uint8_t* pixels, size_t xsize, + size_t ysize, const JxlPixelFormat& format, + double factor) { + std::vector<double> result(xsize * ysize * 4); + size_t num_channels = format.num_channels; + bool gray = num_channels == 1 || num_channels == 2; + bool alpha = num_channels == 2 || num_channels == 4; + JxlEndianness endianness = format.endianness; + // Compute actual type: + if (endianness == JXL_NATIVE_ENDIAN) { + endianness = IsLittleEndian() ? JXL_LITTLE_ENDIAN : JXL_BIG_ENDIAN; + } + + size_t stride = + xsize * jxl::DivCeil(GetDataBits(format.data_type) * num_channels, + jxl::kBitsPerByte); + if (format.align > 1) stride = jxl::RoundUpTo(stride, format.align); + + if (format.data_type == JXL_TYPE_UINT8) { + // Multiplier to bring to 0-1.0 range + double mul = factor > 0.0 ? factor : 1.0 / 255.0; + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels; + double r = pixels[i]; + double g = gray ? r : pixels[i + 1]; + double b = gray ? r : pixels[i + 2]; + double a = alpha ? pixels[i + num_channels - 1] : 255; + result[j + 0] = r * mul; + result[j + 1] = g * mul; + result[j + 2] = b * mul; + result[j + 3] = a * mul; + } + } + } else if (format.data_type == JXL_TYPE_UINT16) { + JXL_ASSERT(endianness != JXL_NATIVE_ENDIAN); + // Multiplier to bring to 0-1.0 range + double mul = factor > 0.0 ? factor : 1.0 / 65535.0; + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels * 2; + double r, g, b, a; + if (endianness == JXL_BIG_ENDIAN) { + r = (pixels[i + 0] << 8) + pixels[i + 1]; + g = gray ? r : (pixels[i + 2] << 8) + pixels[i + 3]; + b = gray ? r : (pixels[i + 4] << 8) + pixels[i + 5]; + a = alpha ? (pixels[i + num_channels * 2 - 2] << 8) + + pixels[i + num_channels * 2 - 1] + : 65535; + } else { + r = (pixels[i + 1] << 8) + pixels[i + 0]; + g = gray ? r : (pixels[i + 3] << 8) + pixels[i + 2]; + b = gray ? r : (pixels[i + 5] << 8) + pixels[i + 4]; + a = alpha ? (pixels[i + num_channels * 2 - 1] << 8) + + pixels[i + num_channels * 2 - 2] + : 65535; + } + result[j + 0] = r * mul; + result[j + 1] = g * mul; + result[j + 2] = b * mul; + result[j + 3] = a * mul; + } + } + } else if (format.data_type == JXL_TYPE_FLOAT) { + JXL_ASSERT(endianness != JXL_NATIVE_ENDIAN); + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels * 4; + double r, g, b, a; + if (endianness == JXL_BIG_ENDIAN) { + r = LoadBEFloat(pixels + i); + g = gray ? r : LoadBEFloat(pixels + i + 4); + b = gray ? r : LoadBEFloat(pixels + i + 8); + a = alpha ? LoadBEFloat(pixels + i + num_channels * 4 - 4) : 1.0; + } else { + r = LoadLEFloat(pixels + i); + g = gray ? r : LoadLEFloat(pixels + i + 4); + b = gray ? r : LoadLEFloat(pixels + i + 8); + a = alpha ? LoadLEFloat(pixels + i + num_channels * 4 - 4) : 1.0; + } + result[j + 0] = r; + result[j + 1] = g; + result[j + 2] = b; + result[j + 3] = a; + } + } + } else if (format.data_type == JXL_TYPE_FLOAT16) { + JXL_ASSERT(endianness != JXL_NATIVE_ENDIAN); + for (size_t y = 0; y < ysize; ++y) { + for (size_t x = 0; x < xsize; ++x) { + size_t j = (y * xsize + x) * 4; + size_t i = y * stride + x * num_channels * 2; + double r, g, b, a; + if (endianness == JXL_BIG_ENDIAN) { + r = LoadBEFloat16(pixels + i); + g = gray ? r : LoadBEFloat16(pixels + i + 2); + b = gray ? r : LoadBEFloat16(pixels + i + 4); + a = alpha ? LoadBEFloat16(pixels + i + num_channels * 2 - 2) : 1.0; + } else { + r = LoadLEFloat16(pixels + i); + g = gray ? r : LoadLEFloat16(pixels + i + 2); + b = gray ? r : LoadLEFloat16(pixels + i + 4); + a = alpha ? LoadLEFloat16(pixels + i + num_channels * 2 - 2) : 1.0; + } + result[j + 0] = r; + result[j + 1] = g; + result[j + 2] = b; + result[j + 3] = a; + } + } + } else { + JXL_ASSERT(false); // Unsupported type + } + return result; +} + +size_t ComparePixels(const uint8_t* a, const uint8_t* b, size_t xsize, + size_t ysize, const JxlPixelFormat& format_a, + const JxlPixelFormat& format_b, + double threshold_multiplier) { + // Convert both images to equal full precision for comparison. + std::vector<double> a_full = ConvertToRGBA32(a, xsize, ysize, format_a); + std::vector<double> b_full = ConvertToRGBA32(b, xsize, ysize, format_b); + bool gray_a = format_a.num_channels < 3; + bool gray_b = format_b.num_channels < 3; + bool alpha_a = !(format_a.num_channels & 1); + bool alpha_b = !(format_b.num_channels & 1); + size_t bits_a = GetPrecision(format_a.data_type); + size_t bits_b = GetPrecision(format_b.data_type); + size_t bits = std::min(bits_a, bits_b); + // How much distance is allowed in case of pixels with lower bit depths, given + // that the double precision float images use range 0-1.0. + // E.g. in case of 1-bit this is 0.5 since 0.499 must map to 0 and 0.501 must + // map to 1. + double precision = 0.5 * threshold_multiplier / ((1ull << bits) - 1ull); + if (format_a.data_type == JXL_TYPE_FLOAT16 || + format_b.data_type == JXL_TYPE_FLOAT16) { + // Lower the precision for float16, because it currently looks like the + // scalar and wasm implementations of hwy have 1 less bit of precision + // than the x86 implementations. + // TODO(lode): Set the required precision back to 11 bits when possible. + precision = 0.5 * threshold_multiplier / ((1ull << (bits - 1)) - 1ull); + } + if (format_b.data_type == JXL_TYPE_UINT8) { + // Increase the threshold by the maximum difference introduced by dithering. + precision += 63.0 / 128.0; + } + size_t numdiff = 0; + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + size_t i = (y * xsize + x) * 4; + bool ok = true; + if (gray_a || gray_b) { + if (!Near(a_full[i + 0], b_full[i + 0], precision)) ok = false; + // If the input was grayscale and the output not, then the output must + // have all channels equal. + if (gray_a && b_full[i + 0] != b_full[i + 1] && + b_full[i + 2] != b_full[i + 2]) { + ok = false; + } + } else { + if (!Near(a_full[i + 0], b_full[i + 0], precision) || + !Near(a_full[i + 1], b_full[i + 1], precision) || + !Near(a_full[i + 2], b_full[i + 2], precision)) { + ok = false; + } + } + if (alpha_a && alpha_b) { + if (!Near(a_full[i + 3], b_full[i + 3], precision)) ok = false; + } else { + // If the input had no alpha channel, the output should be opaque + // after roundtrip. + if (alpha_b && !Near(1.0, b_full[i + 3], precision)) ok = false; + } + if (!ok) numdiff++; + } + } + return numdiff; +} + +double DistanceRMS(const uint8_t* a, const uint8_t* b, size_t xsize, + size_t ysize, const JxlPixelFormat& format) { + // Convert both images to equal full precision for comparison. + std::vector<double> a_full = ConvertToRGBA32(a, xsize, ysize, format); + std::vector<double> b_full = ConvertToRGBA32(b, xsize, ysize, format); + double sum = 0.0; + for (size_t y = 0; y < ysize; y++) { + double row_sum = 0.0; + for (size_t x = 0; x < xsize; x++) { + size_t i = (y * xsize + x) * 4; + for (size_t c = 0; c < format.num_channels; ++c) { + double diff = a_full[i + c] - b_full[i + c]; + row_sum += diff * diff; + } + } + sum += row_sum; + } + sum /= (xsize * ysize); + return sqrt(sum); +} + +float ButteraugliDistance(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b, ThreadPool* pool) { + CodecInOut io0; + JXL_CHECK(ConvertPackedPixelFileToCodecInOut(a, pool, &io0)); + CodecInOut io1; + JXL_CHECK(ConvertPackedPixelFileToCodecInOut(b, pool, &io1)); + // TODO(eustas): simplify? + return ButteraugliDistance(io0.frames, io1.frames, ButteraugliParams(), + *JxlGetDefaultCms(), + /*distmap=*/nullptr, pool); +} + +float Butteraugli3Norm(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b, ThreadPool* pool) { + CodecInOut io0; + JXL_CHECK(ConvertPackedPixelFileToCodecInOut(a, pool, &io0)); + CodecInOut io1; + JXL_CHECK(ConvertPackedPixelFileToCodecInOut(b, pool, &io1)); + ButteraugliParams ba; + ImageF distmap; + ButteraugliDistance(io0.frames, io1.frames, ba, *JxlGetDefaultCms(), &distmap, + pool); + return ComputeDistanceP(distmap, ba, 3); +} + +float ComputeDistance2(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b) { + CodecInOut io0; + JXL_CHECK(ConvertPackedPixelFileToCodecInOut(a, nullptr, &io0)); + CodecInOut io1; + JXL_CHECK(ConvertPackedPixelFileToCodecInOut(b, nullptr, &io1)); + return ComputeDistance2(io0.Main(), io1.Main(), *JxlGetDefaultCms()); +} + +float ComputePSNR(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b) { + CodecInOut io0; + JXL_CHECK(ConvertPackedPixelFileToCodecInOut(a, nullptr, &io0)); + CodecInOut io1; + JXL_CHECK(ConvertPackedPixelFileToCodecInOut(b, nullptr, &io1)); + return ComputePSNR(io0.Main(), io1.Main(), *JxlGetDefaultCms()); +} + +bool SameAlpha(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b) { + JXL_CHECK(a.info.xsize == b.info.xsize); + JXL_CHECK(a.info.ysize == b.info.ysize); + JXL_CHECK(a.info.alpha_bits == b.info.alpha_bits); + JXL_CHECK(a.info.alpha_exponent_bits == b.info.alpha_exponent_bits); + JXL_CHECK(a.info.alpha_bits > 0); + JXL_CHECK(a.frames.size() == b.frames.size()); + for (size_t i = 0; i < a.frames.size(); ++i) { + const extras::PackedImage& color_a = a.frames[i].color; + const extras::PackedImage& color_b = b.frames[i].color; + JXL_CHECK(color_a.format.num_channels == color_b.format.num_channels); + JXL_CHECK(color_a.format.data_type == color_b.format.data_type); + JXL_CHECK(color_a.format.endianness == color_b.format.endianness); + JXL_CHECK(color_a.pixels_size == color_b.pixels_size); + size_t pwidth = + extras::PackedImage::BitsPerChannel(color_a.format.data_type) / 8; + size_t num_color = color_a.format.num_channels < 3 ? 1 : 3; + const uint8_t* p_a = reinterpret_cast<const uint8_t*>(color_a.pixels()); + const uint8_t* p_b = reinterpret_cast<const uint8_t*>(color_b.pixels()); + for (size_t y = 0; y < a.info.ysize; ++y) { + for (size_t x = 0; x < a.info.xsize; ++x) { + size_t idx = + ((y * a.info.xsize + x) * color_a.format.num_channels + num_color) * + pwidth; + if (memcmp(&p_a[idx], &p_b[idx], pwidth) != 0) { + return false; + } + } + } + } + return true; +} + +bool SamePixels(const extras::PackedImage& a, const extras::PackedImage& b) { + JXL_CHECK(a.xsize == b.xsize); + JXL_CHECK(a.ysize == b.ysize); + JXL_CHECK(a.format.num_channels == b.format.num_channels); + JXL_CHECK(a.format.data_type == b.format.data_type); + JXL_CHECK(a.format.endianness == b.format.endianness); + JXL_CHECK(a.pixels_size == b.pixels_size); + const uint8_t* p_a = reinterpret_cast<const uint8_t*>(a.pixels()); + const uint8_t* p_b = reinterpret_cast<const uint8_t*>(b.pixels()); + for (size_t y = 0; y < a.ysize; ++y) { + for (size_t x = 0; x < a.xsize; ++x) { + size_t idx = (y * a.xsize + x) * a.pixel_stride(); + if (memcmp(&p_a[idx], &p_b[idx], a.pixel_stride()) != 0) { + printf("Mismatch at row %" PRIuS " col %" PRIuS "\n", y, x); + printf(" a: "); + for (size_t j = 0; j < a.pixel_stride(); ++j) { + printf(" %3u", p_a[idx + j]); + } + printf("\n b: "); + for (size_t j = 0; j < a.pixel_stride(); ++j) { + printf(" %3u", p_b[idx + j]); + } + printf("\n"); + return false; + } + } + } + return true; +} + +bool SamePixels(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b) { + JXL_CHECK(a.info.xsize == b.info.xsize); + JXL_CHECK(a.info.ysize == b.info.ysize); + JXL_CHECK(a.info.bits_per_sample == b.info.bits_per_sample); + JXL_CHECK(a.info.exponent_bits_per_sample == b.info.exponent_bits_per_sample); + JXL_CHECK(a.frames.size() == b.frames.size()); + for (size_t i = 0; i < a.frames.size(); ++i) { + const auto& frame_a = a.frames[i]; + const auto& frame_b = b.frames[i]; + if (!SamePixels(frame_a.color, frame_b.color)) { + return false; + } + JXL_CHECK(frame_a.extra_channels.size() == frame_b.extra_channels.size()); + for (size_t j = 0; j < frame_a.extra_channels.size(); ++j) { + if (!SamePixels(frame_a.extra_channels[i], frame_b.extra_channels[i])) { + return false; + } + } + } + return true; +} + +Status ReadICC(BitReader* JXL_RESTRICT reader, + std::vector<uint8_t>* JXL_RESTRICT icc, size_t output_limit) { + icc->clear(); + ICCReader icc_reader; + PaddedBytes icc_buffer; + JXL_RETURN_IF_ERROR(icc_reader.Init(reader, output_limit)); + JXL_RETURN_IF_ERROR(icc_reader.Process(reader, &icc_buffer)); + Bytes(icc_buffer).AppendTo(icc); + return true; +} + +namespace { // For EncodeFile +Status PrepareCodecMetadataFromIO(const CompressParams& cparams, + const CodecInOut* io, + CodecMetadata* metadata) { + *metadata = io->metadata; + size_t ups = 1; + if (cparams.already_downsampled) ups = cparams.resampling; + + JXL_RETURN_IF_ERROR(metadata->size.Set(io->xsize() * ups, io->ysize() * ups)); + + // Keep ICC profile in lossless modes because a reconstructed profile may be + // slightly different (quantization). + // Also keep ICC in JPEG reconstruction mode as we need byte-exact profiles. + if (!cparams.IsLossless() && !io->Main().IsJPEG() && cparams.cms_set) { + metadata->m.color_encoding.DecideIfWantICC(cparams.cms); + } + + metadata->m.xyb_encoded = + cparams.color_transform == ColorTransform::kXYB ? true : false; + + // TODO(firsching): move this EncodeFile to test_utils / re-implement this + // using API functions + return true; +} + +Status EncodePreview(const CompressParams& cparams, const ImageBundle& ib, + const CodecMetadata* metadata, const JxlCmsInterface& cms, + ThreadPool* pool, BitWriter* JXL_RESTRICT writer) { + BitWriter preview_writer; + // TODO(janwas): also support generating preview by downsampling + if (ib.HasColor()) { + AuxOut aux_out; + // TODO(lode): check if we want all extra channels and matching xyb_encoded + // for the preview, such that using the main ImageMetadata object for + // encoding this frame is warrented. + FrameInfo frame_info; + frame_info.is_preview = true; + JXL_RETURN_IF_ERROR(EncodeFrame(cparams, frame_info, metadata, ib, cms, + pool, &preview_writer, &aux_out)); + preview_writer.ZeroPadToByte(); + } + + if (preview_writer.BitsWritten() != 0) { + writer->ZeroPadToByte(); + writer->AppendByteAligned(preview_writer); + } + + return true; +} + +} // namespace + +Status EncodeFile(const CompressParams& params, const CodecInOut* io, + std::vector<uint8_t>* compressed, ThreadPool* pool) { + compressed->clear(); + const JxlCmsInterface& cms = *JxlGetDefaultCms(); + io->CheckMetadata(); + BitWriter writer; + + CompressParams cparams = params; + if (io->Main().color_transform != ColorTransform::kNone) { + // Set the color transform to YCbCr or XYB if the original image is such. + cparams.color_transform = io->Main().color_transform; + } + + JXL_RETURN_IF_ERROR(ParamsPostInit(&cparams)); + + std::unique_ptr<CodecMetadata> metadata = jxl::make_unique<CodecMetadata>(); + JXL_RETURN_IF_ERROR(PrepareCodecMetadataFromIO(cparams, io, metadata.get())); + JXL_RETURN_IF_ERROR( + WriteCodestreamHeaders(metadata.get(), &writer, /*aux_out*/ nullptr)); + + // Only send ICC (at least several hundred bytes) if fields aren't enough. + if (metadata->m.color_encoding.WantICC()) { + JXL_RETURN_IF_ERROR(WriteICC(metadata->m.color_encoding.ICC(), &writer, + kLayerHeader, /* aux_out */ nullptr)); + } + + if (metadata->m.have_preview) { + JXL_RETURN_IF_ERROR(EncodePreview(cparams, io->preview_frame, + metadata.get(), cms, pool, &writer)); + } + + // Each frame should start on byte boundaries. + BitWriter::Allotment allotment(&writer, 8); + writer.ZeroPadToByte(); + allotment.ReclaimAndCharge(&writer, kLayerHeader, /* aux_out */ nullptr); + + for (size_t i = 0; i < io->frames.size(); i++) { + FrameInfo info; + info.is_last = i == io->frames.size() - 1; + if (io->frames[i].use_for_next_frame) { + info.save_as_reference = 1; + } + JXL_RETURN_IF_ERROR(EncodeFrame(cparams, info, metadata.get(), + io->frames[i], cms, pool, &writer, + /* aux_out */ nullptr)); + } + + PaddedBytes output = std::move(writer).TakeBytes(); + Bytes(output).AppendTo(compressed); + return true; +} + +} // namespace test + +bool operator==(const jxl::Bytes& a, const jxl::Bytes& b) { + if (a.size() != b.size()) return false; + if (memcmp(a.data(), b.data(), a.size()) != 0) return false; + return true; +} + +// Allow using EXPECT_EQ on jxl::Bytes +bool operator!=(const jxl::Bytes& a, const jxl::Bytes& b) { return !(a == b); } + +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/test_utils.h b/third_party/jpeg-xl/lib/jxl/test_utils.h new file mode 100644 index 0000000000..6734380bf5 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/test_utils.h @@ -0,0 +1,200 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_TEST_UTILS_H_ +#define LIB_JXL_TEST_UTILS_H_ + +// TODO(eustas): reduce includes (move to .cc) + +// Macros and functions useful for tests. + +#include <jxl/codestream_header.h> +#include <jxl/thread_parallel_runner_cxx.h> + +#include <cstddef> +#include <cstdint> +#include <ostream> +#include <vector> + +#include "lib/extras/dec/decode.h" +#include "lib/extras/dec/jxl.h" +#include "lib/extras/enc/jxl.h" +#include "lib/extras/packed_image.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/codec_in_out.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/enc_params.h" + +#define TEST_LIBJPEG_SUPPORT() \ + do { \ + if (!jxl::extras::CanDecode(jxl::extras::Codec::kJPG)) { \ + fprintf(stderr, "Skipping test because of missing libjpeg codec.\n"); \ + return; \ + } \ + } while (0) + +namespace jxl { + +struct AuxOut; +class CodecInOut; +class PaddedBytes; +struct PassesEncoderState; +class ThreadPool; + +namespace test { + +std::string GetTestDataPath(const std::string& filename); +std::vector<uint8_t> ReadTestData(const std::string& filename); + +void JxlBasicInfoSetFromPixelFormat(JxlBasicInfo* basic_info, + const JxlPixelFormat* pixel_format); + +void DefaultAcceptedFormats(extras::JXLDecompressParams& dparams); + +template <typename Params> +void SetThreadParallelRunner(Params params, ThreadPool* pool) { + if (pool && !params.runner_opaque) { + params.runner = pool->runner(); + params.runner_opaque = pool->runner_opaque(); + } +} + +Status DecodeFile(extras::JXLDecompressParams dparams, + const Span<const uint8_t> file, CodecInOut* JXL_RESTRICT io, + ThreadPool* pool = nullptr); + +bool Roundtrip(const CodecInOut* io, const CompressParams& cparams, + extras::JXLDecompressParams dparams, + CodecInOut* JXL_RESTRICT io2, std::stringstream& failures, + size_t* compressed_size = nullptr, ThreadPool* pool = nullptr); + +// Returns compressed size [bytes]. +size_t Roundtrip(const extras::PackedPixelFile& ppf_in, + extras::JXLCompressParams cparams, + extras::JXLDecompressParams dparams, ThreadPool* pool, + extras::PackedPixelFile* ppf_out); + +// A POD descriptor of a ColorEncoding. Only used in tests as the return value +// of AllEncodings(). +struct ColorEncodingDescriptor { + ColorSpace color_space; + WhitePoint white_point; + Primaries primaries; + TransferFunction tf; + RenderingIntent rendering_intent; +}; + +ColorEncoding ColorEncodingFromDescriptor(const ColorEncodingDescriptor& desc); + +// Define the operator<< for tests. +static inline ::std::ostream& operator<<(::std::ostream& os, + const ColorEncodingDescriptor& c) { + return os << "ColorEncoding/" << Description(ColorEncodingFromDescriptor(c)); +} + +// Returns ColorEncodingDescriptors, which are only used in tests. To obtain a +// ColorEncoding object call ColorEncodingFromDescriptor and then call +// ColorEncoding::CreateProfile() on that object to generate a profile. +std::vector<ColorEncodingDescriptor> AllEncodings(); + +// Returns a CodecInOut based on the buf, xsize, ysize, and the assumption +// that the buffer was created using `GetSomeTestImage`. +jxl::CodecInOut SomeTestImageToCodecInOut(const std::vector<uint8_t>& buf, + size_t num_channels, size_t xsize, + size_t ysize); + +bool Near(double expected, double value, double max_dist); + +float LoadLEFloat16(const uint8_t* p); + +float LoadBEFloat16(const uint8_t* p); + +size_t GetPrecision(JxlDataType data_type); + +size_t GetDataBits(JxlDataType data_type); + +// Procedure to convert pixels to double precision, not efficient, but +// well-controlled for testing. It uses double, to be able to represent all +// precisions needed for the maximum data types the API supports: uint32_t +// integers, and, single precision float. The values are in range 0-1 for SDR. +std::vector<double> ConvertToRGBA32(const uint8_t* pixels, size_t xsize, + size_t ysize, const JxlPixelFormat& format, + double factor = 0.0); + +// Returns amount of pixels which differ between the two pictures. Image b is +// the image after roundtrip after roundtrip, image a before roundtrip. There +// are more strict requirements for the alpha channel and grayscale values of +// the output image. +size_t ComparePixels(const uint8_t* a, const uint8_t* b, size_t xsize, + size_t ysize, const JxlPixelFormat& format_a, + const JxlPixelFormat& format_b, + double threshold_multiplier = 1.0); + +double DistanceRMS(const uint8_t* a, const uint8_t* b, size_t xsize, + size_t ysize, const JxlPixelFormat& format); + +float ButteraugliDistance(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b, + ThreadPool* pool = nullptr); + +float Butteraugli3Norm(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b, + ThreadPool* pool = nullptr); + +float ComputeDistance2(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b); + +float ComputePSNR(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b); + +bool SameAlpha(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b); + +bool SamePixels(const extras::PackedImage& a, const extras::PackedImage& b); + +bool SamePixels(const extras::PackedPixelFile& a, + const extras::PackedPixelFile& b); + +class ThreadPoolForTests { + public: + explicit ThreadPoolForTests(int num_threads) { + runner_ = + JxlThreadParallelRunnerMake(/* memory_manager */ nullptr, num_threads); + pool_ = + jxl::make_unique<ThreadPool>(JxlThreadParallelRunner, runner_.get()); + } + ThreadPoolForTests(const ThreadPoolForTests&) = delete; + ThreadPoolForTests& operator&(const ThreadPoolForTests&) = delete; + ThreadPool* operator&() { return pool_.get(); } + + private: + JxlThreadParallelRunnerPtr runner_; + std::unique_ptr<ThreadPool> pool_; +}; + +// `icc` may be empty afterwards - if so, call CreateProfile. Does not append, +// clears any original data that was in icc. +// If `output_limit` is not 0, then returns error if resulting profile would be +// longer than `output_limit` +Status ReadICC(BitReader* JXL_RESTRICT reader, + std::vector<uint8_t>* JXL_RESTRICT icc, size_t output_limit = 0); + +// Compresses pixels from `io` (given in any ColorEncoding). +// `io->metadata.m.original` must be set. +Status EncodeFile(const CompressParams& params, const CodecInOut* io, + std::vector<uint8_t>* compressed, ThreadPool* pool = nullptr); + +} // namespace test + +bool operator==(const jxl::Bytes& a, const jxl::Bytes& b); + +// Allow using EXPECT_EQ on jxl::Bytes +bool operator!=(const jxl::Bytes& a, const jxl::Bytes& b); + +} // namespace jxl + +#endif // LIB_JXL_TEST_UTILS_H_ diff --git a/third_party/jpeg-xl/lib/jxl/testing.h b/third_party/jpeg-xl/lib/jxl/testing.h new file mode 100644 index 0000000000..5344399c4c --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/testing.h @@ -0,0 +1,79 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_TESTING_H_ +#define LIB_JXL_TESTING_H_ + +// GTest/GMock specific macros / wrappers. + +// gmock unconditionally redefines those macros (to wrong values). +// Lets include it only here and mitigate the problem. +#pragma push_macro("PRIdS") +#pragma push_macro("PRIuS") +#include "gmock/gmock.h" +#pragma pop_macro("PRIuS") +#pragma pop_macro("PRIdS") + +#include "gtest/gtest.h" +// JPEGXL_ENABLE_BOXES, JPEGXL_ENABLE_TRANSCODE_JPEG +#include "lib/jxl/common.h" + +#ifdef JXL_DISABLE_SLOW_TESTS +#define JXL_SLOW_TEST(X) DISABLED_##X +#else +#define JXL_SLOW_TEST(X) X +#endif // JXL_DISABLE_SLOW_TESTS + +#if JPEGXL_ENABLE_TRANSCODE_JPEG +#define JXL_TRANSCODE_JPEG_TEST(X) X +#else +#define JXL_TRANSCODE_JPEG_TEST(X) DISABLED_##X +#endif // JPEGXL_ENABLE_TRANSCODE_JPEG + +#if JPEGXL_ENABLE_BOXES +#define JXL_BOXES_TEST(X) X +#else +#define JXL_BOXES_TEST(X) DISABLED_##X +#endif // JPEGXL_ENABLE_BOXES + +#ifdef THREAD_SANITIZER +#define JXL_TSAN_SLOW_TEST(X) DISABLED_##X +#else +#define JXL_TSAN_SLOW_TEST(X) X +#endif // THREAD_SANITIZER + +#if defined(__x86_64__) +#define JXL_X86_64_TEST(X) X +#else +#define JXL_X86_64_TEST(X) DISABLED_##X +#endif // defined(__x86_64__) + +// googletest before 1.10 didn't define INSTANTIATE_TEST_SUITE_P() but instead +// used INSTANTIATE_TEST_CASE_P which is now deprecated. +#ifdef INSTANTIATE_TEST_SUITE_P +#define JXL_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_SUITE_P +#else +#define JXL_GTEST_INSTANTIATE_TEST_SUITE_P INSTANTIATE_TEST_CASE_P +#endif + +// Ensures that we don't make our test bounds too lax, effectively disabling the +// tests. +MATCHER_P(IsSlightlyBelow, max, "") { + return max * 0.75 <= arg && arg <= max * 1.0; +} + +#define JXL_EXPECT_OK(F) \ + { \ + std::stringstream _; \ + EXPECT_TRUE(F) << _.str(); \ + } + +#define JXL_ASSERT_OK(F) \ + { \ + std::stringstream _; \ + ASSERT_TRUE(F) << _.str(); \ + } + +#endif // LIB_JXL_TESTING_H_ diff --git a/third_party/jpeg-xl/lib/jxl/tf_gbench.cc b/third_party/jpeg-xl/lib/jxl/tf_gbench.cc new file mode 100644 index 0000000000..e93a936c90 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/tf_gbench.cc @@ -0,0 +1,145 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "benchmark/benchmark.h" +#include "lib/jxl/image_ops.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/tf_gbench.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/cms/transfer_functions-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +#define RUN_BENCHMARK(F) \ + constexpr size_t kNum = 1 << 12; \ + HWY_FULL(float) d; \ + /* Three parallel runs, as this will run on R, G and B. */ \ + auto sum1 = Zero(d); \ + auto sum2 = Zero(d); \ + auto sum3 = Zero(d); \ + for (auto _ : state) { \ + auto x = Set(d, 1e-5); \ + auto v1 = Set(d, 1e-5); \ + auto v2 = Set(d, 1.1e-5); \ + auto v3 = Set(d, 1.2e-5); \ + for (size_t i = 0; i < kNum; i++) { \ + sum1 += F(d, v1); \ + sum2 += F(d, v2); \ + sum3 += F(d, v3); \ + v1 += x; \ + v2 += x; \ + v3 += x; \ + } \ + } \ + /* floats per second */ \ + state.SetItemsProcessed(kNum* state.iterations() * Lanes(d) * 3); \ + benchmark::DoNotOptimize(sum1 + sum2 + sum3); + +#define RUN_BENCHMARK_SCALAR(F, I) \ + constexpr size_t kNum = 1 << 12; \ + /* Three parallel runs, as this will run on R, G and B. */ \ + float sum1 = 0, sum2 = 0, sum3 = 0; \ + for (auto _ : state) { \ + float x = 1e-5; \ + float v1 = 1e-5; \ + float v2 = 1.1e-5; \ + float v3 = 1.2e-5; \ + for (size_t i = 0; i < kNum; i++) { \ + sum1 += F(I, v1); \ + sum2 += F(I, v2); \ + sum3 += F(I, v3); \ + v1 += x; \ + v2 += x; \ + v3 += x; \ + } \ + } \ + /* floats per second */ \ + state.SetItemsProcessed(kNum* state.iterations() * 3); \ + benchmark::DoNotOptimize(sum1 + sum2 + sum3); + +HWY_NOINLINE void BM_FastSRGB(benchmark::State& state) { + RUN_BENCHMARK(FastLinearToSRGB); +} + +HWY_NOINLINE void BM_TFSRGB(benchmark::State& state) { + RUN_BENCHMARK(TF_SRGB().EncodedFromDisplay); +} + +HWY_NOINLINE void BM_PQDFE(benchmark::State& state) { + TF_PQ tf_pq(10000.0); + RUN_BENCHMARK(tf_pq.DisplayFromEncoded); +} + +HWY_NOINLINE void BM_PQEFD(benchmark::State& state) { + TF_PQ tf_pq(10000.0); + RUN_BENCHMARK(tf_pq.EncodedFromDisplay); +} + +HWY_NOINLINE void BM_PQSlowDFE(benchmark::State& state) { + RUN_BENCHMARK_SCALAR(TF_PQ_Base::DisplayFromEncoded, 10000.0); +} + +HWY_NOINLINE void BM_PQSlowEFD(benchmark::State& state) { + RUN_BENCHMARK_SCALAR(TF_PQ_Base::EncodedFromDisplay, 10000.0); +} +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +namespace { + +HWY_EXPORT(BM_FastSRGB); +HWY_EXPORT(BM_TFSRGB); +HWY_EXPORT(BM_PQDFE); +HWY_EXPORT(BM_PQEFD); +HWY_EXPORT(BM_PQSlowDFE); +HWY_EXPORT(BM_PQSlowEFD); + +float SRGB_pow(float _, float x) { + return x < 0.0031308f ? 12.92f * x : 1.055f * powf(x, 1.0f / 2.4f) - 0.055f; +} + +void BM_FastSRGB(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_FastSRGB)(state); +} +void BM_TFSRGB(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_TFSRGB)(state); +} +void BM_PQDFE(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQDFE)(state); +} +void BM_PQEFD(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQEFD)(state); +} +void BM_PQSlowDFE(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQSlowDFE)(state); +} +void BM_PQSlowEFD(benchmark::State& state) { + HWY_DYNAMIC_DISPATCH(BM_PQSlowEFD)(state); +} + +void BM_SRGB_pow(benchmark::State& state) { RUN_BENCHMARK_SCALAR(SRGB_pow, 0); } + +BENCHMARK(BM_FastSRGB); +BENCHMARK(BM_TFSRGB); +BENCHMARK(BM_SRGB_pow); +BENCHMARK(BM_PQDFE); +BENCHMARK(BM_PQEFD); +BENCHMARK(BM_PQSlowDFE); +BENCHMARK(BM_PQSlowEFD); + +} // namespace +} // namespace jxl +#endif diff --git a/third_party/jpeg-xl/lib/jxl/toc.cc b/third_party/jpeg-xl/lib/jxl/toc.cc new file mode 100644 index 0000000000..72c8ac01cd --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/toc.cc @@ -0,0 +1,105 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/toc.h" + +#include <stdint.h> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/coeff_order.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/fields.h" + +namespace jxl { +size_t MaxBits(const size_t num_sizes) { + const size_t entry_bits = U32Coder::MaxEncodedBits(kTocDist) * num_sizes; + // permutation bit (not its tokens!), padding, entries, padding. + return 1 + kBitsPerByte + entry_bits + kBitsPerByte; +} + +Status ReadToc(size_t toc_entries, BitReader* JXL_RESTRICT reader, + std::vector<uint32_t>* JXL_RESTRICT sizes, + std::vector<coeff_order_t>* JXL_RESTRICT permutation) { + if (toc_entries > 65536) { + // Prevent out of memory if invalid JXL codestream causes a bogus amount + // of toc_entries such as 2720436919446 to be computed. + // TODO(lode): verify whether 65536 is a reasonable upper bound + return JXL_FAILURE("too many toc entries"); + } + + sizes->clear(); + sizes->resize(toc_entries); + if (reader->TotalBitsConsumed() >= reader->TotalBytes() * kBitsPerByte) { + return JXL_STATUS(StatusCode::kNotEnoughBytes, "Not enough bytes for TOC"); + } + const auto check_bit_budget = [&](size_t num_entries) -> Status { + // U32Coder reads 2 bits to recognize variant and kTocDist cheapest variant + // is Bits(10), this way at least 12 bits are required per toc-entry. + size_t minimal_bit_cost = num_entries * (2 + 10); + size_t bit_budget = reader->TotalBytes() * 8; + size_t expenses = reader->TotalBitsConsumed(); + if ((expenses <= bit_budget) && + (minimal_bit_cost <= bit_budget - expenses)) { + return true; + } + return JXL_STATUS(StatusCode::kNotEnoughBytes, "Not enough bytes for TOC"); + }; + + JXL_DASSERT(toc_entries > 0); + if (reader->ReadFixedBits<1>() == 1) { + JXL_RETURN_IF_ERROR(check_bit_budget(toc_entries)); + permutation->resize(toc_entries); + JXL_RETURN_IF_ERROR(DecodePermutation(/*skip=*/0, toc_entries, + permutation->data(), reader)); + } + JXL_RETURN_IF_ERROR(reader->JumpToByteBoundary()); + JXL_RETURN_IF_ERROR(check_bit_budget(toc_entries)); + for (size_t i = 0; i < toc_entries; ++i) { + (*sizes)[i] = U32Coder::Read(kTocDist, reader); + } + JXL_RETURN_IF_ERROR(reader->JumpToByteBoundary()); + JXL_RETURN_IF_ERROR(check_bit_budget(0)); + return true; +} + +Status ReadGroupOffsets(size_t toc_entries, BitReader* JXL_RESTRICT reader, + std::vector<uint64_t>* JXL_RESTRICT offsets, + std::vector<uint32_t>* JXL_RESTRICT sizes, + uint64_t* total_size) { + std::vector<coeff_order_t> permutation; + JXL_RETURN_IF_ERROR(ReadToc(toc_entries, reader, sizes, &permutation)); + + offsets->clear(); + offsets->resize(toc_entries); + + // Prefix sum starting with 0 and ending with the offset of the last group + uint64_t offset = 0; + for (size_t i = 0; i < toc_entries; ++i) { + if (offset + (*sizes)[i] < offset) { + return JXL_FAILURE("group offset overflow"); + } + (*offsets)[i] = offset; + offset += (*sizes)[i]; + } + if (total_size) { + *total_size = offset; + } + + if (!permutation.empty()) { + std::vector<uint64_t> permuted_offsets; + std::vector<uint32_t> permuted_sizes; + permuted_offsets.reserve(toc_entries); + permuted_sizes.reserve(toc_entries); + for (coeff_order_t index : permutation) { + permuted_offsets.push_back((*offsets)[index]); + permuted_sizes.push_back((*sizes)[index]); + } + std::swap(*offsets, permuted_offsets); + std::swap(*sizes, permuted_sizes); + } + + return true; +} +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/toc.h b/third_party/jpeg-xl/lib/jxl/toc.h new file mode 100644 index 0000000000..00006440b7 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/toc.h @@ -0,0 +1,53 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef LIB_JXL_TOC_H_ +#define LIB_JXL_TOC_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <vector> + +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/dec_bit_reader.h" +#include "lib/jxl/field_encodings.h" + +namespace jxl { + +// (2+bits) = 2,3,4 bytes so encoders can patch TOC after encoding. +// 30 is sufficient for 4K channels of uncompressed 16-bit samples. +constexpr U32Enc kTocDist(Bits(10), BitsOffset(14, 1024), BitsOffset(22, 17408), + BitsOffset(30, 4211712)); + +size_t MaxBits(const size_t num_sizes); + +// TODO(veluca): move these to FrameDimensions. +static JXL_INLINE size_t AcGroupIndex(size_t pass, size_t group, + size_t num_groups, size_t num_dc_groups) { + return 2 + num_dc_groups + pass * num_groups + group; +} + +static JXL_INLINE size_t NumTocEntries(size_t num_groups, size_t num_dc_groups, + size_t num_passes) { + if (num_groups == 1 && num_passes == 1) return 1; + return AcGroupIndex(0, 0, num_groups, num_dc_groups) + + num_groups * num_passes; +} + +Status ReadToc(size_t toc_entries, BitReader* JXL_RESTRICT reader, + std::vector<uint32_t>* JXL_RESTRICT sizes, + std::vector<coeff_order_t>* JXL_RESTRICT permutation); + +Status ReadGroupOffsets(size_t toc_entries, BitReader* JXL_RESTRICT reader, + std::vector<uint64_t>* JXL_RESTRICT offsets, + std::vector<uint32_t>* JXL_RESTRICT sizes, + uint64_t* total_size); + +} // namespace jxl + +#endif // LIB_JXL_TOC_H_ diff --git a/third_party/jpeg-xl/lib/jxl/toc_test.cc b/third_party/jpeg-xl/lib/jxl/toc_test.cc new file mode 100644 index 0000000000..8c95f8bc26 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/toc_test.cc @@ -0,0 +1,97 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/toc.h" + +#include <vector> + +#include "lib/jxl/base/common.h" +#include "lib/jxl/base/random.h" +#include "lib/jxl/base/span.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_toc.h" +#include "lib/jxl/testing.h" + +namespace jxl { +namespace { + +void Roundtrip(size_t num_entries, bool permute, Rng* rng) { + // Generate a random permutation. + std::vector<coeff_order_t> permutation; + std::vector<coeff_order_t> inv_permutation(num_entries); + for (size_t i = 0; i < num_entries; i++) { + inv_permutation[i] = i; + } + if (permute) { + permutation.resize(num_entries); + for (size_t i = 0; i < num_entries; i++) { + permutation[i] = i; + } + rng->Shuffle(permutation.data(), permutation.size()); + for (size_t i = 0; i < num_entries; i++) { + inv_permutation[permutation[i]] = i; + } + } + + // Generate num_entries groups of random (byte-aligned) length + std::vector<BitWriter> group_codes(num_entries); + for (BitWriter& writer : group_codes) { + const size_t max_bits = (*rng)() & 0xFFF; + BitWriter::Allotment allotment(&writer, max_bits + kBitsPerByte); + size_t i = 0; + for (; i + BitWriter::kMaxBitsPerCall < max_bits; + i += BitWriter::kMaxBitsPerCall) { + writer.Write(BitWriter::kMaxBitsPerCall, 0); + } + for (; i < max_bits; i += 1) { + writer.Write(/*n_bits=*/1, 0); + } + writer.ZeroPadToByte(); + AuxOut aux_out; + allotment.ReclaimAndCharge(&writer, 0, &aux_out); + } + + BitWriter writer; + AuxOut aux_out; + ASSERT_TRUE(WriteGroupOffsets(group_codes, permutation, &writer, &aux_out)); + + BitReader reader(writer.GetSpan()); + std::vector<uint64_t> group_offsets; + std::vector<uint32_t> group_sizes; + uint64_t total_size; + ASSERT_TRUE(ReadGroupOffsets(num_entries, &reader, &group_offsets, + &group_sizes, &total_size)); + ASSERT_EQ(num_entries, group_offsets.size()); + ASSERT_EQ(num_entries, group_sizes.size()); + EXPECT_TRUE(reader.Close()); + + uint64_t prefix_sum = 0; + for (size_t i = 0; i < num_entries; ++i) { + EXPECT_EQ(prefix_sum, group_offsets[inv_permutation[i]]); + + EXPECT_EQ(0u, group_codes[i].BitsWritten() % kBitsPerByte); + prefix_sum += group_codes[i].BitsWritten() / kBitsPerByte; + + if (i + 1 < num_entries) { + EXPECT_EQ( + group_offsets[inv_permutation[i]] + group_sizes[inv_permutation[i]], + group_offsets[inv_permutation[i + 1]]); + } + } + EXPECT_EQ(prefix_sum, total_size); +} + +TEST(TocTest, Test) { + Rng rng(0); + for (size_t num_entries = 1; num_entries < 10; ++num_entries) { + for (bool permute : std::vector<bool>{false, true}) { + Roundtrip(num_entries, permute, &rng); + } + } +} + +} // namespace +} // namespace jxl diff --git a/third_party/jpeg-xl/lib/jxl/transpose-inl.h b/third_party/jpeg-xl/lib/jxl/transpose-inl.h new file mode 100644 index 0000000000..4674420737 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/transpose-inl.h @@ -0,0 +1,203 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Block transpose for DCT/IDCT + +#if defined(LIB_JXL_TRANSPOSE_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_TRANSPOSE_INL_H_ +#undef LIB_JXL_TRANSPOSE_INL_H_ +#else +#define LIB_JXL_TRANSPOSE_INL_H_ +#endif + +#include <stddef.h> + +#include <hwy/highway.h> +#include <type_traits> + +#include "lib/jxl/base/status.h" +#include "lib/jxl/dct_block-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +#ifndef JXL_INLINE_TRANSPOSE +// Workaround for issue #42 - (excessive?) inlining causes invalid codegen. +#if defined(__arm__) +#define JXL_INLINE_TRANSPOSE HWY_NOINLINE +#else +#define JXL_INLINE_TRANSPOSE HWY_INLINE +#endif +#endif // JXL_INLINE_TRANSPOSE + +// Simple wrapper that ensures that a function will not be inlined. +template <typename T, typename... Args> +JXL_NOINLINE void NoInlineWrapper(const T& f, const Args&... args) { + return f(args...); +} + +template <bool enabled> +struct TransposeSimdTag {}; + +// TODO(veluca): it's not super useful to have this in the SIMD namespace. +template <size_t ROWS_or_0, size_t COLS_or_0, class From, class To> +JXL_INLINE_TRANSPOSE void GenericTransposeBlock(TransposeSimdTag<false>, + const From& from, const To& to, + size_t ROWSp, size_t COLSp) { + size_t ROWS = ROWS_or_0 == 0 ? ROWSp : ROWS_or_0; + size_t COLS = COLS_or_0 == 0 ? COLSp : COLS_or_0; + for (size_t n = 0; n < ROWS; ++n) { + for (size_t m = 0; m < COLS; ++m) { + to.Write(from.Read(n, m), m, n); + } + } +} + +// TODO(veluca): AVX3? +#if HWY_CAP_GE256 +constexpr bool TransposeUseSimd(size_t ROWS, size_t COLS) { + return ROWS % 8 == 0 && COLS % 8 == 0; +} + +template <size_t ROWS_or_0, size_t COLS_or_0, class From, class To> +JXL_INLINE_TRANSPOSE void GenericTransposeBlock(TransposeSimdTag<true>, + const From& from, const To& to, + size_t ROWSp, size_t COLSp) { + size_t ROWS = ROWS_or_0 == 0 ? ROWSp : ROWS_or_0; + size_t COLS = COLS_or_0 == 0 ? COLSp : COLS_or_0; + static_assert(MaxLanes(BlockDesc<8>()) == 8, "Invalid descriptor size"); + static_assert(ROWS_or_0 % 8 == 0, "Invalid number of rows"); + static_assert(COLS_or_0 % 8 == 0, "Invalid number of columns"); + for (size_t n = 0; n < ROWS; n += 8) { + for (size_t m = 0; m < COLS; m += 8) { + const BlockDesc<8> d; + auto i0 = from.LoadPart(d, n + 0, m + 0); + auto i1 = from.LoadPart(d, n + 1, m + 0); + auto i2 = from.LoadPart(d, n + 2, m + 0); + auto i3 = from.LoadPart(d, n + 3, m + 0); + auto i4 = from.LoadPart(d, n + 4, m + 0); + auto i5 = from.LoadPart(d, n + 5, m + 0); + auto i6 = from.LoadPart(d, n + 6, m + 0); + auto i7 = from.LoadPart(d, n + 7, m + 0); + // Surprisingly, this straightforward implementation (24 cycles on port5) + // is faster than load128+insert and LoadDup128+ConcatUpperLower+blend. + const auto q0 = InterleaveLower(d, i0, i2); + const auto q1 = InterleaveLower(d, i1, i3); + const auto q2 = InterleaveUpper(d, i0, i2); + const auto q3 = InterleaveUpper(d, i1, i3); + const auto q4 = InterleaveLower(d, i4, i6); + const auto q5 = InterleaveLower(d, i5, i7); + const auto q6 = InterleaveUpper(d, i4, i6); + const auto q7 = InterleaveUpper(d, i5, i7); + + const auto r0 = InterleaveLower(d, q0, q1); + const auto r1 = InterleaveUpper(d, q0, q1); + const auto r2 = InterleaveLower(d, q2, q3); + const auto r3 = InterleaveUpper(d, q2, q3); + const auto r4 = InterleaveLower(d, q4, q5); + const auto r5 = InterleaveUpper(d, q4, q5); + const auto r6 = InterleaveLower(d, q6, q7); + const auto r7 = InterleaveUpper(d, q6, q7); + + i0 = ConcatLowerLower(d, r4, r0); + i1 = ConcatLowerLower(d, r5, r1); + i2 = ConcatLowerLower(d, r6, r2); + i3 = ConcatLowerLower(d, r7, r3); + i4 = ConcatUpperUpper(d, r4, r0); + i5 = ConcatUpperUpper(d, r5, r1); + i6 = ConcatUpperUpper(d, r6, r2); + i7 = ConcatUpperUpper(d, r7, r3); + to.StorePart(d, i0, m + 0, n + 0); + to.StorePart(d, i1, m + 1, n + 0); + to.StorePart(d, i2, m + 2, n + 0); + to.StorePart(d, i3, m + 3, n + 0); + to.StorePart(d, i4, m + 4, n + 0); + to.StorePart(d, i5, m + 5, n + 0); + to.StorePart(d, i6, m + 6, n + 0); + to.StorePart(d, i7, m + 7, n + 0); + } + } +} +#elif HWY_TARGET != HWY_SCALAR +constexpr bool TransposeUseSimd(size_t ROWS, size_t COLS) { + return ROWS % 4 == 0 && COLS % 4 == 0; +} + +template <size_t ROWS_or_0, size_t COLS_or_0, class From, class To> +JXL_INLINE_TRANSPOSE void GenericTransposeBlock(TransposeSimdTag<true>, + const From& from, const To& to, + size_t ROWSp, size_t COLSp) { + size_t ROWS = ROWS_or_0 == 0 ? ROWSp : ROWS_or_0; + size_t COLS = COLS_or_0 == 0 ? COLSp : COLS_or_0; + static_assert(MaxLanes(BlockDesc<4>()) == 4, "Invalid descriptor size"); + static_assert(ROWS_or_0 % 4 == 0, "Invalid number of rows"); + static_assert(COLS_or_0 % 4 == 0, "Invalid number of columns"); + for (size_t n = 0; n < ROWS; n += 4) { + for (size_t m = 0; m < COLS; m += 4) { + const BlockDesc<4> d; + const auto p0 = from.LoadPart(d, n + 0, m + 0); + const auto p1 = from.LoadPart(d, n + 1, m + 0); + const auto p2 = from.LoadPart(d, n + 2, m + 0); + const auto p3 = from.LoadPart(d, n + 3, m + 0); + + const auto q0 = InterleaveLower(d, p0, p2); + const auto q1 = InterleaveLower(d, p1, p3); + const auto q2 = InterleaveUpper(d, p0, p2); + const auto q3 = InterleaveUpper(d, p1, p3); + + const auto r0 = InterleaveLower(d, q0, q1); + const auto r1 = InterleaveUpper(d, q0, q1); + const auto r2 = InterleaveLower(d, q2, q3); + const auto r3 = InterleaveUpper(d, q2, q3); + + to.StorePart(d, r0, m + 0, n + 0); + to.StorePart(d, r1, m + 1, n + 0); + to.StorePart(d, r2, m + 2, n + 0); + to.StorePart(d, r3, m + 3, n + 0); + } + } +} +#else +constexpr bool TransposeUseSimd(size_t ROWS, size_t COLS) { return false; } +#endif + +template <size_t N, size_t M, typename = void> +struct Transpose { + template <typename From, typename To> + static void Run(const From& from, const To& to) { + // This does not guarantee anything, just saves from the most stupid + // mistakes. + JXL_DASSERT(from.Address(0, 0) != to.Address(0, 0)); + TransposeSimdTag<TransposeUseSimd(N, M)> tag; + GenericTransposeBlock<N, M>(tag, from, to, N, M); + } +}; + +// Avoid inlining and unrolling transposes for large blocks. +template <size_t N, size_t M> +struct Transpose< + N, M, typename std::enable_if<(N >= 8 && M >= 8 && N * M >= 512)>::type> { + template <typename From, typename To> + static void Run(const From& from, const To& to) { + // This does not guarantee anything, just saves from the most stupid + // mistakes. + JXL_DASSERT(from.Address(0, 0) != to.Address(0, 0)); + TransposeSimdTag<TransposeUseSimd(N, M)> tag; + constexpr void (*transpose)(TransposeSimdTag<TransposeUseSimd(N, M)>, + const From&, const To&, size_t, size_t) = + GenericTransposeBlock<0, 0, From, To>; + NoInlineWrapper(transpose, tag, from, to, N, M); + } +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_TRANSPOSE_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/version.h.in b/third_party/jpeg-xl/lib/jxl/version.h.in new file mode 100644 index 0000000000..d077abec79 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/version.h.in @@ -0,0 +1,39 @@ +/* Copyright (c) the JPEG XL Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style + * license that can be found in the LICENSE file. + */ + +/** @addtogroup libjxl_common + * @{ + * @file version.h + * @brief libjxl version information + */ + +#ifndef JXL_VERSION_H_ +#define JXL_VERSION_H_ + +#define JPEGXL_MAJOR_VERSION @JPEGXL_MAJOR_VERSION@ ///< JPEG XL Major version +#define JPEGXL_MINOR_VERSION @JPEGXL_MINOR_VERSION@ ///< JPEG XL Minor version +#define JPEGXL_PATCH_VERSION @JPEGXL_PATCH_VERSION@ ///< JPEG XL Patch version + +/** Can be used to conditionally compile code for a specific JXL version + * @param[maj] major version + * @param[min] minor version + * + * @code + * #if JPEGXL_NUMERIC_VERSION < JPEGXL_COMPUTE_NUMERIC_VERSION(0,8,0) + * // use old/deprecated api + * #else + * // use current api + * #endif + * @endcode + */ +#define JPEGXL_COMPUTE_NUMERIC_VERSION(major,minor,patch) ((major<<24) | (minor<<16) | (patch<<8) | 0) + +/* Numeric representation of the version */ +#define JPEGXL_NUMERIC_VERSION JPEGXL_COMPUTE_NUMERIC_VERSION(JPEGXL_MAJOR_VERSION,JPEGXL_MINOR_VERSION,JPEGXL_PATCH_VERSION) + +#endif /* JXL_VERSION_H_ */ + +/** @}*/ diff --git a/third_party/jpeg-xl/lib/jxl/xorshift128plus-inl.h b/third_party/jpeg-xl/lib/jxl/xorshift128plus-inl.h new file mode 100644 index 0000000000..a473d591f2 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/xorshift128plus-inl.h @@ -0,0 +1,103 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Fast but weak random generator. + +#if defined(LIB_JXL_XORSHIFT128PLUS_INL_H_) == defined(HWY_TARGET_TOGGLE) +#ifdef LIB_JXL_XORSHIFT128PLUS_INL_H_ +#undef LIB_JXL_XORSHIFT128PLUS_INL_H_ +#else +#define LIB_JXL_XORSHIFT128PLUS_INL_H_ +#endif + +#include <stddef.h> + +#include <hwy/highway.h> +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::ShiftLeft; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Xor; + +// Adapted from https://github.com/vpxyz/xorshift/blob/master/xorshift128plus/ +// (MIT-license) +class Xorshift128Plus { + public: + // 8 independent generators (= single iteration for AVX-512) + enum { N = 8 }; + + explicit HWY_MAYBE_UNUSED Xorshift128Plus(const uint64_t seed) { + // Init state using SplitMix64 generator + s0_[0] = SplitMix64(seed + 0x9E3779B97F4A7C15ull); + s1_[0] = SplitMix64(s0_[0]); + for (size_t i = 1; i < N; ++i) { + s0_[i] = SplitMix64(s1_[i - 1]); + s1_[i] = SplitMix64(s0_[i]); + } + } + + HWY_MAYBE_UNUSED Xorshift128Plus(const uint32_t seed1, const uint32_t seed2, + const uint32_t seed3, const uint32_t seed4) { + // Init state using SplitMix64 generator + s0_[0] = SplitMix64(((static_cast<uint64_t>(seed1) << 32) + seed2) + + 0x9E3779B97F4A7C15ull); + s1_[0] = SplitMix64(((static_cast<uint64_t>(seed3) << 32) + seed4) + + 0x9E3779B97F4A7C15ull); + for (size_t i = 1; i < N; ++i) { + s0_[i] = SplitMix64(s0_[i - 1]); + s1_[i] = SplitMix64(s1_[i - 1]); + } + } + + HWY_INLINE HWY_MAYBE_UNUSED void Fill(uint64_t* HWY_RESTRICT random_bits) { +#if HWY_CAP_INTEGER64 + const HWY_FULL(uint64_t) d; + for (size_t i = 0; i < N; i += Lanes(d)) { + auto s1 = Load(d, s0_ + i); + const auto s0 = Load(d, s1_ + i); + const auto bits = Add(s1, s0); // b, c + Store(s0, d, s0_ + i); + s1 = Xor(s1, ShiftLeft<23>(s1)); + Store(bits, d, random_bits + i); + s1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0)))); + Store(s1, d, s1_ + i); + } +#else + for (size_t i = 0; i < N; ++i) { + auto s1 = s0_[i]; + const auto s0 = s1_[i]; + const auto bits = s1 + s0; // b, c + s0_[i] = s0; + s1 ^= s1 << 23; + random_bits[i] = bits; + s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5); + s1_[i] = s1; + } +#endif + } + + private: + static uint64_t SplitMix64(uint64_t z) { + z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull; + z = (z ^ (z >> 27)) * 0x94D049BB133111EBull; + return z ^ (z >> 31); + } + + HWY_ALIGN uint64_t s0_[N]; + HWY_ALIGN uint64_t s1_[N]; +}; + +} // namespace +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#endif // LIB_JXL_XORSHIFT128PLUS_INL_H_ diff --git a/third_party/jpeg-xl/lib/jxl/xorshift128plus_test.cc b/third_party/jpeg-xl/lib/jxl/xorshift128plus_test.cc new file mode 100644 index 0000000000..2ee4535284 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/xorshift128plus_test.cc @@ -0,0 +1,378 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include <stdint.h> + +#include <algorithm> +#include <vector> + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/xorshift128plus_test.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> +#include <hwy/tests/hwy_gtest.h> + +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/test_utils.h" +#include "lib/jxl/testing.h" +#include "lib/jxl/xorshift128plus-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Or; +using hwy::HWY_NAMESPACE::ShiftRight; +using hwy::HWY_NAMESPACE::Sub; + +// Define to nonzero in order to print the (new) golden outputs. +#define PRINT_RESULTS 0 + +const size_t kVectors = 64; + +#if PRINT_RESULTS + +template <int kNumLanes> +void Print(const uint64_t (&result)[kNumLanes]) { + printf("{ "); + for (int i = 0; i < kNumLanes; ++i) { + if (i != 0) { + printf(", "); + } + printf("0x%016llXull", result[i]); + } + printf("},\n"); +} + +#else // PRINT_RESULTS + +const uint64_t kExpected[kVectors][Xorshift128Plus::N] = { + {0x6E901576D477CBB1ull, 0xE9E53789195DA2A2ull, 0xB681F6DDA5E0AE99ull, + 0x8EFD18CE21FD6896ull, 0xA898A80DF75CF532ull, 0x50CEB2C9E2DE7E32ull, + 0x3CA7C2FEB25C0DD0ull, 0xA4D0866B80B4D836ull}, + {0x8CD6A1E6233D3A26ull, 0x3D4603ADE98B112Dull, 0xDC427AF674019E36ull, + 0xE28B4D230705AC53ull, 0x7297E9BBA88783DDull, 0x34D3D23CFCD9B41Aull, + 0x5A223615ADBE96B8ull, 0xE5EB529027CFBD01ull}, + {0xC1894CF00DFAC6A2ull, 0x18EDF8AE9085E404ull, 0x8E936625296B4CCDull, + 0x31971EF3A14A899Bull, 0xBE87535FCE0BF26Aull, 0x576F7A752BC6649Full, + 0xA44CBADCE0C6B937ull, 0x3DBA819BB17A353Aull}, + {0x27CE38DFCC1C5EB6ull, 0x920BEB5606340256ull, 0x3986CBC40C9AFC2Cull, + 0xE22BCB3EEB1E191Eull, 0x6E1FCDD3602A8FBAull, 0x052CB044E5415A29ull, + 0x46266646EFB9ECD7ull, 0x8F44914618D29335ull}, + {0xDD30AEDF72A362C5ull, 0xBC1D824E16BB98F4ull, 0x9EA6009C2AA3D2F1ull, + 0xF65C0FBBE17AF081ull, 0x22424D06A8738991ull, 0x8A62763F2B7611D2ull, + 0x2F3E89F722637939ull, 0x84D338BEF50AFD50ull}, + {0x00F46494898E2B0Bull, 0x81239DC4FB8E8003ull, 0x414AD93EC5773FE7ull, + 0x791473C450E4110Full, 0x87F127BF68C959ACull, 0x6429282D695EF67Bull, + 0x661082E11546CBA8ull, 0x5815D53FA5436BFDull}, + {0xB3DEADAB9BE6E0F9ull, 0xAA1B7B8F7CED0202ull, 0x4C5ED437699D279Eull, + 0xA4471727F1CB39D3ull, 0xE439DA193F802F70ull, 0xF89401BB04FA6493ull, + 0x3B08045A4FE898BAull, 0x32137BFE98227950ull}, + {0xFBAE4A092897FEF3ull, 0x0639F6CE56E71C8Eull, 0xF0AD6465C07F0C1Eull, + 0xFF8E28563361DCE5ull, 0xC2013DB7F86BC6B9ull, 0x8EFCC0503330102Full, + 0x3F6B767EA5C4DA40ull, 0xB9864B950B2232E1ull}, + {0x76EB58DE8E5EC22Aull, 0x9BBBF49A18B32F4Full, 0xC8405F02B2B2FAB9ull, + 0xC3E122A5F146BC34ull, 0xC90BB046660F5765ull, 0xB933981310DBECCFull, + 0x5A2A7BFC9126FD1Cull, 0x8BB388C94DF87901ull}, + {0x753EB89AD63EF3C3ull, 0xF24AAF40C89D65ADull, 0x23F68931C1A6AA6Dull, + 0xF47E79BF702C6DD0ull, 0xA3AD113244EE7EAEull, 0xD42CBEA28F793DC3ull, + 0xD896FCF1820F497Cull, 0x042B86D2818948C1ull}, + {0x8F2A4FC5A4265763ull, 0xEC499E6F95EAA10Cull, 0xE3786D4ECCD0DEB5ull, + 0xC725C53D3AC4CC43ull, 0x065A4ACBBF83610Eull, 0x35C61C9FEF167129ull, + 0x7B720AEAA7D70048ull, 0x14206B841377D039ull}, + {0xAD27D78BF96055F6ull, 0x5F43B20FF47ADCD4ull, 0xE184C2401E2BF71Eull, + 0x30B263D78990045Dull, 0xC22F00EBFF9BA201ull, 0xAE7F86522B53A562ull, + 0x2853312BC039F0A4ull, 0x868D619E6549C3C8ull}, + {0xFD5493D8AE9A8371ull, 0x773D5E224DF61B3Bull, 0x5377C54FBB1A8280ull, + 0xCAD4DE3B8265CAFAull, 0xCDF3F19C91EBD5F6ull, 0xC8EA0F182D73BD78ull, + 0x220502D593433FF1ull, 0xB81205E612DC31B1ull}, + {0x8F32A39EAEDA4C70ull, 0x1D4B0914AA4DAC7Full, 0x56EF1570F3A8B405ull, + 0x29812CB17404A592ull, 0x97A2AAF69CAE90F2ull, 0x12BF5E02778BBFE5ull, + 0x9D4B55AD42A05FD2ull, 0x06C2BAB5E6086620ull}, + {0x8DB4B9648302B253ull, 0xD756AD9E3AEA12C7ull, 0x68709B7F11D4B188ull, + 0x7CC299DDCD707A4Bull, 0x97B860C370A7661Dull, 0xCECD314FC20E64F5ull, + 0x55F412CDFB4C7EC3ull, 0x55EE97591193B525ull}, + {0xCF70F3ACA96E6254ull, 0x022FEDECA2E09F46ull, 0x686823DB60AE1ECFull, + 0xFD36190D3739830Eull, 0x74E1C09027F68120ull, 0xB5883A835C093842ull, + 0x93E1EFB927E9E4E3ull, 0xB2721E249D7E5EBEull}, + {0x69B6E21C44188CB8ull, 0x5D6CFB853655A7AAull, 0x3E001A0B425A66DCull, + 0x8C57451103A5138Full, 0x7BF8B4BE18EAB402ull, 0x494102EB8761A365ull, + 0xB33796A9F6A81F0Eull, 0x10005AB3BCCFD960ull}, + {0xB2CF25740AE965DCull, 0x6F7C1DF7EF53D670ull, 0x648DD6087AC2251Eull, + 0x040955D9851D487Dull, 0xBD550FC7E21A7F66ull, 0x57408F484DEB3AB5ull, + 0x481E24C150B506C1ull, 0x72C0C3EAF91A40D6ull}, + {0x1997A481858A5D39ull, 0x539718F4BEF50DC1ull, 0x2EC4DC4787E7E368ull, + 0xFF1CE78879419845ull, 0xE219A93DD6F6DD30ull, 0x85328618D02FEC1Aull, + 0xC86E02D969181B20ull, 0xEBEC8CD8BBA34E6Eull}, + {0x28B55088A16CE947ull, 0xDD25AC11E6350195ull, 0xBD1F176694257B1Cull, + 0x09459CCF9FCC9402ull, 0xF8047341E386C4E4ull, 0x7E8E9A9AD984C6C0ull, + 0xA4661E95062AA092ull, 0x70A9947005ED1152ull}, + {0x4C01CF75DBE98CCDull, 0x0BA076CDFC7373B9ull, 0x6C5E7A004B57FB59ull, + 0x336B82297FD3BC56ull, 0x7990C0BE74E8D60Full, 0xF0275CC00EC5C8C8ull, + 0x6CF29E682DFAD2E9ull, 0xFA4361524BD95D72ull}, + {0x631D2A19FF62F018ull, 0x41C43863B985B3FAull, 0xE052B2267038EFD9ull, + 0xE2A535FAC575F430ull, 0xE004EEA90B1FF5B8ull, 0x42DFE2CA692A1F26ull, + 0x90FB0BFC9A189ECCull, 0x4484102BD3536BD0ull}, + {0xD027134E9ACCA5A5ull, 0xBBAB4F966D476A9Bull, 0x713794A96E03D693ull, + 0x9F6335E6B94CD44Aull, 0xC5090C80E7471617ull, 0x6D9C1B0C87B58E33ull, + 0x1969CE82E31185A5ull, 0x2099B97E87754EBEull}, + {0x60EBAF4ED934350Full, 0xC26FBF0BA5E6ECFFull, 0x9E54150F0312EC57ull, + 0x0973B48364ED0041ull, 0x800A523241426CFCull, 0x03AB5EC055F75989ull, + 0x8CF315935DEEB40Aull, 0x83D3FC0190BD1409ull}, + {0x26D35394CF720A51ull, 0xCE9EAA15243CBAFEull, 0xE2B45FBAF21B29E0ull, + 0xDB92E98EDE73F9E0ull, 0x79B16F5101C26387ull, 0x1AC15959DE88C86Full, + 0x387633AEC6D6A580ull, 0xA6FC05807BFC5EB8ull}, + {0x2D26C8E47C6BADA9ull, 0x820E6EC832D52D73ull, 0xB8432C3E0ED0EE5Bull, + 0x0F84B3C4063AAA87ull, 0xF393E4366854F651ull, 0x749E1B4D2366A567ull, + 0x805EACA43480D004ull, 0x244EBF3AA54400A5ull}, + {0xBFDC3763AA79F75Aull, 0x9E3A74CC751F41DBull, 0xF401302A149DBC55ull, + 0x6B25F7973D7BF7BCull, 0x13371D34FDBC3DAEull, 0xC5E1998C8F484DCDull, + 0x7031B8AE5C364464ull, 0x3847F0C4F3DA2C25ull}, + {0x24C6387D2C0F1225ull, 0x77CCE960255C67A4ull, 0x21A0947E497B10EBull, + 0xBB5DB73A825A9D7Eull, 0x26294A41999E553Dull, 0x3953E0089F87D925ull, + 0x3DAE6E5D4E5EAAFEull, 0x74B545460341A7AAull}, + {0x710E5EB08A7DB820ull, 0x7E43C4E77CAEA025ull, 0xD4C91529C8B060C1ull, + 0x09AE26D8A7B0CA29ull, 0xAB9F356BB360A772ull, 0xB68834A25F19F6E9ull, + 0x79B8D9894C5734E2ull, 0xC6847E7C8FFD265Full}, + {0x10C4BCB06A5111E6ull, 0x57CB50955B6A2516ull, 0xEF53C87798B6995Full, + 0xAB38E15BBD8D0197ull, 0xA51C6106EFF73C93ull, 0x83D7F0E2270A7134ull, + 0x0923FD330397FCE5ull, 0xF9DE54EDFE58FB45ull}, + {0x07D44833ACCD1A94ull, 0xAAD3C9E945E2F9F3ull, 0xABF4C879B876AA37ull, + 0xF29C69A21B301619ull, 0x2DDCE959111C788Bull, 0x7CEDB48F8AC1729Bull, + 0x93F3BA9A02B659BEull, 0xF20A87FF17933CBEull}, + {0x8E96EBE93180CFE6ull, 0x94CAA12873937079ull, 0x05F613D9380D4189ull, + 0xBCAB40C1DC79F38Aull, 0x0AD8907B7C61D19Eull, 0x88534E189D103910ull, + 0x2DB2FAABA160AB8Full, 0xA070E7506B06F15Cull}, + {0x6FB1FCDAFFEF87A9ull, 0xE735CF25337A090Dull, 0x172C6EDCEFEF1825ull, + 0x76957EA49EF0542Dull, 0x819BF4CD250F7C49ull, 0xD6FF23E4AD00C4D4ull, + 0xE79673C1EC358FF0ull, 0xAC9C048144337938ull}, + {0x4C5387FF258B3AF4ull, 0xEDB68FAEC2CB1AA3ull, 0x02A624E67B4E1DA4ull, + 0x5C44797A38E08AF2ull, 0x36546A70E9411B4Bull, 0x47C17B24D2FD9675ull, + 0x101957AAA020CA26ull, 0x47A1619D4779F122ull}, + {0xF84B8BCDC92D9A3Cull, 0x951D7D2C74B3066Bull, 0x7AC287C06EDDD9B2ull, + 0x4C38FC476608D38Full, 0x224D793B19CB4BCDull, 0x835A255899BF1A41ull, + 0x4AD250E9F62DB4ABull, 0xD9B44F4B58781096ull}, + {0xABBAF99A8EB5C6B8ull, 0xFB568E900D3A9F56ull, 0x11EDF63D23C5DF11ull, + 0xA9C3011D3FA7C5A8ull, 0xAEDD3CF11AFFF725ull, 0xABCA472B5F1EDD6Bull, + 0x0600B6BB5D879804ull, 0xDB4DE007F22191A0ull}, + {0xD76CC9EFF0CE9392ull, 0xF5E0A772B59BA49Aull, 0x7D1AE1ED0C1261B5ull, + 0x79224A33B5EA4F4Aull, 0x6DD825D80C40EA60ull, 0x47FC8E747E51C953ull, + 0x695C05F72888BF98ull, 0x1A012428440B9015ull}, + {0xD754DD61F9B772BFull, 0xC4A2FCF4C0F9D4EBull, 0x461167CDF67A24A2ull, + 0x434748490EBCB9D4ull, 0x274DD9CDCA5781DEull, 0x36BAC63BA9A85209ull, + 0x30324DAFDA36B70Full, 0x337570DB4FE6DAB3ull}, + {0xF46CBDD57C551546ull, 0x8E02507E676DA3E3ull, 0xD826245A8C15406Dull, + 0xDFB38A5B71113B72ull, 0x5EA38454C95B16B5ull, 0x28C054FB87ABF3E1ull, + 0xAA2724C0BA1A8096ull, 0xECA83EC980304F2Full}, + {0x6AA76EC294EB3303ull, 0x42D4CDB2A8032E3Bull, 0x7999EDF75DCD8735ull, + 0xB422BFFE696CCDCCull, 0x8F721461FD7CCDFEull, 0x148E1A5814FDE253ull, + 0x4DC941F4375EF8FFull, 0x27B2A9E0EB5B49CFull}, + {0xCEA592EF9343EBE1ull, 0xF7D38B5FA7698903ull, 0x6CCBF352203FEAB6ull, + 0x830F3095FCCDA9C5ull, 0xDBEEF4B81B81C8F4ull, 0x6D7EB9BCEECA5CF9ull, + 0xC58ABB0FBE436C69ull, 0xE4B97E6DB2041A4Bull}, + {0x7E40FC772978AF14ull, 0xCDDA4BBAE28354A1ull, 0xE4F993B832C32613ull, + 0xD3608093C68A4B35ull, 0x9A3B60E01BEE3699ull, 0x03BEF248F3288713ull, + 0x70B9294318F3E9B4ull, 0x8D2ABB913B8610DEull}, + {0x37F209128E7D8B2Cull, 0x81D2AB375BD874BCull, 0xA716A1B7373F7408ull, + 0x0CEE97BEC4706540ull, 0xA40C5FD9CDBC1512ull, 0x73CAF6C8918409E7ull, + 0x45E11BCEDF0BBAA1ull, 0x612C612BFF6E6605ull}, + {0xF8ECB14A12D0F649ull, 0xDA683CD7C01BA1ACull, 0xA2203F7510E124C1ull, + 0x7F83E52E162F3C78ull, 0x77D2BB73456ACADBull, 0x37FC34FC840BBA6Full, + 0x3076BC7D4C6EBC1Full, 0x4F514123632B5FA9ull}, + {0x44D789DED935E884ull, 0xF8291591E09FEC9Full, 0xD9CED2CF32A2E4B7ull, + 0x95F70E1EB604904Aull, 0xDE438FE43C14F6ABull, 0x4C8D23E4FAFCF8D8ull, + 0xC716910A3067EB86ull, 0x3D6B7915315095D3ull}, + {0x3170FDBADAB92095ull, 0x8F1963933FC5650Bull, 0x72F94F00ABECFEABull, + 0x6E3AE826C6AAB4CEull, 0xA677A2BF31068258ull, 0x9660CDC4F363AF10ull, + 0xD81A15A152379EF1ull, 0x5D7D285E1080A3F9ull}, + {0xDAD5DDFF9A2249B3ull, 0x6F9721D926103FAEull, 0x1418CBB83FFA349Aull, + 0xE71A30AD48C012B2ull, 0xBE76376C63751132ull, 0x3496467ACA713AE6ull, + 0x8D7EC01369F991A3ull, 0xD8C73A88B96B154Eull}, + {0x8B5D9C74AEB4833Aull, 0xF914FB3F867B912Full, 0xB894EA034936B1DCull, + 0x8A16D21BE51C4F5Bull, 0x31FF048ED582D98Eull, 0xB95AB2F4DC65B820ull, + 0x04082B9170561AF7ull, 0xA215610A5DC836FAull}, + {0xB2ADE592C092FAACull, 0x7A1E683BCBF13294ull, 0xC7A4DBF86858C096ull, + 0x3A49940F97BFF316ull, 0xCAE5C06B82C46703ull, 0xC7F413A0F951E2BDull, + 0x6665E7BB10EB5916ull, 0x86F84A5A94EDE319ull}, + {0x4EA199D8FAA79CA3ull, 0xDFA26E5BF1981704ull, 0x0F5E081D37FA4E01ull, + 0x9CB632F89CD675CDull, 0x4A09DB89D48C0304ull, 0x88142742EA3C7672ull, + 0xAC4F149E6D2E9BDBull, 0x6D9E1C23F8B1C6C6ull}, + {0xD58BE47B92DEC0E9ull, 0x8E57573645E34328ull, 0x4CC094CCB5FB5126ull, + 0x5F1D66AF6FB40E3Cull, 0x2BA15509132D3B00ull, 0x0D6545646120E567ull, + 0x3CF680C45C223666ull, 0x96B28E32930179DAull}, + {0x5900C45853AC7990ull, 0x61881E3E8B7FF169ull, 0x4DE5F835DF2230FFull, + 0x4427A9E7932F73FFull, 0x9B641BAD379A8C8Dull, 0xDF271E5BF98F4E5Cull, + 0xDFDA16DB830FF5EEull, 0x371C7E7CFB89C0E9ull}, + {0x4410A8576247A250ull, 0x6AD2DA12B45AC0D9ull, 0x18DFC72AAC85EECCull, + 0x06FC8BB2A0EF25C8ull, 0xEB287619C85E6118ull, 0x19553ECA67F25A2Cull, + 0x3B9557F1DCEC5BAAull, 0x7BAD9E8B710D1079ull}, + {0x34F365D66BD22B28ull, 0xE6E124B9F10F835Dull, 0x0573C38ABF2B24DCull, + 0xD32E6AF10A0125AEull, 0x383590ACEA979519ull, 0x8376ED7A39E28205ull, + 0xF0B7F184DCBDA435ull, 0x062A203390E31794ull}, + {0xA2AFFD7E41918760ull, 0x7F90FC1BD0819C86ull, 0x5033C08E5A969533ull, + 0x2707AF5C6D039590ull, 0x57BBD5980F17DF9Cull, 0xD3FE6E61D763268Aull, + 0x9E0A0AE40F335A3Bull, 0x43CF4EB0A99613C5ull}, + {0xD4D2A397CE1A7C2Eull, 0x3DF7CE7CC3212DADull, 0x0880F0D5D356C75Aull, + 0xA8AFC44DD03B1346ull, 0x79263B46C13A29E0ull, 0x11071B3C0ED58E7Aull, + 0xED46DC9F538406BFull, 0x2C94974F2B94843Dull}, + {0xE246E13C39AB5D5Eull, 0xAC1018489D955B20ull, 0x8601B558771852B8ull, + 0x110BD4C06DB40173ull, 0x738FC8A18CCA0EBBull, 0x6673E09BE0EA76E5ull, + 0x024BC7A0C7527877ull, 0x45E6B4652E2EC34Eull}, + {0xD1ED26A1A375CDC8ull, 0xAABC4E896A617CB8ull, 0x0A9C9E8E57D753C6ull, + 0xA3774A75FEB4C30Eull, 0x30B816C01C93E49Eull, 0xF405BABC06D2408Cull, + 0xCC0CE6B4CE788ABCull, 0x75E7922D0447956Cull}, + {0xD07C1676A698BC95ull, 0x5F9AEA4840E2D860ull, 0xD5FC10D58BDF6F02ull, + 0xF190A2AD4BC2EEA7ull, 0x0C24D11F51726931ull, 0xDB646899A16B6512ull, + 0x7BC10670047B1DD8ull, 0x2413A5ABCD45F092ull}, + {0x4E66892190CFD923ull, 0xF10162440365EC8Eull, 0x158ACA5A6A2280AEull, + 0x0D60ED11C0224166ull, 0x7CD2E9A71B9D7488ull, 0x450D7289706AB2A3ull, + 0x88FAE34EC9A0D7DCull, 0x96FF9103575A97DAull}, + {0x77990FAC6046C446ull, 0xB174B5FB30C76676ull, 0xE352CE3EB56CF82Aull, + 0xC6039B6873A9A082ull, 0xE3F80F3AE333148Aull, 0xB853BA24BA3539B9ull, + 0xE8863E52ECCB0C74ull, 0x309B4CC1092CC245ull}, + {0xBC2B70BEE8388D9Full, 0xE48D92AE22216DCEull, 0xF15F3BF3E2C15D8Full, + 0x1DD964D4812D8B24ull, 0xD56AF02FB4665E4Cull, 0x98002200595BD9A3ull, + 0x049246D50BB8FA12ull, 0x1B542DF485B579B9ull}, + {0x2347409ADFA8E497ull, 0x36015C2211D62498ull, 0xE9F141F32EB82690ull, + 0x1F839912D0449FB9ull, 0x4E4DCFFF2D02D97Cull, 0xF8A03AB4C0F625C9ull, + 0x0605F575795DAC5Cull, 0x4746C9BEA0DDA6B1ull}, + {0xCA5BB519ECE7481Bull, 0xFD496155E55CA945ull, 0xF753B9DBB1515F81ull, + 0x50549E8BAC0F70E7ull, 0x8614FB0271E21C60ull, 0x60C72947EB0F0070ull, + 0xA6511C10AEE742B6ull, 0x48FB48F2CACCB43Eull}}; + +#endif // PRINT_RESULTS + +// Ensures Xorshift128+ returns consistent and unchanging values. +void TestGolden() { + HWY_ALIGN Xorshift128Plus rng(12345); + for (uint64_t vector = 0; vector < kVectors; ++vector) { + HWY_ALIGN uint64_t lanes[Xorshift128Plus::N]; + rng.Fill(lanes); +#if PRINT_RESULTS + Print(lanes); +#else + for (size_t i = 0; i < Xorshift128Plus::N; ++i) { + ASSERT_EQ(kExpected[vector][i], lanes[i]) + << "Where vector=" << vector << " i=" << i; + } +#endif + } +} + +// Output changes when given different seeds +void TestSeedChanges() { + HWY_ALIGN uint64_t lanes[Xorshift128Plus::N]; + + std::vector<uint64_t> first; + constexpr size_t kNumSeeds = 16384; + first.reserve(kNumSeeds); + + // All 14-bit seeds + for (size_t seed = 0; seed < kNumSeeds; ++seed) { + HWY_ALIGN Xorshift128Plus rng(seed); + + rng.Fill(lanes); + first.push_back(lanes[0]); + } + + // All outputs are unique + ASSERT_EQ(kNumSeeds, first.size()); + std::sort(first.begin(), first.end()); + first.erase(std::unique(first.begin(), first.end()), first.end()); + EXPECT_EQ(kNumSeeds, first.size()); +} + +void TestFloat() { + test::ThreadPoolForTests pool(8); + +#ifdef JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 256; +#else // JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 4096; +#endif // JXL_DISABLE_SLOW_TESTS + EXPECT_TRUE(RunOnPool( + &pool, 0, kMaxSeed, ThreadPool::NoInit, + [](const uint32_t seed, size_t /*thread*/) { + HWY_ALIGN Xorshift128Plus rng(seed); + + const HWY_FULL(uint32_t) du; + const HWY_FULL(float) df; + HWY_ALIGN uint64_t batch[Xorshift128Plus::N]; + HWY_ALIGN float lanes[MaxLanes(df)]; + double sum = 0.0; + size_t count = 0; + const size_t kReps = 2000; + for (size_t reps = 0; reps < kReps; ++reps) { + rng.Fill(batch); + for (size_t i = 0; i < Xorshift128Plus::N * 2; i += Lanes(df)) { + const auto bits = + Load(du, reinterpret_cast<const uint32_t*>(batch) + i); + // 1.0 + 23 random mantissa bits = [1, 2) + const auto rand12 = + BitCast(df, Or(ShiftRight<9>(bits), Set(du, 0x3F800000))); + const auto rand01 = Sub(rand12, Set(df, 1.0f)); + Store(rand01, df, lanes); + for (float lane : lanes) { + sum += lane; + count += 1; + EXPECT_LE(lane, 1.0f); + EXPECT_GE(lane, 0.0f); + } + } + } + + // Verify average (uniform distribution) + EXPECT_NEAR(0.5, sum / count, 0.00702); + }, + "TestXorShift")); +} + +// Not more than one 64-bit zero +void TestNotZero() { + test::ThreadPoolForTests pool(8); + +#ifdef JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 500; +#else // JXL_DISABLE_SLOW_TESTS + const uint32_t kMaxSeed = 2000; +#endif // JXL_DISABLE_SLOW_TESTS + EXPECT_TRUE(RunOnPool( + &pool, 0, kMaxSeed, ThreadPool::NoInit, + [](const uint32_t task, size_t /*thread*/) { + HWY_ALIGN uint64_t lanes[Xorshift128Plus::N]; + + HWY_ALIGN Xorshift128Plus rng(task); + size_t num_zero = 0; + for (size_t vectors = 0; vectors < 10000; ++vectors) { + rng.Fill(lanes); + for (uint64_t lane : lanes) { + num_zero += static_cast<size_t>(lane == 0); + } + } + EXPECT_LE(num_zero, 1u); + }, + "TestNotZero")); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +class Xorshift128Test : public hwy::TestWithParamTarget {}; + +HWY_TARGET_INSTANTIATE_TEST_SUITE_P(Xorshift128Test); + +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestNotZero); +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestGolden); +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestSeedChanges); +HWY_EXPORT_AND_TEST_P(Xorshift128Test, TestFloat); + +} // namespace jxl +#endif |