summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc')
-rw-r--r--third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc1170
1 files changed, 1170 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc
new file mode 100644
index 0000000000..ae4cd3bd3b
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/enc_adaptive_quantization.cc
@@ -0,0 +1,1170 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/enc_adaptive_quantization.h"
+
+#include <stddef.h>
+#include <stdlib.h>
+
+#include <algorithm>
+#include <cmath>
+#include <string>
+#include <vector>
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "lib/jxl/enc_adaptive_quantization.cc"
+#include <hwy/foreach_target.h>
+#include <hwy/highway.h>
+
+#include "lib/jxl/ac_strategy.h"
+#include "lib/jxl/base/common.h"
+#include "lib/jxl/base/compiler_specific.h"
+#include "lib/jxl/base/data_parallel.h"
+#include "lib/jxl/base/fast_math-inl.h"
+#include "lib/jxl/base/status.h"
+#include "lib/jxl/butteraugli/butteraugli.h"
+#include "lib/jxl/cms/opsin_params.h"
+#include "lib/jxl/coeff_order_fwd.h"
+#include "lib/jxl/color_encoding_internal.h"
+#include "lib/jxl/convolve.h"
+#include "lib/jxl/dec_cache.h"
+#include "lib/jxl/dec_group.h"
+#include "lib/jxl/enc_aux_out.h"
+#include "lib/jxl/enc_butteraugli_comparator.h"
+#include "lib/jxl/enc_cache.h"
+#include "lib/jxl/enc_debug_image.h"
+#include "lib/jxl/enc_group.h"
+#include "lib/jxl/enc_modular.h"
+#include "lib/jxl/enc_params.h"
+#include "lib/jxl/enc_transforms-inl.h"
+#include "lib/jxl/epf.h"
+#include "lib/jxl/frame_dimensions.h"
+#include "lib/jxl/image.h"
+#include "lib/jxl/image_bundle.h"
+#include "lib/jxl/image_ops.h"
+#include "lib/jxl/quant_weights.h"
+
+// Set JXL_DEBUG_ADAPTIVE_QUANTIZATION to 1 to enable debugging.
+#ifndef JXL_DEBUG_ADAPTIVE_QUANTIZATION
+#define JXL_DEBUG_ADAPTIVE_QUANTIZATION 0
+#endif
+
+HWY_BEFORE_NAMESPACE();
+namespace jxl {
+namespace HWY_NAMESPACE {
+namespace {
+
+// These templates are not found via ADL.
+using hwy::HWY_NAMESPACE::AbsDiff;
+using hwy::HWY_NAMESPACE::Add;
+using hwy::HWY_NAMESPACE::And;
+using hwy::HWY_NAMESPACE::Max;
+using hwy::HWY_NAMESPACE::Rebind;
+using hwy::HWY_NAMESPACE::Sqrt;
+using hwy::HWY_NAMESPACE::ZeroIfNegative;
+
+// The following functions modulate an exponent (out_val) and return the updated
+// value. Their descriptor is limited to 8 lanes for 8x8 blocks.
+
+// Hack for mask estimation. Eventually replace this code with butteraugli's
+// masking.
+float ComputeMaskForAcStrategyUse(const float out_val) {
+ const float kMul = 1.0f;
+ const float kOffset = 0.001f;
+ return kMul / (out_val + kOffset);
+}
+
+template <class D, class V>
+V ComputeMask(const D d, const V out_val) {
+ const auto kBase = Set(d, -0.7647f);
+ const auto kMul4 = Set(d, 9.4708735624378946f);
+ const auto kMul2 = Set(d, 17.35036561631863f);
+ const auto kOffset2 = Set(d, 302.59587815579727f);
+ const auto kMul3 = Set(d, 6.7943250517376494f);
+ const auto kOffset3 = Set(d, 3.7179635626140772f);
+ const auto kOffset4 = Mul(Set(d, 0.25f), kOffset3);
+ const auto kMul0 = Set(d, 0.80061762862741759f);
+ const auto k1 = Set(d, 1.0f);
+
+ // Avoid division by zero.
+ const auto v1 = Max(Mul(out_val, kMul0), Set(d, 1e-3f));
+ const auto v2 = Div(k1, Add(v1, kOffset2));
+ const auto v3 = Div(k1, MulAdd(v1, v1, kOffset3));
+ const auto v4 = Div(k1, MulAdd(v1, v1, kOffset4));
+ // TODO(jyrki):
+ // A log or two here could make sense. In butteraugli we have effectively
+ // log(log(x + C)) for this kind of use, as a single log is used in
+ // saturating visual masking and here the modulation values are exponential,
+ // another log would counter that.
+ return Add(kBase, MulAdd(kMul4, v4, MulAdd(kMul2, v2, Mul(kMul3, v3))));
+}
+
+// mul and mul2 represent a scaling difference between jxl and butteraugli.
+static const float kSGmul = 226.77216153508914f;
+static const float kSGmul2 = 1.0f / 73.377132366608819f;
+static const float kLog2 = 0.693147181f;
+// Includes correction factor for std::log -> log2.
+static const float kSGRetMul = kSGmul2 * 18.6580932135f * kLog2;
+static const float kSGVOffset = 7.7825991679894591f;
+
+template <bool invert, typename D, typename V>
+V RatioOfDerivativesOfCubicRootToSimpleGamma(const D d, V v) {
+ // The opsin space in jxl is the cubic root of photons, i.e., v * v * v
+ // is related to the number of photons.
+ //
+ // SimpleGamma(v * v * v) is the psychovisual space in butteraugli.
+ // This ratio allows quantization to move from jxl's opsin space to
+ // butteraugli's log-gamma space.
+ float kEpsilon = 1e-2;
+ v = ZeroIfNegative(v);
+ const auto kNumMul = Set(d, kSGRetMul * 3 * kSGmul);
+ const auto kVOffset = Set(d, kSGVOffset * kLog2 + kEpsilon);
+ const auto kDenMul = Set(d, kLog2 * kSGmul);
+
+ const auto v2 = Mul(v, v);
+
+ const auto num = MulAdd(kNumMul, v2, Set(d, kEpsilon));
+ const auto den = MulAdd(Mul(kDenMul, v), v2, kVOffset);
+ return invert ? Div(num, den) : Div(den, num);
+}
+
+template <bool invert = false>
+static float RatioOfDerivativesOfCubicRootToSimpleGamma(float v) {
+ using DScalar = HWY_CAPPED(float, 1);
+ auto vscalar = Load(DScalar(), &v);
+ return GetLane(
+ RatioOfDerivativesOfCubicRootToSimpleGamma<invert>(DScalar(), vscalar));
+}
+
+// TODO(veluca): this function computes an approximation of the derivative of
+// SimpleGamma with (f(x+eps)-f(x))/eps. Consider two-sided approximation or
+// exact derivatives. For reference, SimpleGamma was:
+/*
+template <typename D, typename V>
+V SimpleGamma(const D d, V v) {
+ // A simple HDR compatible gamma function.
+ const auto mul = Set(d, kSGmul);
+ const auto kRetMul = Set(d, kSGRetMul);
+ const auto kRetAdd = Set(d, kSGmul2 * -20.2789020414f);
+ const auto kVOffset = Set(d, kSGVOffset);
+
+ v *= mul;
+
+ // This should happen rarely, but may lead to a NaN, which is rather
+ // undesirable. Since negative photons don't exist we solve the NaNs by
+ // clamping here.
+ // TODO(veluca): with FastLog2f, this no longer leads to NaNs.
+ v = ZeroIfNegative(v);
+ return kRetMul * FastLog2f(d, v + kVOffset) + kRetAdd;
+}
+*/
+
+template <class D, class V>
+V GammaModulation(const D d, const size_t x, const size_t y,
+ const ImageF& xyb_x, const ImageF& xyb_y, const Rect& rect,
+ const V out_val) {
+ const float kBias = 0.16f;
+ JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[0]);
+ JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[1]);
+ JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[2]);
+ auto overall_ratio = Zero(d);
+ auto bias = Set(d, kBias);
+ auto half = Set(d, 0.5f);
+ for (size_t dy = 0; dy < 8; ++dy) {
+ const float* const JXL_RESTRICT row_in_x = rect.ConstRow(xyb_x, y + dy);
+ const float* const JXL_RESTRICT row_in_y = rect.ConstRow(xyb_y, y + dy);
+ for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
+ const auto iny = Add(Load(d, row_in_y + x + dx), bias);
+ const auto inx = Load(d, row_in_x + x + dx);
+ const auto r = Sub(iny, inx);
+ const auto g = Add(iny, inx);
+ const auto ratio_r =
+ RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, r);
+ const auto ratio_g =
+ RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, g);
+ const auto avg_ratio = Mul(half, Add(ratio_r, ratio_g));
+
+ overall_ratio = Add(overall_ratio, avg_ratio);
+ }
+ }
+ overall_ratio = Mul(SumOfLanes(d, overall_ratio), Set(d, 1.0f / 64));
+ // ideally -1.0, but likely optimal correction adds some entropy, so slightly
+ // less than that.
+ // ln(2) constant folded in because we want std::log but have FastLog2f.
+ static const float v = 0.14507933746197058f;
+ const auto kGam = Set(d, v * 0.693147180559945f);
+ return MulAdd(kGam, FastLog2f(d, overall_ratio), out_val);
+}
+
+// Change precision in 8x8 blocks that have high frequency content.
+template <class D, class V>
+V HfModulation(const D d, const size_t x, const size_t y, const ImageF& xyb,
+ const Rect& rect, const V out_val) {
+ // Zero out the invalid differences for the rightmost value per row.
+ const Rebind<uint32_t, D> du;
+ HWY_ALIGN constexpr uint32_t kMaskRight[kBlockDim] = {~0u, ~0u, ~0u, ~0u,
+ ~0u, ~0u, ~0u, 0};
+
+ auto sum = Zero(d); // sum of absolute differences with right and below
+
+ static const float valmin = 0.020602694503245016f;
+ auto valminv = Set(d, valmin);
+ for (size_t dy = 0; dy < 8; ++dy) {
+ const float* JXL_RESTRICT row_in = rect.ConstRow(xyb, y + dy) + x;
+ const float* JXL_RESTRICT row_in_next =
+ dy == 7 ? row_in : rect.ConstRow(xyb, y + dy + 1) + x;
+
+ // In SCALAR, there is no guarantee of having extra row padding.
+ // Hence, we need to ensure we don't access pixels outside the row itself.
+ // In SIMD modes, however, rows are padded, so it's safe to access one
+ // garbage value after the row. The vector then gets masked with kMaskRight
+ // to remove the influence of that value.
+#if HWY_TARGET != HWY_SCALAR
+ for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
+#else
+ for (size_t dx = 0; dx < 7; dx += Lanes(d)) {
+#endif
+ const auto p = Load(d, row_in + dx);
+ const auto pr = LoadU(d, row_in + dx + 1);
+ const auto mask = BitCast(d, Load(du, kMaskRight + dx));
+ sum = Add(sum, And(mask, Min(valminv, AbsDiff(p, pr))));
+
+ const auto pd = Load(d, row_in_next + dx);
+ sum = Add(sum, Min(valminv, AbsDiff(p, pd)));
+ }
+#if HWY_TARGET == HWY_SCALAR
+ const auto p = Load(d, row_in + 7);
+ const auto pd = Load(d, row_in_next + 7);
+ sum = Add(sum, Min(valminv, AbsDiff(p, pd)));
+#endif
+ }
+ // more negative value gives more bpp
+ static const float kOffset = -1.110929106987477;
+ static const float kMul = -0.38078920620238305;
+ sum = SumOfLanes(d, sum);
+ float scalar_sum = GetLane(sum);
+ scalar_sum += kOffset;
+ scalar_sum *= kMul;
+ return Add(Set(d, scalar_sum), out_val);
+}
+
+void PerBlockModulations(const float butteraugli_target, const ImageF& xyb_x,
+ const ImageF& xyb_y, const ImageF& xyb_b,
+ const Rect& rect_in, const float scale,
+ const Rect& rect_out, ImageF* out) {
+ float base_level = 0.48f * scale;
+ float kDampenRampStart = 2.0f;
+ float kDampenRampEnd = 14.0f;
+ float dampen = 1.0f;
+ if (butteraugli_target >= kDampenRampStart) {
+ dampen = 1.0f - ((butteraugli_target - kDampenRampStart) /
+ (kDampenRampEnd - kDampenRampStart));
+ if (dampen < 0) {
+ dampen = 0;
+ }
+ }
+ const float mul = scale * dampen;
+ const float add = (1.0f - dampen) * base_level;
+ for (size_t iy = rect_out.y0(); iy < rect_out.y1(); iy++) {
+ const size_t y = iy * 8;
+ float* const JXL_RESTRICT row_out = out->Row(iy);
+ const HWY_CAPPED(float, kBlockDim) df;
+ for (size_t ix = rect_out.x0(); ix < rect_out.x1(); ix++) {
+ size_t x = ix * 8;
+ auto out_val = Set(df, row_out[ix]);
+ out_val = ComputeMask(df, out_val);
+ out_val = HfModulation(df, x, y, xyb_y, rect_in, out_val);
+ out_val = GammaModulation(df, x, y, xyb_x, xyb_y, rect_in, out_val);
+ // We want multiplicative quantization field, so everything
+ // until this point has been modulating the exponent.
+ row_out[ix] = FastPow2f(GetLane(out_val) * 1.442695041f) * mul + add;
+ }
+ }
+}
+
+template <typename D, typename V>
+V MaskingSqrt(const D d, V v) {
+ static const float kLogOffset = 27.97044946785558f;
+ static const float kMul = 211.53333281566171f;
+ const auto mul_v = Set(d, kMul * 1e8);
+ const auto offset_v = Set(d, kLogOffset);
+ return Mul(Set(d, 0.25f), Sqrt(MulAdd(v, Sqrt(mul_v), offset_v)));
+}
+
+float MaskingSqrt(const float v) {
+ using DScalar = HWY_CAPPED(float, 1);
+ auto vscalar = Load(DScalar(), &v);
+ return GetLane(MaskingSqrt(DScalar(), vscalar));
+}
+
+void StoreMin4(const float v, float& min0, float& min1, float& min2,
+ float& min3) {
+ if (v < min3) {
+ if (v < min0) {
+ min3 = min2;
+ min2 = min1;
+ min1 = min0;
+ min0 = v;
+ } else if (v < min1) {
+ min3 = min2;
+ min2 = min1;
+ min1 = v;
+ } else if (v < min2) {
+ min3 = min2;
+ min2 = v;
+ } else {
+ min3 = v;
+ }
+ }
+}
+
+// Look for smooth areas near the area of degradation.
+// If the areas are generally smooth, don't do masking.
+// Output is downsampled 2x.
+void FuzzyErosion(const float butteraugli_target, const Rect& from_rect,
+ const ImageF& from, const Rect& to_rect, ImageF* to) {
+ const size_t xsize = from.xsize();
+ const size_t ysize = from.ysize();
+ constexpr int kStep = 1;
+ static_assert(kStep == 1, "Step must be 1");
+ JXL_ASSERT(to_rect.xsize() * 2 == from_rect.xsize());
+ JXL_ASSERT(to_rect.ysize() * 2 == from_rect.ysize());
+ static const float kMulBase0 = 0.125;
+ static const float kMulBase1 = 0.10;
+ static const float kMulBase2 = 0.09;
+ static const float kMulBase3 = 0.06;
+ static const float kMulAdd0 = 0.0;
+ static const float kMulAdd1 = -0.10;
+ static const float kMulAdd2 = -0.09;
+ static const float kMulAdd3 = -0.06;
+
+ float mul = 0.0;
+ if (butteraugli_target < 2.0f) {
+ mul = (2.0f - butteraugli_target) * (1.0f / 2.0f);
+ }
+ float kMul0 = kMulBase0 + mul * kMulAdd0;
+ float kMul1 = kMulBase1 + mul * kMulAdd1;
+ float kMul2 = kMulBase2 + mul * kMulAdd2;
+ float kMul3 = kMulBase3 + mul * kMulAdd3;
+ static const float kTotal = 0.29959705784054957;
+ float norm = kTotal / (kMul0 + kMul1 + kMul2 + kMul3);
+ kMul0 *= norm;
+ kMul1 *= norm;
+ kMul2 *= norm;
+ kMul3 *= norm;
+
+ for (size_t fy = 0; fy < from_rect.ysize(); ++fy) {
+ size_t y = fy + from_rect.y0();
+ size_t ym1 = y >= kStep ? y - kStep : y;
+ size_t yp1 = y + kStep < ysize ? y + kStep : y;
+ const float* rowt = from.Row(ym1);
+ const float* row = from.Row(y);
+ const float* rowb = from.Row(yp1);
+ float* row_out = to_rect.Row(to, fy / 2);
+ for (size_t fx = 0; fx < from_rect.xsize(); ++fx) {
+ size_t x = fx + from_rect.x0();
+ size_t xm1 = x >= kStep ? x - kStep : x;
+ size_t xp1 = x + kStep < xsize ? x + kStep : x;
+ float min0 = row[x];
+ float min1 = row[xm1];
+ float min2 = row[xp1];
+ float min3 = rowt[xm1];
+ // Sort the first four values.
+ if (min0 > min1) std::swap(min0, min1);
+ if (min0 > min2) std::swap(min0, min2);
+ if (min0 > min3) std::swap(min0, min3);
+ if (min1 > min2) std::swap(min1, min2);
+ if (min1 > min3) std::swap(min1, min3);
+ if (min2 > min3) std::swap(min2, min3);
+ // The remaining five values of a 3x3 neighbourhood.
+ StoreMin4(rowt[x], min0, min1, min2, min3);
+ StoreMin4(rowt[xp1], min0, min1, min2, min3);
+ StoreMin4(rowb[xm1], min0, min1, min2, min3);
+ StoreMin4(rowb[x], min0, min1, min2, min3);
+ StoreMin4(rowb[xp1], min0, min1, min2, min3);
+
+ float v = kMul0 * min0 + kMul1 * min1 + kMul2 * min2 + kMul3 * min3;
+ if (fx % 2 == 0 && fy % 2 == 0) {
+ row_out[fx / 2] = v;
+ } else {
+ row_out[fx / 2] += v;
+ }
+ }
+ }
+}
+
+struct AdaptiveQuantizationImpl {
+ void PrepareBuffers(size_t num_threads) {
+ diff_buffer = ImageF(kEncTileDim + 8, num_threads);
+ for (size_t i = pre_erosion.size(); i < num_threads; i++) {
+ pre_erosion.emplace_back(kEncTileDimInBlocks * 2 + 2,
+ kEncTileDimInBlocks * 2 + 2);
+ }
+ }
+
+ void ComputeTile(float butteraugli_target, float scale, const Image3F& xyb,
+ const Rect& rect_in, const Rect& rect_out, const int thread,
+ ImageF* mask, ImageF* mask1x1) {
+ JXL_ASSERT(rect_in.x0() % 8 == 0);
+ JXL_ASSERT(rect_in.y0() % 8 == 0);
+ const size_t xsize = xyb.xsize();
+ const size_t ysize = xyb.ysize();
+
+ // The XYB gamma is 3.0 to be able to decode faster with two muls.
+ // Butteraugli's gamma is matching the gamma of human eye, around 2.6.
+ // We approximate the gamma difference by adding one cubic root into
+ // the adaptive quantization. This gives us a total gamma of 2.6666
+ // for quantization uses.
+ const float match_gamma_offset = 0.019;
+
+ const HWY_FULL(float) df;
+
+ size_t y_start_1x1 = rect_in.y0() + rect_out.y0() * 8;
+ size_t y_end_1x1 = y_start_1x1 + rect_out.ysize() * 8;
+
+ size_t x_start_1x1 = rect_in.x0() + rect_out.x0() * 8;
+ size_t x_end_1x1 = x_start_1x1 + rect_out.xsize() * 8;
+
+ if (rect_in.x0() != 0 && rect_out.x0() == 0) x_start_1x1 -= 2;
+ if (rect_in.x1() < xsize && rect_out.x1() * 8 == rect_in.xsize()) {
+ x_end_1x1 += 2;
+ }
+ if (rect_in.y0() != 0 && rect_out.y0() == 0) y_start_1x1 -= 2;
+ if (rect_in.y1() < ysize && rect_out.y1() * 8 == rect_in.ysize()) {
+ y_end_1x1 += 2;
+ }
+
+ // Computes image (padded to multiple of 8x8) of local pixel differences.
+ // Subsample both directions by 4.
+ // 1x1 Laplacian of intensity.
+ for (size_t y = y_start_1x1; y < y_end_1x1; ++y) {
+ const size_t y2 = y + 1 < ysize ? y + 1 : y;
+ const size_t y1 = y > 0 ? y - 1 : y;
+ const float* row_in = xyb.ConstPlaneRow(1, y);
+ const float* row_in1 = xyb.ConstPlaneRow(1, y1);
+ const float* row_in2 = xyb.ConstPlaneRow(1, y2);
+ float* mask1x1_out = mask1x1->Row(y);
+ auto scalar_pixel1x1 = [&](size_t x) {
+ const size_t x2 = x + 1 < xsize ? x + 1 : x;
+ const size_t x1 = x > 0 ? x - 1 : x;
+ const float base =
+ 0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]);
+ const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma(
+ row_in[x] + match_gamma_offset);
+ float diff = fabs(gammac * (row_in[x] - base));
+ static const double kScaler = 1.0;
+ diff *= kScaler;
+ diff = log1p(diff);
+ static const float kMul = 1.0;
+ static const float kOffset = 0.01;
+ mask1x1_out[x] = kMul / (diff + kOffset);
+ };
+ for (size_t x = x_start_1x1; x < x_end_1x1; ++x) {
+ scalar_pixel1x1(x);
+ }
+ }
+
+ size_t y_start = rect_in.y0() + rect_out.y0() * 8;
+ size_t y_end = y_start + rect_out.ysize() * 8;
+
+ size_t x_start = rect_in.x0() + rect_out.x0() * 8;
+ size_t x_end = x_start + rect_out.xsize() * 8;
+
+ if (x_start != 0) x_start -= 4;
+ if (x_end != xsize) x_end += 4;
+ if (y_start != 0) y_start -= 4;
+ if (y_end != ysize) y_end += 4;
+ pre_erosion[thread].ShrinkTo((x_end - x_start) / 4, (y_end - y_start) / 4);
+
+ static const float limit = 0.2f;
+ for (size_t y = y_start; y < y_end; ++y) {
+ size_t y2 = y + 1 < ysize ? y + 1 : y;
+ size_t y1 = y > 0 ? y - 1 : y;
+
+ const float* row_in = xyb.ConstPlaneRow(1, y);
+ const float* row_in1 = xyb.ConstPlaneRow(1, y1);
+ const float* row_in2 = xyb.ConstPlaneRow(1, y2);
+ float* JXL_RESTRICT row_out = diff_buffer.Row(thread);
+
+ auto scalar_pixel = [&](size_t x) {
+ const size_t x2 = x + 1 < xsize ? x + 1 : x;
+ const size_t x1 = x > 0 ? x - 1 : x;
+ const float base =
+ 0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]);
+ const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma(
+ row_in[x] + match_gamma_offset);
+ float diff = gammac * (row_in[x] - base);
+ diff *= diff;
+ if (diff >= limit) {
+ diff = limit;
+ }
+ diff = MaskingSqrt(diff);
+ if ((y % 4) != 0) {
+ row_out[x - x_start] += diff;
+ } else {
+ row_out[x - x_start] = diff;
+ }
+ };
+
+ size_t x = x_start;
+ // First pixel of the row.
+ if (x_start == 0) {
+ scalar_pixel(x_start);
+ ++x;
+ }
+ // SIMD
+ const auto match_gamma_offset_v = Set(df, match_gamma_offset);
+ const auto quarter = Set(df, 0.25f);
+ for (; x + 1 + Lanes(df) < x_end; x += Lanes(df)) {
+ const auto in = LoadU(df, row_in + x);
+ const auto in_r = LoadU(df, row_in + x + 1);
+ const auto in_l = LoadU(df, row_in + x - 1);
+ const auto in_t = LoadU(df, row_in2 + x);
+ const auto in_b = LoadU(df, row_in1 + x);
+ auto base = Mul(quarter, Add(Add(in_r, in_l), Add(in_t, in_b)));
+ auto gammacv =
+ RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/false>(
+ df, Add(in, match_gamma_offset_v));
+ auto diff = Mul(gammacv, Sub(in, base));
+ diff = Mul(diff, diff);
+ diff = Min(diff, Set(df, limit));
+ diff = MaskingSqrt(df, diff);
+ if ((y & 3) != 0) {
+ diff = Add(diff, LoadU(df, row_out + x - x_start));
+ }
+ StoreU(diff, df, row_out + x - x_start);
+ }
+ // Scalar
+ for (; x < x_end; ++x) {
+ scalar_pixel(x);
+ }
+ if (y % 4 == 3) {
+ float* row_dout = pre_erosion[thread].Row((y - y_start) / 4);
+ for (size_t x = 0; x < (x_end - x_start) / 4; x++) {
+ row_dout[x] = (row_out[x * 4] + row_out[x * 4 + 1] +
+ row_out[x * 4 + 2] + row_out[x * 4 + 3]) *
+ 0.25f;
+ }
+ }
+ }
+ Rect from_rect(x_start % 8 == 0 ? 0 : 1, y_start % 8 == 0 ? 0 : 1,
+ rect_out.xsize() * 2, rect_out.ysize() * 2);
+ FuzzyErosion(butteraugli_target, from_rect, pre_erosion[thread], rect_out,
+ &aq_map);
+ for (size_t y = 0; y < rect_out.ysize(); ++y) {
+ const float* aq_map_row = rect_out.ConstRow(aq_map, y);
+ float* mask_row = rect_out.Row(mask, y);
+ for (size_t x = 0; x < rect_out.xsize(); ++x) {
+ mask_row[x] = ComputeMaskForAcStrategyUse(aq_map_row[x]);
+ }
+ }
+ PerBlockModulations(butteraugli_target, xyb.Plane(0), xyb.Plane(1),
+ xyb.Plane(2), rect_in, scale, rect_out, &aq_map);
+ }
+ std::vector<ImageF> pre_erosion;
+ ImageF aq_map;
+ ImageF diff_buffer;
+};
+
+static void Blur1x1Masking(ThreadPool* pool, ImageF* mask1x1,
+ const Rect& rect) {
+ // Blur the mask1x1 to obtain the masking image.
+ // Before blurring it contains an image of absolute value of the
+ // Laplacian of the intensity channel.
+ static const float kFilterMask1x1[5] = {
+ static_cast<float>(0.25647067633737227),
+ static_cast<float>(0.2050056912354399075),
+ static_cast<float>(0.154082048668497307),
+ static_cast<float>(0.08149576591362004441),
+ static_cast<float>(0.0512750104812308467),
+ };
+ double sum =
+ 1.0 + 4 * (kFilterMask1x1[0] + kFilterMask1x1[1] + kFilterMask1x1[2] +
+ kFilterMask1x1[4] + 2 * kFilterMask1x1[3]);
+ if (sum < 1e-5) {
+ sum = 1e-5;
+ }
+ const float normalize = static_cast<float>(1.0 / sum);
+ const float normalize_mul = normalize;
+ WeightsSymmetric5 weights =
+ WeightsSymmetric5{{HWY_REP4(normalize)},
+ {HWY_REP4(normalize_mul * kFilterMask1x1[0])},
+ {HWY_REP4(normalize_mul * kFilterMask1x1[2])},
+ {HWY_REP4(normalize_mul * kFilterMask1x1[1])},
+ {HWY_REP4(normalize_mul * kFilterMask1x1[4])},
+ {HWY_REP4(normalize_mul * kFilterMask1x1[3])}};
+ ImageF temp(rect.xsize(), rect.ysize());
+ Symmetric5(*mask1x1, rect, weights, pool, &temp);
+ *mask1x1 = std::move(temp);
+}
+
+ImageF AdaptiveQuantizationMap(const float butteraugli_target,
+ const Image3F& xyb, const Rect& rect,
+ float scale, ThreadPool* pool, ImageF* mask,
+ ImageF* mask1x1) {
+ JXL_DASSERT(rect.xsize() % kBlockDim == 0);
+ JXL_DASSERT(rect.ysize() % kBlockDim == 0);
+ AdaptiveQuantizationImpl impl;
+ const size_t xsize_blocks = rect.xsize() / kBlockDim;
+ const size_t ysize_blocks = rect.ysize() / kBlockDim;
+ impl.aq_map = ImageF(xsize_blocks, ysize_blocks);
+ *mask = ImageF(xsize_blocks, ysize_blocks);
+ *mask1x1 = ImageF(xyb.xsize(), xyb.ysize());
+ JXL_CHECK(RunOnPool(
+ pool, 0,
+ DivCeil(xsize_blocks, kEncTileDimInBlocks) *
+ DivCeil(ysize_blocks, kEncTileDimInBlocks),
+ [&](const size_t num_threads) {
+ impl.PrepareBuffers(num_threads);
+ return true;
+ },
+ [&](const uint32_t tid, const size_t thread) {
+ size_t n_enc_tiles = DivCeil(xsize_blocks, kEncTileDimInBlocks);
+ size_t tx = tid % n_enc_tiles;
+ size_t ty = tid / n_enc_tiles;
+ size_t by0 = ty * kEncTileDimInBlocks;
+ size_t by1 = std::min((ty + 1) * kEncTileDimInBlocks, ysize_blocks);
+ size_t bx0 = tx * kEncTileDimInBlocks;
+ size_t bx1 = std::min((tx + 1) * kEncTileDimInBlocks, xsize_blocks);
+ Rect rect_out(bx0, by0, bx1 - bx0, by1 - by0);
+ impl.ComputeTile(butteraugli_target, scale, xyb, rect, rect_out, thread,
+ mask, mask1x1);
+ },
+ "AQ DiffPrecompute"));
+
+ Blur1x1Masking(pool, mask1x1, rect);
+ return std::move(impl).aq_map;
+}
+
+} // namespace
+
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace jxl
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace jxl {
+HWY_EXPORT(AdaptiveQuantizationMap);
+
+namespace {
+
+// If true, prints the quantization maps at each iteration.
+constexpr bool FLAGS_dump_quant_state = false;
+
+void DumpHeatmap(const CompressParams& cparams, const AuxOut* aux_out,
+ const std::string& label, const ImageF& image,
+ float good_threshold, float bad_threshold) {
+ if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
+ Image3F heatmap = CreateHeatMapImage(image, good_threshold, bad_threshold);
+ char filename[200];
+ snprintf(filename, sizeof(filename), "%s%05d", label.c_str(),
+ aux_out->num_butteraugli_iters);
+ DumpImage(cparams, filename, heatmap);
+ }
+}
+
+void DumpHeatmaps(const CompressParams& cparams, const AuxOut* aux_out,
+ float ba_target, const ImageF& quant_field,
+ const ImageF& tile_heatmap, const ImageF& bt_diffmap) {
+ if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
+ if (!WantDebugOutput(cparams)) return;
+ ImageF inv_qmap(quant_field.xsize(), quant_field.ysize());
+ for (size_t y = 0; y < quant_field.ysize(); ++y) {
+ const float* JXL_RESTRICT row_q = quant_field.ConstRow(y);
+ float* JXL_RESTRICT row_inv_q = inv_qmap.Row(y);
+ for (size_t x = 0; x < quant_field.xsize(); ++x) {
+ row_inv_q[x] = 1.0f / row_q[x]; // never zero
+ }
+ }
+ DumpHeatmap(cparams, aux_out, "quant_heatmap", inv_qmap, 4.0f * ba_target,
+ 6.0f * ba_target);
+ DumpHeatmap(cparams, aux_out, "tile_heatmap", tile_heatmap, ba_target,
+ 1.5f * ba_target);
+ // matches heat maps produced by the command line tool.
+ DumpHeatmap(cparams, aux_out, "bt_diffmap", bt_diffmap,
+ ButteraugliFuzzyInverse(1.5), ButteraugliFuzzyInverse(0.5));
+ }
+}
+
+ImageF TileDistMap(const ImageF& distmap, int tile_size, int margin,
+ const AcStrategyImage& ac_strategy) {
+ const int tile_xsize = (distmap.xsize() + tile_size - 1) / tile_size;
+ const int tile_ysize = (distmap.ysize() + tile_size - 1) / tile_size;
+ ImageF tile_distmap(tile_xsize, tile_ysize);
+ size_t distmap_stride = tile_distmap.PixelsPerRow();
+ for (int tile_y = 0; tile_y < tile_ysize; ++tile_y) {
+ AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(tile_y);
+ float* JXL_RESTRICT dist_row = tile_distmap.Row(tile_y);
+ for (int tile_x = 0; tile_x < tile_xsize; ++tile_x) {
+ AcStrategy acs = ac_strategy_row[tile_x];
+ if (!acs.IsFirstBlock()) continue;
+ int this_tile_xsize = acs.covered_blocks_x() * tile_size;
+ int this_tile_ysize = acs.covered_blocks_y() * tile_size;
+ int y_begin = std::max<int>(0, tile_size * tile_y - margin);
+ int y_end = std::min<int>(distmap.ysize(),
+ tile_size * tile_y + this_tile_ysize + margin);
+ int x_begin = std::max<int>(0, tile_size * tile_x - margin);
+ int x_end = std::min<int>(distmap.xsize(),
+ tile_size * tile_x + this_tile_xsize + margin);
+ float dist_norm = 0.0;
+ double pixels = 0;
+ for (int y = y_begin; y < y_end; ++y) {
+ float ymul = 1.0;
+ constexpr float kBorderMul = 0.98f;
+ constexpr float kCornerMul = 0.7f;
+ if (margin != 0 && (y == y_begin || y == y_end - 1)) {
+ ymul = kBorderMul;
+ }
+ const float* const JXL_RESTRICT row = distmap.Row(y);
+ for (int x = x_begin; x < x_end; ++x) {
+ float xmul = ymul;
+ if (margin != 0 && (x == x_begin || x == x_end - 1)) {
+ if (xmul == 1.0) {
+ xmul = kBorderMul;
+ } else {
+ xmul = kCornerMul;
+ }
+ }
+ float v = row[x];
+ v *= v;
+ v *= v;
+ v *= v;
+ v *= v;
+ dist_norm += xmul * v;
+ pixels += xmul;
+ }
+ }
+ if (pixels == 0) pixels = 1;
+ // 16th norm is less than the max norm, we reduce the difference
+ // with this normalization factor.
+ constexpr float kTileNorm = 1.2f;
+ const float tile_dist =
+ kTileNorm * std::pow(dist_norm / pixels, 1.0f / 16.0f);
+ dist_row[tile_x] = tile_dist;
+ for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
+ for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
+ dist_row[tile_x + distmap_stride * iy + ix] = tile_dist;
+ }
+ }
+ }
+ }
+ return tile_distmap;
+}
+
+static const float kDcQuantPow = 0.83f;
+static const float kDcQuant = 1.095924047623553f;
+static const float kAcQuant = 0.7381485255235064f;
+
+// Computes the decoded image for a given set of compression parameters.
+ImageBundle RoundtripImage(const FrameHeader& frame_header,
+ const Image3F& opsin, PassesEncoderState* enc_state,
+ const JxlCmsInterface& cms, ThreadPool* pool) {
+ std::unique_ptr<PassesDecoderState> dec_state =
+ jxl::make_unique<PassesDecoderState>();
+ JXL_CHECK(dec_state->output_encoding_info.SetFromMetadata(
+ *enc_state->shared.metadata));
+ dec_state->shared = &enc_state->shared;
+ JXL_ASSERT(opsin.ysize() % kBlockDim == 0);
+
+ const size_t xsize_groups = DivCeil(opsin.xsize(), kGroupDim);
+ const size_t ysize_groups = DivCeil(opsin.ysize(), kGroupDim);
+ const size_t num_groups = xsize_groups * ysize_groups;
+
+ size_t num_special_frames = enc_state->special_frames.size();
+ size_t num_passes = enc_state->progressive_splitter.GetNumPasses();
+ ModularFrameEncoder modular_frame_encoder(frame_header, enc_state->cparams);
+ JXL_CHECK(InitializePassesEncoder(frame_header, opsin, Rect(opsin), cms, pool,
+ enc_state, &modular_frame_encoder,
+ nullptr));
+ JXL_CHECK(dec_state->Init(frame_header));
+ JXL_CHECK(dec_state->InitForAC(num_passes, pool));
+
+ ImageBundle decoded(&enc_state->shared.metadata->m);
+ decoded.origin = frame_header.frame_origin;
+ decoded.SetFromImage(Image3F(opsin.xsize(), opsin.ysize()),
+ dec_state->output_encoding_info.color_encoding);
+
+ PassesDecoderState::PipelineOptions options;
+ options.use_slow_render_pipeline = false;
+ options.coalescing = false;
+ options.render_spotcolors = false;
+ options.render_noise = false;
+
+ // Same as frame_header.nonserialized_metadata->m
+ const ImageMetadata& metadata = *decoded.metadata();
+
+ JXL_CHECK(dec_state->PreparePipeline(frame_header, &decoded, options));
+
+ hwy::AlignedUniquePtr<GroupDecCache[]> group_dec_caches;
+ const auto allocate_storage = [&](const size_t num_threads) -> Status {
+ JXL_RETURN_IF_ERROR(
+ dec_state->render_pipeline->PrepareForThreads(num_threads,
+ /*use_group_ids=*/false));
+ group_dec_caches = hwy::MakeUniqueAlignedArray<GroupDecCache>(num_threads);
+ return true;
+ };
+ const auto process_group = [&](const uint32_t group_index,
+ const size_t thread) {
+ if (frame_header.loop_filter.epf_iters > 0) {
+ ComputeSigma(frame_header.loop_filter,
+ dec_state->shared->frame_dim.BlockGroupRect(group_index),
+ dec_state.get());
+ }
+ RenderPipelineInput input =
+ dec_state->render_pipeline->GetInputBuffers(group_index, thread);
+ JXL_CHECK(DecodeGroupForRoundtrip(
+ frame_header, enc_state->coeffs, group_index, dec_state.get(),
+ &group_dec_caches[thread], thread, input, &decoded, nullptr));
+ for (size_t c = 0; c < metadata.num_extra_channels; c++) {
+ std::pair<ImageF*, Rect> ri = input.GetBuffer(3 + c);
+ FillPlane(0.0f, ri.first, ri.second);
+ }
+ input.Done();
+ };
+ JXL_CHECK(RunOnPool(pool, 0, num_groups, allocate_storage, process_group,
+ "AQ loop"));
+
+ // Ensure we don't create any new special frames.
+ enc_state->special_frames.resize(num_special_frames);
+
+ return decoded;
+}
+
+constexpr int kMaxButteraugliIters = 4;
+
+void FindBestQuantization(const FrameHeader& frame_header,
+ const Image3F& linear, const Image3F& opsin,
+ ImageF& quant_field, PassesEncoderState* enc_state,
+ const JxlCmsInterface& cms, ThreadPool* pool,
+ AuxOut* aux_out) {
+ const CompressParams& cparams = enc_state->cparams;
+ if (cparams.resampling > 1 &&
+ cparams.original_butteraugli_distance <= 4.0 * cparams.resampling) {
+ // For downsampled opsin image, the butteraugli based adaptive quantization
+ // loop would only make the size bigger without improving the distance much,
+ // so in this case we enable it only for very high butteraugli targets.
+ return;
+ }
+ Quantizer& quantizer = enc_state->shared.quantizer;
+ ImageI& raw_quant_field = enc_state->shared.raw_quant_field;
+
+ const float butteraugli_target = cparams.butteraugli_distance;
+ const float original_butteraugli = cparams.original_butteraugli_distance;
+ ButteraugliParams params;
+ params.intensity_target = 80.f;
+ JxlButteraugliComparator comparator(params, cms);
+ JXL_CHECK(comparator.SetLinearReferenceImage(linear));
+ bool lower_is_better =
+ (comparator.GoodQualityScore() < comparator.BadQualityScore());
+ const float initial_quant_dc = InitialQuantDC(butteraugli_target);
+ AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field),
+ original_butteraugli, &quant_field);
+ ImageF tile_distmap;
+ ImageF initial_quant_field(quant_field.xsize(), quant_field.ysize());
+ CopyImageTo(quant_field, &initial_quant_field);
+
+ float initial_qf_min, initial_qf_max;
+ ImageMinMax(initial_quant_field, &initial_qf_min, &initial_qf_max);
+ float initial_qf_ratio = initial_qf_max / initial_qf_min;
+ float qf_max_deviation_low = std::sqrt(250 / initial_qf_ratio);
+ float asymmetry = 2;
+ if (qf_max_deviation_low < asymmetry) asymmetry = qf_max_deviation_low;
+ float qf_lower = initial_qf_min / (asymmetry * qf_max_deviation_low);
+ float qf_higher = initial_qf_max * (qf_max_deviation_low / asymmetry);
+
+ JXL_ASSERT(qf_higher / qf_lower < 253);
+
+ constexpr int kOriginalComparisonRound = 1;
+ int iters = kMaxButteraugliIters;
+ if (cparams.speed_tier != SpeedTier::kTortoise) {
+ iters = 2;
+ }
+ for (int i = 0; i < iters + 1; ++i) {
+ if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
+ printf("\nQuantization field:\n");
+ for (size_t y = 0; y < quant_field.ysize(); ++y) {
+ for (size_t x = 0; x < quant_field.xsize(); ++x) {
+ printf(" %.5f", quant_field.Row(y)[x]);
+ }
+ printf("\n");
+ }
+ }
+ quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field);
+ ImageBundle dec_linear =
+ RoundtripImage(frame_header, opsin, enc_state, cms, pool);
+ float score;
+ ImageF diffmap;
+ JXL_CHECK(comparator.CompareWith(dec_linear, &diffmap, &score));
+ if (!lower_is_better) {
+ score = -score;
+ ScaleImage(-1.0f, &diffmap);
+ }
+ tile_distmap = TileDistMap(diffmap, 8 * cparams.resampling, 0,
+ enc_state->shared.ac_strategy);
+ if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && WantDebugOutput(cparams)) {
+ DumpImage(cparams, ("dec" + ToString(i)).c_str(), *dec_linear.color());
+ DumpHeatmaps(cparams, aux_out, butteraugli_target, quant_field,
+ tile_distmap, diffmap);
+ }
+ if (aux_out != nullptr) ++aux_out->num_butteraugli_iters;
+ if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
+ float minval, maxval;
+ ImageMinMax(quant_field, &minval, &maxval);
+ printf("\nButteraugli iter: %d/%d\n", i, kMaxButteraugliIters);
+ printf("Butteraugli distance: %f (target = %f)\n", score,
+ original_butteraugli);
+ printf("quant range: %f ... %f DC quant: %f\n", minval, maxval,
+ initial_quant_dc);
+ if (FLAGS_dump_quant_state) {
+ quantizer.DumpQuantizationMap(raw_quant_field);
+ }
+ }
+
+ if (i == iters) break;
+
+ double kPow[8] = {
+ 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ };
+ double kPowMod[8] = {
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ };
+ if (i == kOriginalComparisonRound) {
+ // Don't allow optimization to make the quant field a lot worse than
+ // what the initial guess was. This allows the AC field to have enough
+ // precision to reduce the oscillations due to the dc reconstruction.
+ double kInitMul = 0.6;
+ const double kOneMinusInitMul = 1.0 - kInitMul;
+ for (size_t y = 0; y < quant_field.ysize(); ++y) {
+ float* const JXL_RESTRICT row_q = quant_field.Row(y);
+ const float* const JXL_RESTRICT row_init = initial_quant_field.Row(y);
+ for (size_t x = 0; x < quant_field.xsize(); ++x) {
+ double clamp = kOneMinusInitMul * row_q[x] + kInitMul * row_init[x];
+ if (row_q[x] < clamp) {
+ row_q[x] = clamp;
+ if (row_q[x] > qf_higher) row_q[x] = qf_higher;
+ if (row_q[x] < qf_lower) row_q[x] = qf_lower;
+ }
+ }
+ }
+ }
+
+ double cur_pow = 0.0;
+ if (i < 7) {
+ cur_pow = kPow[i] + (original_butteraugli - 1.0) * kPowMod[i];
+ if (cur_pow < 0) {
+ cur_pow = 0;
+ }
+ }
+ if (cur_pow == 0.0) {
+ for (size_t y = 0; y < quant_field.ysize(); ++y) {
+ const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y);
+ float* const JXL_RESTRICT row_q = quant_field.Row(y);
+ for (size_t x = 0; x < quant_field.xsize(); ++x) {
+ const float diff = row_dist[x] / original_butteraugli;
+ if (diff > 1.0f) {
+ float old = row_q[x];
+ row_q[x] *= diff;
+ int qf_old = old * quantizer.InvGlobalScale() + 0.5;
+ int qf_new = row_q[x] * quantizer.InvGlobalScale() + 0.5;
+ if (qf_old == qf_new) {
+ row_q[x] = old + quantizer.Scale();
+ }
+ }
+ if (row_q[x] > qf_higher) row_q[x] = qf_higher;
+ if (row_q[x] < qf_lower) row_q[x] = qf_lower;
+ }
+ }
+ } else {
+ for (size_t y = 0; y < quant_field.ysize(); ++y) {
+ const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y);
+ float* const JXL_RESTRICT row_q = quant_field.Row(y);
+ for (size_t x = 0; x < quant_field.xsize(); ++x) {
+ const float diff = row_dist[x] / original_butteraugli;
+ if (diff <= 1.0f) {
+ row_q[x] *= std::pow(diff, cur_pow);
+ } else {
+ float old = row_q[x];
+ row_q[x] *= diff;
+ int qf_old = old * quantizer.InvGlobalScale() + 0.5;
+ int qf_new = row_q[x] * quantizer.InvGlobalScale() + 0.5;
+ if (qf_old == qf_new) {
+ row_q[x] = old + quantizer.Scale();
+ }
+ }
+ if (row_q[x] > qf_higher) row_q[x] = qf_higher;
+ if (row_q[x] < qf_lower) row_q[x] = qf_lower;
+ }
+ }
+ }
+ }
+ quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field);
+}
+
+void FindBestQuantizationMaxError(const FrameHeader& frame_header,
+ const Image3F& opsin, ImageF& quant_field,
+ PassesEncoderState* enc_state,
+ const JxlCmsInterface& cms, ThreadPool* pool,
+ AuxOut* aux_out) {
+ // TODO(szabadka): Make this work for non-opsin color spaces.
+ const CompressParams& cparams = enc_state->cparams;
+ Quantizer& quantizer = enc_state->shared.quantizer;
+ ImageI& raw_quant_field = enc_state->shared.raw_quant_field;
+
+ // TODO(veluca): better choice of this value.
+ const float initial_quant_dc =
+ 16 * std::sqrt(0.1f / cparams.butteraugli_distance);
+ AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field),
+ cparams.original_butteraugli_distance, &quant_field);
+
+ const float inv_max_err[3] = {1.0f / enc_state->cparams.max_error[0],
+ 1.0f / enc_state->cparams.max_error[1],
+ 1.0f / enc_state->cparams.max_error[2]};
+
+ for (int i = 0; i < kMaxButteraugliIters + 1; ++i) {
+ quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field);
+ if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) {
+ DumpXybImage(cparams, ("ops" + ToString(i)).c_str(), opsin);
+ }
+ ImageBundle decoded =
+ RoundtripImage(frame_header, opsin, enc_state, cms, pool);
+ if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) {
+ DumpXybImage(cparams, ("dec" + ToString(i)).c_str(), *decoded.color());
+ }
+ for (size_t by = 0; by < enc_state->shared.frame_dim.ysize_blocks; by++) {
+ AcStrategyRow ac_strategy_row =
+ enc_state->shared.ac_strategy.ConstRow(by);
+ for (size_t bx = 0; bx < enc_state->shared.frame_dim.xsize_blocks; bx++) {
+ AcStrategy acs = ac_strategy_row[bx];
+ if (!acs.IsFirstBlock()) continue;
+ float max_error = 0;
+ for (size_t c = 0; c < 3; c++) {
+ for (size_t y = by * kBlockDim;
+ y < (by + acs.covered_blocks_y()) * kBlockDim; y++) {
+ if (y >= decoded.ysize()) continue;
+ const float* JXL_RESTRICT in_row = opsin.ConstPlaneRow(c, y);
+ const float* JXL_RESTRICT dec_row =
+ decoded.color()->ConstPlaneRow(c, y);
+ for (size_t x = bx * kBlockDim;
+ x < (bx + acs.covered_blocks_x()) * kBlockDim; x++) {
+ if (x >= decoded.xsize()) continue;
+ max_error = std::max(
+ std::abs(in_row[x] - dec_row[x]) * inv_max_err[c], max_error);
+ }
+ }
+ }
+ // Target an error between max_error/2 and max_error.
+ // If the error in the varblock is above the target, increase the qf to
+ // compensate. If the error is below the target, decrease the qf.
+ // However, to avoid an excessive increase of the qf, only do so if the
+ // error is less than half the maximum allowed error.
+ const float qf_mul = (max_error < 0.5f) ? max_error * 2.0f
+ : (max_error > 1.0f) ? max_error
+ : 1.0f;
+ for (size_t qy = by; qy < by + acs.covered_blocks_y(); qy++) {
+ float* JXL_RESTRICT quant_field_row = quant_field.Row(qy);
+ for (size_t qx = bx; qx < bx + acs.covered_blocks_x(); qx++) {
+ quant_field_row[qx] *= qf_mul;
+ }
+ }
+ }
+ }
+ }
+ quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field);
+}
+
+} // namespace
+
+void AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect,
+ float butteraugli_target, ImageF* quant_field) {
+ // Replace the whole quant_field in non-8x8 blocks with the maximum of each
+ // 8x8 block.
+ size_t stride = quant_field->PixelsPerRow();
+
+ // At low distances it is great to use max, but mean works better
+ // at high distances. We interpolate between them for a distance
+ // range.
+ float mean_max_mixer = 1.0f;
+ {
+ static const float kLimit = 1.54138f;
+ static const float kMul = 0.56391f;
+ static const float kMin = 0.0f;
+ if (butteraugli_target > kLimit) {
+ mean_max_mixer -= (butteraugli_target - kLimit) * kMul;
+ if (mean_max_mixer < kMin) {
+ mean_max_mixer = kMin;
+ }
+ }
+ }
+ for (size_t y = 0; y < rect.ysize(); ++y) {
+ AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(rect, y);
+ float* JXL_RESTRICT quant_row = rect.Row(quant_field, y);
+ for (size_t x = 0; x < rect.xsize(); ++x) {
+ AcStrategy acs = ac_strategy_row[x];
+ if (!acs.IsFirstBlock()) continue;
+ JXL_ASSERT(x + acs.covered_blocks_x() <= quant_field->xsize());
+ JXL_ASSERT(y + acs.covered_blocks_y() <= quant_field->ysize());
+ float max = quant_row[x];
+ float mean = 0.0;
+ for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
+ for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
+ mean += quant_row[x + ix + iy * stride];
+ max = std::max(quant_row[x + ix + iy * stride], max);
+ }
+ }
+ mean /= acs.covered_blocks_y() * acs.covered_blocks_x();
+ if (acs.covered_blocks_y() * acs.covered_blocks_x() >= 4) {
+ max *= mean_max_mixer;
+ max += (1.0f - mean_max_mixer) * mean;
+ }
+ for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
+ for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
+ quant_row[x + ix + iy * stride] = max;
+ }
+ }
+ }
+ }
+}
+
+float InitialQuantDC(float butteraugli_target) {
+ const float kDcMul = 0.3; // Butteraugli target where non-linearity kicks in.
+ const float butteraugli_target_dc = std::max<float>(
+ 0.5f * butteraugli_target,
+ std::min<float>(butteraugli_target,
+ kDcMul * std::pow((1.0f / kDcMul) * butteraugli_target,
+ kDcQuantPow)));
+ // We want the maximum DC value to be at most 2**15 * kInvDCQuant / quant_dc.
+ // The maximum DC value might not be in the kXybRange because of inverse
+ // gaborish, so we add some slack to the maximum theoretical quant obtained
+ // this way (64).
+ return std::min(kDcQuant / butteraugli_target_dc, 50.f);
+}
+
+ImageF InitialQuantField(const float butteraugli_target, const Image3F& opsin,
+ const Rect& rect, ThreadPool* pool, float rescale,
+ ImageF* mask, ImageF* mask1x1) {
+ const float quant_ac = kAcQuant / butteraugli_target;
+ return HWY_DYNAMIC_DISPATCH(AdaptiveQuantizationMap)(
+ butteraugli_target, opsin, rect, quant_ac * rescale, pool, mask, mask1x1);
+}
+
+void FindBestQuantizer(const FrameHeader& frame_header, const Image3F* linear,
+ const Image3F& opsin, ImageF& quant_field,
+ PassesEncoderState* enc_state,
+ const JxlCmsInterface& cms, ThreadPool* pool,
+ AuxOut* aux_out, double rescale) {
+ const CompressParams& cparams = enc_state->cparams;
+ if (cparams.max_error_mode) {
+ FindBestQuantizationMaxError(frame_header, opsin, quant_field, enc_state,
+ cms, pool, aux_out);
+ } else if (linear && cparams.speed_tier <= SpeedTier::kKitten) {
+ // Normal encoding to a butteraugli score.
+ FindBestQuantization(frame_header, *linear, opsin, quant_field, enc_state,
+ cms, pool, aux_out);
+ }
+}
+
+} // namespace jxl
+#endif // HWY_ONCE