summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/svm/cross_validate_assignment_trainer.h')
-rw-r--r--ml/dlib/dlib/svm/cross_validate_assignment_trainer.h181
1 files changed, 181 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h b/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h
new file mode 100644
index 000000000..8166e1c82
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h
@@ -0,0 +1,181 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_
+#define DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_
+
+#include "cross_validate_assignment_trainer_abstract.h"
+#include <vector>
+#include "../matrix.h"
+#include "svm.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename assignment_function
+ >
+ double test_assignment_function (
+ const assignment_function& assigner,
+ const std::vector<typename assignment_function::sample_type>& samples,
+ const std::vector<typename assignment_function::label_type>& labels
+ )
+ {
+ // make sure requires clause is not broken
+#ifdef ENABLE_ASSERTS
+ if (assigner.forces_assignment())
+ {
+ DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
+ "\t double test_assignment_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ );
+ }
+ else
+ {
+ DLIB_ASSERT(is_assignment_problem(samples, labels),
+ "\t double test_assignment_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ );
+ }
+#endif
+ double total_right = 0;
+ double total = 0;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ const std::vector<long>& out = assigner(samples[i]);
+ for (unsigned long j = 0; j < out.size(); ++j)
+ {
+ if (out[j] == labels[i][j])
+ ++total_right;
+
+ ++total;
+ }
+ }
+
+ if (total != 0)
+ return total_right/total;
+ else
+ return 1;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ double cross_validate_assignment_trainer (
+ const trainer_type& trainer,
+ const std::vector<typename trainer_type::sample_type>& samples,
+ const std::vector<typename trainer_type::label_type>& labels,
+ const long folds
+ )
+ {
+ // make sure requires clause is not broken
+#ifdef ENABLE_ASSERTS
+ if (trainer.forces_assignment())
+ {
+ DLIB_ASSERT(is_forced_assignment_problem(samples, labels) &&
+ 1 < folds && folds <= static_cast<long>(samples.size()),
+ "\t double cross_validate_assignment_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t folds: " << folds
+ << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ );
+ }
+ else
+ {
+ DLIB_ASSERT(is_assignment_problem(samples, labels) &&
+ 1 < folds && folds <= static_cast<long>(samples.size()),
+ "\t double cross_validate_assignment_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t folds: " << folds
+ << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
+ << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
+ );
+ }
+#endif
+
+
+
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::label_type label_type;
+
+ const long num_in_test = samples.size()/folds;
+ const long num_in_train = samples.size() - num_in_test;
+
+
+ std::vector<sample_type> samples_test, samples_train;
+ std::vector<label_type> labels_test, labels_train;
+
+
+ long next_test_idx = 0;
+ double total_right = 0;
+ double total = 0;
+
+
+ for (long i = 0; i < folds; ++i)
+ {
+ samples_test.clear();
+ labels_test.clear();
+ samples_train.clear();
+ labels_train.clear();
+
+ // load up the test samples
+ for (long cnt = 0; cnt < num_in_test; ++cnt)
+ {
+ samples_test.push_back(samples[next_test_idx]);
+ labels_test.push_back(labels[next_test_idx]);
+ next_test_idx = (next_test_idx + 1)%samples.size();
+ }
+
+ // load up the training samples
+ long next = next_test_idx;
+ for (long cnt = 0; cnt < num_in_train; ++cnt)
+ {
+ samples_train.push_back(samples[next]);
+ labels_train.push_back(labels[next]);
+ next = (next + 1)%samples.size();
+ }
+
+
+ const typename trainer_type::trained_function_type& df = trainer.train(samples_train,labels_train);
+
+ // check how good df is on the test data
+ for (unsigned long i = 0; i < samples_test.size(); ++i)
+ {
+ const std::vector<long>& out = df(samples_test[i]);
+ for (unsigned long j = 0; j < out.size(); ++j)
+ {
+ if (out[j] == labels_test[i][j])
+ ++total_right;
+
+ ++total;
+ }
+ }
+
+ } // for (long i = 0; i < folds; ++i)
+
+ if (total != 0)
+ return total_right/total;
+ else
+ return 1;
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_
+