summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/dnn/curand_dlibapi.cpp
blob: 67828e664038b198dd244ac4555894b87ec9dd12 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuRAND_CPP_
#define DLIB_DNN_CuRAND_CPP_

#ifdef DLIB_USE_CUDA

#include "curand_dlibapi.h"
#include <curand.h>
#include "../string.h"

static const char* curand_get_error_string(curandStatus_t s)
{
    switch(s)
    {
        case CURAND_STATUS_NOT_INITIALIZED: 
            return "CUDA Runtime API initialization failed.";
        case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
            return "The requested length must be a multiple of two.";
        default:
            return "A call to cuRAND failed";
    }
}

// Check the return value of a call to the cuDNN runtime for an error condition.
#define CHECK_CURAND(call)                                                      \
do{                                                                              \
    const curandStatus_t error = call;                                         \
    if (error != CURAND_STATUS_SUCCESS)                                        \
    {                                                                          \
        std::ostringstream sout;                                               \
        sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
        sout << "code: " << error << ", reason: " << curand_get_error_string(error);\
        throw dlib::curand_error(sout.str());                            \
    }                                                                          \
}while(false)

namespace dlib
{
    namespace cuda 
    {

    // ----------------------------------------------------------------------------------------

        curand_generator::
        curand_generator(
            unsigned long long seed
        ) : handle(nullptr)
        {
            curandGenerator_t gen;
            CHECK_CURAND(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
            handle = gen;

            CHECK_CURAND(curandSetPseudoRandomGeneratorSeed(gen, seed));
        }

        curand_generator::
        ~curand_generator()
        {
            if (handle)
            {
                curandDestroyGenerator((curandGenerator_t)handle);
            }
        }

        void curand_generator::
        fill_gaussian (
            tensor& data,
            float mean,
            float stddev
        )
        {
            if (data.size() == 0)
                return;

            CHECK_CURAND(curandGenerateNormal((curandGenerator_t)handle, 
                                        data.device(),
                                        data.size(),
                                        mean,
                                        stddev));
        }

        void curand_generator::
        fill_uniform (
            tensor& data
        )
        {
            if (data.size() == 0)
                return;

            CHECK_CURAND(curandGenerateUniform((curandGenerator_t)handle, data.device(), data.size()));
        }

        void curand_generator::
        fill (
            cuda_data_ptr<unsigned int>& data
        )
        {
            if (data.size() == 0)
                return;

            CHECK_CURAND(curandGenerate((curandGenerator_t)handle, data, data.size()));
        }

    // -----------------------------------------------------------------------------------

    }  
}

#endif // DLIB_USE_CUDA

#endif // DLIB_DNN_CuRAND_CPP_