summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/enc_linalg_test.cc
blob: 967b9a3afbb48bccd1cd53c2d94170c29eeb50cf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include "lib/jxl/enc_linalg.h"

#include "lib/jxl/image_test_utils.h"
#include "lib/jxl/testing.h"

namespace jxl {
namespace {

ImageD Identity(const size_t N) {
  ImageD out(N, N);
  for (size_t i = 0; i < N; ++i) {
    double* JXL_RESTRICT row = out.Row(i);
    std::fill(row, row + N, 0);
    row[i] = 1.0;
  }
  return out;
}

ImageD Diagonal(const ImageD& d) {
  JXL_ASSERT(d.ysize() == 1);
  ImageD out(d.xsize(), d.xsize());
  const double* JXL_RESTRICT row_diag = d.Row(0);
  for (size_t k = 0; k < d.xsize(); ++k) {
    double* JXL_RESTRICT row_out = out.Row(k);
    std::fill(row_out, row_out + d.xsize(), 0.0);
    row_out[k] = row_diag[k];
  }
  return out;
}

ImageD MatMul(const ImageD& A, const ImageD& B) {
  JXL_ASSERT(A.ysize() == B.xsize());
  ImageD out(A.xsize(), B.ysize());
  for (size_t y = 0; y < B.ysize(); ++y) {
    const double* const JXL_RESTRICT row_b = B.Row(y);
    double* const JXL_RESTRICT row_out = out.Row(y);
    for (size_t x = 0; x < A.xsize(); ++x) {
      row_out[x] = 0.0;
      for (size_t k = 0; k < B.xsize(); ++k) {
        row_out[x] += A.Row(k)[x] * row_b[k];
      }
    }
  }
  return out;
}

ImageD Transpose(const ImageD& A) {
  ImageD out(A.ysize(), A.xsize());
  for (size_t x = 0; x < A.xsize(); ++x) {
    double* const JXL_RESTRICT row_out = out.Row(x);
    for (size_t y = 0; y < A.ysize(); ++y) {
      row_out[y] = A.Row(y)[x];
    }
  }
  return out;
}

ImageD RandomSymmetricMatrix(const size_t N, Rng& rng, const double vmin,
                             const double vmax) {
  ImageD A(N, N);
  GenerateImage(rng, &A, vmin, vmax);
  for (size_t i = 0; i < N; ++i) {
    for (size_t j = 0; j < i; ++j) {
      A.Row(j)[i] = A.Row(i)[j];
    }
  }
  return A;
}

void VerifyMatrixEqual(const ImageD& A, const ImageD& B, const double eps) {
  ASSERT_EQ(A.xsize(), B.xsize());
  ASSERT_EQ(A.ysize(), B.ysize());
  for (size_t y = 0; y < A.ysize(); ++y) {
    for (size_t x = 0; x < A.xsize(); ++x) {
      ASSERT_NEAR(A.Row(y)[x], B.Row(y)[x], eps);
    }
  }
}

void VerifyOrthogonal(const ImageD& A, const double eps) {
  VerifyMatrixEqual(Identity(A.xsize()), MatMul(Transpose(A), A), eps);
}

TEST(LinAlgTest, ConvertToDiagonal) {
  {
    ImageD I = Identity(2);
    ImageD U(2, 2), d(2, 1);
    ConvertToDiagonal(I, &d, &U);
    VerifyMatrixEqual(I, U, 1e-15);
    for (size_t k = 0; k < 2; ++k) {
      ASSERT_NEAR(d.Row(0)[k], 1.0, 1e-15);
    }
  }
  {
    ImageD A = Identity(2);
    A.Row(0)[1] = A.Row(1)[0] = 2.0;
    ImageD U(2, 2), d(2, 1);
    ConvertToDiagonal(A, &d, &U);
    VerifyOrthogonal(U, 1e-12);
    VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12);
  }
  Rng rng(0);
  for (size_t i = 0; i < 100; ++i) {
    ImageD A = RandomSymmetricMatrix(2, rng, -1.0, 1.0);
    ImageD U(2, 2), d(2, 1);
    ConvertToDiagonal(A, &d, &U);
    VerifyOrthogonal(U, 1e-12);
    VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12);
  }
}

}  // namespace
}  // namespace jxl