diff options
Diffstat (limited to 'src/ml/dlib/dlib/matrix/matrix_lu.h')
-rw-r--r-- | src/ml/dlib/dlib/matrix/matrix_lu.h | 361 |
1 files changed, 361 insertions, 0 deletions
diff --git a/src/ml/dlib/dlib/matrix/matrix_lu.h b/src/ml/dlib/dlib/matrix/matrix_lu.h new file mode 100644 index 000000000..3e49cd653 --- /dev/null +++ b/src/ml/dlib/dlib/matrix/matrix_lu.h @@ -0,0 +1,361 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +// This code was adapted from code from the JAMA part of NIST's TNT library. +// See: http://math.nist.gov/tnt/ +#ifndef DLIB_MATRIX_LU_DECOMPOSITION_H +#define DLIB_MATRIX_LU_DECOMPOSITION_H + +#include "matrix.h" +#include "matrix_utilities.h" +#include "matrix_subexp.h" +#include "matrix_trsm.h" +#include <algorithm> + +#ifdef DLIB_USE_LAPACK +#include "lapack/getrf.h" +#endif + + +namespace dlib +{ + + template < + typename matrix_exp_type + > + class lu_decomposition + { + public: + + const static long NR = matrix_exp_type::NR; + const static long NC = matrix_exp_type::NC; + typedef typename matrix_exp_type::type type; + typedef typename matrix_exp_type::mem_manager_type mem_manager_type; + typedef typename matrix_exp_type::layout_type layout_type; + + typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type; + typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type; + typedef matrix<long,NR,1,mem_manager_type,layout_type> pivot_column_vector_type; + + // You have supplied an invalid type of matrix_exp_type. You have + // to use this object with matrices that contain float or double type data. + COMPILE_TIME_ASSERT((is_same_type<float, type>::value || + is_same_type<double, type>::value )); + + template <typename EXP> + lu_decomposition ( + const matrix_exp<EXP> &A + ); + + bool is_square ( + ) const; + + bool is_singular ( + ) const; + + long nr( + ) const; + + long nc( + ) const; + + const matrix_type get_l ( + ) const; + + const matrix_type get_u ( + ) const; + + const pivot_column_vector_type& get_pivot ( + ) const; + + type det ( + ) const; + + template <typename EXP> + const matrix_type solve ( + const matrix_exp<EXP> &B + ) const; + + private: + + /* Array for internal storage of decomposition. */ + matrix<type,0,0,mem_manager_type,column_major_layout> LU; + long m, n, pivsign; + pivot_column_vector_type piv; + + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Public member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template <typename matrix_exp_type> + template <typename EXP> + lu_decomposition<matrix_exp_type>:: + lu_decomposition ( + const matrix_exp<EXP>& A + ) : + LU(A), + m(A.nr()), + n(A.nc()) + { + using namespace std; + using std::abs; + + COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value)); + + // make sure requires clause is not broken + DLIB_ASSERT(A.size() > 0, + "\tlu_decomposition::lu_decomposition(A)" + << "\n\tInvalid inputs were given to this function" + << "\n\tA.size(): " << A.size() + << "\n\tthis: " << this + ); + +#ifdef DLIB_USE_LAPACK + matrix<lapack::integer,0,1,mem_manager_type,layout_type> piv_temp; + lapack::getrf(LU, piv_temp); + + pivsign = 1; + + // Turn the piv_temp vector into a more useful form. This way we will have the identity + // rowm(A,piv) == L*U. The permutation vector that comes out of LAPACK is somewhat + // different. + piv = trans(range(0,m-1)); + for (long i = 0; i < piv_temp.size(); ++i) + { + // -1 because FORTRAN is indexed starting with 1 instead of 0 + if (piv(piv_temp(i)-1) != piv(i)) + { + std::swap(piv(i), piv(piv_temp(i)-1)); + pivsign = -pivsign; + } + } + +#else + + // Use a "left-looking", dot-product, Crout/Doolittle algorithm. + + + piv = trans(range(0,m-1)); + pivsign = 1; + + column_vector_type LUcolj(m); + + // Outer loop. + for (long j = 0; j < n; j++) + { + + // Make a copy of the j-th column to localize references. + LUcolj = colm(LU,j); + + // Apply previous transformations. + for (long i = 0; i < m; i++) + { + // Most of the time is spent in the following dot product. + const long kmax = std::min(i,j); + type s; + if (kmax > 0) + s = rowm(LU,i, kmax)*colm(LUcolj,0,kmax); + else + s = 0; + + LU(i,j) = LUcolj(i) -= s; + } + + // Find pivot and exchange if necessary. + long p = j; + for (long i = j+1; i < m; i++) + { + if (abs(LUcolj(i)) > abs(LUcolj(p))) + { + p = i; + } + } + if (p != j) + { + long k=0; + for (k = 0; k < n; k++) + { + type t = LU(p,k); + LU(p,k) = LU(j,k); + LU(j,k) = t; + } + k = piv(p); + piv(p) = piv(j); + piv(j) = k; + pivsign = -pivsign; + } + + // Compute multipliers. + if ((j < m) && (LU(j,j) != 0.0)) + { + for (long i = j+1; i < m; i++) + { + LU(i,j) /= LU(j,j); + } + } + } + +#endif + } + +// ---------------------------------------------------------------------------------------- + + template <typename matrix_exp_type> + bool lu_decomposition<matrix_exp_type>:: + is_square ( + ) const + { + return m == n; + } + +// ---------------------------------------------------------------------------------------- + + template <typename matrix_exp_type> + long lu_decomposition<matrix_exp_type>:: + nr ( + ) const + { + return m; + } + +// ---------------------------------------------------------------------------------------- + + template <typename matrix_exp_type> + long lu_decomposition<matrix_exp_type>:: + nc ( + ) const + { + return n; + } + +// ---------------------------------------------------------------------------------------- + + template <typename matrix_exp_type> + bool lu_decomposition<matrix_exp_type>:: + is_singular ( + ) const + { + /* Is the matrix singular? + if upper triangular factor U (and hence A) is singular, false otherwise. + */ + // make sure requires clause is not broken + DLIB_ASSERT(is_square() == true, + "\tbool lu_decomposition::is_singular()" + << "\n\tYou can only use this on square matrices" + << "\n\tthis: " << this + ); + + type max_val, min_val; + find_min_and_max (abs(diag(LU)), min_val, max_val); + type eps = max_val; + if (eps != 0) + eps *= std::sqrt(std::numeric_limits<type>::epsilon())/10; + else + eps = 1; // there is no max so just use 1 + + return min_val < eps; + } + +// ---------------------------------------------------------------------------------------- + + template <typename matrix_exp_type> + const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>:: + get_l ( + ) const + { + if (LU.nr() >= LU.nc()) + return lowerm(LU,1.0); + else + return lowerm(subm(LU,0,0,m,m), 1.0); + } + +// ---------------------------------------------------------------------------------------- + + template <typename matrix_exp_type> + const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>:: + get_u ( + ) const + { + if (LU.nr() >= LU.nc()) + return upperm(subm(LU,0,0,n,n)); + else + return upperm(LU); + } + +// ---------------------------------------------------------------------------------------- + + template <typename matrix_exp_type> + const typename lu_decomposition<matrix_exp_type>::pivot_column_vector_type& lu_decomposition<matrix_exp_type>:: + get_pivot ( + ) const + { + return piv; + } + +// ---------------------------------------------------------------------------------------- + + template <typename matrix_exp_type> + typename lu_decomposition<matrix_exp_type>::type lu_decomposition<matrix_exp_type>:: + det ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_square() == true, + "\ttype lu_decomposition::det()" + << "\n\tYou can only use this on square matrices" + << "\n\tthis: " << this + ); + + // Check if it is singular and if it is just return 0. + // We want to do this because a prod() operation can easily + // overcome a single diagonal element that is effectively 0 when + // LU is a big enough matrix. + if (is_singular()) + return 0; + + return prod(diag(LU))*static_cast<type>(pivsign); + } + +// ---------------------------------------------------------------------------------------- + + template <typename matrix_exp_type> + template <typename EXP> + const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>:: + solve ( + const matrix_exp<EXP> &B + ) const + { + COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value)); + + // make sure requires clause is not broken + DLIB_ASSERT(is_square() == true && B.nr() == nr(), + "\ttype lu_decomposition::solve()" + << "\n\tInvalid arguments to this function" + << "\n\tis_square(): " << (is_square()? "true":"false" ) + << "\n\tB.nr(): " << B.nr() + << "\n\tnr(): " << nr() + << "\n\tthis: " << this + ); + + // Copy right hand side with pivoting + matrix<type,0,0,mem_manager_type,column_major_layout> X(rowm(B, piv)); + + using namespace blas_bindings; + // Solve L*Y = B(piv,:) + triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasUnit, LU, X); + // Solve U*X = Y; + triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, LU, X); + return X; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIX_LU_DECOMPOSITION_H + + |