summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/dnn/cublas_dlibapi.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/dnn/cublas_dlibapi.cpp')
-rw-r--r--ml/dlib/dlib/dnn/cublas_dlibapi.cpp165
1 files changed, 165 insertions, 0 deletions
diff --git a/ml/dlib/dlib/dnn/cublas_dlibapi.cpp b/ml/dlib/dlib/dnn/cublas_dlibapi.cpp
new file mode 100644
index 000000000..376cc9f00
--- /dev/null
+++ b/ml/dlib/dlib/dnn/cublas_dlibapi.cpp
@@ -0,0 +1,165 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuBLAS_CPP_
+#define DLIB_DNN_CuBLAS_CPP_
+
+#ifdef DLIB_USE_CUDA
+
+#include "cublas_dlibapi.h"
+#include "cuda_utils.h"
+
+#include <cublas_v2.h>
+#include <vector>
+
+static const char* cublas_get_error_string(cublasStatus_t s)
+{
+ switch(s)
+ {
+ case CUBLAS_STATUS_NOT_INITIALIZED:
+ return "CUDA Runtime API initialization failed.";
+ case CUBLAS_STATUS_ALLOC_FAILED:
+ return "CUDA Resources could not be allocated.";
+ default:
+ return "A call to cuBLAS failed";
+ }
+}
+
+// Check the return value of a call to the cuBLAS runtime for an error condition.
+#define CHECK_CUBLAS(call) \
+do{ \
+ const cublasStatus_t error = call; \
+ if (error != CUBLAS_STATUS_SUCCESS) \
+ { \
+ std::ostringstream sout; \
+ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
+ sout << "code: " << error << ", reason: " << cublas_get_error_string(error);\
+ throw dlib::cublas_error(sout.str()); \
+ } \
+}while(false)
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // -----------------------------------------------------------------------------------
+
+ class cublas_context
+ {
+ public:
+ // not copyable
+ cublas_context(const cublas_context&) = delete;
+ cublas_context& operator=(const cublas_context&) = delete;
+
+ cublas_context()
+ {
+ handles.resize(16);
+ }
+ ~cublas_context()
+ {
+ for (auto h : handles)
+ {
+ if (h)
+ cublasDestroy(h);
+ }
+ }
+
+ cublasHandle_t get_handle (
+ )
+ {
+ int new_device_id;
+ CHECK_CUDA(cudaGetDevice(&new_device_id));
+ // make room for more devices if needed
+ if (new_device_id >= (long)handles.size())
+ handles.resize(new_device_id+16);
+
+ // If we don't have a handle already for this device then make one
+ if (!handles[new_device_id])
+ CHECK_CUBLAS(cublasCreate(&handles[new_device_id]));
+
+ // Finally, return the handle for the current device
+ return handles[new_device_id];
+ }
+
+ private:
+
+ std::vector<cublasHandle_t> handles;
+ };
+
+ static cublasHandle_t context()
+ {
+ thread_local cublas_context c;
+ return c.get_handle();
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ void gemm (
+ float beta,
+ tensor& dest,
+ float alpha,
+ const tensor& lhs,
+ bool trans_lhs,
+ const tensor& rhs,
+ bool trans_rhs
+ )
+ {
+ // Recall that BLAS uses column major order so to deal with that we flip the
+ // order of the lhs and rhs arguments.
+ const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N;
+ const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N;
+
+ const int dest_nr = dest.num_samples();
+ const int dest_nc = dest.size()/dest_nr;
+ const int lhs_nr = lhs.num_samples();
+ const int lhs_nc = lhs.size()/lhs_nr;
+ const int rhs_nr = rhs.num_samples();
+ const int rhs_nc = rhs.size()/rhs_nr;
+ if (trans_lhs && trans_rhs)
+ {
+ DLIB_ASSERT( dest_nr == lhs_nc &&
+ dest_nc == rhs_nr &&
+ lhs_nr == rhs_nc)
+ }
+ else if (!trans_lhs && trans_rhs)
+ {
+ DLIB_ASSERT( dest_nr == lhs_nr &&
+ dest_nc == rhs_nr &&
+ lhs_nc == rhs_nc)
+ }
+ else if (trans_lhs && !trans_rhs)
+ {
+ DLIB_ASSERT( dest_nr == lhs_nc &&
+ dest_nc == rhs_nc &&
+ lhs_nr == rhs_nr)
+ }
+ else
+ {
+ DLIB_ASSERT( dest_nr == lhs_nr &&
+ dest_nc == rhs_nc &&
+ lhs_nc == rhs_nr)
+ }
+
+ const int k = trans_rhs ? rhs_nc : rhs_nr;
+ CHECK_CUBLAS(cublasSgemm(context(),
+ transb,
+ transa,
+ dest_nc, dest_nr, k,
+ &alpha,
+ rhs.device(), rhs_nc,
+ lhs.device(), lhs_nc,
+ &beta,
+ dest.device(),dest_nc));
+ }
+
+ // ------------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuBLAS_CPP_
+
+
+