// Copyright (C) 2009 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include "../tester.h" #include #ifndef DLIB_USE_BLAS #error "BLAS bindings must be used for this test to make any sense" #endif namespace dlib { namespace blas_bindings { // This is a little screwy. This function is used inside the BLAS // bindings to count how many times each of the BLAS functions get called. #ifdef DLIB_TEST_BLAS_BINDINGS int& counter_gemm() { static int counter = 0; return counter; } #endif } } namespace { using namespace test; using namespace std; // Declare the logger we will use in this test. The name of the logger // should start with "test." dlib::logger dlog("test.gemm"); class blas_bindings_gemm_tester : public tester { public: blas_bindings_gemm_tester ( ) : tester ( "test_gemm", // the command line argument name for this test "Run tests for GEMM routines.", // the command line argument description 0 // the number of command line arguments for this test ) {} template void test_gemm_stuff( const matrix_type& c ) const { using namespace dlib; using namespace dlib::blas_bindings; matrix_type b, a; a = c; counter_gemm() = 0; b = a*a; DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; b = a/2*a; DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; b = a*trans(a) + a; DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; b = (a+a)*(a+a); DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; b = a*(a-a); DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; b = trans(a)*trans(a) + a; DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; b = trans(trans(trans(a)*a + a)); DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; b = a*a*a*a; DLIB_TEST(counter_gemm() == 3); b = c; counter_gemm() = 0; a = a*a*a*a; DLIB_TEST(counter_gemm() == 3); a = c; counter_gemm() = 0; a = (b + a*trans(a)*a*3*a)*trans(b); DLIB_TEST(counter_gemm() == 4); a = c; counter_gemm() = 0; a = trans((trans(b) + trans(a)*trans(a)*a*3*a)*trans(b)); DLIB_TEST(counter_gemm() == 4); a = c; counter_gemm() = 0; a = trans((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(b)); DLIB_TEST(counter_gemm() == 4); a = c; counter_gemm() = 0; a = trans((trans(b) + trans(a)*(a + b)*trans(a)*3*a)*trans(b)); DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); a = c; counter_gemm() = 0; a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*a)*trans(b)); DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); a = c; } template void test_gemm_stuff_conj( const matrix_type& c ) const { using namespace dlib; using namespace dlib::blas_bindings; matrix_type b, a; a = c; counter_gemm() = 0; b = a*conj(a); DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; b = a*trans(conj(a)) + a; DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; b = conj(trans(a))*trans(a) + a; DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; b = trans(trans(trans(a)*conj(a) + conj(a))); DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; b = a*a*conj(a)*a; DLIB_TEST(counter_gemm() == 3); b = c; counter_gemm() = 0; a = a*trans(conj(a))*a*a; DLIB_TEST(counter_gemm() == 3); a = c; counter_gemm() = 0; a = (b + a*trans(conj(a))*a*3*a)*trans(b); DLIB_TEST(counter_gemm() == 4); a = c; counter_gemm() = 0; a = (trans((conj(trans(b)) + trans(a)*conj(trans(a))*a*3*a)*trans(b))); DLIB_TEST(counter_gemm() == 4); a = c; counter_gemm() = 0; a = ((trans(b) + trans(a)*(a)*trans(a)*3*a)*trans(conj(b))); DLIB_TEST(counter_gemm() == 4); a = c; counter_gemm() = 0; a = trans((trans(b) + trans(a)*conj(a + b)*trans(a)*3*a)*trans(b)); DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); a = c; counter_gemm() = 0; a = trans((trans(b) + trans(a)*(a*8 + b+b+b+b)*trans(a)*3*conj(a))*trans(b)); DLIB_TEST_MSG(counter_gemm() == 4, counter_gemm()); a = c; } void perform_test ( ) { using namespace dlib; typedef dlib::memory_manager::kernel_1a mm; print_spinner(); dlog << dlib::LINFO << "test double"; { matrix a = randm(4,4); test_gemm_stuff(a); } print_spinner(); dlog << dlib::LINFO << "test float"; { matrix a = matrix_cast(randm(4,4)); test_gemm_stuff(a); } print_spinner(); dlog << dlib::LINFO << "test complex"; { matrix a = matrix_cast(randm(4,4)); matrix b = matrix_cast(randm(4,4)); matrix > c = complex_matrix(a,b); test_gemm_stuff(c); test_gemm_stuff_conj(c); } print_spinner(); dlog << dlib::LINFO << "test complex"; { matrix a = matrix_cast(randm(4,4)); matrix b = matrix_cast(randm(4,4)); matrix > c = complex_matrix(a,b); test_gemm_stuff(c); test_gemm_stuff_conj(c); } print_spinner(); dlog << dlib::LINFO << "test double, column major"; { matrix a = randm(100,100); test_gemm_stuff(a); } print_spinner(); dlog << dlib::LINFO << "test float, column major"; { matrix a = matrix_cast(randm(100,100)); test_gemm_stuff(a); } print_spinner(); dlog << dlib::LINFO << "test complex, column major"; { matrix a = matrix_cast(randm(100,100)); matrix b = matrix_cast(randm(100,100)); matrix,100,100,mm,column_major_layout > c = complex_matrix(a,b); test_gemm_stuff(c); test_gemm_stuff_conj(c); } print_spinner(); dlog << dlib::LINFO << "test complex, column major"; { matrix a = matrix_cast(randm(100,100)); matrix b = matrix_cast(randm(100,100)); matrix,100,100,mm,column_major_layout > c = complex_matrix(a,b); test_gemm_stuff(c); test_gemm_stuff_conj(c); } { using namespace dlib; using namespace dlib::blas_bindings; array2d a(100,100); array2d b(100,100); matrix c; counter_gemm() = 0; c = mat(a)*mat(b); DLIB_TEST(counter_gemm() == 1); counter_gemm() = 0; c = trans(2*mat(a)*mat(b)); DLIB_TEST(counter_gemm() == 1); } { using namespace dlib; using namespace dlib::blas_bindings; array2d a(100,100); array2d b(100,100); matrix aa(100,100); matrix bb(100,100); matrix c; counter_gemm() = 0; c = mat(&a[0][0],100,100)*mat(&b[0][0],100,100); DLIB_TEST(counter_gemm() == 1); set_ptrm(&c(0,0),100,100) = mat(&a[0][0],100,100)*mat(&b[0][0],100,100); DLIB_TEST(counter_gemm() == 2); set_ptrm(&c(0,0),100,100) = aa*bb; DLIB_TEST(counter_gemm() == 3); counter_gemm() = 0; c = trans(2*mat(&a[0][0],100,100)*mat(&b[0][0],100,100)); DLIB_TEST(counter_gemm() == 1); set_ptrm(&c(0,0),100,100) = trans(2*mat(&a[0][0],100,100)*mat(&b[0][0],100,100)); DLIB_TEST(counter_gemm() == 2); set_ptrm(&c(0,0),100,100) = trans(2*mat(a)*mat(b)); DLIB_TEST(counter_gemm() == 3); } print_spinner(); } }; blas_bindings_gemm_tester a; }