diff options
Diffstat (limited to 'src/ml/dlib/dlib/test/filtering.cpp')
-rw-r--r-- | src/ml/dlib/dlib/test/filtering.cpp | 166 |
1 files changed, 166 insertions, 0 deletions
diff --git a/src/ml/dlib/dlib/test/filtering.cpp b/src/ml/dlib/dlib/test/filtering.cpp new file mode 100644 index 000000000..61dc88440 --- /dev/null +++ b/src/ml/dlib/dlib/test/filtering.cpp @@ -0,0 +1,166 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + + +#include <dlib/filtering.h> +#include <sstream> +#include <string> +#include <cstdlib> +#include <ctime> +#include <dlib/matrix.h> +#include <dlib/rand.h> + +#include "tester.h" + +namespace +{ + + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.filtering"); + +// ---------------------------------------------------------------------------------------- + + template <typename filter_type> + double test_filter ( + filter_type kf, + int size + ) + { + // This test has a point moving in a circle around the origin. The point + // also gets a random bump in a random direction at each time step. + + running_stats<double> rs; + + dlib::rand rnd; + int count = 0; + const dlib::vector<double,3> z(0,0,1); + dlib::vector<double,2> p(10,10), temp; + for (int i = 0; i < size; ++i) + { + // move the point around in a circle + p += z.cross(p).normalize()/0.5; + // randomly drop measurements + if (rnd.get_random_double() < 0.7 || count < 4) + { + // make a random bump + dlib::vector<double,2> pp; + pp.x() = rnd.get_random_gaussian()/3; + pp.y() = rnd.get_random_gaussian()/3; + + ++count; + kf.update(p+pp); + } + else + { + kf.update(); + dlog << LTRACE << "MISSED MEASUREMENT"; + } + // figure out the next position + temp = (p+z.cross(p).normalize()/0.5); + const double error = length(temp - rowm(kf.get_predicted_next_state(),range(0,1))); + rs.add(error); + + dlog << LTRACE << temp << "("<< error << "): " << trans(kf.get_predicted_next_state()); + + // test the serialization a few times. + if (count < 10) + { + ostringstream sout; + serialize(kf, sout); + istringstream sin(sout.str()); + filter_type temp; + deserialize(temp, sin); + kf = temp; + } + } + + + return rs.mean(); + + } + +// ---------------------------------------------------------------------------------------- + + void test_kalman_filter() + { + matrix<double,2,2> R; + R = 0.3, 0, + 0, 0.3; + + // the variables in the state are + // x,y, x velocity, y velocity, x acceleration, and y acceleration + matrix<double,6,6> A; + A = 1, 0, 1, 0, 0, 0, + 0, 1, 0, 1, 0, 0, + 0, 0, 1, 0, 1, 0, + 0, 0, 0, 1, 0, 1, + 0, 0, 0, 0, 1, 0, + 0, 0, 0, 0, 0, 1; + + // the measurements only tell us the positions + matrix<double,2,6> H; + H = 1, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, 0; + + + kalman_filter<6,2> kf; + kf.set_measurement_noise(R); + matrix<double> pn = 0.01*identity_matrix<double,6>(); + kf.set_process_noise(pn); + kf.set_observation_model(H); + kf.set_transition_model(A); + + DLIB_TEST(equal(kf.get_observation_model() , H)); + DLIB_TEST(equal(kf.get_transition_model() , A)); + DLIB_TEST(equal(kf.get_measurement_noise() , R)); + DLIB_TEST(equal(kf.get_process_noise() , pn)); + DLIB_TEST(equal(kf.get_current_estimation_error_covariance() , identity_matrix(pn))); + + double kf_error = test_filter(kf, 300); + + dlog << LINFO << "kf error: "<< kf_error; + DLIB_TEST_MSG(kf_error < 0.75, kf_error); + } + +// ---------------------------------------------------------------------------------------- + + void test_rls_filter() + { + + rls_filter rls(10, 0.99, 0.1); + + DLIB_TEST(rls.get_window_size() == 10); + DLIB_TEST(rls.get_forget_factor() == 0.99); + DLIB_TEST(rls.get_c() == 0.1); + + double rls_error = test_filter(rls, 1000); + + dlog << LINFO << "rls error: "<< rls_error; + DLIB_TEST_MSG(rls_error < 0.75, rls_error); + } + +// ---------------------------------------------------------------------------------------- + + class filtering_tester : public tester + { + public: + filtering_tester ( + ) : + tester ("test_filtering", + "Runs tests on the filtering stuff (rls and kalman filters).") + {} + + void perform_test ( + ) + { + test_rls_filter(); + test_kalman_filter(); + } + } a; + +} + + |