summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/rbf_network_abstract.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/svm/rbf_network_abstract.h')
-rw-r--r--ml/dlib/dlib/svm/rbf_network_abstract.h132
1 files changed, 132 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/rbf_network_abstract.h b/ml/dlib/dlib/svm/rbf_network_abstract.h
new file mode 100644
index 000000000..782a4bdbd
--- /dev/null
+++ b/ml/dlib/dlib/svm/rbf_network_abstract.h
@@ -0,0 +1,132 @@
+// Copyright (C) 2008 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_RBf_NETWORK_ABSTRACT_
+#ifdef DLIB_RBf_NETWORK_ABSTRACT_
+
+#include "../algs.h"
+#include "function_abstract.h"
+#include "kernel_abstract.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename K
+ >
+ class rbf_network_trainer
+ {
+ /*!
+ REQUIREMENTS ON K
+ is a kernel function object as defined in dlib/svm/kernel_abstract.h
+ (since this is supposed to be a RBF network it is probably reasonable
+ to use some sort of radial basis kernel)
+
+ INITIAL VALUE
+ - get_num_centers() == 10
+
+ WHAT THIS OBJECT REPRESENTS
+ This object implements a trainer for a radial basis function network.
+
+ The implementation of this algorithm follows the normal RBF training
+ process. For more details see the code or the Wikipedia article
+ about RBF networks.
+ !*/
+
+ public:
+ typedef K kernel_type;
+ typedef typename kernel_type::scalar_type scalar_type;
+ typedef typename kernel_type::sample_type sample_type;
+ typedef typename kernel_type::mem_manager_type mem_manager_type;
+ typedef decision_function<kernel_type> trained_function_type;
+
+ rbf_network_trainer (
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ !*/
+
+ void set_kernel (
+ const kernel_type& k
+ );
+ /*!
+ ensures
+ - #get_kernel() == k
+ !*/
+
+ const kernel_type& get_kernel (
+ ) const;
+ /*!
+ ensures
+ - returns a copy of the kernel function in use by this object
+ !*/
+
+ void set_num_centers (
+ const unsigned long num_centers
+ );
+ /*!
+ ensures
+ - #get_num_centers() == num_centers
+ !*/
+
+ const unsigned long get_num_centers (
+ ) const;
+ /*!
+ ensures
+ - returns the maximum number of centers (a.k.a. basis_vectors in the
+ trained decision_function) you will get when you train this object on data.
+ !*/
+
+ template <
+ typename in_sample_vector_type,
+ typename in_scalar_vector_type
+ >
+ const decision_function<kernel_type> train (
+ const in_sample_vector_type& x,
+ const in_scalar_vector_type& y
+ ) const
+ /*!
+ requires
+ - x == a matrix or something convertible to a matrix via mat().
+ Also, x should contain sample_type objects.
+ - y == a matrix or something convertible to a matrix via mat().
+ Also, y should contain scalar_type objects.
+ - is_learning_problem(x,y) == true
+ ensures
+ - trains a RBF network given the training samples in x and
+ labels in y and returns the resulting decision_function
+ throws
+ - std::bad_alloc
+ !*/
+
+ void swap (
+ rbf_network_trainer& item
+ );
+ /*!
+ ensures
+ - swaps *this and item
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename K>
+ void swap (
+ rbf_network_trainer<K>& a,
+ rbf_network_trainer<K>& b
+ ) { a.swap(b); }
+ /*!
+ provides a global swap
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_RBf_NETWORK_ABSTRACT_
+
+
+