diff options
Diffstat (limited to 'ml/dlib/dlib/control/approximate_linear_models.h')
-rw-r--r-- | ml/dlib/dlib/control/approximate_linear_models.h | 128 |
1 files changed, 128 insertions, 0 deletions
diff --git a/ml/dlib/dlib/control/approximate_linear_models.h b/ml/dlib/dlib/control/approximate_linear_models.h new file mode 100644 index 000000000..9732d71e9 --- /dev/null +++ b/ml/dlib/dlib/control/approximate_linear_models.h @@ -0,0 +1,128 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_APPROXIMATE_LINEAR_MODELS_Hh_ +#define DLIB_APPROXIMATE_LINEAR_MODELS_Hh_ + +#include "approximate_linear_models_abstract.h" +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + struct process_sample + { + typedef feature_extractor feature_extractor_type; + typedef typename feature_extractor::state_type state_type; + typedef typename feature_extractor::action_type action_type; + + process_sample(){} + + process_sample( + const state_type& s, + const action_type& a, + const state_type& n, + const double& r + ) : state(s), action(a), next_state(n), reward(r) {} + + state_type state; + action_type action; + state_type next_state; + double reward; + }; + + template < typename feature_extractor > + void serialize (const process_sample<feature_extractor>& item, std::ostream& out) + { + serialize(item.state, out); + serialize(item.action, out); + serialize(item.next_state, out); + serialize(item.reward, out); + } + + template < typename feature_extractor > + void deserialize (process_sample<feature_extractor>& item, std::istream& in) + { + deserialize(item.state, in); + deserialize(item.action, in); + deserialize(item.next_state, in); + deserialize(item.reward, in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename feature_extractor + > + class policy + { + public: + + typedef feature_extractor feature_extractor_type; + typedef typename feature_extractor::state_type state_type; + typedef typename feature_extractor::action_type action_type; + + + policy ( + ) + { + w.set_size(fe.num_features()); + w = 0; + } + + policy ( + const matrix<double,0,1>& weights_, + const feature_extractor& fe_ + ) : w(weights_), fe(fe_) {} + + action_type operator() ( + const state_type& state + ) const + { + return fe.find_best_action(state,w); + } + + const feature_extractor& get_feature_extractor ( + ) const { return fe; } + + const matrix<double,0,1>& get_weights ( + ) const { return w; } + + + private: + matrix<double,0,1> w; + feature_extractor fe; + }; + + template < typename feature_extractor > + inline void serialize(const policy<feature_extractor>& item, std::ostream& out) + { + int version = 1; + serialize(version, out); + serialize(item.get_feature_extractor(), out); + serialize(item.get_weights(), out); + } + template < typename feature_extractor > + inline void deserialize(policy<feature_extractor>& item, std::istream& in) + { + int version = 0; + deserialize(version, in); + if (version != 1) + throw serialization_error("Unexpected version found while deserializing dlib::policy object."); + feature_extractor fe; + matrix<double,0,1> w; + deserialize(fe, in); + deserialize(w, in); + item = policy<feature_extractor>(w,fe); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_APPROXIMATE_LINEAR_MODELS_Hh_ + |