summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jpegli/dct-inl.h
blob: 66cc3b6b53b7603fd862c929df5f05a5d5edb7e1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#if defined(LIB_JPEGLI_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE)
#ifdef LIB_JPEGLI_DCT_INL_H_
#undef LIB_JPEGLI_DCT_INL_H_
#else
#define LIB_JPEGLI_DCT_INL_H_
#endif

#include "lib/jpegli/transpose-inl.h"
#include "lib/jxl/base/compiler_specific.h"

HWY_BEFORE_NAMESPACE();
namespace jpegli {
namespace HWY_NAMESPACE {
namespace {

// These templates are not found via ADL.
using hwy::HWY_NAMESPACE::Abs;
using hwy::HWY_NAMESPACE::Add;
using hwy::HWY_NAMESPACE::DemoteTo;
using hwy::HWY_NAMESPACE::Ge;
using hwy::HWY_NAMESPACE::IfThenElseZero;
using hwy::HWY_NAMESPACE::Mul;
using hwy::HWY_NAMESPACE::MulAdd;
using hwy::HWY_NAMESPACE::Rebind;
using hwy::HWY_NAMESPACE::Round;
using hwy::HWY_NAMESPACE::Sub;
using hwy::HWY_NAMESPACE::Vec;

using D = HWY_FULL(float);
using DI = HWY_FULL(int32_t);

template <size_t N>
void AddReverse(const float* JXL_RESTRICT ain1, const float* JXL_RESTRICT ain2,
                float* JXL_RESTRICT aout) {
  HWY_CAPPED(float, 8) d8;
  for (size_t i = 0; i < N; i++) {
    auto in1 = Load(d8, ain1 + i * 8);
    auto in2 = Load(d8, ain2 + (N - i - 1) * 8);
    Store(Add(in1, in2), d8, aout + i * 8);
  }
}

template <size_t N>
void SubReverse(const float* JXL_RESTRICT ain1, const float* JXL_RESTRICT ain2,
                float* JXL_RESTRICT aout) {
  HWY_CAPPED(float, 8) d8;
  for (size_t i = 0; i < N; i++) {
    auto in1 = Load(d8, ain1 + i * 8);
    auto in2 = Load(d8, ain2 + (N - i - 1) * 8);
    Store(Sub(in1, in2), d8, aout + i * 8);
  }
}

template <size_t N>
void B(float* JXL_RESTRICT coeff) {
  HWY_CAPPED(float, 8) d8;
  constexpr float kSqrt2 = 1.41421356237f;
  auto sqrt2 = Set(d8, kSqrt2);
  auto in1 = Load(d8, coeff);
  auto in2 = Load(d8, coeff + 8);
  Store(MulAdd(in1, sqrt2, in2), d8, coeff);
  for (size_t i = 1; i + 1 < N; i++) {
    auto in1 = Load(d8, coeff + i * 8);
    auto in2 = Load(d8, coeff + (i + 1) * 8);
    Store(Add(in1, in2), d8, coeff + i * 8);
  }
}

// Ideally optimized away by compiler (except the multiply).
template <size_t N>
void InverseEvenOdd(const float* JXL_RESTRICT ain, float* JXL_RESTRICT aout) {
  HWY_CAPPED(float, 8) d8;
  for (size_t i = 0; i < N / 2; i++) {
    auto in1 = Load(d8, ain + i * 8);
    Store(in1, d8, aout + 2 * i * 8);
  }
  for (size_t i = N / 2; i < N; i++) {
    auto in1 = Load(d8, ain + i * 8);
    Store(in1, d8, aout + (2 * (i - N / 2) + 1) * 8);
  }
}

// Constants for DCT implementation. Generated by the following snippet:
// for i in range(N // 2):
//    print(1.0 / (2 * math.cos((i + 0.5) * math.pi / N)), end=", ")
template <size_t N>
struct WcMultipliers;

template <>
struct WcMultipliers<4> {
  static constexpr float kMultipliers[] = {
      0.541196100146197,
      1.3065629648763764,
  };
};

template <>
struct WcMultipliers<8> {
  static constexpr float kMultipliers[] = {
      0.5097955791041592,
      0.6013448869350453,
      0.8999762231364156,
      2.5629154477415055,
  };
};

constexpr float WcMultipliers<4>::kMultipliers[];
constexpr float WcMultipliers<8>::kMultipliers[];

// Invoked on full vector.
template <size_t N>
void Multiply(float* JXL_RESTRICT coeff) {
  HWY_CAPPED(float, 8) d8;
  for (size_t i = 0; i < N / 2; i++) {
    auto in1 = Load(d8, coeff + (N / 2 + i) * 8);
    auto mul = Set(d8, WcMultipliers<N>::kMultipliers[i]);
    Store(Mul(in1, mul), d8, coeff + (N / 2 + i) * 8);
  }
}

void LoadFromBlock(const float* JXL_RESTRICT pixels, size_t pixels_stride,
                   size_t off, float* JXL_RESTRICT coeff) {
  HWY_CAPPED(float, 8) d8;
  for (size_t i = 0; i < 8; i++) {
    Store(LoadU(d8, pixels + i * pixels_stride + off), d8, coeff + i * 8);
  }
}

void StoreToBlockAndScale(const float* JXL_RESTRICT coeff, float* output,
                          size_t off) {
  HWY_CAPPED(float, 8) d8;
  auto mul = Set(d8, 1.0f / 8);
  for (size_t i = 0; i < 8; i++) {
    StoreU(Mul(mul, Load(d8, coeff + i * 8)), d8, output + i * 8 + off);
  }
}

template <size_t N>
struct DCT1DImpl;

template <>
struct DCT1DImpl<1> {
  JXL_INLINE void operator()(float* JXL_RESTRICT mem) {}
};

template <>
struct DCT1DImpl<2> {
  JXL_INLINE void operator()(float* JXL_RESTRICT mem) {
    HWY_CAPPED(float, 8) d8;
    auto in1 = Load(d8, mem);
    auto in2 = Load(d8, mem + 8);
    Store(Add(in1, in2), d8, mem);
    Store(Sub(in1, in2), d8, mem + 8);
  }
};

template <size_t N>
struct DCT1DImpl {
  void operator()(float* JXL_RESTRICT mem) {
    HWY_ALIGN float tmp[N * 8];
    AddReverse<N / 2>(mem, mem + N * 4, tmp);
    DCT1DImpl<N / 2>()(tmp);
    SubReverse<N / 2>(mem, mem + N * 4, tmp + N * 4);
    Multiply<N>(tmp);
    DCT1DImpl<N / 2>()(tmp + N * 4);
    B<N / 2>(tmp + N * 4);
    InverseEvenOdd<N>(tmp, mem);
  }
};

void DCT1D(const float* JXL_RESTRICT pixels, size_t pixels_stride,
           float* JXL_RESTRICT output) {
  HWY_CAPPED(float, 8) d8;
  HWY_ALIGN float tmp[64];
  for (size_t i = 0; i < 8; i += Lanes(d8)) {
    // TODO(veluca): consider removing the temporary memory here (as is done in
    // IDCT), if it turns out that some compilers don't optimize away the loads
    // and this is performance-critical.
    LoadFromBlock(pixels, pixels_stride, i, tmp);
    DCT1DImpl<8>()(tmp);
    StoreToBlockAndScale(tmp, output, i);
  }
}

JXL_INLINE JXL_MAYBE_UNUSED void TransformFromPixels(
    const float* JXL_RESTRICT pixels, size_t pixels_stride,
    float* JXL_RESTRICT coefficients, float* JXL_RESTRICT scratch_space) {
  DCT1D(pixels, pixels_stride, scratch_space);
  Transpose8x8Block(scratch_space, coefficients);
  DCT1D(coefficients, 8, scratch_space);
  Transpose8x8Block(scratch_space, coefficients);
}

JXL_INLINE JXL_MAYBE_UNUSED void StoreQuantizedValue(const Vec<DI>& ival,
                                                     int16_t* out) {
  Rebind<int16_t, DI> di16;
  Store(DemoteTo(di16, ival), di16, out);
}

JXL_INLINE JXL_MAYBE_UNUSED void StoreQuantizedValue(const Vec<DI>& ival,
                                                     int32_t* out) {
  DI di;
  Store(ival, di, out);
}

template <typename T>
void QuantizeBlock(const float* dct, const float* qmc, float aq_strength,
                   const float* zero_bias_offset, const float* zero_bias_mul,
                   T* block) {
  D d;
  DI di;
  const auto aq_mul = Set(d, aq_strength);
  for (size_t k = 0; k < DCTSIZE2; k += Lanes(d)) {
    const auto val = Load(d, dct + k);
    const auto q = Load(d, qmc + k);
    const auto qval = Mul(val, q);
    const auto zb_offset = Load(d, zero_bias_offset + k);
    const auto zb_mul = Load(d, zero_bias_mul + k);
    const auto threshold = Add(zb_offset, Mul(zb_mul, aq_mul));
    const auto nzero_mask = Ge(Abs(qval), threshold);
    const auto ival = ConvertTo(di, IfThenElseZero(nzero_mask, Round(qval)));
    StoreQuantizedValue(ival, block + k);
  }
}

template <typename T>
void ComputeCoefficientBlock(const float* JXL_RESTRICT pixels, size_t stride,
                             const float* JXL_RESTRICT qmc,
                             int16_t last_dc_coeff, float aq_strength,
                             const float* zero_bias_offset,
                             const float* zero_bias_mul,
                             float* JXL_RESTRICT tmp, T* block) {
  float* JXL_RESTRICT dct = tmp;
  float* JXL_RESTRICT scratch_space = tmp + DCTSIZE2;
  TransformFromPixels(pixels, stride, dct, scratch_space);
  QuantizeBlock(dct, qmc, aq_strength, zero_bias_offset, zero_bias_mul, block);
  // Center DC values around zero.
  static constexpr float kDCBias = 128.0f;
  const float dc = (dct[0] - kDCBias) * qmc[0];
  float dc_threshold = zero_bias_offset[0] + aq_strength * zero_bias_mul[0];
  if (std::abs(dc - last_dc_coeff) < dc_threshold) {
    block[0] = last_dc_coeff;
  } else {
    block[0] = std::round(dc);
  }
}

// NOLINTNEXTLINE(google-readability-namespace-comments)
}  // namespace
}  // namespace HWY_NAMESPACE
}  // namespace jpegli
HWY_AFTER_NAMESPACE();
#endif  // LIB_JPEGLI_DCT_INL_H_