summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp')
-rw-r--r--ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp787
1 files changed, 787 insertions, 0 deletions
diff --git a/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp b/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp
new file mode 100644
index 000000000..2260e92a1
--- /dev/null
+++ b/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp
@@ -0,0 +1,787 @@
+// Copyright (C) 2011 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#include <sstream>
+#include <string>
+#include <cstdlib>
+#include <ctime>
+#include <dlib/optimization.h>
+#include <dlib/unordered_pair.h>
+#include <dlib/rand.h>
+
+#include "tester.h"
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.find_max_factor_graph_nmplp");
+
+// ----------------------------------------------------------------------------------------
+
+ dlib::rand rnd;
+
+ template <bool fully_connected>
+ class map_problem
+ {
+ /*
+ This is a simple 8 node problem with two cycles in it unless fully_connected is true
+ and then it's a fully connected 8 note graph.
+ */
+
+ public:
+
+ mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights;
+ map_problem()
+ {
+ for (int i = 0; i < 8; ++i)
+ {
+ for (int j = i; j < 8; ++j)
+ {
+ weights[make_unordered_pair(i,j)][make_pair(0,0)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,j)][make_pair(0,1)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,j)][make_pair(1,0)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,j)][make_pair(1,1)] = rnd.get_random_gaussian();
+ }
+ }
+ }
+
+ struct node_iterator
+ {
+ node_iterator() {}
+ node_iterator(unsigned long nid_): nid(nid_) {}
+ bool operator== (const node_iterator& item) const { return item.nid == nid; }
+ bool operator!= (const node_iterator& item) const { return item.nid != nid; }
+
+ node_iterator& operator++()
+ {
+ ++nid;
+ return *this;
+ }
+
+ unsigned long nid;
+ };
+
+ struct neighbor_iterator
+ {
+ neighbor_iterator() : count(0) {}
+
+ bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); }
+ bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); }
+ neighbor_iterator& operator++()
+ {
+ ++count;
+ return *this;
+ }
+
+ unsigned long node_id () const
+ {
+ if (fully_connected)
+ {
+ if (count < home_node)
+ return count;
+ else
+ return count+1;
+ }
+
+ if (home_node < 4)
+ {
+ if (count == 0)
+ return (home_node + 4 + 1)%4;
+ else if (count == 1)
+ return (home_node + 4 - 1)%4;
+ else
+ return 8; // one past the end
+ }
+ else
+ {
+ if (count == 0)
+ return (home_node + 4 + 1)%4 + 4;
+ else if (count == 1)
+ return (home_node + 4 - 1)%4 + 4;
+ else
+ return 8; // one past the end
+ }
+ }
+
+ unsigned long home_node;
+ unsigned long count;
+ };
+
+ unsigned long number_of_nodes (
+ ) const
+ {
+ return 8;
+ }
+
+ node_iterator begin(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 0;
+ return temp;
+ }
+
+ node_iterator end(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 8;
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const node_iterator& it
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = it.nid;
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const neighbor_iterator& it
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = it.node_id();
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const node_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = 9;
+ temp.count = 8;
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const neighbor_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = 9;
+ temp.count = 8;
+ return temp;
+ }
+
+
+ unsigned long node_id (
+ const node_iterator& it
+ ) const
+ {
+ return it.nid;
+ }
+
+ unsigned long node_id (
+ const neighbor_iterator& it
+ ) const
+ {
+ return it.node_id();
+ }
+
+
+ unsigned long num_states (
+ const node_iterator&
+ ) const
+ {
+ return 2;
+ }
+
+ unsigned long num_states (
+ const neighbor_iterator&
+ ) const
+ {
+ return 2;
+ }
+
+ double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.nid, s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); }
+ double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); }
+
+ private:
+
+ double basic_factor_value (
+ unsigned long n1,
+ unsigned long n2,
+ unsigned long s1,
+ unsigned long s2
+ ) const
+ {
+ if (n1 > n2)
+ {
+ swap(n1,n2);
+ swap(s1,s2);
+ }
+ return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)];
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class map_problem_chain
+ {
+ /*
+ This is a chain structured 8 node graph (so no cycles).
+ */
+
+ public:
+
+ mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights;
+ map_problem_chain()
+ {
+ for (int i = 0; i < 7; ++i)
+ {
+ weights[make_unordered_pair(i,i+1)][make_pair(0,0)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,i+1)][make_pair(0,1)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,i+1)][make_pair(1,0)] = rnd.get_random_gaussian();
+ weights[make_unordered_pair(i,i+1)][make_pair(1,1)] = rnd.get_random_gaussian();
+ }
+ }
+
+ struct node_iterator
+ {
+ node_iterator() {}
+ node_iterator(unsigned long nid_): nid(nid_) {}
+ bool operator== (const node_iterator& item) const { return item.nid == nid; }
+ bool operator!= (const node_iterator& item) const { return item.nid != nid; }
+
+ node_iterator& operator++()
+ {
+ ++nid;
+ return *this;
+ }
+
+ unsigned long nid;
+ };
+
+ struct neighbor_iterator
+ {
+ neighbor_iterator() : count(0) {}
+
+ bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); }
+ bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); }
+ neighbor_iterator& operator++()
+ {
+ ++count;
+ return *this;
+ }
+
+ unsigned long node_id () const
+ {
+ if (count >= 2)
+ return 8;
+ return nid[count];
+ }
+
+ unsigned long nid[2];
+ unsigned int count;
+ };
+
+ unsigned long number_of_nodes (
+ ) const
+ {
+ return 8;
+ }
+
+ node_iterator begin(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 0;
+ return temp;
+ }
+
+ node_iterator end(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 8;
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const node_iterator& it
+ ) const
+ {
+ neighbor_iterator temp;
+ if (it.nid == 0)
+ {
+ temp.nid[0] = it.nid+1;
+ temp.nid[1] = 8;
+ }
+ else if (it.nid == 7)
+ {
+ temp.nid[0] = it.nid-1;
+ temp.nid[1] = 8;
+ }
+ else
+ {
+ temp.nid[0] = it.nid-1;
+ temp.nid[1] = it.nid+1;
+ }
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const neighbor_iterator& it
+ ) const
+ {
+ const unsigned long nid = it.node_id();
+ neighbor_iterator temp;
+ if (nid == 0)
+ {
+ temp.nid[0] = nid+1;
+ temp.nid[1] = 8;
+ }
+ else if (nid == 7)
+ {
+ temp.nid[0] = nid-1;
+ temp.nid[1] = 8;
+ }
+ else
+ {
+ temp.nid[0] = nid-1;
+ temp.nid[1] = nid+1;
+ }
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const node_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.nid[0] = 8;
+ temp.nid[1] = 8;
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const neighbor_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.nid[0] = 8;
+ temp.nid[1] = 8;
+ return temp;
+ }
+
+
+ unsigned long node_id (
+ const node_iterator& it
+ ) const
+ {
+ return it.nid;
+ }
+
+ unsigned long node_id (
+ const neighbor_iterator& it
+ ) const
+ {
+ return it.node_id();
+ }
+
+
+ unsigned long num_states (
+ const node_iterator&
+ ) const
+ {
+ return 2;
+ }
+
+ unsigned long num_states (
+ const neighbor_iterator&
+ ) const
+ {
+ return 2;
+ }
+
+ double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.nid, s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); }
+ double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); }
+
+ private:
+
+ double basic_factor_value (
+ unsigned long n1,
+ unsigned long n2,
+ unsigned long s1,
+ unsigned long s2
+ ) const
+ {
+ if (n1 > n2)
+ {
+ swap(n1,n2);
+ swap(s1,s2);
+ }
+ return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)];
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+
+ class map_problem2
+ {
+ /*
+ This is a simple tree structured graph. In particular, it is a star made
+ up of 6 nodes.
+ */
+ public:
+ matrix<double> numbers;
+
+ map_problem2()
+ {
+ numbers = randm(5,3,rnd);
+ }
+
+ struct node_iterator
+ {
+ node_iterator() {}
+ node_iterator(unsigned long nid_): nid(nid_) {}
+ bool operator== (const node_iterator& item) const { return item.nid == nid; }
+ bool operator!= (const node_iterator& item) const { return item.nid != nid; }
+
+ node_iterator& operator++()
+ {
+ ++nid;
+ return *this;
+ }
+
+ unsigned long nid;
+ };
+
+ struct neighbor_iterator
+ {
+ neighbor_iterator() : count(0) {}
+
+ bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); }
+ bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); }
+ neighbor_iterator& operator++()
+ {
+ ++count;
+ return *this;
+ }
+
+ unsigned long node_id () const
+ {
+ if (home_node == 6)
+ return 6;
+
+ if (home_node < 5)
+ {
+ // all the nodes are connected to node 5 and nothing else
+ if (count == 0)
+ return 5;
+ else
+ return 6; // the number returned by the end() functions.
+ }
+ else if (count < 5)
+ {
+ return count;
+ }
+ else
+ {
+ return 6;
+ }
+
+ }
+
+ unsigned long home_node;
+ unsigned long count;
+ };
+
+ unsigned long number_of_nodes (
+ ) const
+ {
+ return 6;
+ }
+
+ node_iterator begin(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 0;
+ return temp;
+ }
+
+ node_iterator end(
+ ) const
+ {
+ node_iterator temp;
+ temp.nid = 6;
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const node_iterator& it
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = it.nid;
+ return temp;
+ }
+
+ neighbor_iterator begin(
+ const neighbor_iterator& it
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = it.node_id();
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const node_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = 6;
+ return temp;
+ }
+
+ neighbor_iterator end(
+ const neighbor_iterator&
+ ) const
+ {
+ neighbor_iterator temp;
+ temp.home_node = 6;
+ return temp;
+ }
+
+
+ unsigned long node_id (
+ const node_iterator& it
+ ) const
+ {
+ return it.nid;
+ }
+
+ unsigned long node_id (
+ const neighbor_iterator& it
+ ) const
+ {
+ return it.node_id();
+ }
+
+
+ unsigned long num_states (
+ const node_iterator&
+ ) const
+ {
+ return 3;
+ }
+
+ unsigned long num_states (
+ const neighbor_iterator&
+ ) const
+ {
+ return 3;
+ }
+
+ double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.nid, s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); }
+ double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); }
+ double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const
+ { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); }
+
+ private:
+
+ double basic_factor_value (
+ unsigned long n1,
+ unsigned long n2,
+ unsigned long s1,
+ unsigned long s2
+ ) const
+ {
+ if (n1 > n2)
+ {
+ swap(n1,n2);
+ swap(s1,s2);
+ }
+
+
+ // basically ignore the other node in this factor. The node we
+ // are ignoring is the center node of this star graph. So we basically
+ // let it always have a value of 1.
+ if (s2 == 1)
+ return numbers(n1,s1) + 1;
+ else
+ return numbers(n1,s1);
+ }
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename map_problem>
+ double find_total_score (
+ const map_problem& prob,
+ const std::vector<unsigned long>& map_assignment
+ )
+ {
+ typedef typename map_problem::node_iterator node_iterator;
+ typedef typename map_problem::neighbor_iterator neighbor_iterator;
+
+ double score = 0;
+ for (node_iterator i = prob.begin(); i != prob.end(); ++i)
+ {
+ const unsigned long id_i = prob.node_id(i);
+ for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j)
+ {
+ const unsigned long id_j = prob.node_id(j);
+ score += prob.factor_value(i,j, map_assignment[id_i], map_assignment[id_j]);
+ }
+ }
+
+ return score;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ template <
+ typename map_problem
+ >
+ void brute_force_find_max_factor_graph_nmplp (
+ const map_problem& prob,
+ std::vector<unsigned long>& map_assignment
+ )
+ {
+ std::vector<unsigned long> temp_assignment;
+ temp_assignment.resize(prob.number_of_nodes(),0);
+
+ double best_score = -std::numeric_limits<double>::infinity();
+
+ for (unsigned long i = 0; i < 255; ++i)
+ {
+ temp_assignment[0] = (i&0x01)!=0;
+ temp_assignment[1] = (i&0x02)!=0;
+ temp_assignment[2] = (i&0x04)!=0;
+ temp_assignment[3] = (i&0x08)!=0;
+ temp_assignment[4] = (i&0x10)!=0;
+ temp_assignment[5] = (i&0x20)!=0;
+ temp_assignment[6] = (i&0x40)!=0;
+ temp_assignment[7] = (i&0x80)!=0;
+
+ double score = find_total_score(prob,temp_assignment);
+ if (score > best_score)
+ {
+ best_score = score;
+ map_assignment = temp_assignment;
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename map_problem>
+ void do_test(
+ )
+ {
+ print_spinner();
+ std::vector<unsigned long> map_assignment1, map_assignment2;
+ map_problem prob;
+ find_max_factor_graph_nmplp(prob, map_assignment1, 1000, 1e-8);
+
+ const double score1 = find_total_score(prob, map_assignment1);
+
+ brute_force_find_max_factor_graph_nmplp(prob, map_assignment2);
+ const double score2 = find_total_score(prob, map_assignment2);
+
+ dlog << LINFO << "score NMPLP: " << score1;
+ dlog << LINFO << "score MAP: " << score2;
+
+ DLIB_TEST(std::abs(score1 - score2) < 1e-10);
+ DLIB_TEST(mat(map_assignment1) == mat(map_assignment2));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <typename map_problem>
+ void do_test2(
+ )
+ {
+ print_spinner();
+ std::vector<unsigned long> map_assignment1, map_assignment2;
+ map_problem prob;
+ find_max_factor_graph_nmplp(prob, map_assignment1, 10, 1e-8);
+
+ const double score1 = find_total_score(prob, map_assignment1);
+
+ map_assignment2.resize(6);
+ map_assignment2[0] = index_of_max(rowm(prob.numbers,0));
+ map_assignment2[1] = index_of_max(rowm(prob.numbers,1));
+ map_assignment2[2] = index_of_max(rowm(prob.numbers,2));
+ map_assignment2[3] = index_of_max(rowm(prob.numbers,3));
+ map_assignment2[4] = index_of_max(rowm(prob.numbers,4));
+ map_assignment2[5] = 1;
+ const double score2 = find_total_score(prob, map_assignment2);
+
+ dlog << LINFO << "score NMPLP: " << score1;
+ dlog << LINFO << "score MAP: " << score2;
+ dlog << LINFO << "MAP assignment: "<< trans(mat(map_assignment1));
+
+ DLIB_TEST(std::abs(score1 - score2) < 1e-10);
+ DLIB_TEST(mat(map_assignment1) == mat(map_assignment2));
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class test_find_max_factor_graph_nmplp : public tester
+ {
+ public:
+ test_find_max_factor_graph_nmplp (
+ ) :
+ tester ("test_find_max_factor_graph_nmplp",
+ "Runs tests on the find_max_factor_graph_nmplp routine.")
+ {}
+
+ void perform_test (
+ )
+ {
+ rnd.clear();
+
+ dlog << LINFO << "test on a chain structured graph";
+ for (int i = 0; i < 30; ++i)
+ do_test<map_problem_chain>();
+
+ dlog << LINFO << "test on a 2 cycle graph";
+ for (int i = 0; i < 30; ++i)
+ do_test<map_problem<false> >();
+
+ dlog << LINFO << "test on a fully connected graph";
+ for (int i = 0; i < 5; ++i)
+ do_test<map_problem<true> >();
+
+ dlog << LINFO << "test on a tree structured graph";
+ for (int i = 0; i < 10; ++i)
+ do_test2<map_problem2>();
+ }
+ } a;
+
+}
+
+
+
+