summaryrefslogtreecommitdiffstats
path: root/third_party/aom/av1/encoder/arm/neon/pickrst_sve.c
blob: a519ecc5f54c0b6a5d82f1f012c4968d7f1d5414 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
/*
 * Copyright (c) 2024, Alliance for Open Media. All rights reserved
 *
 * This source code is subject to the terms of the BSD 2 Clause License and
 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
 * was not distributed with this source code in the LICENSE file, you can
 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
 * Media Patent License 1.0 was not distributed with this source code in the
 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
 */

#include <arm_neon.h>
#include <arm_sve.h>
#include <string.h>

#include "config/aom_config.h"
#include "config/av1_rtcd.h"

#include "aom_dsp/arm/aom_neon_sve_bridge.h"
#include "aom_dsp/arm/mem_neon.h"
#include "aom_dsp/arm/sum_neon.h"
#include "aom_dsp/arm/transpose_neon.h"
#include "av1/common/restoration.h"
#include "av1/encoder/pickrst.h"

static INLINE uint8_t find_average_sve(const uint8_t *src, int src_stride,
                                       int width, int height) {
  uint32x4_t avg_u32 = vdupq_n_u32(0);
  uint8x16_t ones = vdupq_n_u8(1);

  // Use a predicate to compute the last columns.
  svbool_t pattern = svwhilelt_b8_u32(0, width % 16);

  int h = height;
  do {
    int j = width;
    const uint8_t *src_ptr = src;
    while (j >= 16) {
      uint8x16_t s = vld1q_u8(src_ptr);
      avg_u32 = vdotq_u32(avg_u32, s, ones);

      j -= 16;
      src_ptr += 16;
    }
    uint8x16_t s_end = svget_neonq_u8(svld1_u8(pattern, src_ptr));
    avg_u32 = vdotq_u32(avg_u32, s_end, ones);

    src += src_stride;
  } while (--h != 0);
  return (uint8_t)(vaddlvq_u32(avg_u32) / (width * height));
}

static INLINE void compute_sub_avg(const uint8_t *buf, int buf_stride, int avg,
                                   int16_t *buf_avg, int buf_avg_stride,
                                   int width, int height,
                                   int downsample_factor) {
  uint8x8_t avg_u8 = vdup_n_u8(avg);

  // Use a predicate to compute the last columns.
  svbool_t pattern = svwhilelt_b8_u32(0, width % 8);

  uint8x8_t avg_end = vget_low_u8(svget_neonq_u8(svdup_n_u8_z(pattern, avg)));

  do {
    int j = width;
    const uint8_t *buf_ptr = buf;
    int16_t *buf_avg_ptr = buf_avg;
    while (j >= 8) {
      uint8x8_t d = vld1_u8(buf_ptr);
      vst1q_s16(buf_avg_ptr, vreinterpretq_s16_u16(vsubl_u8(d, avg_u8)));

      j -= 8;
      buf_ptr += 8;
      buf_avg_ptr += 8;
    }
    uint8x8_t d_end = vget_low_u8(svget_neonq_u8(svld1_u8(pattern, buf_ptr)));
    vst1q_s16(buf_avg_ptr, vreinterpretq_s16_u16(vsubl_u8(d_end, avg_end)));

    buf += buf_stride;
    buf_avg += buf_avg_stride;
    height -= downsample_factor;
  } while (height > 0);
}

