diff options
Diffstat (limited to 'ml/dlib/dlib/filtering/rls_filter.h')
-rw-r--r-- | ml/dlib/dlib/filtering/rls_filter.h | 198 |
1 files changed, 0 insertions, 198 deletions
diff --git a/ml/dlib/dlib/filtering/rls_filter.h b/ml/dlib/dlib/filtering/rls_filter.h deleted file mode 100644 index 4481ab3f4..000000000 --- a/ml/dlib/dlib/filtering/rls_filter.h +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (C) 2012 Davis E. King (davis@dlib.net) -// License: Boost Software License See LICENSE.txt for the full license. -#ifndef DLIB_RLS_FiLTER_Hh_ -#define DLIB_RLS_FiLTER_Hh_ - -#include "rls_filter_abstract.h" -#include "../svm/rls.h" -#include <vector> -#include "../matrix.h" -#include "../sliding_buffer.h" - -namespace dlib -{ - -// ---------------------------------------------------------------------------------------- - - class rls_filter - { - /*! - CONVENTION - - data.size() == the number of variables in a measurement - - data[i].size() == data[j].size() for all i and j. - - data[i].size() == get_window_size() - - data[i][0] == most recent measurement of i-th variable given to update. - - data[i].back() == oldest measurement of i-th variable given to update - (or zero if we haven't seen this much data yet). - - - if (count <= 2) then - - count == number of times update(z) has been called - !*/ - public: - - rls_filter() - { - size = 5; - count = 0; - filter = rls(0.8, 100); - } - - explicit rls_filter ( - unsigned long size_, - double forget_factor = 0.8, - double C = 100 - ) - { - // make sure requires clause is not broken - DLIB_ASSERT(0 < forget_factor && forget_factor <= 1 && - 0 < C && size_ >= 2, - "\t rls_filter::rls_filter()" - << "\n\t invalid arguments were given to this function" - << "\n\t forget_factor: " << forget_factor - << "\n\t C: " << C - << "\n\t size_: " << size_ - << "\n\t this: " << this - ); - - size = size_; - count = 0; - filter = rls(forget_factor, C); - } - - double get_c( - ) const - { - return filter.get_c(); - } - - double get_forget_factor( - ) const - { - return filter.get_forget_factor(); - } - - unsigned long get_window_size ( - ) const - { - return size; - } - - void update ( - ) - { - if (filter.get_w().size() == 0) - return; - - for (unsigned long i = 0; i < data.size(); ++i) - { - // Put old predicted value into the circular buffer as if it was - // the measurement we just observed. But don't update the rls filter. - data[i].push_front(next(i)); - } - - // predict next state - for (long i = 0; i < next.size(); ++i) - next(i) = filter(mat(data[i])); - } - - template <typename EXP> - void update ( - const matrix_exp<EXP>& z - ) - { - // make sure requires clause is not broken - DLIB_ASSERT(is_col_vector(z) == true && - z.size() != 0 && - (get_predicted_next_state().size()==0 || z.size()==get_predicted_next_state().size()), - "\t void rls_filter::update(z)" - << "\n\t invalid arguments were given to this function" - << "\n\t is_col_vector(z): " << is_col_vector(z) - << "\n\t z.size(): " << z.size() - << "\n\t get_predicted_next_state().size(): " << get_predicted_next_state().size() - << "\n\t this: " << this - ); - - // initialize data if necessary - if (data.size() == 0) - { - data.resize(z.size()); - for (long i = 0; i < z.size(); ++i) - data[i].assign(size, 0); - } - - - for (unsigned long i = 0; i < data.size(); ++i) - { - // Once there is some stuff in the circular buffer, start - // showing it to the rls filter so it can do its thing. - if (count >= 2) - { - filter.train(mat(data[i]), z(i)); - } - - // keep track of the measurements in our circular buffer - data[i].push_front(z(i)); - } - - // Don't bother with the filter until we have seen two samples - if (count >= 2) - { - // predict next state - for (long i = 0; i < z.size(); ++i) - next(i) = filter(mat(data[i])); - } - else - { - // Use current measurement as the next state prediction - // since we don't know any better at this point. - ++count; - next = matrix_cast<double>(z); - } - } - - const matrix<double,0,1>& get_predicted_next_state( - ) const - { - return next; - } - - friend inline void serialize(const rls_filter& item, std::ostream& out) - { - int version = 1; - serialize(version, out); - serialize(item.count, out); - serialize(item.size, out); - serialize(item.filter, out); - serialize(item.next, out); - serialize(item.data, out); - } - - friend inline void deserialize(rls_filter& item, std::istream& in) - { - int version = 0; - deserialize(version, in); - if (version != 1) - throw dlib::serialization_error("Unknown version number found while deserializing rls_filter object."); - - deserialize(item.count, in); - deserialize(item.size, in); - deserialize(item.filter, in); - deserialize(item.next, in); - deserialize(item.data, in); - } - - private: - - unsigned long count; - unsigned long size; - rls filter; - matrix<double,0,1> next; - std::vector<circular_buffer<double> > data; - }; - -// ---------------------------------------------------------------------------------------- - -} - -#endif // DLIB_RLS_FiLTER_Hh_ - |