diff options
Diffstat (limited to 'ml/dlib/dlib/svm/one_vs_one_trainer.h')
-rw-r--r-- | ml/dlib/dlib/svm/one_vs_one_trainer.h | 249 |
1 files changed, 249 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/one_vs_one_trainer.h b/ml/dlib/dlib/svm/one_vs_one_trainer.h new file mode 100644 index 000000000..2beec8f67 --- /dev/null +++ b/ml/dlib/dlib/svm/one_vs_one_trainer.h @@ -0,0 +1,249 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ONE_VS_ONE_TRAiNER_Hh_ +#define DLIB_ONE_VS_ONE_TRAiNER_Hh_ + +#include "one_vs_one_trainer_abstract.h" + +#include "one_vs_one_decision_function.h" +#include <vector> + +#include "../unordered_pair.h" +#include "multiclass_tools.h" + +#include <sstream> +#include <iostream> + +#include "../any.h" +#include <map> +#include <set> +#include "../threads.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename any_trainer, + typename label_type_ = double + > + class one_vs_one_trainer + { + public: + typedef label_type_ label_type; + + typedef typename any_trainer::sample_type sample_type; + typedef typename any_trainer::scalar_type scalar_type; + typedef typename any_trainer::mem_manager_type mem_manager_type; + + typedef one_vs_one_decision_function<one_vs_one_trainer> trained_function_type; + + one_vs_one_trainer ( + ) : + verbose(false), + num_threads(4) + {} + + void set_trainer ( + const any_trainer& trainer + ) + { + default_trainer = trainer; + trainers.clear(); + } + + void set_trainer ( + const any_trainer& trainer, + const label_type& l1, + const label_type& l2 + ) + { + trainers[make_unordered_pair(l1,l2)] = trainer; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_num_threads ( + unsigned long num + ) + { + num_threads = num; + } + + unsigned long get_num_threads ( + ) const + { + return num_threads; + } + + struct invalid_label : public dlib::error + { + invalid_label(const std::string& msg, const label_type& l1_, const label_type& l2_ + ) : dlib::error(msg), l1(l1_), l2(l2_) {}; + + virtual ~invalid_label( + ) throw() {} + + label_type l1, l2; + }; + + trained_function_type train ( + const std::vector<sample_type>& all_samples, + const std::vector<label_type>& all_labels + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_learning_problem(all_samples,all_labels), + "\t trained_function_type one_vs_one_trainer::train(all_samples,all_labels)" + << "\n\t invalid inputs were given to this function" + << "\n\t all_samples.size(): " << all_samples.size() + << "\n\t all_labels.size(): " << all_labels.size() + ); + + const std::vector<label_type> distinct_labels = select_all_distinct_labels(all_labels); + + + // fill pairs with all the pairs of labels. + std::vector<unordered_pair<label_type> > pairs; + for (unsigned long i = 0; i < distinct_labels.size(); ++i) + { + for (unsigned long j = i+1; j < distinct_labels.size(); ++j) + { + pairs.push_back(unordered_pair<label_type>(distinct_labels[i], distinct_labels[j])); + + // make sure we have a trainer for this pair + const typename binary_function_table::const_iterator itr = trainers.find(pairs.back()); + if (itr == trainers.end() && default_trainer.is_empty()) + { + std::ostringstream sout; + sout << "In one_vs_one_trainer, no trainer registered for the (" + << pairs.back().first << ", " << pairs.back().second << ") label pair."; + throw invalid_label(sout.str(), pairs.back().first, pairs.back().second); + } + } + } + + + + // Now train on all the label pairs. + parallel_for_helper helper(all_samples,all_labels,default_trainer,trainers,verbose,pairs); + parallel_for(num_threads, 0, pairs.size(), helper, 500); + + if (helper.error_message.size() != 0) + { + throw dlib::error("binary trainer threw while training one vs. one classifier. Error was: " + helper.error_message); + } + return trained_function_type(helper.dfs); + } + + private: + + typedef std::map<unordered_pair<label_type>, any_trainer> binary_function_table; + + struct parallel_for_helper + { + parallel_for_helper( + const std::vector<sample_type>& all_samples_, + const std::vector<label_type>& all_labels_, + const any_trainer& default_trainer_, + const binary_function_table& trainers_, + const bool verbose_, + const std::vector<unordered_pair<label_type> >& pairs_ + ) : + all_samples(all_samples_), + all_labels(all_labels_), + default_trainer(default_trainer_), + trainers(trainers_), + verbose(verbose_), + pairs(pairs_) + {} + + void operator()(long i) const + { + try + { + std::vector<sample_type> samples; + std::vector<scalar_type> labels; + + const unordered_pair<label_type> p = pairs[i]; + + // pick out the samples corresponding to these two classes + for (unsigned long k = 0; k < all_samples.size(); ++k) + { + if (all_labels[k] == p.first) + { + samples.push_back(all_samples[k]); + labels.push_back(+1); + } + else if (all_labels[k] == p.second) + { + samples.push_back(all_samples[k]); + labels.push_back(-1); + } + } + + if (verbose) + { + auto_mutex lock(class_mutex); + std::cout << "Training classifier for " << p.first << " vs. " << p.second << std::endl; + } + + any_trainer trainer; + // now train a binary classifier using the samples we selected + { auto_mutex lock(class_mutex); + const typename binary_function_table::const_iterator itr = trainers.find(p); + if (itr != trainers.end()) + trainer = itr->second; + else + trainer = default_trainer; + } + + any_decision_function<sample_type,scalar_type> binary_df = trainer.train(samples, labels); + + auto_mutex lock(class_mutex); + dfs[p] = binary_df; + } + catch (std::exception& e) + { + auto_mutex lock(class_mutex); + error_message = e.what(); + } + } + + mutable typename trained_function_type::binary_function_table dfs; + mutex class_mutex; + mutable std::string error_message; + + const std::vector<sample_type>& all_samples; + const std::vector<label_type>& all_labels; + const any_trainer& default_trainer; + const binary_function_table& trainers; + const bool verbose; + const std::vector<unordered_pair<label_type> >& pairs; + }; + + + any_trainer default_trainer; + binary_function_table trainers; + bool verbose; + unsigned long num_threads; + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ONE_VS_ONE_TRAiNER_Hh_ + |