diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/aom/test/av1_softmax_test.cc | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/aom/test/av1_softmax_test.cc')
-rw-r--r-- | third_party/aom/test/av1_softmax_test.cc | 122 |
1 files changed, 122 insertions, 0 deletions
diff --git a/third_party/aom/test/av1_softmax_test.cc b/third_party/aom/test/av1_softmax_test.cc new file mode 100644 index 0000000000..2b04af1342 --- /dev/null +++ b/third_party/aom/test/av1_softmax_test.cc @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2021, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <memory> +#include <new> +#include <tuple> + +#include "aom/aom_integer.h" +#include "aom_ports/aom_timer.h" +#include "av1/encoder/ml.h" +#include "config/aom_config.h" +#include "config/aom_dsp_rtcd.h" +#include "config/av1_rtcd.h" +#include "test/acm_random.h" +#include "test/register_state_check.h" +#include "test/util.h" +#include "third_party/googletest/src/googletest/include/gtest/gtest.h" + +namespace { +using FastSoftmaxFn = void (*)(const float *const input, float *output); +using FastSoftmaxTestParams = std::tuple<const FastSoftmaxFn, int>; + +// Error thresholds for functional equivalence +constexpr float kRelEpsilon = 5e-2f; +constexpr float kAbsEpsilon = 5e-3f; + +class FastSoftmaxTest : public ::testing::TestWithParam<FastSoftmaxTestParams> { + public: + FastSoftmaxTest() : target_fn_(GET_PARAM(0)), num_classes_(GET_PARAM(1)) {} + void SetUp() override { + ref_buf_.reset(new (std::nothrow) float[num_classes_]()); + ASSERT_NE(ref_buf_, nullptr); + dst_buf_.reset(new (std::nothrow) float[num_classes_]()); + ASSERT_NE(dst_buf_, nullptr); + input_.reset(new (std::nothrow) float[num_classes_]()); + ASSERT_NE(input_, nullptr); + } + void RunSoftmaxTest(); + void RunSoftmaxSpeedTest(const int run_times); + void FillInputBuf(); + + private: + const FastSoftmaxFn target_fn_; + const int num_classes_; + std::unique_ptr<float[]> ref_buf_, dst_buf_, input_; + libaom_test::ACMRandom rng_; +}; + +void FastSoftmaxTest::FillInputBuf() { + for (int idx = 0; idx < num_classes_; idx++) { + input_[idx] = ((float)rng_.Rand31() - (1 << 30)) / (1u << 30); + } +} + +void FastSoftmaxTest::RunSoftmaxTest() { + av1_nn_softmax(input_.get(), ref_buf_.get(), num_classes_); + target_fn_(input_.get(), dst_buf_.get()); + + for (int idx = 0; idx < num_classes_; idx++) { + if (ref_buf_[idx] < kAbsEpsilon) { + ASSERT_LE(dst_buf_[idx], kAbsEpsilon) + << "Reference output was near-zero, test output was not" << std::endl; + } else { + const float error = dst_buf_[idx] - ref_buf_[idx]; + const float relative_error = fabsf(error / ref_buf_[idx]); + ASSERT_LE(relative_error, kRelEpsilon) + << "Excessive relative error between reference and test output" + << std::endl; + ASSERT_LE(error, kAbsEpsilon) + << "Excessive absolute error between reference and test output" + << std::endl; + } + } +} + +void FastSoftmaxTest::RunSoftmaxSpeedTest(const int run_times) { + aom_usec_timer timer; + aom_usec_timer_start(&timer); + for (int idx = 0; idx < run_times; idx++) { + target_fn_(input_.get(), dst_buf_.get()); + } + aom_usec_timer_mark(&timer); + const int64_t time = aom_usec_timer_elapsed(&timer); + std::cout << "Test with " << num_classes_ << " classes took " << time + << " us." << std::endl; +} + +TEST_P(FastSoftmaxTest, RandomValues) { + FillInputBuf(); + RunSoftmaxTest(); +} + +TEST_P(FastSoftmaxTest, DISABLED_Speed) { + constexpr int kNumTimes = 1000000; + RunSoftmaxSpeedTest(kNumTimes); +} + +void AnchorSoftmax16Fn(const float *input, float *output) { + av1_nn_softmax(input, output, 16); +} + +const FastSoftmaxTestParams kArrayParams_c[] = { + FastSoftmaxTestParams(AnchorSoftmax16Fn, 16), + FastSoftmaxTestParams(av1_nn_fast_softmax_16_c, 16) +}; +INSTANTIATE_TEST_SUITE_P(C, FastSoftmaxTest, + ::testing::ValuesIn(kArrayParams_c)); + +#if HAVE_SSE3 && !CONFIG_EXCLUDE_SIMD_MISMATCH +INSTANTIATE_TEST_SUITE_P( + SSE3, FastSoftmaxTest, + ::testing::Values(FastSoftmaxTestParams(av1_nn_fast_softmax_16_sse3, 16))); +#endif +} // namespace |