summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_gaborish.cc
blob: fc90acb476e7ce6d2ffdb07c5405e0ff1720af68 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include "lib/jxl/render_pipeline/stage_gaborish.h"

#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_gaborish.cc"
#include <hwy/foreach_target.h>
#include <hwy/highway.h>

HWY_BEFORE_NAMESPACE();
namespace jxl {
namespace HWY_NAMESPACE {

// These templates are not found via ADL.
using hwy::HWY_NAMESPACE::Add;
using hwy::HWY_NAMESPACE::Mul;
using hwy::HWY_NAMESPACE::MulAdd;

class GaborishStage : public RenderPipelineStage {
 public:
  explicit GaborishStage(const LoopFilter& lf)
      : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric(
            /*shift=*/0, /*border=*/1)) {
    weights_[0] = 1;
    weights_[1] = lf.gab_x_weight1;
    weights_[2] = lf.gab_x_weight2;
    weights_[3] = 1;
    weights_[4] = lf.gab_y_weight1;
    weights_[5] = lf.gab_y_weight2;
    weights_[6] = 1;
    weights_[7] = lf.gab_b_weight1;
    weights_[8] = lf.gab_b_weight2;
    // Normalize
    for (size_t c = 0; c < 3; c++) {
      const float div =
          weights_[3 * c] + 4 * (weights_[3 * c + 1] + weights_[3 * c + 2]);
      const float mul = 1.0f / div;
      weights_[3 * c] *= mul;
      weights_[3 * c + 1] *= mul;
      weights_[3 * c + 2] *= mul;
    }
  }

  void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows,
                  size_t xextra, size_t xsize, size_t xpos, size_t ypos,
                  size_t thread_id) const final {
    PROFILER_ZONE("Gaborish");

    const HWY_FULL(float) d;
    for (size_t c = 0; c < 3; c++) {
      float* JXL_RESTRICT row_t = GetInputRow(input_rows, c, -1);
      float* JXL_RESTRICT row_m = GetInputRow(input_rows, c, 0);
      float* JXL_RESTRICT row_b = GetInputRow(input_rows, c, 1);
      float* JXL_RESTRICT row_out = GetOutputRow(output_rows, c, 0);
      const auto w0 = Set(d, weights_[3 * c + 0]);
      const auto w1 = Set(d, weights_[3 * c + 1]);
      const auto w2 = Set(d, weights_[3 * c + 2]);
// Group data need only be aligned to a block; for >=512 bit vectors, this may
// result in unaligned loads.
#if HWY_CAP_GE512
#define LoadMaybeU LoadU
#else
#define LoadMaybeU Load
#endif
      // Since GetInputRow(input_rows, c, {-1, 0, 1}) is aligned, rounding
      // xextra up to Lanes(d) doesn't access anything problematic.
      for (ssize_t x = -RoundUpTo(xextra, Lanes(d));
           x < (ssize_t)(xsize + xextra); x += Lanes(d)) {
        const auto t = LoadMaybeU(d, row_t + x);
        const auto tl = LoadU(d, row_t + x - 1);
        const auto tr = LoadU(d, row_t + x + 1);
        const auto m = LoadMaybeU(d, row_m + x);
        const auto l = LoadU(d, row_m + x - 1);
        const auto r = LoadU(d, row_m + x + 1);
        const auto b = LoadMaybeU(d, row_b + x);
        const auto bl = LoadU(d, row_b + x - 1);
        const auto br = LoadU(d, row_b + x + 1);
        const auto sum0 = m;
        const auto sum1 = Add(Add(l, r), Add(t, b));
        const auto sum2 = Add(Add(tl, tr), Add(bl, br));
        auto pixels = MulAdd(sum2, w2, MulAdd(sum1, w1, Mul(sum0, w0)));
        Store(pixels, d, row_out + x);
      }
    }
  }
#undef LoadMaybeU

  RenderPipelineChannelMode GetChannelMode(size_t c) const final {
    return c < 3 ? RenderPipelineChannelMode::kInOut
                 : RenderPipelineChannelMode::kIgnored;
  }

  const char* GetName() const override { return "Gab"; }

 private:
  float weights_[9];
};

std::unique_ptr<RenderPipelineStage> GetGaborishStage(const LoopFilter& lf) {
  return jxl::make_unique<GaborishStage>(lf);
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
}  // namespace HWY_NAMESPACE
}  // namespace jxl
HWY_AFTER_NAMESPACE();

#if HWY_ONCE
namespace jxl {

HWY_EXPORT(GetGaborishStage);

std::unique_ptr<RenderPipelineStage> GetGaborishStage(const LoopFilter& lf) {
  JXL_ASSERT(lf.gab == 1);
  return HWY_DYNAMIC_DISPATCH(GetGaborishStage)(lf);
}

}  // namespace jxl
#endif