summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/dnn/curand_dlibapi.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/dnn/curand_dlibapi.cpp')
-rw-r--r--ml/dlib/dlib/dnn/curand_dlibapi.cpp113
1 files changed, 113 insertions, 0 deletions
diff --git a/ml/dlib/dlib/dnn/curand_dlibapi.cpp b/ml/dlib/dlib/dnn/curand_dlibapi.cpp
new file mode 100644
index 000000000..67828e664
--- /dev/null
+++ b/ml/dlib/dlib/dnn/curand_dlibapi.cpp
@@ -0,0 +1,113 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_DNN_CuRAND_CPP_
+#define DLIB_DNN_CuRAND_CPP_
+
+#ifdef DLIB_USE_CUDA
+
+#include "curand_dlibapi.h"
+#include <curand.h>
+#include "../string.h"
+
+static const char* curand_get_error_string(curandStatus_t s)
+{
+ switch(s)
+ {
+ case CURAND_STATUS_NOT_INITIALIZED:
+ return "CUDA Runtime API initialization failed.";
+ case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
+ return "The requested length must be a multiple of two.";
+ default:
+ return "A call to cuRAND failed";
+ }
+}
+
+// Check the return value of a call to the cuDNN runtime for an error condition.
+#define CHECK_CURAND(call) \
+do{ \
+ const curandStatus_t error = call; \
+ if (error != CURAND_STATUS_SUCCESS) \
+ { \
+ std::ostringstream sout; \
+ sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
+ sout << "code: " << error << ", reason: " << curand_get_error_string(error);\
+ throw dlib::curand_error(sout.str()); \
+ } \
+}while(false)
+
+namespace dlib
+{
+ namespace cuda
+ {
+
+ // ----------------------------------------------------------------------------------------
+
+ curand_generator::
+ curand_generator(
+ unsigned long long seed
+ ) : handle(nullptr)
+ {
+ curandGenerator_t gen;
+ CHECK_CURAND(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
+ handle = gen;
+
+ CHECK_CURAND(curandSetPseudoRandomGeneratorSeed(gen, seed));
+ }
+
+ curand_generator::
+ ~curand_generator()
+ {
+ if (handle)
+ {
+ curandDestroyGenerator((curandGenerator_t)handle);
+ }
+ }
+
+ void curand_generator::
+ fill_gaussian (
+ tensor& data,
+ float mean,
+ float stddev
+ )
+ {
+ if (data.size() == 0)
+ return;
+
+ CHECK_CURAND(curandGenerateNormal((curandGenerator_t)handle,
+ data.device(),
+ data.size(),
+ mean,
+ stddev));
+ }
+
+ void curand_generator::
+ fill_uniform (
+ tensor& data
+ )
+ {
+ if (data.size() == 0)
+ return;
+
+ CHECK_CURAND(curandGenerateUniform((curandGenerator_t)handle, data.device(), data.size()));
+ }
+
+ void curand_generator::
+ fill (
+ cuda_data_ptr<unsigned int>& data
+ )
+ {
+ if (data.size() == 0)
+ return;
+
+ CHECK_CURAND(curandGenerate((curandGenerator_t)handle, data, data.size()));
+ }
+
+ // -----------------------------------------------------------------------------------
+
+ }
+}
+
+#endif // DLIB_USE_CUDA
+
+#endif // DLIB_DNN_CuRAND_CPP_
+