diff options
Diffstat (limited to 'ml/dlib/dlib/matrix/matrix_assign.h')
-rw-r--r-- | ml/dlib/dlib/matrix/matrix_assign.h | 978 |
1 files changed, 978 insertions, 0 deletions
diff --git a/ml/dlib/dlib/matrix/matrix_assign.h b/ml/dlib/dlib/matrix/matrix_assign.h new file mode 100644 index 000000000..da53050b1 --- /dev/null +++ b/ml/dlib/dlib/matrix/matrix_assign.h @@ -0,0 +1,978 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MATRIx_ASSIGn_ +#define DLIB_MATRIx_ASSIGn_ + +#include "matrix.h" +#include "matrix_utilities.h" +#include "matrix_subexp.h" +#include "../enable_if.h" +#include "matrix_assign_fwd.h" +#include "matrix_default_mul.h" +#include "matrix_conj_trans.h" +#include "matrix_mat.h" + +namespace dlib +{ + /* + This file contains some templates that are used inside the matrix_blas_bindings.h + file to bind various matrix expressions to optimized code for carrying them out. + */ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace blas_bindings + { + + // ------------------------------------------------------------------------------------ + + template <typename T> + void zero_matrix ( + T& m + ) + { + for (long r = 0; r < m.nr(); ++r) + { + for (long c = 0; c < m.nc(); ++c) + { + m(r,c) = 0; + } + } + } + + // ------------------------------------------------------------------------------------ + + // This template struct is used to tell us if a matrix expression contains a matrix multiply. + template <typename T> + struct has_matrix_multiply + { + const static bool value = false; + }; + + template <typename T, typename U> + struct has_matrix_multiply<matrix_multiply_exp<T,U> > + { const static bool value = true; }; + + template <typename T, typename U> + struct has_matrix_multiply<matrix_add_exp<T,U> > + { const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; }; + + template <typename T, typename U> + struct has_matrix_multiply<matrix_subtract_exp<T,U> > + { const static bool value = has_matrix_multiply<T>::value || has_matrix_multiply<U>::value; }; + + template <typename T, bool Tb> + struct has_matrix_multiply<matrix_mul_scal_exp<T,Tb> > + { const static bool value = true; }; + + template <typename T> + struct has_matrix_multiply<matrix_div_scal_exp<T> > + { const static bool value = has_matrix_multiply<T>::value; }; + + template <typename T> + struct has_matrix_multiply<matrix_op<T> > + { const static bool value = has_matrix_multiply<T>::value; }; + + template <typename T> + struct has_matrix_multiply<op_trans<T> > + { const static bool value = has_matrix_multiply<T>::value; }; + + template <typename T> + struct has_matrix_multiply<op_conj_trans<T> > + { const static bool value = has_matrix_multiply<T>::value; }; + + template <typename T> + struct has_matrix_multiply<op_conj<T> > + { const static bool value = has_matrix_multiply<T>::value; }; + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + const int unknown_matrix = 0; + const int general_matrix = 1; + const int row_matrix = 2; + const int column_matrix = 3; + + // ------------------------------------------------------------------------------------ + + template <typename T> + struct matrix_type_id + { + const static int value = unknown_matrix; + }; + + template <typename T, long NR, long NC, typename MM, typename L> + struct matrix_type_id<matrix<T,NR,NC,MM,L> > + { + const static int value = general_matrix; + }; + + template <typename T, long NR, typename MM, typename L> + struct matrix_type_id<matrix<T,NR,1,MM,L> > + { + const static int value = column_matrix; + }; + + template <typename T, typename MM, typename L> + struct matrix_type_id<matrix<T,1,1,MM,L> > + { + const static int value = column_matrix; + }; + + template <typename T, long NC, typename MM, typename L> + struct matrix_type_id<matrix<T,1,NC,MM,L> > + { + const static int value = row_matrix; + }; + + // ------------------------------------------------------------------------------------ + + template <typename T, long NR, long NC, typename MM, typename L> + struct matrix_type_id<matrix_op<op_colm<matrix<T,NR,NC,MM,L> > > > + { + const static int value = column_matrix; + }; + + template <typename T, long NR, long NC, typename MM, typename L> + struct matrix_type_id<matrix_op<op_rowm<matrix<T,NR,NC,MM,L> > > > + { + const static int value = row_matrix; + }; + + template <typename T, long NR, long NC, typename MM, typename L> + struct matrix_type_id<matrix_op<op_colm2<matrix<T,NR,NC,MM,L> > > > + { + const static int value = column_matrix; + }; + + template <typename T, long NR, long NC, typename MM, typename L> + struct matrix_type_id<matrix_op<op_rowm2<matrix<T,NR,NC,MM,L> > > > + { + const static int value = row_matrix; + }; + + template <typename T, long NR, long NC, typename MM, typename L> + struct matrix_type_id<matrix_op<op_subm<matrix<T,NR,NC,MM,L> > > > + { + const static int value = general_matrix; + }; + + template < typename T, typename MM > + struct matrix_type_id<matrix_op<op_array2d_to_mat<array2d<T,MM> > > > + { const static int value = general_matrix; }; + + template < typename T, typename MM > + struct matrix_type_id<matrix_op<op_array_to_mat<array<T,MM> > > > + { const static int value = column_matrix; }; + + template < typename value_type, typename alloc > + struct matrix_type_id<matrix_op<op_std_vect_to_mat<std::vector<value_type,alloc> > > > + { const static int value = column_matrix; }; + + template < typename value_type, typename alloc > + struct matrix_type_id<matrix_op<op_std_vect_to_mat<std_vector_c<value_type,alloc> > > > + { const static int value = column_matrix; }; + + template < typename T > + struct matrix_type_id<matrix_op<op_pointer_to_col_vect<T> > > + { const static int value = column_matrix; }; + template < typename T > + struct matrix_type_id<matrix_op<op_pointer_to_mat<T> > > + { const static int value = general_matrix; }; + + // ------------------------------------------------------------------------------------ + + template <typename T, typename U> + struct same_matrix + { + const static int T_id = matrix_type_id<T>::value; + const static int U_id = matrix_type_id<U>::value; + // The check for unknown_matrix is here so that we can be sure that matrix types + // other than the ones specifically enumerated above never get pushed into + // any of the BLAS bindings. So saying they are never the same as anything + // else prevents them from matching any of the BLAS bindings. + const static bool value = (T_id == U_id) && (T_id != unknown_matrix); + }; + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + // This template struct is used to tell us if two matrix expressions both contain the same + // sequence of operators, expressions. It also only has a value of true if the T expression + // contains only matrices with the given layout. + template <typename T, typename U, typename layout> + struct same_exp + { + const static bool value = (is_same_type<typename T::exp_type, typename U::exp_type>::value || + same_matrix<typename T::exp_type, typename U::exp_type>::value) && + is_same_type<typename T::layout_type,layout>::value; + + }; + + // Used only below. They help strip off the const and & qualifiers that can show up + // in the LHS_ref_type and RHS_ref_type typedefs. + template <typename T> struct noref{ typedef T type;}; + template <typename T> struct noref<T&>{ typedef T type;}; + template <typename T> struct noref<const T&>{ typedef T type;}; + template <typename T> struct noref<const T>{ typedef T type;}; + + template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs, typename layout> + struct same_exp<matrix_multiply_exp<Tlhs,Trhs>, matrix_multiply_exp<Ulhs,Urhs>,layout > + { + // The reason this case is more complex than the others is because the matrix_multiply_exp + // will use a temporary matrix instead of Tlhs or Trhs in the event that one of these + // types corresponds to an expensive expression. So we have to use the type that really + // gets used. The following typedefs are here to pick out that true type. + typedef typename matrix_multiply_exp<Tlhs,Trhs>::LHS_ref_type T_LHS_ref_type; + typedef typename matrix_multiply_exp<Tlhs,Trhs>::RHS_ref_type T_RHS_ref_type; + typedef typename noref<T_LHS_ref_type>::type T_lhs_type; + typedef typename noref<T_RHS_ref_type>::type T_rhs_type; + + typedef typename matrix_multiply_exp<Ulhs,Urhs>::LHS_ref_type U_LHS_ref_type; + typedef typename matrix_multiply_exp<Ulhs,Urhs>::RHS_ref_type U_RHS_ref_type; + typedef typename noref<U_LHS_ref_type>::type U_lhs_type; + typedef typename noref<U_RHS_ref_type>::type U_rhs_type; + + const static bool value = same_exp<T_lhs_type,U_lhs_type,layout>::value && + same_exp<T_rhs_type,U_rhs_type,layout>::value; + }; + + template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs, typename layout> + struct same_exp<matrix_add_exp<Tlhs,Trhs>, matrix_add_exp<Ulhs,Urhs>, layout > + { const static bool value = same_exp<Tlhs,Ulhs,layout>::value && same_exp<Trhs,Urhs,layout>::value; }; + + template <typename Tlhs, typename Ulhs, typename Trhs, typename Urhs, typename layout> + struct same_exp<matrix_subtract_exp<Tlhs,Trhs>, matrix_subtract_exp<Ulhs,Urhs>, layout > + { const static bool value = same_exp<Tlhs,Ulhs,layout>::value && same_exp<Trhs,Urhs,layout>::value; }; + + template <typename T, typename U, bool Tb, bool Ub, typename layout> + struct same_exp<matrix_mul_scal_exp<T,Tb>, matrix_mul_scal_exp<U,Ub>, layout > + { const static bool value = same_exp<T,U,layout>::value; }; + + template <typename T, typename U, typename layout> + struct same_exp<matrix_div_scal_exp<T>, matrix_div_scal_exp<U>, layout > + { const static bool value = same_exp<T,U,layout>::value; }; + + template <typename T, typename U, typename layout> + struct same_exp<matrix_op<op_trans<T> >, matrix_op<op_trans<U> >, layout > + { const static bool value = same_exp<T,U,layout>::value; }; + + template <typename T, typename U, typename layout> + struct same_exp<matrix_op<op_conj<T> >, matrix_op<op_conj<U> >, layout > + { const static bool value = same_exp<T,U,layout>::value; }; + + template <typename T, typename U, typename layout> + struct same_exp<matrix_op<op_conj_trans<T> >, matrix_op<op_conj_trans<U> >, layout > + { const static bool value = same_exp<T,U,layout>::value; }; + + // ------------------------------------------------------------------------------------ + + struct yes_type + { + char ch; + }; + struct no_type + { + yes_type a, b; + }; + + // This is a helper that is used below to apply the same_exp template to matrix expressions. + template <typename T, typename layout, typename U> + typename enable_if<same_exp<T,U,layout>,yes_type>::type test(U); + template <typename T, typename layout, typename U> + typename disable_if<same_exp<T,U,layout>,no_type>::type test(U); + + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp, + typename enabled = void + > + struct matrix_assign_blas_helper + { + // We are in the default version of the blas helper so this + // means there wasn't any more specific overload. So just + // let the default matrix assignment happen. + template <typename EXP> + static void assign ( + dest_exp& dest, + const EXP& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + if (transpose == false) + matrix_assign_default(dest,src,alpha,add_to); + else + matrix_assign_default(dest,trans(src),alpha,add_to); + } + + // If we know this is a matrix multiply then apply the + // default dlib matrix multiply to speed things up a bit more + // than the above default function would. + template <typename EXP1, typename EXP2> + static void assign ( + dest_exp& dest, + const matrix_multiply_exp<EXP1,EXP2>& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + // At some point I need to improve the default (i.e. non BLAS) matrix + // multiplication algorithm... + + if (alpha == static_cast<typename src_exp::type>(1)) + { + if (add_to == false) + { + zero_matrix(dest); + } + + if (transpose == false) + default_matrix_multiply(dest, src.lhs, src.rhs); + else + default_matrix_multiply(dest, trans(src.rhs), trans(src.lhs)); + } + else + { + if (add_to) + { + typename dest_exp::matrix_type temp(dest.nr(),dest.nc()); + zero_matrix(temp); + + if (transpose == false) + default_matrix_multiply(temp, src.lhs, src.rhs); + else + default_matrix_multiply(temp, trans(src.rhs), trans(src.lhs)); + + matrix_assign_default(dest,temp, alpha,true); + } + else + { + zero_matrix(dest); + + if (transpose == false) + default_matrix_multiply(dest, src.lhs, src.rhs); + else + default_matrix_multiply(dest, trans(src.rhs), trans(src.lhs)); + + matrix_assign_default(dest,dest, alpha, false); + } + } + } + }; + + // This is a macro to help us add overloads for the matrix_assign_blas_helper template. + // Using this macro it is easy to add overloads for arbitrary matrix expressions. +#define DLIB_ADD_BLAS_BINDING(src_expression) \ + template <typename T, typename L> struct BOOST_JOIN(blas,__LINE__) \ + { const static bool value = sizeof(yes_type) == sizeof(test<T,L>(src_expression)); }; \ + \ + template < typename dest_exp, typename src_exp > \ + struct matrix_assign_blas_helper<dest_exp, src_exp, \ + typename enable_if<BOOST_JOIN(blas,__LINE__)<src_exp,typename dest_exp::layout_type> >::type > { \ + static void assign ( \ + dest_exp& dest, \ + const src_exp& src, \ + typename src_exp::type alpha, \ + bool add_to, \ + bool DLIB_NO_WARN_UNUSED transpose \ + ) { \ + DLIB_NO_WARN_UNUSED typedef typename dest_exp::type T; + +#define DLIB_END_BLAS_BINDING }}; + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + // ------------------- Forward Declarations ------------------- + + template < + typename dest_exp, + typename src_exp + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const src_exp& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ); + /*! + requires + - src.aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + !*/ + + template < + typename dest_exp, + typename src_exp, typename src_exp2 + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_add_exp<src_exp, src_exp2>& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ); + /*! + requires + - src.aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + !*/ + + template < + typename dest_exp, + typename src_exp, bool Sb + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_mul_scal_exp<src_exp,Sb>& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ); + /*! + requires + - src.aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + !*/ + + template < + typename dest_exp, + typename src_exp + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_op<op_trans<src_exp> >& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ); + /*! + requires + - src.aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + !*/ + + template < + typename dest_exp, + typename src_exp, typename src_exp2 + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_subtract_exp<src_exp, src_exp2>& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ); + /*! + requires + - src.aliases(dest) == false + - dest.nr() == src.nr() + - dest.nc() == src.nc() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix<T,NR,NC,MM,L>& dest, + const src_exp& src + ); + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix<T,NR,NC,MM,L>& dest, + const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src + ); + /*! + This function catches the expressions of the form: + M = M + exp; + and converts them into the appropriate matrix_assign_blas() call. + This is an important case to catch because it is the expression used + to represent the += matrix operator. + !*/ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix<T,NR,NC,MM,L>& dest, + const matrix_add_exp<src_exp, matrix<T,NR,NC,MM,L> >& src + ); + /*! + This function catches the expressions of the form: + M = exp + M; + and converts them into the appropriate matrix_assign_blas() call. + This is an important case to catch because it is the expression used + to represent the += matrix operator. + !*/ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix<T,NR,NC,MM,L>& dest, + const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src + ); + /*! + This function catches the expressions of the form: + M = M - exp; + and converts them into the appropriate matrix_assign_blas() call. + This is an important case to catch because it is the expression used + to represent the -= matrix operator. + !*/ + + + // End of forward declarations for overloaded matrix_assign_blas functions + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const src_exp& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + matrix_assign_blas_helper<dest_exp,src_exp>::assign(dest,src,alpha,add_to, transpose); + } + + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp, typename src_exp2 + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_add_exp<src_exp, src_exp2>& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + if (has_matrix_multiply<src_exp>::value || has_matrix_multiply<src_exp2>::value) + { + matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to, transpose); + matrix_assign_blas_proxy(dest, src.rhs, alpha, true, transpose); + } + else + { + if (transpose == false) + matrix_assign_default(dest, src, alpha, add_to); + else + matrix_assign_default(dest, trans(src), alpha, add_to); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp, bool Sb + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_mul_scal_exp<src_exp,Sb>& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + matrix_assign_blas_proxy(dest, src.m, alpha*src.s, add_to, transpose); + } + + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_op<op_trans<src_exp> >& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + matrix_assign_blas_proxy(dest, src.op.m, alpha, add_to, !transpose); + } + + // ------------------------------------------------------------------------------------ + + template < + typename dest_exp, + typename src_exp, typename src_exp2 + > + void matrix_assign_blas_proxy ( + dest_exp& dest, + const matrix_subtract_exp<src_exp, src_exp2>& src, + typename src_exp::type alpha, + bool add_to, + bool transpose + ) + { + + if (has_matrix_multiply<src_exp>::value || has_matrix_multiply<src_exp2>::value) + { + matrix_assign_blas_proxy(dest, src.lhs, alpha, add_to, transpose); + matrix_assign_blas_proxy(dest, src.rhs, -alpha, true, transpose); + } + else + { + if (transpose == false) + matrix_assign_default(dest, src, alpha, add_to); + else + matrix_assign_default(dest, trans(src), alpha, add_to); + } + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + // Once we get into this function it means that we are dealing with a matrix of float, + // double, complex<float>, or complex<double> and the src_exp contains at least one + // matrix multiply. + + template < + typename T, long NR, long NC, typename MM, typename L, + long NR2, long NC2, bool Sb + > + void matrix_assign_blas ( + matrix<T,NR,NC,MM,L>& dest, + const matrix_mul_scal_exp<matrix<T,NR2,NC2,MM,L>,Sb>& src + ) + { + // It's ok that we don't check for aliasing in this case because there isn't + // any complex unrolling of successive + or - operators in this expression. + matrix_assign_blas_proxy(dest,src.m,src.s,false, false); + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix<T,NR,NC,MM,L>& dest, + const src_exp& src + ) + { + if (src.aliases(dest)) + { + matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc()); + matrix_assign_blas_proxy(temp,src,1,false, false); + temp.swap(dest); + } + else + { + matrix_assign_blas_proxy(dest,src,1,false, false); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + assignable_sub_matrix<T,NR,NC,MM,L>& dest, + const src_exp& src + ) + { + if (src.aliases(dest.m)) + { + matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc()); + matrix_assign_blas_proxy(temp,src,1,false, false); + matrix_assign_default(dest,temp); + } + else + { + matrix_assign_blas_proxy(dest,src,1,false, false); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, + typename src_exp + > + void matrix_assign_blas ( + assignable_ptr_matrix<T>& dest, + const src_exp& src + ) + { + if (src.aliases(mat(dest.ptr,dest.height,dest.width))) + { + matrix<T> temp(dest.nr(),dest.nc()); + matrix_assign_blas_proxy(temp,src,1,false, false); + matrix_assign_default(dest,temp); + } + else + { + matrix_assign_blas_proxy(dest,src,1,false, false); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + assignable_row_matrix<T,NR,NC,MM,L>& dest, + const src_exp& src + ) + { + if (src.aliases(dest.m)) + { + matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc()); + matrix_assign_blas_proxy(temp,src,1,false, false); + matrix_assign_default(dest,temp); + } + else + { + matrix_assign_blas_proxy(dest,src,1,false, false); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + assignable_col_matrix<T,NR,NC,MM,L>& dest, + const src_exp& src + ) + { + if (src.aliases(dest.m)) + { + matrix<T,NR,NC,MM,L> temp(dest.nr(),dest.nc()); + matrix_assign_blas_proxy(temp,src,1,false, false); + matrix_assign_default(dest,temp); + } + else + { + matrix_assign_blas_proxy(dest,src,1,false, false); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix<T,NR,NC,MM,L>& dest, + const matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src + ) + { + if (src.rhs.aliases(dest) == false) + { + if (&src.lhs != &dest) + { + dest = src.lhs; + } + + matrix_assign_blas_proxy(dest, src.rhs, 1, true, false); + } + else + { + matrix<T,NR,NC,MM,L> temp(src.lhs); + matrix_assign_blas_proxy(temp, src.rhs, 1, true, false); + temp.swap(dest); + } + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix<T,NR,NC,MM,L>& dest, + const matrix_add_exp<src_exp, matrix<T,NR,NC,MM,L> >& src + ) + { + // Just switch around the left and right hand sides of the incoming + // add expression and pass it back into matrix_assign_blas() so that + // the above function will be called. + typedef matrix_add_exp<matrix<T,NR,NC,MM,L> ,src_exp> swapped_add_exp; + matrix_assign_blas(dest, swapped_add_exp(src.rhs, src.lhs)); + } + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + void matrix_assign_blas ( + matrix<T,NR,NC,MM,L>& dest, + const matrix_subtract_exp<matrix<T,NR,NC,MM,L> ,src_exp>& src + ) + { + if (src.rhs.aliases(dest) == false) + { + if (&src.lhs != &dest) + { + dest = src.lhs; + } + + matrix_assign_blas_proxy(dest, src.rhs, -1, true, false); + } + else + { + matrix<T,NR,NC,MM,L> temp(src.lhs); + matrix_assign_blas_proxy(temp, src.rhs, -1, true, false); + temp.swap(dest); + } + } + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + } // end of namespace blas_bindings + + // ------------------------------------------------------------------------------------ + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + inline typename enable_if_c<(is_same_type<T,float>::value || + is_same_type<T,double>::value || + is_same_type<T,std::complex<float> >::value || + is_same_type<T,std::complex<double> >::value) && + blas_bindings::has_matrix_multiply<src_exp>::value + >::type matrix_assign_big ( + matrix<T,NR,NC,MM,L>& dest, + const src_exp& src + ) + { + blas_bindings::matrix_assign_blas(dest,src); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + inline typename enable_if_c<(is_same_type<T,float>::value || + is_same_type<T,double>::value || + is_same_type<T,std::complex<float> >::value || + is_same_type<T,std::complex<double> >::value) && + blas_bindings::has_matrix_multiply<src_exp>::value + >::type matrix_assign_big ( + assignable_sub_matrix<T,NR,NC,MM,L>& dest, + const src_exp& src + ) + { + blas_bindings::matrix_assign_blas(dest,src); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename src_exp + > + inline typename enable_if_c<(is_same_type<T,float>::value || + is_same_type<T,double>::value || + is_same_type<T,std::complex<float> >::value || + is_same_type<T,std::complex<double> >::value) && + blas_bindings::has_matrix_multiply<src_exp>::value + >::type matrix_assign_big ( + assignable_ptr_matrix<T>& dest, + const src_exp& src + ) + { + blas_bindings::matrix_assign_blas(dest,src); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + inline typename enable_if_c<(is_same_type<T,float>::value || + is_same_type<T,double>::value || + is_same_type<T,std::complex<float> >::value || + is_same_type<T,std::complex<double> >::value) && + blas_bindings::has_matrix_multiply<src_exp>::value + >::type matrix_assign_big ( + assignable_row_matrix<T,NR,NC,MM,L>& dest, + const src_exp& src + ) + { + blas_bindings::matrix_assign_blas(dest,src); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, long NR, long NC, typename MM, typename L, + typename src_exp + > + inline typename enable_if_c<(is_same_type<T,float>::value || + is_same_type<T,double>::value || + is_same_type<T,std::complex<float> >::value || + is_same_type<T,std::complex<double> >::value) && + blas_bindings::has_matrix_multiply<src_exp>::value + >::type matrix_assign_big ( + assignable_col_matrix<T,NR,NC,MM,L>& dest, + const src_exp& src + ) + { + blas_bindings::matrix_assign_blas(dest,src); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MATRIx_ASSIGn_ + |