summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/krr_trainer.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/svm/krr_trainer.h')
-rw-r--r--ml/dlib/dlib/svm/krr_trainer.h368
1 files changed, 368 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/krr_trainer.h b/ml/dlib/dlib/svm/krr_trainer.h
new file mode 100644
index 000000000..a43431169
--- /dev/null
+++ b/ml/dlib/dlib/svm/krr_trainer.h
@@ -0,0 +1,368 @@
+// Copyright (C) 2010 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_KRR_TRAInER_Hh_
+#define DLIB_KRR_TRAInER_Hh_
+
+#include "../algs.h"
+#include "function.h"
+#include "kernel.h"
+#include "empirical_kernel_map.h"
+#include "linearly_independent_subset_finder.h"
+#include "../statistics.h"
+#include "rr_trainer.h"
+#include "krr_trainer_abstract.h"
+#include <vector>
+#include <iostream>
+
+namespace dlib
+{
+ template <
+ typename K
+ >
+ class krr_trainer
+ {
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ krr_trainer (
+ ) :
+ verbose(false),
+ max_basis_size(400),
+ ekm_stale(true)
+ {
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ trainer.be_verbose();
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ trainer.be_quiet();
+ }
+
+ void use_regression_loss_for_loo_cv (
+ )
+ {
+ trainer.use_regression_loss_for_loo_cv();
+ }
+
+ void use_classification_loss_for_loo_cv (
+ )
+ {
+ trainer.use_classification_loss_for_loo_cv();
+ }
+
+ bool will_use_regression_loss_for_loo_cv (
+ ) const
+ {
+ return trainer.will_use_regression_loss_for_loo_cv();
+ }
+
+ const kernel_type get_kernel (
+ ) const
+ {
+ return kern;
+ }
+
+ void set_kernel (
+ const kernel_type& k
+ )
+ {
+ kern = k;
+ }
+
+ template <typename T>
+ void set_basis (
+ const T& basis_samples
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(basis_samples.size() > 0 && is_vector(mat(basis_samples)),
+ "\tvoid krr_trainer::set_basis(basis_samples)"
+ << "\n\t You have to give a non-empty set of basis_samples and it must be a vector"
+ << "\n\t basis_samples.size(): " << basis_samples.size()
+ << "\n\t is_vector(mat(basis_samples)): " << is_vector(mat(basis_samples))
+ << "\n\t this: " << this
+ );
+
+ basis = mat(basis_samples);
+ ekm_stale = true;
+ }
+
+ bool basis_loaded (
+ ) const
+ {
+ return (basis.size() != 0);
+ }
+
+ void clear_basis (
+ )
+ {
+ basis.set_size(0);
+ ekm.clear();
+ ekm_stale = true;
+ }
+
+ unsigned long get_max_basis_size (
+ ) const
+ {
+ return max_basis_size;
+ }
+
+ void set_max_basis_size (
+ unsigned long max_basis_size_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(max_basis_size_ > 0,
+ "\t void krr_trainer::set_max_basis_size()"
+ << "\n\t max_basis_size_ must be greater than 0"
+ << "\n\t max_basis_size_: " << max_basis_size_
+ << "\n\t this: " << this
+ );
+
+ max_basis_size = max_basis_size_;
+ }
+
+ void set_lambda (
+ scalar_type lambda_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(lambda_ >= 0,
+ "\t void krr_trainer::set_lambda()"
+ << "\n\t lambda must be greater than or equal to 0"
+ << "\n\t lambda_: " << lambda_
+ << "\n\t this: " << this
+ );
+
+ trainer.set_lambda(lambda_);
+ }
+
+ const scalar_type get_lambda (
+ ) const
+ {
+ return trainer.get_lambda();
+ }
+
+ template <typename EXP>
+ void set_search_lambdas (
+ const matrix_exp<EXP>& lambdas
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_vector(lambdas) && lambdas.size() > 0 && min(lambdas) > 0,
+ "\t void krr_trainer::set_search_lambdas()"
+ << "\n\t lambdas must be a non-empty vector of values"
+ << "\n\t is_vector(lambdas): " << is_vector(lambdas)
+ << "\n\t lambdas.size(): " << lambdas.size()
+ << "\n\t min(lambdas): " << min(lambdas)
+ << "\n\t this: " << this
+ );
+
+ trainer.set_search_lambdas(lambdas);
+ }
+
+ const matrix<scalar_type,0,0,mem_manager_type>& get_search_lambdas (
+ ) const
+ {
+ return trainer.get_search_lambdas();
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ {
+ std::vector<scalar_type> temp;
+ scalar_type temp2;
+ return do_train(mat(x), mat(y), false, temp, temp2);
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ std::vector<scalar_type>& loo_values
+ ) const
+ {
+ scalar_type temp;
+ return do_train(mat(x), mat(y), true, loo_values, temp);
+ }
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ std::vector<scalar_type>& loo_values,
+ scalar_type& lambda_used
+ ) const
+ {
+ return do_train(mat(x), mat(y), true, loo_values, lambda_used);
+ }
+
+
+ private:
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> do_train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y,
+ const bool output_loo_values,
+ std::vector<scalar_type>& loo_values,
+ scalar_type& the_lambda
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_learning_problem(x,y),
+ "\t decision_function krr_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_vector(x): " << is_vector(x)
+ << "\n\t is_vector(y): " << is_vector(y)
+ << "\n\t x.size(): " << x.size()
+ << "\n\t y.size(): " << y.size()
+ );
+
+#ifdef ENABLE_ASSERTS
+ if (get_lambda() == 0 && will_use_regression_loss_for_loo_cv() == false)
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_binary_classification_problem(x,y),
+ "\t decision_function krr_trainer::train(x,y)"
+ << "\n\t invalid inputs were given to this function"
+ );
+ }
+#endif
+
+ // The first thing we do is make sure we have an appropriate ekm ready for use below.
+ if (basis_loaded())
+ {
+ if (ekm_stale)
+ {
+ ekm.load(kern, basis);
+ ekm_stale = false;
+ }
+ }
+ else
+ {
+ linearly_independent_subset_finder<kernel_type> lisf(kern, max_basis_size);
+ fill_lisf(lisf, x);
+ ekm.load(lisf);
+ }
+
+ if (verbose)
+ {
+ std::cout << "\nNumber of basis vectors used: " << ekm.out_vector_size() << std::endl;
+ }
+
+ typedef matrix<scalar_type,0,1,mem_manager_type> column_matrix_type;
+
+ running_stats<scalar_type> rs;
+
+ // Now we project all the x samples into kernel space using our EKM
+ matrix<column_matrix_type,0,1,mem_manager_type > proj_x;
+ proj_x.set_size(x.size());
+ for (long i = 0; i < proj_x.size(); ++i)
+ {
+ scalar_type err;
+ // Note that we also append a 1 to the end of the vectors because this is
+ // a convenient way of dealing with the bias term later on.
+ if (verbose == false)
+ {
+ proj_x(i) = ekm.project(x(i));
+ }
+ else
+ {
+ proj_x(i) = ekm.project(x(i),err);
+ rs.add(err);
+ }
+ }
+
+ if (verbose)
+ {
+ std::cout << "Mean EKM projection error: " << rs.mean() << std::endl;
+ std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl;
+ }
+
+
+ decision_function<linear_kernel<matrix<scalar_type,0,0,mem_manager_type> > > lin_df;
+
+ if (output_loo_values)
+ lin_df = trainer.train(proj_x,y, loo_values, the_lambda);
+ else
+ lin_df = trainer.train(proj_x,y);
+
+ // convert the linear decision function into a kernelized one.
+ decision_function<kernel_type> df;
+ df = ekm.convert_to_decision_function(lin_df.basis_vectors(0));
+ df.b = lin_df.b;
+
+ // If we used an automatically derived basis then there isn't any point in
+ // keeping the ekm around. So free its memory.
+ if (basis_loaded() == false)
+ {
+ ekm.clear();
+ }
+
+ return df;
+ }
+
+
+ /*!
+ CONVENTION
+ - if (ekm_stale) then
+ - kern or basis have changed since the last time
+ they were loaded into the ekm
+
+ - get_lambda() == trainer.get_lambda()
+ - get_kernel() == kern
+ - get_max_basis_size() == max_basis_size
+ - will_use_regression_loss_for_loo_cv() == trainer.will_use_regression_loss_for_loo_cv()
+ - get_search_lambdas() == trainer.get_search_lambdas()
+
+ - basis_loaded() == (basis.size() != 0)
+ !*/
+
+ rr_trainer<linear_kernel<matrix<scalar_type,0,0,mem_manager_type> > > trainer;
+
+ bool verbose;
+
+
+ kernel_type kern;
+ unsigned long max_basis_size;
+
+ matrix<sample_type,0,1,mem_manager_type> basis;
+ mutable empirical_kernel_map<kernel_type> ekm;
+ mutable bool ekm_stale;
+
+ };
+
+}
+
+#endif // DLIB_KRR_TRAInER_Hh_
+
+