summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/base/matrix_ops.h')
-rw-r--r--third_party/jpeg-xl/lib/jxl/base/matrix_ops.h78
1 files changed, 40 insertions, 38 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h b/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h
index 1a969bd4f0..cde6a64b1e 100644
--- a/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h
+++ b/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h
@@ -8,6 +8,7 @@
// 3x3 matrix operations.
+#include <array>
#include <cmath> // abs
#include <cstddef>
@@ -15,66 +16,67 @@
namespace jxl {
+typedef std::array<float, 3> Vector3;
+typedef std::array<double, 3> Vector3d;
+typedef std::array<Vector3, 3> Matrix3x3;
+typedef std::array<Vector3d, 3> Matrix3x3d;
+
// Computes C = A * B, where A, B, C are 3x3 matrices.
-template <typename T>
-void Mul3x3Matrix(const T* a, const T* b, T* c) {
- alignas(16) T temp[3]; // For transposed column
+template <typename Matrix>
+void Mul3x3Matrix(const Matrix& a, const Matrix& b, Matrix& c) {
for (size_t x = 0; x < 3; x++) {
- for (size_t z = 0; z < 3; z++) {
- temp[z] = b[z * 3 + x];
- }
+ alignas(16) Vector3d temp{b[0][x], b[1][x], b[2][x]}; // transpose
for (size_t y = 0; y < 3; y++) {
- double e = 0;
- for (size_t z = 0; z < 3; z++) {
- e += a[y * 3 + z] * temp[z];
- }
- c[y * 3 + x] = e;
+ c[y][x] = a[y][0] * temp[0] + a[y][1] * temp[1] + a[y][2] * temp[2];
}
}
}
// Computes C = A * B, where A is 3x3 matrix and B is vector.
-template <typename T>
-void Mul3x3Vector(const T* a, const T* b, T* c) {
+template <typename Matrix, typename Vector>
+void Mul3x3Vector(const Matrix& a, const Vector& b, Vector& c) {
for (size_t y = 0; y < 3; y++) {
double e = 0;
for (size_t x = 0; x < 3; x++) {
- e += a[y * 3 + x] * b[x];
+ e += a[y][x] * b[x];
}
c[y] = e;
}
}
// Inverts a 3x3 matrix in place.
-template <typename T>
-Status Inv3x3Matrix(T* matrix) {
+template <typename Matrix>
+Status Inv3x3Matrix(Matrix& matrix) {
// Intermediate computation is done in double precision.
- double temp[9];
- temp[0] = static_cast<double>(matrix[4]) * matrix[8] -
- static_cast<double>(matrix[5]) * matrix[7];
- temp[1] = static_cast<double>(matrix[2]) * matrix[7] -
- static_cast<double>(matrix[1]) * matrix[8];
- temp[2] = static_cast<double>(matrix[1]) * matrix[5] -
- static_cast<double>(matrix[2]) * matrix[4];
- temp[3] = static_cast<double>(matrix[5]) * matrix[6] -
- static_cast<double>(matrix[3]) * matrix[8];
- temp[4] = static_cast<double>(matrix[0]) * matrix[8] -
- static_cast<double>(matrix[2]) * matrix[6];
- temp[5] = static_cast<double>(matrix[2]) * matrix[3] -
- static_cast<double>(matrix[0]) * matrix[5];
- temp[6] = static_cast<double>(matrix[3]) * matrix[7] -
- static_cast<double>(matrix[4]) * matrix[6];
- temp[7] = static_cast<double>(matrix[1]) * matrix[6] -
- static_cast<double>(matrix[0]) * matrix[7];
- temp[8] = static_cast<double>(matrix[0]) * matrix[4] -
- static_cast<double>(matrix[1]) * matrix[3];
- double det = matrix[0] * temp[0] + matrix[1] * temp[3] + matrix[2] * temp[6];
+ Matrix3x3d temp;
+ temp[0][0] = static_cast<double>(matrix[1][1]) * matrix[2][2] -
+ static_cast<double>(matrix[1][2]) * matrix[2][1];
+ temp[0][1] = static_cast<double>(matrix[0][2]) * matrix[2][1] -
+ static_cast<double>(matrix[0][1]) * matrix[2][2];
+ temp[0][2] = static_cast<double>(matrix[0][1]) * matrix[1][2] -
+ static_cast<double>(matrix[0][2]) * matrix[1][1];
+ temp[1][0] = static_cast<double>(matrix[1][2]) * matrix[2][0] -
+ static_cast<double>(matrix[1][0]) * matrix[2][2];
+ temp[1][1] = static_cast<double>(matrix[0][0]) * matrix[2][2] -
+ static_cast<double>(matrix[0][2]) * matrix[2][0];
+ temp[1][2] = static_cast<double>(matrix[0][2]) * matrix[1][0] -
+ static_cast<double>(matrix[0][0]) * matrix[1][2];
+ temp[2][0] = static_cast<double>(matrix[1][0]) * matrix[2][1] -
+ static_cast<double>(matrix[1][1]) * matrix[2][0];
+ temp[2][1] = static_cast<double>(matrix[0][1]) * matrix[2][0] -
+ static_cast<double>(matrix[0][0]) * matrix[2][1];
+ temp[2][2] = static_cast<double>(matrix[0][0]) * matrix[1][1] -
+ static_cast<double>(matrix[0][1]) * matrix[1][0];
+ double det = matrix[0][0] * temp[0][0] + matrix[0][1] * temp[1][0] +
+ matrix[0][2] * temp[2][0];
if (std::abs(det) < 1e-10) {
return JXL_FAILURE("Matrix determinant is too close to 0");
}
double idet = 1.0 / det;
- for (size_t i = 0; i < 9; i++) {
- matrix[i] = temp[i] * idet;
+ for (size_t j = 0; j < 3; j++) {
+ for (size_t i = 0; i < 3; i++) {
+ matrix[j][i] = temp[j][i] * idet;
+ }
}
return true;
}