diff options
Diffstat (limited to 'ml/dlib/dlib/svm/svm_c_ekm_trainer.h')
-rw-r--r-- | ml/dlib/dlib/svm/svm_c_ekm_trainer.h | 636 |
1 files changed, 636 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/svm_c_ekm_trainer.h b/ml/dlib/dlib/svm/svm_c_ekm_trainer.h new file mode 100644 index 000000000..735e0f22e --- /dev/null +++ b/ml/dlib/dlib/svm/svm_c_ekm_trainer.h @@ -0,0 +1,636 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SVM_C_EKm_TRAINER_Hh_ +#define DLIB_SVM_C_EKm_TRAINER_Hh_ + +#include "../algs.h" +#include "function.h" +#include "kernel.h" +#include "empirical_kernel_map.h" +#include "svm_c_linear_trainer.h" +#include "svm_c_ekm_trainer_abstract.h" +#include "../statistics.h" +#include "../rand.h" +#include <vector> + +namespace dlib +{ + template < + typename K + > + class svm_c_ekm_trainer + { + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function<kernel_type> trained_function_type; + + svm_c_ekm_trainer ( + ) + { + verbose = false; + ekm_stale = true; + + initial_basis_size = 10; + basis_size_increment = 50; + max_basis_size = 300; + } + + explicit svm_c_ekm_trainer ( + const scalar_type& C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t svm_c_ekm_trainer::svm_c_ekm_trainer()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + + ocas.set_c(C); + verbose = false; + ekm_stale = true; + + initial_basis_size = 10; + basis_size_increment = 50; + max_basis_size = 300; + } + + void set_epsilon ( + scalar_type eps + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps > 0, + "\t void svm_c_ekm_trainer::set_epsilon()" + << "\n\t eps must be greater than 0" + << "\n\t eps: " << eps + << "\n\t this: " << this + ); + + ocas.set_epsilon(eps); + } + + const scalar_type get_epsilon ( + ) const + { + return ocas.get_epsilon(); + } + + void set_max_iterations ( + unsigned long max_iter + ) + { + ocas.set_max_iterations(max_iter); + } + + unsigned long get_max_iterations ( + ) + { + return ocas.get_max_iterations(); + } + + void be_verbose ( + ) + { + verbose = true; + ocas.be_quiet(); + } + + void be_very_verbose ( + ) + { + verbose = true; + ocas.be_verbose(); + } + + void be_quiet ( + ) + { + verbose = false; + ocas.be_quiet(); + } + + void set_oca ( + const oca& item + ) + { + ocas.set_oca(item); + } + + const oca get_oca ( + ) const + { + return ocas.get_oca(); + } + + const kernel_type get_kernel ( + ) const + { + return kern; + } + + void set_kernel ( + const kernel_type& k + ) + { + kern = k; + ekm_stale = true; + } + + template <typename T> + void set_basis ( + const T& basis_samples + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(basis_samples.size() > 0 && is_vector(mat(basis_samples)), + "\tvoid svm_c_ekm_trainer::set_basis(basis_samples)" + << "\n\t You have to give a non-empty set of basis_samples and it must be a vector" + << "\n\t basis_samples.size(): " << basis_samples.size() + << "\n\t is_vector(mat(basis_samples)): " << is_vector(mat(basis_samples)) + << "\n\t this: " << this + ); + + basis = mat(basis_samples); + ekm_stale = true; + } + + bool basis_loaded( + ) const + { + return (basis.size() != 0); + } + + void clear_basis ( + ) + { + basis.set_size(0); + ekm.clear(); + ekm_stale = true; + } + + unsigned long get_max_basis_size ( + ) const + { + return max_basis_size; + } + + void set_max_basis_size ( + unsigned long max_basis_size_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(max_basis_size_ > 0, + "\t void svm_c_ekm_trainer::set_max_basis_size()" + << "\n\t max_basis_size_ must be greater than 0" + << "\n\t max_basis_size_: " << max_basis_size_ + << "\n\t this: " << this + ); + + max_basis_size = max_basis_size_; + if (initial_basis_size > max_basis_size) + initial_basis_size = max_basis_size; + } + + unsigned long get_initial_basis_size ( + ) const + { + return initial_basis_size; + } + + void set_initial_basis_size ( + unsigned long initial_basis_size_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(initial_basis_size_ > 0, + "\t void svm_c_ekm_trainer::set_initial_basis_size()" + << "\n\t initial_basis_size_ must be greater than 0" + << "\n\t initial_basis_size_: " << initial_basis_size_ + << "\n\t this: " << this + ); + + initial_basis_size = initial_basis_size_; + + if (initial_basis_size > max_basis_size) + max_basis_size = initial_basis_size; + } + + unsigned long get_basis_size_increment ( + ) const + { + return basis_size_increment; + } + + void set_basis_size_increment ( + unsigned long basis_size_increment_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(basis_size_increment_ > 0, + "\t void svm_c_ekm_trainer::set_basis_size_increment()" + << "\n\t basis_size_increment_ must be greater than 0" + << "\n\t basis_size_increment_: " << basis_size_increment_ + << "\n\t this: " << this + ); + + basis_size_increment = basis_size_increment_; + } + + void set_c ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_ekm_trainer::set_c()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + ocas.set_c(C); + } + + const scalar_type get_c_class1 ( + ) const + { + return ocas.get_c_class1(); + } + + const scalar_type get_c_class2 ( + ) const + { + return ocas.get_c_class2(); + } + + void set_c_class1 ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_ekm_trainer::set_c_class1()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + ocas.set_c_class1(C); + } + + void set_c_class2 ( + scalar_type C + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C > 0, + "\t void svm_c_ekm_trainer::set_c_class2()" + << "\n\t C must be greater than 0" + << "\n\t C: " << C + << "\n\t this: " << this + ); + + ocas.set_c_class2(C); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function<kernel_type> train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + { + scalar_type obj; + if (basis_loaded()) + return do_train_user_basis(mat(x),mat(y),obj); + else + return do_train_auto_basis(mat(x),mat(y),obj); + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function<kernel_type> train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + scalar_type& svm_objective + ) const + { + if (basis_loaded()) + return do_train_user_basis(mat(x),mat(y),svm_objective); + else + return do_train_auto_basis(mat(x),mat(y),svm_objective); + } + + + private: + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function<kernel_type> do_train_user_basis ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + scalar_type& svm_objective + ) const + /*! + requires + - basis_loaded() == true + ensures + - trains an SVM with the user supplied basis + !*/ + { + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true, + "\t decision_function svm_c_ekm_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y) + ); + + if (ekm_stale) + { + ekm.load(kern, basis); + ekm_stale = false; + } + + // project all the samples with the ekm + running_stats<scalar_type> rs; + std::vector<matrix<scalar_type,0,1, mem_manager_type> > proj_samples; + proj_samples.reserve(x.size()); + for (long i = 0; i < x.size(); ++i) + { + if (verbose) + { + scalar_type err; + proj_samples.push_back(ekm.project(x(i), err)); + rs.add(err); + } + else + { + proj_samples.push_back(ekm.project(x(i))); + } + } + + if (verbose) + { + std::cout << "\nMean EKM projection error: " << rs.mean() << std::endl; + std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl; + } + + // now do the training + decision_function<linear_kernel<matrix<scalar_type,0,1, mem_manager_type> > > df; + df = ocas.train(proj_samples, y, svm_objective); + + if (verbose) + { + std::cout << "Final svm objective: " << svm_objective << std::endl; + } + + decision_function<kernel_type> final_df; + final_df = ekm.convert_to_decision_function(df.basis_vectors(0)); + final_df.b = df.b; + return final_df; + } + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function<kernel_type> do_train_auto_basis ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y, + scalar_type& svm_objective + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_binary_classification_problem(x,y) == true, + "\t decision_function svm_c_ekm_trainer::train(x,y)" + << "\n\t invalid inputs were given to this function" + << "\n\t x.nr(): " << x.nr() + << "\n\t y.nr(): " << y.nr() + << "\n\t x.nc(): " << x.nc() + << "\n\t y.nc(): " << y.nc() + << "\n\t is_binary_classification_problem(x,y): " << is_binary_classification_problem(x,y) + ); + + + std::vector<matrix<scalar_type,0,1, mem_manager_type> > proj_samples(x.size()); + decision_function<linear_kernel<matrix<scalar_type,0,1, mem_manager_type> > > df; + + // we will use a linearly_independent_subset_finder to store our basis set. + linearly_independent_subset_finder<kernel_type> lisf(get_kernel(), max_basis_size); + + dlib::rand rnd; + + // first pick the initial basis set randomly + for (unsigned long i = 0; i < 10*initial_basis_size && lisf.size() < initial_basis_size; ++i) + { + lisf.add(x(rnd.get_random_32bit_number()%x.size())); + } + + ekm.load(lisf); + + // first project all samples into the span of the current basis + for (long i = 0; i < x.size(); ++i) + { + proj_samples[i] = ekm.project(x(i)); + } + + + svm_c_linear_trainer<linear_kernel<matrix<scalar_type,0,1,mem_manager_type> > > trainer(ocas); + + const scalar_type min_epsilon = trainer.get_epsilon(); + // while we are determining what the basis set will be we are going to use a very + // lose stopping condition. We will tighten it back up before producing the + // final decision_function. + trainer.set_epsilon(0.2); + + scalar_type prev_svm_objective = std::numeric_limits<scalar_type>::max(); + + empirical_kernel_map<kernel_type> prev_ekm; + + // This loop is where we try to generate a basis for SVM training. We will + // do this by repeatedly training the SVM and adding a few points which violate the + // margin to the basis in each iteration. + while (true) + { + // if the basis is already as big as it's going to get then just do the most + // accurate training right now. + if (lisf.size() == max_basis_size) + trainer.set_epsilon(min_epsilon); + + while (true) + { + // now do the training. + df = trainer.train(proj_samples, y, svm_objective); + + if (svm_objective < prev_svm_objective) + break; + + // If the training didn't reduce the objective more than last time then + // try lowering the epsilon and doing it again. + if (trainer.get_epsilon() > min_epsilon) + { + trainer.set_epsilon(std::max(trainer.get_epsilon()*0.5, min_epsilon)); + if (verbose) + std::cout << " *** Reducing epsilon to " << trainer.get_epsilon() << std::endl; + } + else + break; + } + + if (verbose) + { + std::cout << "svm objective: " << svm_objective << std::endl; + std::cout << "basis size: " << lisf.size() << std::endl; + } + + // if we failed to make progress on this iteration then we are done + if (svm_objective >= prev_svm_objective) + break; + + prev_svm_objective = svm_objective; + + // now add more elements to the basis + unsigned long count = 0; + for (unsigned long j = 0; + (j < 100*basis_size_increment) && (count < basis_size_increment) && (lisf.size() < max_basis_size); + ++j) + { + // pick a random sample + const unsigned long idx = rnd.get_random_32bit_number()%x.size(); + // If it is a margin violator then it is useful to add it into the basis set. + if (df(proj_samples[idx])*y(idx) < 1) + { + // Add the sample into the basis set if it is linearly independent of all the + // vectors already in the basis set. + if (lisf.add(x(idx))) + { + ++count; + } + } + } + // if we couldn't add any more basis vectors then stop + if (count == 0) + { + if (verbose) + std::cout << "Stopping, couldn't add more basis vectors." << std::endl; + break; + } + + + // Project all the samples into the span of our newly enlarged basis. We will do this + // using the special transformation in the EKM that lets us project from a smaller + // basis set to a larger without needing to reevaluate kernel functions we have already + // computed. + ekm.swap(prev_ekm); + ekm.load(lisf); + projection_function<kernel_type> proj_part; + matrix<double> prev_to_new; + prev_ekm.get_transformation_to(ekm, prev_to_new, proj_part); + + + matrix<scalar_type,0,1, mem_manager_type> temp; + for (long i = 0; i < x.size(); ++i) + { + // assign to temporary to avoid memory allocation that would result if we + // assigned this expression straight into proj_samples[i] + temp = prev_to_new*proj_samples[i] + proj_part(x(i)); + proj_samples[i] = temp; + + } + } + + // Reproject all the data samples using the final basis. We could just use what we + // already have but the recursive thing done above to compute the proj_samples + // might have accumulated a little numerical error. So lets just be safe. + running_stats<scalar_type> rs, rs_margin; + for (long i = 0; i < x.size(); ++i) + { + if (verbose) + { + scalar_type err; + proj_samples[i] = ekm.project(x(i),err); + rs.add(err); + // if this point is within the margin + if (df(proj_samples[i])*y(i) < 1) + rs_margin.add(err); + } + else + { + proj_samples[i] = ekm.project(x(i)); + } + } + + // do the final training + trainer.set_epsilon(min_epsilon); + df = trainer.train(proj_samples, y, svm_objective); + + + if (verbose) + { + std::cout << "\nMean EKM projection error: " << rs.mean() << std::endl; + std::cout << "Standard deviation of EKM projection error: " << rs.stddev() << std::endl; + std::cout << "Mean EKM projection error for margin violators: " << rs_margin.mean() << std::endl; + std::cout << "Standard deviation of EKM projection error for margin violators: " << ((rs_margin.current_n()>1)?rs_margin.stddev():0) << std::endl; + + std::cout << "Final svm objective: " << svm_objective << std::endl; + } + + + decision_function<kernel_type> final_df; + final_df = ekm.convert_to_decision_function(df.basis_vectors(0)); + final_df.b = df.b; + + // we don't need the ekm anymore so clear it out + ekm.clear(); + + return final_df; + } + + + + + /*! + CONVENTION + - if (ekm_stale) then + - kern or basis have changed since the last time + they were loaded into the ekm + !*/ + + svm_c_linear_trainer<linear_kernel<matrix<scalar_type,0,1,mem_manager_type> > > ocas; + bool verbose; + + kernel_type kern; + unsigned long max_basis_size; + unsigned long basis_size_increment; + unsigned long initial_basis_size; + + + matrix<sample_type,0,1,mem_manager_type> basis; + mutable empirical_kernel_map<kernel_type> ekm; + mutable bool ekm_stale; + + }; + +} + +#endif // DLIB_SVM_C_EKm_TRAINER_Hh_ + + + |