summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/svm_threaded_abstract.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/svm/svm_threaded_abstract.h')
-rw-r--r--ml/dlib/dlib/svm/svm_threaded_abstract.h62
1 files changed, 62 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/svm_threaded_abstract.h b/ml/dlib/dlib/svm/svm_threaded_abstract.h
new file mode 100644
index 000000000..f9973fb5c
--- /dev/null
+++ b/ml/dlib/dlib/svm/svm_threaded_abstract.h
@@ -0,0 +1,62 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_SVm_THREADED_ABSTRACT_
+#ifdef DLIB_SVm_THREADED_ABSTRACT_
+
+#include "../matrix/matrix_abstract.h"
+#include "../algs.h"
+#include "../svm.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const matrix<double, 1, 2, typename trainer_type::mem_manager_type>
+ cross_validate_trainer_threaded (
+ const trainer_type& trainer,
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ const long folds,
+ const long num_threads
+ );
+ /*!
+ requires
+ - is_binary_classification_problem(x,y) == true
+ - 1 < folds <= std::min(sum(y>0),sum(y<0))
+ (e.g. There must be at least as many examples of each class as there are folds)
+ - trainer_type == some kind of trainer object (e.g. svm_nu_trainer)
+ - num_threads > 0
+ - It must be safe for multiple trainer objects to access the elements of x from
+ multiple threads at the same time. Note that all trainers and kernels in
+ dlib are thread safe in this regard since they do not mutate the elements of x.
+ ensures
+ - performs k-fold cross validation by using the given trainer to solve the
+ given binary classification problem for the given number of folds.
+ Each fold is tested using the output of the trainer and the average
+ classification accuracy from all folds is returned.
+ - uses num_threads threads of execution in doing the cross validation.
+ - The accuracy is returned in a row vector, let us call it R. Both
+ quantities in R are numbers between 0 and 1 which represent the fraction
+ of examples correctly classified. R(0) is the fraction of +1 examples
+ correctly classified and R(1) is the fraction of -1 examples correctly
+ classified.
+ - The number of folds used is given by the folds argument.
+ throws
+ - any exceptions thrown by trainer.train()
+ - std::bad_alloc
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SVm_THREADED_ABSTRACT_
+
+
+