diff options
Diffstat (limited to 'ml/dlib/dlib/svm/one_vs_all_trainer.h')
-rw-r--r-- | ml/dlib/dlib/svm/one_vs_all_trainer.h | 234 |
1 files changed, 0 insertions, 234 deletions
diff --git a/ml/dlib/dlib/svm/one_vs_all_trainer.h b/ml/dlib/dlib/svm/one_vs_all_trainer.h deleted file mode 100644 index bcb006a41..000000000 --- a/ml/dlib/dlib/svm/one_vs_all_trainer.h +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright (C) 2010 Davis E. King (davis@dlib.net) -// License: Boost Software License See LICENSE.txt for the full license. -#ifndef DLIB_ONE_VS_ALL_TRAiNER_Hh_ -#define DLIB_ONE_VS_ALL_TRAiNER_Hh_ - -#include "one_vs_all_trainer_abstract.h" - -#include "one_vs_all_decision_function.h" -#include <vector> - -#include "multiclass_tools.h" - -#include <sstream> -#include <iostream> - -#include "../any.h" -#include <map> -#include <set> -#include "../threads.h" - -namespace dlib -{ - -// ---------------------------------------------------------------------------------------- - - template < - typename any_trainer, - typename label_type_ = double - > - class one_vs_all_trainer - { - public: - typedef label_type_ label_type; - - typedef typename any_trainer::sample_type sample_type; - typedef typename any_trainer::scalar_type scalar_type; - typedef typename any_trainer::mem_manager_type mem_manager_type; - - typedef one_vs_all_decision_function<one_vs_all_trainer> trained_function_type; - - one_vs_all_trainer ( - ) : - verbose(false), - num_threads(4) - {} - - void set_trainer ( - const any_trainer& trainer - ) - { - default_trainer = trainer; - trainers.clear(); - } - - void set_trainer ( - const any_trainer& trainer, - const label_type& l - ) - { - trainers[l] = trainer; - } - - void be_verbose ( - ) - { - verbose = true; - } - - void be_quiet ( - ) - { - verbose = false; - } - - void set_num_threads ( - unsigned long num - ) - { - num_threads = num; - } - - unsigned long get_num_threads ( - ) const - { - return num_threads; - } - - struct invalid_label : public dlib::error - { - invalid_label(const std::string& msg, const label_type& l_ - ) : dlib::error(msg), l(l_) {}; - - virtual ~invalid_label( - ) throw() {} - - label_type l; - }; - - trained_function_type train ( - const std::vector<sample_type>& all_samples, - const std::vector<label_type>& all_labels - ) const - { - // make sure requires clause is not broken - DLIB_ASSERT(is_learning_problem(all_samples,all_labels), - "\t trained_function_type one_vs_all_trainer::train(all_samples,all_labels)" - << "\n\t invalid inputs were given to this function" - << "\n\t all_samples.size(): " << all_samples.size() - << "\n\t all_labels.size(): " << all_labels.size() - ); - - const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels); - - // make sure we have a trainer object for each of the label types. - for (unsigned long i = 0; i < distinct_labels.size(); ++i) - { - const label_type l = distinct_labels[i]; - const typename binary_function_table::const_iterator itr = trainers.find(l); - - if (itr == trainers.end() && default_trainer.is_empty()) - { - std::ostringstream sout; - sout << "In one_vs_all_trainer, no trainer registered for the " << l << " label."; - throw invalid_label(sout.str(), l); - } - } - - - // now do the training - parallel_for_helper helper(all_samples,all_labels,default_trainer,trainers,verbose,distinct_labels); - parallel_for(num_threads, 0, distinct_labels.size(), helper, 500); - - if (helper.error_message.size() != 0) - { - throw dlib::error("binary trainer threw while training one vs. all classifier. Error was: " + helper.error_message); - } - return trained_function_type(helper.dfs); - } - - private: - - typedef std::map<label_type, any_trainer> binary_function_table; - struct parallel_for_helper - { - parallel_for_helper( - const std::vector<sample_type>& all_samples_, - const std::vector<label_type>& all_labels_, - const any_trainer& default_trainer_, - const binary_function_table& trainers_, - const bool verbose_, - const std::vector<label_type>& distinct_labels_ - ) : - all_samples(all_samples_), - all_labels(all_labels_), - default_trainer(default_trainer_), - trainers(trainers_), - verbose(verbose_), - distinct_labels(distinct_labels_) - {} - - void operator()(long i) const - { - try - { - std::vector<scalar_type> labels; - - const label_type l = distinct_labels[i]; - - // setup one of the one vs all training sets - for (unsigned long k = 0; k < all_samples.size(); ++k) - { - if (all_labels[k] == l) - labels.push_back(+1); - else - labels.push_back(-1); - } - - - if (verbose) - { - auto_mutex lock(class_mutex); - std::cout << "Training classifier for " << l << " vs. all" << std::endl; - } - - any_trainer trainer; - // now train a binary classifier using the samples we selected - { auto_mutex lock(class_mutex); - const typename binary_function_table::const_iterator itr = trainers.find(l); - if (itr != trainers.end()) - trainer = itr->second; - else - trainer = default_trainer; - } - - any_decision_function<sample_type,scalar_type> binary_df = trainer.train(all_samples, labels); - - auto_mutex lock(class_mutex); - dfs[l] = binary_df; - } - catch (std::exception& e) - { - auto_mutex lock(class_mutex); - error_message = e.what(); - } - } - - mutable typename trained_function_type::binary_function_table dfs; - mutex class_mutex; - mutable std::string error_message; - - const std::vector<sample_type>& all_samples; - const std::vector<label_type>& all_labels; - const any_trainer& default_trainer; - const binary_function_table& trainers; - const bool verbose; - const std::vector<label_type>& distinct_labels; - }; - - any_trainer default_trainer; - - binary_function_table trainers; - - bool verbose; - unsigned long num_threads; - - }; - -// ---------------------------------------------------------------------------------------- - -} - -#endif // DLIB_ONE_VS_ALL_TRAiNER_Hh_ - - |