diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-09 13:19:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-09 13:20:02 +0000 |
commit | 58daab21cd043e1dc37024a7f99b396788372918 (patch) | |
tree | 96771e43bb69f7c1c2b0b4f7374cb74d7866d0cb /ml/dlib/dlib/svm/roc_trainer.h | |
parent | Releasing debian version 1.43.2-1. (diff) | |
download | netdata-58daab21cd043e1dc37024a7f99b396788372918.tar.xz netdata-58daab21cd043e1dc37024a7f99b396788372918.zip |
Merging upstream version 1.44.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'ml/dlib/dlib/svm/roc_trainer.h')
-rw-r--r-- | ml/dlib/dlib/svm/roc_trainer.h | 149 |
1 files changed, 149 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/roc_trainer.h b/ml/dlib/dlib/svm/roc_trainer.h new file mode 100644 index 000000000..fa2c0ef9b --- /dev/null +++ b/ml/dlib/dlib/svm/roc_trainer.h @@ -0,0 +1,149 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ROC_TRAINEr_H_ +#define DLIB_ROC_TRAINEr_H_ + +#include "roc_trainer_abstract.h" +#include "../algs.h" +#include <limits> + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class roc_trainer_type + { + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + roc_trainer_type ( + ) : desired_accuracy(0), class_selection(0){} + + roc_trainer_type ( + const trainer_type& trainer_, + const scalar_type& desired_accuracy_, + const scalar_type& class_selection_ + ) : trainer(trainer_), desired_accuracy(desired_accuracy_), class_selection(class_selection_) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= desired_accuracy && desired_accuracy <= 1 && + (class_selection == -1 || class_selection == +1), + "\t roc_trainer_type::roc_trainer_type()" + << "\n\t invalid inputs were given to this function" + << "\n\t desired_accuracy: " << desired_accuracy + << "\n\t class_selection: " << class_selection + ); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const trained_function_type train ( + const in_sample_vector_type& samples, + const in_scalar_vector_type& labels + ) const + /*! + requires + - is_binary_classification_problem(samples, labels) == true + !*/ + { + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(samples, labels), + "\t roc_trainer_type::train()" + << "\n\t invalid inputs were given to this function" + ); + + + return do_train(mat(samples), mat(labels)); + } + + private: + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const trained_function_type do_train ( + const in_sample_vector_type& samples, + const in_scalar_vector_type& labels + ) const + { + trained_function_type df = trainer.train(samples, labels); + + // clear out the old bias + df.b = 0; + + // obtain all the scores from the df using all the class_selection labeled samples + std::vector<double> scores; + for (long i = 0; i < samples.size(); ++i) + { + if (labels(i) == class_selection) + scores.push_back(df(samples(i))); + } + + if (class_selection == +1) + std::sort(scores.rbegin(), scores.rend()); + else + std::sort(scores.begin(), scores.end()); + + // now pick out the index that gives us the desired accuracy with regards to selected class + unsigned long idx = static_cast<unsigned long>(desired_accuracy*scores.size() + 0.5); + if (idx >= scores.size()) + idx = scores.size()-1; + + df.b = scores[idx]; + + // In this case add a very small extra amount to the bias so that all the samples + // with the class_selection label are classified correctly. + if (desired_accuracy == 1) + { + if (class_selection == +1) + df.b -= std::numeric_limits<scalar_type>::epsilon()*df.b; + else + df.b += std::numeric_limits<scalar_type>::epsilon()*df.b; + } + + return df; + } + + trainer_type trainer; + scalar_type desired_accuracy; + scalar_type class_selection; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const roc_trainer_type<trainer_type> roc_c1_trainer ( + const trainer_type& trainer, + const typename trainer_type::scalar_type& desired_accuracy + ) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, +1); } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const roc_trainer_type<trainer_type> roc_c2_trainer ( + const trainer_type& trainer, + const typename trainer_type::scalar_type& desired_accuracy + ) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, -1); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ROC_TRAINEr_H_ + + |