summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/enc_noise.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/enc_noise.cc')
-rw-r--r--third_party/jpeg-xl/lib/jxl/enc_noise.cc372
1 files changed, 372 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/enc_noise.cc b/third_party/jpeg-xl/lib/jxl/enc_noise.cc
new file mode 100644
index 0000000000..a12a9e6dc4
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/enc_noise.cc
@@ -0,0 +1,372 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "lib/jxl/enc_noise.h"
+
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <algorithm>
+#include <numeric>
+#include <utility>
+
+#include "lib/jxl/base/compiler_specific.h"
+#include "lib/jxl/chroma_from_luma.h"
+#include "lib/jxl/convolve.h"
+#include "lib/jxl/enc_aux_out.h"
+#include "lib/jxl/enc_optimize.h"
+#include "lib/jxl/image_ops.h"
+
+namespace jxl {
+namespace {
+
+using OptimizeArray = optimize::Array<double, NoiseParams::kNumNoisePoints>;
+
+float GetScoreSumsOfAbsoluteDifferences(const Image3F& opsin, const int x,
+ const int y, const int block_size) {
+ const int small_bl_size_x = 3;
+ const int small_bl_size_y = 4;
+ const int kNumSAD =
+ (block_size - small_bl_size_x) * (block_size - small_bl_size_y);
+ // block_size x block_size reference pixels
+ int counter = 0;
+ const int offset = 2;
+
+ std::vector<float> sad(kNumSAD, 0);
+ for (int y_bl = 0; y_bl + small_bl_size_y < block_size; ++y_bl) {
+ for (int x_bl = 0; x_bl + small_bl_size_x < block_size; ++x_bl) {
+ float sad_sum = 0;
+ // size of the center patch, we compare all the patches inside window with
+ // the center one
+ for (int cy = 0; cy < small_bl_size_y; ++cy) {
+ for (int cx = 0; cx < small_bl_size_x; ++cx) {
+ float wnd = 0.5f * (opsin.PlaneRow(1, y + y_bl + cy)[x + x_bl + cx] +
+ opsin.PlaneRow(0, y + y_bl + cy)[x + x_bl + cx]);
+ float center =
+ 0.5f * (opsin.PlaneRow(1, y + offset + cy)[x + offset + cx] +
+ opsin.PlaneRow(0, y + offset + cy)[x + offset + cx]);
+ sad_sum += std::abs(center - wnd);
+ }
+ }
+ sad[counter++] = sad_sum;
+ }
+ }
+ const int kSamples = (kNumSAD) / 2;
+ // As with ROAD (rank order absolute distance), we keep the smallest half of
+ // the values in SAD (we use here the more robust patch SAD instead of
+ // absolute single-pixel differences).
+ std::sort(sad.begin(), sad.end());
+ const float total_sad_sum =
+ std::accumulate(sad.begin(), sad.begin() + kSamples, 0.0f);
+ return total_sad_sum / kSamples;
+}
+
+class NoiseHistogram {
+ public:
+ static constexpr int kBins = 256;
+
+ NoiseHistogram() { std::fill(bins, bins + kBins, 0); }
+
+ void Increment(const float x) { bins[Index(x)] += 1; }
+ int Get(const float x) const { return bins[Index(x)]; }
+ int Bin(const size_t bin) const { return bins[bin]; }
+
+ int Mode() const {
+ size_t max_idx = 0;
+ for (size_t i = 0; i < kBins; i++) {
+ if (bins[i] > bins[max_idx]) max_idx = i;
+ }
+ return max_idx;
+ }
+
+ double Quantile(double q01) const {
+ const int64_t total = std::accumulate(bins, bins + kBins, int64_t{1});
+ const int64_t target = static_cast<int64_t>(q01 * total);
+ // Until sum >= target:
+ int64_t sum = 0;
+ size_t i = 0;
+ for (; i < kBins; ++i) {
+ sum += bins[i];
+ // Exact match: assume middle of bin i
+ if (sum == target) {
+ return i + 0.5;
+ }
+ if (sum > target) break;
+ }
+
+ // Next non-empty bin (in case histogram is sparsely filled)
+ size_t next = i + 1;
+ while (next < kBins && bins[next] == 0) {
+ ++next;
+ }
+
+ // Linear interpolation according to how far into next we went
+ const double excess = target - sum;
+ const double weight_next = bins[Index(next)] / excess;
+ return ClampX(next * weight_next + i * (1.0 - weight_next));
+ }
+
+ // Inter-quartile range
+ double IQR() const { return Quantile(0.75) - Quantile(0.25); }
+
+ private:
+ template <typename T>
+ T ClampX(const T x) const {
+ return std::min(std::max(T(0), x), T(kBins - 1));
+ }
+ size_t Index(const float x) const { return ClampX(static_cast<int>(x)); }
+
+ uint32_t bins[kBins];
+};
+
+std::vector<float> GetSADScoresForPatches(const Image3F& opsin,
+ const size_t block_s,
+ const size_t num_bin,
+ NoiseHistogram* sad_histogram) {
+ std::vector<float> sad_scores(
+ (opsin.ysize() / block_s) * (opsin.xsize() / block_s), 0.0f);
+
+ int block_index = 0;
+
+ for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) {
+ for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) {
+ float sad_sc = GetScoreSumsOfAbsoluteDifferences(opsin, x, y, block_s);
+ sad_scores[block_index++] = sad_sc;
+ sad_histogram->Increment(sad_sc * num_bin);
+ }
+ }
+ return sad_scores;
+}
+
+float GetSADThreshold(const NoiseHistogram& histogram, const int num_bin) {
+ // Here we assume that the most patches with similar SAD value is a "flat"
+ // patches. However, some images might contain regular texture part and
+ // generate second strong peak at the histogram
+ // TODO(user) handle bimodal and heavy-tailed case
+ const int mode = histogram.Mode();
+ return static_cast<float>(mode) / NoiseHistogram::kBins;
+}
+
+// loss = sum asym * (F(x) - nl)^2 + kReg * num_points * sum (w[i] - w[i+1])^2
+// where asym = 1 if F(x) < nl, kAsym if F(x) > nl.
+struct LossFunction {
+ explicit LossFunction(std::vector<NoiseLevel> nl0) : nl(std::move(nl0)) {}
+
+ double Compute(const OptimizeArray& w, OptimizeArray* df,
+ bool skip_regularization = false) const {
+ constexpr double kReg = 0.005;
+ constexpr double kAsym = 1.1;
+ double loss_function = 0;
+ for (size_t i = 0; i < w.size(); i++) {
+ (*df)[i] = 0;
+ }
+ for (auto ind : nl) {
+ std::pair<int, float> pos = IndexAndFrac(ind.intensity);
+ JXL_DASSERT(pos.first >= 0 && static_cast<size_t>(pos.first) <
+ NoiseParams::kNumNoisePoints - 1);
+ double low = w[pos.first];
+ double hi = w[pos.first + 1];
+ double val = low * (1.0f - pos.second) + hi * pos.second;
+ double dist = val - ind.noise_level;
+ if (dist > 0) {
+ loss_function += kAsym * dist * dist;
+ (*df)[pos.first] -= kAsym * (1.0f - pos.second) * dist;
+ (*df)[pos.first + 1] -= kAsym * pos.second * dist;
+ } else {
+ loss_function += dist * dist;
+ (*df)[pos.first] -= (1.0f - pos.second) * dist;
+ (*df)[pos.first + 1] -= pos.second * dist;
+ }
+ }
+ if (skip_regularization) return loss_function;
+ for (size_t i = 0; i + 1 < w.size(); i++) {
+ double diff = w[i] - w[i + 1];
+ loss_function += kReg * nl.size() * diff * diff;
+ (*df)[i] -= kReg * diff * nl.size();
+ (*df)[i + 1] += kReg * diff * nl.size();
+ }
+ return loss_function;
+ }
+
+ std::vector<NoiseLevel> nl;
+};
+
+void OptimizeNoiseParameters(const std::vector<NoiseLevel>& noise_level,
+ NoiseParams* noise_params) {
+ constexpr double kMaxError = 1e-3;
+ static const double kPrecision = 1e-8;
+ static const int kMaxIter = 40;
+
+ float avg = 0;
+ for (const NoiseLevel& nl : noise_level) {
+ avg += nl.noise_level;
+ }
+ avg /= noise_level.size();
+
+ LossFunction loss_function(noise_level);
+ OptimizeArray parameter_vector;
+ for (size_t i = 0; i < parameter_vector.size(); i++) {
+ parameter_vector[i] = avg;
+ }
+
+ parameter_vector = optimize::OptimizeWithScaledConjugateGradientMethod(
+ loss_function, parameter_vector, kPrecision, kMaxIter);
+
+ OptimizeArray df = parameter_vector;
+ float loss = loss_function.Compute(parameter_vector, &df,
+ /*skip_regularization=*/true) /
+ noise_level.size();
+
+ // Approximation went too badly: escape with no noise at all.
+ if (loss > kMaxError) {
+ noise_params->Clear();
+ return;
+ }
+
+ for (size_t i = 0; i < parameter_vector.size(); i++) {
+ noise_params->lut[i] = std::max(parameter_vector[i], 0.0);
+ }
+}
+
+std::vector<NoiseLevel> GetNoiseLevel(
+ const Image3F& opsin, const std::vector<float>& texture_strength,
+ const float threshold, const size_t block_s) {
+ std::vector<NoiseLevel> noise_level_per_intensity;
+
+ const int filt_size = 1;
+ static const float kLaplFilter[filt_size * 2 + 1][filt_size * 2 + 1] = {
+ {-0.25f, -1.0f, -0.25f},
+ {-1.0f, 5.0f, -1.0f},
+ {-0.25f, -1.0f, -0.25f},
+ };
+
+ // The noise model is built based on channel 0.5 * (X+Y) as we notice that it
+ // is similar to the model 0.5 * (Y-X)
+ size_t patch_index = 0;
+
+ for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) {
+ for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) {
+ if (texture_strength[patch_index] <= threshold) {
+ // Calculate mean value
+ float mean_int = 0;
+ for (size_t y_bl = 0; y_bl < block_s; ++y_bl) {
+ for (size_t x_bl = 0; x_bl < block_s; ++x_bl) {
+ mean_int += 0.5f * (opsin.PlaneRow(1, y + y_bl)[x + x_bl] +
+ opsin.PlaneRow(0, y + y_bl)[x + x_bl]);
+ }
+ }
+ mean_int /= block_s * block_s;
+
+ // Calculate Noise level
+ float noise_level = 0;
+ size_t count = 0;
+ for (size_t y_bl = 0; y_bl < block_s; ++y_bl) {
+ for (size_t x_bl = 0; x_bl < block_s; ++x_bl) {
+ float filtered_value = 0;
+ for (int y_f = -1 * filt_size; y_f <= filt_size; ++y_f) {
+ if ((static_cast<ssize_t>(y_bl) + y_f) >= 0 &&
+ (y_bl + y_f) < block_s) {
+ for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) {
+ if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 &&
+ (x_bl + x_f) < block_s) {
+ filtered_value +=
+ 0.5f *
+ (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl + x_f] +
+ opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl + x_f]) *
+ kLaplFilter[y_f + filt_size][x_f + filt_size];
+ } else {
+ filtered_value +=
+ 0.5f *
+ (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl - x_f] +
+ opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl - x_f]) *
+ kLaplFilter[y_f + filt_size][x_f + filt_size];
+ }
+ }
+ } else {
+ for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) {
+ if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 &&
+ (x_bl + x_f) < block_s) {
+ filtered_value +=
+ 0.5f *
+ (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl + x_f] +
+ opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl + x_f]) *
+ kLaplFilter[y_f + filt_size][x_f + filt_size];
+ } else {
+ filtered_value +=
+ 0.5f *
+ (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl - x_f] +
+ opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl - x_f]) *
+ kLaplFilter[y_f + filt_size][x_f + filt_size];
+ }
+ }
+ }
+ }
+ noise_level += std::abs(filtered_value);
+ ++count;
+ }
+ }
+ noise_level /= count;
+ NoiseLevel nl;
+ nl.intensity = mean_int;
+ nl.noise_level = noise_level;
+ noise_level_per_intensity.push_back(nl);
+ }
+ ++patch_index;
+ }
+ }
+ return noise_level_per_intensity;
+}
+
+void EncodeFloatParam(float val, float precision, BitWriter* writer) {
+ JXL_ASSERT(val >= 0);
+ const int absval_quant = static_cast<int>(val * precision + 0.5f);
+ JXL_ASSERT(absval_quant < (1 << 10));
+ writer->Write(10, absval_quant);
+}
+
+} // namespace
+
+Status GetNoiseParameter(const Image3F& opsin, NoiseParams* noise_params,
+ float quality_coef) {
+ // The size of a patch in decoder might be different from encoder's patch
+ // size.
+ // For encoder: the patch size should be big enough to estimate
+ // noise level, but, at the same time, it should be not too big
+ // to be able to estimate intensity value of the patch
+ const size_t block_s = 8;
+ const size_t kNumBin = 256;
+ NoiseHistogram sad_histogram;
+ std::vector<float> sad_scores =
+ GetSADScoresForPatches(opsin, block_s, kNumBin, &sad_histogram);
+ float sad_threshold = GetSADThreshold(sad_histogram, kNumBin);
+ // If threshold is too large, the image has a strong pattern. This pattern
+ // fools our model and it will add too much noise. Therefore, we do not add
+ // noise for such images
+ if (sad_threshold > 0.15f || sad_threshold <= 0.0f) {
+ noise_params->Clear();
+ return false;
+ }
+ std::vector<NoiseLevel> nl =
+ GetNoiseLevel(opsin, sad_scores, sad_threshold, block_s);
+
+ OptimizeNoiseParameters(nl, noise_params);
+ for (float& i : noise_params->lut) {
+ i *= quality_coef * 1.4;
+ }
+ return noise_params->HasAny();
+}
+
+void EncodeNoise(const NoiseParams& noise_params, BitWriter* writer,
+ size_t layer, AuxOut* aux_out) {
+ JXL_ASSERT(noise_params.HasAny());
+
+ BitWriter::Allotment allotment(writer, NoiseParams::kNumNoisePoints * 16);
+ for (float i : noise_params.lut) {
+ EncodeFloatParam(i, kNoisePrecision, writer);
+ }
+ allotment.ReclaimAndCharge(writer, layer, aux_out);
+}
+
+} // namespace jxl