diff options
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/render_pipeline/stage_upsampling.cc')
-rw-r--r-- | third_party/jpeg-xl/lib/jxl/render_pipeline/stage_upsampling.cc | 187 |
1 files changed, 187 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_upsampling.cc b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_upsampling.cc new file mode 100644 index 0000000000..a75e259865 --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/render_pipeline/stage_upsampling.cc @@ -0,0 +1,187 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include "lib/jxl/render_pipeline/stage_upsampling.h" + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_upsampling.cc" +#include <hwy/foreach_target.h> +#include <hwy/highway.h> + +#include "lib/jxl/sanitizers.h" +#include "lib/jxl/simd_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace jxl { +namespace HWY_NAMESPACE { + +// These templates are not found via ADL. +using hwy::HWY_NAMESPACE::Clamp; +using hwy::HWY_NAMESPACE::Max; +using hwy::HWY_NAMESPACE::Min; +using hwy::HWY_NAMESPACE::MulAdd; + +class UpsamplingStage : public RenderPipelineStage { + public: + explicit UpsamplingStage(const CustomTransformData& ups_factors, size_t c, + size_t shift) + : RenderPipelineStage(RenderPipelineStage::Settings::Symmetric( + /*shift=*/shift, /*border=*/2)), + c_(c) { + const float* weights = shift == 1 ? ups_factors.upsampling2_weights + : shift == 2 ? ups_factors.upsampling4_weights + : ups_factors.upsampling8_weights; + size_t N = 1 << (shift - 1); + for (size_t i = 0; i < 5 * N; i++) { + for (size_t j = 0; j < 5 * N; j++) { + size_t y = std::min(i, j); + size_t x = std::max(i, j); + kernel_[j / 5][i / 5][j % 5][i % 5] = + weights[5 * N * y - y * (y - 1) / 2 + x - y]; + } + } + } + + void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows, + size_t xextra, size_t xsize, size_t xpos, size_t ypos, + size_t thread_id) const final { + PROFILER_ZONE("Upsampling"); + static HWY_FULL(float) df; + size_t shift = settings_.shift_x; + size_t N = 1 << shift; + const size_t xsize_v = RoundUpTo(xsize, Lanes(df)); + for (ssize_t iy = -2; iy <= 2; iy++) { + msan::UnpoisonMemory(GetInputRow(input_rows, c_, iy) + xsize + 2, + sizeof(float) * (xsize_v - xsize)); + } + JXL_ASSERT(xextra == 0); + ssize_t x0 = 0; + ssize_t x1 = xsize; + if (N == 2) { + ProcessRowImpl<2>(input_rows, output_rows, x0, x1); + } + if (N == 4) { + ProcessRowImpl<4>(input_rows, output_rows, x0, x1); + } + if (N == 8) { + ProcessRowImpl<8>(input_rows, output_rows, x0, x1); + } + for (size_t oy = 0; oy < N; oy++) { + float* dst_row = GetOutputRow(output_rows, c_, oy); + msan::PoisonMemory(dst_row + xsize * N, + sizeof(float) * (xsize_v - xsize) * N); + } + } + + RenderPipelineChannelMode GetChannelMode(size_t c) const final { + return c == c_ ? RenderPipelineChannelMode::kInOut + : RenderPipelineChannelMode::kIgnored; + } + + const char* GetName() const override { return "Upsample"; } + + private: + template <size_t N> + JXL_INLINE float Kernel(size_t x, size_t y, ssize_t ix, ssize_t iy) const { + ix += 2; + iy += 2; + if (N == 2) { + return kernel_[0][0][y % 2 ? 4 - iy : iy][x % 2 ? 4 - ix : ix]; + } + if (N == 4) { + return kernel_[y % 4 < 2 ? y % 2 : 1 - y % 2] + [x % 4 < 2 ? x % 2 : 1 - x % 2][y % 4 < 2 ? iy : 4 - iy] + [x % 4 < 2 ? ix : 4 - ix]; + } + if (N == 8) { + return kernel_[y % 8 < 4 ? y % 4 : 3 - y % 4] + [x % 8 < 4 ? x % 4 : 3 - x % 4][y % 8 < 4 ? iy : 4 - iy] + [x % 8 < 4 ? ix : 4 - ix]; + } + JXL_ABORT("Invalid upsample"); + } + + template <ssize_t N> + void ProcessRowImpl(const RowInfo& input_rows, const RowInfo& output_rows, + ssize_t x0, ssize_t x1) const { + static HWY_FULL(float) df; + using V = hwy::HWY_NAMESPACE::Vec<HWY_FULL(float)>; + V ups0, ups1, ups2, ups3, ups4, ups5, ups6, ups7; + (void)ups2, (void)ups3, (void)ups4, (void)ups5, (void)ups6, (void)ups7; + V* ups[N]; + if (N >= 2) { + ups[0] = &ups0; + ups[1] = &ups1; + } + if (N >= 4) { + ups[2] = &ups2; + ups[3] = &ups3; + } + if (N == 8) { + ups[4] = &ups4; + ups[5] = &ups5; + ups[6] = &ups6; + ups[7] = &ups7; + } + for (size_t oy = 0; oy < N; oy++) { + float* dst_row = GetOutputRow(output_rows, c_, oy); + for (ssize_t x = x0; x < x1; x += Lanes(df)) { + for (size_t ox = 0; ox < N; ox++) { + auto result = Zero(df); + auto min = LoadU(df, GetInputRow(input_rows, c_, 0) + x); + auto max = min; + for (ssize_t iy = -2; iy <= 2; iy++) { + for (ssize_t ix = -2; ix <= 2; ix++) { + auto v = LoadU(df, GetInputRow(input_rows, c_, iy) + x + ix); + result = MulAdd(Set(df, Kernel<N>(ox, oy, ix, iy)), v, result); + min = Min(v, min); + max = Max(v, max); + } + } + // Avoid overshooting. + *ups[ox] = Clamp(result, min, max); + } + if (N == 2) { + StoreInterleaved(df, ups0, ups1, dst_row + x * N); + } + if (N == 4) { + StoreInterleaved(df, ups0, ups1, ups2, ups3, dst_row + x * N); + } + if (N == 8) { + StoreInterleaved(df, ups0, ups1, ups2, ups3, ups4, ups5, ups6, ups7, + dst_row + x * N); + } + } + } + } + + size_t c_; + float kernel_[4][4][5][5]; +}; + +std::unique_ptr<RenderPipelineStage> GetUpsamplingStage( + const CustomTransformData& ups_factors, size_t c, size_t shift) { + return jxl::make_unique<UpsamplingStage>(ups_factors, c, shift); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace jxl +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace jxl { + +HWY_EXPORT(GetUpsamplingStage); + +std::unique_ptr<RenderPipelineStage> GetUpsamplingStage( + const CustomTransformData& ups_factors, size_t c, size_t shift) { + JXL_ASSERT(shift != 0); + JXL_ASSERT(shift <= 3); + return HWY_DYNAMIC_DISPATCH(GetUpsamplingStage)(ups_factors, c, shift); +} + +} // namespace jxl +#endif |