diff options
Diffstat (limited to 'ml/dlib/dlib/svm/ranking_tools.h')
-rw-r--r-- | ml/dlib/dlib/svm/ranking_tools.h | 448 |
1 files changed, 448 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/ranking_tools.h b/ml/dlib/dlib/svm/ranking_tools.h new file mode 100644 index 00000000..3c77b41a --- /dev/null +++ b/ml/dlib/dlib/svm/ranking_tools.h @@ -0,0 +1,448 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RANKING_ToOLS_Hh_ +#define DLIB_RANKING_ToOLS_Hh_ + +#include "ranking_tools_abstract.h" + +#include "../algs.h" +#include "../matrix.h" +#include <vector> +#include <utility> +#include <algorithm> +#include "sparse_vector.h" +#include "../statistics.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + struct ranking_pair + { + ranking_pair() {} + + ranking_pair( + const std::vector<T>& r, + const std::vector<T>& nr + ) : + relevant(r), nonrelevant(nr) + {} + + std::vector<T> relevant; + std::vector<T> nonrelevant; + }; + + template < + typename T + > + void serialize ( + const ranking_pair<T>& item, + std::ostream& out + ) + { + int version = 1; + serialize(version, out); + serialize(item.relevant, out); + serialize(item.nonrelevant, out); + } + + + template < + typename T + > + void deserialize ( + ranking_pair<T>& item, + std::istream& in + ) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw dlib::serialization_error("Wrong version found while deserializing dlib::ranking_pair"); + + deserialize(item.relevant, in); + deserialize(item.nonrelevant, in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + typename disable_if<is_matrix<T>,bool>::type is_ranking_problem ( + const std::vector<ranking_pair<T> >& samples + ) + { + if (samples.size() == 0) + return false; + + + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (samples[i].relevant.size() == 0) + return false; + if (samples[i].nonrelevant.size() == 0) + return false; + } + + return true; + } + + template < + typename T + > + typename enable_if<is_matrix<T>,bool>::type is_ranking_problem ( + const std::vector<ranking_pair<T> >& samples + ) + { + if (samples.size() == 0) + return false; + + + for (unsigned long i = 0; i < samples.size(); ++i) + { + if (samples[i].relevant.size() == 0) + return false; + if (samples[i].nonrelevant.size() == 0) + return false; + } + + // If these are dense vectors then they must all have the same dimensionality. + const long dims = max_index_plus_one(samples[0].relevant); + for (unsigned long i = 0; i < samples.size(); ++i) + { + for (unsigned long j = 0; j < samples[i].relevant.size(); ++j) + { + if (is_vector(samples[i].relevant[j]) == false) + return false; + + if (samples[i].relevant[j].size() != dims) + return false; + } + for (unsigned long j = 0; j < samples[i].nonrelevant.size(); ++j) + { + if (is_vector(samples[i].nonrelevant[j]) == false) + return false; + + if (samples[i].nonrelevant[j].size() != dims) + return false; + } + } + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + unsigned long max_index_plus_one ( + const ranking_pair<T>& item + ) + { + return std::max(max_index_plus_one(item.relevant), max_index_plus_one(item.nonrelevant)); + } + + template < + typename T + > + unsigned long max_index_plus_one ( + const std::vector<ranking_pair<T> >& samples + ) + { + unsigned long dims = 0; + for (unsigned long i = 0; i < samples.size(); ++i) + { + dims = std::max(dims, max_index_plus_one(samples[i])); + } + return dims; + } + +// ---------------------------------------------------------------------------------------- + + template <typename T> + void count_ranking_inversions ( + const std::vector<T>& x, + const std::vector<T>& y, + std::vector<unsigned long>& x_count, + std::vector<unsigned long>& y_count + ) + { + x_count.assign(x.size(),0); + y_count.assign(y.size(),0); + + if (x.size() == 0 || y.size() == 0) + return; + + std::vector<std::pair<T,unsigned long> > xsort(x.size()); + std::vector<std::pair<T,unsigned long> > ysort(y.size()); + for (unsigned long i = 0; i < x.size(); ++i) + xsort[i] = std::make_pair(x[i], i); + for (unsigned long j = 0; j < y.size(); ++j) + ysort[j] = std::make_pair(y[j], j); + + std::sort(xsort.begin(), xsort.end()); + std::sort(ysort.begin(), ysort.end()); + + + unsigned long i, j; + + // Do the counting for the x values. + for (i = 0, j = 0; i < x_count.size(); ++i) + { + // Skip past y values that are in the correct order with respect to xsort[i]. + while (j < ysort.size() && ysort[j].first < xsort[i].first) + ++j; + + x_count[xsort[i].second] = ysort.size() - j; + } + + + // Now do the counting for the y values. + for (i = 0, j = 0; j < y_count.size(); ++j) + { + // Skip past x values that are in the incorrect order with respect to ysort[j]. + while (i < xsort.size() && !(ysort[j].first < xsort[i].first)) + ++i; + + y_count[ysort[j].second] = i; + } + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline bool compare_first_reverse_second ( + const std::pair<double,bool>& a, + const std::pair<double,bool>& b + ) + { + if (a.first < b.first) + return true; + else if (a.first > b.first) + return false; + else if (a.second && !b.second) + return true; + else + return false; + } + } + + template < + typename ranking_function, + typename T + > + matrix<double,1,2> test_ranking_function ( + const ranking_function& funct, + const std::vector<ranking_pair<T> >& samples + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ranking_problem(samples), + "\t double test_ranking_function()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples) + ); + + unsigned long total_pairs = 0; + unsigned long total_wrong = 0; + + std::vector<double> rel_scores; + std::vector<double> nonrel_scores; + std::vector<unsigned long> rel_counts; + std::vector<unsigned long> nonrel_counts; + + running_stats<double> rs; + std::vector<std::pair<double,bool> > total_scores; + std::vector<bool> total_ranking; + + for (unsigned long i = 0; i < samples.size(); ++i) + { + rel_scores.resize(samples[i].relevant.size()); + nonrel_scores.resize(samples[i].nonrelevant.size()); + total_scores.clear(); + + for (unsigned long k = 0; k < rel_scores.size(); ++k) + { + rel_scores[k] = funct(samples[i].relevant[k]); + total_scores.push_back(std::make_pair(rel_scores[k], true)); + } + + for (unsigned long k = 0; k < nonrel_scores.size(); ++k) + { + nonrel_scores[k] = funct(samples[i].nonrelevant[k]); + total_scores.push_back(std::make_pair(nonrel_scores[k], false)); + } + + // Now compute the average precision for this sample. We need to sort the + // results and the back them into total_ranking. Note that we sort them so + // that, if you get a block of ranking values that are all equal, the elements + // marked as true will come last. This prevents a ranking from outputting a + // constant value for everything and still getting a good MAP score. + std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second); + total_ranking.clear(); + for (unsigned long i = 0; i < total_scores.size(); ++i) + total_ranking.push_back(total_scores[i].second); + rs.add(average_precision(total_ranking)); + + + count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts); + + total_pairs += rel_scores.size()*nonrel_scores.size(); + + // Note that we don't need to look at nonrel_counts since it is redundant with + // the information in rel_counts in this case. + total_wrong += sum(mat(rel_counts)); + } + + const double rank_swaps = static_cast<double>(total_pairs - total_wrong) / total_pairs; + const double mean_average_precision = rs.mean(); + matrix<double,1,2> res; + res = rank_swaps, mean_average_precision; + return res; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename ranking_function, + typename T + > + matrix<double,1,2> test_ranking_function ( + const ranking_function& funct, + const ranking_pair<T>& sample + ) + { + return test_ranking_function(funct, std::vector<ranking_pair<T> >(1,sample)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename trainer_type, + typename T + > + matrix<double,1,2> cross_validate_ranking_trainer ( + const trainer_type& trainer, + const std::vector<ranking_pair<T> >& samples, + const long folds + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ranking_problem(samples) && + 1 < folds && folds <= static_cast<long>(samples.size()), + "\t double cross_validate_ranking_trainer()" + << "\n\t invalid inputs were given to this function" + << "\n\t samples.size(): " << samples.size() + << "\n\t folds: " << folds + << "\n\t is_ranking_problem(samples): " << is_ranking_problem(samples) + ); + + + const long num_in_test = samples.size()/folds; + const long num_in_train = samples.size() - num_in_test; + + + std::vector<ranking_pair<T> > samples_test, samples_train; + + + long next_test_idx = 0; + + unsigned long total_pairs = 0; + unsigned long total_wrong = 0; + + std::vector<double> rel_scores; + std::vector<double> nonrel_scores; + std::vector<unsigned long> rel_counts; + std::vector<unsigned long> nonrel_counts; + + running_stats<double> rs; + std::vector<std::pair<double,bool> > total_scores; + std::vector<bool> total_ranking; + + for (long i = 0; i < folds; ++i) + { + samples_test.clear(); + samples_train.clear(); + + // load up the test samples + for (long cnt = 0; cnt < num_in_test; ++cnt) + { + samples_test.push_back(samples[next_test_idx]); + next_test_idx = (next_test_idx + 1)%samples.size(); + } + + // load up the training samples + long next = next_test_idx; + for (long cnt = 0; cnt < num_in_train; ++cnt) + { + samples_train.push_back(samples[next]); + next = (next + 1)%samples.size(); + } + + + const typename trainer_type::trained_function_type& df = trainer.train(samples_train); + + // check how good df is on the test data + for (unsigned long i = 0; i < samples_test.size(); ++i) + { + rel_scores.resize(samples_test[i].relevant.size()); + nonrel_scores.resize(samples_test[i].nonrelevant.size()); + + total_scores.clear(); + + for (unsigned long k = 0; k < rel_scores.size(); ++k) + { + rel_scores[k] = df(samples_test[i].relevant[k]); + total_scores.push_back(std::make_pair(rel_scores[k], true)); + } + + for (unsigned long k = 0; k < nonrel_scores.size(); ++k) + { + nonrel_scores[k] = df(samples_test[i].nonrelevant[k]); + total_scores.push_back(std::make_pair(nonrel_scores[k], false)); + } + + // Now compute the average precision for this sample. We need to sort the + // results and the back them into total_ranking. Note that we sort them so + // that, if you get a block of ranking values that are all equal, the elements + // marked as true will come last. This prevents a ranking from outputting a + // constant value for everything and still getting a good MAP score. + std::sort(total_scores.rbegin(), total_scores.rend(), impl::compare_first_reverse_second); + total_ranking.clear(); + for (unsigned long i = 0; i < total_scores.size(); ++i) + total_ranking.push_back(total_scores[i].second); + rs.add(average_precision(total_ranking)); + + + count_ranking_inversions(rel_scores, nonrel_scores, rel_counts, nonrel_counts); + + total_pairs += rel_scores.size()*nonrel_scores.size(); + + // Note that we don't need to look at nonrel_counts since it is redundant with + // the information in rel_counts in this case. + total_wrong += sum(mat(rel_counts)); + } + + } // for (long i = 0; i < folds; ++i) + + const double rank_swaps = static_cast<double>(total_pairs - total_wrong) / total_pairs; + const double mean_average_precision = rs.mean(); + matrix<double,1,2> res; + res = rank_swaps, mean_average_precision; + return res; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RANKING_ToOLS_Hh_ + |