diff options
Diffstat (limited to '')
-rw-r--r-- | third_party/aom/av1/encoder/x86/ml_avx2.c | 240 |
1 files changed, 240 insertions, 0 deletions
diff --git a/third_party/aom/av1/encoder/x86/ml_avx2.c b/third_party/aom/av1/encoder/x86/ml_avx2.c new file mode 100644 index 0000000000..6432708416 --- /dev/null +++ b/third_party/aom/av1/encoder/x86/ml_avx2.c @@ -0,0 +1,240 @@ +/* + * Copyright (c) 2023, Alliance for Open Media. All rights reserved + * + * This source code is subject to the terms of the BSD 2 Clause License and + * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License + * was not distributed with this source code in the LICENSE file, you can + * obtain it at www.aomedia.org/license/software. If the Alliance for Open + * Media Patent License 1.0 was not distributed with this source code in the + * PATENTS file, you can obtain it at www.aomedia.org/license/patent. + */ + +#include <stdbool.h> +#include <assert.h> +#include <immintrin.h> + +#include "config/av1_rtcd.h" +#include "av1/encoder/ml.h" +#include "av1/encoder/x86/ml_sse3.h" + +#define CALC_OUTPUT_FOR_2ROWS \ + const int index = weight_idx + (2 * i * tot_num_inputs); \ + const __m256 weight0 = _mm256_loadu_ps(&weights[index]); \ + const __m256 weight1 = _mm256_loadu_ps(&weights[index + tot_num_inputs]); \ + const __m256 mul0 = _mm256_mul_ps(inputs256, weight0); \ + const __m256 mul1 = _mm256_mul_ps(inputs256, weight1); \ + hadd[i] = _mm256_hadd_ps(mul0, mul1); + +static INLINE void nn_propagate_8to1( + const float *const inputs, const float *const weights, + const float *const bias, int num_inputs_to_process, int tot_num_inputs, + int num_outputs, float *const output_nodes, int is_clip_required) { + // Process one output row at a time. + for (int out = 0; out < num_outputs; out++) { + __m256 in_result = _mm256_setzero_ps(); + float bias_val = bias[out]; + for (int in = 0; in < num_inputs_to_process; in += 8) { + const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]); + const int weight_idx = in + (out * tot_num_inputs); + const __m256 weight0 = _mm256_loadu_ps(&weights[weight_idx]); + const __m256 mul0 = _mm256_mul_ps(inputs256, weight0); + in_result = _mm256_add_ps(in_result, mul0); + } + const __m128 low_128 = _mm256_castps256_ps128(in_result); + const __m128 high_128 = _mm256_extractf128_ps(in_result, 1); + const __m128 sum_par_0 = _mm_add_ps(low_128, high_128); + const __m128 sum_par_1 = _mm_hadd_ps(sum_par_0, sum_par_0); + const __m128 sum_tot = + _mm_add_ps(_mm_shuffle_ps(sum_par_1, sum_par_1, 0x99), sum_par_1); + + bias_val += _mm_cvtss_f32(sum_tot); + if (is_clip_required) bias_val = AOMMAX(bias_val, 0); + output_nodes[out] = bias_val; + } +} + +static INLINE void nn_propagate_8to4( + const float *const inputs, const float *const weights, + const float *const bias, int num_inputs_to_process, int tot_num_inputs, + int num_outputs, float *const output_nodes, int is_clip_required) { + __m256 hadd[2]; + for (int out = 0; out < num_outputs; out += 4) { + __m128 bias_reg = _mm_loadu_ps(&bias[out]); + __m128 in_result = _mm_setzero_ps(); + for (int in = 0; in < num_inputs_to_process; in += 8) { + const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]); + const int weight_idx = in + (out * tot_num_inputs); + // Process two output row at a time. + for (int i = 0; i < 2; i++) { + CALC_OUTPUT_FOR_2ROWS + } + + const __m256 sum_par = _mm256_hadd_ps(hadd[0], hadd[1]); + const __m128 low_128 = _mm256_castps256_ps128(sum_par); + const __m128 high_128 = _mm256_extractf128_ps(sum_par, 1); + const __m128 result = _mm_add_ps(low_128, high_128); + + in_result = _mm_add_ps(in_result, result); + } + + in_result = _mm_add_ps(in_result, bias_reg); + if (is_clip_required) in_result = _mm_max_ps(in_result, _mm_setzero_ps()); + _mm_storeu_ps(&output_nodes[out], in_result); + } +} + +static INLINE void nn_propagate_8to8( + const float *const inputs, const float *const weights, + const float *const bias, int num_inputs_to_process, int tot_num_inputs, + int num_outputs, float *const output_nodes, int is_clip_required) { + __m256 hadd[4]; + for (int out = 0; out < num_outputs; out += 8) { + __m256 bias_reg = _mm256_loadu_ps(&bias[out]); + __m256 in_result = _mm256_setzero_ps(); + for (int in = 0; in < num_inputs_to_process; in += 8) { + const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]); + const int weight_idx = in + (out * tot_num_inputs); + // Process two output rows at a time. + for (int i = 0; i < 4; i++) { + CALC_OUTPUT_FOR_2ROWS + } + const __m256 hh0 = _mm256_hadd_ps(hadd[0], hadd[1]); + const __m256 hh1 = _mm256_hadd_ps(hadd[2], hadd[3]); + + __m256 ht_0 = _mm256_permute2f128_ps(hh0, hh1, 0x20); + __m256 ht_1 = _mm256_permute2f128_ps(hh0, hh1, 0x31); + + __m256 result = _mm256_add_ps(ht_0, ht_1); + in_result = _mm256_add_ps(in_result, result); + } + in_result = _mm256_add_ps(in_result, bias_reg); + if (is_clip_required) + in_result = _mm256_max_ps(in_result, _mm256_setzero_ps()); + _mm256_storeu_ps(&output_nodes[out], in_result); + } +} + +static INLINE void nn_propagate_input_multiple_of_8( + const float *const inputs, const float *const weights, + const float *const bias, int num_inputs_to_process, int tot_num_inputs, + bool is_output_layer, int num_outputs, float *const output_nodes) { + // The saturation of output is considered for hidden layer which is not equal + // to final hidden layer. + const int is_clip_required = + !is_output_layer && num_inputs_to_process == tot_num_inputs; + if (num_outputs % 8 == 0) { + nn_propagate_8to8(inputs, weights, bias, num_inputs_to_process, + tot_num_inputs, num_outputs, output_nodes, + is_clip_required); + } else if (num_outputs % 4 == 0) { + nn_propagate_8to4(inputs, weights, bias, num_inputs_to_process, + tot_num_inputs, num_outputs, output_nodes, + is_clip_required); + } else { + nn_propagate_8to1(inputs, weights, bias, num_inputs_to_process, + tot_num_inputs, num_outputs, output_nodes, + is_clip_required); + } +} + +void av1_nn_predict_avx2(const float *input_nodes, + const NN_CONFIG *const nn_config, int reduce_prec, + float *const output) { + float buf[2][NN_MAX_NODES_PER_LAYER]; + int buf_index = 0; + int num_inputs = nn_config->num_inputs; + assert(num_inputs > 0 && num_inputs <= NN_MAX_NODES_PER_LAYER); + + for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) { + const float *layer_weights = nn_config->weights[layer]; + const float *layer_bias = nn_config->bias[layer]; + bool is_output_layer = layer == nn_config->num_hidden_layers; + float *const output_nodes = is_output_layer ? output : &buf[buf_index][0]; + const int num_outputs = is_output_layer + ? nn_config->num_outputs + : nn_config->num_hidden_nodes[layer]; + assert(num_outputs > 0 && num_outputs <= NN_MAX_NODES_PER_LAYER); + + // Process input multiple of 8 using AVX2 intrinsic. + if (num_inputs % 8 == 0) { + nn_propagate_input_multiple_of_8(input_nodes, layer_weights, layer_bias, + num_inputs, num_inputs, is_output_layer, + num_outputs, output_nodes); + } else { + // When number of inputs is not multiple of 8, use hybrid approach of AVX2 + // and SSE3 based on the need. + const int in_mul_8 = num_inputs / 8; + const int num_inputs_to_process = in_mul_8 * 8; + int bias_is_considered = 0; + if (in_mul_8) { + nn_propagate_input_multiple_of_8( + input_nodes, layer_weights, layer_bias, num_inputs_to_process, + num_inputs, is_output_layer, num_outputs, output_nodes); + bias_is_considered = 1; + } + + const float *out_temp = bias_is_considered ? output_nodes : layer_bias; + const int input_remaining = num_inputs % 8; + if (input_remaining % 4 == 0 && num_outputs % 8 == 0) { + for (int out = 0; out < num_outputs; out += 8) { + __m128 out_h = _mm_loadu_ps(&out_temp[out + 4]); + __m128 out_l = _mm_loadu_ps(&out_temp[out]); + for (int in = in_mul_8 * 8; in < num_inputs; in += 4) { + av1_nn_propagate_4to8_sse3(&input_nodes[in], + &layer_weights[out * num_inputs + in], + &out_h, &out_l, num_inputs); + } + if (!is_output_layer) { + const __m128 zero = _mm_setzero_ps(); + out_h = _mm_max_ps(out_h, zero); + out_l = _mm_max_ps(out_l, zero); + } + _mm_storeu_ps(&output_nodes[out + 4], out_h); + _mm_storeu_ps(&output_nodes[out], out_l); + } + } else if (input_remaining % 4 == 0 && num_outputs % 4 == 0) { + for (int out = 0; out < num_outputs; out += 4) { + __m128 outputs = _mm_loadu_ps(&out_temp[out]); + for (int in = in_mul_8 * 8; in < num_inputs; in += 4) { + av1_nn_propagate_4to4_sse3(&input_nodes[in], + &layer_weights[out * num_inputs + in], + &outputs, num_inputs); + } + if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps()); + _mm_storeu_ps(&output_nodes[out], outputs); + } + } else if (input_remaining % 4 == 0) { + for (int out = 0; out < num_outputs; out++) { + __m128 outputs = _mm_load1_ps(&out_temp[out]); + for (int in = in_mul_8 * 8; in < num_inputs; in += 4) { + av1_nn_propagate_4to1_sse3(&input_nodes[in], + &layer_weights[out * num_inputs + in], + &outputs); + } + if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps()); + output_nodes[out] = _mm_cvtss_f32(outputs); + } + } else { + // Use SSE instructions for scalar operations to avoid the latency + // of swapping between SIMD and FPU modes. + for (int out = 0; out < num_outputs; out++) { + __m128 outputs = _mm_load1_ps(&out_temp[out]); + for (int in_node = in_mul_8 * 8; in_node < num_inputs; in_node++) { + __m128 input = _mm_load1_ps(&input_nodes[in_node]); + __m128 weight = + _mm_load1_ps(&layer_weights[num_inputs * out + in_node]); + outputs = _mm_add_ps(outputs, _mm_mul_ps(input, weight)); + } + if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps()); + output_nodes[out] = _mm_cvtss_f32(outputs); + } + } + } + // Before processing the next layer, treat the output of current layer as + // input to next layer. + input_nodes = output_nodes; + num_inputs = num_outputs; + buf_index = 1 - buf_index; + } + if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs); +} |