diff options
Diffstat (limited to 'ml/dlib/examples/custom_trainer_ex.cpp')
-rw-r--r-- | ml/dlib/examples/custom_trainer_ex.cpp | 277 |
1 files changed, 277 insertions, 0 deletions
diff --git a/ml/dlib/examples/custom_trainer_ex.cpp b/ml/dlib/examples/custom_trainer_ex.cpp new file mode 100644 index 000000000..39af53f39 --- /dev/null +++ b/ml/dlib/examples/custom_trainer_ex.cpp @@ -0,0 +1,277 @@ +// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt +/* + This example program shows you how to create your own custom binary classification + trainer object and use it with the multiclass classification tools in the dlib C++ + library. This example assumes you have already become familiar with the concepts + introduced in the multiclass_classification_ex.cpp example program. + + + In this example we will create a very simple trainer object that takes a binary + classification problem and produces a decision rule which says a test point has the + same class as whichever centroid it is closest to. + + The multiclass training dataset will consist of four classes. Each class will be a blob + of points in one of the quadrants of the cartesian plane. For fun, we will use + std::string labels and therefore the labels of these classes will be the following: + "upper_left", + "upper_right", + "lower_left", + "lower_right" +*/ + +#include <dlib/svm_threaded.h> + +#include <iostream> +#include <vector> + +#include <dlib/rand.h> + +using namespace std; +using namespace dlib; + +// Our data will be 2-dimensional data. So declare an appropriate type to contain these points. +typedef matrix<double,2,1> sample_type; + +// ---------------------------------------------------------------------------------------- + +struct custom_decision_function +{ + /*! + WHAT THIS OBJECT REPRESENTS + This object is the representation of our binary decision rule. + !*/ + + // centers of the two classes + sample_type positive_center, negative_center; + + double operator() ( + const sample_type& x + ) const + { + // if x is closer to the positive class then return +1 + if (length(positive_center - x) < length(negative_center - x)) + return +1; + else + return -1; + } +}; + +// Later on in this example we will save our decision functions to disk. This +// pair of routines is needed for this functionality. +void serialize (const custom_decision_function& item, std::ostream& out) +{ + // write the state of item to the output stream + serialize(item.positive_center, out); + serialize(item.negative_center, out); +} + +void deserialize (custom_decision_function& item, std::istream& in) +{ + // read the data from the input stream and store it in item + deserialize(item.positive_center, in); + deserialize(item.negative_center, in); +} + +// ---------------------------------------------------------------------------------------- + +class simple_custom_trainer +{ + /*! + WHAT THIS OBJECT REPRESENTS + This is our example custom binary classifier trainer object. It simply + computes the means of the +1 and -1 classes, puts them into our + custom_decision_function, and returns the results. + + Below we define the train() function. I have also included the + requires/ensures definition for a generic binary classifier's train() + !*/ +public: + + + custom_decision_function train ( + const std::vector<sample_type>& samples, + const std::vector<double>& labels + ) const + /*! + requires + - is_binary_classification_problem(samples, labels) == true + (e.g. labels consists of only +1 and -1 values, samples.size() == labels.size()) + ensures + - returns a decision function F with the following properties: + - if (new_x is a sample predicted have +1 label) then + - F(new_x) >= 0 + - else + - F(new_x) < 0 + !*/ + { + sample_type positive_center, negative_center; + + // compute sums of each class + positive_center = 0; + negative_center = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (labels[i] == +1) + positive_center += samples[i]; + else // this is a -1 sample + negative_center += samples[i]; + } + + // divide by number of +1 samples + positive_center /= sum(mat(labels) == +1); + // divide by number of -1 samples + negative_center /= sum(mat(labels) == -1); + + custom_decision_function df; + df.positive_center = positive_center; + df.negative_center = negative_center; + + return df; + } +}; + +// ---------------------------------------------------------------------------------------- + +void generate_data ( + std::vector<sample_type>& samples, + std::vector<string>& labels +); +/*! + ensures + - make some four class data as described above. + - each class will have 50 samples in it +!*/ + +// ---------------------------------------------------------------------------------------- + +int main() +{ + std::vector<sample_type> samples; + std::vector<string> labels; + + // First, get our labeled set of training data + generate_data(samples, labels); + + cout << "samples.size(): "<< samples.size() << endl; + + // Define the trainer we will use. The second template argument specifies the type + // of label used, which is string in this case. + typedef one_vs_one_trainer<any_trainer<sample_type>, string> ovo_trainer; + + + ovo_trainer trainer; + + // Now tell the one_vs_one_trainer that, by default, it should use the simple_custom_trainer + // to solve the individual binary classification subproblems. + trainer.set_trainer(simple_custom_trainer()); + + // Next, to make things a little more interesting, we will setup the one_vs_one_trainer + // to use kernel ridge regression to solve the upper_left vs lower_right binary classification + // subproblem. + typedef radial_basis_kernel<sample_type> rbf_kernel; + krr_trainer<rbf_kernel> rbf_trainer; + rbf_trainer.set_kernel(rbf_kernel(0.1)); + trainer.set_trainer(rbf_trainer, "upper_left", "lower_right"); + + + // Now let's do 5-fold cross-validation using the one_vs_one_trainer we just setup. + // As an aside, always shuffle the order of the samples before doing cross validation. + // For a discussion of why this is a good idea see the svm_ex.cpp example. + randomize_samples(samples, labels); + cout << "cross validation: \n" << cross_validate_multiclass_trainer(trainer, samples, labels, 5) << endl; + // This dataset is very easy and everything is correctly classified. Therefore, the output of + // cross validation is the following confusion matrix. + /* + 50 0 0 0 + 0 50 0 0 + 0 0 50 0 + 0 0 0 50 + */ + + + // We can also obtain the decision rule as always. + one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, labels); + + cout << "predicted label: "<< df(samples[0]) << ", true label: "<< labels[0] << endl; + cout << "predicted label: "<< df(samples[90]) << ", true label: "<< labels[90] << endl; + // The output is: + /* + predicted label: upper_right, true label: upper_right + predicted label: lower_left, true label: lower_left + */ + + + // Finally, let's save our multiclass decision rule to disk. Remember that we have + // to specify the types of binary decision function used inside the one_vs_one_decision_function. + one_vs_one_decision_function<ovo_trainer, + custom_decision_function, // This is the output of the simple_custom_trainer + decision_function<radial_basis_kernel<sample_type> > // This is the output of the rbf_trainer + > df2, df3; + + df2 = df; + // save to a file called df.dat + serialize("df.dat") << df2; + + // load the function back in from disk and store it in df3. + deserialize("df.dat") >> df3; + + + // Test df3 to see that this worked. + cout << endl; + cout << "predicted label: "<< df3(samples[0]) << ", true label: "<< labels[0] << endl; + cout << "predicted label: "<< df3(samples[90]) << ", true label: "<< labels[90] << endl; + // Test df3 on the samples and labels and print the confusion matrix. + cout << "test deserialized function: \n" << test_multiclass_decision_function(df3, samples, labels) << endl; + +} + +// ---------------------------------------------------------------------------------------- + +void generate_data ( + std::vector<sample_type>& samples, + std::vector<string>& labels +) +{ + const long num = 50; + + sample_type m; + + dlib::rand rnd; + + + // add some points in the upper right quadrant + m = 10, 10; + for (long i = 0; i < num; ++i) + { + samples.push_back(m + randm(2,1,rnd)); + labels.push_back("upper_right"); + } + + // add some points in the upper left quadrant + m = -10, 10; + for (long i = 0; i < num; ++i) + { + samples.push_back(m + randm(2,1,rnd)); + labels.push_back("upper_left"); + } + + // add some points in the lower right quadrant + m = 10, -10; + for (long i = 0; i < num; ++i) + { + samples.push_back(m + randm(2,1,rnd)); + labels.push_back("lower_right"); + } + + // add some points in the lower left quadrant + m = -10, -10; + for (long i = 0; i < num; ++i) + { + samples.push_back(m + randm(2,1,rnd)); + labels.push_back("lower_left"); + } + +} + +// ---------------------------------------------------------------------------------------- + |