diff options
Diffstat (limited to 'ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h')
-rw-r--r-- | ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h | 208 |
1 files changed, 208 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h new file mode 100644 index 000000000..be8fa3f3f --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h @@ -0,0 +1,208 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_ +#define DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_ + +#include <vector> +#include "../matrix.h" +#include "cross_validate_multiclass_trainer_abstract.h" +#include <sstream> + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename dec_funct_type, + typename sample_type, + typename label_type + > + const matrix<double> test_multiclass_decision_function ( + const dec_funct_type& dec_funct, + const std::vector<sample_type>& x_test, + const std::vector<label_type>& y_test + ) + { + + // make sure requires clause is not broken + DLIB_ASSERT( is_learning_problem(x_test,y_test) == true, + "\tmatrix test_multiclass_decision_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_learning_problem(x_test,y_test): " + << is_learning_problem(x_test,y_test)); + + + const std::vector<label_type> all_labels = dec_funct.get_labels(); + + // make a lookup table that maps from labels to their index in all_labels + std::map<label_type,unsigned long> label_to_int; + for (unsigned long i = 0; i < all_labels.size(); ++i) + label_to_int[all_labels[i]] = i; + + matrix<double, 0, 0, typename dec_funct_type::mem_manager_type> res; + res.set_size(all_labels.size(), all_labels.size()); + + res = 0; + + typename std::map<label_type,unsigned long>::const_iterator iter; + + // now test this trained object + for (unsigned long i = 0; i < x_test.size(); ++i) + { + iter = label_to_int.find(y_test[i]); + // ignore samples with labels that the decision function doesn't know about. + if (iter == label_to_int.end()) + continue; + + const unsigned long truth = iter->second; + const unsigned long pred = label_to_int[dec_funct(x_test[i])]; + + res(truth,pred) += 1; + } + + return res; + } + +// ---------------------------------------------------------------------------------------- + + class cross_validation_error : public dlib::error + { + public: + cross_validation_error(const std::string& msg) : dlib::error(msg){}; + }; + + template < + typename trainer_type, + typename sample_type, + typename label_type + > + const matrix<double> cross_validate_multiclass_trainer ( + const trainer_type& trainer, + const std::vector<sample_type>& x, + const std::vector<label_type>& y, + const long folds + ) + { + typedef typename trainer_type::mem_manager_type mem_manager_type; + + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(x,y) == true && + 1 < folds && folds <= static_cast<long>(x.size()), + "\tmatrix cross_validate_multiclass_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t x.size(): " << x.size() + << "\n\t folds: " << folds + << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y) + ); + + const std::vector<label_type> all_labels = select_all_distinct_labels(y); + + // count the number of times each label shows up + std::map<label_type,long> label_counts; + for (unsigned long i = 0; i < y.size(); ++i) + label_counts[y[i]] += 1; + + + // figure out how many samples from each class will be in the test and train splits + std::map<label_type,long> num_in_test, num_in_train; + for (typename std::map<label_type,long>::iterator i = label_counts.begin(); i != label_counts.end(); ++i) + { + const long in_test = i->second/folds; + if (in_test == 0) + { + std::ostringstream sout; + sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl; + sout << "than the number of elements of one of the training classes." << std::endl; + sout << " folds: "<< folds << std::endl; + sout << " size of class " << i->first << ": "<< i->second << std::endl; + throw cross_validation_error(sout.str()); + } + num_in_test[i->first] = in_test; + num_in_train[i->first] = i->second - in_test; + } + + + + std::vector<sample_type> x_test, x_train; + std::vector<label_type> y_test, y_train; + + matrix<double, 0, 0, mem_manager_type> res; + + std::map<label_type,long> next_test_idx; + for (unsigned long i = 0; i < all_labels.size(); ++i) + next_test_idx[all_labels[i]] = 0; + + label_type label; + + for (long i = 0; i < folds; ++i) + { + x_test.clear(); + y_test.clear(); + x_train.clear(); + y_train.clear(); + + // load up the test samples + for (unsigned long j = 0; j < all_labels.size(); ++j) + { + label = all_labels[j]; + long next = next_test_idx[label]; + + long cur = 0; + const long num_needed = num_in_test[label]; + while (cur < num_needed) + { + if (y[next] == label) + { + x_test.push_back(x[next]); + y_test.push_back(label); + ++cur; + } + next = (next + 1)%x.size(); + } + + next_test_idx[label] = next; + } + + // load up the training samples + for (unsigned long j = 0; j < all_labels.size(); ++j) + { + label = all_labels[j]; + long next = next_test_idx[label]; + + long cur = 0; + const long num_needed = num_in_train[label]; + while (cur < num_needed) + { + if (y[next] == label) + { + x_train.push_back(x[next]); + y_train.push_back(label); + ++cur; + } + next = (next + 1)%x.size(); + } + } + + + try + { + // do the training and testing + res += test_multiclass_decision_function(trainer.train(x_train,y_train),x_test,y_test); + } + catch (invalid_nu_error&) + { + // just ignore cases which result in an invalid nu + } + + } // for (long i = 0; i < folds; ++i) + + return res; + } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_ + |