summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/mlp/mlp_kernel_1.h
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/mlp/mlp_kernel_1.h')
-rw-r--r--ml/dlib/dlib/mlp/mlp_kernel_1.h394
1 files changed, 394 insertions, 0 deletions
diff --git a/ml/dlib/dlib/mlp/mlp_kernel_1.h b/ml/dlib/dlib/mlp/mlp_kernel_1.h
new file mode 100644
index 00000000..d420eea9
--- /dev/null
+++ b/ml/dlib/dlib/mlp/mlp_kernel_1.h
@@ -0,0 +1,394 @@
+// Copyright (C) 2007 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_MLp_KERNEL_1_
+#define DLIB_MLp_KERNEL_1_
+
+#include "../algs.h"
+#include "../serialize.h"
+#include "../matrix.h"
+#include "../rand.h"
+#include "mlp_kernel_abstract.h"
+#include <ctime>
+#include <sstream>
+
+namespace dlib
+{
+
+ class mlp_kernel_1 : noncopyable
+ {
+ /*!
+ INITIAL VALUE
+ The network is initially initialized with random weights
+
+ CONVENTION
+ - input_layer_nodes() == input_nodes
+ - first_hidden_layer_nodes() == first_hidden_nodes
+ - second_hidden_layer_nodes() == second_hidden_nodes
+ - output_layer_nodes() == output_nodes
+ - get_alpha == alpha
+ - get_momentum() == momentum
+
+
+ - if (second_hidden_nodes == 0) then
+ - for all i and j:
+ - w1(i,j) == the weight on the link from node i in the first hidden layer
+ to input node j
+ - w3(i,j) == the weight on the link from node i in the output layer
+ to first hidden layer node j
+ - for all i and j:
+ - w1m == the momentum terms for w1 from the previous update
+ - w3m == the momentum terms for w3 from the previous update
+ - else
+ - for all i and j:
+ - w1(i,j) == the weight on the link from node i in the first hidden layer
+ to input node j
+ - w2(i,j) == the weight on the link from node i in the second hidden layer
+ to first hidden layer node j
+ - w3(i,j) == the weight on the link from node i in the output layer
+ to second hidden layer node j
+ - for all i and j:
+ - w1m == the momentum terms for w1 from the previous update
+ - w2m == the momentum terms for w2 from the previous update
+ - w3m == the momentum terms for w3 from the previous update
+ !*/
+
+ public:
+
+ mlp_kernel_1 (
+ long nodes_in_input_layer,
+ long nodes_in_first_hidden_layer,
+ long nodes_in_second_hidden_layer = 0,
+ long nodes_in_output_layer = 1,
+ double alpha_ = 0.1,
+ double momentum_ = 0.8
+ ) :
+ input_nodes(nodes_in_input_layer),
+ first_hidden_nodes(nodes_in_first_hidden_layer),
+ second_hidden_nodes(nodes_in_second_hidden_layer),
+ output_nodes(nodes_in_output_layer),
+ alpha(alpha_),
+ momentum(momentum_)
+ {
+
+ // seed the random number generator
+ std::ostringstream sout;
+ sout << time(0);
+ rand_nums.set_seed(sout.str());
+
+ w1.set_size(first_hidden_nodes+1, input_nodes+1);
+ w1m.set_size(first_hidden_nodes+1, input_nodes+1);
+ z.set_size(input_nodes+1,1);
+
+ if (second_hidden_nodes != 0)
+ {
+ w2.set_size(second_hidden_nodes+1, first_hidden_nodes+1);
+ w3.set_size(output_nodes, second_hidden_nodes+1);
+
+ w2m.set_size(second_hidden_nodes+1, first_hidden_nodes+1);
+ w3m.set_size(output_nodes, second_hidden_nodes+1);
+ }
+ else
+ {
+ w3.set_size(output_nodes, first_hidden_nodes+1);
+
+ w3m.set_size(output_nodes, first_hidden_nodes+1);
+ }
+
+ reset();
+ }
+
+ virtual ~mlp_kernel_1 (
+ ) {}
+
+ void reset (
+ )
+ {
+ // randomize the weights for the first layer
+ for (long r = 0; r < w1.nr(); ++r)
+ for (long c = 0; c < w1.nc(); ++c)
+ w1(r,c) = rand_nums.get_random_double();
+
+ // randomize the weights for the second layer
+ for (long r = 0; r < w2.nr(); ++r)
+ for (long c = 0; c < w2.nc(); ++c)
+ w2(r,c) = rand_nums.get_random_double();
+
+ // randomize the weights for the third layer
+ for (long r = 0; r < w3.nr(); ++r)
+ for (long c = 0; c < w3.nc(); ++c)
+ w3(r,c) = rand_nums.get_random_double();
+
+ // zero all the momentum terms
+ set_all_elements(w1m,0);
+ set_all_elements(w2m,0);
+ set_all_elements(w3m,0);
+ }
+
+ long input_layer_nodes (
+ ) const { return input_nodes; }
+
+ long first_hidden_layer_nodes (
+ ) const { return first_hidden_nodes; }
+
+ long second_hidden_layer_nodes (
+ ) const { return second_hidden_nodes; }
+
+ long output_layer_nodes (
+ ) const { return output_nodes; }
+
+ double get_alpha (
+ ) const { return alpha; }
+
+ double get_momentum (
+ ) const { return momentum; }
+
+ template <typename EXP>
+ const matrix<double> operator() (
+ const matrix_exp<EXP>& in
+ ) const
+ {
+ for (long i = 0; i < in.nr(); ++i)
+ z(i) = in(i);
+ // insert the bias
+ z(z.nr()-1) = -1;
+
+ tmp1 = sigmoid(w1*z);
+ // insert the bias
+ tmp1(tmp1.nr()-1) = -1;
+
+ if (second_hidden_nodes == 0)
+ {
+ return sigmoid(w3*tmp1);
+ }
+ else
+ {
+ tmp2 = sigmoid(w2*tmp1);
+ // insert the bias
+ tmp2(tmp2.nr()-1) = -1;
+
+ return sigmoid(w3*tmp2);
+ }
+ }
+
+ template <typename EXP1, typename EXP2>
+ void train (
+ const matrix_exp<EXP1>& example_in,
+ const matrix_exp<EXP2>& example_out
+ )
+ {
+ for (long i = 0; i < example_in.nr(); ++i)
+ z(i) = example_in(i);
+ // insert the bias
+ z(z.nr()-1) = -1;
+
+ tmp1 = sigmoid(w1*z);
+ // insert the bias
+ tmp1(tmp1.nr()-1) = -1;
+
+
+ if (second_hidden_nodes == 0)
+ {
+ o = sigmoid(w3*tmp1);
+
+ // now compute the errors and propagate them backwards though the network
+ e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o);
+ e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w3)*e3 );
+
+ // compute the new weight updates
+ w3m = alpha * e3*trans(tmp1) + w3m*momentum;
+ w1m = alpha * e1*trans(z) + w1m*momentum;
+
+ // now update the weights
+ w1 += w1m;
+ w3 += w3m;
+ }
+ else
+ {
+ tmp2 = sigmoid(w2*tmp1);
+ // insert the bias
+ tmp2(tmp2.nr()-1) = -1;
+
+ o = sigmoid(w3*tmp2);
+
+
+ // now compute the errors and propagate them backwards though the network
+ e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o);
+ e2 = pointwise_multiply(tmp2, uniform_matrix<double>(second_hidden_nodes+1,1,1.0) - tmp2, trans(w3)*e3 );
+ e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w2)*e2 );
+
+ // compute the new weight updates
+ w3m = alpha * e3*trans(tmp2) + w3m*momentum;
+ w2m = alpha * e2*trans(tmp1) + w2m*momentum;
+ w1m = alpha * e1*trans(z) + w1m*momentum;
+
+ // now update the weights
+ w1 += w1m;
+ w2 += w2m;
+ w3 += w3m;
+ }
+ }
+
+ template <typename EXP>
+ void train (
+ const matrix_exp<EXP>& example_in,
+ double example_out
+ )
+ {
+ matrix<double,1,1> e_out;
+ e_out(0) = example_out;
+ train(example_in,e_out);
+ }
+
+ double get_average_change (
+ ) const
+ {
+ // sum up all the weight changes
+ double delta = sum(abs(w1m)) + sum(abs(w2m)) + sum(abs(w3m));
+
+ // divide by the number of weights
+ delta /= w1m.nr()*w1m.nc() +
+ w2m.nr()*w2m.nc() +
+ w3m.nr()*w3m.nc();
+
+ return delta;
+ }
+
+ void swap (
+ mlp_kernel_1& item
+ )
+ {
+ exchange(input_nodes, item.input_nodes);
+ exchange(first_hidden_nodes, item.first_hidden_nodes);
+ exchange(second_hidden_nodes, item.second_hidden_nodes);
+ exchange(output_nodes, item.output_nodes);
+ exchange(alpha, item.alpha);
+ exchange(momentum, item.momentum);
+
+ w1.swap(item.w1);
+ w2.swap(item.w2);
+ w3.swap(item.w3);
+
+ w1m.swap(item.w1m);
+ w2m.swap(item.w2m);
+ w3m.swap(item.w3m);
+
+ // even swap the temporary matrices because this may ultimately result in
+ // fewer calls to new and delete.
+ e1.swap(item.e1);
+ e2.swap(item.e2);
+ e3.swap(item.e3);
+ z.swap(item.z);
+ tmp1.swap(item.tmp1);
+ tmp2.swap(item.tmp2);
+ o.swap(item.o);
+ }
+
+
+ friend void serialize (
+ const mlp_kernel_1& item,
+ std::ostream& out
+ );
+
+ friend void deserialize (
+ mlp_kernel_1& item,
+ std::istream& in
+ );
+
+ private:
+
+ long input_nodes;
+ long first_hidden_nodes;
+ long second_hidden_nodes;
+ long output_nodes;
+ double alpha;
+ double momentum;
+
+ matrix<double> w1;
+ matrix<double> w2;
+ matrix<double> w3;
+
+ matrix<double> w1m;
+ matrix<double> w2m;
+ matrix<double> w3m;
+
+
+ rand rand_nums;
+
+ // temporary storage
+ mutable matrix<double> e1, e2, e3;
+ mutable matrix<double> z, tmp1, tmp2, o;
+ };
+
+ inline void swap (
+ mlp_kernel_1& a,
+ mlp_kernel_1& b
+ ) { a.swap(b); }
+
+// ----------------------------------------------------------------------------------------
+
+ inline void serialize (
+ const mlp_kernel_1& item,
+ std::ostream& out
+ )
+ {
+ try
+ {
+ serialize(item.input_nodes, out);
+ serialize(item.first_hidden_nodes, out);
+ serialize(item.second_hidden_nodes, out);
+ serialize(item.output_nodes, out);
+ serialize(item.alpha, out);
+ serialize(item.momentum, out);
+
+ serialize(item.w1, out);
+ serialize(item.w2, out);
+ serialize(item.w3, out);
+
+ serialize(item.w1m, out);
+ serialize(item.w2m, out);
+ serialize(item.w3m, out);
+ }
+ catch (serialization_error& e)
+ {
+ throw serialization_error(e.info + "\n while serializing object of type mlp_kernel_1");
+ }
+ }
+
+ inline void deserialize (
+ mlp_kernel_1& item,
+ std::istream& in
+ )
+ {
+ try
+ {
+ deserialize(item.input_nodes, in);
+ deserialize(item.first_hidden_nodes, in);
+ deserialize(item.second_hidden_nodes, in);
+ deserialize(item.output_nodes, in);
+ deserialize(item.alpha, in);
+ deserialize(item.momentum, in);
+
+ deserialize(item.w1, in);
+ deserialize(item.w2, in);
+ deserialize(item.w3, in);
+
+ deserialize(item.w1m, in);
+ deserialize(item.w2m, in);
+ deserialize(item.w3m, in);
+
+ item.z.set_size(item.input_nodes+1,1);
+ }
+ catch (serialization_error& e)
+ {
+ // give item a reasonable value since the deserialization failed
+ mlp_kernel_1(1,1).swap(item);
+ throw serialization_error(e.info + "\n while deserializing object of type mlp_kernel_1");
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_MLp_KERNEL_1_
+