summaryrefslogtreecommitdiffstats
path: root/src/ml/dlib/dlib/matrix/matrix_lu.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/ml/dlib/dlib/matrix/matrix_lu.h')
-rw-r--r--src/ml/dlib/dlib/matrix/matrix_lu.h361
1 files changed, 361 insertions, 0 deletions
diff --git a/src/ml/dlib/dlib/matrix/matrix_lu.h b/src/ml/dlib/dlib/matrix/matrix_lu.h
new file mode 100644
index 000000000..3e49cd653
--- /dev/null
+++ b/src/ml/dlib/dlib/matrix/matrix_lu.h
@@ -0,0 +1,361 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+// This code was adapted from code from the JAMA part of NIST's TNT library.
+// See: http://math.nist.gov/tnt/
+#ifndef DLIB_MATRIX_LU_DECOMPOSITION_H
+#define DLIB_MATRIX_LU_DECOMPOSITION_H
+
+#include "matrix.h"
+#include "matrix_utilities.h"
+#include "matrix_subexp.h"
+#include "matrix_trsm.h"
+#include <algorithm>
+
+#ifdef DLIB_USE_LAPACK
+#include "lapack/getrf.h"
+#endif
+
+
+namespace dlib
+{
+
+ template <
+ typename matrix_exp_type
+ >
+ class lu_decomposition
+ {
+ public:
+
+ const static long NR = matrix_exp_type::NR;
+ const static long NC = matrix_exp_type::NC;
+ typedef typename matrix_exp_type::type type;
+ typedef typename matrix_exp_type::mem_manager_type mem_manager_type;
+ typedef typename matrix_exp_type::layout_type layout_type;
+
+ typedef matrix<type,0,0,mem_manager_type,layout_type> matrix_type;
+ typedef matrix<type,NR,1,mem_manager_type,layout_type> column_vector_type;
+ typedef matrix<long,NR,1,mem_manager_type,layout_type> pivot_column_vector_type;
+
+ // You have supplied an invalid type of matrix_exp_type. You have
+ // to use this object with matrices that contain float or double type data.
+ COMPILE_TIME_ASSERT((is_same_type<float, type>::value ||
+ is_same_type<double, type>::value ));
+
+ template <typename EXP>
+ lu_decomposition (
+ const matrix_exp<EXP> &A
+ );
+
+ bool is_square (
+ ) const;
+
+ bool is_singular (
+ ) const;
+
+ long nr(
+ ) const;
+
+ long nc(
+ ) const;
+
+ const matrix_type get_l (
+ ) const;
+
+ const matrix_type get_u (
+ ) const;
+
+ const pivot_column_vector_type& get_pivot (
+ ) const;
+
+ type det (
+ ) const;
+
+ template <typename EXP>
+ const matrix_type solve (
+ const matrix_exp<EXP> &B
+ ) const;
+
+ private:
+
+ /* Array for internal storage of decomposition. */
+ matrix<type,0,0,mem_manager_type,column_major_layout> LU;
+ long m, n, pivsign;
+ pivot_column_vector_type piv;
+
+
+ };
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// Public member functions
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ lu_decomposition<matrix_exp_type>::
+ lu_decomposition (
+ const matrix_exp<EXP>& A
+ ) :
+ LU(A),
+ m(A.nr()),
+ n(A.nc())
+ {
+ using namespace std;
+ using std::abs;
+
+ COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(A.size() > 0,
+ "\tlu_decomposition::lu_decomposition(A)"
+ << "\n\tInvalid inputs were given to this function"
+ << "\n\tA.size(): " << A.size()
+ << "\n\tthis: " << this
+ );
+
+#ifdef DLIB_USE_LAPACK
+ matrix<lapack::integer,0,1,mem_manager_type,layout_type> piv_temp;
+ lapack::getrf(LU, piv_temp);
+
+ pivsign = 1;
+
+ // Turn the piv_temp vector into a more useful form. This way we will have the identity
+ // rowm(A,piv) == L*U. The permutation vector that comes out of LAPACK is somewhat
+ // different.
+ piv = trans(range(0,m-1));
+ for (long i = 0; i < piv_temp.size(); ++i)
+ {
+ // -1 because FORTRAN is indexed starting with 1 instead of 0
+ if (piv(piv_temp(i)-1) != piv(i))
+ {
+ std::swap(piv(i), piv(piv_temp(i)-1));
+ pivsign = -pivsign;
+ }
+ }
+
+#else
+
+ // Use a "left-looking", dot-product, Crout/Doolittle algorithm.
+
+
+ piv = trans(range(0,m-1));
+ pivsign = 1;
+
+ column_vector_type LUcolj(m);
+
+ // Outer loop.
+ for (long j = 0; j < n; j++)
+ {
+
+ // Make a copy of the j-th column to localize references.
+ LUcolj = colm(LU,j);
+
+ // Apply previous transformations.
+ for (long i = 0; i < m; i++)
+ {
+ // Most of the time is spent in the following dot product.
+ const long kmax = std::min(i,j);
+ type s;
+ if (kmax > 0)
+ s = rowm(LU,i, kmax)*colm(LUcolj,0,kmax);
+ else
+ s = 0;
+
+ LU(i,j) = LUcolj(i) -= s;
+ }
+
+ // Find pivot and exchange if necessary.
+ long p = j;
+ for (long i = j+1; i < m; i++)
+ {
+ if (abs(LUcolj(i)) > abs(LUcolj(p)))
+ {
+ p = i;
+ }
+ }
+ if (p != j)
+ {
+ long k=0;
+ for (k = 0; k < n; k++)
+ {
+ type t = LU(p,k);
+ LU(p,k) = LU(j,k);
+ LU(j,k) = t;
+ }
+ k = piv(p);
+ piv(p) = piv(j);
+ piv(j) = k;
+ pivsign = -pivsign;
+ }
+
+ // Compute multipliers.
+ if ((j < m) && (LU(j,j) != 0.0))
+ {
+ for (long i = j+1; i < m; i++)
+ {
+ LU(i,j) /= LU(j,j);
+ }
+ }
+ }
+
+#endif
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ bool lu_decomposition<matrix_exp_type>::
+ is_square (
+ ) const
+ {
+ return m == n;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ long lu_decomposition<matrix_exp_type>::
+ nr (
+ ) const
+ {
+ return m;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ long lu_decomposition<matrix_exp_type>::
+ nc (
+ ) const
+ {
+ return n;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ bool lu_decomposition<matrix_exp_type>::
+ is_singular (
+ ) const
+ {
+ /* Is the matrix singular?
+ if upper triangular factor U (and hence A) is singular, false otherwise.
+ */
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_square() == true,
+ "\tbool lu_decomposition::is_singular()"
+ << "\n\tYou can only use this on square matrices"
+ << "\n\tthis: " << this
+ );
+
+ type max_val, min_val;
+ find_min_and_max (abs(diag(LU)), min_val, max_val);
+ type eps = max_val;
+ if (eps != 0)
+ eps *= std::sqrt(std::numeric_limits<type>::epsilon())/10;
+ else
+ eps = 1; // there is no max so just use 1
+
+ return min_val < eps;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
+ get_l (
+ ) const
+ {
+ if (LU.nr() >= LU.nc())
+ return lowerm(LU,1.0);
+ else
+ return lowerm(subm(LU,0,0,m,m), 1.0);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
+ get_u (
+ ) const
+ {
+ if (LU.nr() >= LU.nc())
+ return upperm(subm(LU,0,0,n,n));
+ else
+ return upperm(LU);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ const typename lu_decomposition<matrix_exp_type>::pivot_column_vector_type& lu_decomposition<matrix_exp_type>::
+ get_pivot (
+ ) const
+ {
+ return piv;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ typename lu_decomposition<matrix_exp_type>::type lu_decomposition<matrix_exp_type>::
+ det (
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_square() == true,
+ "\ttype lu_decomposition::det()"
+ << "\n\tYou can only use this on square matrices"
+ << "\n\tthis: " << this
+ );
+
+ // Check if it is singular and if it is just return 0.
+ // We want to do this because a prod() operation can easily
+ // overcome a single diagonal element that is effectively 0 when
+ // LU is a big enough matrix.
+ if (is_singular())
+ return 0;
+
+ return prod(diag(LU))*static_cast<type>(pivsign);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename matrix_exp_type>
+ template <typename EXP>
+ const typename lu_decomposition<matrix_exp_type>::matrix_type lu_decomposition<matrix_exp_type>::
+ solve (
+ const matrix_exp<EXP> &B
+ ) const
+ {
+ COMPILE_TIME_ASSERT((is_same_type<type, typename EXP::type>::value));
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_square() == true && B.nr() == nr(),
+ "\ttype lu_decomposition::solve()"
+ << "\n\tInvalid arguments to this function"
+ << "\n\tis_square(): " << (is_square()? "true":"false" )
+ << "\n\tB.nr(): " << B.nr()
+ << "\n\tnr(): " << nr()
+ << "\n\tthis: " << this
+ );
+
+ // Copy right hand side with pivoting
+ matrix<type,0,0,mem_manager_type,column_major_layout> X(rowm(B, piv));
+
+ using namespace blas_bindings;
+ // Solve L*Y = B(piv,:)
+ triangular_solver(CblasLeft, CblasLower, CblasNoTrans, CblasUnit, LU, X);
+ // Solve U*X = Y;
+ triangular_solver(CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, LU, X);
+ return X;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MATRIX_LU_DECOMPOSITION_H
+
+