summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/rbf_network.h
blob: 23a2c742485c62e1b32bbd48d2eb617de15ab3d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// Copyright (C) 2008  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_RBf_NETWORK_
#define DLIB_RBf_NETWORK_

#include "../matrix.h"
#include "rbf_network_abstract.h"
#include "kernel.h"
#include "linearly_independent_subset_finder.h"
#include "function.h"
#include "../algs.h"

namespace dlib
{

// ------------------------------------------------------------------------------

    template <
        typename Kern 
        >
    class rbf_network_trainer 
    {
        /*!
            This is an implementation of an RBF network trainer that follows
            the directions right off Wikipedia basically.  So nothing 
            particularly fancy.  Although the way the centers are selected
            is somewhat unique.
        !*/

    public:
        typedef Kern kernel_type;
        typedef typename kernel_type::scalar_type scalar_type;
        typedef typename kernel_type::sample_type sample_type;
        typedef typename kernel_type::mem_manager_type mem_manager_type;
        typedef decision_function<kernel_type> trained_function_type;

        rbf_network_trainer (
        ) :
            num_centers(10)
        {
        }

        void set_kernel (
            const kernel_type& k
        )
        {
            kernel = k;
        }

        const kernel_type& get_kernel (
        ) const
        {
            return kernel;
        }

        void set_num_centers (
            const unsigned long num 
        )
        {
            num_centers = num;
        }

        unsigned long get_num_centers (
        ) const
        {
            return num_centers;
        }

        template <
            typename in_sample_vector_type,
            typename in_scalar_vector_type
            >
        const decision_function<kernel_type> train (
            const in_sample_vector_type& x,
            const in_scalar_vector_type& y
        ) const
        {
            return do_train(mat(x), mat(y));
        }

        void swap (
            rbf_network_trainer& item
        )
        {
            exchange(kernel, item.kernel);
            exchange(num_centers, item.num_centers);
        }

    private:

    // ------------------------------------------------------------------------------------

        template <
            typename in_sample_vector_type,
            typename in_scalar_vector_type
            >
        const decision_function<kernel_type> do_train (
            const in_sample_vector_type& x,
            const in_scalar_vector_type& y
        ) const
        {
            typedef typename decision_function<kernel_type>::scalar_vector_type scalar_vector_type;

            // make sure requires clause is not broken
            DLIB_ASSERT(is_learning_problem(x,y),
                "\tdecision_function rbf_network_trainer::train(x,y)"
                << "\n\t invalid inputs were given to this function"
                << "\n\t x.nr(): " << x.nr() 
                << "\n\t y.nr(): " << y.nr() 
                << "\n\t x.nc(): " << x.nc() 
                << "\n\t y.nc(): " << y.nc() 
                );

            // use the linearly_independent_subset_finder object to select the centers.  So here
            // we show it all the data samples so it can find the best centers.
            linearly_independent_subset_finder<kernel_type> lisf(kernel, num_centers);
            fill_lisf(lisf, x);

            const long num_centers = lisf.size();

            // fill the K matrix with the output of the kernel for all the center and sample point pairs
            matrix<scalar_type,0,0,mem_manager_type> K(x.nr(), num_centers+1);
            for (long r = 0; r < x.nr(); ++r)
            {
                for (long c = 0; c < num_centers; ++c)
                {
                    K(r,c) = kernel(x(r), lisf[c]);
                }
                // This last column of the K matrix takes care of the bias term
                K(r,num_centers) = 1;
            }

            // compute the best weights by using the pseudo-inverse
            scalar_vector_type weights(pinv(K)*y);

            // now put everything into a decision_function object and return it
            return decision_function<kernel_type> (remove_row(weights,num_centers),
                                                   -weights(num_centers),
                                                   kernel,
                                                   lisf.get_dictionary());

        }

        kernel_type kernel;
        unsigned long num_centers;

    }; // end of class rbf_network_trainer 

// ----------------------------------------------------------------------------------------

    template <typename sample_type>
    void swap (
        rbf_network_trainer<sample_type>& a,
        rbf_network_trainer<sample_type>& b
    ) { a.swap(b); }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_RBf_NETWORK_