From 6bf0a5cb5034a7e684dcc3500e841785237ce2dd Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 19:32:43 +0200 Subject: Adding upstream version 1:115.7.0. Signed-off-by: Daniel Baumann --- .../jpeg-xl/lib/jxl/enc_adaptive_quantization.cc | 1145 ++++++++++++++++++++ 1 file changed, 1145 insertions(+) create mode 100644 third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc (limited to 'third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc') diff --git a/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc new file mode 100644 index 0000000000..f54204b059 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc @@ -0,0 +1,1145 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/enc_adaptive_quantization.h" + +#include +#include +#include + +#include +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/enc_adaptive_quantization.cc" +#include +#include + +#include "lib/jxl/ac_strategy.h" +#include "lib/jxl/base/compiler_specific.h" +#include "lib/jxl/base/data_parallel.h" +#include "lib/jxl/base/profiler.h" +#include "lib/jxl/base/status.h" +#include "lib/jxl/butteraugli/butteraugli.h" +#include "lib/jxl/coeff_order_fwd.h" +#include "lib/jxl/color_encoding_internal.h" +#include "lib/jxl/color_management.h" +#include "lib/jxl/common.h" +#include "lib/jxl/convolve.h" +#include "lib/jxl/dec_cache.h" +#include "lib/jxl/dec_group.h" +#include "lib/jxl/enc_aux_out.h" +#include "lib/jxl/enc_butteraugli_comparator.h" +#include "lib/jxl/enc_cache.h" +#include "lib/jxl/enc_group.h" +#include "lib/jxl/enc_modular.h" +#include "lib/jxl/enc_params.h" +#include "lib/jxl/enc_transforms-inl.h" +#include "lib/jxl/epf.h" +#include "lib/jxl/fast_math-inl.h" +#include "lib/jxl/gauss_blur.h" +#include "lib/jxl/image.h" +#include "lib/jxl/image_bundle.h" +#include "lib/jxl/image_ops.h" +#include "lib/jxl/opsin_params.h" +#include "lib/jxl/quant_weights.h" +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +namespace { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::AbsDiff; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::And; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Rebind; +using hwy::HWY_NAMESPACE::Sqrt; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +// The following functions modulate an exponent (out_val) and return the updated +// value. Their descriptor is limited to 8 lanes for 8x8 blocks. + +// Hack for mask estimation. Eventually replace this code with butteraugli's +// masking. +float ComputeMaskForAcStrategyUse(const float out_val) { + const float kMul = 1.0f; + const float kOffset = 0.001f; + return kMul / (out_val + kOffset); +} + +template +V ComputeMask(const D d, const V out_val) { + const auto kBase = Set(d, -0.76471879237038032f); + const auto kMul4 = Set(d, 4.4585596705216615f); + const auto kMul2 = Set(d, 17.282053892620215f); + const auto kOffset2 = Set(d, 302.36961315317848f); + const auto kMul3 = Set(d, 7.0561261998705858f); + const auto kOffset3 = Set(d, 2.3179635626140773f); + const auto kOffset4 = Mul(Set(d, 0.25f), kOffset3); + const auto kMul0 = Set(d, 0.80061762862741759f); + const auto k1 = Set(d, 1.0f); + + // Avoid division by zero. + const auto v1 = Max(Mul(out_val, kMul0), Set(d, 1e-3f)); + const auto v2 = Div(k1, Add(v1, kOffset2)); + const auto v3 = Div(k1, MulAdd(v1, v1, kOffset3)); + const auto v4 = Div(k1, MulAdd(v1, v1, kOffset4)); + // TODO(jyrki): + // A log or two here could make sense. In butteraugli we have effectively + // log(log(x + C)) for this kind of use, as a single log is used in + // saturating visual masking and here the modulation values are exponential, + // another log would counter that. + return Add(kBase, MulAdd(kMul4, v4, MulAdd(kMul2, v2, Mul(kMul3, v3)))); +} + +// mul and mul2 represent a scaling difference between jxl and butteraugli. +static const float kSGmul = 226.77216153508914f; +static const float kSGmul2 = 1.0f / 73.377132366608819f; +static const float kLog2 = 0.693147181f; +// Includes correction factor for std::log -> log2. +static const float kSGRetMul = kSGmul2 * 18.6580932135f * kLog2; +static const float kSGVOffset = 7.7825991679894591f; + +template +V RatioOfDerivativesOfCubicRootToSimpleGamma(const D d, V v) { + // The opsin space in jxl is the cubic root of photons, i.e., v * v * v + // is related to the number of photons. + // + // SimpleGamma(v * v * v) is the psychovisual space in butteraugli. + // This ratio allows quantization to move from jxl's opsin space to + // butteraugli's log-gamma space. + float kEpsilon = 1e-2; + v = ZeroIfNegative(v); + const auto kNumMul = Set(d, kSGRetMul * 3 * kSGmul); + const auto kVOffset = Set(d, kSGVOffset * kLog2 + kEpsilon); + const auto kDenMul = Set(d, kLog2 * kSGmul); + + const auto v2 = Mul(v, v); + + const auto num = MulAdd(kNumMul, v2, Set(d, kEpsilon)); + const auto den = MulAdd(Mul(kDenMul, v), v2, kVOffset); + return invert ? Div(num, den) : Div(den, num); +} + +template +static float RatioOfDerivativesOfCubicRootToSimpleGamma(float v) { + using DScalar = HWY_CAPPED(float, 1); + auto vscalar = Load(DScalar(), &v); + return GetLane( + RatioOfDerivativesOfCubicRootToSimpleGamma(DScalar(), vscalar)); +} + +// TODO(veluca): this function computes an approximation of the derivative of +// SimpleGamma with (f(x+eps)-f(x))/eps. Consider two-sided approximation or +// exact derivatives. For reference, SimpleGamma was: +/* +template +V SimpleGamma(const D d, V v) { + // A simple HDR compatible gamma function. + const auto mul = Set(d, kSGmul); + const auto kRetMul = Set(d, kSGRetMul); + const auto kRetAdd = Set(d, kSGmul2 * -20.2789020414f); + const auto kVOffset = Set(d, kSGVOffset); + + v *= mul; + + // This should happen rarely, but may lead to a NaN, which is rather + // undesirable. Since negative photons don't exist we solve the NaNs by + // clamping here. + // TODO(veluca): with FastLog2f, this no longer leads to NaNs. + v = ZeroIfNegative(v); + return kRetMul * FastLog2f(d, v + kVOffset) + kRetAdd; +} +*/ + +template +V GammaModulation(const D d, const size_t x, const size_t y, + const ImageF& xyb_x, const ImageF& xyb_y, const V out_val) { + const float kBias = 0.16f; + JXL_DASSERT(kBias > kOpsinAbsorbanceBias[0]); + JXL_DASSERT(kBias > kOpsinAbsorbanceBias[1]); + JXL_DASSERT(kBias > kOpsinAbsorbanceBias[2]); + auto overall_ratio = Zero(d); + auto bias = Set(d, kBias); + auto half = Set(d, 0.5f); + for (size_t dy = 0; dy < 8; ++dy) { + const float* const JXL_RESTRICT row_in_x = xyb_x.Row(y + dy); + const float* const JXL_RESTRICT row_in_y = xyb_y.Row(y + dy); + for (size_t dx = 0; dx < 8; dx += Lanes(d)) { + const auto iny = Add(Load(d, row_in_y + x + dx), bias); + const auto inx = Load(d, row_in_x + x + dx); + const auto r = Sub(iny, inx); + const auto g = Add(iny, inx); + const auto ratio_r = + RatioOfDerivativesOfCubicRootToSimpleGamma(d, r); + const auto ratio_g = + RatioOfDerivativesOfCubicRootToSimpleGamma(d, g); + const auto avg_ratio = Mul(half, Add(ratio_r, ratio_g)); + + overall_ratio = Add(overall_ratio, avg_ratio); + } + } + overall_ratio = Mul(SumOfLanes(d, overall_ratio), Set(d, 1.0f / 64)); + // ideally -1.0, but likely optimal correction adds some entropy, so slightly + // less than that. + // ln(2) constant folded in because we want std::log but have FastLog2f. + const auto kGam = Set(d, -0.15526878023684174f * 0.693147180559945f); + return MulAdd(kGam, FastLog2f(d, overall_ratio), out_val); +} + +template +V ColorModulation(const D d, const size_t x, const size_t y, + const ImageF& xyb_x, const ImageF& xyb_y, const ImageF& xyb_b, + const double butteraugli_target, V out_val) { + static const float kStrengthMul = 4.2456542701250122f; + static const float kRedRampStart = 0.18748564245760829f; + static const float kRedRampLength = 0.16701783842516479f; + static const float kBlueRampLength = 0.16117602661852037f; + static const float kBlueRampStart = 0.47897504338287333f; + const float strength = kStrengthMul * (1.0f - 0.15f * butteraugli_target); + if (strength < 0) { + return out_val; + } + // x values are smaller than y and b values, need to take the difference into + // account. + const float red_strength = strength * 6.0f; + const float blue_strength = strength; + { + // Reduce some bits from areas not blue or red. + const float offset = strength * -0.007; // 9174542291185913f; + out_val = Add(out_val, Set(d, offset)); + } + // Calculate how much of the 8x8 block is covered with blue or red. + auto blue_coverage = Zero(d); + auto red_coverage = Zero(d); + auto bias_y = Set(d, 0.2f); + auto bias_y_add = Set(d, 0.1f); + for (size_t dy = 0; dy < 8; ++dy) { + const float* const JXL_RESTRICT row_in_x = xyb_x.Row(y + dy); + const float* const JXL_RESTRICT row_in_y = xyb_y.Row(y + dy); + const float* const JXL_RESTRICT row_in_b = xyb_b.Row(y + dy); + for (size_t dx = 0; dx < 8; dx += Lanes(d)) { + const auto pixel_y = Load(d, row_in_y + x + dx); + // Estimate redness-greeness relative to the intensity. + const auto pixel_xpy = Div(Abs(Load(d, row_in_x + x + dx)), + Max(Add(bias_y_add, pixel_y), bias_y)); + const auto pixel_x = + Max(Set(d, 0.0f), Sub(pixel_xpy, Set(d, kRedRampStart))); + const auto pixel_b = + Max(Set(d, 0.0f), Sub(Load(d, row_in_b + x + dx), + Add(pixel_y, Set(d, kBlueRampStart)))); + const auto blue_slope = Min(pixel_b, Set(d, kBlueRampLength)); + const auto red_slope = Min(pixel_x, Set(d, kRedRampLength)); + red_coverage = Add(red_coverage, red_slope); + blue_coverage = Add(blue_coverage, blue_slope); + } + } + + // Saturate when the high red or high blue coverage is above a level. + // The idea here is that if a certain fraction of the block is red or + // blue we consider as if it was fully red or blue. + static const float ratio = 28.0f; // out of 64 pixels. + + auto overall_red_coverage = SumOfLanes(d, red_coverage); + overall_red_coverage = + Min(overall_red_coverage, Set(d, ratio * kRedRampLength)); + overall_red_coverage = + Mul(overall_red_coverage, Set(d, red_strength / ratio)); + + auto overall_blue_coverage = SumOfLanes(d, blue_coverage); + overall_blue_coverage = + Min(overall_blue_coverage, Set(d, ratio * kBlueRampLength)); + overall_blue_coverage = + Mul(overall_blue_coverage, Set(d, blue_strength / ratio)); + + return Add(overall_red_coverage, Add(overall_blue_coverage, out_val)); +} + +// Change precision in 8x8 blocks that have high frequency content. +template +V HfModulation(const D d, const size_t x, const size_t y, const ImageF& xyb, + const V out_val) { + // Zero out the invalid differences for the rightmost value per row. + const Rebind du; + HWY_ALIGN constexpr uint32_t kMaskRight[kBlockDim] = {~0u, ~0u, ~0u, ~0u, + ~0u, ~0u, ~0u, 0}; + + auto sum = Zero(d); // sum of absolute differences with right and below + + static const float valmin = 0.52489909479039587f; + auto valminv = Set(d, valmin); + for (size_t dy = 0; dy < 8; ++dy) { + const float* JXL_RESTRICT row_in = xyb.Row(y + dy) + x; + const float* JXL_RESTRICT row_in_next = + dy == 7 ? row_in : xyb.Row(y + dy + 1) + x; + + // In SCALAR, there is no guarantee of having extra row padding. + // Hence, we need to ensure we don't access pixels outside the row itself. + // In SIMD modes, however, rows are padded, so it's safe to access one + // garbage value after the row. The vector then gets masked with kMaskRight + // to remove the influence of that value. +#if HWY_TARGET != HWY_SCALAR + for (size_t dx = 0; dx < 8; dx += Lanes(d)) { +#else + for (size_t dx = 0; dx < 7; dx += Lanes(d)) { +#endif + const auto p = Load(d, row_in + dx); + const auto pr = LoadU(d, row_in + dx + 1); + const auto mask = BitCast(d, Load(du, kMaskRight + dx)); + sum = Add(sum, And(mask, Min(valminv, AbsDiff(p, pr)))); + + const auto pd = Load(d, row_in_next + dx); + sum = Add(sum, Min(valminv, AbsDiff(p, pd))); + } +#if HWY_TARGET == HWY_SCALAR + const auto p = Load(d, row_in + 7); + const auto pd = Load(d, row_in_next + 7); + sum = Add(sum, Min(valminv, AbsDiff(p, pd))); +#endif + } + // more negative value gives more bpp + static const float kOffset = -2.6545897672771526; + static const float kMul = -0.049868161744916512; + + sum = SumOfLanes(d, sum); + float scalar_sum = GetLane(sum); + static const float maxsum = 7.9076877647025947f; + static const float minsum = 0.53640540945659809f; + scalar_sum = std::min(maxsum, scalar_sum); + scalar_sum = std::max(minsum, scalar_sum); + scalar_sum += kOffset; + scalar_sum *= kMul; + return Add(Set(d, scalar_sum), out_val); +} + +void PerBlockModulations(const float butteraugli_target, const ImageF& xyb_x, + const ImageF& xyb_y, const ImageF& xyb_b, + const float scale, const Rect& rect, ImageF* out) { + JXL_ASSERT(SameSize(xyb_x, xyb_y)); + JXL_ASSERT(DivCeil(xyb_x.xsize(), kBlockDim) == out->xsize()); + JXL_ASSERT(DivCeil(xyb_x.ysize(), kBlockDim) == out->ysize()); + + float base_level = 0.48f * scale; + float kDampenRampStart = 2.0f; + float kDampenRampEnd = 14.0f; + float dampen = 1.0f; + if (butteraugli_target >= kDampenRampStart) { + dampen = 1.0f - ((butteraugli_target - kDampenRampStart) / + (kDampenRampEnd - kDampenRampStart)); + if (dampen < 0) { + dampen = 0; + } + } + const float mul = scale * dampen; + const float add = (1.0f - dampen) * base_level; + for (size_t iy = rect.y0(); iy < rect.y0() + rect.ysize(); iy++) { + const size_t y = iy * 8; + float* const JXL_RESTRICT row_out = out->Row(iy); + const HWY_CAPPED(float, kBlockDim) df; + for (size_t ix = rect.x0(); ix < rect.x0() + rect.xsize(); ix++) { + size_t x = ix * 8; + auto out_val = Set(df, row_out[ix]); + out_val = ComputeMask(df, out_val); + out_val = HfModulation(df, x, y, xyb_y, out_val); + out_val = ColorModulation(df, x, y, xyb_x, xyb_y, xyb_b, + butteraugli_target, out_val); + out_val = GammaModulation(df, x, y, xyb_x, xyb_y, out_val); + // We want multiplicative quantization field, so everything + // until this point has been modulating the exponent. + row_out[ix] = FastPow2f(GetLane(out_val) * 1.442695041f) * mul + add; + } + } +} + +template +V MaskingSqrt(const D d, V v) { + static const float kLogOffset = 27.97044946785558f; + static const float kMul = 211.53333281566171f; + const auto mul_v = Set(d, kMul * 1e8); + const auto offset_v = Set(d, kLogOffset); + return Mul(Set(d, 0.25f), Sqrt(MulAdd(v, Sqrt(mul_v), offset_v))); +} + +float MaskingSqrt(const float v) { + using DScalar = HWY_CAPPED(float, 1); + auto vscalar = Load(DScalar(), &v); + return GetLane(MaskingSqrt(DScalar(), vscalar)); +} + +void StoreMin4(const float v, float& min0, float& min1, float& min2, + float& min3) { + if (v < min3) { + if (v < min0) { + min3 = min2; + min2 = min1; + min1 = min0; + min0 = v; + } else if (v < min1) { + min3 = min2; + min2 = min1; + min1 = v; + } else if (v < min2) { + min3 = min2; + min2 = v; + } else { + min3 = v; + } + } +} + +// Look for smooth areas near the area of degradation. +// If the areas are generally smooth, don't do masking. +// Output is downsampled 2x. +void FuzzyErosion(const Rect& from_rect, const ImageF& from, + const Rect& to_rect, ImageF* to) { + const size_t xsize = from.xsize(); + const size_t ysize = from.ysize(); + constexpr int kStep = 1; + static_assert(kStep == 1, "Step must be 1"); + JXL_ASSERT(to_rect.xsize() * 2 == from_rect.xsize()); + JXL_ASSERT(to_rect.ysize() * 2 == from_rect.ysize()); + for (size_t fy = 0; fy < from_rect.ysize(); ++fy) { + size_t y = fy + from_rect.y0(); + size_t ym1 = y >= kStep ? y - kStep : y; + size_t yp1 = y + kStep < ysize ? y + kStep : y; + const float* rowt = from.Row(ym1); + const float* row = from.Row(y); + const float* rowb = from.Row(yp1); + float* row_out = to_rect.Row(to, fy / 2); + for (size_t fx = 0; fx < from_rect.xsize(); ++fx) { + size_t x = fx + from_rect.x0(); + size_t xm1 = x >= kStep ? x - kStep : x; + size_t xp1 = x + kStep < xsize ? x + kStep : x; + float min0 = row[x]; + float min1 = row[xm1]; + float min2 = row[xp1]; + float min3 = rowt[xm1]; + // Sort the first four values. + if (min0 > min1) std::swap(min0, min1); + if (min0 > min2) std::swap(min0, min2); + if (min0 > min3) std::swap(min0, min3); + if (min1 > min2) std::swap(min1, min2); + if (min1 > min3) std::swap(min1, min3); + if (min2 > min3) std::swap(min2, min3); + // The remaining five values of a 3x3 neighbourhood. + StoreMin4(rowt[x], min0, min1, min2, min3); + StoreMin4(rowt[xp1], min0, min1, min2, min3); + StoreMin4(rowb[xm1], min0, min1, min2, min3); + StoreMin4(rowb[x], min0, min1, min2, min3); + StoreMin4(rowb[xp1], min0, min1, min2, min3); + static const float kMul0 = 0.125f; + static const float kMul1 = 0.075f; + static const float kMul2 = 0.06f; + static const float kMul3 = 0.05f; + float v = kMul0 * min0 + kMul1 * min1 + kMul2 * min2 + kMul3 * min3; + if (fx % 2 == 0 && fy % 2 == 0) { + row_out[fx / 2] = v; + } else { + row_out[fx / 2] += v; + } + } + } +} + +struct AdaptiveQuantizationImpl { + void Init(const Image3F& xyb) { + JXL_DASSERT(xyb.xsize() % kBlockDim == 0); + JXL_DASSERT(xyb.ysize() % kBlockDim == 0); + const size_t xsize = xyb.xsize(); + const size_t ysize = xyb.ysize(); + aq_map = ImageF(xsize / kBlockDim, ysize / kBlockDim); + } + void PrepareBuffers(size_t num_threads) { + diff_buffer = ImageF(kEncTileDim + 8, num_threads); + for (size_t i = pre_erosion.size(); i < num_threads; i++) { + pre_erosion.emplace_back(kEncTileDimInBlocks * 2 + 2, + kEncTileDimInBlocks * 2 + 2); + } + } + + void ComputeTile(float butteraugli_target, float scale, const Image3F& xyb, + const Rect& rect, const int thread, ImageF* mask) { + PROFILER_ZONE("aq DiffPrecompute"); + const size_t xsize = xyb.xsize(); + const size_t ysize = xyb.ysize(); + + // The XYB gamma is 3.0 to be able to decode faster with two muls. + // Butteraugli's gamma is matching the gamma of human eye, around 2.6. + // We approximate the gamma difference by adding one cubic root into + // the adaptive quantization. This gives us a total gamma of 2.6666 + // for quantization uses. + const float match_gamma_offset = 0.019; + + const HWY_FULL(float) df; + + size_t y_start = rect.y0() * 8; + size_t y_end = y_start + rect.ysize() * 8; + + size_t x0 = rect.x0() * 8; + size_t x1 = x0 + rect.xsize() * 8; + if (x0 != 0) x0 -= 4; + if (x1 != xyb.xsize()) x1 += 4; + if (y_start != 0) y_start -= 4; + if (y_end != xyb.ysize()) y_end += 4; + pre_erosion[thread].ShrinkTo((x1 - x0) / 4, (y_end - y_start) / 4); + + static const float limit = 0.2f; + // Computes image (padded to multiple of 8x8) of local pixel differences. + // Subsample both directions by 4. + for (size_t y = y_start; y < y_end; ++y) { + size_t y2 = y + 1 < ysize ? y + 1 : y; + size_t y1 = y > 0 ? y - 1 : y; + + const float* row_in = xyb.PlaneRow(1, y); + const float* row_in1 = xyb.PlaneRow(1, y1); + const float* row_in2 = xyb.PlaneRow(1, y2); + float* JXL_RESTRICT row_out = diff_buffer.Row(thread); + + auto scalar_pixel = [&](size_t x) { + const size_t x2 = x + 1 < xsize ? x + 1 : x; + const size_t x1 = x > 0 ? x - 1 : x; + const float base = + 0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]); + const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma( + row_in[x] + match_gamma_offset); + float diff = gammac * (row_in[x] - base); + diff *= diff; + if (diff >= limit) { + diff = limit; + } + diff = MaskingSqrt(diff); + if ((y % 4) != 0) { + row_out[x - x0] += diff; + } else { + row_out[x - x0] = diff; + } + }; + + size_t x = x0; + // First pixel of the row. + if (x0 == 0) { + scalar_pixel(x0); + ++x; + } + // SIMD + const auto match_gamma_offset_v = Set(df, match_gamma_offset); + const auto quarter = Set(df, 0.25f); + for (; x + 1 + Lanes(df) < x1; x += Lanes(df)) { + const auto in = LoadU(df, row_in + x); + const auto in_r = LoadU(df, row_in + x + 1); + const auto in_l = LoadU(df, row_in + x - 1); + const auto in_t = LoadU(df, row_in2 + x); + const auto in_b = LoadU(df, row_in1 + x); + auto base = Mul(quarter, Add(Add(in_r, in_l), Add(in_t, in_b))); + auto gammacv = + RatioOfDerivativesOfCubicRootToSimpleGamma( + df, Add(in, match_gamma_offset_v)); + auto diff = Mul(gammacv, Sub(in, base)); + diff = Mul(diff, diff); + diff = Min(diff, Set(df, limit)); + diff = MaskingSqrt(df, diff); + if ((y & 3) != 0) { + diff = Add(diff, LoadU(df, row_out + x - x0)); + } + StoreU(diff, df, row_out + x - x0); + } + // Scalar + for (; x < x1; ++x) { + scalar_pixel(x); + } + if (y % 4 == 3) { + float* row_dout = pre_erosion[thread].Row((y - y_start) / 4); + for (size_t x = 0; x < (x1 - x0) / 4; x++) { + row_dout[x] = (row_out[x * 4] + row_out[x * 4 + 1] + + row_out[x * 4 + 2] + row_out[x * 4 + 3]) * + 0.25f; + } + } + } + Rect from_rect(x0 % 8 == 0 ? 0 : 1, y_start % 8 == 0 ? 0 : 1, + rect.xsize() * 2, rect.ysize() * 2); + FuzzyErosion(from_rect, pre_erosion[thread], rect, &aq_map); + for (size_t y = 0; y < rect.ysize(); ++y) { + const float* aq_map_row = rect.ConstRow(aq_map, y); + float* mask_row = rect.Row(mask, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + mask_row[x] = ComputeMaskForAcStrategyUse(aq_map_row[x]); + } + } + PerBlockModulations(butteraugli_target, xyb.Plane(0), xyb.Plane(1), + xyb.Plane(2), scale, rect, &aq_map); + } + std::vector pre_erosion; + ImageF aq_map; + ImageF diff_buffer; +}; + +ImageF AdaptiveQuantizationMap(const float butteraugli_target, + const Image3F& xyb, + const FrameDimensions& frame_dim, float scale, + ThreadPool* pool, ImageF* mask) { + PROFILER_ZONE("aq AdaptiveQuantMap"); + + AdaptiveQuantizationImpl impl; + impl.Init(xyb); + *mask = ImageF(frame_dim.xsize_blocks, frame_dim.ysize_blocks); + JXL_CHECK(RunOnPool( + pool, 0, + DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks) * + DivCeil(frame_dim.ysize_blocks, kEncTileDimInBlocks), + [&](const size_t num_threads) { + impl.PrepareBuffers(num_threads); + return true; + }, + [&](const uint32_t tid, const size_t thread) { + size_t n_enc_tiles = + DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks); + size_t tx = tid % n_enc_tiles; + size_t ty = tid / n_enc_tiles; + size_t by0 = ty * kEncTileDimInBlocks; + size_t by1 = + std::min((ty + 1) * kEncTileDimInBlocks, frame_dim.ysize_blocks); + size_t bx0 = tx * kEncTileDimInBlocks; + size_t bx1 = + std::min((tx + 1) * kEncTileDimInBlocks, frame_dim.xsize_blocks); + Rect r(bx0, by0, bx1 - bx0, by1 - by0); + impl.ComputeTile(butteraugli_target, scale, xyb, r, thread, mask); + }, + "AQ DiffPrecompute")); + + return std::move(impl).aq_map; +} + +} // namespace + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { +HWY_EXPORT(AdaptiveQuantizationMap); + +namespace { +// If true, prints the quantization maps at each iteration. +bool FLAGS_dump_quant_state = false; + +void DumpHeatmap(const AuxOut* aux_out, const std::string& label, + const ImageF& image, float good_threshold, + float bad_threshold) { + Image3F heatmap = CreateHeatMapImage(image, good_threshold, bad_threshold); + char filename[200]; + snprintf(filename, sizeof(filename), "%s%05d", label.c_str(), + aux_out->num_butteraugli_iters); + aux_out->DumpImage(filename, heatmap); +} + +void DumpHeatmaps(const AuxOut* aux_out, float ba_target, + const ImageF& quant_field, const ImageF& tile_heatmap, + const ImageF& bt_diffmap) { + if (!WantDebugOutput(aux_out)) return; + ImageF inv_qmap(quant_field.xsize(), quant_field.ysize()); + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* JXL_RESTRICT row_q = quant_field.ConstRow(y); + float* JXL_RESTRICT row_inv_q = inv_qmap.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + row_inv_q[x] = 1.0f / row_q[x]; // never zero + } + } + DumpHeatmap(aux_out, "quant_heatmap", inv_qmap, 4.0f * ba_target, + 6.0f * ba_target); + DumpHeatmap(aux_out, "tile_heatmap", tile_heatmap, ba_target, + 1.5f * ba_target); + // matches heat maps produced by the command line tool. + DumpHeatmap(aux_out, "bt_diffmap", bt_diffmap, ButteraugliFuzzyInverse(1.5), + ButteraugliFuzzyInverse(0.5)); +} + +ImageF TileDistMap(const ImageF& distmap, int tile_size, int margin, + const AcStrategyImage& ac_strategy) { + PROFILER_FUNC; + const int tile_xsize = (distmap.xsize() + tile_size - 1) / tile_size; + const int tile_ysize = (distmap.ysize() + tile_size - 1) / tile_size; + ImageF tile_distmap(tile_xsize, tile_ysize); + size_t distmap_stride = tile_distmap.PixelsPerRow(); + for (int tile_y = 0; tile_y < tile_ysize; ++tile_y) { + AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(tile_y); + float* JXL_RESTRICT dist_row = tile_distmap.Row(tile_y); + for (int tile_x = 0; tile_x < tile_xsize; ++tile_x) { + AcStrategy acs = ac_strategy_row[tile_x]; + if (!acs.IsFirstBlock()) continue; + int this_tile_xsize = acs.covered_blocks_x() * tile_size; + int this_tile_ysize = acs.covered_blocks_y() * tile_size; + int y_begin = std::max(0, tile_size * tile_y - margin); + int y_end = std::min(distmap.ysize(), + tile_size * tile_y + this_tile_ysize + margin); + int x_begin = std::max(0, tile_size * tile_x - margin); + int x_end = std::min(distmap.xsize(), + tile_size * tile_x + this_tile_xsize + margin); + float dist_norm = 0.0; + double pixels = 0; + for (int y = y_begin; y < y_end; ++y) { + float ymul = 1.0; + constexpr float kBorderMul = 0.98f; + constexpr float kCornerMul = 0.7f; + if (margin != 0 && (y == y_begin || y == y_end - 1)) { + ymul = kBorderMul; + } + const float* const JXL_RESTRICT row = distmap.Row(y); + for (int x = x_begin; x < x_end; ++x) { + float xmul = ymul; + if (margin != 0 && (x == x_begin || x == x_end - 1)) { + if (xmul == 1.0) { + xmul = kBorderMul; + } else { + xmul = kCornerMul; + } + } + float v = row[x]; + v *= v; + v *= v; + v *= v; + v *= v; + dist_norm += xmul * v; + pixels += xmul; + } + } + if (pixels == 0) pixels = 1; + // 16th norm is less than the max norm, we reduce the difference + // with this normalization factor. + constexpr float kTileNorm = 1.2f; + const float tile_dist = + kTileNorm * std::pow(dist_norm / pixels, 1.0f / 16.0f); + dist_row[tile_x] = tile_dist; + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + dist_row[tile_x + distmap_stride * iy + ix] = tile_dist; + } + } + } + } + return tile_distmap; +} + +static const float kDcQuantPow = 0.83; +static const float kDcQuant = 1.095924047623553f; +static const float kAcQuant = 0.80751132443618624f; + +void FindBestQuantization(const ImageBundle& linear, const Image3F& opsin, + PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out) { + const CompressParams& cparams = enc_state->cparams; + if (cparams.resampling > 1 && + cparams.original_butteraugli_distance <= 4.0 * cparams.resampling) { + // For downsampled opsin image, the butteraugli based adaptive quantization + // loop would only make the size bigger without improving the distance much, + // so in this case we enable it only for very high butteraugli targets. + return; + } + Quantizer& quantizer = enc_state->shared.quantizer; + ImageI& raw_quant_field = enc_state->shared.raw_quant_field; + ImageF& quant_field = enc_state->initial_quant_field; + + // TODO(veluca): this should really be rather handled on the + // ButteraugliComparator side. + struct TemporaryShrink { + TemporaryShrink(ImageBundle& bundle, size_t xsize, size_t ysize) + : bundle(bundle), + orig_xsize(bundle.xsize()), + orig_ysize(bundle.ysize()) { + bundle.ShrinkTo(xsize, ysize); + } + TemporaryShrink(const TemporaryShrink&) = delete; + TemporaryShrink(TemporaryShrink&&) = delete; + + ~TemporaryShrink() { bundle.ShrinkTo(orig_xsize, orig_ysize); } + + ImageBundle& bundle; + size_t orig_xsize; + size_t orig_ysize; + } t(const_cast(linear), + enc_state->shared.frame_header.nonserialized_metadata->xsize(), + enc_state->shared.frame_header.nonserialized_metadata->ysize()); + + const float butteraugli_target = cparams.butteraugli_distance; + const float original_butteraugli = cparams.original_butteraugli_distance; + ButteraugliParams params = cparams.ba_params; + params.intensity_target = linear.metadata()->IntensityTarget(); + // Hack the default intensity target value to be 80.0, the intensity + // target of sRGB images and a more reasonable viewing default than + // JPEG XL file format's default. + if (fabs(params.intensity_target - 255.0f) < 1e-3) { + params.intensity_target = 80.0f; + } + JxlButteraugliComparator comparator(params, cms); + JXL_CHECK(comparator.SetReferenceImage(linear)); + bool lower_is_better = + (comparator.GoodQualityScore() < comparator.BadQualityScore()); + const float initial_quant_dc = InitialQuantDC(butteraugli_target); + AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field), + &quant_field); + ImageF tile_distmap; + ImageF initial_quant_field = CopyImage(quant_field); + + float initial_qf_min, initial_qf_max; + ImageMinMax(initial_quant_field, &initial_qf_min, &initial_qf_max); + float initial_qf_ratio = initial_qf_max / initial_qf_min; + float qf_max_deviation_low = std::sqrt(250 / initial_qf_ratio); + float asymmetry = 2; + if (qf_max_deviation_low < asymmetry) asymmetry = qf_max_deviation_low; + float qf_lower = initial_qf_min / (asymmetry * qf_max_deviation_low); + float qf_higher = initial_qf_max * (qf_max_deviation_low / asymmetry); + + JXL_ASSERT(qf_higher / qf_lower < 253); + + constexpr int kOriginalComparisonRound = 1; + int iters = cparams.max_butteraugli_iters; + if (iters > 7) { + iters = 7; + } + if (cparams.speed_tier != SpeedTier::kTortoise) { + iters = 2; + } + for (int i = 0; i < iters + 1; ++i) { + if (FLAGS_dump_quant_state) { + printf("\nQuantization field:\n"); + for (size_t y = 0; y < quant_field.ysize(); ++y) { + for (size_t x = 0; x < quant_field.xsize(); ++x) { + printf(" %.5f", quant_field.Row(y)[x]); + } + printf("\n"); + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); + ImageBundle dec_linear = RoundtripImage(opsin, enc_state, cms, pool); + PROFILER_ZONE("enc Butteraugli"); + float score; + ImageF diffmap; + JXL_CHECK(comparator.CompareWith(dec_linear, &diffmap, &score)); + if (!lower_is_better) { + score = -score; + diffmap = ScaleImage(-1.0f, diffmap); + } + tile_distmap = TileDistMap(diffmap, 8 * cparams.resampling, 0, + enc_state->shared.ac_strategy); + if (WantDebugOutput(aux_out)) { + aux_out->DumpImage(("dec" + ToString(i)).c_str(), *dec_linear.color()); + DumpHeatmaps(aux_out, butteraugli_target, quant_field, tile_distmap, + diffmap); + } + if (aux_out != nullptr) ++aux_out->num_butteraugli_iters; + if (cparams.log_search_state) { + float minval, maxval; + ImageMinMax(quant_field, &minval, &maxval); + printf("\nButteraugli iter: %d/%d\n", i, cparams.max_butteraugli_iters); + printf("Butteraugli distance: %f (target = %f)\n", score, + original_butteraugli); + printf("quant range: %f ... %f DC quant: %f\n", minval, maxval, + initial_quant_dc); + if (FLAGS_dump_quant_state) { + quantizer.DumpQuantizationMap(raw_quant_field); + } + } + + if (i == iters) break; + + double kPow[8] = { + 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + }; + double kPowMod[8] = { + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + }; + if (i == kOriginalComparisonRound) { + // Don't allow optimization to make the quant field a lot worse than + // what the initial guess was. This allows the AC field to have enough + // precision to reduce the oscillations due to the dc reconstruction. + double kInitMul = 0.6; + const double kOneMinusInitMul = 1.0 - kInitMul; + for (size_t y = 0; y < quant_field.ysize(); ++y) { + float* const JXL_RESTRICT row_q = quant_field.Row(y); + const float* const JXL_RESTRICT row_init = initial_quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + double clamp = kOneMinusInitMul * row_q[x] + kInitMul * row_init[x]; + if (row_q[x] < clamp) { + row_q[x] = clamp; + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } + } + + double cur_pow = 0.0; + if (i < 7) { + cur_pow = kPow[i] + (original_butteraugli - 1.0) * kPowMod[i]; + if (cur_pow < 0) { + cur_pow = 0; + } + } + if (cur_pow == 0.0) { + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); + float* const JXL_RESTRICT row_q = quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + const float diff = row_dist[x] / original_butteraugli; + if (diff > 1.0f) { + float old = row_q[x]; + row_q[x] *= diff; + int qf_old = old * quantizer.InvGlobalScale() + 0.5; + int qf_new = row_q[x] * quantizer.InvGlobalScale() + 0.5; + if (qf_old == qf_new) { + row_q[x] = old + quantizer.Scale(); + } + } + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } else { + for (size_t y = 0; y < quant_field.ysize(); ++y) { + const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); + float* const JXL_RESTRICT row_q = quant_field.Row(y); + for (size_t x = 0; x < quant_field.xsize(); ++x) { + const float diff = row_dist[x] / original_butteraugli; + if (diff <= 1.0f) { + row_q[x] *= std::pow(diff, cur_pow); + } else { + float old = row_q[x]; + row_q[x] *= diff; + int qf_old = old * quantizer.InvGlobalScale() + 0.5; + int qf_new = row_q[x] * quantizer.InvGlobalScale() + 0.5; + if (qf_old == qf_new) { + row_q[x] = old + quantizer.Scale(); + } + } + if (row_q[x] > qf_higher) row_q[x] = qf_higher; + if (row_q[x] < qf_lower) row_q[x] = qf_lower; + } + } + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); +} + +void FindBestQuantizationMaxError(const Image3F& opsin, + PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out) { + // TODO(szabadka): Make this work for non-opsin color spaces. + const CompressParams& cparams = enc_state->cparams; + Quantizer& quantizer = enc_state->shared.quantizer; + ImageI& raw_quant_field = enc_state->shared.raw_quant_field; + ImageF& quant_field = enc_state->initial_quant_field; + + // TODO(veluca): better choice of this value. + const float initial_quant_dc = + 16 * std::sqrt(0.1f / cparams.butteraugli_distance); + AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field), + &quant_field); + + const float inv_max_err[3] = {1.0f / enc_state->cparams.max_error[0], + 1.0f / enc_state->cparams.max_error[1], + 1.0f / enc_state->cparams.max_error[2]}; + + for (int i = 0; i < cparams.max_butteraugli_iters + 1; ++i) { + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); + if (aux_out) { + aux_out->DumpXybImage(("ops" + ToString(i)).c_str(), opsin); + } + ImageBundle decoded = RoundtripImage(opsin, enc_state, cms, pool); + if (aux_out) { + aux_out->DumpXybImage(("dec" + ToString(i)).c_str(), *decoded.color()); + } + + for (size_t by = 0; by < enc_state->shared.frame_dim.ysize_blocks; by++) { + AcStrategyRow ac_strategy_row = + enc_state->shared.ac_strategy.ConstRow(by); + for (size_t bx = 0; bx < enc_state->shared.frame_dim.xsize_blocks; bx++) { + AcStrategy acs = ac_strategy_row[bx]; + if (!acs.IsFirstBlock()) continue; + float max_error = 0; + for (size_t c = 0; c < 3; c++) { + for (size_t y = by * kBlockDim; + y < (by + acs.covered_blocks_y()) * kBlockDim; y++) { + if (y >= decoded.ysize()) continue; + const float* JXL_RESTRICT in_row = opsin.ConstPlaneRow(c, y); + const float* JXL_RESTRICT dec_row = + decoded.color()->ConstPlaneRow(c, y); + for (size_t x = bx * kBlockDim; + x < (bx + acs.covered_blocks_x()) * kBlockDim; x++) { + if (x >= decoded.xsize()) continue; + max_error = std::max( + std::abs(in_row[x] - dec_row[x]) * inv_max_err[c], max_error); + } + } + } + // Target an error between max_error/2 and max_error. + // If the error in the varblock is above the target, increase the qf to + // compensate. If the error is below the target, decrease the qf. + // However, to avoid an excessive increase of the qf, only do so if the + // error is less than half the maximum allowed error. + const float qf_mul = (max_error < 0.5f) ? max_error * 2.0f + : (max_error > 1.0f) ? max_error + : 1.0f; + for (size_t qy = by; qy < by + acs.covered_blocks_y(); qy++) { + float* JXL_RESTRICT quant_field_row = quant_field.Row(qy); + for (size_t qx = bx; qx < bx + acs.covered_blocks_x(); qx++) { + quant_field_row[qx] *= qf_mul; + } + } + } + } + } + quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); +} + +} // namespace + +void AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect, + ImageF* quant_field) { + // Replace the whole quant_field in non-8x8 blocks with the maximum of each + // 8x8 block. + size_t stride = quant_field->PixelsPerRow(); + for (size_t y = 0; y < rect.ysize(); ++y) { + AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(rect, y); + float* JXL_RESTRICT quant_row = rect.Row(quant_field, y); + for (size_t x = 0; x < rect.xsize(); ++x) { + AcStrategy acs = ac_strategy_row[x]; + if (!acs.IsFirstBlock()) continue; + JXL_ASSERT(x + acs.covered_blocks_x() <= quant_field->xsize()); + JXL_ASSERT(y + acs.covered_blocks_y() <= quant_field->ysize()); + float max = quant_row[x]; + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + max = std::max(quant_row[x + ix + iy * stride], max); + } + } + for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { + for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { + quant_row[x + ix + iy * stride] = max; + } + } + } + } +} + +float InitialQuantDC(float butteraugli_target) { + const float kDcMul = 0.3; // Butteraugli target where non-linearity kicks in. + const float butteraugli_target_dc = std::max( + 0.5f * butteraugli_target, + std::min(butteraugli_target, + kDcMul * std::pow((1.0f / kDcMul) * butteraugli_target, + kDcQuantPow))); + // We want the maximum DC value to be at most 2**15 * kInvDCQuant / quant_dc. + // The maximum DC value might not be in the kXybRange because of inverse + // gaborish, so we add some slack to the maximum theoretical quant obtained + // this way (64). + return std::min(kDcQuant / butteraugli_target_dc, 50.f); +} + +ImageF InitialQuantField(const float butteraugli_target, const Image3F& opsin, + const FrameDimensions& frame_dim, ThreadPool* pool, + float rescale, ImageF* mask) { + PROFILER_FUNC; + const float quant_ac = kAcQuant / butteraugli_target; + return HWY_DYNAMIC_DISPATCH(AdaptiveQuantizationMap)( + butteraugli_target, opsin, frame_dim, quant_ac * rescale, pool, mask); +} + +void FindBestQuantizer(const ImageBundle* linear, const Image3F& opsin, + PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool, + AuxOut* aux_out, double rescale) { + const CompressParams& cparams = enc_state->cparams; + if (cparams.max_error_mode) { + PROFILER_ZONE("enc find best maxerr"); + FindBestQuantizationMaxError(opsin, enc_state, cms, pool, aux_out); + } else if (cparams.speed_tier <= SpeedTier::kKitten) { + // Normal encoding to a butteraugli score. + PROFILER_ZONE("enc find best2"); + FindBestQuantization(*linear, opsin, enc_state, cms, pool, aux_out); + } +} + +ImageBundle RoundtripImage(const Image3F& opsin, PassesEncoderState* enc_state, + const JxlCmsInterface& cms, ThreadPool* pool) { + PROFILER_ZONE("enc roundtrip"); + std::unique_ptr dec_state = + jxl::make_unique(); + JXL_CHECK(dec_state->output_encoding_info.SetFromMetadata( + *enc_state->shared.metadata)); + dec_state->shared = &enc_state->shared; + JXL_ASSERT(opsin.ysize() % kBlockDim == 0); + + const size_t xsize_groups = DivCeil(opsin.xsize(), kGroupDim); + const size_t ysize_groups = DivCeil(opsin.ysize(), kGroupDim); + const size_t num_groups = xsize_groups * ysize_groups; + + size_t num_special_frames = enc_state->special_frames.size(); + + std::unique_ptr modular_frame_encoder = + jxl::make_unique(enc_state->shared.frame_header, + enc_state->cparams); + JXL_CHECK(InitializePassesEncoder(opsin, cms, pool, enc_state, + modular_frame_encoder.get(), nullptr)); + JXL_CHECK(dec_state->Init()); + JXL_CHECK(dec_state->InitForAC(pool)); + + ImageBundle decoded(&enc_state->shared.metadata->m); + decoded.origin = enc_state->shared.frame_header.frame_origin; + decoded.SetFromImage(Image3F(opsin.xsize(), opsin.ysize()), + dec_state->output_encoding_info.color_encoding); + + PassesDecoderState::PipelineOptions options; + options.use_slow_render_pipeline = false; + options.coalescing = true; + options.render_spotcolors = false; + + // Same as dec_state->shared->frame_header.nonserialized_metadata->m + const ImageMetadata& metadata = *decoded.metadata(); + + JXL_CHECK(dec_state->PreparePipeline(&decoded, options)); + + hwy::AlignedUniquePtr group_dec_caches; + const auto allocate_storage = [&](const size_t num_threads) -> Status { + JXL_RETURN_IF_ERROR( + dec_state->render_pipeline->PrepareForThreads(num_threads, + /*use_group_ids=*/false)); + group_dec_caches = hwy::MakeUniqueAlignedArray(num_threads); + return true; + }; + const auto process_group = [&](const uint32_t group_index, + const size_t thread) { + if (dec_state->shared->frame_header.loop_filter.epf_iters > 0) { + ComputeSigma(dec_state->shared->BlockGroupRect(group_index), + dec_state.get()); + } + RenderPipelineInput input = + dec_state->render_pipeline->GetInputBuffers(group_index, thread); + JXL_CHECK(DecodeGroupForRoundtrip( + enc_state->coeffs, group_index, dec_state.get(), + &group_dec_caches[thread], thread, input, &decoded, nullptr)); + for (size_t c = 0; c < metadata.num_extra_channels; c++) { + std::pair ri = input.GetBuffer(3 + c); + FillPlane(0.0f, ri.first, ri.second); + } + input.Done(); + }; + JXL_CHECK(RunOnPool(pool, 0, num_groups, allocate_storage, process_group, + "AQ loop")); + + // Ensure we don't create any new special frames. + enc_state->special_frames.resize(num_special_frames); + + return decoded; +} + +} // namespace jxl +#endif // HWY_ONCE -- cgit v1.2.3