diff options
Diffstat (limited to 'ml/dlib/dlib/svm/cross_validate_assignment_trainer.h')
-rw-r--r-- | ml/dlib/dlib/svm/cross_validate_assignment_trainer.h | 181 |
1 files changed, 181 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h b/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h new file mode 100644 index 000000000..8166e1c82 --- /dev/null +++ b/ml/dlib/dlib/svm/cross_validate_assignment_trainer.h @@ -0,0 +1,181 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_ +#define DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_ + +#include "cross_validate_assignment_trainer_abstract.h" +#include <vector> +#include "../matrix.h" +#include "svm.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename assignment_function + > + double test_assignment_function ( + const assignment_function& assigner, + const std::vector<typename assignment_function::sample_type>& samples, + const std::vector<typename assignment_function::label_type>& labels + ) + { + // make sure requires clause is not broken +#ifdef ENABLE_ASSERTS + if (assigner.forces_assignment()) + { + DLIB_ASSERT(is_forced_assignment_problem(samples, labels), + "\t double test_assignment_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels) + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + ); + } + else + { + DLIB_ASSERT(is_assignment_problem(samples, labels), + "\t double test_assignment_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + ); + } +#endif + double total_right = 0; + double total = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + { + const std::vector<long>& out = assigner(samples[i]); + for (unsigned long j = 0; j < out.size(); ++j) + { + if (out[j] == labels[i][j]) + ++total_right; + + ++total; + } + } + + if (total != 0) + return total_right/total; + else + return 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + double cross_validate_assignment_trainer ( + const trainer_type& trainer, + const std::vector<typename trainer_type::sample_type>& samples, + const std::vector<typename trainer_type::label_type>& labels, + const long folds + ) + { + // make sure requires clause is not broken +#ifdef ENABLE_ASSERTS + if (trainer.forces_assignment()) + { + DLIB_ASSERT(is_forced_assignment_problem(samples, labels) && + 1 < folds && folds <= static_cast<long>(samples.size()), + "\t double cross_validate_assignment_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t folds: " << folds + << "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels) + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + ); + } + else + { + DLIB_ASSERT(is_assignment_problem(samples, labels) && + 1 < folds && folds <= static_cast<long>(samples.size()), + "\t double cross_validate_assignment_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t folds: " << folds + << "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels) + << "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels) + ); + } +#endif + + + + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::label_type label_type; + + const long num_in_test = samples.size()/folds; + const long num_in_train = samples.size() - num_in_test; + + + std::vector<sample_type> samples_test, samples_train; + std::vector<label_type> labels_test, labels_train; + + + long next_test_idx = 0; + double total_right = 0; + double total = 0; + + + for (long i = 0; i < folds; ++i) + { + samples_test.clear(); + labels_test.clear(); + samples_train.clear(); + labels_train.clear(); + + // load up the test samples + for (long cnt = 0; cnt < num_in_test; ++cnt) + { + samples_test.push_back(samples[next_test_idx]); + labels_test.push_back(labels[next_test_idx]); + next_test_idx = (next_test_idx + 1)%samples.size(); + } + + // load up the training samples + long next = next_test_idx; + for (long cnt = 0; cnt < num_in_train; ++cnt) + { + samples_train.push_back(samples[next]); + labels_train.push_back(labels[next]); + next = (next + 1)%samples.size(); + } + + + const typename trainer_type::trained_function_type& df = trainer.train(samples_train,labels_train); + + // check how good df is on the test data + for (unsigned long i = 0; i < samples_test.size(); ++i) + { + const std::vector<long>& out = df(samples_test[i]); + for (unsigned long j = 0; j < out.size(); ++j) + { + if (out[j] == labels_test[i][j]) + ++total_right; + + ++total; + } + } + + } // for (long i = 0; i < folds; ++i) + + if (total != 0) + return total_right/total; + else + return 1; + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CROSS_VALIDATE_ASSiGNEMNT_TRAINER_Hh_ + |