diff options
Diffstat (limited to 'ml/dlib/dlib/svm/rbf_network_abstract.h')
-rw-r--r-- | ml/dlib/dlib/svm/rbf_network_abstract.h | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/rbf_network_abstract.h b/ml/dlib/dlib/svm/rbf_network_abstract.h new file mode 100644 index 000000000..782a4bdbd --- /dev/null +++ b/ml/dlib/dlib/svm/rbf_network_abstract.h @@ -0,0 +1,132 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_RBf_NETWORK_ABSTRACT_ +#ifdef DLIB_RBf_NETWORK_ABSTRACT_ + +#include "../algs.h" +#include "function_abstract.h" +#include "kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename K + > + class rbf_network_trainer + { + /*! + REQUIREMENTS ON K + is a kernel function object as defined in dlib/svm/kernel_abstract.h + (since this is supposed to be a RBF network it is probably reasonable + to use some sort of radial basis kernel) + + INITIAL VALUE + - get_num_centers() == 10 + + WHAT THIS OBJECT REPRESENTS + This object implements a trainer for a radial basis function network. + + The implementation of this algorithm follows the normal RBF training + process. For more details see the code or the Wikipedia article + about RBF networks. + !*/ + + public: + typedef K kernel_type; + typedef typename kernel_type::scalar_type scalar_type; + typedef typename kernel_type::sample_type sample_type; + typedef typename kernel_type::mem_manager_type mem_manager_type; + typedef decision_function<kernel_type> trained_function_type; + + rbf_network_trainer ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void set_kernel ( + const kernel_type& k + ); + /*! + ensures + - #get_kernel() == k + !*/ + + const kernel_type& get_kernel ( + ) const; + /*! + ensures + - returns a copy of the kernel function in use by this object + !*/ + + void set_num_centers ( + const unsigned long num_centers + ); + /*! + ensures + - #get_num_centers() == num_centers + !*/ + + const unsigned long get_num_centers ( + ) const; + /*! + ensures + - returns the maximum number of centers (a.k.a. basis_vectors in the + trained decision_function) you will get when you train this object on data. + !*/ + + template < + typename in_sample_vector_type, + typename in_scalar_vector_type + > + const decision_function<kernel_type> train ( + const in_sample_vector_type& x, + const in_scalar_vector_type& y + ) const + /*! + requires + - x == a matrix or something convertible to a matrix via mat(). + Also, x should contain sample_type objects. + - y == a matrix or something convertible to a matrix via mat(). + Also, y should contain scalar_type objects. + - is_learning_problem(x,y) == true + ensures + - trains a RBF network given the training samples in x and + labels in y and returns the resulting decision_function + throws + - std::bad_alloc + !*/ + + void swap ( + rbf_network_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template <typename K> + void swap ( + rbf_network_trainer<K>& a, + rbf_network_trainer<K>& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_RBf_NETWORK_ABSTRACT_ + + + |