summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/control/approximate_linear_models.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/control/approximate_linear_models.h')
-rw-r--r--ml/dlib/dlib/control/approximate_linear_models.h128
1 files changed, 128 insertions, 0 deletions
diff --git a/ml/dlib/dlib/control/approximate_linear_models.h b/ml/dlib/dlib/control/approximate_linear_models.h
new file mode 100644
index 000000000..9732d71e9
--- /dev/null
+++ b/ml/dlib/dlib/control/approximate_linear_models.h
@@ -0,0 +1,128 @@
+// Copyright (C) 2015 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
+#define DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
+
+#include "approximate_linear_models_abstract.h"
+#include "../matrix.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ struct process_sample
+ {
+ typedef feature_extractor feature_extractor_type;
+ typedef typename feature_extractor::state_type state_type;
+ typedef typename feature_extractor::action_type action_type;
+
+ process_sample(){}
+
+ process_sample(
+ const state_type& s,
+ const action_type& a,
+ const state_type& n,
+ const double& r
+ ) : state(s), action(a), next_state(n), reward(r) {}
+
+ state_type state;
+ action_type action;
+ state_type next_state;
+ double reward;
+ };
+
+ template < typename feature_extractor >
+ void serialize (const process_sample<feature_extractor>& item, std::ostream& out)
+ {
+ serialize(item.state, out);
+ serialize(item.action, out);
+ serialize(item.next_state, out);
+ serialize(item.reward, out);
+ }
+
+ template < typename feature_extractor >
+ void deserialize (process_sample<feature_extractor>& item, std::istream& in)
+ {
+ deserialize(item.state, in);
+ deserialize(item.action, in);
+ deserialize(item.next_state, in);
+ deserialize(item.reward, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename feature_extractor
+ >
+ class policy
+ {
+ public:
+
+ typedef feature_extractor feature_extractor_type;
+ typedef typename feature_extractor::state_type state_type;
+ typedef typename feature_extractor::action_type action_type;
+
+
+ policy (
+ )
+ {
+ w.set_size(fe.num_features());
+ w = 0;
+ }
+
+ policy (
+ const matrix<double,0,1>& weights_,
+ const feature_extractor& fe_
+ ) : w(weights_), fe(fe_) {}
+
+ action_type operator() (
+ const state_type& state
+ ) const
+ {
+ return fe.find_best_action(state,w);
+ }
+
+ const feature_extractor& get_feature_extractor (
+ ) const { return fe; }
+
+ const matrix<double,0,1>& get_weights (
+ ) const { return w; }
+
+
+ private:
+ matrix<double,0,1> w;
+ feature_extractor fe;
+ };
+
+ template < typename feature_extractor >
+ inline void serialize(const policy<feature_extractor>& item, std::ostream& out)
+ {
+ int version = 1;
+ serialize(version, out);
+ serialize(item.get_feature_extractor(), out);
+ serialize(item.get_weights(), out);
+ }
+ template < typename feature_extractor >
+ inline void deserialize(policy<feature_extractor>& item, std::istream& in)
+ {
+ int version = 0;
+ deserialize(version, in);
+ if (version != 1)
+ throw serialization_error("Unexpected version found while deserializing dlib::policy object.");
+ feature_extractor fe;
+ matrix<double,0,1> w;
+ deserialize(fe, in);
+ deserialize(w, in);
+ item = policy<feature_extractor>(w,fe);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_APPROXIMATE_LINEAR_MODELS_Hh_
+