diff options
Diffstat (limited to '')
-rw-r--r-- | third_party/jpeg-xl/lib/jxl/render_pipeline/stage_epf.cc | 524 |
1 files changed, 524 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_epf.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_epf.cc new file mode 100644 index 0000000000..d59c497843 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_epf.cc @@ -0,0 +1,524 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_epf.h" + +#include "lib/jxl/epf.h" +#include "lib/jxl/sanitizers.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_epf.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { +// TODO(veluca): In principle, vectors could be not capped, if we want to deal +// with having two different sigma values in a single vector. +using DF = HWY_CAPPED(float, 8); + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::AbsDiff; +using hwy::HWY_NAMESPACE::Add; +using hwy::HWY_NAMESPACE::Div; +using hwy::HWY_NAMESPACE::Mul; +using hwy::HWY_NAMESPACE::MulAdd; +using hwy::HWY_NAMESPACE::Vec; +using hwy::HWY_NAMESPACE::VFromD; +using hwy::HWY_NAMESPACE::ZeroIfNegative; + +JXL_INLINE Vec<DF> Weight(Vec<DF> sad, Vec<DF> inv_sigma, Vec<DF> thres) { + auto v = MulAdd(sad, inv_sigma, Set(DF(), 1.0f)); + return ZeroIfNegative(v); +} + +// 5x5 plus-shaped kernel with 5 SADs per pixel (3x3 plus-shaped). So this makes +// this filter a 7x7 filter. +class EPF0Stage : public RenderPipelineStage { + public: + EPF0Stage(const LoopFilter& lf, const ImageF& sigma) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/3)), + lf_(lf), + sigma_(&sigma) {} + + template <bool aligned> + JXL_INLINE void AddPixel(int row, float* JXL_RESTRICT rows[3][7], ssize_t x, + Vec<DF> sad, Vec<DF> inv_sigma, + Vec<DF>* JXL_RESTRICT X, Vec<DF>* JXL_RESTRICT Y, + Vec<DF>* JXL_RESTRICT B, + Vec<DF>* JXL_RESTRICT w) const { + auto cx = aligned ? Load(DF(), rows[0][3 + row] + x) + : LoadU(DF(), rows[0][3 + row] + x); + auto cy = aligned ? Load(DF(), rows[1][3 + row] + x) + : LoadU(DF(), rows[1][3 + row] + x); + auto cb = aligned ? Load(DF(), rows[2][3 + row] + x) + : LoadU(DF(), rows[2][3 + row] + x); + + auto weight = Weight(sad, inv_sigma, Set(DF(), lf_.epf_pass1_zeroflush)); + *w = Add(*w, weight); + *X = MulAdd(weight, cx, *X); + *Y = MulAdd(weight, cy, *Y); + *B = MulAdd(weight, cb, *B); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + DF df; + + using V = decltype(Zero(df)); + V t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, tA, tB; + V* sads[12] = {&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7, &t8, &t9, &tA, &tB}; + + xextra = RoundUpTo(xextra, Lanes(df)); + const float* JXL_RESTRICT row_sigma = + sigma_->Row(ypos / kBlockDim + kSigmaPadding); + + float sm = lf_.epf_pass0_sigma_scale * 1.65; + float bsm = sm * lf_.epf_border_sad_mul; + + HWY_ALIGN float sad_mul_center[kBlockDim] = {bsm, sm, sm, sm, + sm, sm, sm, bsm}; + HWY_ALIGN float sad_mul_border[kBlockDim] = {bsm, bsm, bsm, bsm, + bsm, bsm, bsm, bsm}; + float* JXL_RESTRICT rows[3][7]; + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 7; i++) { + rows[c][i] = GetInputRow(input_rows, c, i - 3); + } + } + + const float* sad_mul = + (ypos % kBlockDim == 0 || ypos % kBlockDim == kBlockDim - 1) + ? sad_mul_border + : sad_mul_center; + + for (ssize_t x = -xextra; x < static_cast<ssize_t>(xsize + xextra); + x += Lanes(df)) { + size_t bx = (x + xpos + kSigmaPadding * kBlockDim) / kBlockDim; + size_t ix = (x + xpos) % kBlockDim; + + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows[c][3 + 0] + x); + StoreU(px, df, GetOutputRow(output_rows, c, 0) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Mul(Set(df, row_sigma[bx]), sm); + + for (size_t i = 0; i < 12; i++) *sads[i] = Zero(df); + constexpr std::array<int, 2> sads_off[12] = { + {{-2, 0}}, {{-1, -1}}, {{-1, 0}}, {{-1, 1}}, {{0, -2}}, {{0, -1}}, + {{0, 1}}, {{0, 2}}, {{1, -1}}, {{1, 0}}, {{1, 1}}, {{2, 0}}, + }; + + // compute sads + // TODO(veluca): consider unrolling and optimizing this. + for (size_t c = 0; c < 3; c++) { + auto scale = Set(df, lf_.epf_channel_scale[c]); + for (size_t i = 0; i < 12; i++) { + auto sad = Zero(df); + constexpr std::array<int, 2> plus_off[] = { + {{0, 0}}, {{-1, 0}}, {{0, -1}}, {{1, 0}}, {{0, 1}}}; + for (size_t j = 0; j < 5; j++) { + const auto r11 = + LoadU(df, rows[c][3 + plus_off[j][0]] + x + plus_off[j][1]); + const auto c11 = + LoadU(df, rows[c][3 + sads_off[i][0] + plus_off[j][0]] + x + + sads_off[i][1] + plus_off[j][1]); + sad = Add(sad, AbsDiff(r11, c11)); + } + *sads[i] = MulAdd(sad, scale, *sads[i]); + } + } + const auto x_cc = Load(df, rows[0][3 + 0] + x); + const auto y_cc = Load(df, rows[1][3 + 0] + x); + const auto b_cc = Load(df, rows[2][3 + 0] + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + for (size_t i = 0; i < 12; i++) { + AddPixel</*aligned=*/false>(/*row=*/sads_off[i][0], rows, + x + sads_off[i][1], *sads[i], inv_sigma, &X, + &Y, &B, &w); + } +#if JXL_HIGH_PRECISION + auto inv_w = Div(Set(df, 1.0f), w); +#else + auto inv_w = ApproximateReciprocal(w); +#endif + StoreU(Mul(X, inv_w), df, GetOutputRow(output_rows, 0, 0) + x); + StoreU(Mul(Y, inv_w), df, GetOutputRow(output_rows, 1, 0) + x); + StoreU(Mul(B, inv_w), df, GetOutputRow(output_rows, 2, 0) + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "EPF0"; } + + private: + LoopFilter lf_; + const ImageF* sigma_; +}; + +// 3x3 plus-shaped kernel with 5 SADs per pixel (also 3x3 plus-shaped). So this +// makes this filter a 5x5 filter. +class EPF1Stage : public RenderPipelineStage { + public: + EPF1Stage(const LoopFilter& lf, const ImageF& sigma) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/2)), + lf_(lf), + sigma_(&sigma) {} + + template <bool aligned> + JXL_INLINE void AddPixel(int row, float* JXL_RESTRICT rows[3][5], ssize_t x, + Vec<DF> sad, Vec<DF> inv_sigma, + Vec<DF>* JXL_RESTRICT X, Vec<DF>* JXL_RESTRICT Y, + Vec<DF>* JXL_RESTRICT B, + Vec<DF>* JXL_RESTRICT w) const { + auto cx = aligned ? Load(DF(), rows[0][2 + row] + x) + : LoadU(DF(), rows[0][2 + row] + x); + auto cy = aligned ? Load(DF(), rows[1][2 + row] + x) + : LoadU(DF(), rows[1][2 + row] + x); + auto cb = aligned ? Load(DF(), rows[2][2 + row] + x) + : LoadU(DF(), rows[2][2 + row] + x); + + auto weight = Weight(sad, inv_sigma, Set(DF(), lf_.epf_pass1_zeroflush)); + *w = Add(*w, weight); + *X = MulAdd(weight, cx, *X); + *Y = MulAdd(weight, cy, *Y); + *B = MulAdd(weight, cb, *B); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + DF df; + xextra = RoundUpTo(xextra, Lanes(df)); + const float* JXL_RESTRICT row_sigma = + sigma_->Row(ypos / kBlockDim + kSigmaPadding); + + float sm = 1.65f; + float bsm = sm * lf_.epf_border_sad_mul; + + HWY_ALIGN float sad_mul_center[kBlockDim] = {bsm, sm, sm, sm, + sm, sm, sm, bsm}; + HWY_ALIGN float sad_mul_border[kBlockDim] = {bsm, bsm, bsm, bsm, + bsm, bsm, bsm, bsm}; + + float* JXL_RESTRICT rows[3][5]; + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 5; i++) { + rows[c][i] = GetInputRow(input_rows, c, i - 2); + } + } + + const float* sad_mul = + (ypos % kBlockDim == 0 || ypos % kBlockDim == kBlockDim - 1) + ? sad_mul_border + : sad_mul_center; + + for (ssize_t x = -xextra; x < static_cast<ssize_t>(xsize + xextra); + x += Lanes(df)) { + size_t bx = (x + xpos + kSigmaPadding * kBlockDim) / kBlockDim; + size_t ix = (x + xpos) % kBlockDim; + + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows[c][2 + 0] + x); + Store(px, df, GetOutputRow(output_rows, c, 0) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Mul(Set(df, row_sigma[bx]), sm); + auto sad0 = Zero(df); + auto sad1 = Zero(df); + auto sad2 = Zero(df); + auto sad3 = Zero(df); + + // compute sads + for (size_t c = 0; c < 3; c++) { + // center px = 22, px above = 21 + auto t = Undefined(df); + + const auto p20 = Load(df, rows[c][2 + -2] + x); + const auto p21 = Load(df, rows[c][2 + -1] + x); + auto sad0c = AbsDiff(p20, p21); // SAD 2, 1 + + const auto p11 = LoadU(df, rows[c][2 + -1] + x - 1); + auto sad1c = AbsDiff(p11, p21); // SAD 1, 2 + + const auto p31 = LoadU(df, rows[c][2 + -1] + x + 1); + auto sad2c = AbsDiff(p31, p21); // SAD 3, 2 + + const auto p02 = LoadU(df, rows[c][2 + 0] + x - 2); + const auto p12 = LoadU(df, rows[c][2 + 0] + x - 1); + sad1c = Add(sad1c, AbsDiff(p02, p12)); // SAD 1, 2 + sad0c = Add(sad0c, AbsDiff(p11, p12)); // SAD 2, 1 + + const auto p22 = LoadU(df, rows[c][2 + 0] + x); + t = AbsDiff(p12, p22); + sad1c = Add(sad1c, t); // SAD 1, 2 + sad2c = Add(sad2c, t); // SAD 3, 2 + t = AbsDiff(p22, p21); + auto sad3c = t; // SAD 2, 3 + sad0c = Add(sad0c, t); // SAD 2, 1 + + const auto p32 = LoadU(df, rows[c][2 + 0] + x + 1); + sad0c = Add(sad0c, AbsDiff(p31, p32)); // SAD 2, 1 + t = AbsDiff(p22, p32); + sad1c = Add(sad1c, t); // SAD 1, 2 + sad2c = Add(sad2c, t); // SAD 3, 2 + + const auto p42 = LoadU(df, rows[c][2 + 0] + x + 2); + sad2c = Add(sad2c, AbsDiff(p42, p32)); // SAD 3, 2 + + const auto p13 = LoadU(df, rows[c][2 + 1] + x - 1); + sad3c = Add(sad3c, AbsDiff(p13, p12)); // SAD 2, 3 + + const auto p23 = Load(df, rows[c][2 + 1] + x); + t = AbsDiff(p22, p23); + sad0c = Add(sad0c, t); // SAD 2, 1 + sad3c = Add(sad3c, t); // SAD 2, 3 + sad1c = Add(sad1c, AbsDiff(p13, p23)); // SAD 1, 2 + + const auto p33 = LoadU(df, rows[c][2 + 1] + x + 1); + sad2c = Add(sad2c, AbsDiff(p33, p23)); // SAD 3, 2 + sad3c = Add(sad3c, AbsDiff(p33, p32)); // SAD 2, 3 + + const auto p24 = Load(df, rows[c][2 + 2] + x); + sad3c = Add(sad3c, AbsDiff(p24, p23)); // SAD 2, 3 + + auto scale = Set(df, lf_.epf_channel_scale[c]); + sad0 = MulAdd(sad0c, scale, sad0); + sad1 = MulAdd(sad1c, scale, sad1); + sad2 = MulAdd(sad2c, scale, sad2); + sad3 = MulAdd(sad3c, scale, sad3); + } + const auto x_cc = Load(df, rows[0][2 + 0] + x); + const auto y_cc = Load(df, rows[1][2 + 0] + x); + const auto b_cc = Load(df, rows[2][2 + 0] + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + // Top row + AddPixel</*aligned=*/true>(/*row=*/-1, rows, x, sad0, inv_sigma, &X, &Y, + &B, &w); + // Center + AddPixel</*aligned=*/false>(/*row=*/0, rows, x - 1, sad1, inv_sigma, &X, + &Y, &B, &w); + AddPixel</*aligned=*/false>(/*row=*/0, rows, x + 1, sad2, inv_sigma, &X, + &Y, &B, &w); + // Bottom + AddPixel</*aligned=*/true>(/*row=*/1, rows, x, sad3, inv_sigma, &X, &Y, + &B, &w); +#if JXL_HIGH_PRECISION + auto inv_w = Div(Set(df, 1.0f), w); +#else + auto inv_w = ApproximateReciprocal(w); +#endif + Store(Mul(X, inv_w), df, GetOutputRow(output_rows, 0, 0) + x); + Store(Mul(Y, inv_w), df, GetOutputRow(output_rows, 1, 0) + x); + Store(Mul(B, inv_w), df, GetOutputRow(output_rows, 2, 0) + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "EPF1"; } + + private: + LoopFilter lf_; + const ImageF* sigma_; +}; + +// 3x3 plus-shaped kernel with 1 SAD per pixel. So this makes this filter a 3x3 +// filter. +class EPF2Stage : public RenderPipelineStage { + public: + EPF2Stage(const LoopFilter& lf, const ImageF& sigma) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/0, /*border=*/1)), + lf_(lf), + sigma_(&sigma) {} + + template <bool aligned> + JXL_INLINE void AddPixel(int row, float* JXL_RESTRICT rows[3][3], ssize_t x, + Vec<DF> rx, Vec<DF> ry, Vec<DF> rb, + Vec<DF> inv_sigma, Vec<DF>* JXL_RESTRICT X, + Vec<DF>* JXL_RESTRICT Y, Vec<DF>* JXL_RESTRICT B, + Vec<DF>* JXL_RESTRICT w) const { + auto cx = aligned ? Load(DF(), rows[0][1 + row] + x) + : LoadU(DF(), rows[0][1 + row] + x); + auto cy = aligned ? Load(DF(), rows[1][1 + row] + x) + : LoadU(DF(), rows[1][1 + row] + x); + auto cb = aligned ? Load(DF(), rows[2][1 + row] + x) + : LoadU(DF(), rows[2][1 + row] + x); + + auto sad = Mul(AbsDiff(cx, rx), Set(DF(), lf_.epf_channel_scale[0])); + sad = MulAdd(AbsDiff(cy, ry), Set(DF(), lf_.epf_channel_scale[1]), sad); + sad = MulAdd(AbsDiff(cb, rb), Set(DF(), lf_.epf_channel_scale[2]), sad); + + auto weight = Weight(sad, inv_sigma, Set(DF(), lf_.epf_pass2_zeroflush)); + + *w = Add(*w, weight); + *X = MulAdd(weight, cx, *X); + *Y = MulAdd(weight, cy, *Y); + *B = MulAdd(weight, cb, *B); + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + DF df; + xextra = RoundUpTo(xextra, Lanes(df)); + const float* JXL_RESTRICT row_sigma = + sigma_->Row(ypos / kBlockDim + kSigmaPadding); + + float sm = lf_.epf_pass2_sigma_scale * 1.65; + float bsm = sm * lf_.epf_border_sad_mul; + + HWY_ALIGN float sad_mul_center[kBlockDim] = {bsm, sm, sm, sm, + sm, sm, sm, bsm}; + HWY_ALIGN float sad_mul_border[kBlockDim] = {bsm, bsm, bsm, bsm, + bsm, bsm, bsm, bsm}; + + float* JXL_RESTRICT rows[3][3]; + for (size_t c = 0; c < 3; c++) { + for (int i = 0; i < 3; i++) { + rows[c][i] = GetInputRow(input_rows, c, i - 1); + } + } + + const float* sad_mul = + (ypos % kBlockDim == 0 || ypos % kBlockDim == kBlockDim - 1) + ? sad_mul_border + : sad_mul_center; + + for (ssize_t x = -xextra; x < static_cast<ssize_t>(xsize + xextra); + x += Lanes(df)) { + size_t bx = (x + xpos + kSigmaPadding * kBlockDim) / kBlockDim; + size_t ix = (x + xpos) % kBlockDim; + + if (row_sigma[bx] < kMinSigma) { + for (size_t c = 0; c < 3; c++) { + auto px = Load(df, rows[c][1 + 0] + x); + Store(px, df, GetOutputRow(output_rows, c, 0) + x); + } + continue; + } + + const auto sm = Load(df, sad_mul + ix); + const auto inv_sigma = Mul(Set(df, row_sigma[bx]), sm); + + const auto x_cc = Load(df, rows[0][1 + 0] + x); + const auto y_cc = Load(df, rows[1][1 + 0] + x); + const auto b_cc = Load(df, rows[2][1 + 0] + x); + + auto w = Set(df, 1); + auto X = x_cc; + auto Y = y_cc; + auto B = b_cc; + + // Top row + AddPixel</*aligned=*/true>(/*row=*/-1, rows, x, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); + // Center + AddPixel</*aligned=*/false>(/*row=*/0, rows, x - 1, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); + AddPixel</*aligned=*/false>(/*row=*/0, rows, x + 1, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); + // Bottom + AddPixel</*aligned=*/true>(/*row=*/1, rows, x, x_cc, y_cc, b_cc, + inv_sigma, &X, &Y, &B, &w); +#if JXL_HIGH_PRECISION + auto inv_w = Div(Set(df, 1.0f), w); +#else + auto inv_w = ApproximateReciprocal(w); +#endif + Store(Mul(X, inv_w), df, GetOutputRow(output_rows, 0, 0) + x); + Store(Mul(Y, inv_w), df, GetOutputRow(output_rows, 1, 0) + x); + Store(Mul(B, inv_w), df, GetOutputRow(output_rows, 2, 0) + x); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c < 3 ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "EPF2"; } + + private: + LoopFilter lf_; + const ImageF* sigma_; +}; + +std::unique_ptr<RenderPipelineStage> GetEPFStage0(const LoopFilter& lf, + const ImageF& sigma) { + return jxl::make_unique<EPF0Stage>(lf, sigma); +} + +std::unique_ptr<RenderPipelineStage> GetEPFStage1(const LoopFilter& lf, + const ImageF& sigma) { + return jxl::make_unique<EPF1Stage>(lf, sigma); +} + +std::unique_ptr<RenderPipelineStage> GetEPFStage2(const LoopFilter& lf, + const ImageF& sigma) { + return jxl::make_unique<EPF2Stage>(lf, sigma); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetEPFStage0); +HWY_EXPORT(GetEPFStage1); +HWY_EXPORT(GetEPFStage2); + +std::unique_ptr<RenderPipelineStage> GetEPFStage(const LoopFilter& lf, + const ImageF& sigma, + size_t epf_stage) { + JXL_ASSERT(lf.epf_iters != 0); + switch (epf_stage) { + case 0: + return HWY_DYNAMIC_DISPATCH(GetEPFStage0)(lf, sigma); + case 1: + return HWY_DYNAMIC_DISPATCH(GetEPFStage1)(lf, sigma); + case 2: + return HWY_DYNAMIC_DISPATCH(GetEPFStage2)(lf, sigma); + default: + JXL_ABORT("Invalid EPF stage"); + } +} + +} // namespace jxl +#endif |