summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/test/elastic_net.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/test/elastic_net.cpp')
-rw-r--r--ml/dlib/dlib/test/elastic_net.cpp122
1 files changed, 122 insertions, 0 deletions
diff --git a/ml/dlib/dlib/test/elastic_net.cpp b/ml/dlib/dlib/test/elastic_net.cpp
new file mode 100644
index 000000000..0e0501639
--- /dev/null
+++ b/ml/dlib/dlib/test/elastic_net.cpp
@@ -0,0 +1,122 @@
+// Copyright (C) 2016 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include <dlib/optimization/elastic_net.h>
+#include "tester.h"
+#include <dlib/svm.h>
+#include <dlib/rand.h>
+#include <dlib/string.h>
+#include <vector>
+#include <sstream>
+#include <ctime>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.elastic_net");
+
+// ----------------------------------------------------------------------------------------
+
+ matrix<double,0,1> basic_elastic_net(
+ const matrix<double>& X,
+ const matrix<double,0,1>& Y,
+ double ridge_lambda,
+ double lasso_budget,
+ double eps
+ )
+ {
+ DLIB_CASSERT(X.nc() == Y.nr(),"");
+
+
+ typedef matrix<double,0,1> sample_type;
+ typedef linear_kernel<sample_type> kernel_type;
+
+ svm_c_linear_dcd_trainer<kernel_type> trainer;
+ trainer.solve_svm_l2_problem(true);
+ const double C = 1/(2*ridge_lambda);
+ trainer.set_c(C);
+ trainer.set_epsilon(eps);
+ trainer.enable_shrinking(true);
+ trainer.include_bias(false);
+
+
+ std::vector<sample_type> samples;
+ std::vector<double> labels;
+ for (long r = 0; r < X.nr(); ++r)
+ {
+ sample_type temp = trans(rowm(X,r));
+
+ const double xmul = (1/lasso_budget);
+ samples.push_back(temp - xmul*Y);
+ labels.push_back(+1);
+ samples.push_back(temp + xmul*Y);
+ labels.push_back(-1);
+ }
+
+ svm_c_linear_dcd_trainer<kernel_type>::optimizer_state state;
+ auto df = trainer.train(samples, labels, state);
+ auto&& alpha = state.get_alpha();
+
+ matrix<double,0,1> betas(alpha.size()/2);
+ for (long i = 0; i < betas.size(); ++i)
+ betas(i) = lasso_budget*(alpha[2*i] - alpha[2*i+1]);
+ betas /= sum(mat(alpha));
+ return betas;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_elastic_net : public tester
+ {
+ public:
+ test_elastic_net (
+ ) :
+ tester (
+ "test_elastic_net",
+ "Run tests on the elastic_net object.",
+ 0
+ )
+ {
+ }
+
+ void perform_test (
+ )
+ {
+ matrix<double> w = {1,2,0,4, 0,0,0,0,0, 6, 7,8,0, 9, 0};
+
+ matrix<double> X = randm(w.size(),1000);
+ matrix<double> Y = trans(X)*w;
+ Y += 0.1*(randm(Y.nr(), Y.nc())-0.5);
+
+
+ double ridge_lambda = 0.1;
+ double lasso_budget = sum(abs(w));
+ double eps = 0.0000001;
+
+ dlib::elastic_net solver(X*trans(X),X*Y);
+ solver.set_epsilon(eps);
+
+
+ matrix<double,0,1> results;
+ matrix<double,0,1> results2;
+ for (double s = 1.2; s > 0.10; s *= 0.9)
+ {
+ print_spinner();
+ dlog << LINFO << "s: "<< s;
+ // make sure the two solvers agree.
+ results = basic_elastic_net(X, Y, ridge_lambda, lasso_budget*s, eps);
+ results2 = solver(ridge_lambda, lasso_budget*s);
+ dlog << LINFO << "error: "<< max(abs(results - results2));
+ DLIB_TEST(max(abs(results - results2)) < 1e-3);
+ }
+ }
+ } a;
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+
+