summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/sort_basis_vectors.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/svm/sort_basis_vectors.h')
-rw-r--r--ml/dlib/dlib/svm/sort_basis_vectors.h224
1 files changed, 224 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/sort_basis_vectors.h b/ml/dlib/dlib/svm/sort_basis_vectors.h
new file mode 100644
index 000000000..1d4605b41
--- /dev/null
+++ b/ml/dlib/dlib/svm/sort_basis_vectors.h
@@ -0,0 +1,224 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SORT_BASIS_VECTORs_Hh_
+#define DLIB_SORT_BASIS_VECTORs_Hh_
+
+#include <vector>
+
+#include "sort_basis_vectors_abstract.h"
+#include "../matrix.h"
+#include "../statistics.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace bs_impl
+ {
+ template <typename EXP>
+ typename EXP::matrix_type invert (
+ const matrix_exp<EXP>& m
+ )
+ {
+ eigenvalue_decomposition<EXP> eig(make_symmetric(m));
+
+ typedef typename EXP::type scalar_type;
+ typedef typename EXP::mem_manager_type mm_type;
+
+ matrix<scalar_type,0,1,mm_type> vals = eig.get_real_eigenvalues();
+
+ const scalar_type max_eig = max(abs(vals));
+ const scalar_type thresh = max_eig*std::sqrt(std::numeric_limits<scalar_type>::epsilon());
+
+ // Since m might be singular or almost singular we need to do something about
+ // any very small eigenvalues. So here we set the smallest eigenvalues to
+ // be equal to a large value to make the inversion stable. We can't just set
+ // them to zero like in a normal pseudo-inverse since we want the resulting
+ // inverse matrix to be full rank.
+ for (long i = 0; i < vals.size(); ++i)
+ {
+ if (std::abs(vals(i)) < thresh)
+ vals(i) = max_eig;
+ }
+
+ // Build the inverse matrix. This is basically a pseudo-inverse.
+ return make_symmetric(eig.get_pseudo_v()*diagm(reciprocal(vals))*trans(eig.get_pseudo_v()));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename vect1_type,
+ typename vect2_type,
+ typename vect3_type
+ >
+ const std::vector<typename kernel_type::sample_type> sort_basis_vectors_impl (
+ const kernel_type& kern,
+ const vect1_type& samples,
+ const vect2_type& labels,
+ const vect3_type& basis,
+ double eps
+ )
+ {
+ DLIB_ASSERT(is_binary_classification_problem(samples, labels) &&
+ 0 < eps && eps <= 1 &&
+ basis.size() > 0,
+ "\t void sort_basis_vectors()"
+ << "\n\t Invalid arguments were given to this function."
+ << "\n\t is_binary_classification_problem(samples, labels): " << is_binary_classification_problem(samples, labels)
+ << "\n\t basis.size(): " << basis.size()
+ << "\n\t eps: " << eps
+ );
+
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::mem_manager_type mm_type;
+
+ typedef matrix<scalar_type,0,1,mm_type> col_matrix;
+ typedef matrix<scalar_type,0,0,mm_type> gen_matrix;
+
+ col_matrix c1_mean, c2_mean, temp, delta;
+
+
+ col_matrix weights;
+
+ running_covariance<gen_matrix> cov;
+
+ // compute the covariance matrix and the means of the two classes.
+ for (long i = 0; i < samples.size(); ++i)
+ {
+ temp = kernel_matrix(kern, basis, samples(i));
+ cov.add(temp);
+ if (labels(i) > 0)
+ c1_mean += temp;
+ else
+ c2_mean += temp;
+ }
+
+ c1_mean /= sum(labels > 0);
+ c2_mean /= sum(labels < 0);
+
+ delta = c1_mean - c2_mean;
+
+ gen_matrix cov_inv = bs_impl::invert(cov.covariance());
+
+
+ matrix<long,0,1,mm_type> total_perm = trans(range(0, delta.size()-1));
+ matrix<long,0,1,mm_type> perm = total_perm;
+
+ std::vector<std::pair<scalar_type,long> > sorted_feats(delta.size());
+
+ long best_size = delta.size();
+ long misses = 0;
+ matrix<long,0,1,mm_type> best_total_perm = perm;
+
+ // Now we basically find fisher's linear discriminant over and over. Each
+ // time sorting the features so that the most important ones pile up together.
+ weights = trans(chol(cov_inv))*delta;
+ while (true)
+ {
+
+ for (unsigned long i = 0; i < sorted_feats.size(); ++i)
+ sorted_feats[i] = make_pair(std::abs(weights(i)), i);
+
+ std::sort(sorted_feats.begin(), sorted_feats.end());
+
+ // make a permutation vector according to the sorting
+ for (long i = 0; i < perm.size(); ++i)
+ perm(i) = sorted_feats[i].second;
+
+
+ // Apply the permutation. Doing this gives the same result as permuting all the
+ // features and then recomputing the delta and cov_inv from scratch.
+ cov_inv = subm(cov_inv,perm,perm);
+ delta = rowm(delta,perm);
+
+ // Record all the permutations we have done so we will know how the final
+ // weights match up with the original basis vectors when we are done.
+ total_perm = rowm(total_perm, perm);
+
+ // compute new Fisher weights for sorted features.
+ weights = trans(chol(cov_inv))*delta;
+
+ // Measure how many features it takes to account for eps% of the weights vector.
+ const scalar_type total_weight = length_squared(weights);
+ scalar_type weight_accum = 0;
+ long size = 0;
+ // figure out how to get eps% of the weights
+ for (long i = weights.size()-1; i >= 0; --i)
+ {
+ ++size;
+ weight_accum += weights(i)*weights(i);
+ if (weight_accum/total_weight > eps)
+ break;
+ }
+
+ // loop until the best_size stops dropping
+ if (size < best_size)
+ {
+ misses = 0;
+ best_size = size;
+ best_total_perm = total_perm;
+ }
+ else
+ {
+ ++misses;
+
+ // Give up once we have had 10 rounds where we didn't find a weights vector with
+ // a smaller concentration of good features.
+ if (misses >= 10)
+ break;
+ }
+
+ }
+
+ // make sure best_size isn't zero
+ if (best_size == 0)
+ best_size = 1;
+
+ std::vector<typename kernel_type::sample_type> sorted_basis;
+
+ // permute the basis so that it matches up with the contents of the best weights
+ sorted_basis.resize(best_size);
+ for (unsigned long i = 0; i < sorted_basis.size(); ++i)
+ {
+ // Note that we load sorted_basis backwards so that the most important
+ // basis elements come first.
+ sorted_basis[i] = basis(best_total_perm(basis.size()-i-1));
+ }
+
+ return sorted_basis;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename kernel_type,
+ typename vect1_type,
+ typename vect2_type,
+ typename vect3_type
+ >
+ const std::vector<typename kernel_type::sample_type> sort_basis_vectors (
+ const kernel_type& kern,
+ const vect1_type& samples,
+ const vect2_type& labels,
+ const vect3_type& basis,
+ double eps = 0.99
+ )
+ {
+ return bs_impl::sort_basis_vectors_impl(kern,
+ mat(samples),
+ mat(labels),
+ mat(basis),
+ eps);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SORT_BASIS_VECTORs_Hh_
+