static INLINE void copy_upper_triangle(int64_t *H, int64_t *H_tmp,
                                       const int wiener_win2, const int scale) {
  for (int i = 0; i < wiener_win2 - 2; i = i + 2) {
    // Transpose the first 2x2 square. It needs a special case as the element
    // of the bottom left is on the diagonal.
    int64x2_t row0 = vld1q_s64(H_tmp + i * wiener_win2 + i + 1);
    int64x2_t row1 = vld1q_s64(H_tmp + (i + 1) * wiener_win2 + i + 1);

    int64x2_t tr_row = aom_vtrn2q_s64(row0, row1);

    vst1_s64(H_tmp + (i + 1) * wiener_win2 + i, vget_low_s64(row0));
    vst1q_s64(H_tmp + (i + 2) * wiener_win2 + i, tr_row);

    // Transpose and store all the remaining 2x2 squares of the line.
    for (int j = i + 3; j < wiener_win2; j = j + 2) {
      row0 = vld1q_s64(H_tmp + i * wiener_win2 + j);
      row1 = vld1q_s64(H_tmp + (i + 1) * wiener_win2 + j);

      int64x2_t tr_row0 = aom_vtrn1q_s64(row0, row1);
      int64x2_t tr_row1 = aom_vtrn2q_s64(row0, row1);

      vst1q_s64(H_tmp + j * wiener_win2 + i, tr_row0);
      vst1q_s64(H_tmp + (j + 1) * wiener_win2 + i, tr_row1);
    }
  }
  for (int i = 0; i < wiener_win2 * wiener_win2; i++) {
    H[i] += H_tmp[i] * scale;
  }
}

// Transpose the matrix that has just been computed and accumulate it in M.
static INLINE void acc_transpose_M(int64_t *M, const int64_t *M_trn,
                                   const int wiener_win, int scale) {
  for (int i = 0; i < wiener_win; ++i) {
    for (int j = 0; j < wiener_win; ++j) {
      int tr_idx = j * wiener_win + i;
      *M++ += (int64_t)(M_trn[tr_idx] * scale);
    }
  }
}

// Swap each half of the dgd vectors so that we can accumulate the result of
// the dot-products directly in the destination matrix.
static INLINE int16x8x2_t transpose_dgd(int16x8_t dgd0, int16x8_t dgd1) {
  int16x8_t dgd_trn0 = vreinterpretq_s16_s64(
      vzip1q_s64(vreinterpretq_s64_s16(dgd0), vreinterpretq_s64_s16(dgd1)));
  int16x8_t dgd_trn1 = vreinterpretq_s16_s64(
      vzip2q_s64(vreinterpretq_s64_s16(dgd0), vreinterpretq_s64_s16(dgd1)));

  return (struct int16x8x2_t){ dgd_trn0, dgd_trn1 };
}

static INLINE void compute_M_one_row_win5(int16x8_t src, int16x8_t dgd[5],
                                          int64_t *M, int row) {
  const int wiener_win = 5;

  int64x2_t m01 = vld1q_s64(M + row * wiener_win + 0);
  int16x8x2_t dgd01 = transpose_dgd(dgd[0], dgd[1]);

  int64x2_t cross_corr01 = aom_svdot_lane_s16(m01, dgd01.val[0], src, 0);
  cross_corr01 = aom_svdot_lane_s16(cross_corr01, dgd01.val[1], src, 1);
  vst1q_s64(M + row * wiener_win + 0, cross_corr01);

  int64x2_t m23 = vld1q_s64(M + row * wiener_win + 2);
  int16x8x2_t dgd23 = transpose_dgd(dgd[2], dgd[3]);

  int64x2_t cross_corr23 = aom_svdot_lane_s16(m23, dgd23.val[0], src, 0);
  cross_corr23 = aom_svdot_lane_s16(cross_corr23, dgd23.val[1], src, 1);
  vst1q_s64(M + row * wiener_win + 2, cross_corr23);

  int64x2_t m4 = aom_sdotq_s16(vdupq_n_s64(0), src, dgd[4]);
  M[row * wiener_win + 4] += vaddvq_s64(m4);
}

