diff options
Diffstat (limited to 'ml/dlib/dlib/test/elastic_net.cpp')
-rw-r--r-- | ml/dlib/dlib/test/elastic_net.cpp | 122 |
1 files changed, 122 insertions, 0 deletions
diff --git a/ml/dlib/dlib/test/elastic_net.cpp b/ml/dlib/dlib/test/elastic_net.cpp new file mode 100644 index 000000000..0e0501639 --- /dev/null +++ b/ml/dlib/dlib/test/elastic_net.cpp @@ -0,0 +1,122 @@ +// Copyright (C) 2016 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include <dlib/optimization/elastic_net.h> +#include "tester.h" +#include <dlib/svm.h> +#include <dlib/rand.h> +#include <dlib/string.h> +#include <vector> +#include <sstream> +#include <ctime> + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.elastic_net"); + +// ---------------------------------------------------------------------------------------- + + matrix<double,0,1> basic_elastic_net( + const matrix<double>& X, + const matrix<double,0,1>& Y, + double ridge_lambda, + double lasso_budget, + double eps + ) + { + DLIB_CASSERT(X.nc() == Y.nr(),""); + + + typedef matrix<double,0,1> sample_type; + typedef linear_kernel<sample_type> kernel_type; + + svm_c_linear_dcd_trainer<kernel_type> trainer; + trainer.solve_svm_l2_problem(true); + const double C = 1/(2*ridge_lambda); + trainer.set_c(C); + trainer.set_epsilon(eps); + trainer.enable_shrinking(true); + trainer.include_bias(false); + + + std::vector<sample_type> samples; + std::vector<double> labels; + for (long r = 0; r < X.nr(); ++r) + { + sample_type temp = trans(rowm(X,r)); + + const double xmul = (1/lasso_budget); + samples.push_back(temp - xmul*Y); + labels.push_back(+1); + samples.push_back(temp + xmul*Y); + labels.push_back(-1); + } + + svm_c_linear_dcd_trainer<kernel_type>::optimizer_state state; + auto df = trainer.train(samples, labels, state); + auto&& alpha = state.get_alpha(); + + matrix<double,0,1> betas(alpha.size()/2); + for (long i = 0; i < betas.size(); ++i) + betas(i) = lasso_budget*(alpha[2*i] - alpha[2*i+1]); + betas /= sum(mat(alpha)); + return betas; + } + +// ---------------------------------------------------------------------------------------- + + class test_elastic_net : public tester + { + public: + test_elastic_net ( + ) : + tester ( + "test_elastic_net", + "Run tests on the elastic_net object.", + 0 + ) + { + } + + void perform_test ( + ) + { + matrix<double> w = {1,2,0,4, 0,0,0,0,0, 6, 7,8,0, 9, 0}; + + matrix<double> X = randm(w.size(),1000); + matrix<double> Y = trans(X)*w; + Y += 0.1*(randm(Y.nr(), Y.nc())-0.5); + + + double ridge_lambda = 0.1; + double lasso_budget = sum(abs(w)); + double eps = 0.0000001; + + dlib::elastic_net solver(X*trans(X),X*Y); + solver.set_epsilon(eps); + + + matrix<double,0,1> results; + matrix<double,0,1> results2; + for (double s = 1.2; s > 0.10; s *= 0.9) + { + print_spinner(); + dlog << LINFO << "s: "<< s; + // make sure the two solvers agree. + results = basic_elastic_net(X, Y, ridge_lambda, lasso_budget*s, eps); + results2 = solver(ridge_lambda, lasso_budget*s); + dlog << LINFO << "error: "<< max(abs(results - results2)); + DLIB_TEST(max(abs(results - results2)) < 1e-3); + } + } + } a; + +// ---------------------------------------------------------------------------------------- + +} + + + |