diff options
Diffstat (limited to 'ml/dlib/dlib/statistics/lda.h')
-rw-r--r-- | ml/dlib/dlib/statistics/lda.h | 237 |
1 files changed, 237 insertions, 0 deletions
diff --git a/ml/dlib/dlib/statistics/lda.h b/ml/dlib/dlib/statistics/lda.h new file mode 100644 index 000000000..38de3fd1e --- /dev/null +++ b/ml/dlib/dlib/statistics/lda.h @@ -0,0 +1,237 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_LDA_Hh_ +#define DLIB_LDA_Hh_ + +#include "lda_abstract.h" +#include "../algs.h" +#include <map> +#include "../matrix.h" +#include <vector> + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + + inline std::map<unsigned long,unsigned long> make_class_labels( + const std::vector<unsigned long>& row_labels + ) + { + std::map<unsigned long,unsigned long> class_labels; + for (unsigned long i = 0; i < row_labels.size(); ++i) + { + const unsigned long next = class_labels.size(); + if (class_labels.count(row_labels[i]) == 0) + class_labels[row_labels[i]] = next; + } + return class_labels; + } + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + matrix<T,0,1> center_matrix ( + matrix<T>& X + ) + { + matrix<T,1> mean; + for (long r = 0; r < X.nr(); ++r) + mean += rowm(X,r); + mean /= X.nr(); + + for (long r = 0; r < X.nr(); ++r) + set_rowm(X,r) -= mean; + + return trans(mean); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void compute_lda_transform ( + matrix<T>& X, + matrix<T,0,1>& mean, + const std::vector<unsigned long>& row_labels, + unsigned long lda_dims = 500, + unsigned long extra_pca_dims = 200 + ) + { + std::map<unsigned long,unsigned long> class_labels = impl::make_class_labels(row_labels); + // LDA can only give out at most class_labels.size()-1 dimensions so don't try to + // compute more than that. + lda_dims = std::min<unsigned long>(lda_dims, class_labels.size()-1); + + // make sure requires clause is not broken + DLIB_CASSERT(class_labels.size() > 1, + "\t void compute_lda_transform()" + << "\n\t You can't call this function if the number of distinct class labels is less than 2." + ); + DLIB_CASSERT(X.size() != 0 && (long)row_labels.size() == X.nr() && lda_dims != 0, + "\t void compute_lda_transform()" + << "\n\t Invalid inputs were given to this function." + << "\n\t X.size(): " << X.size() + << "\n\t row_labels.size(): " << row_labels.size() + << "\n\t lda_dims: " << lda_dims + ); + + + mean = impl::center_matrix(X); + // Do PCA to reduce dims + matrix<T> pu,pw,pv; + svd_fast(X, pu, pw, pv, lda_dims+extra_pca_dims, 4); + pu.set_size(0,0); // free RAM, we don't need pu. + X = X*pv; + + + matrix<T> class_means(class_labels.size(), X.nc()); + class_means = 0; + matrix<T,0,1> class_counts(class_labels.size()); + class_counts = 0; + + // First compute the means of each class + for (unsigned long i = 0; i < row_labels.size(); ++i) + { + const unsigned long class_idx = class_labels[row_labels[i]]; + set_rowm(class_means,class_idx) += rowm(X,i); + class_counts(class_idx)++; + } + class_means = inv(diagm(class_counts))*class_means; + // subtract means from the data + for (unsigned long i = 0; i < row_labels.size(); ++i) + { + const unsigned long class_idx = class_labels[row_labels[i]]; + set_rowm(X,i) -= rowm(class_means,class_idx); + } + + // Note that we are using the formulas from the paper Using Discriminant + // Eigenfeatures for Image Retrieval by Swets and Weng. + matrix<T> Sw = trans(X)*X; + matrix<T> Sb = trans(class_means)*class_means; + matrix<T> A, H; + matrix<T,0,1> W; + svd3(Sw, A, W, H); + W = sqrt(W); + W = reciprocal(lowerbound(W,max(W)*1e-5)); + A = trans(H*diagm(W))*Sb*H*diagm(W); + matrix<T> v,s,u; + svd3(A, v, s, u); + matrix<T> tform = H*diagm(W)*u; + // pick out only the number of dimensions we are supposed to for the output, unless + // we should just keep them all, then don't do anything. + if ((long)lda_dims <= tform.nc()) + { + rsort_columns(tform, s); + tform = colm(tform, range(0, lda_dims-1)); + } + + X = trans(pv*tform); + mean = X*mean; + } + +// ---------------------------------------------------------------------------------------- + + inline std::pair<double,double> equal_error_rate ( + const std::vector<double>& low_vals, + const std::vector<double>& high_vals + ) + { + std::vector<std::pair<double,int> > temp; + temp.reserve(low_vals.size()+high_vals.size()); + for (unsigned long i = 0; i < low_vals.size(); ++i) + temp.push_back(std::make_pair(low_vals[i], -1)); + for (unsigned long i = 0; i < high_vals.size(); ++i) + temp.push_back(std::make_pair(high_vals[i], +1)); + + std::sort(temp.begin(), temp.end()); + + if (temp.size() == 0) + return std::make_pair(0,0); + + double thresh = temp[0].first; + + unsigned long num_low_wrong = low_vals.size(); + unsigned long num_high_wrong = 0; + double low_error = num_low_wrong/(double)low_vals.size(); + double high_error = num_high_wrong/(double)high_vals.size(); + for (unsigned long i = 0; i < temp.size() && high_error < low_error; ++i) + { + thresh = temp[i].first; + if (temp[i].second > 0) + { + num_high_wrong++; + high_error = num_high_wrong/(double)high_vals.size(); + } + else + { + num_low_wrong--; + low_error = num_low_wrong/(double)low_vals.size(); + } + } + + return std::make_pair((low_error+high_error)/2, thresh); + } + +// ---------------------------------------------------------------------------------------- + + struct roc_point + { + double true_positive_rate; + double false_positive_rate; + double detection_threshold; + }; + + inline std::vector<roc_point> compute_roc_curve ( + const std::vector<double>& true_detections, + const std::vector<double>& false_detections + ) + { + DLIB_CASSERT(true_detections.size() != 0); + DLIB_CASSERT(false_detections.size() != 0); + + std::vector<std::pair<double,int> > temp; + temp.reserve(true_detections.size()+false_detections.size()); + for (unsigned long i = 0; i < true_detections.size(); ++i) + temp.push_back(std::make_pair(true_detections[i], +1)); + for (unsigned long i = 0; i < false_detections.size(); ++i) + temp.push_back(std::make_pair(false_detections[i], -1)); + + std::sort(temp.rbegin(), temp.rend()); + + + std::vector<roc_point> roc_curve; + roc_curve.reserve(temp.size()); + + double num_false_included = 0; + double num_true_included = 0; + for (unsigned long i = 0; i < temp.size(); ++i) + { + if (temp[i].second > 0) + num_true_included++; + else + num_false_included++; + + roc_point p; + p.true_positive_rate = num_true_included/true_detections.size(); + p.false_positive_rate = num_false_included/false_detections.size(); + p.detection_threshold = temp[i].first; + roc_curve.push_back(p); + } + + return roc_curve; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_LDA_Hh_ + |