summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/roc_trainer.h
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-09 13:19:48 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-09 13:20:02 +0000
commit58daab21cd043e1dc37024a7f99b396788372918 (patch)
tree96771e43bb69f7c1c2b0b4f7374cb74d7866d0cb /ml/dlib/dlib/svm/roc_trainer.h
parentReleasing debian version 1.43.2-1. (diff)
downloadnetdata-58daab21cd043e1dc37024a7f99b396788372918.tar.xz
netdata-58daab21cd043e1dc37024a7f99b396788372918.zip
Merging upstream version 1.44.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'ml/dlib/dlib/svm/roc_trainer.h')
-rw-r--r--ml/dlib/dlib/svm/roc_trainer.h149
1 files changed, 149 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/roc_trainer.h b/ml/dlib/dlib/svm/roc_trainer.h
new file mode 100644
index 000000000..fa2c0ef9b
--- /dev/null
+++ b/ml/dlib/dlib/svm/roc_trainer.h
@@ -0,0 +1,149 @@
+// Copyright (C) 2009 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_ROC_TRAINEr_H_
+#define DLIB_ROC_TRAINEr_H_
+
+#include "roc_trainer_abstract.h"
+#include "../algs.h"
+#include <limits>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ class roc_trainer_type
+ {
+ public:
+ typedef typename trainer_type::kernel_type kernel_type;
+ typedef typename trainer_type::scalar_type scalar_type;
+ typedef typename trainer_type::sample_type sample_type;
+ typedef typename trainer_type::mem_manager_type mem_manager_type;
+ typedef typename trainer_type::trained_function_type trained_function_type;
+
+ roc_trainer_type (
+ ) : desired_accuracy(0), class_selection(0){}
+
+ roc_trainer_type (
+ const trainer_type& trainer_,
+ const scalar_type& desired_accuracy_,
+ const scalar_type& class_selection_
+ ) : trainer(trainer_), desired_accuracy(desired_accuracy_), class_selection(class_selection_)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(0 <= desired_accuracy && desired_accuracy <= 1 &&
+ (class_selection == -1 || class_selection == +1),
+ "\t roc_trainer_type::roc_trainer_type()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t desired_accuracy: " << desired_accuracy
+ << "\n\t class_selection: " << class_selection
+ );
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const trained_function_type train (
+ const in_sample_vector_type& samples,
+ const in_scalar_vector_type& labels
+ ) const
+ /*!
+ requires
+ - is_binary_classification_problem(samples, labels) == true
+ !*/
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(samples, labels),
+ "\t roc_trainer_type::train()"
+ << "\n\t invalid inputs were given to this function"
+ );
+
+
+ return do_train(mat(samples), mat(labels));
+ }
+
+ private:
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const trained_function_type do_train (
+ const in_sample_vector_type& samples,
+ const in_scalar_vector_type& labels
+ ) const
+ {
+ trained_function_type df = trainer.train(samples, labels);
+
+ // clear out the old bias
+ df.b = 0;
+
+ // obtain all the scores from the df using all the class_selection labeled samples
+ std::vector<double> scores;
+ for (long i = 0; i < samples.size(); ++i)
+ {
+ if (labels(i) == class_selection)
+ scores.push_back(df(samples(i)));
+ }
+
+ if (class_selection == +1)
+ std::sort(scores.rbegin(), scores.rend());
+ else
+ std::sort(scores.begin(), scores.end());
+
+ // now pick out the index that gives us the desired accuracy with regards to selected class
+ unsigned long idx = static_cast<unsigned long>(desired_accuracy*scores.size() + 0.5);
+ if (idx >= scores.size())
+ idx = scores.size()-1;
+
+ df.b = scores[idx];
+
+ // In this case add a very small extra amount to the bias so that all the samples
+ // with the class_selection label are classified correctly.
+ if (desired_accuracy == 1)
+ {
+ if (class_selection == +1)
+ df.b -= std::numeric_limits<scalar_type>::epsilon()*df.b;
+ else
+ df.b += std::numeric_limits<scalar_type>::epsilon()*df.b;
+ }
+
+ return df;
+ }
+
+ trainer_type trainer;
+ scalar_type desired_accuracy;
+ scalar_type class_selection;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const roc_trainer_type<trainer_type> roc_c1_trainer (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type& desired_accuracy
+ ) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, +1); }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type
+ >
+ const roc_trainer_type<trainer_type> roc_c2_trainer (
+ const trainer_type& trainer,
+ const typename trainer_type::scalar_type& desired_accuracy
+ ) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, -1); }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_ROC_TRAINEr_H_
+
+