diff options
Diffstat (limited to 'ml/dlib/dlib/dnn/solvers_abstract.h')
-rw-r--r-- | ml/dlib/dlib/dnn/solvers_abstract.h | 204 |
1 files changed, 204 insertions, 0 deletions
diff --git a/ml/dlib/dlib/dnn/solvers_abstract.h b/ml/dlib/dlib/dnn/solvers_abstract.h new file mode 100644 index 000000000..d10ef163a --- /dev/null +++ b/ml/dlib/dlib/dnn/solvers_abstract.h @@ -0,0 +1,204 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_DNn_SOLVERS_ABSTRACT_H_ +#ifdef DLIB_DNn_SOLVERS_ABSTRACT_H_ + +#include "tensor_abstract.h" +#include <iostream> + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class EXAMPLE_SOLVER + { + /*! + WHAT THIS OBJECT REPRESENTS + A solver defines the parameter update rule for a single layer in a deep + neural network. It takes a parameter gradient vector and the layer's + parameters and tells you how the parameters should be updated. + Importantly, each solver instance is used with only one layer in a network. + This allows us to define solvers that have per layer state, for example, a + solver may keep a momentum term and apply it to its update rule. + + Note that there is no dlib::EXAMPLE_SOLVER type. It is shown here purely + to document the interface a solver object must implement. + !*/ + + public: + + EXAMPLE_SOLVER( + ); + + template <typename layer_type> + const tensor& operator() ( + const float learning_rate, + const layer_type& l, + const tensor& params_grad + ) + /*! + requires + - l.get_layer_params().size() != 0 + - have_same_dimensions(l.get_layer_params(), params_grad) == true. + - When this function is invoked on a particular solver instance, it is + always supplied with the same layer instance, l. That is, the solver is + allowed to remember things from one invocation to another and to assume + that it is being serially applied to optimize the same layer's + parameters. + ensures + - Returns a step vector V that is intended to be used to update the + parameters by adding V to l.get_layer_params(). + - This function will use the given "learning rate" to compute V. How the + learning rate is used is solver dependent. But in general the learning + rate should be used to select the step size, i.e. to somehow determine + the magnitude of V. + !*/ + }; + + void serialize(const EXAMPLE_SOLVER& item, std::ostream& out); + void deserialize(EXAMPLE_SOLVER& item, std::istream& in); + /*! + provides serialization support + !*/ + + std::ostream& operator<< (std::ostream& out, const EXAMPLE_SOLVER& item); + /*! + Prints the solver's name and parameters to out. + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class sgd + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the EXAMPLE_SOLVER interface defined above. It is a + basic stochastic gradient descent solver which uses momentum and weight + decay. In particular, it computes the update vector V according to: + V = momentum*V - weight_decay*learning_rate*l.get_layer_params() - learning_rate*params_grad; + Here V is a momentum term that is remembered by the solver from one + invocation of operator() to the next. + + + Note that the actual learning rate and weight decay used by the solver are + multiplied by the per layer multipliers. That is, the solver will call + get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and + multiply these values with the nominal learning rate and weight decay, + respectively, to determine the values it will use during each step. It is + also overloaded to allow additional learning rate multipliers to be applied + to fc_ and con_ bias parameters. + !*/ + public: + + sgd( + ); + /*! + ensures + - #get_weight_decay() == 0.0005 + - #get_momentum() == 0.9 + !*/ + + explicit sgd( + float weight_decay, + float momentum = 0.9 + ); + /*! + requires + - weight_decay >= 0 + - momentum >= 0 + ensures + - #get_weight_decay() == weight_decay + - #get_momentum() == momentum + !*/ + + float get_weight_decay () const; + float get_momentum () const; + }; + + void serialize(const sgd& item, std::ostream& out); + void deserialize(sgd& item, std::istream& in); + /*! + provides serialization support + !*/ + + std::ostream& operator<< (std::ostream& out, const sgd& item); + /*! + Prints the solver's name and parameters to out. + !*/ + +// ---------------------------------------------------------------------------------------- + + class adam + { + /*! + WHAT THIS OBJECT REPRESENTS + This object implements the EXAMPLE_SOLVER interface defined above. In + particular, it implements the ADAM parameter update method described in the + paper: + Kingma, Diederik P., and Jimmy Ba Adam. "A method for stochastic + optimization." International Conference on Learning Representation. 2015. + + + Note that the actual learning rate and weight decay used by the solver are + multiplied by the per layer multipliers. That is, the solver will call + get_learning_rate_multiplier(l) and get_weight_decay_multiplier(l) and + multiply these values with the nominal learning rate and weight decay, + respectively, to determine the values it will use during each step. It is + also overloaded to allow additional learning rate multipliers to be applied + to fc_ and con_ bias parameters. + !*/ + + public: + + adam( + ); + /*! + ensures + - #get_weight_decay() == 0.0005 + - #get_momentum1() == 0.9 + - #get_momentum2() == 0.999 + !*/ + + adam( + float weight_decay, + float momentum1, + float momentum2 + ); + /*! + requires + - weight_decay >= 0 + - 0 <= momentum1 < 1 + - 0 <= momentum2 < 1 + ensures + - #get_weight_decay() == weight_decay + - #get_momentum1() == momentum1 + - #get_momentum2() == momentum2 + !*/ + + float get_weight_decay () const; + float get_momentum1 () const; + float get_momentum2 () const; + }; + + void serialize(const adam& item, std::ostream& out); + void deserialize(adam& item, std::istream& in); + /*! + provides serialization support + !*/ + + std::ostream& operator<< (std::ostream& out, const adam& item); + /*! + Prints the solver's name and parameters to out. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_DNn_SOLVERS_ABSTRACT_H_ + |