summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h')
-rw-r--r--ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h208
1 files changed, 208 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h
new file mode 100644
index 000000000..be8fa3f3f
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_multiclass_trainer.h
@@ -0,0 +1,208 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
+#define DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
+
+#include <vector>
+#include "../matrix.h"
+#include "cross_validate_multiclass_trainer_abstract.h"
+#include <sstream>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename dec_funct_type,
+ typename sample_type,
+ typename label_type
+ >
+ const matrix<double> test_multiclass_decision_function (
+ const dec_funct_type& dec_funct,
+ const std::vector<sample_type>& x_test,
+ const std::vector<label_type>& y_test
+ )
+ {
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT( is_learning_problem(x_test,y_test) == true,
+ "\tmatrix test_multiclass_decision_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_learning_problem(x_test,y_test): "
+ << is_learning_problem(x_test,y_test));
+
+
+ const std::vector<label_type> all_labels = dec_funct.get_labels();
+
+ // make a lookup table that maps from labels to their index in all_labels
+ std::map<label_type,unsigned long> label_to_int;
+ for (unsigned long i = 0; i < all_labels.size(); ++i)
+ label_to_int[all_labels[i]] = i;
+
+ matrix<double, 0, 0, typename dec_funct_type::mem_manager_type> res;
+ res.set_size(all_labels.size(), all_labels.size());
+
+ res = 0;
+
+ typename std::map<label_type,unsigned long>::const_iterator iter;
+
+ // now test this trained object
+ for (unsigned long i = 0; i < x_test.size(); ++i)
+ {
+ iter = label_to_int.find(y_test[i]);
+ // ignore samples with labels that the decision function doesn't know about.
+ if (iter == label_to_int.end())
+ continue;
+
+ const unsigned long truth = iter->second;
+ const unsigned long pred = label_to_int[dec_funct(x_test[i])];
+
+ res(truth,pred) += 1;
+ }
+
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class cross_validation_error : public dlib::error
+ {
+ public:
+ cross_validation_error(const std::string& msg) : dlib::error(msg){};
+ };
+
+ template <
+ typename trainer_type,
+ typename sample_type,
+ typename label_type
+ >
+ const matrix<double> cross_validate_multiclass_trainer (
+ const trainer_type& trainer,
+ const std::vector<sample_type>& x,
+ const std::vector<label_type>& y,
+ const long folds
+ )
+ {
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(x,y) == true &&
+ 1 < folds && folds <= static_cast<long>(x.size()),
+ "\tmatrix cross_validate_multiclass_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t x.size(): " << x.size()
+ << "\n\t folds: " << folds
+ << "\n\t is_learning_problem(x,y): " << is_learning_problem(x,y)
+ );
+
+ const std::vector<label_type> all_labels = select_all_distinct_labels(y);
+
+ // count the number of times each label shows up
+ std::map<label_type,long> label_counts;
+ for (unsigned long i = 0; i < y.size(); ++i)
+ label_counts[y[i]] += 1;
+
+
+ // figure out how many samples from each class will be in the test and train splits
+ std::map<label_type,long> num_in_test, num_in_train;
+ for (typename std::map<label_type,long>::iterator i = label_counts.begin(); i != label_counts.end(); ++i)
+ {
+ const long in_test = i->second/folds;
+ if (in_test == 0)
+ {
+ std::ostringstream sout;
+ sout << "In dlib::cross_validate_multiclass_trainer(), the number of folds was larger" << std::endl;
+ sout << "than the number of elements of one of the training classes." << std::endl;
+ sout << " folds: "<< folds << std::endl;
+ sout << " size of class " << i->first << ": "<< i->second << std::endl;
+ throw cross_validation_error(sout.str());
+ }
+ num_in_test[i->first] = in_test;
+ num_in_train[i->first] = i->second - in_test;
+ }
+
+
+
+ std::vector<sample_type> x_test, x_train;
+ std::vector<label_type> y_test, y_train;
+
+ matrix<double, 0, 0, mem_manager_type> res;
+
+ std::map<label_type,long> next_test_idx;
+ for (unsigned long i = 0; i < all_labels.size(); ++i)
+ next_test_idx[all_labels[i]] = 0;
+
+ label_type label;
+
+ for (long i = 0; i < folds; ++i)
+ {
+ x_test.clear();
+ y_test.clear();
+ x_train.clear();
+ y_train.clear();
+
+ // load up the test samples
+ for (unsigned long j = 0; j < all_labels.size(); ++j)
+ {
+ label = all_labels[j];
+ long next = next_test_idx[label];
+
+ long cur = 0;
+ const long num_needed = num_in_test[label];
+ while (cur < num_needed)
+ {
+ if (y[next] == label)
+ {
+ x_test.push_back(x[next]);
+ y_test.push_back(label);
+ ++cur;
+ }
+ next = (next + 1)%x.size();
+ }
+
+ next_test_idx[label] = next;
+ }
+
+ // load up the training samples
+ for (unsigned long j = 0; j < all_labels.size(); ++j)
+ {
+ label = all_labels[j];
+ long next = next_test_idx[label];
+
+ long cur = 0;
+ const long num_needed = num_in_train[label];
+ while (cur < num_needed)
+ {
+ if (y[next] == label)
+ {
+ x_train.push_back(x[next]);
+ y_train.push_back(label);
+ ++cur;
+ }
+ next = (next + 1)%x.size();
+ }
+ }
+
+
+ try
+ {
+ // do the training and testing
+ res += test_multiclass_decision_function(trainer.train(x_train,y_train),x_test,y_test);
+ }
+ catch (invalid_nu_error&)
+ {
+ // just ignore cases which result in an invalid nu
+ }
+
+ } // for (long i = 0; i < folds; ++i)
+
+ return res;
+ }
+
+}
+
+// ----------------------------------------------------------------------------------------
+
+#endif // DLIB_CROSS_VALIDATE_MULTICLASS_TRaINER_Hh_
+