summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h')
-rw-r--r--ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h83
1 files changed, 83 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h b/ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h
new file mode 100644
index 000000000..3d2409b28
--- /dev/null
+++ b/ml/dlib/dlib/svm/cross_validate_sequence_labeler_abstract.h
@@ -0,0 +1,83 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_ABSTRACT_Hh_
+#ifdef DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_ABSTRACT_Hh_
+
+#include <vector>
+#include "../matrix.h"
+#include "svm.h"
+
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename sequence_labeler_type,
+ typename sequence_type
+ >
+ const matrix<double> test_sequence_labeler (
+ const sequence_labeler_type& labeler,
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<unsigned long> >& labels
+ );
+ /*!
+ requires
+ - is_sequence_labeling_problem(samples, labels)
+ - sequence_labeler_type == dlib::sequence_labeler or an object with a
+ compatible interface.
+ ensures
+ - Tests labeler against the given samples and labels and returns a confusion
+ matrix summarizing the results.
+ - The confusion matrix C returned by this function has the following properties.
+ - C.nc() == labeler.num_labels()
+ - C.nr() == labeler.num_labels()
+ - C(T,P) == the number of times a sequence element with label T was predicted
+ to have a label of P.
+ - Any samples with a label value >= labeler.num_labels() are ignored. That
+ is, samples with labels the labeler hasn't ever seen before are ignored.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename sequence_type
+ >
+ const matrix<double> cross_validate_sequence_labeler (
+ const trainer_type& trainer,
+ const std::vector<sequence_type>& samples,
+ const std::vector<std::vector<unsigned long> >& labels,
+ const long folds
+ );
+ /*!
+ requires
+ - is_sequence_labeling_problem(samples, labels)
+ - 1 < folds <= samples.size()
+ - for all valid i and j: labels[i][j] < trainer.num_labels()
+ - trainer_type == dlib::structural_sequence_labeling_trainer or an object
+ with a compatible interface.
+ ensures
+ - performs k-fold cross validation by using the given trainer to solve the
+ given sequence labeling problem for the given number of folds. Each fold
+ is tested using the output of the trainer and the confusion matrix from all
+ folds is summed and returned.
+ - The total confusion matrix is computed by running test_sequence_labeler()
+ on each fold and summing its output.
+ - The number of folds used is given by the folds argument.
+ - The confusion matrix C returned by this function has the following properties.
+ - C.nc() == trainer.num_labels()
+ - C.nr() == trainer.num_labels()
+ - C(T,P) == the number of times a sequence element with label T was predicted
+ to have a label of P.
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_CROSS_VALIDATE_SEQUENCE_LABeLER_ABSTRACT_Hh_
+
+
+