static INLINE void compute_M_one_row_win7(int16x8_t src, int16x8_t dgd[7],
                                          int64_t *M, int row) {
  const int wiener_win = 7;

  int64x2_t m01 = vld1q_s64(M + row * wiener_win + 0);
  int16x8x2_t dgd01 = transpose_dgd(dgd[0], dgd[1]);

  int64x2_t cross_corr01 = aom_svdot_lane_s16(m01, dgd01.val[0], src, 0);
  cross_corr01 = aom_svdot_lane_s16(cross_corr01, dgd01.val[1], src, 1);
  vst1q_s64(M + row * wiener_win + 0, cross_corr01);

  int64x2_t m23 = vld1q_s64(M + row * wiener_win + 2);
  int16x8x2_t dgd23 = transpose_dgd(dgd[2], dgd[3]);

  int64x2_t cross_corr23 = aom_svdot_lane_s16(m23, dgd23.val[0], src, 0);
  cross_corr23 = aom_svdot_lane_s16(cross_corr23, dgd23.val[1], src, 1);
  vst1q_s64(M + row * wiener_win + 2, cross_corr23);

  int64x2_t m45 = vld1q_s64(M + row * wiener_win + 4);
  int16x8x2_t dgd45 = transpose_dgd(dgd[4], dgd[5]);

  int64x2_t cross_corr45 = aom_svdot_lane_s16(m45, dgd45.val[0], src, 0);
  cross_corr45 = aom_svdot_lane_s16(cross_corr45, dgd45.val[1], src, 1);
  vst1q_s64(M + row * wiener_win + 4, cross_corr45);

  int64x2_t m6 = aom_sdotq_s16(vdupq_n_s64(0), src, dgd[6]);
  M[row * wiener_win + 6] += vaddvq_s64(m6);
}

static INLINE void compute_H_one_col(int16x8_t *dgd, int col, int64_t *H,
                                     const int wiener_win,
                                     const int wiener_win2) {
  for (int row0 = 0; row0 < wiener_win; row0++) {
    for (int row1 = row0; row1 < wiener_win; row1++) {
      int auto_cov_idx =
          (col * wiener_win + row0) * wiener_win2 + (col * wiener_win) + row1;

      int64x2_t auto_cov = aom_sdotq_s16(vdupq_n_s64(0), dgd[row0], dgd[row1]);
      H[auto_cov_idx] += vaddvq_s64(auto_cov);
    }
  }
}

static INLINE void compute_H_two_rows_win5(int16x8_t *dgd0, int16x8_t *dgd1,
                                           int row0, int row1, int64_t *H) {
  for (int col0 = 0; col0 < 5; col0++) {
    int auto_cov_idx = (row0 * 5 + col0) * 25 + (row1 * 5);

    int64x2_t h01 = vld1q_s64(H + auto_cov_idx);
    int16x8x2_t dgd01 = transpose_dgd(dgd1[0], dgd1[1]);

    int64x2_t auto_cov01 = aom_svdot_lane_s16(h01, dgd01.val[0], dgd0[col0], 0);
    auto_cov01 = aom_svdot_lane_s16(auto_cov01, dgd01.val[1], dgd0[col0], 1);
    vst1q_s64(H + auto_cov_idx, auto_cov01);

    int64x2_t h23 = vld1q_s64(H + auto_cov_idx + 2);
    int16x8x2_t dgd23 = transpose_dgd(dgd1[2], dgd1[3]);

    int64x2_t auto_cov23 = aom_svdot_lane_s16(h23, dgd23.val[0], dgd0[col0], 0);
    auto_cov23 = aom_svdot_lane_s16(auto_cov23, dgd23.val[1], dgd0[col0], 1);
    vst1q_s64(H + auto_cov_idx + 2, auto_cov23);

    int64x2_t auto_cov4 = aom_sdotq_s16(vdupq_n_s64(0), dgd0[col0], dgd1[4]);
    H[auto_cov_idx + 4] += vaddvq_s64(auto_cov4);
  }
}

