diff options
Diffstat (limited to 'ml/dlib/dlib/test/lspi.cpp')
-rw-r--r-- | ml/dlib/dlib/test/lspi.cpp | 258 |
1 files changed, 258 insertions, 0 deletions
diff --git a/ml/dlib/dlib/test/lspi.cpp b/ml/dlib/dlib/test/lspi.cpp new file mode 100644 index 00000000..01388711 --- /dev/null +++ b/ml/dlib/dlib/test/lspi.cpp @@ -0,0 +1,258 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#include "tester.h" +#include <dlib/control.h> +#include <vector> +#include <sstream> +#include <ctime> + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + dlib::logger dlog("test.lspi"); + + template <bool have_prior> + struct chain_model + { + typedef int state_type; + typedef int action_type; // 0 is move left, 1 is move right + const static bool force_last_weight_to_1 = have_prior; + + + const static int num_states = 4; // not required in the model interface + + matrix<double,8,1> offset; + chain_model() + { + offset = + 2.048 , + 2.56 , + 2.048 , + 3.2 , + 2.56 , + 4 , + 3.2, + 5 ; + if (!have_prior) + offset = 0; + + } + + unsigned long num_features( + ) const + { + if (have_prior) + return num_states*2 + 1; + else + return num_states*2; + } + + action_type find_best_action ( + const state_type& state, + const matrix<double,0,1>& w + ) const + { + if (w(state*2)+offset(state*2) >= w(state*2+1)+offset(state*2+1)) + //if (w(state*2) >= w(state*2+1)) + return 0; + else + return 1; + } + + void get_features ( + const state_type& state, + const action_type& action, + matrix<double,0,1>& feats + ) const + { + feats.set_size(num_features()); + feats = 0; + feats(state*2 + action) = 1; + if (have_prior) + feats(num_features()-1) = offset(state*2+action); + } + + }; + + void test_lspi_prior1() + { + print_spinner(); + typedef process_sample<chain_model<true> > sample_type; + std::vector<sample_type> samples; + + samples.push_back(sample_type(0,0,0,0)); + samples.push_back(sample_type(0,1,1,0)); + + samples.push_back(sample_type(1,0,0,0)); + samples.push_back(sample_type(1,1,2,0)); + + samples.push_back(sample_type(2,0,1,0)); + samples.push_back(sample_type(2,1,3,0)); + + samples.push_back(sample_type(3,0,2,0)); + samples.push_back(sample_type(3,1,3,1)); + + + lspi<chain_model<true> > trainer; + //trainer.be_verbose(); + trainer.set_lambda(0); + policy<chain_model<true> > pol = trainer.train(samples); + + dlog << LINFO << pol.get_weights(); + + matrix<double,0,1> w = pol.get_weights(); + DLIB_TEST(pol.get_weights().size() == 9); + DLIB_TEST(w(w.size()-1) == 1); + w(w.size()-1) = 0; + DLIB_TEST_MSG(length(w) < 1e-12, length(w)); + + dlog << LINFO << "action: " << pol(0); + dlog << LINFO << "action: " << pol(1); + dlog << LINFO << "action: " << pol(2); + dlog << LINFO << "action: " << pol(3); + DLIB_TEST(pol(0) == 1); + DLIB_TEST(pol(1) == 1); + DLIB_TEST(pol(2) == 1); + DLIB_TEST(pol(3) == 1); + } + + void test_lspi_prior2() + { + print_spinner(); + typedef process_sample<chain_model<true> > sample_type; + std::vector<sample_type> samples; + + samples.push_back(sample_type(0,0,0,0)); + samples.push_back(sample_type(0,1,1,0)); + + samples.push_back(sample_type(1,0,0,0)); + samples.push_back(sample_type(1,1,2,0)); + + samples.push_back(sample_type(2,0,1,0)); + samples.push_back(sample_type(2,1,3,1)); + + samples.push_back(sample_type(3,0,2,0)); + samples.push_back(sample_type(3,1,3,0)); + + + lspi<chain_model<true> > trainer; + //trainer.be_verbose(); + trainer.set_lambda(0); + policy<chain_model<true> > pol = trainer.train(samples); + + + dlog << LINFO << "action: " << pol(0); + dlog << LINFO << "action: " << pol(1); + dlog << LINFO << "action: " << pol(2); + dlog << LINFO << "action: " << pol(3); + DLIB_TEST(pol(0) == 1); + DLIB_TEST(pol(1) == 1); + DLIB_TEST(pol(2) == 1); + DLIB_TEST(pol(3) == 0); + } + + void test_lspi_noprior1() + { + print_spinner(); + typedef process_sample<chain_model<false> > sample_type; + std::vector<sample_type> samples; + + samples.push_back(sample_type(0,0,0,0)); + samples.push_back(sample_type(0,1,1,0)); + + samples.push_back(sample_type(1,0,0,0)); + samples.push_back(sample_type(1,1,2,0)); + + samples.push_back(sample_type(2,0,1,0)); + samples.push_back(sample_type(2,1,3,0)); + + samples.push_back(sample_type(3,0,2,0)); + samples.push_back(sample_type(3,1,3,1)); + + + lspi<chain_model<false> > trainer; + //trainer.be_verbose(); + trainer.set_lambda(0.01); + policy<chain_model<false> > pol = trainer.train(samples); + + dlog << LINFO << pol.get_weights(); + DLIB_TEST(pol.get_weights().size() == 8); + + + dlog << LINFO << "action: " << pol(0); + dlog << LINFO << "action: " << pol(1); + dlog << LINFO << "action: " << pol(2); + dlog << LINFO << "action: " << pol(3); + DLIB_TEST(pol(0) == 1); + DLIB_TEST(pol(1) == 1); + DLIB_TEST(pol(2) == 1); + DLIB_TEST(pol(3) == 1); + } + void test_lspi_noprior2() + { + print_spinner(); + typedef process_sample<chain_model<false> > sample_type; + std::vector<sample_type> samples; + + samples.push_back(sample_type(0,0,0,0)); + samples.push_back(sample_type(0,1,1,0)); + + samples.push_back(sample_type(1,0,0,0)); + samples.push_back(sample_type(1,1,2,1)); + + samples.push_back(sample_type(2,0,1,0)); + samples.push_back(sample_type(2,1,3,0)); + + samples.push_back(sample_type(3,0,2,0)); + samples.push_back(sample_type(3,1,3,0)); + + + lspi<chain_model<false> > trainer; + //trainer.be_verbose(); + trainer.set_lambda(0.01); + policy<chain_model<false> > pol = trainer.train(samples); + + dlog << LINFO << pol.get_weights(); + DLIB_TEST(pol.get_weights().size() == 8); + + + dlog << LINFO << "action: " << pol(0); + dlog << LINFO << "action: " << pol(1); + dlog << LINFO << "action: " << pol(2); + dlog << LINFO << "action: " << pol(3); + DLIB_TEST(pol(0) == 1); + DLIB_TEST(pol(1) == 1); + DLIB_TEST(pol(2) == 0); + DLIB_TEST(pol(3) == 0); + } + + class lspi_tester : public tester + { + public: + lspi_tester ( + ) : + tester ( + "test_lspi", // the command line argument name for this test + "Run tests on the lspi object.", // the command line argument description + 0 // the number of command line arguments for this test + ) + { + } + + void perform_test ( + ) + { + test_lspi_prior1(); + test_lspi_prior2(); + + test_lspi_noprior1(); + test_lspi_noprior2(); + } + }; + + lspi_tester a; +} + |