summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/test/lspi.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/test/lspi.cpp')
-rw-r--r--ml/dlib/dlib/test/lspi.cpp258
1 files changed, 258 insertions, 0 deletions
diff --git a/ml/dlib/dlib/test/lspi.cpp b/ml/dlib/dlib/test/lspi.cpp
new file mode 100644
index 00000000..01388711
--- /dev/null
+++ b/ml/dlib/dlib/test/lspi.cpp
@@ -0,0 +1,258 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+#include "tester.h"
+#include <dlib/control.h>
+#include <vector>
+#include <sstream>
+#include <ctime>
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+ dlib::logger dlog("test.lspi");
+
+ template <bool have_prior>
+ struct chain_model
+ {
+ typedef int state_type;
+ typedef int action_type; // 0 is move left, 1 is move right
+ const static bool force_last_weight_to_1 = have_prior;
+
+
+ const static int num_states = 4; // not required in the model interface
+
+ matrix<double,8,1> offset;
+ chain_model()
+ {
+ offset =
+ 2.048 ,
+ 2.56 ,
+ 2.048 ,
+ 3.2 ,
+ 2.56 ,
+ 4 ,
+ 3.2,
+ 5 ;
+ if (!have_prior)
+ offset = 0;
+
+ }
+
+ unsigned long num_features(
+ ) const
+ {
+ if (have_prior)
+ return num_states*2 + 1;
+ else
+ return num_states*2;
+ }
+
+ action_type find_best_action (
+ const state_type& state,
+ const matrix<double,0,1>& w
+ ) const
+ {
+ if (w(state*2)+offset(state*2) >= w(state*2+1)+offset(state*2+1))
+ //if (w(state*2) >= w(state*2+1))
+ return 0;
+ else
+ return 1;
+ }
+
+ void get_features (
+ const state_type& state,
+ const action_type& action,
+ matrix<double,0,1>& feats
+ ) const
+ {
+ feats.set_size(num_features());
+ feats = 0;
+ feats(state*2 + action) = 1;
+ if (have_prior)
+ feats(num_features()-1) = offset(state*2+action);
+ }
+
+ };
+
+ void test_lspi_prior1()
+ {
+ print_spinner();
+ typedef process_sample<chain_model<true> > sample_type;
+ std::vector<sample_type> samples;
+
+ samples.push_back(sample_type(0,0,0,0));
+ samples.push_back(sample_type(0,1,1,0));
+
+ samples.push_back(sample_type(1,0,0,0));
+ samples.push_back(sample_type(1,1,2,0));
+
+ samples.push_back(sample_type(2,0,1,0));
+ samples.push_back(sample_type(2,1,3,0));
+
+ samples.push_back(sample_type(3,0,2,0));
+ samples.push_back(sample_type(3,1,3,1));
+
+
+ lspi<chain_model<true> > trainer;
+ //trainer.be_verbose();
+ trainer.set_lambda(0);
+ policy<chain_model<true> > pol = trainer.train(samples);
+
+ dlog << LINFO << pol.get_weights();
+
+ matrix<double,0,1> w = pol.get_weights();
+ DLIB_TEST(pol.get_weights().size() == 9);
+ DLIB_TEST(w(w.size()-1) == 1);
+ w(w.size()-1) = 0;
+ DLIB_TEST_MSG(length(w) < 1e-12, length(w));
+
+ dlog << LINFO << "action: " << pol(0);
+ dlog << LINFO << "action: " << pol(1);
+ dlog << LINFO << "action: " << pol(2);
+ dlog << LINFO << "action: " << pol(3);
+ DLIB_TEST(pol(0) == 1);
+ DLIB_TEST(pol(1) == 1);
+ DLIB_TEST(pol(2) == 1);
+ DLIB_TEST(pol(3) == 1);
+ }
+
+ void test_lspi_prior2()
+ {
+ print_spinner();
+ typedef process_sample<chain_model<true> > sample_type;
+ std::vector<sample_type> samples;
+
+ samples.push_back(sample_type(0,0,0,0));
+ samples.push_back(sample_type(0,1,1,0));
+
+ samples.push_back(sample_type(1,0,0,0));
+ samples.push_back(sample_type(1,1,2,0));
+
+ samples.push_back(sample_type(2,0,1,0));
+ samples.push_back(sample_type(2,1,3,1));
+
+ samples.push_back(sample_type(3,0,2,0));
+ samples.push_back(sample_type(3,1,3,0));
+
+
+ lspi<chain_model<true> > trainer;
+ //trainer.be_verbose();
+ trainer.set_lambda(0);
+ policy<chain_model<true> > pol = trainer.train(samples);
+
+
+ dlog << LINFO << "action: " << pol(0);
+ dlog << LINFO << "action: " << pol(1);
+ dlog << LINFO << "action: " << pol(2);
+ dlog << LINFO << "action: " << pol(3);
+ DLIB_TEST(pol(0) == 1);
+ DLIB_TEST(pol(1) == 1);
+ DLIB_TEST(pol(2) == 1);
+ DLIB_TEST(pol(3) == 0);
+ }
+
+ void test_lspi_noprior1()
+ {
+ print_spinner();
+ typedef process_sample<chain_model<false> > sample_type;
+ std::vector<sample_type> samples;
+
+ samples.push_back(sample_type(0,0,0,0));
+ samples.push_back(sample_type(0,1,1,0));
+
+ samples.push_back(sample_type(1,0,0,0));
+ samples.push_back(sample_type(1,1,2,0));
+
+ samples.push_back(sample_type(2,0,1,0));
+ samples.push_back(sample_type(2,1,3,0));
+
+ samples.push_back(sample_type(3,0,2,0));
+ samples.push_back(sample_type(3,1,3,1));
+
+
+ lspi<chain_model<false> > trainer;
+ //trainer.be_verbose();
+ trainer.set_lambda(0.01);
+ policy<chain_model<false> > pol = trainer.train(samples);
+
+ dlog << LINFO << pol.get_weights();
+ DLIB_TEST(pol.get_weights().size() == 8);
+
+
+ dlog << LINFO << "action: " << pol(0);
+ dlog << LINFO << "action: " << pol(1);
+ dlog << LINFO << "action: " << pol(2);
+ dlog << LINFO << "action: " << pol(3);
+ DLIB_TEST(pol(0) == 1);
+ DLIB_TEST(pol(1) == 1);
+ DLIB_TEST(pol(2) == 1);
+ DLIB_TEST(pol(3) == 1);
+ }
+ void test_lspi_noprior2()
+ {
+ print_spinner();
+ typedef process_sample<chain_model<false> > sample_type;
+ std::vector<sample_type> samples;
+
+ samples.push_back(sample_type(0,0,0,0));
+ samples.push_back(sample_type(0,1,1,0));
+
+ samples.push_back(sample_type(1,0,0,0));
+ samples.push_back(sample_type(1,1,2,1));
+
+ samples.push_back(sample_type(2,0,1,0));
+ samples.push_back(sample_type(2,1,3,0));
+
+ samples.push_back(sample_type(3,0,2,0));
+ samples.push_back(sample_type(3,1,3,0));
+
+
+ lspi<chain_model<false> > trainer;
+ //trainer.be_verbose();
+ trainer.set_lambda(0.01);
+ policy<chain_model<false> > pol = trainer.train(samples);
+
+ dlog << LINFO << pol.get_weights();
+ DLIB_TEST(pol.get_weights().size() == 8);
+
+
+ dlog << LINFO << "action: " << pol(0);
+ dlog << LINFO << "action: " << pol(1);
+ dlog << LINFO << "action: " << pol(2);
+ dlog << LINFO << "action: " << pol(3);
+ DLIB_TEST(pol(0) == 1);
+ DLIB_TEST(pol(1) == 1);
+ DLIB_TEST(pol(2) == 0);
+ DLIB_TEST(pol(3) == 0);
+ }
+
+ class lspi_tester : public tester
+ {
+ public:
+ lspi_tester (
+ ) :
+ tester (
+ "test_lspi", // the command line argument name for this test
+ "Run tests on the lspi object.", // the command line argument description
+ 0 // the number of command line arguments for this test
+ )
+ {
+ }
+
+ void perform_test (
+ )
+ {
+ test_lspi_prior1();
+ test_lspi_prior2();
+
+ test_lspi_noprior1();
+ test_lspi_noprior2();
+ }
+ };
+
+ lspi_tester a;
+}
+