static INLINE void compute_H_two_rows_win7(int16x8_t *dgd0, int16x8_t *dgd1,
                                           int row0, int row1, int64_t *H) {
  for (int col0 = 0; col0 < 7; col0++) {
    int auto_cov_idx = (row0 * 7 + col0) * 49 + (row1 * 7);

    int64x2_t h01 = vld1q_s64(H + auto_cov_idx);
    int16x8x2_t dgd01 = transpose_dgd(dgd1[0], dgd1[1]);

    int64x2_t auto_cov01 = aom_svdot_lane_s16(h01, dgd01.val[0], dgd0[col0], 0);
    auto_cov01 = aom_svdot_lane_s16(auto_cov01, dgd01.val[1], dgd0[col0], 1);
    vst1q_s64(H + auto_cov_idx, auto_cov01);

    int64x2_t h23 = vld1q_s64(H + auto_cov_idx + 2);
    int16x8x2_t dgd23 = transpose_dgd(dgd1[2], dgd1[3]);

    int64x2_t auto_cov23 = aom_svdot_lane_s16(h23, dgd23.val[0], dgd0[col0], 0);
    auto_cov23 = aom_svdot_lane_s16(auto_cov23, dgd23.val[1], dgd0[col0], 1);
    vst1q_s64(H + auto_cov_idx + 2, auto_cov23);

    int64x2_t h45 = vld1q_s64(H + auto_cov_idx + 4);
    int16x8x2_t dgd45 = transpose_dgd(dgd1[4], dgd1[5]);

    int64x2_t auto_cov45 = aom_svdot_lane_s16(h45, dgd45.val[0], dgd0[col0], 0);
    auto_cov45 = aom_svdot_lane_s16(auto_cov45, dgd45.val[1], dgd0[col0], 1);
    vst1q_s64(H + auto_cov_idx + 4, auto_cov45);

    int64x2_t auto_cov6 = aom_sdotq_s16(vdupq_n_s64(0), dgd0[col0], dgd1[6]);
    H[auto_cov_idx + 6] += vaddvq_s64(auto_cov6);
  }
}

