summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-09 13:19:48 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-09 13:20:02 +0000
commit58daab21cd043e1dc37024a7f99b396788372918 (patch)
tree96771e43bb69f7c1c2b0b4f7374cb74d7866d0cb /ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h
parentReleasing debian version 1.43.2-1. (diff)
downloadnetdata-58daab21cd043e1dc37024a7f99b396788372918.tar.xz
netdata-58daab21cd043e1dc37024a7f99b396788372918.zip
Merging upstream version 1.44.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h')
-rw-r--r--ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h959
1 files changed, 959 insertions, 0 deletions
diff --git a/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h b/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h
new file mode 100644
index 000000000..f035442bf
--- /dev/null
+++ b/ml/dlib/dlib/graph_cuts/find_max_factor_graph_potts.h
@@ -0,0 +1,959 @@
+// Copyright (C) 2012 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_
+#define DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_
+
+#include "find_max_factor_graph_potts_abstract.h"
+#include "../matrix.h"
+#include "min_cut.h"
+#include "general_potts_problem.h"
+#include "../algs.h"
+#include "../graph_utils.h"
+#include "../array2d.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+
+ template <
+ typename potts_problem,
+ typename T = void
+ >
+ class flows_container
+ {
+ /*
+ This object notionally represents a matrix of flow values. It's
+ overloaded to represent this matrix efficiently though. In this case
+ it represents the matrix using a sparse representation.
+ */
+
+ typedef typename potts_problem::value_type edge_type;
+ std::vector<std::vector<edge_type> > flows;
+ public:
+
+ void setup(
+ const potts_problem& p
+ )
+ {
+ flows.resize(p.number_of_nodes());
+ for (unsigned long i = 0; i < flows.size(); ++i)
+ {
+ flows[i].resize(p.number_of_neighbors(i));
+ }
+ }
+
+ edge_type& operator() (
+ const long r,
+ const long c
+ ) { return flows[r][c]; }
+
+ const edge_type& operator() (
+ const long r,
+ const long c
+ ) const { return flows[r][c]; }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_problem
+ >
+ class flows_container<potts_problem,
+ typename enable_if_c<potts_problem::max_number_of_neighbors!=0>::type>
+ {
+ /*
+ This object notionally represents a matrix of flow values. It's
+ overloaded to represent this matrix efficiently though. In this case
+ it represents the matrix using a dense representation.
+
+ */
+ typedef typename potts_problem::value_type edge_type;
+ const static unsigned long max_number_of_neighbors = potts_problem::max_number_of_neighbors;
+ matrix<edge_type,0,max_number_of_neighbors> flows;
+ public:
+
+ void setup(
+ const potts_problem& p
+ )
+ {
+ flows.set_size(p.number_of_nodes(), max_number_of_neighbors);
+ }
+
+ edge_type& operator() (
+ const long r,
+ const long c
+ ) { return flows(r,c); }
+
+ const edge_type& operator() (
+ const long r,
+ const long c
+ ) const { return flows(r,c); }
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_problem
+ >
+ class potts_flow_graph
+ {
+ public:
+ typedef typename potts_problem::value_type edge_type;
+ private:
+ /*!
+ This is a utility class used by dlib::min_cut to convert a potts_problem
+ into the kind of flow graph expected by the min_cut object's main block
+ of code.
+
+ Within this object, we will use the convention that one past
+ potts_problem::number_of_nodes() is the source node and two past is
+ the sink node.
+ !*/
+
+ potts_problem& g;
+
+ // flows(i,j) == the flow from node id i to it's jth neighbor
+ flows_container<potts_problem> flows;
+ // source_flows(i,0) == flow from source to node i,
+ // source_flows(i,1) == flow from node i to source
+ matrix<edge_type,0,2> source_flows;
+
+ // sink_flows(i,0) == flow from sink to node i,
+ // sink_flows(i,1) == flow from node i to sink
+ matrix<edge_type,0,2> sink_flows;
+
+ node_label source_label, sink_label;
+ public:
+
+ potts_flow_graph(
+ potts_problem& g_
+ ) : g(g_)
+ {
+ flows.setup(g);
+
+ source_flows.set_size(g.number_of_nodes(), 2);
+ sink_flows.set_size(g.number_of_nodes(), 2);
+ source_flows = 0;
+ sink_flows = 0;
+
+ source_label = FREE_NODE;
+ sink_label = FREE_NODE;
+
+ // setup flows based on factor potentials
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ const edge_type temp = g.factor_value(i);
+ if (temp < 0)
+ source_flows(i,0) = -temp;
+ else
+ sink_flows(i,1) = temp;
+
+ for (unsigned long j = 0; j < g.number_of_neighbors(i); ++j)
+ {
+ flows(i,j) = g.factor_value_disagreement(i, g.get_neighbor(i,j));
+ }
+ }
+ }
+
+ class out_edge_iterator
+ {
+ friend class potts_flow_graph;
+ unsigned long idx; // base node idx
+ unsigned long cnt; // count over the neighbors of idx
+ public:
+
+ out_edge_iterator(
+ ):idx(0),cnt(0){}
+
+ out_edge_iterator(
+ unsigned long idx_,
+ unsigned long cnt_
+ ):idx(idx_),cnt(cnt_)
+ {}
+
+ bool operator!= (
+ const out_edge_iterator& item
+ ) const { return cnt != item.cnt; }
+
+ out_edge_iterator& operator++(
+ )
+ {
+ ++cnt;
+ return *this;
+ }
+ };
+
+ class in_edge_iterator
+ {
+ friend class potts_flow_graph;
+ unsigned long idx; // base node idx
+ unsigned long cnt; // count over the neighbors of idx
+ public:
+
+ in_edge_iterator(
+ ):idx(0),cnt(0)
+ {}
+
+
+ in_edge_iterator(
+ unsigned long idx_,
+ unsigned long cnt_
+ ):idx(idx_),cnt(cnt_)
+ {}
+
+ bool operator!= (
+ const in_edge_iterator& item
+ ) const { return cnt != item.cnt; }
+
+ in_edge_iterator& operator++(
+ )
+ {
+ ++cnt;
+ return *this;
+ }
+ };
+
+ unsigned long number_of_nodes (
+ ) const { return g.number_of_nodes() + 2; }
+
+ out_edge_iterator out_begin(
+ const unsigned long& it
+ ) const { return out_edge_iterator(it, 0); }
+
+ in_edge_iterator in_begin(
+ const unsigned long& it
+ ) const { return in_edge_iterator(it, 0); }
+
+ out_edge_iterator out_end(
+ const unsigned long& it
+ ) const
+ {
+ if (it >= g.number_of_nodes())
+ return out_edge_iterator(it, g.number_of_nodes());
+ else
+ return out_edge_iterator(it, g.number_of_neighbors(it)+2);
+ }
+
+ in_edge_iterator in_end(
+ const unsigned long& it
+ ) const
+ {
+ if (it >= g.number_of_nodes())
+ return in_edge_iterator(it, g.number_of_nodes());
+ else
+ return in_edge_iterator(it, g.number_of_neighbors(it)+2);
+ }
+
+
+ template <typename iterator_type>
+ unsigned long node_id (
+ const iterator_type& it
+ ) const
+ {
+ // if this isn't an iterator over the source or sink nodes
+ if (it.idx < g.number_of_nodes())
+ {
+ const unsigned long num = g.number_of_neighbors(it.idx);
+ if (it.cnt < num)
+ return g.get_neighbor(it.idx, it.cnt);
+ else if (it.cnt == num)
+ return g.number_of_nodes();
+ else
+ return g.number_of_nodes()+1;
+ }
+ else
+ {
+ return it.cnt;
+ }
+ }
+
+
+ edge_type get_flow (
+ const unsigned long& it1,
+ const unsigned long& it2
+ ) const
+ {
+ if (it1 >= g.number_of_nodes())
+ {
+ // if it1 is the source
+ if (it1 == g.number_of_nodes())
+ {
+ return source_flows(it2,0);
+ }
+ else // if it1 is the sink
+ {
+ return sink_flows(it2,0);
+ }
+ }
+ else if (it2 >= g.number_of_nodes())
+ {
+ // if it2 is the source
+ if (it2 == g.number_of_nodes())
+ {
+ return source_flows(it1,1);
+ }
+ else // if it2 is the sink
+ {
+ return sink_flows(it1,1);
+ }
+ }
+ else
+ {
+ return flows(it1, g.get_neighbor_idx(it1, it2));
+ }
+
+ }
+
+ edge_type get_flow (
+ const out_edge_iterator& it
+ ) const
+ {
+ if (it.idx < g.number_of_nodes())
+ {
+ const unsigned long num = g.number_of_neighbors(it.idx);
+ if (it.cnt < num)
+ return flows(it.idx, it.cnt);
+ else if (it.cnt == num)
+ return source_flows(it.idx,1);
+ else
+ return sink_flows(it.idx,1);
+ }
+ else
+ {
+ // if it.idx is the source
+ if (it.idx == g.number_of_nodes())
+ {
+ return source_flows(it.cnt,0);
+ }
+ else // if it.idx is the sink
+ {
+ return sink_flows(it.cnt,0);
+ }
+ }
+ }
+
+ edge_type get_flow (
+ const in_edge_iterator& it
+ ) const
+ {
+ return get_flow(node_id(it), it.idx);
+ }
+
+ void adjust_flow (
+ const unsigned long& it1,
+ const unsigned long& it2,
+ const edge_type& value
+ )
+ {
+ if (it1 >= g.number_of_nodes())
+ {
+ // if it1 is the source
+ if (it1 == g.number_of_nodes())
+ {
+ source_flows(it2,0) += value;
+ source_flows(it2,1) -= value;
+ }
+ else // if it1 is the sink
+ {
+ sink_flows(it2,0) += value;
+ sink_flows(it2,1) -= value;
+ }
+ }
+ else if (it2 >= g.number_of_nodes())
+ {
+ // if it2 is the source
+ if (it2 == g.number_of_nodes())
+ {
+ source_flows(it1,1) += value;
+ source_flows(it1,0) -= value;
+ }
+ else // if it2 is the sink
+ {
+ sink_flows(it1,1) += value;
+ sink_flows(it1,0) -= value;
+ }
+ }
+ else
+ {
+ flows(it1, g.get_neighbor_idx(it1, it2)) += value;
+ flows(it2, g.get_neighbor_idx(it2, it1)) -= value;
+ }
+
+ }
+
+ void set_label (
+ const unsigned long& it,
+ node_label value
+ )
+ {
+ if (it < g.number_of_nodes())
+ g.set_label(it, value);
+ else if (it == g.number_of_nodes())
+ source_label = value;
+ else
+ sink_label = value;
+ }
+
+ node_label get_label (
+ const unsigned long& it
+ ) const
+ {
+ if (it < g.number_of_nodes())
+ return g.get_label(it);
+ if (it == g.number_of_nodes())
+ return source_label;
+ else
+ return sink_label;
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename label_image_type,
+ typename image_potts_model
+ >
+ class potts_grid_problem
+ {
+ label_image_type& label_img;
+ long nc;
+ long num_nodes;
+ unsigned char* labels;
+ const image_potts_model& model;
+
+ public:
+ const static unsigned long max_number_of_neighbors = 4;
+
+ potts_grid_problem (
+ label_image_type& label_img_,
+ const image_potts_model& image_potts_model_
+ ) :
+ label_img(label_img_),
+ model(image_potts_model_)
+ {
+ num_nodes = model.nr()*model.nc();
+ nc = model.nc();
+ labels = &label_img[0][0];
+ }
+
+ unsigned long number_of_nodes (
+ ) const { return num_nodes; }
+
+ unsigned long number_of_neighbors (
+ unsigned long
+ ) const
+ {
+ return 4;
+ }
+
+ unsigned long get_neighbor_idx (
+ long node_id1,
+ long node_id2
+ ) const
+ {
+ long diff = node_id2-node_id1;
+ if (diff > nc)
+ diff -= (long)number_of_nodes();
+ else if (diff < -nc)
+ diff += (long)number_of_nodes();
+
+ if (diff == 1)
+ return 0;
+ else if (diff == -1)
+ return 1;
+ else if (diff == nc)
+ return 2;
+ else
+ return 3;
+ }
+
+ unsigned long get_neighbor (
+ long node_id,
+ long idx
+ ) const
+ {
+ switch(idx)
+ {
+ case 0:
+ {
+ long temp = node_id+1;
+ if (temp < (long)number_of_nodes())
+ return temp;
+ else
+ return temp - (long)number_of_nodes();
+ }
+ case 1:
+ {
+ long temp = node_id-1;
+ if (node_id >= 1)
+ return temp;
+ else
+ return temp + (long)number_of_nodes();
+ }
+ case 2:
+ {
+ long temp = node_id+nc;
+ if (temp < (long)number_of_nodes())
+ return temp;
+ else
+ return temp - (long)number_of_nodes();
+ }
+ case 3:
+ {
+ long temp = node_id-nc;
+ if (node_id >= nc)
+ return temp;
+ else
+ return temp + (long)number_of_nodes();
+ }
+ }
+ return 0;
+ }
+
+ void set_label (
+ const unsigned long& idx,
+ node_label value
+ )
+ {
+ *(labels+idx) = value;
+ }
+
+ node_label get_label (
+ const unsigned long& idx
+ ) const
+ {
+ return *(labels+idx);
+ }
+
+ typedef typename image_potts_model::value_type value_type;
+
+ value_type factor_value (unsigned long idx) const
+ {
+ return model.factor_value(idx);
+ }
+
+ value_type factor_value_disagreement (unsigned long idx1, unsigned long idx2) const
+ {
+ return model.factor_value_disagreement(idx1,idx2);
+ }
+
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_model
+ >
+ typename potts_model::value_type potts_model_score (
+ const potts_model& prob
+ )
+ {
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
+ {
+ for (unsigned long jj = 0; jj < prob.number_of_neighbors(i); ++jj)
+ {
+ unsigned long j = prob.get_neighbor(i,jj);
+ DLIB_ASSERT(prob.factor_value_disagreement(i,j) >= 0,
+ "\t value_type potts_model_score(prob)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t prob.factor_value_disagreement(i,j): " << prob.factor_value_disagreement(i,j)
+ );
+ DLIB_ASSERT(prob.factor_value_disagreement(i,j) == prob.factor_value_disagreement(j,i),
+ "\t value_type potts_model_score(prob)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t prob.factor_value_disagreement(i,j): " << prob.factor_value_disagreement(i,j)
+ << "\n\t prob.factor_value_disagreement(j,i): " << prob.factor_value_disagreement(j,i)
+ );
+ }
+ }
+#endif
+
+ typename potts_model::value_type score = 0;
+ for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
+ {
+ const bool label = (prob.get_label(i)!=0);
+ if (label)
+ score += prob.factor_value(i);
+ }
+
+ for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
+ {
+ for (unsigned long n = 0; n < prob.number_of_neighbors(i); ++n)
+ {
+ const unsigned long idx2 = prob.get_neighbor(i,n);
+ const bool label_i = (prob.get_label(i)!=0);
+ const bool label_idx2 = (prob.get_label(idx2)!=0);
+ if (label_i != label_idx2 && i < idx2)
+ score -= prob.factor_value_disagreement(i, idx2);
+ }
+ }
+
+ return score;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type
+ >
+ typename graph_type::edge_type potts_model_score (
+ const graph_type& g,
+ const std::vector<node_label>& labels
+ )
+ {
+ DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
+ "\t edge_type potts_model_score(g,labels)"
+ << "\n\t Invalid inputs were given to this function."
+ );
+ typedef typename graph_type::edge_type edge_type;
+ typedef typename graph_type::type type;
+
+ // The edges and node's have to use the same type to represent factor weights!
+ COMPILE_TIME_ASSERT((is_same_type<edge_type, type>::value == true));
+
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ for (unsigned long jj = 0; jj < g.node(i).number_of_neighbors(); ++jj)
+ {
+ unsigned long j = g.node(i).neighbor(jj).index();
+ DLIB_ASSERT(edge(g,i,j) >= 0,
+ "\t edge_type potts_model_score(g,labels)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t edge(g,i,j): " << edge(g,i,j)
+ );
+ }
+ }
+#endif
+
+ typename graph_type::edge_type score = 0;
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ const bool label = (labels[i]!=0);
+ if (label)
+ score += g.node(i).data;
+ }
+
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
+ {
+ const unsigned long idx2 = g.node(i).neighbor(n).index();
+ const bool label_i = (labels[i]!=0);
+ const bool label_idx2 = (labels[idx2]!=0);
+ if (label_i != label_idx2 && i < idx2)
+ score -= g.node(i).edge(n);
+ }
+ }
+
+ return score;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_grid_problem,
+ typename mem_manager
+ >
+ typename potts_grid_problem::value_type potts_model_score (
+ const potts_grid_problem& prob,
+ const array2d<node_label,mem_manager>& labels
+ )
+ {
+ DLIB_ASSERT(prob.nr() == labels.nr() && prob.nc() == labels.nc(),
+ "\t value_type potts_model_score(prob,labels)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t prob.nr(): " << labels.nr()
+ << "\n\t prob.nc(): " << labels.nc()
+ );
+ typedef array2d<node_label,mem_manager> image_type;
+ // This const_cast is ok because the model object won't actually modify labels
+ dlib::impl::potts_grid_problem<image_type,potts_grid_problem> model(const_cast<image_type&>(labels),prob);
+ return potts_model_score(model);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_model
+ >
+ void find_max_factor_graph_potts (
+ potts_model& prob
+ )
+ {
+#ifdef ENABLE_ASSERTS
+ for (unsigned long node_i = 0; node_i < prob.number_of_nodes(); ++node_i)
+ {
+ for (unsigned long jj = 0; jj < prob.number_of_neighbors(node_i); ++jj)
+ {
+ unsigned long node_j = prob.get_neighbor(node_i,jj);
+ DLIB_ASSERT(prob.get_neighbor_idx(node_j,node_i) < prob.number_of_neighbors(node_j),
+ "\t void find_max_factor_graph_potts(prob)"
+ << "\n\t The supplied potts problem defines an invalid graph."
+ << "\n\t node_i: " << node_i
+ << "\n\t node_j: " << node_j
+ << "\n\t prob.get_neighbor_idx(node_j,node_i): " << prob.get_neighbor_idx(node_j,node_i)
+ << "\n\t prob.number_of_neighbors(node_j): " << prob.number_of_neighbors(node_j)
+ );
+
+ DLIB_ASSERT(prob.get_neighbor_idx(node_i,prob.get_neighbor(node_i,jj)) == jj,
+ "\t void find_max_factor_graph_potts(prob)"
+ << "\n\t The get_neighbor_idx() and get_neighbor() functions must be inverses of each other."
+ << "\n\t node_i: " << node_i
+ << "\n\t jj: " << jj
+ << "\n\t prob.get_neighbor(node_i,jj): " << prob.get_neighbor(node_i,jj)
+ << "\n\t prob.get_neighbor_idx(node_i,prob.get_neighbor(node_i,jj)): " << prob.get_neighbor_idx(node_i,node_j)
+ );
+
+ DLIB_ASSERT(prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i))==node_i,
+ "\t void find_max_factor_graph_potts(prob)"
+ << "\n\t The get_neighbor_idx() and get_neighbor() functions must be inverses of each other."
+ << "\n\t node_i: " << node_i
+ << "\n\t node_j: " << node_j
+ << "\n\t prob.get_neighbor_idx(node_j,node_i): " << prob.get_neighbor_idx(node_j,node_i)
+ << "\n\t prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i)): " << prob.get_neighbor(node_j,prob.get_neighbor_idx(node_j,node_i))
+ );
+
+ DLIB_ASSERT(prob.factor_value_disagreement(node_i,node_j) >= 0,
+ "\t void find_max_factor_graph_potts(prob)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t node_i: " << node_i
+ << "\n\t node_j: " << node_j
+ << "\n\t prob.factor_value_disagreement(node_i,node_j): " << prob.factor_value_disagreement(node_i,node_j)
+ );
+ DLIB_ASSERT(prob.factor_value_disagreement(node_i,node_j) == prob.factor_value_disagreement(node_j,node_i),
+ "\t void find_max_factor_graph_potts(prob)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t node_i: " << node_i
+ << "\n\t node_j: " << node_j
+ << "\n\t prob.factor_value_disagreement(node_i,node_j): " << prob.factor_value_disagreement(node_i,node_j)
+ << "\n\t prob.factor_value_disagreement(node_j,node_i): " << prob.factor_value_disagreement(node_j,node_i)
+ );
+ }
+ }
+#endif
+ COMPILE_TIME_ASSERT(is_signed_type<typename potts_model::value_type>::value);
+ min_cut mc;
+ dlib::impl::potts_flow_graph<potts_model> pfg(prob);
+ mc(pfg, prob.number_of_nodes(), prob.number_of_nodes()+1);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename graph_type
+ >
+ void find_max_factor_graph_potts (
+ const graph_type& g,
+ std::vector<node_label>& labels
+ )
+ {
+ DLIB_ASSERT(graph_contains_length_one_cycle(g) == false,
+ "\t void find_max_factor_graph_potts(g,labels)"
+ << "\n\t Invalid inputs were given to this function."
+ );
+ typedef typename graph_type::edge_type edge_type;
+ typedef typename graph_type::type type;
+
+ // The edges and node's have to use the same type to represent factor weights!
+ COMPILE_TIME_ASSERT((is_same_type<edge_type, type>::value == true));
+ COMPILE_TIME_ASSERT(is_signed_type<edge_type>::value);
+
+#ifdef ENABLE_ASSERTS
+ for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
+ {
+ for (unsigned long jj = 0; jj < g.node(i).number_of_neighbors(); ++jj)
+ {
+ unsigned long j = g.node(i).neighbor(jj).index();
+ DLIB_ASSERT(edge(g,i,j) >= 0,
+ "\t void find_max_factor_graph_potts(g,labels)"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t i: " << i
+ << "\n\t j: " << j
+ << "\n\t edge(g,i,j): " << edge(g,i,j)
+ );
+ }
+ }
+#endif
+
+ dlib::impl::general_potts_problem<graph_type> gg(g, labels);
+ find_max_factor_graph_potts(gg);
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename potts_grid_problem,
+ typename mem_manager
+ >
+ void find_max_factor_graph_potts (
+ const potts_grid_problem& prob,
+ array2d<node_label,mem_manager>& labels
+ )
+ {
+ typedef array2d<node_label,mem_manager> image_type;
+ labels.set_size(prob.nr(), prob.nc());
+ dlib::impl::potts_grid_problem<image_type,potts_grid_problem> model(labels,prob);
+ find_max_factor_graph_potts(model);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename pixel_type1,
+ typename pixel_type2,
+ typename model_type
+ >
+ struct potts_grid_image_pair_model
+ {
+ const pixel_type1* data1;
+ const pixel_type2* data2;
+ const model_type& model;
+ const long nr_;
+ const long nc_;
+ template <typename image_type1, typename image_type2>
+ potts_grid_image_pair_model(
+ const model_type& model_,
+ const image_type1& img1,
+ const image_type2& img2
+ ) :
+ model(model_),
+ nr_(img1.nr()),
+ nc_(img1.nc())
+ {
+ data1 = &img1[0][0];
+ data2 = &img2[0][0];
+ }
+
+ typedef typename model_type::value_type value_type;
+
+ long nr() const { return nr_; }
+ long nc() const { return nc_; }
+
+ value_type factor_value (
+ unsigned long idx
+ ) const
+ {
+ return model.factor_value(*(data1 + idx), *(data2 + idx));
+ }
+
+ value_type factor_value_disagreement (
+ unsigned long idx1,
+ unsigned long idx2
+ ) const
+ {
+ return model.factor_value_disagreement(*(data1 + idx1), *(data1 + idx2));
+ }
+ };
+
+ // ----------------------------------------------------------------------------------------
+
+ template <
+ typename image_type,
+ typename model_type
+ >
+ struct potts_grid_image_single_model
+ {
+ const typename image_type::type* data1;
+ const model_type& model;
+ const long nr_;
+ const long nc_;
+ potts_grid_image_single_model(
+ const model_type& model_,
+ const image_type& img1
+ ) :
+ model(model_),
+ nr_(img1.nr()),
+ nc_(img1.nc())
+ {
+ data1 = &img1[0][0];
+ }
+
+ typedef typename model_type::value_type value_type;
+
+ long nr() const { return nr_; }
+ long nc() const { return nc_; }
+
+ value_type factor_value (
+ unsigned long idx
+ ) const
+ {
+ return model.factor_value(*(data1 + idx));
+ }
+
+ value_type factor_value_disagreement (
+ unsigned long idx1,
+ unsigned long idx2
+ ) const
+ {
+ return model.factor_value_disagreement(*(data1 + idx1), *(data1 + idx2));
+ }
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename pair_image_model,
+ typename pixel_type1,
+ typename pixel_type2,
+ typename mem_manager
+ >
+ impl::potts_grid_image_pair_model<pixel_type1, pixel_type2, pair_image_model> make_potts_grid_problem (
+ const pair_image_model& model,
+ const array2d<pixel_type1,mem_manager>& img1,
+ const array2d<pixel_type2,mem_manager>& img2
+ )
+ {
+ DLIB_ASSERT(get_rect(img1) == get_rect(img2),
+ "\t potts_grid_problem make_potts_grid_problem()"
+ << "\n\t Invalid inputs were given to this function."
+ << "\n\t get_rect(img1): " << get_rect(img1)
+ << "\n\t get_rect(img2): " << get_rect(img2)
+ );
+ typedef impl::potts_grid_image_pair_model<pixel_type1, pixel_type2, pair_image_model> potts_type;
+ return potts_type(model,img1,img2);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <
+ typename single_image_model,
+ typename pixel_type,
+ typename mem_manager
+ >
+ impl::potts_grid_image_single_model<array2d<pixel_type,mem_manager>, single_image_model> make_potts_grid_problem (
+ const single_image_model& model,
+ const array2d<pixel_type,mem_manager>& img
+ )
+ {
+ typedef impl::potts_grid_image_single_model<array2d<pixel_type,mem_manager>, single_image_model> potts_type;
+ return potts_type(model,img);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_FIND_MAX_FACTOR_GRAPH_PoTTS_Hh_
+