// Copyright (C) 2009 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ROC_TRAINEr_H_ #define DLIB_ROC_TRAINEr_H_ #include "roc_trainer_abstract.h" #include "../algs.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- template < typename trainer_type > class roc_trainer_type { public: typedef typename trainer_type::kernel_type kernel_type; typedef typename trainer_type::scalar_type scalar_type; typedef typename trainer_type::sample_type sample_type; typedef typename trainer_type::mem_manager_type mem_manager_type; typedef typename trainer_type::trained_function_type trained_function_type; roc_trainer_type ( ) : desired_accuracy(0), class_selection(0){} roc_trainer_type ( const trainer_type& trainer_, const scalar_type& desired_accuracy_, const scalar_type& class_selection_ ) : trainer(trainer_), desired_accuracy(desired_accuracy_), class_selection(class_selection_) { // make sure requires clause is not broken DLIB_ASSERT(0 <= desired_accuracy && desired_accuracy <= 1 && (class_selection == -1 || class_selection == +1), "\t roc_trainer_type::roc_trainer_type()" << "\n\t invalid inputs were given to this function" << "\n\t desired_accuracy: " << desired_accuracy << "\n\t class_selection: " << class_selection ); } template < typename in_sample_vector_type, typename in_scalar_vector_type > const trained_function_type train ( const in_sample_vector_type& samples, const in_scalar_vector_type& labels ) const /*! requires - is_binary_classification_problem(samples, labels) == true !*/ { // make sure requires clause is not broken DLIB_ASSERT(is_binary_classification_problem(samples, labels), "\t roc_trainer_type::train()" << "\n\t invalid inputs were given to this function" ); return do_train(mat(samples), mat(labels)); } private: template < typename in_sample_vector_type, typename in_scalar_vector_type > const trained_function_type do_train ( const in_sample_vector_type& samples, const in_scalar_vector_type& labels ) const { trained_function_type df = trainer.train(samples, labels); // clear out the old bias df.b = 0; // obtain all the scores from the df using all the class_selection labeled samples std::vector scores; for (long i = 0; i < samples.size(); ++i) { if (labels(i) == class_selection) scores.push_back(df(samples(i))); } if (class_selection == +1) std::sort(scores.rbegin(), scores.rend()); else std::sort(scores.begin(), scores.end()); // now pick out the index that gives us the desired accuracy with regards to selected class unsigned long idx = static_cast(desired_accuracy*scores.size() + 0.5); if (idx >= scores.size()) idx = scores.size()-1; df.b = scores[idx]; // In this case add a very small extra amount to the bias so that all the samples // with the class_selection label are classified correctly. if (desired_accuracy == 1) { if (class_selection == +1) df.b -= std::numeric_limits::epsilon()*df.b; else df.b += std::numeric_limits::epsilon()*df.b; } return df; } trainer_type trainer; scalar_type desired_accuracy; scalar_type class_selection; }; // ---------------------------------------------------------------------------------------- template < typename trainer_type > const roc_trainer_type roc_c1_trainer ( const trainer_type& trainer, const typename trainer_type::scalar_type& desired_accuracy ) { return roc_trainer_type(trainer, desired_accuracy, +1); } // ---------------------------------------------------------------------------------------- template < typename trainer_type > const roc_trainer_type roc_c2_trainer ( const trainer_type& trainer, const typename trainer_type::scalar_type& desired_accuracy ) { return roc_trainer_type(trainer, desired_accuracy, -1); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_ROC_TRAINEr_H_