summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/ranking_tools.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/svm/ranking_tools.h')
-rw-r--r--ml/dlib/dlib/svm/ranking_tools.h448
1 files changed, 448 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/ranking_tools.h b/ml/dlib/dlib/svm/ranking_tools.h
new file mode 100644
index 00000000..3c77b41a
--- /dev/null
+++ b/ml/dlib/dlib/svm/ranking_tools.h
@@ -0,0 +1,448 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_RANKING_ToOLS_Hh_
+#define DLIB_RANKING_ToOLS_Hh_
+
+#include "ranking_tools_abstract.h"
+
+#include "../algs.h"
+#include "../matrix.h"
+#include <vector>
+#include <utility>
+#include <algorithm>
+#include "sparse_vector.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ struct ranking_pair
+ {
+ ranking_pair() {}
+
+ ranking_pair(
+ const std::vector<T>& r,
+ const std::vector<T>& nr
+ ) :
+ relevant(r), nonrelevant(nr)
+ {}
+
+ std::vector<T> relevant;
+ std::vector<T> nonrelevant;
+ };
+
+ template <
+ typename T
+ >
+ void serialize (
+ const ranking_pair<T>& item,
+ std::ostream& out
+ )
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.relevant, out);
+ serialize(item.nonrelevant, out);
+ }
+
+
+ template <
+ typename T
+ >
+ void deserialize (
+ ranking_pair<T>& item,
+ std::istream& in
+ )
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw dlib::serialization_error("Wrong version found while deserializing dlib::ranking_pair");
+
+ deserialize(item.relevant, in);
+ deserialize(item.nonrelevant, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ typename disable_if<is_matrix<T>,bool>::type is_ranking_problem (
+ const std::vector<ranking_pair<T> >& samples
+ )
+ {
+ if (samples.size() == 0)
+ return false;
+
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ if (samples[i].relevant.size() == 0)
+ return false;
+ if (samples[i].nonrelevant.size() == 0)
+ return false;
+ }
+
+ return true;
+ }
+
+ template <
+ typename T
+ >
+ typename enable_if<is_matrix<T>,bool>::type is_ranking_problem (
+ const std::vector<ranking_pair<T> >& samples
+ )
+ {
+ if (samples.size() == 0)
+ return false;
+
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ if (samples[i].relevant.size() == 0)
+ return false;
+ if (samples[i].nonrelevant.size() == 0)
+ return false;
+ }
+
+ // If these are dense vectors then they must all have the same dimensionality.
+ const long dims = max_index_plus_one(samples[0].relevant);
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ for (unsigned long j = 0; j < samples[i].relevant.size(); ++j)
+ {
+ if (is_vector(samples[i].relevant[j]) == false)
+ return false;
+
+ if (samples[i].relevant[j].size() != dims)
+ return false;
+ }
+ for (unsigned long j = 0; j < samples[i].nonrelevant.size(); ++j)
+ {
+ if (is_vector(samples[i].nonrelevant[j]) == false)
+ return false;
+
+ if (samples[i].nonrelevant[j].size() != dims)
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename T
+ >
+ unsigned long max_index_plus_one (
+ const ranking_pair<T>& item
+ )
+ {
+ return std::max(max_index_plus_one(item.relevant), max_index_plus_one(item.nonrelevant));
+ }
+
+ template <
+ typename T
+ >
+ unsigned long max_index_plus_one (
+ const std::vector<ranking_pair<T> >& samples
+ )
+ {
+ unsigned long dims = 0;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ dims = std::max(dims, max_index_plus_one(samples[i]));
+ }
+ return dims;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename T>
+ void count_ranking_inversions (
+ const std::vector<T>& x,
+ const std::vector<T>& y,
+ std::vector<unsigned long>& x_count,
+ std::vector<unsigned long>& y_count
+ )
+ {
+ x_count.assign(x.size(),0);
+ y_count.assign(y.size(),0);
+
+ if (x.size() == 0 || y.size() == 0)
+ return;
+
+ std::vector<std::pair<T,unsigned long> > xsort(x.size());
+ std::vector<std::pair<T,unsigned long> > ysort(y.size());
+ for (unsigned long i = 0; i < x.size(); ++i)
+ xsort[i] = std::make_pair(x[i], i);
+ for (unsigned long j = 0; j < y.size(); ++j)
+ ysort[j] = std::make_pair(y[j], j);
+
+ std::sort(xsort.begin(), xsort.end());
+ std::sort(ysort.begin(), ysort.end());
+
+
+ unsigned long i, j;
+
+ // Do the counting for the x values.
+ for (i = 0, j = 0; i < x_count.size(); ++i)
+ {
+ // Skip past y values that are in the correct order with respect to xsort[i].
+ while (j < ysort.size() && ysort[j].first < xsort[i].first)
+ ++j;
+
+ x_count[xsort[i].second] = ysort.size() - j;
+ }
+
+
+ // Now do the counting for the y values.
+ for (i = 0, j = 0; j < y_count.size(); ++j)
+ {
+ // Skip past x values that are in the incorrect order with respect to ysort[j].
+ while (i < xsort.size() && !(ysort[j].first < xsort[i].first))
+ ++i;
+
+ y_count[ysort[j].second] = i;
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ inline bool compare_first_reverse_second (
+ const std::pair<double,bool>& a,
+ const std::pair<double,bool>& b
+ )
+ {
+ if (a.first < b.first)
+ return true;
+ else if (a.first > b.first)
+ return false;
+ else if (a.second && !b.second)
+ return true;
+ else
+ return false;
+ }
+ }
+
+ template <
+ typename ranking_function,
+ typename T
+ >
+ matrix<double,1,2> test_ranking_function (
+ const ranking_function& funct,
+ const std::vector<ranking_pair<T> >& samples
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_ranking_problem(samples),
+ "\t double test_ranking_function()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples)
+ );
+
+ unsigned long total_pairs = 0;
+ unsigned long total_wrong = 0;
+
+ std::vector<double> rel_scores;
+ std::vector<double> nonrel_scores;
+ std::vector<unsigned long> rel_counts;
+ std::vector<unsigned long> nonrel_counts;
+
+ running_stats<double> rs;
+ std::vector<std::pair<double,bool> > total_scores;
+ std::vector<bool> total_ranking;
+
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ rel_scores.resize(samples[i].relevant.size());
+ nonrel_scores.resize(samples[i].nonrelevant.size());
+ total_scores.clear();
+
+ for (unsigned long k = 0; k < rel_scores.size(); ++k)
+ {
+ rel_scores[k] = funct(samples[i].relevant[k]);
+ total_scores.push_back(std::make_pair(rel_scores[k], true));
+ }
+
+ for (unsigned long k = 0; k < nonrel_scores.size(); ++k)
+ {
+ nonrel_scores[k] = funct(samples[i].nonrelevant[k]);
+ total_scores.push_back(std::make_pair(nonrel_scores[k], false));
+ }
+
+ // Now compute the average precision for this sample. We need to sort the
+ // results and the back them into total_ranking. Note that we sort them so
+ // that, if you get a block of ranking values that are all equal, the elements
+ // marked as true will come last. This prevents a ranking from outputting a
+ // constant value for everything and still getting a good MAP score.
+ std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second);
+ total_ranking.clear();
+ for (unsigned long i = 0; i < total_scores.size(); ++i)
+ total_ranking.push_back(total_scores[i].second);
+ rs.add(average_precision(total_ranking));
+
+
+ count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts);
+
+ total_pairs += rel_scores.size()*nonrel_scores.size();
+
+ // Note that we don't need to look at nonrel_counts since it is redundant with
+ // the information in rel_counts in this case.
+ total_wrong += sum(mat(rel_counts));
+ }
+
+ const double rank_swaps = static_cast<double>(total_pairs - total_wrong) / total_pairs;
+ const double mean_average_precision = rs.mean();
+ matrix<double,1,2> res;
+ res = rank_swaps, mean_average_precision;
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename ranking_function,
+ typename T
+ >
+ matrix<double,1,2> test_ranking_function (
+ const ranking_function& funct,
+ const ranking_pair<T>& sample
+ )
+ {
+ return test_ranking_function(funct, std::vector<ranking_pair<T> >(1,sample));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename trainer_type,
+ typename T
+ >
+ matrix<double,1,2> cross_validate_ranking_trainer (
+ const trainer_type& trainer,
+ const std::vector<ranking_pair<T> >& samples,
+ const long folds
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_ranking_problem(samples) &&
+ 1 < folds && folds <= static_cast<long>(samples.size()),
+ "\t double cross_validate_ranking_trainer()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t samples.size(): " << samples.size()
+ << "\n\t folds: " << folds
+ << "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples)
+ );
+
+
+ const long num_in_test = samples.size()/folds;
+ const long num_in_train = samples.size() - num_in_test;
+
+
+ std::vector<ranking_pair<T> > samples_test, samples_train;
+
+
+ long next_test_idx = 0;
+
+ unsigned long total_pairs = 0;
+ unsigned long total_wrong = 0;
+
+ std::vector<double> rel_scores;
+ std::vector<double> nonrel_scores;
+ std::vector<unsigned long> rel_counts;
+ std::vector<unsigned long> nonrel_counts;
+
+ running_stats<double> rs;
+ std::vector<std::pair<double,bool> > total_scores;
+ std::vector<bool> total_ranking;
+
+ for (long i = 0; i < folds; ++i)
+ {
+ samples_test.clear();
+ samples_train.clear();
+
+ // load up the test samples
+ for (long cnt = 0; cnt < num_in_test; ++cnt)
+ {
+ samples_test.push_back(samples[next_test_idx]);
+ next_test_idx = (next_test_idx + 1)%samples.size();
+ }
+
+ // load up the training samples
+ long next = next_test_idx;
+ for (long cnt = 0; cnt < num_in_train; ++cnt)
+ {
+ samples_train.push_back(samples[next]);
+ next = (next + 1)%samples.size();
+ }
+
+
+ const typename trainer_type::trained_function_type& df = trainer.train(samples_train);
+
+ // check how good df is on the test data
+ for (unsigned long i = 0; i < samples_test.size(); ++i)
+ {
+ rel_scores.resize(samples_test[i].relevant.size());
+ nonrel_scores.resize(samples_test[i].nonrelevant.size());
+
+ total_scores.clear();
+
+ for (unsigned long k = 0; k < rel_scores.size(); ++k)
+ {
+ rel_scores[k] = df(samples_test[i].relevant[k]);
+ total_scores.push_back(std::make_pair(rel_scores[k], true));
+ }
+
+ for (unsigned long k = 0; k < nonrel_scores.size(); ++k)
+ {
+ nonrel_scores[k] = df(samples_test[i].nonrelevant[k]);
+ total_scores.push_back(std::make_pair(nonrel_scores[k], false));
+ }
+
+ // Now compute the average precision for this sample. We need to sort the
+ // results and the back them into total_ranking. Note that we sort them so
+ // that, if you get a block of ranking values that are all equal, the elements
+ // marked as true will come last. This prevents a ranking from outputting a
+ // constant value for everything and still getting a good MAP score.
+ std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second);
+ total_ranking.clear();
+ for (unsigned long i = 0; i < total_scores.size(); ++i)
+ total_ranking.push_back(total_scores[i].second);
+ rs.add(average_precision(total_ranking));
+
+
+ count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts);
+
+ total_pairs += rel_scores.size()*nonrel_scores.size();
+
+ // Note that we don't need to look at nonrel_counts since it is redundant with
+ // the information in rel_counts in this case.
+ total_wrong += sum(mat(rel_counts));
+ }
+
+ } // for (long i = 0; i < folds; ++i)
+
+ const double rank_swaps = static_cast<double>(total_pairs - total_wrong) / total_pairs;
+ const double mean_average_precision = rs.mean();
+ matrix<double,1,2> res;
+ res = rank_swaps, mean_average_precision;
+ return res;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RANKING_ToOLS_Hh_
+