1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "lib/jxl/enc_linalg.h"
#include "lib/jxl/image_test_utils.h"
#include "lib/jxl/testing.h"
namespace jxl {
namespace {
ImageD Identity(const size_t N) {
ImageD out(N, N);
for (size_t i = 0; i < N; ++i) {
double* JXL_RESTRICT row = out.Row(i);
std::fill(row, row + N, 0);
row[i] = 1.0;
}
return out;
}
ImageD Diagonal(const ImageD& d) {
JXL_ASSERT(d.ysize() == 1);
ImageD out(d.xsize(), d.xsize());
const double* JXL_RESTRICT row_diag = d.Row(0);
for (size_t k = 0; k < d.xsize(); ++k) {
double* JXL_RESTRICT row_out = out.Row(k);
std::fill(row_out, row_out + d.xsize(), 0.0);
row_out[k] = row_diag[k];
}
return out;
}
ImageD MatMul(const ImageD& A, const ImageD& B) {
JXL_ASSERT(A.ysize() == B.xsize());
ImageD out(A.xsize(), B.ysize());
for (size_t y = 0; y < B.ysize(); ++y) {
const double* const JXL_RESTRICT row_b = B.Row(y);
double* const JXL_RESTRICT row_out = out.Row(y);
for (size_t x = 0; x < A.xsize(); ++x) {
row_out[x] = 0.0;
for (size_t k = 0; k < B.xsize(); ++k) {
row_out[x] += A.Row(k)[x] * row_b[k];
}
}
}
return out;
}
ImageD Transpose(const ImageD& A) {
ImageD out(A.ysize(), A.xsize());
for (size_t x = 0; x < A.xsize(); ++x) {
double* const JXL_RESTRICT row_out = out.Row(x);
for (size_t y = 0; y < A.ysize(); ++y) {
row_out[y] = A.Row(y)[x];
}
}
return out;
}
ImageD RandomSymmetricMatrix(const size_t N, Rng& rng, const double vmin,
const double vmax) {
ImageD A(N, N);
GenerateImage(rng, &A, vmin, vmax);
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < i; ++j) {
A.Row(j)[i] = A.Row(i)[j];
}
}
return A;
}
void VerifyMatrixEqual(const ImageD& A, const ImageD& B, const double eps) {
ASSERT_EQ(A.xsize(), B.xsize());
ASSERT_EQ(A.ysize(), B.ysize());
for (size_t y = 0; y < A.ysize(); ++y) {
for (size_t x = 0; x < A.xsize(); ++x) {
ASSERT_NEAR(A.Row(y)[x], B.Row(y)[x], eps);
}
}
}
void VerifyOrthogonal(const ImageD& A, const double eps) {
VerifyMatrixEqual(Identity(A.xsize()), MatMul(Transpose(A), A), eps);
}
TEST(LinAlgTest, ConvertToDiagonal) {
{
ImageD I = Identity(2);
ImageD U(2, 2), d(2, 1);
ConvertToDiagonal(I, &d, &U);
VerifyMatrixEqual(I, U, 1e-15);
for (size_t k = 0; k < 2; ++k) {
ASSERT_NEAR(d.Row(0)[k], 1.0, 1e-15);
}
}
{
ImageD A = Identity(2);
A.Row(0)[1] = A.Row(1)[0] = 2.0;
ImageD U(2, 2), d(2, 1);
ConvertToDiagonal(A, &d, &U);
VerifyOrthogonal(U, 1e-12);
VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12);
}
Rng rng(0);
for (size_t i = 0; i < 100; ++i) {
ImageD A = RandomSymmetricMatrix(2, rng, -1.0, 1.0);
ImageD U(2, 2), d(2, 1);
ConvertToDiagonal(A, &d, &U);
VerifyOrthogonal(U, 1e-12);
VerifyMatrixEqual(A, MatMul(U, MatMul(Diagonal(d), Transpose(U))), 1e-12);
}
}
} // namespace
} // namespace jxl
|