// This function computes two matrices: the cross-correlation between the src
// buffer and dgd buffer (M), and the auto-covariance of the dgd buffer (H).
//
// M is of size 7 * 7. It needs to be filled such that multiplying one element
// from src with each element of a row of the wiener window will fill one
// column of M. However this is not very convenient in terms of memory
// accesses, as it means we do contiguous loads of dgd but strided stores to M.
// As a result, we use an intermediate matrix M_trn which is instead filled
// such that one row of the wiener window gives one row of M_trn. Once fully
// computed, M_trn is then transposed to return M.
//
// H is of size 49 * 49. It is filled by multiplying every pair of elements of
// the wiener window together. Since it is a symmetric matrix, we only compute
// the upper triangle, and then copy it down to the lower one. Here we fill it
// by taking each different pair of columns, and multiplying all the elements of
// the first one with all the elements of the second one, with a special case
// when multiplying a column by itself.
static INLINE void compute_stats_win7_sve(int16_t *dgd_avg, int dgd_avg_stride,
                                          int16_t *src_avg, int src_avg_stride,
                                          int width, int height, int64_t *M,
                                          int64_t *H, int downsample_factor) {
  const int wiener_win = 7;
  const int wiener_win2 = wiener_win * wiener_win;

  // Use a predicate to compute the last columns of the block for H.
  svbool_t pattern = svwhilelt_b16_u32(0, width % 8);

  // Use intermediate matrices for H and M to perform the computation, they
  // will be accumulated into the original H and M at the end.
  int64_t M_trn[49];
  memset(M_trn, 0, sizeof(M_trn));

  int64_t H_tmp[49 * 49];
  memset(H_tmp, 0, sizeof(H_tmp));

  do {
    // Cross-correlation (M).
    for (int row = 0; row < wiener_win; row++) {
      int j = 0;
      while (j < width) {
        int16x8_t dgd[7];
        load_s16_8x7(dgd_avg + row * dgd_avg_stride + j, 1, &dgd[0], &dgd[1],
                     &dgd[2], &dgd[3], &dgd[4], &dgd[5], &dgd[6]);
        int16x8_t s = vld1q_s16(src_avg + j);

        // Compute all the elements of one row of M.
        compute_M_one_row_win7(s, dgd, M_trn, row);

        j += 8;
      }
    }

    // Auto-covariance (H).
    int j = 0;
    while (j <= width - 8) {
      for (int col0 = 0; col0 < wiener_win; col0++) {
        int16x8_t dgd0[7];
        load_s16_8x7(dgd_avg + j + col0, dgd_avg_stride, &dgd0[0], &dgd0[1],
                     &dgd0[2], &dgd0[3], &dgd0[4], &dgd0[5], &dgd0[6]);

        // Perform computation of the first column with itself (28 elements).
        // For the first column this will fill the upper triangle of the 7x7
        // matrix at the top left of the H matrix. For the next columns this
        // will fill the upper triangle of the other 7x7 matrices around H's
        // diagonal.
        compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2);

        // All computation next to the matrix diagonal has already been done.
        for (int col1 = col0 + 1; col1 < wiener_win; col1++) {
          // Load second column and scale based on downsampling factor.
          int16x8_t dgd1[7];
          load_s16_8x7(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1],
                       &dgd1[2], &dgd1[3], &dgd1[4], &dgd1[5], &dgd1[6]);

          // Compute all elements from the combination of both columns (49
          // elements).
          compute_H_two_rows_win7(dgd0, dgd1, col0, col1, H_tmp);
        }
      }
      j += 8;
    }

    if (j < width) {
      // Process remaining columns using a predicate to discard excess elements.
      for (int col0 = 0; col0 < wiener_win; col0++) {
        // Load first column.
        int16x8_t dgd0[7];
        dgd0[0] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 0 * dgd_avg_stride + j + col0));
        dgd0[1] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 1 * dgd_avg_stride + j + col0));
        dgd0[2] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 2 * dgd_avg_stride + j + col0));
        dgd0[3] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 3 * dgd_avg_stride + j + col0));
        dgd0[4] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 4 * dgd_avg_stride + j + col0));
        dgd0[5] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 5 * dgd_avg_stride + j + col0));
        dgd0[6] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 6 * dgd_avg_stride + j + col0));

        // Perform computation of the first column with itself (28 elements).
        // For the first column this will fill the upper triangle of the 7x7
        // matrix at the top left of the H matrix. For the next columns this
        // will fill the upper triangle of the other 7x7 matrices around H's
        // diagonal.
        compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2);

        // All computation next to the matrix diagonal has already been done.
        for (int col1 = col0 + 1; col1 < wiener_win; col1++) {
          // Load second column and scale based on downsampling factor.
          int16x8_t dgd1[7];
          load_s16_8x7(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1],
                       &dgd1[2], &dgd1[3], &dgd1[4], &dgd1[5], &dgd1[6]);

          // Compute all elements from the combination of both columns (49
          // elements).
          compute_H_two_rows_win7(dgd0, dgd1, col0, col1, H_tmp);
        }
      }
    }
    dgd_avg += downsample_factor * dgd_avg_stride;
    src_avg += src_avg_stride;
  } while (--height != 0);

  // Transpose M_trn.
  acc_transpose_M(M, M_trn, 7, downsample_factor);

  // Copy upper triangle of H in the lower one.
  copy_upper_triangle(H, H_tmp, wiener_win2, downsample_factor);
}

