diff options
Diffstat (limited to 'ml/dlib/dlib/test/matrix3.cpp')
-rw-r--r-- | ml/dlib/dlib/test/matrix3.cpp | 1134 |
1 files changed, 1134 insertions, 0 deletions
diff --git a/ml/dlib/dlib/test/matrix3.cpp b/ml/dlib/dlib/test/matrix3.cpp new file mode 100644 index 000000000..b66af638c --- /dev/null +++ b/ml/dlib/dlib/test/matrix3.cpp @@ -0,0 +1,1134 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include <dlib/matrix.h> +#include <sstream> +#include <string> +#include <cstdlib> +#include <ctime> +#include <vector> +#include "../stl_checked.h" +#include "../array.h" +#include "../rand.h" + +#include "tester.h" +#include <dlib/memory_manager_stateless.h> +#include <dlib/array2d.h> + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.matrix3"); + + + const double eps_mul = 200; + + template <typename T, typename U> + void check_equal ( + const T& a, + const U& b + ) + { + DLIB_TEST(a.nr() == b.nr()); + DLIB_TEST(a.nc() == b.nc()); + typedef typename T::type type; + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + type error = std::abs(a(r,c) - b(r,c)); + DLIB_TEST_MSG(error < std::sqrt(std::numeric_limits<type>::epsilon())*eps_mul, "error: " << error << + " eps: " << std::sqrt(std::numeric_limits<type>::epsilon())*eps_mul); + } + } + } + + template <typename T, typename U> + void c_check_equal ( + const T& a, + const U& b + ) + { + DLIB_TEST(a.nr() == b.nr()); + DLIB_TEST(a.nc() == b.nc()); + typedef typename T::type type; + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + typename type::value_type error = std::abs(a(r,c) - b(r,c)); + DLIB_TEST_MSG(error < std::sqrt(std::numeric_limits<typename type::value_type>::epsilon())*eps_mul, "error: " << error << + " eps: " << std::sqrt(std::numeric_limits<typename type::value_type>::epsilon())*eps_mul); + } + } + } + + template <typename T, typename U> + void assign_no_blas ( + const T& a_, + const U& b + ) + { + T& a = const_cast<T&>(a_); + DLIB_TEST(a.nr() == b.nr()); + DLIB_TEST(a.nc() == b.nc()); + for (long r = 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = b(r,c); + } + } + } + + template <typename type> + type rnd_num (dlib::rand& rnd) + { + return static_cast<type>(10*rnd.get_random_double()); + } + + template <typename type> + void test_blas( long rows, long cols) + { + // The tests in this function exercise the BLAS bindings located in the matrix/matrix_blas_bindings.h file. + // It does this by performing an assignment that is subject to BLAS bindings and comparing the + // results directly to an unevaluated matrix_exp that should be equal. + + dlib::rand rnd; + + matrix<type> a(rows,cols), temp, temp2, temp3; + + for (int k = 0; k < 6; ++k) + { + for (long r= 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + a(r,c) = rnd_num<type>(rnd); + } + } + matrix<type> at; + at = trans(a); + + matrix<complex<type> > c_a(rows,cols), c_at, c_sqr; + for (long r= 0; r < a.nr(); ++r) + { + for (long c = 0; c < a.nc(); ++c) + { + c_a(r,c) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd)); + } + } + c_at = trans(c_a); + const int size = max(rows,cols); + c_sqr = 10*matrix_cast<complex<type> >(complex_matrix(randm(size,size,rnd), randm(size,size,rnd))); + + + matrix<complex<type> > c_temp(cols,cols), c_temp2(cols,cols); + const complex<type> i(0,1); + + const type one = 1; + const type two = 1; + const type num1 = static_cast<type>(3.6); + const type num2 = static_cast<type>(6.6); + const type num3 = static_cast<type>(8.6); + + matrix<complex<type>,0,1> c_cv4(cols), c_cv3(rows); + matrix<complex<type>,1,0> c_rv4(cols), c_rv3(rows); + + matrix<type,0,1> cv4(cols); + + for (long idx = 0; idx < cv4.size(); ++idx) + cv4(idx) = rnd_num<type>(rnd); + + for (long idx = 0; idx < c_cv4.size(); ++idx) + c_cv4(idx) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd)); + + matrix<type,1,0> rv3(rows); + + for (long idx = 0; idx < rv3.size(); ++idx) + rv3(idx) = rnd_num<type>(rnd); + + for (long idx = 0; idx < c_rv3.size(); ++idx) + c_rv3(idx) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd)); + + matrix<type,0,1> cv3(rows); + + for (long idx = 0; idx < cv3.size(); ++idx) + cv3(idx) = rnd_num<type>(rnd); + + for (long idx = 0; idx < c_cv3.size(); ++idx) + c_cv3(idx) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd)); + + matrix<type,1,0> rv4(cols); + for (long idx = 0; idx < rv4.size(); ++idx) + rv4(idx) = rnd_num<type>(rnd); + + for (long idx = 0; idx < c_rv4.size(); ++idx) + c_rv4(idx) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd)); + + + + // GEMM tests + dlog << LTRACE << "1.1"; + check_equal(tmp(at*a), at*a); + check_equal(tmp(trans(at*a)), trans(at*a)); + check_equal(tmp(2.4*trans(4*trans(at*a) + at*3*a)), 2.4*trans(4*trans(at*a) + at*3*a)); + dlog << LTRACE << "1.2"; + check_equal(tmp(trans(a)*a), trans(a)*a); + check_equal(tmp(trans(trans(a)*a)), trans(trans(a)*a)); + dlog << LTRACE << "1.3"; + check_equal(tmp(at*trans(at)), at*trans(at)); + check_equal(tmp(trans(at*trans(at))), trans(at*trans(at))); + dlog << LTRACE << "1.4"; + check_equal(tmp(trans(at)*trans(a)), a*at); + check_equal(tmp(trans(trans(at)*trans(a))), trans(a*at)); + dlog << LTRACE << "1.5"; + + print_spinner(); + c_check_equal(tmp(conj(trans(c_a))*c_a), trans(conj(c_a))*c_a); + dlog << LTRACE << "1.5.1"; + c_check_equal(tmp(trans(conj(trans(c_a))*c_a)), trans(trans(conj(c_a))*c_a)); + dlog << LTRACE << "1.5.2"; + c_check_equal(tmp((conj(trans(c_sqr))*trans(c_sqr))), (trans(conj(c_sqr))*trans(c_sqr))); + dlog << LTRACE << "1.5.3"; + c_check_equal(tmp(trans(conj(trans(c_sqr))*trans(c_sqr))), trans(trans(conj(c_sqr))*trans(c_sqr))); + dlog << LTRACE << "1.6"; + c_check_equal(tmp(c_at*trans(conj(c_at))), c_at*conj(trans(c_at))); + dlog << LTRACE << "1.6.1"; + c_check_equal(tmp(trans(c_at*trans(conj(c_at)))), trans(c_at*conj(trans(c_at)))); + dlog << LTRACE << "1.6.2"; + c_check_equal(tmp((c_sqr)*trans(conj(c_sqr))), (c_sqr)*conj(trans(c_sqr))); + dlog << LTRACE << "1.6.2.1"; + c_check_equal(tmp(trans(c_sqr)*trans(conj(c_sqr))), trans(c_sqr)*conj(trans(c_sqr))); + dlog << LTRACE << "1.6.3"; + c_check_equal(tmp(trans(trans(c_sqr)*trans(conj(c_sqr)))), trans(trans(c_sqr)*conj(trans(c_sqr)))); + dlog << LTRACE << "1.7"; + c_check_equal(tmp(conj(trans(c_at))*trans(conj(c_a))), conj(trans(c_at))*trans(conj(c_a))); + c_check_equal(tmp(trans(conj(trans(c_at))*trans(conj(c_a)))), trans(conj(trans(c_at))*trans(conj(c_a)))); + dlog << LTRACE << "1.8"; + + check_equal(tmp(a*trans(rowm(a,1))) , a*trans(rowm(a,1))); + check_equal(tmp(a*colm(at,1)) , a*colm(at,1)); + check_equal(tmp(subm(a,1,1,2,2)*subm(a,1,2,2,2)), subm(a,1,1,2,2)*subm(a,1,2,2,2)); + + dlog << LTRACE << "1.9"; + check_equal(tmp(trans(a*trans(rowm(a,1)))) , trans(a*trans(rowm(a,1)))); + dlog << LTRACE << "1.10"; + check_equal(tmp(trans(a*colm(at,1))) , trans(a*colm(at,1))); + dlog << LTRACE << "1.11"; + check_equal(tmp(trans(subm(a,1,1,2,2)*subm(a,1,2,2,2))), trans(subm(a,1,1,2,2)*subm(a,1,2,2,2))); + dlog << LTRACE << "1.12"; + + { + temp = at*a; + temp2 = temp; + + temp += 3.5*at*a; + assign_no_blas(temp2, temp2 + 3.5*at*a); + check_equal(temp, temp2); + + temp -= at*3.5*a; + assign_no_blas(temp2, temp2 - at*3.5*a); + check_equal(temp, temp2); + + temp = temp + 4*at*a; + assign_no_blas(temp2, temp2 + 4*at*a); + check_equal(temp, temp2); + + temp = temp - 2.4*at*a; + assign_no_blas(temp2, temp2 - 2.4*at*a); + check_equal(temp, temp2); + } + dlog << LTRACE << "1.13"; + { + temp = trans(at*a); + temp2 = temp; + temp3 = temp; + + dlog << LTRACE << "1.14"; + temp += trans(3.5*at*a); + assign_no_blas(temp2, temp2 + trans(3.5*at*a)); + check_equal(temp, temp2); + + dlog << LTRACE << "1.15"; + temp -= trans(at*3.5*a); + assign_no_blas(temp2, temp2 - trans(at*3.5*a)); + check_equal(temp, temp2); + + dlog << LTRACE << "1.16"; + temp = trans(temp + 4*at*a); + assign_no_blas(temp3, trans(temp2 + 4*at*a)); + check_equal(temp, temp3); + + temp2 = temp; + dlog << LTRACE << "1.17"; + temp = trans(temp - 2.4*at*a); + assign_no_blas(temp3, trans(temp2 - 2.4*at*a)); + check_equal(temp, temp3); + } + + dlog << LTRACE << "1.17.1"; + { + matrix<type> m1, m2; + + m1 = matrix_cast<type>(randm(rows, cols, rnd)); + m2 = matrix_cast<type>(randm(cols, rows + 8, rnd)); + check_equal(tmp(m1*m2), m1*m2); + check_equal(tmp(trans(m1*m2)), trans(m1*m2)); + + m1 = trans(m1); + check_equal(tmp(trans(m1)*m2), trans(m1)*m2); + check_equal(tmp(trans(trans(m1)*m2)), trans(trans(m1)*m2)); + + m2 = trans(m2); + check_equal(tmp(trans(m1)*trans(m2)), trans(m1)*trans(m2)); + check_equal(tmp(trans(trans(m1)*trans(m2))), trans(trans(m1)*trans(m2))); + + m1 = trans(m1); + check_equal(tmp(m1*trans(m2)), m1*trans(m2)); + check_equal(tmp(trans(m1*trans(m2))), trans(m1*trans(m2))); + } + + dlog << LTRACE << "1.17.5"; + { + matrix<type,1,0> r; + matrix<type,0,1> c; + + r = matrix_cast<type>(randm(1, rows+9, rnd)); + c = matrix_cast<type>(randm(rows, 1, rnd)); + + check_equal(tmp(c*r), c*r); + check_equal(tmp(trans(c*r)), trans(c*r)); + + check_equal(tmp(trans(r)*trans(c)), trans(r)*trans(c)); + check_equal(tmp(trans(trans(r)*trans(c))), trans(trans(r)*trans(c))); + } + + dlog << LTRACE << "1.18"; + + // GEMV tests + check_equal(tmp(a*cv4), a*cv4); + check_equal(tmp(trans(a*cv4)), trans(a*cv4)); + check_equal(tmp(rv3*a), rv3*a); + check_equal(tmp(trans(cv4)*at), trans(cv4)*at); + check_equal(tmp(a*trans(rv4)), a*trans(rv4)); + check_equal(tmp(trans(a*trans(rv4))), trans(a*trans(rv4))); + + check_equal(tmp(trans(a)*cv3), trans(a)*cv3); + check_equal(tmp(rv4*trans(a)), rv4*trans(a)); + check_equal(tmp(trans(cv3)*trans(at)), trans(cv3)*trans(at)); + check_equal(tmp(trans(cv3)*a), trans(cv3)*a); + check_equal(tmp(trans(a)*trans(rv3)), trans(a)*trans(rv3)); + + + c_check_equal(tmp(trans(conj(c_a))*c_cv3), trans(conj(c_a))*c_cv3); + c_check_equal(tmp(c_rv4*trans(conj(c_a))), c_rv4*trans(conj(c_a))); + c_check_equal(tmp(trans(c_cv3)*trans(conj(c_at))), trans(c_cv3)*trans(conj(c_at))); + c_check_equal(tmp(conj(trans(c_a))*trans(c_rv3)), trans(conj(c_a))*trans(c_rv3)); + c_check_equal(tmp(c_rv4*conj(c_at)), c_rv4*conj(c_at)); + c_check_equal(tmp(trans(c_cv4)*conj(c_at)), trans(c_cv4)*conj(c_at)); + + dlog << LTRACE << "2.00"; + + c_check_equal(tmp(trans(trans(conj(c_a))*c_cv3)), trans(trans(conj(c_a))*c_cv3)); + c_check_equal(tmp(trans(c_rv4*trans(conj(c_a)))), trans(c_rv4*trans(conj(c_a)))); + c_check_equal(tmp(trans(trans(c_cv3)*trans(conj(c_at)))), trans(trans(c_cv3)*trans(conj(c_at)))); + dlog << LTRACE << "2.20"; + c_check_equal(tmp(trans(conj(trans(c_a))*trans(c_rv3))), trans(trans(conj(c_a))*trans(c_rv3))); + c_check_equal(tmp(trans(c_rv4*conj(c_at))), trans(c_rv4*conj(c_at))); + c_check_equal(tmp(trans(trans(c_cv4)*conj(c_at))), trans(trans(c_cv4)*conj(c_at))); + + + + dlog << LTRACE << "6"; + temp = a*at; + check_equal(temp, a*at); + temp = temp + a*at + trans(at)*at + trans(at)*sin(at); + check_equal(temp, a*at + a*at+ trans(at)*at + trans(at)*sin(at)); + + dlog << LTRACE << "6.1"; + temp = a*at; + check_equal(temp, a*at); + temp = a*at + temp; + check_equal(temp, a*at + a*at); + + print_spinner(); + dlog << LTRACE << "6.2"; + temp = a*at; + check_equal(temp, a*at); + dlog << LTRACE << "6.2.3"; + temp = temp - a*at; + dlog << LTRACE << "6.2.4"; + check_equal(temp, a*at-a*at); + + dlog << LTRACE << "6.3"; + temp = a*at; + dlog << LTRACE << "6.3.5"; + check_equal(temp, a*at); + dlog << LTRACE << "6.3.6"; + temp = a*at - temp; + dlog << LTRACE << "6.4"; + check_equal(temp, a*at-a*at); + + + + const long d = min(rows,cols); + rectangle rect(1,1,d,d); + temp.set_size(max(rows,cols)+4,max(rows,cols)+4); + set_all_elements(temp,4); + temp2 = temp; + + dlog << LTRACE << "7"; + set_subm(temp,rect) = a*at; + assign_no_blas( set_subm(temp2,rect) , a*at); + check_equal(temp, temp2); + + temp = a; + temp2 = a; + + set_colm(temp,1) = a*cv4; + assign_no_blas( set_colm(temp2,1) , a*cv4); + check_equal(temp, temp2); + + set_rowm(temp,1) = rv3*a; + assign_no_blas( set_rowm(temp2,1) , rv3*a); + check_equal(temp, temp2); + + + // Test BLAS GER + { + temp.set_size(cols,cols); + set_all_elements(temp,3); + temp2 = temp; + + + dlog << LTRACE << "8"; + temp += cv4*rv4; + assign_no_blas(temp2, temp2 + cv4*rv4); + check_equal(temp, temp2); + + dlog << LTRACE << "8.3"; + temp = temp + cv4*rv4; + assign_no_blas(temp2, temp2 + cv4*rv4); + check_equal(temp, temp2); + dlog << LTRACE << "8.9"; + } + { + temp.set_size(cols,cols); + set_all_elements(temp,3); + temp2 = temp; + temp3 = 0; + + dlog << LTRACE << "8.10"; + + temp += trans(cv4*rv4); + assign_no_blas(temp3, temp2 + trans(cv4*rv4)); + check_equal(temp, temp3); + temp3 = 0; + + dlog << LTRACE << "8.11"; + temp2 = temp; + temp = trans(temp + cv4*rv4); + assign_no_blas(temp3, trans(temp2 + cv4*rv4)); + check_equal(temp, temp3); + dlog << LTRACE << "8.12"; + } + { + matrix<complex<type> > temp, temp2, temp3; + matrix<complex<type>,0,1 > cv4; + matrix<complex<type>,1,0 > rv4; + cv4.set_size(cols); + rv4.set_size(cols); + temp.set_size(cols,cols); + set_all_elements(temp,complex<type>(3,5)); + temp(cols-1, cols-4) = 9; + temp2 = temp; + temp3.set_size(cols,cols); + temp3 = 0; + + for (long i = 0; i < rv4.size(); ++i) + { + rv4(i) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd)); + cv4(i) = complex<type>(rnd_num<type>(rnd),rnd_num<type>(rnd)); + } + + dlog << LTRACE << "8.13"; + + temp += trans(cv4*rv4); + assign_no_blas(temp3, temp2 + trans(cv4*rv4)); + c_check_equal(temp, temp3); + temp3 = 0; + + dlog << LTRACE << "8.14"; + temp2 = temp; + temp = trans(temp + cv4*rv4); + assign_no_blas(temp3, trans(temp2 + cv4*rv4)); + c_check_equal(temp, temp3); + dlog << LTRACE << "8.15"; + } + + + + + set_all_elements(c_temp, one + num1*i); + c_temp2 = c_temp; + set_all_elements(c_rv4, one + num2*i); + set_all_elements(c_cv4, two + num3*i); + + + dlog << LTRACE << "9"; + c_temp += c_cv4*c_rv4; + assign_no_blas(c_temp2, c_temp2 + c_cv4*c_rv4); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.1"; + c_temp += c_cv4*conj(c_rv4); + assign_no_blas(c_temp2, c_temp2 + c_cv4*conj(c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.2"; + c_temp = c_cv4*conj(c_rv4) + c_temp; + assign_no_blas(c_temp2, c_temp2 + c_cv4*conj(c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.3"; + c_temp = trans(c_rv4)*trans(conj(c_cv4)) + c_temp; + assign_no_blas(c_temp2, c_temp2 + trans(c_rv4)*trans(conj(c_cv4))); + c_check_equal(c_temp, c_temp2); + + + dlog << LTRACE << "9.4"; + c_temp += conj(c_cv4)*c_rv4; + assign_no_blas(c_temp2, c_temp2 + conj(c_cv4)*c_rv4); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.5"; + c_temp += conj(c_cv4)*conj(c_rv4); + assign_no_blas(c_temp2, c_temp2 + conj(c_cv4)*conj(c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.6"; + c_temp = conj(c_cv4)*conj(c_rv4) + c_temp; + assign_no_blas(c_temp2, c_temp2 + conj(c_cv4)*conj(c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "9.7"; + c_temp = conj(trans(c_rv4))*trans(conj(c_cv4)) + c_temp; + assign_no_blas(c_temp2, c_temp2 + conj(trans(c_rv4))*trans(conj(c_cv4))); + c_check_equal(c_temp, c_temp2); + + + dlog << LTRACE << "10"; + c_temp += trans(c_cv4*c_rv4); + assign_no_blas(c_temp2, c_temp2 + trans(c_cv4*c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.1"; + c_temp += trans(c_cv4*conj(c_rv4)); + assign_no_blas(c_temp2, c_temp2 + trans(c_cv4*conj(c_rv4))); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.2"; + c_temp = trans(c_cv4*conj(c_rv4)) + c_temp; + assign_no_blas(c_temp2, c_temp2 + trans(c_cv4*conj(c_rv4))); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.3"; + c_temp = trans(trans(c_rv4)*trans(conj(c_cv4))) + c_temp; + assign_no_blas(c_temp2, c_temp2 + trans(trans(c_rv4)*trans(conj(c_cv4)))); + c_check_equal(c_temp, c_temp2); + + + dlog << LTRACE << "10.4"; + c_temp += trans(conj(c_cv4)*c_rv4); + assign_no_blas(c_temp2, c_temp2 + trans(conj(c_cv4)*c_rv4)); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.5"; + c_temp += trans(conj(c_cv4)*conj(c_rv4)); + assign_no_blas(c_temp2, c_temp2 + trans(conj(c_cv4)*conj(c_rv4))); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.6"; + c_temp = trans(conj(c_cv4)*conj(c_rv4)) + c_temp; + assign_no_blas(c_temp2, c_temp2 + trans(conj(c_cv4)*conj(c_rv4))); + c_check_equal(c_temp, c_temp2); + dlog << LTRACE << "10.7"; + c_temp = trans(conj(trans(c_rv4))*trans(conj(c_cv4))) + c_temp; + assign_no_blas(c_temp2, c_temp2 + trans(conj(trans(c_rv4))*trans(conj(c_cv4)))); + c_check_equal(c_temp, c_temp2); + + dlog << LTRACE << "10.8"; + + + print_spinner(); + + // Test DOT + check_equal( tmp(rv4*cv4), rv4*cv4); + check_equal( tmp(trans(rv4*cv4)), trans(rv4*cv4)); + check_equal( tmp(trans(cv4)*trans(rv4)), trans(cv4)*trans(rv4)); + check_equal( tmp(rv4*3.9*cv4), rv4*3.9*cv4); + check_equal( tmp(trans(cv4)*3.9*trans(rv4)), trans(cv4)*3.9*trans(rv4)); + check_equal( tmp(rv4*cv4*3.9), rv4*3.9*cv4); + check_equal( tmp(trans(cv4)*trans(rv4)*3.9), trans(cv4)*3.9*trans(rv4)); + + + check_equal( tmp(trans(rv4*cv4)), trans(rv4*cv4)); + check_equal( tmp(trans(trans(rv4*cv4))), trans(trans(rv4*cv4))); + check_equal( tmp(trans(trans(cv4)*trans(rv4))), trans(trans(cv4)*trans(rv4))); + check_equal( tmp(trans(rv4*3.9*cv4)), trans(rv4*3.9*cv4)); + check_equal( tmp(trans(trans(cv4)*3.9*trans(rv4))), trans(trans(cv4)*3.9*trans(rv4))); + check_equal( tmp(trans(rv4*cv4*3.9)), trans(rv4*3.9*cv4)); + check_equal( tmp(trans(trans(cv4)*trans(rv4)*3.9)), trans(trans(cv4)*3.9*trans(rv4))); + + + temp.set_size(1,1); + temp = 4; + check_equal( tmp(temp + rv4*cv4), temp + rv4*cv4); + check_equal( tmp(temp + trans(cv4)*trans(rv4)), temp + trans(cv4)*trans(rv4)); + + dlog << LTRACE << "11"; + + + + c_check_equal( tmp(conj(c_rv4)*c_cv4), conj(c_rv4)*c_cv4); + c_check_equal( tmp(conj(trans(c_cv4))*trans(c_rv4)), trans(conj(c_cv4))*trans(c_rv4)); + + c_check_equal( tmp(conj(c_rv4)*i*c_cv4), conj(c_rv4)*i*c_cv4); + c_check_equal( tmp(conj(trans(c_cv4))*i*trans(c_rv4)), trans(conj(c_cv4))*i*trans(c_rv4)); + + c_temp.set_size(1,1); + c_temp = 4; + c_check_equal( tmp(c_temp + conj(c_rv4)*c_cv4), c_temp + conj(c_rv4)*c_cv4); + c_check_equal( tmp(c_temp + trans(conj(c_cv4))*trans(c_rv4)), c_temp + trans(conj(c_cv4))*trans(c_rv4)); + + complex<type> tmp = c_rv4*c_cv4; + DLIB_TEST(abs((tmp + i) - ((c_rv4*c_cv4)(0) + i)) < std::sqrt(std::numeric_limits<type>::epsilon())*eps_mul ); + DLIB_TEST(max(abs((rv4*cv4 + 1.0) - ((rv4*cv4)(0) + 1.0))) < std::sqrt(std::numeric_limits<type>::epsilon())*eps_mul); + + } + + { + matrix<int> m(2,3), m2(6,1); + + m = 1,2,3, + 4,5,6; + + m2 = 1,2,3,4,5,6; + + DLIB_TEST(reshape_to_column_vector(m) == m2); + DLIB_TEST(reshape_to_column_vector(m+m) == m2+m2); + + } + { + matrix<int,2,3> m(2,3); + matrix<int> m2(6,1); + + m = 1,2,3, + 4,5,6; + + m2 = 1,2,3,4,5,6; + + DLIB_TEST(reshape_to_column_vector(m) == m2); + DLIB_TEST(reshape_to_column_vector(m+m) == m2+m2); + + } + } + + + void matrix_test ( + ) + /*! + ensures + - runs tests on the matrix stuff compliance with the specs + !*/ + { + print_spinner(); + + + { + matrix<long> m1(2,2), m2(2,2); + + m1 = 1, 2, + 3, 4; + + m2 = 4, 5, + 6, 7; + + + DLIB_TEST(subm(tensor_product(m1,m2),range(0,1), range(0,1)) == 1*m2); + DLIB_TEST(subm(tensor_product(m1,m2),range(0,1), range(2,3)) == 2*m2); + DLIB_TEST(subm(tensor_product(m1,m2),range(2,3), range(0,1)) == 3*m2); + DLIB_TEST(subm(tensor_product(m1,m2),range(2,3), range(2,3)) == 4*m2); + } + + { + print_spinner(); + dlog << LTRACE << "testing blas stuff"; + dlog << LTRACE << " \nsmall double"; + test_blas<double>(3,4); + print_spinner(); + dlog << LTRACE << " \nsmall float"; + test_blas<float>(3,4); + print_spinner(); + dlog << LTRACE << " \nbig double"; + test_blas<double>(120,131); + print_spinner(); + dlog << LTRACE << " \nbig float"; + test_blas<float>(120,131); + print_spinner(); + dlog << LTRACE << "testing done"; + } + + + { + matrix<long> m(3,4), ml(3,4), mu(3,4); + m = 1,2,3,4, + 4,5,6,7, + 7,8,9,0; + + ml = 1,0,0,0, + 4,5,0,0, + 7,8,9,0; + + mu = 1,2,3,4, + 0,5,6,7, + 0,0,9,0; + + + DLIB_TEST(lowerm(m) == ml); + DLIB_TEST(upperm(m) == mu); + + ml = 3,0,0,0, + 4,3,0,0, + 7,8,3,0; + + mu = 4,2,3,4, + 0,4,6,7, + 0,0,4,0; + + DLIB_TEST(lowerm(m,3) == ml); + DLIB_TEST(upperm(m,4) == mu); + + } + + { + matrix<long> m(3,4), row(1,3), col(2,1); + m = 1,2,3,4, + 4,5,6,7, + 7,8,9,0; + + row = 4,5,6; + col = 3,6; + + DLIB_TEST(rowm(m, 1, 3) == row); + DLIB_TEST(colm(m, 2, 2) == col); + + } + + + { + std::vector<double> v(34, 8); + std::vector<double> v2(34, 9); + + DLIB_TEST(mat(&v[0], v.size()) == mat(v)); + DLIB_TEST(mat(&v2[0], v.size()) != mat(v)); + } + + { + std::vector<long> v(1, 3); + std::vector<long> v2(1, 2); + + DLIB_TEST(mat(&v[0], v.size()) == mat(v)); + DLIB_TEST(mat(&v2[0], v.size()) != mat(v)); + } + + { + matrix<double> a(3,3), b(3,3); + a = 1, 2.5, 1, + 3, 4, 5, + 0.5, 2.2, 3; + + b = 0, 1, 0, + 1, 1, 1, + 0, 1, 1; + + DLIB_TEST((a>1) == b); + DLIB_TEST((1<a) == b); + + b = 1, 1, 1, + 1, 1, 1, + 0, 1, 1; + + DLIB_TEST((a>=1) == b); + DLIB_TEST((1<=a) == b); + + b = 0, 0, 0, + 0, 0, 0, + 0, 1, 0; + DLIB_TEST((a==2.2) == b); + DLIB_TEST((a!=2.2) == (b==0)); + DLIB_TEST((2.2==a) == b); + DLIB_TEST((2.2!=a) == (0==b)); + + b = 0, 0, 0, + 0, 0, 0, + 1, 0, 0; + DLIB_TEST((a<1) == b); + DLIB_TEST((1>a) == b); + + b = 1, 0, 1, + 0, 0, 0, + 1, 0, 0; + DLIB_TEST((a<=1) == b); + DLIB_TEST((1>=a) == b); + } + + { + matrix<double> a, b, c; + a = randm(4,2); + + b += a; + c -= a; + + DLIB_TEST(equal(a, b)); + DLIB_TEST(equal(-a, c)); + + b += a; + c -= a; + + DLIB_TEST(equal(2*a, b)); + DLIB_TEST(equal(-2*a, c)); + + b += a + a; + c -= a + a; + + DLIB_TEST(equal(4*a, b)); + DLIB_TEST(equal(-4*a, c)); + + b.set_size(0,0); + c.set_size(0,0); + + + b += a + a; + c -= a + a; + + DLIB_TEST(equal(2*a, b)); + DLIB_TEST(equal(-2*a, c)); + } + + { + matrix<int> a, b, c; + + a.set_size(2, 3); + b.set_size(2, 6); + c.set_size(4, 3); + + a = 1, 2, 3, + 4, 5, 6; + + b = 1, 2, 3, 1, 2, 3, + 4, 5, 6, 4, 5, 6; + + c = 1, 2, 3, + 4, 5, 6, + 1, 2, 3, + 4, 5, 6; + + DLIB_TEST(join_rows(a,a) == b); + DLIB_TEST(join_rows(a,abs(a)) == b); + DLIB_TEST(join_cols(trans(a), trans(a)) == trans(b)); + DLIB_TEST(join_cols(a,a) == c); + DLIB_TEST(join_cols(a,abs(a)) == c); + DLIB_TEST(join_rows(trans(a),trans(a)) == trans(c)); + } + + { + matrix<int, 2, 3> a; + matrix<int, 2, 6> b; + matrix<int, 4, 3> c; + + a = 1, 2, 3, + 4, 5, 6; + + b = 1, 2, 3, 1, 2, 3, + 4, 5, 6, 4, 5, 6; + + c = 1, 2, 3, + 4, 5, 6, + 1, 2, 3, + 4, 5, 6; + + DLIB_TEST(join_rows(a,a) == b); + DLIB_TEST(join_rows(a,abs(a)) == b); + DLIB_TEST(join_cols(trans(a), trans(a)) == trans(b)); + DLIB_TEST(join_cols(a,a) == c); + DLIB_TEST(join_cols(a,abs(a)) == c); + DLIB_TEST(join_rows(trans(a),trans(a)) == trans(c)); + } + + { + matrix<int, 2, 3> a; + matrix<int> a2; + matrix<int, 2, 6> b; + matrix<int, 4, 3> c; + + a = 1, 2, 3, + 4, 5, 6; + + a2 = a; + + b = 1, 2, 3, 1, 2, 3, + 4, 5, 6, 4, 5, 6; + + c = 1, 2, 3, + 4, 5, 6, + 1, 2, 3, + 4, 5, 6; + + DLIB_TEST(join_rows(a,a2) == b); + DLIB_TEST(join_rows(a2,a) == b); + DLIB_TEST(join_cols(trans(a2), trans(a)) == trans(b)); + DLIB_TEST(join_cols(a2,a) == c); + DLIB_TEST(join_cols(a,a2) == c); + DLIB_TEST(join_rows(trans(a2),trans(a)) == trans(c)); + } + + { + matrix<int> a, b; + + a.set_size(2,3); + + a = 1, 2, 3, + 4, 5, 6; + + b.set_size(3,2); + b = 1, 2, + 3, 4, + 5, 6; + + DLIB_TEST(reshape(a, 3, 2) == b); + + b.set_size(2,3); + b = 1, 4, 2, + 5, 3, 6; + + DLIB_TEST(reshape(trans(a), 2, 3) == b); + + } + + { + matrix<int,2,3> a; + matrix<int> b; + + a = 1, 2, 3, + 4, 5, 6; + + b.set_size(3,2); + b = 1, 2, + 3, 4, + 5, 6; + + DLIB_TEST(reshape(a, 3, 2) == b); + + b.set_size(2,3); + b = 1, 4, 2, + 5, 3, 6; + + DLIB_TEST(reshape(trans(a), 2, 3) == b); + + } + + { + std::vector<int> v(6); + for (unsigned long i = 0; i < v.size(); ++i) + v[i] = i; + + matrix<int,2,3> a; + a = 0, 1, 2, + 3, 4, 5; + + DLIB_TEST(mat(&v[0], 2, 3) == a); + } + + { + matrix<int> a(3,4); + matrix<int> b(3,1), c(1,4); + + a = 1, 2, 3, 6, + 4, 5, 6, 9, + 1, 1, 1, 3; + + b(0) = sum(rowm(a,0)); + b(1) = sum(rowm(a,1)); + b(2) = sum(rowm(a,2)); + + c(0) = sum(colm(a,0)); + c(1) = sum(colm(a,1)); + c(2) = sum(colm(a,2)); + c(3) = sum(colm(a,3)); + + DLIB_TEST(sum_cols(a) == b); + DLIB_TEST(sum_rows(a) == c); + + } + + { + matrix<int> m(3,3); + + m = 1, 2, 3, + 4, 5, 6, + 7, 8, 9; + + DLIB_TEST(make_symmetric(m) == trans(make_symmetric(m))); + DLIB_TEST(lowerm(make_symmetric(m)) == lowerm(m)); + DLIB_TEST(upperm(make_symmetric(m)) == trans(lowerm(m))); + } + + { + matrix<int,3,4> a; + matrix<int> b(3,1), c(1,4); + + a = 1, 2, 3, 6, + 4, 5, 6, 9, + 1, 1, 1, 3; + + b(0) = sum(rowm(a,0)); + b(1) = sum(rowm(a,1)); + b(2) = sum(rowm(a,2)); + + c(0) = sum(colm(a,0)); + c(1) = sum(colm(a,1)); + c(2) = sum(colm(a,2)); + c(3) = sum(colm(a,3)); + + DLIB_TEST(sum_cols(a) == b); + DLIB_TEST(sum_rows(a) == c); + + } + + { + matrix<int> m(3,4), s(3,4); + m = -2, 1, 5, -5, + 5, 5, 5, 5, + 9, 0, -4, -2; + + s = -1, 1, 1, -1, + 1, 1, 1, 1, + 1, 1, -1, -1; + + DLIB_TEST(sign(m) == s); + DLIB_TEST(sign(matrix_cast<double>(m)) == matrix_cast<double>(s)); + } + + } + + + void test_matrix_IO() + { + dlib::rand rnd; + print_spinner(); + + for (int i = 0; i < 400; ++i) + { + ostringstream sout; + sout.precision(20); + + matrix<double> m1, m2, m3; + + const long r = rnd.get_random_32bit_number()%7+1; + const long c = rnd.get_random_32bit_number()%7+1; + const long num = rnd.get_random_32bit_number()%2+1; + + m1 = randm(r,c,rnd); + sout << m1; + if (num != 1) + sout << "\n" << m1; + + if (rnd.get_random_double() < 0.3) + sout << " \n"; + else if (rnd.get_random_double() < 0.3) + sout << " \n\n 3 3 3 3"; + else if (rnd.get_random_double() < 0.3) + sout << " \n \n v 3 3 3 3 3"; + + istringstream sin(sout.str()); + sin >> m2; + DLIB_TEST_MSG(equal(m1,m2), m1 << "\n***********\n" << m2); + + if (num != 1) + { + sin >> m3; + DLIB_TEST_MSG(equal(m1,m3), m1 << "\n***********\n" << m3); + } + } + + + { + istringstream sin(" 1 2\n3"); + matrix<double> m; + DLIB_TEST(sin.good()); + sin >> m; + DLIB_TEST(!sin.good()); + } + { + istringstream sin(""); + matrix<double> m; + DLIB_TEST(sin.good()); + sin >> m; + DLIB_TEST(!sin.good()); + } + } + + + void test_axpy() + { + const int n = 4; + matrix<double> B = dlib::randm(n,n); + + matrix<double> g = dlib::uniform_matrix<double>(n,1,0.0); + + const double tau = 1; + + matrix<double> p = g + tau*dlib::colm(B,0); + matrix<double> q = dlib::colm(B,0); + DLIB_TEST(max(abs(p-q)) < 1e-14); + + p = tau*dlib::colm(B,0); + q = dlib::colm(B,0); + DLIB_TEST(max(abs(p-q)) < 1e-14); + + + + + g = dlib::uniform_matrix<double>(n,n,0.0); + p = g + tau*B; + DLIB_TEST(max(abs(p-B)) < 1e-14); + + p = g + tau*subm(B,get_rect(B)); + DLIB_TEST(max(abs(p-B)) < 1e-14); + + g = dlib::uniform_matrix<double>(2,2,0.0); + p = g + tau*subm(B,1,1,2,2); + DLIB_TEST(max(abs(p-subm(B,1,1,2,2))) < 1e-14); + + set_subm(p,0,0,2,2) = g + tau*subm(B,1,1,2,2); + DLIB_TEST(max(abs(p-subm(B,1,1,2,2))) < 1e-14); + } + + + class matrix_tester : public tester + { + public: + matrix_tester ( + ) : + tester ("test_matrix3", + "Runs tests on the matrix component.") + {} + + void perform_test ( + ) + { + test_axpy(); + test_matrix_IO(); + matrix_test(); + } + } a; + +} + + |