summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/filtering/rls_filter.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/filtering/rls_filter.h')
-rw-r--r--ml/dlib/dlib/filtering/rls_filter.h198
1 files changed, 0 insertions, 198 deletions
diff --git a/ml/dlib/dlib/filtering/rls_filter.h b/ml/dlib/dlib/filtering/rls_filter.h
deleted file mode 100644
index 4481ab3f4..000000000
--- a/ml/dlib/dlib/filtering/rls_filter.h
+++ /dev/null
@@ -1,198 +0,0 @@
-// Copyright (C) 2012 Davis E. King (davis@dlib.net)
-// License: Boost Software License See LICENSE.txt for the full license.
-#ifndef DLIB_RLS_FiLTER_Hh_
-#define DLIB_RLS_FiLTER_Hh_
-
-#include "rls_filter_abstract.h"
-#include "../svm/rls.h"
-#include <vector>
-#include "../matrix.h"
-#include "../sliding_buffer.h"
-
-namespace dlib
-{
-
-// ----------------------------------------------------------------------------------------
-
- class rls_filter
- {
- /*!
- CONVENTION
- - data.size() == the number of variables in a measurement
- - data[i].size() == data[j].size() for all i and j.
- - data[i].size() == get_window_size()
- - data[i][0] == most recent measurement of i-th variable given to update.
- - data[i].back() == oldest measurement of i-th variable given to update
- (or zero if we haven't seen this much data yet).
-
- - if (count <= 2) then
- - count == number of times update(z) has been called
- !*/
- public:
-
- rls_filter()
- {
- size = 5;
- count = 0;
- filter = rls(0.8, 100);
- }
-
- explicit rls_filter (
- unsigned long size_,
- double forget_factor = 0.8,
- double C = 100
- )
- {
- // make sure requires clause is not broken
- DLIB_ASSERT(0 < forget_factor && forget_factor <= 1 &&
- 0 < C && size_ >= 2,
- "\t rls_filter::rls_filter()"
- << "\n\t invalid arguments were given to this function"
- << "\n\t forget_factor: " << forget_factor
- << "\n\t C: " << C
- << "\n\t size_: " << size_
- << "\n\t this: " << this
- );
-
- size = size_;
- count = 0;
- filter = rls(forget_factor, C);
- }
-
- double get_c(
- ) const
- {
- return filter.get_c();
- }
-
- double get_forget_factor(
- ) const
- {
- return filter.get_forget_factor();
- }
-
- unsigned long get_window_size (
- ) const
- {
- return size;
- }
-
- void update (
- )
- {
- if (filter.get_w().size() == 0)
- return;
-
- for (unsigned long i = 0; i < data.size(); ++i)
- {
- // Put old predicted value into the circular buffer as if it was
- // the measurement we just observed. But don't update the rls filter.
- data[i].push_front(next(i));
- }
-
- // predict next state
- for (long i = 0; i < next.size(); ++i)
- next(i) = filter(mat(data[i]));
- }
-
- template <typename EXP>
- void update (
- const matrix_exp<EXP>& z
- )
- {
- // make sure requires clause is not broken
- DLIB_ASSERT(is_col_vector(z) == true &&
- z.size() != 0 &&
- (get_predicted_next_state().size()==0 || z.size()==get_predicted_next_state().size()),
- "\t void rls_filter::update(z)"
- << "\n\t invalid arguments were given to this function"
- << "\n\t is_col_vector(z): " << is_col_vector(z)
- << "\n\t z.size(): " << z.size()
- << "\n\t get_predicted_next_state().size(): " << get_predicted_next_state().size()
- << "\n\t this: " << this
- );
-
- // initialize data if necessary
- if (data.size() == 0)
- {
- data.resize(z.size());
- for (long i = 0; i < z.size(); ++i)
- data[i].assign(size, 0);
- }
-
-
- for (unsigned long i = 0; i < data.size(); ++i)
- {
- // Once there is some stuff in the circular buffer, start
- // showing it to the rls filter so it can do its thing.
- if (count >= 2)
- {
- filter.train(mat(data[i]), z(i));
- }
-
- // keep track of the measurements in our circular buffer
- data[i].push_front(z(i));
- }
-
- // Don't bother with the filter until we have seen two samples
- if (count >= 2)
- {
- // predict next state
- for (long i = 0; i < z.size(); ++i)
- next(i) = filter(mat(data[i]));
- }
- else
- {
- // Use current measurement as the next state prediction
- // since we don't know any better at this point.
- ++count;
- next = matrix_cast<double>(z);
- }
- }
-
- const matrix<double,0,1>& get_predicted_next_state(
- ) const
- {
- return next;
- }
-
- friend inline void serialize(const rls_filter& item, std::ostream& out)
- {
- int version = 1;
- serialize(version, out);
- serialize(item.count, out);
- serialize(item.size, out);
- serialize(item.filter, out);
- serialize(item.next, out);
- serialize(item.data, out);
- }
-
- friend inline void deserialize(rls_filter& item, std::istream& in)
- {
- int version = 0;
- deserialize(version, in);
- if (version != 1)
- throw dlib::serialization_error("Unknown version number found while deserializing rls_filter object.");
-
- deserialize(item.count, in);
- deserialize(item.size, in);
- deserialize(item.filter, in);
- deserialize(item.next, in);
- deserialize(item.data, in);
- }
-
- private:
-
- unsigned long count;
- unsigned long size;
- rls filter;
- matrix<double,0,1> next;
- std::vector<circular_buffer<double> > data;
- };
-
-// ----------------------------------------------------------------------------------------
-
-}
-
-#endif // DLIB_RLS_FiLTER_Hh_
-