// This function computes two matrices: the cross-correlation between the src
// buffer and dgd buffer (M), and the auto-covariance of the dgd buffer (H).
//
// M is of size 5 * 5. It needs to be filled such that multiplying one element
// from src with each element of a row of the wiener window will fill one
// column of M. However this is not very convenient in terms of memory
// accesses, as it means we do contiguous loads of dgd but strided stores to M.
// As a result, we use an intermediate matrix M_trn which is instead filled
// such that one row of the wiener window gives one row of M_trn. Once fully
// computed, M_trn is then transposed to return M.
//
// H is of size 25 * 25. It is filled by multiplying every pair of elements of
// the wiener window together. Since it is a symmetric matrix, we only compute
// the upper triangle, and then copy it down to the lower one. Here we fill it
// by taking each different pair of columns, and multiplying all the elements of
// the first one with all the elements of the second one, with a special case
// when multiplying a column by itself.
static INLINE void compute_stats_win5_sve(int16_t *dgd_avg, int dgd_avg_stride,
                                          int16_t *src_avg, int src_avg_stride,
                                          int width, int height, int64_t *M,
                                          int64_t *H, int downsample_factor) {
  const int wiener_win = 5;
  const int wiener_win2 = wiener_win * wiener_win;

  // Use a predicate to compute the last columns of the block for H.
  svbool_t pattern = svwhilelt_b16_u32(0, width % 8);

  // Use intermediate matrices for H and M to perform the computation, they
  // will be accumulated into the original H and M at the end.
  int64_t M_trn[25];
  memset(M_trn, 0, sizeof(M_trn));

  int64_t H_tmp[25 * 25];
  memset(H_tmp, 0, sizeof(H_tmp));

  do {
    // Cross-correlation (M).
    for (int row = 0; row < wiener_win; row++) {
      int j = 0;
      while (j < width) {
        int16x8_t dgd[5];
        load_s16_8x5(dgd_avg + row * dgd_avg_stride + j, 1, &dgd[0], &dgd[1],
                     &dgd[2], &dgd[3], &dgd[4]);
        int16x8_t s = vld1q_s16(src_avg + j);

        // Compute all the elements of one row of M.
        compute_M_one_row_win5(s, dgd, M_trn, row);

        j += 8;
      }
    }

    // Auto-covariance (H).
    int j = 0;
    while (j <= width - 8) {
      for (int col0 = 0; col0 < wiener_win; col0++) {
        // Load first column.
        int16x8_t dgd0[5];
        load_s16_8x5(dgd_avg + j + col0, dgd_avg_stride, &dgd0[0], &dgd0[1],
                     &dgd0[2], &dgd0[3], &dgd0[4]);

        // Perform computation of the first column with itself (15 elements).
        // For the first column this will fill the upper triangle of the 5x5
        // matrix at the top left of the H matrix. For the next columns this
        // will fill the upper triangle of the other 5x5 matrices around H's
        // diagonal.
        compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2);

        // All computation next to the matrix diagonal has already been done.
        for (int col1 = col0 + 1; col1 < wiener_win; col1++) {
          // Load second column and scale based on downsampling factor.
          int16x8_t dgd1[5];
          load_s16_8x5(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1],
                       &dgd1[2], &dgd1[3], &dgd1[4]);

          // Compute all elements from the combination of both columns (25
          // elements).
          compute_H_two_rows_win5(dgd0, dgd1, col0, col1, H_tmp);
        }
      }
      j += 8;
    }

    // Process remaining columns using a predicate to discard excess elements.
    if (j < width) {
      for (int col0 = 0; col0 < wiener_win; col0++) {
        int16x8_t dgd0[5];
        dgd0[0] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 0 * dgd_avg_stride + j + col0));
        dgd0[1] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 1 * dgd_avg_stride + j + col0));
        dgd0[2] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 2 * dgd_avg_stride + j + col0));
        dgd0[3] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 3 * dgd_avg_stride + j + col0));
        dgd0[4] = svget_neonq_s16(
            svld1_s16(pattern, dgd_avg + 4 * dgd_avg_stride + j + col0));

        // Perform computation of the first column with itself (15 elements).
        // For the first column this will fill the upper triangle of the 5x5
        // matrix at the top left of the H matrix. For the next columns this
        // will fill the upper triangle of the other 5x5 matrices around H's
        // diagonal.
        compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2);

        // All computation next to the matrix diagonal has already been done.
        for (int col1 = col0 + 1; col1 < wiener_win; col1++) {
          // Load second column and scale based on downsampling factor.
          int16x8_t dgd1[5];
          load_s16_8x5(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1],
                       &dgd1[2], &dgd1[3], &dgd1[4]);

          // Compute all elements from the combination of both columns (25
          // elements).
          compute_H_two_rows_win5(dgd0, dgd1, col0, col1, H_tmp);
        }
      }
    }
    dgd_avg += downsample_factor * dgd_avg_stride;
    src_avg += src_avg_stride;
  } while (--height != 0);

  // Transpose M_trn.
  acc_transpose_M(M, M_trn, 5, downsample_factor);

  // Copy upper triangle of H in the lower one.
  copy_upper_triangle(H, H_tmp, wiener_win2, downsample_factor);
}

