summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/simplify_linear_decision_function.h
blob: 4f5bef6f393d2da3a478c5a8e0703f99e0521c22 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// Copyright (C) 2010  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_Hh_
#define DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_Hh_

#include "simplify_linear_decision_function_abstract.h"
#include "../algs.h"
#include "function.h"
#include "sparse_kernel.h"
#include "kernel.h"
#include <map>
#include <vector>

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename T
        >
    decision_function<sparse_linear_kernel<T> > simplify_linear_decision_function (
        const decision_function<sparse_linear_kernel<T> >& df
    )
    {
        // don't do anything if we don't have to
        if (df.basis_vectors.size() <= 1)
            return df;

        decision_function<sparse_linear_kernel<T> > new_df;

        new_df.b = df.b;
        new_df.basis_vectors.set_size(1);
        new_df.alpha.set_size(1);
        new_df.alpha(0) = 1;

        // now compute the weighted sum of all the sparse basis_vectors in df
        typedef typename T::value_type pair_type;
        typedef typename pair_type::first_type key_type;
        typedef typename pair_type::second_type value_type;
        std::map<key_type, value_type> accum;
        for (long i = 0; i < df.basis_vectors.size(); ++i)
        {
            typename T::const_iterator j = df.basis_vectors(i).begin();
            const typename T::const_iterator end = df.basis_vectors(i).end();
            for (; j != end; ++j)
            {
                accum[j->first] += df.alpha(i) * (j->second);
            }
        }

        new_df.basis_vectors(0) = T(accum.begin(), accum.end());

        return new_df;
    }

// ----------------------------------------------------------------------------------------

    template <
        typename T
        >
    decision_function<linear_kernel<T> > simplify_linear_decision_function (
        const decision_function<linear_kernel<T> >& df
    )
    {
        // don't do anything if we don't have to
        if (df.basis_vectors.size() <= 1)
            return df;

        decision_function<linear_kernel<T> > new_df;

        new_df.b = df.b;
        new_df.basis_vectors.set_size(1);
        new_df.alpha.set_size(1);
        new_df.alpha(0) = 1;

        // now compute the weighted sum of all the basis_vectors in df
        new_df.basis_vectors(0) = 0;
        for (long i = 0; i < df.basis_vectors.size(); ++i)
        {
            new_df.basis_vectors(0) += df.alpha(i) * df.basis_vectors(i);
        }

        return new_df;
    }

// ----------------------------------------------------------------------------------------

    template <
        typename T
        >
    decision_function<linear_kernel<T> > simplify_linear_decision_function (
        const normalized_function<decision_function<linear_kernel<T> >, vector_normalizer<T> >& df
    )
    {
        decision_function<linear_kernel<T> > new_df = simplify_linear_decision_function(df.function);

        // now incorporate the normalization stuff into new_df
        new_df.basis_vectors(0) = pointwise_multiply(new_df.basis_vectors(0), df.normalizer.std_devs());
        new_df.b += dot(new_df.basis_vectors(0), df.normalizer.means());

        return new_df;
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_SIMPLIFY_LINEAR_DECiSION_FUNCTION_Hh_