summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/structural_sequence_labeling_trainer_abstract.h
blob: 43e5f5131a960976df2cab84f7a0a75d9c643f61 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
// Copyright (C) 2011  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#undef DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_ABSTRACT_Hh_
#ifdef DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_ABSTRACT_Hh_

#include "../algs.h"
#include "../optimization.h"
#include "structural_svm_sequence_labeling_problem_abstract.h"
#include "sequence_labeler_abstract.h"


namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename feature_extractor
        >
    class structural_sequence_labeling_trainer
    {
        /*!
            REQUIREMENTS ON feature_extractor
                It must be an object that implements an interface compatible with 
                the example_feature_extractor defined in dlib/svm/sequence_labeler_abstract.h.

            WHAT THIS OBJECT REPRESENTS
                This object is a tool for learning to do sequence labeling based
                on a set of training data.  The training procedure produces a
                sequence_labeler object which can be used to predict the labels of
                new data sequences.

                Note that this is just a convenience wrapper around the 
                structural_svm_sequence_labeling_problem to make it look 
                similar to all the other trainers in dlib.  
        !*/

    public:
        typedef typename feature_extractor::sequence_type sample_sequence_type;
        typedef std::vector<unsigned long> labeled_sequence_type;
        typedef sequence_labeler<feature_extractor> trained_function_type;

        structural_sequence_labeling_trainer (
        );
        /*!
            ensures
                - #get_c() == 100
                - this object isn't verbose
                - #get_epsilon() == 0.1
                - #get_max_iterations() == 10000
                - #get_num_threads() == 2
                - #get_max_cache_size() == 5
                - #get_feature_extractor() == a default initialized feature_extractor
        !*/

        explicit structural_sequence_labeling_trainer (
            const feature_extractor& fe
        );
        /*!
            ensures
                - #get_c() == 100
                - this object isn't verbose
                - #get_epsilon() == 0.1
                - #get_max_iterations() == 10000
                - #get_num_threads() == 2
                - #get_max_cache_size() == 5
                - #get_feature_extractor() == fe 
        !*/

        const feature_extractor& get_feature_extractor (
        ) const;
        /*!
            ensures
                - returns the feature extractor used by this object
        !*/

        unsigned long num_labels (
        ) const; 
        /*!
            ensures
                - returns get_feature_extractor().num_labels()
                  (i.e. returns the number of possible output labels for each 
                  element of a sequence)
        !*/

        void set_num_threads (
            unsigned long num
        );
        /*!
            ensures
                - #get_num_threads() == num
        !*/

        unsigned long get_num_threads (
        ) const;
        /*!
            ensures
                - returns the number of threads used during training.  You should 
                  usually set this equal to the number of processing cores on your
                  machine.
        !*/

        void set_epsilon (
            double eps
        );
        /*!
            requires
                - eps > 0
            ensures
                - #get_epsilon() == eps
        !*/

        const double get_epsilon (
        ) const;
        /*!
            ensures
                - returns the error epsilon that determines when training should stop.
                  Smaller values may result in a more accurate solution but take longer 
                  to train.  You can think of this epsilon value as saying "solve the 
                  optimization problem until the average number of labeling mistakes per 
                  training sample is within epsilon of its optimal value".
        !*/

        void set_max_iterations (
            unsigned long max_iter
        );
        /*!
            ensures
                - #get_max_iterations() == max_iter
        !*/

        unsigned long get_max_iterations (
        ); 
        /*!
            ensures
                - returns the maximum number of iterations the SVM optimizer is allowed to
                  run before it is required to stop and return a result.
        !*/

        void set_max_cache_size (
            unsigned long max_size
        );
        /*!
            ensures
                - #get_max_cache_size() == max_size
        !*/

        unsigned long get_max_cache_size (
        ) const;
        /*!
            ensures
                - During training, this object basically runs the sequence_labeler on 
                  each training sample, over and over.  To speed this up, it is possible to 
                  cache the results of these labeler invocations.  This function returns the 
                  number of cache elements per training sample kept in the cache.  Note 
                  that a value of 0 means caching is not used at all.  
        !*/

        void be_verbose (
        );
        /*!
            ensures
                - This object will print status messages to standard out so that a 
                  user can observe the progress of the algorithm.
        !*/

        void be_quiet (
        );
        /*!
            ensures
                - this object will not print anything to standard out
        !*/

        void set_oca (
            const oca& item
        );
        /*!
            ensures
                - #get_oca() == item 
        !*/

        const oca get_oca (
        ) const;
        /*!
            ensures
                - returns a copy of the optimizer used to solve the structural SVM problem.  
        !*/

        void set_c (
            double C
        );
        /*!
            requires
                - C > 0
            ensures
                - #get_c() = C
        !*/

        double get_c (
        ) const;
        /*!
            ensures
                - returns the SVM regularization parameter.  It is the parameter 
                  that determines the trade-off between trying to fit the training 
                  data (i.e. minimize the loss) or allowing more errors but hopefully 
                  improving the generalization of the resulting sequence labeler.  Larger 
                  values encourage exact fitting while smaller values of C may encourage 
                  better generalization. 
        !*/

        double get_loss (
            unsigned long label
        ) const;
        /*!
            requires
                - label < num_labels()
            ensures
                - returns the loss incurred when a sequence element with the given
                  label is misclassified.  This value controls how much we care about
                  correctly classifying this type of label.  Larger loss values indicate
                  that we care more strongly than smaller values.
        !*/

        void set_loss (
            unsigned long label,
            double value
        );
        /*!
            requires
                - label < num_labels()
                - value >= 0
            ensures
                - #get_loss(label) == value
        !*/

        const sequence_labeler<feature_extractor> train(
            const std::vector<sample_sequence_type>& x,
            const std::vector<labeled_sequence_type>& y
        ) const;
        /*!
            requires
                - is_sequence_labeling_problem(x, y) == true
                - contains_invalid_labeling(get_feature_extractor(), x, y) == false
                - for all valid i and j: y[i][j] < num_labels()
            ensures
                - Uses the structural_svm_sequence_labeling_problem to train a 
                  sequence_labeler on the given x/y training pairs.  The idea is 
                  to learn to predict a y given an input x.
                - returns a function F with the following properties:
                    - F(new_x) == A sequence of predicted labels for the elements of new_x.  
                    - F(new_x).size() == new_x.size()
                    - for all valid i:
                        - F(new_x)[i] == the predicted label of new_x[i]
        !*/

    };

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_STRUCTURAL_SEQUENCE_LABELING_TRAiNER_ABSTRACT_Hh_