void av1_compute_stats_sve(int wiener_win, const uint8_t *dgd,
                           const uint8_t *src, int16_t *dgd_avg,
                           int16_t *src_avg, int h_start, int h_end,
                           int v_start, int v_end, int dgd_stride,
                           int src_stride, int64_t *M, int64_t *H,
                           int use_downsampled_wiener_stats) {
  assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_CHROMA);

  const int wiener_win2 = wiener_win * wiener_win;
  const int wiener_halfwin = wiener_win >> 1;
  const int32_t width = h_end - h_start;
  const int32_t height = v_end - v_start;
  const uint8_t *dgd_start = &dgd[v_start * dgd_stride + h_start];
  memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
  memset(M, 0, sizeof(*M) * wiener_win * wiener_win);

  const uint8_t avg = find_average_sve(dgd_start, dgd_stride, width, height);
  const int downsample_factor =
      use_downsampled_wiener_stats ? WIENER_STATS_DOWNSAMPLE_FACTOR : 1;

  // dgd_avg and src_avg have been memset to zero before calling this
  // function, so round up the stride to the next multiple of 8 so that we
  // don't have to worry about a tail loop when computing M.
  const int dgd_avg_stride = ((width + 2 * wiener_halfwin) & ~7) + 8;
  const int src_avg_stride = (width & ~7) + 8;

  // Compute (dgd - avg) and store it in dgd_avg.
  // The wiener window will slide along the dgd frame, centered on each pixel.
  // For the top left pixel and all the pixels on the side of the frame this
  // means half of the window will be outside of the frame. As such the actual
  // buffer that we need to subtract the avg from will be 2 * wiener_halfwin
  // wider and 2 * wiener_halfwin higher than the original dgd buffer.
  const int vert_offset = v_start - wiener_halfwin;
  const int horiz_offset = h_start - wiener_halfwin;
  const uint8_t *dgd_win = dgd + horiz_offset + vert_offset * dgd_stride;
  compute_sub_avg(dgd_win, dgd_stride, avg, dgd_avg, dgd_avg_stride,
                  width + 2 * wiener_halfwin, height + 2 * wiener_halfwin, 1);

  // Compute (src - avg), downsample if necessary and store in src-avg.
  const uint8_t *src_start = src + h_start + v_start * src_stride;
  compute_sub_avg(src_start, src_stride * downsample_factor, avg, src_avg,
                  src_avg_stride, width, height, downsample_factor);

  const int downsample_height = height / downsample_factor;

  // Since the height is not necessarily a multiple of the downsample factor,
  // the last line of src will be scaled according to how many rows remain.
  const int downsample_remainder = height % downsample_factor;

  if (wiener_win == WIENER_WIN) {
    compute_stats_win7_sve(dgd_avg, dgd_avg_stride, src_avg, src_avg_stride,
                           width, downsample_height, M, H, downsample_factor);
  } else {
    compute_stats_win5_sve(dgd_avg, dgd_avg_stride, src_avg, src_avg_stride,
                           width, downsample_height, M, H, downsample_factor);
  }

  if (downsample_remainder > 0) {
    const int remainder_offset = height - downsample_remainder;
    if (wiener_win == WIENER_WIN) {
      compute_stats_win7_sve(
          dgd_avg + remainder_offset * dgd_avg_stride, dgd_avg_stride,
          src_avg + downsample_height * src_avg_stride, src_avg_stride, width,
          1, M, H, downsample_remainder);
    } else {
      compute_stats_win5_sve(
          dgd_avg + remainder_offset * dgd_avg_stride, dgd_avg_stride,
          src_avg + downsample_height * src_avg_stride, src_avg_stride, width,
          1, M, H, downsample_remainder);
    }
  }
}