diff options
Diffstat (limited to '')
-rw-r--r-- | ml/dlib/dlib/svm/svm_threaded_abstract.h | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/svm_threaded_abstract.h b/ml/dlib/dlib/svm/svm_threaded_abstract.h new file mode 100644 index 000000000..f9973fb5c --- /dev/null +++ b/ml/dlib/dlib/svm/svm_threaded_abstract.h @@ -0,0 +1,62 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SVm_THREADED_ABSTRACT_ +#ifdef DLIB_SVm_THREADED_ABSTRACT_ + +#include "../matrix/matrix_abstract.h" +#include "../algs.h" +#include "../svm.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const matrix<double, 1, 2, typename trainer_type::mem_manager_type> + cross_validate_trainer_threaded ( + const trainer_type& trainer, + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + const long folds, + const long num_threads + ); + /*! + requires + - is_binary_classification_problem(x,y) == true + - 1 < folds <= std::min(sum(y>0),sum(y<0)) + (e.g. There must be at least as many examples of each class as there are folds) + - trainer_type == some kind of trainer object (e.g. svm_nu_trainer) + - num_threads > 0 + - It must be safe for multiple trainer objects to access the elements of x from + multiple threads at the same time. Note that all trainers and kernels in + dlib are thread safe in this regard since they do not mutate the elements of x. + ensures + - performs k-fold cross validation by using the given trainer to solve the + given binary classification problem for the given number of folds. + Each fold is tested using the output of the trainer and the average + classification accuracy from all folds is returned. + - uses num_threads threads of execution in doing the cross validation. + - The accuracy is returned in a row vector, let us call it R. Both + quantities in R are numbers between 0 and 1 which represent the fraction + of examples correctly classified. R(0) is the fraction of +1 examples + correctly classified and R(1) is the fraction of -1 examples correctly + classified. + - The number of folds used is given by the folds argument. + throws + - any exceptions thrown by trainer.train() + - std::bad_alloc + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SVm_THREADED_ABSTRACT_ + + + |