summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h185
1 files changed, 185 insertions, 0 deletions
diff --git a/ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h b/ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h
new file mode 100644
index 000000000..a0821b696
--- /dev/null
+++ b/ml/dlib/dlib/graph_cuts/graph_labeler_abstract.h
@@ -0,0 +1,185 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#undef DLIB_GRAPH_LaBELER_ABSTRACT_Hh_
+#ifdef DLIB_GRAPH_LaBELER_ABSTRACT_Hh_
+
+#include "find_max_factor_graph_potts_abstract.h"
+#include "../graph/graph_kernel_abstract.h"
+#include "../matrix/matrix_abstract.h"
+#include <vector>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ class graph_labeler
+ {
+ /*!
+ REQUIREMENTS ON vector_type
+ - vector_type is a dlib::matrix capable of representing column
+ vectors or it is a sparse vector type as defined in dlib/svm/sparse_vector_abstract.h.
+
+ WHAT THIS OBJECT REPRESENTS
+ This object is a tool for labeling each node in a graph with a value
+ of true or false, subject to a labeling consistency constraint between
+ nodes that share an edge. In particular, this object is useful for
+ representing a graph labeling model learned via some machine learning
+ method.
+
+ To elaborate, suppose we have a graph we want to label. Moreover,
+ suppose we can assign a score to each node which represents how much
+ we want to label the node as true, and we also have scores for each
+ edge which represent how much we wanted the nodes sharing the edge to
+ have the same label. If we could do this then we could find the optimal
+ labeling using the find_max_factor_graph_potts() routine. Therefore,
+ the graph_labeler is just an object which contains the necessary data
+ to compute these score functions and then call find_max_factor_graph_potts().
+ Additionally, this object uses linear functions to represent these score
+ functions.
+
+ THREAD SAFETY
+ It is always safe to use distinct instances of this object in different
+ threads. However, when a single instance is shared between threads then
+ the following rules apply:
+ It is safe to call the const members of this object from multiple
+ threads. This is because the const members are purely read-only
+ operations. However, any operation that modifies a graph_labeler is
+ not threadsafe.
+ !*/
+
+ public:
+
+ typedef std::vector<bool> label_type;
+ typedef label_type result_type;
+
+ graph_labeler(
+ );
+ /*!
+ ensures
+ - this object is properly initialized
+ - #get_node_weights() == an initial value of type vector_type.
+ - #get_edge_weights() == an initial value of type vector_type.
+ !*/
+
+ graph_labeler(
+ const vector_type& edge_weights,
+ const vector_type& node_weights
+ );
+ /*!
+ requires
+ - min(edge_weights) >= 0
+ ensures
+ - #get_edge_weights() == edge_weights
+ - #get_node_weights() == node_weights
+ !*/
+
+ const vector_type& get_edge_weights (
+ ) const;
+ /*!
+ ensures
+ - Recall that the score function for an edge is a linear function of
+ the vector stored at that edge. This means there is some vector, E,
+ which we dot product with the vector in the graph to compute the
+ score. Therefore, this function returns that E vector which defines
+ the edge score function.
+ !*/
+
+ const vector_type& get_node_weights (
+ ) const;
+ /*!
+ ensures
+ - Recall that the score function for a node is a linear function of
+ the vector stored in that node. This means there is some vector, W,
+ which we dot product with the vector in the graph to compute the score.
+ Therefore, this function returns that W vector which defines the node
+ score function.
+ !*/
+
+ template <typename graph_type>
+ void operator() (
+ const graph_type& sample,
+ std::vector<bool>& labels
+ ) const;
+ /*!
+ requires
+ - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h
+ - graph_type::type and graph_type::edge_type must be either matrix objects
+ capable of representing column vectors or some kind of sparse vector
+ type as defined in dlib/svm/sparse_vector_abstract.h.
+ - graph_contains_length_one_cycle(sample) == false
+ - for all valid i and j:
+ - min(edge(sample,i,j)) >= 0
+ - it must be legal to call dot(edge(sample,i,j), get_edge_weights())
+ - it must be legal to call dot(sample.node(i).data, get_node_weights())
+ ensures
+ - Computes a labeling for each node in the given graph and stores the result
+ in #labels.
+ - #labels.size() == sample.number_of_nodes()
+ - for all valid i:
+ - #labels[i] == the label of the node sample.node(i).
+ - The labels are computed by creating a graph, G, with scalar values on each node
+ and edge. The scalar values are calculated according to the following:
+ - for all valid i:
+ - G.node(i).data == dot(get_node_weights(), sample.node(i).data)
+ - for all valid i and j:
+ - edge(G,i,j) == dot(get_edge_weights(), edge(sample,i,j))
+ Then the labels are computed by calling find_max_factor_graph_potts(G,#labels).
+ !*/
+
+ template <typename graph_type>
+ std::vector<bool> operator() (
+ const graph_type& sample
+ ) const;
+ /*!
+ requires
+ - graph_type is an implementation of dlib/graph/graph_kernel_abstract.h
+ - graph_contains_length_one_cycle(sample) == false
+ - for all valid i and j:
+ - min(edge(sample,i,j)) >= 0
+ - it must be legal to call dot(edge(sample,i,j), get_edge_weights())
+ - it must be legal to call dot(sample.node(i).data, get_node_weights())
+ ensures
+ - Performs (*this)(sample, labels); return labels;
+ (i.e. This is just another version of the above operator() routine
+ but instead of returning the labels via the second argument, it
+ returns them as the normal return value).
+ !*/
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void serialize (
+ const graph_labeler<vector_type>& item,
+ std::ostream& out
+ );
+ /*!
+ provides serialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename vector_type
+ >
+ void deserialize (
+ graph_labeler<vector_type>& item,
+ std::istream& in
+ );
+ /*!
+ provides deserialization support
+ !*/
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GRAPH_LaBELER_ABSTRACT_Hh_
+