diff options
Diffstat (limited to 'ml/dlib/dlib/mlp/mlp_kernel_1.h')
-rw-r--r-- | ml/dlib/dlib/mlp/mlp_kernel_1.h | 394 |
1 files changed, 394 insertions, 0 deletions
diff --git a/ml/dlib/dlib/mlp/mlp_kernel_1.h b/ml/dlib/dlib/mlp/mlp_kernel_1.h new file mode 100644 index 00000000..d420eea9 --- /dev/null +++ b/ml/dlib/dlib/mlp/mlp_kernel_1.h @@ -0,0 +1,394 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MLp_KERNEL_1_ +#define DLIB_MLp_KERNEL_1_ + +#include "../algs.h" +#include "../serialize.h" +#include "../matrix.h" +#include "../rand.h" +#include "mlp_kernel_abstract.h" +#include <ctime> +#include <sstream> + +namespace dlib +{ + + class mlp_kernel_1 : noncopyable + { + /*! + INITIAL VALUE + The network is initially initialized with random weights + + CONVENTION + - input_layer_nodes() == input_nodes + - first_hidden_layer_nodes() == first_hidden_nodes + - second_hidden_layer_nodes() == second_hidden_nodes + - output_layer_nodes() == output_nodes + - get_alpha == alpha + - get_momentum() == momentum + + + - if (second_hidden_nodes == 0) then + - for all i and j: + - w1(i,j) == the weight on the link from node i in the first hidden layer + to input node j + - w3(i,j) == the weight on the link from node i in the output layer + to first hidden layer node j + - for all i and j: + - w1m == the momentum terms for w1 from the previous update + - w3m == the momentum terms for w3 from the previous update + - else + - for all i and j: + - w1(i,j) == the weight on the link from node i in the first hidden layer + to input node j + - w2(i,j) == the weight on the link from node i in the second hidden layer + to first hidden layer node j + - w3(i,j) == the weight on the link from node i in the output layer + to second hidden layer node j + - for all i and j: + - w1m == the momentum terms for w1 from the previous update + - w2m == the momentum terms for w2 from the previous update + - w3m == the momentum terms for w3 from the previous update + !*/ + + public: + + mlp_kernel_1 ( + long nodes_in_input_layer, + long nodes_in_first_hidden_layer, + long nodes_in_second_hidden_layer = 0, + long nodes_in_output_layer = 1, + double alpha_ = 0.1, + double momentum_ = 0.8 + ) : + input_nodes(nodes_in_input_layer), + first_hidden_nodes(nodes_in_first_hidden_layer), + second_hidden_nodes(nodes_in_second_hidden_layer), + output_nodes(nodes_in_output_layer), + alpha(alpha_), + momentum(momentum_) + { + + // seed the random number generator + std::ostringstream sout; + sout << time(0); + rand_nums.set_seed(sout.str()); + + w1.set_size(first_hidden_nodes+1, input_nodes+1); + w1m.set_size(first_hidden_nodes+1, input_nodes+1); + z.set_size(input_nodes+1,1); + + if (second_hidden_nodes != 0) + { + w2.set_size(second_hidden_nodes+1, first_hidden_nodes+1); + w3.set_size(output_nodes, second_hidden_nodes+1); + + w2m.set_size(second_hidden_nodes+1, first_hidden_nodes+1); + w3m.set_size(output_nodes, second_hidden_nodes+1); + } + else + { + w3.set_size(output_nodes, first_hidden_nodes+1); + + w3m.set_size(output_nodes, first_hidden_nodes+1); + } + + reset(); + } + + virtual ~mlp_kernel_1 ( + ) {} + + void reset ( + ) + { + // randomize the weights for the first layer + for (long r = 0; r < w1.nr(); ++r) + for (long c = 0; c < w1.nc(); ++c) + w1(r,c) = rand_nums.get_random_double(); + + // randomize the weights for the second layer + for (long r = 0; r < w2.nr(); ++r) + for (long c = 0; c < w2.nc(); ++c) + w2(r,c) = rand_nums.get_random_double(); + + // randomize the weights for the third layer + for (long r = 0; r < w3.nr(); ++r) + for (long c = 0; c < w3.nc(); ++c) + w3(r,c) = rand_nums.get_random_double(); + + // zero all the momentum terms + set_all_elements(w1m,0); + set_all_elements(w2m,0); + set_all_elements(w3m,0); + } + + long input_layer_nodes ( + ) const { return input_nodes; } + + long first_hidden_layer_nodes ( + ) const { return first_hidden_nodes; } + + long second_hidden_layer_nodes ( + ) const { return second_hidden_nodes; } + + long output_layer_nodes ( + ) const { return output_nodes; } + + double get_alpha ( + ) const { return alpha; } + + double get_momentum ( + ) const { return momentum; } + + template <typename EXP> + const matrix<double> operator() ( + const matrix_exp<EXP>& in + ) const + { + for (long i = 0; i < in.nr(); ++i) + z(i) = in(i); + // insert the bias + z(z.nr()-1) = -1; + + tmp1 = sigmoid(w1*z); + // insert the bias + tmp1(tmp1.nr()-1) = -1; + + if (second_hidden_nodes == 0) + { + return sigmoid(w3*tmp1); + } + else + { + tmp2 = sigmoid(w2*tmp1); + // insert the bias + tmp2(tmp2.nr()-1) = -1; + + return sigmoid(w3*tmp2); + } + } + + template <typename EXP1, typename EXP2> + void train ( + const matrix_exp<EXP1>& example_in, + const matrix_exp<EXP2>& example_out + ) + { + for (long i = 0; i < example_in.nr(); ++i) + z(i) = example_in(i); + // insert the bias + z(z.nr()-1) = -1; + + tmp1 = sigmoid(w1*z); + // insert the bias + tmp1(tmp1.nr()-1) = -1; + + + if (second_hidden_nodes == 0) + { + o = sigmoid(w3*tmp1); + + // now compute the errors and propagate them backwards though the network + e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o); + e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w3)*e3 ); + + // compute the new weight updates + w3m = alpha * e3*trans(tmp1) + w3m*momentum; + w1m = alpha * e1*trans(z) + w1m*momentum; + + // now update the weights + w1 += w1m; + w3 += w3m; + } + else + { + tmp2 = sigmoid(w2*tmp1); + // insert the bias + tmp2(tmp2.nr()-1) = -1; + + o = sigmoid(w3*tmp2); + + + // now compute the errors and propagate them backwards though the network + e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o); + e2 = pointwise_multiply(tmp2, uniform_matrix<double>(second_hidden_nodes+1,1,1.0) - tmp2, trans(w3)*e3 ); + e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w2)*e2 ); + + // compute the new weight updates + w3m = alpha * e3*trans(tmp2) + w3m*momentum; + w2m = alpha * e2*trans(tmp1) + w2m*momentum; + w1m = alpha * e1*trans(z) + w1m*momentum; + + // now update the weights + w1 += w1m; + w2 += w2m; + w3 += w3m; + } + } + + template <typename EXP> + void train ( + const matrix_exp<EXP>& example_in, + double example_out + ) + { + matrix<double,1,1> e_out; + e_out(0) = example_out; + train(example_in,e_out); + } + + double get_average_change ( + ) const + { + // sum up all the weight changes + double delta = sum(abs(w1m)) + sum(abs(w2m)) + sum(abs(w3m)); + + // divide by the number of weights + delta /= w1m.nr()*w1m.nc() + + w2m.nr()*w2m.nc() + + w3m.nr()*w3m.nc(); + + return delta; + } + + void swap ( + mlp_kernel_1& item + ) + { + exchange(input_nodes, item.input_nodes); + exchange(first_hidden_nodes, item.first_hidden_nodes); + exchange(second_hidden_nodes, item.second_hidden_nodes); + exchange(output_nodes, item.output_nodes); + exchange(alpha, item.alpha); + exchange(momentum, item.momentum); + + w1.swap(item.w1); + w2.swap(item.w2); + w3.swap(item.w3); + + w1m.swap(item.w1m); + w2m.swap(item.w2m); + w3m.swap(item.w3m); + + // even swap the temporary matrices because this may ultimately result in + // fewer calls to new and delete. + e1.swap(item.e1); + e2.swap(item.e2); + e3.swap(item.e3); + z.swap(item.z); + tmp1.swap(item.tmp1); + tmp2.swap(item.tmp2); + o.swap(item.o); + } + + + friend void serialize ( + const mlp_kernel_1& item, + std::ostream& out + ); + + friend void deserialize ( + mlp_kernel_1& item, + std::istream& in + ); + + private: + + long input_nodes; + long first_hidden_nodes; + long second_hidden_nodes; + long output_nodes; + double alpha; + double momentum; + + matrix<double> w1; + matrix<double> w2; + matrix<double> w3; + + matrix<double> w1m; + matrix<double> w2m; + matrix<double> w3m; + + + rand rand_nums; + + // temporary storage + mutable matrix<double> e1, e2, e3; + mutable matrix<double> z, tmp1, tmp2, o; + }; + + inline void swap ( + mlp_kernel_1& a, + mlp_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + inline void serialize ( + const mlp_kernel_1& item, + std::ostream& out + ) + { + try + { + serialize(item.input_nodes, out); + serialize(item.first_hidden_nodes, out); + serialize(item.second_hidden_nodes, out); + serialize(item.output_nodes, out); + serialize(item.alpha, out); + serialize(item.momentum, out); + + serialize(item.w1, out); + serialize(item.w2, out); + serialize(item.w3, out); + + serialize(item.w1m, out); + serialize(item.w2m, out); + serialize(item.w3m, out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type mlp_kernel_1"); + } + } + + inline void deserialize ( + mlp_kernel_1& item, + std::istream& in + ) + { + try + { + deserialize(item.input_nodes, in); + deserialize(item.first_hidden_nodes, in); + deserialize(item.second_hidden_nodes, in); + deserialize(item.output_nodes, in); + deserialize(item.alpha, in); + deserialize(item.momentum, in); + + deserialize(item.w1, in); + deserialize(item.w2, in); + deserialize(item.w3, in); + + deserialize(item.w1m, in); + deserialize(item.w2m, in); + deserialize(item.w3m, in); + + item.z.set_size(item.input_nodes+1,1); + } + catch (serialization_error& e) + { + // give item a reasonable value since the deserialization failed + mlp_kernel_1(1,1).swap(item); + throw serialization_error(e.info + "\n while deserializing object of type mlp_kernel_1"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MLp_KERNEL_1_ + |