diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-09 13:19:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-09 13:20:02 +0000 |
commit | 58daab21cd043e1dc37024a7f99b396788372918 (patch) | |
tree | 96771e43bb69f7c1c2b0b4f7374cb74d7866d0cb /ml/dlib/dlib/filtering/rls_filter.h | |
parent | Releasing debian version 1.43.2-1. (diff) | |
download | netdata-58daab21cd043e1dc37024a7f99b396788372918.tar.xz netdata-58daab21cd043e1dc37024a7f99b396788372918.zip |
Merging upstream version 1.44.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'ml/dlib/dlib/filtering/rls_filter.h')
-rw-r--r-- | ml/dlib/dlib/filtering/rls_filter.h | 198 |
1 files changed, 198 insertions, 0 deletions
diff --git a/ml/dlib/dlib/filtering/rls_filter.h b/ml/dlib/dlib/filtering/rls_filter.h new file mode 100644 index 000000000..4481ab3f4 --- /dev/null +++ b/ml/dlib/dlib/filtering/rls_filter.h @@ -0,0 +1,198 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_RLS_FiLTER_Hh_ +#define DLIB_RLS_FiLTER_Hh_ + +#include "rls_filter_abstract.h" +#include "../svm/rls.h" +#include <vector> +#include "../matrix.h" +#include "../sliding_buffer.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class rls_filter + { + /*! + CONVENTION + - data.size() == the number of variables in a measurement + - data[i].size() == data[j].size() for all i and j. + - data[i].size() == get_window_size() + - data[i][0] == most recent measurement of i-th variable given to update. + - data[i].back() == oldest measurement of i-th variable given to update + (or zero if we haven't seen this much data yet). + + - if (count <= 2) then + - count == number of times update(z) has been called + !*/ + public: + + rls_filter() + { + size = 5; + count = 0; + filter = rls(0.8, 100); + } + + explicit rls_filter ( + unsigned long size_, + double forget_factor = 0.8, + double C = 100 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 < forget_factor && forget_factor <= 1 && + 0 < C && size_ >= 2, + "\t rls_filter::rls_filter()" + << "\n\t invalid arguments were given to this function" + << "\n\t forget_factor: " << forget_factor + << "\n\t C: " << C + << "\n\t size_: " << size_ + << "\n\t this: " << this + ); + + size = size_; + count = 0; + filter = rls(forget_factor, C); + } + + double get_c( + ) const + { + return filter.get_c(); + } + + double get_forget_factor( + ) const + { + return filter.get_forget_factor(); + } + + unsigned long get_window_size ( + ) const + { + return size; + } + + void update ( + ) + { + if (filter.get_w().size() == 0) + return; + + for (unsigned long i = 0; i < data.size(); ++i) + { + // Put old predicted value into the circular buffer as if it was + // the measurement we just observed. But don't update the rls filter. + data[i].push_front(next(i)); + } + + // predict next state + for (long i = 0; i < next.size(); ++i) + next(i) = filter(mat(data[i])); + } + + template <typename EXP> + void update ( + const matrix_exp<EXP>& z + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_col_vector(z) == true && + z.size() != 0 && + (get_predicted_next_state().size()==0 || z.size()==get_predicted_next_state().size()), + "\t void rls_filter::update(z)" + << "\n\t invalid arguments were given to this function" + << "\n\t is_col_vector(z): " << is_col_vector(z) + << "\n\t z.size(): " << z.size() + << "\n\t get_predicted_next_state().size(): " << get_predicted_next_state().size() + << "\n\t this: " << this + ); + + // initialize data if necessary + if (data.size() == 0) + { + data.resize(z.size()); + for (long i = 0; i < z.size(); ++i) + data[i].assign(size, 0); + } + + + for (unsigned long i = 0; i < data.size(); ++i) + { + // Once there is some stuff in the circular buffer, start + // showing it to the rls filter so it can do its thing. + if (count >= 2) + { + filter.train(mat(data[i]), z(i)); + } + + // keep track of the measurements in our circular buffer + data[i].push_front(z(i)); + } + + // Don't bother with the filter until we have seen two samples + if (count >= 2) + { + // predict next state + for (long i = 0; i < z.size(); ++i) + next(i) = filter(mat(data[i])); + } + else + { + // Use current measurement as the next state prediction + // since we don't know any better at this point. + ++count; + next = matrix_cast<double>(z); + } + } + + const matrix<double,0,1>& get_predicted_next_state( + ) const + { + return next; + } + + friend inline void serialize(const rls_filter& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.count, out); + serialize(item.size, out); + serialize(item.filter, out); + serialize(item.next, out); + serialize(item.data, out); + } + + friend inline void deserialize(rls_filter& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw dlib::serialization_error("Unknown version number found while deserializing rls_filter object."); + + deserialize(item.count, in); + deserialize(item.size, in); + deserialize(item.filter, in); + deserialize(item.next, in); + deserialize(item.data, in); + } + + private: + + unsigned long count; + unsigned long size; + rls filter; + matrix<double,0,1> next; + std::vector<circular_buffer<double> > data; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RLS_FiLTER_Hh_ + |