diff options
Diffstat (limited to 'src/ml/dlib/dlib/svm/roc_trainer_abstract.h')
-rw-r--r-- | src/ml/dlib/dlib/svm/roc_trainer_abstract.h | 135 |
1 files changed, 135 insertions, 0 deletions
diff --git a/src/ml/dlib/dlib/svm/roc_trainer_abstract.h b/src/ml/dlib/dlib/svm/roc_trainer_abstract.h new file mode 100644 index 000000000..74e6f9b65 --- /dev/null +++ b/src/ml/dlib/dlib/svm/roc_trainer_abstract.h @@ -0,0 +1,135 @@ +// Copyright (C) 2009 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ROC_TRAINEr_ABSTRACT_ +#ifdef DLIB_ROC_TRAINEr_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + class roc_trainer_type + { + /*! + REQUIREMENTS ON trainer_type + - trainer_type == some kind of batch trainer object (e.g. svm_nu_trainer) + + WHAT THIS OBJECT REPRESENTS + This object is a simple trainer post processor that allows you to + easily adjust the bias term in a trained decision_function object. + That is, this object lets you pick a point on the ROC curve and + it will adjust the bias term appropriately. + + So for example, suppose you wanted to set the bias term so that + the accuracy of your decision function on +1 labeled samples was 99%. + To do this you would use an instance of this object declared as follows: + roc_trainer_type<trainer_type>(your_trainer, 0.99, +1); + !*/ + + public: + typedef typename trainer_type::kernel_type kernel_type; + typedef typename trainer_type::scalar_type scalar_type; + typedef typename trainer_type::sample_type sample_type; + typedef typename trainer_type::mem_manager_type mem_manager_type; + typedef typename trainer_type::trained_function_type trained_function_type; + + roc_trainer_type ( + ); + /*! + ensures + - This object is in an uninitialized state. You must + construct a real one with the other constructor and assign it + to this instance before you use this object. + !*/ + + roc_trainer_type ( + const trainer_type& trainer_, + const scalar_type& desired_accuracy_, + const scalar_type& class_selection_ + ); + /*! + requires + - 0 <= desired_accuracy_ <= 1 + - class_selection_ == +1 or -1 + ensures + - when training is performed using this object it will automatically + adjust the bias term in the returned decision function so that it + achieves the desired accuracy on the selected class type. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const trained_function_type train ( + const in_sample_vector_type& samples, + const in_scalar_vector_type& labels + ) const + /*! + requires + - is_binary_classification_problem(samples, labels) == true + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + ensures + - performs training using the trainer object given to this object's + constructor, then modifies the bias term in the returned decision function + as discussed above, and finally returns the decision function. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const roc_trainer_type<trainer_type> roc_c1_trainer ( + const trainer_type& trainer, + const typename trainer_type::scalar_type& desired_accuracy + ) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, +1); } + /*! + requires + - 0 <= desired_accuracy <= 1 + - trainer_type == some kind of batch trainer object that creates decision_function + objects (e.g. svm_nu_trainer) + ensures + - returns a roc_trainer_type object that has been instantiated with the given + arguments. The returned roc trainer will select the decision function + bias that gives the desired accuracy with respect to the +1 class. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type + > + const roc_trainer_type<trainer_type> roc_c2_trainer ( + const trainer_type& trainer, + const typename trainer_type::scalar_type& desired_accuracy + ) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, -1); } + /*! + requires + - 0 <= desired_accuracy <= 1 + - trainer_type == some kind of batch trainer object that creates decision_function + objects (e.g. svm_nu_trainer) + ensures + - returns a roc_trainer_type object that has been instantiated with the given + arguments. The returned roc trainer will select the decision function + bias that gives the desired accuracy with respect to the -1 class. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ROC_TRAINEr_ABSTRACT_ + + + |