diff options
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/base/matrix_ops.h')
-rw-r--r-- | third_party/jpeg-xl/lib/jxl/base/matrix_ops.h | 78 |
1 files changed, 40 insertions, 38 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h b/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h index 1a969bd4f0..cde6a64b1e 100644 --- a/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h +++ b/third_party/jpeg-xl/lib/jxl/base/matrix_ops.h @@ -8,6 +8,7 @@ // 3x3 matrix operations. +#include <array> #include <cmath> // abs #include <cstddef> @@ -15,66 +16,67 @@ namespace jxl { +typedef std::array<float, 3> Vector3; +typedef std::array<double, 3> Vector3d; +typedef std::array<Vector3, 3> Matrix3x3; +typedef std::array<Vector3d, 3> Matrix3x3d; + // Computes C = A * B, where A, B, C are 3x3 matrices. -template <typename T> -void Mul3x3Matrix(const T* a, const T* b, T* c) { - alignas(16) T temp[3]; // For transposed column +template <typename Matrix> +void Mul3x3Matrix(const Matrix& a, const Matrix& b, Matrix& c) { for (size_t x = 0; x < 3; x++) { - for (size_t z = 0; z < 3; z++) { - temp[z] = b[z * 3 + x]; - } + alignas(16) Vector3d temp{b[0][x], b[1][x], b[2][x]}; // transpose for (size_t y = 0; y < 3; y++) { - double e = 0; - for (size_t z = 0; z < 3; z++) { - e += a[y * 3 + z] * temp[z]; - } - c[y * 3 + x] = e; + c[y][x] = a[y][0] * temp[0] + a[y][1] * temp[1] + a[y][2] * temp[2]; } } } // Computes C = A * B, where A is 3x3 matrix and B is vector. -template <typename T> -void Mul3x3Vector(const T* a, const T* b, T* c) { +template <typename Matrix, typename Vector> +void Mul3x3Vector(const Matrix& a, const Vector& b, Vector& c) { for (size_t y = 0; y < 3; y++) { double e = 0; for (size_t x = 0; x < 3; x++) { - e += a[y * 3 + x] * b[x]; + e += a[y][x] * b[x]; } c[y] = e; } } // Inverts a 3x3 matrix in place. -template <typename T> -Status Inv3x3Matrix(T* matrix) { +template <typename Matrix> +Status Inv3x3Matrix(Matrix& matrix) { // Intermediate computation is done in double precision. - double temp[9]; - temp[0] = static_cast<double>(matrix[4]) * matrix[8] - - static_cast<double>(matrix[5]) * matrix[7]; - temp[1] = static_cast<double>(matrix[2]) * matrix[7] - - static_cast<double>(matrix[1]) * matrix[8]; - temp[2] = static_cast<double>(matrix[1]) * matrix[5] - - static_cast<double>(matrix[2]) * matrix[4]; - temp[3] = static_cast<double>(matrix[5]) * matrix[6] - - static_cast<double>(matrix[3]) * matrix[8]; - temp[4] = static_cast<double>(matrix[0]) * matrix[8] - - static_cast<double>(matrix[2]) * matrix[6]; - temp[5] = static_cast<double>(matrix[2]) * matrix[3] - - static_cast<double>(matrix[0]) * matrix[5]; - temp[6] = static_cast<double>(matrix[3]) * matrix[7] - - static_cast<double>(matrix[4]) * matrix[6]; - temp[7] = static_cast<double>(matrix[1]) * matrix[6] - - static_cast<double>(matrix[0]) * matrix[7]; - temp[8] = static_cast<double>(matrix[0]) * matrix[4] - - static_cast<double>(matrix[1]) * matrix[3]; - double det = matrix[0] * temp[0] + matrix[1] * temp[3] + matrix[2] * temp[6]; + Matrix3x3d temp; + temp[0][0] = static_cast<double>(matrix[1][1]) * matrix[2][2] - + static_cast<double>(matrix[1][2]) * matrix[2][1]; + temp[0][1] = static_cast<double>(matrix[0][2]) * matrix[2][1] - + static_cast<double>(matrix[0][1]) * matrix[2][2]; + temp[0][2] = static_cast<double>(matrix[0][1]) * matrix[1][2] - + static_cast<double>(matrix[0][2]) * matrix[1][1]; + temp[1][0] = static_cast<double>(matrix[1][2]) * matrix[2][0] - + static_cast<double>(matrix[1][0]) * matrix[2][2]; + temp[1][1] = static_cast<double>(matrix[0][0]) * matrix[2][2] - + static_cast<double>(matrix[0][2]) * matrix[2][0]; + temp[1][2] = static_cast<double>(matrix[0][2]) * matrix[1][0] - + static_cast<double>(matrix[0][0]) * matrix[1][2]; + temp[2][0] = static_cast<double>(matrix[1][0]) * matrix[2][1] - + static_cast<double>(matrix[1][1]) * matrix[2][0]; + temp[2][1] = static_cast<double>(matrix[0][1]) * matrix[2][0] - + static_cast<double>(matrix[0][0]) * matrix[2][1]; + temp[2][2] = static_cast<double>(matrix[0][0]) * matrix[1][1] - + static_cast<double>(matrix[0][1]) * matrix[1][0]; + double det = matrix[0][0] * temp[0][0] + matrix[0][1] * temp[1][0] + + matrix[0][2] * temp[2][0]; if (std::abs(det) < 1e-10) { return JXL_FAILURE("Matrix determinant is too close to 0"); } double idet = 1.0 / det; - for (size_t i = 0; i < 9; i++) { - matrix[i] = temp[i] * idet; + for (size_t j = 0; j < 3; j++) { + for (size_t i = 0; i < 3; i++) { + matrix[j][i] = temp[j][i] * idet; + } } return true; } |