diff options
Diffstat (limited to 'ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp')
-rw-r--r-- | ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp | 787 |
1 files changed, 787 insertions, 0 deletions
diff --git a/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp b/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp new file mode 100644 index 000000000..2260e92a1 --- /dev/null +++ b/ml/dlib/dlib/test/find_max_factor_graph_nmplp.cpp @@ -0,0 +1,787 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#include <sstream> +#include <string> +#include <cstdlib> +#include <ctime> +#include <dlib/optimization.h> +#include <dlib/unordered_pair.h> +#include <dlib/rand.h> + +#include "tester.h" + +namespace +{ + using namespace test; + using namespace dlib; + using namespace std; + + logger dlog("test.find_max_factor_graph_nmplp"); + +// ---------------------------------------------------------------------------------------- + + dlib::rand rnd; + + template <bool fully_connected> + class map_problem + { + /* + This is a simple 8 node problem with two cycles in it unless fully_connected is true + and then it's a fully connected 8 note graph. + */ + + public: + + mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights; + map_problem() + { + for (int i = 0; i < 8; ++i) + { + for (int j = i; j < 8; ++j) + { + weights[make_unordered_pair(i,j)][make_pair(0,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,j)][make_pair(0,1)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,j)][make_pair(1,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,j)][make_pair(1,1)] = rnd.get_random_gaussian(); + } + } + } + + struct node_iterator + { + node_iterator() {} + node_iterator(unsigned long nid_): nid(nid_) {} + bool operator== (const node_iterator& item) const { return item.nid == nid; } + bool operator!= (const node_iterator& item) const { return item.nid != nid; } + + node_iterator& operator++() + { + ++nid; + return *this; + } + + unsigned long nid; + }; + + struct neighbor_iterator + { + neighbor_iterator() : count(0) {} + + bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); } + bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); } + neighbor_iterator& operator++() + { + ++count; + return *this; + } + + unsigned long node_id () const + { + if (fully_connected) + { + if (count < home_node) + return count; + else + return count+1; + } + + if (home_node < 4) + { + if (count == 0) + return (home_node + 4 + 1)%4; + else if (count == 1) + return (home_node + 4 - 1)%4; + else + return 8; // one past the end + } + else + { + if (count == 0) + return (home_node + 4 + 1)%4 + 4; + else if (count == 1) + return (home_node + 4 - 1)%4 + 4; + else + return 8; // one past the end + } + } + + unsigned long home_node; + unsigned long count; + }; + + unsigned long number_of_nodes ( + ) const + { + return 8; + } + + node_iterator begin( + ) const + { + node_iterator temp; + temp.nid = 0; + return temp; + } + + node_iterator end( + ) const + { + node_iterator temp; + temp.nid = 8; + return temp; + } + + neighbor_iterator begin( + const node_iterator& it + ) const + { + neighbor_iterator temp; + temp.home_node = it.nid; + return temp; + } + + neighbor_iterator begin( + const neighbor_iterator& it + ) const + { + neighbor_iterator temp; + temp.home_node = it.node_id(); + return temp; + } + + neighbor_iterator end( + const node_iterator& + ) const + { + neighbor_iterator temp; + temp.home_node = 9; + temp.count = 8; + return temp; + } + + neighbor_iterator end( + const neighbor_iterator& + ) const + { + neighbor_iterator temp; + temp.home_node = 9; + temp.count = 8; + return temp; + } + + + unsigned long node_id ( + const node_iterator& it + ) const + { + return it.nid; + } + + unsigned long node_id ( + const neighbor_iterator& it + ) const + { + return it.node_id(); + } + + + unsigned long num_states ( + const node_iterator& + ) const + { + return 2; + } + + unsigned long num_states ( + const neighbor_iterator& + ) const + { + return 2; + } + + double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.nid, s1, s2); } + double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); } + double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); } + double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); } + + private: + + double basic_factor_value ( + unsigned long n1, + unsigned long n2, + unsigned long s1, + unsigned long s2 + ) const + { + if (n1 > n2) + { + swap(n1,n2); + swap(s1,s2); + } + return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)]; + } + + }; + +// ---------------------------------------------------------------------------------------- + + class map_problem_chain + { + /* + This is a chain structured 8 node graph (so no cycles). + */ + + public: + + mutable std::map<unordered_pair<int>,std::map<std::pair<int,int>,double> > weights; + map_problem_chain() + { + for (int i = 0; i < 7; ++i) + { + weights[make_unordered_pair(i,i+1)][make_pair(0,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,i+1)][make_pair(0,1)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,i+1)][make_pair(1,0)] = rnd.get_random_gaussian(); + weights[make_unordered_pair(i,i+1)][make_pair(1,1)] = rnd.get_random_gaussian(); + } + } + + struct node_iterator + { + node_iterator() {} + node_iterator(unsigned long nid_): nid(nid_) {} + bool operator== (const node_iterator& item) const { return item.nid == nid; } + bool operator!= (const node_iterator& item) const { return item.nid != nid; } + + node_iterator& operator++() + { + ++nid; + return *this; + } + + unsigned long nid; + }; + + struct neighbor_iterator + { + neighbor_iterator() : count(0) {} + + bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); } + bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); } + neighbor_iterator& operator++() + { + ++count; + return *this; + } + + unsigned long node_id () const + { + if (count >= 2) + return 8; + return nid[count]; + } + + unsigned long nid[2]; + unsigned int count; + }; + + unsigned long number_of_nodes ( + ) const + { + return 8; + } + + node_iterator begin( + ) const + { + node_iterator temp; + temp.nid = 0; + return temp; + } + + node_iterator end( + ) const + { + node_iterator temp; + temp.nid = 8; + return temp; + } + + neighbor_iterator begin( + const node_iterator& it + ) const + { + neighbor_iterator temp; + if (it.nid == 0) + { + temp.nid[0] = it.nid+1; + temp.nid[1] = 8; + } + else if (it.nid == 7) + { + temp.nid[0] = it.nid-1; + temp.nid[1] = 8; + } + else + { + temp.nid[0] = it.nid-1; + temp.nid[1] = it.nid+1; + } + return temp; + } + + neighbor_iterator begin( + const neighbor_iterator& it + ) const + { + const unsigned long nid = it.node_id(); + neighbor_iterator temp; + if (nid == 0) + { + temp.nid[0] = nid+1; + temp.nid[1] = 8; + } + else if (nid == 7) + { + temp.nid[0] = nid-1; + temp.nid[1] = 8; + } + else + { + temp.nid[0] = nid-1; + temp.nid[1] = nid+1; + } + return temp; + } + + neighbor_iterator end( + const node_iterator& + ) const + { + neighbor_iterator temp; + temp.nid[0] = 8; + temp.nid[1] = 8; + return temp; + } + + neighbor_iterator end( + const neighbor_iterator& + ) const + { + neighbor_iterator temp; + temp.nid[0] = 8; + temp.nid[1] = 8; + return temp; + } + + + unsigned long node_id ( + const node_iterator& it + ) const + { + return it.nid; + } + + unsigned long node_id ( + const neighbor_iterator& it + ) const + { + return it.node_id(); + } + + + unsigned long num_states ( + const node_iterator& + ) const + { + return 2; + } + + unsigned long num_states ( + const neighbor_iterator& + ) const + { + return 2; + } + + double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.nid, s1, s2); } + double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); } + double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); } + double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); } + + private: + + double basic_factor_value ( + unsigned long n1, + unsigned long n2, + unsigned long s1, + unsigned long s2 + ) const + { + if (n1 > n2) + { + swap(n1,n2); + swap(s1,s2); + } + return weights[make_unordered_pair(n1,n2)][make_pair(s1,s2)]; + } + + }; + +// ---------------------------------------------------------------------------------------- + + + class map_problem2 + { + /* + This is a simple tree structured graph. In particular, it is a star made + up of 6 nodes. + */ + public: + matrix<double> numbers; + + map_problem2() + { + numbers = randm(5,3,rnd); + } + + struct node_iterator + { + node_iterator() {} + node_iterator(unsigned long nid_): nid(nid_) {} + bool operator== (const node_iterator& item) const { return item.nid == nid; } + bool operator!= (const node_iterator& item) const { return item.nid != nid; } + + node_iterator& operator++() + { + ++nid; + return *this; + } + + unsigned long nid; + }; + + struct neighbor_iterator + { + neighbor_iterator() : count(0) {} + + bool operator== (const neighbor_iterator& item) const { return item.node_id() == node_id(); } + bool operator!= (const neighbor_iterator& item) const { return item.node_id() != node_id(); } + neighbor_iterator& operator++() + { + ++count; + return *this; + } + + unsigned long node_id () const + { + if (home_node == 6) + return 6; + + if (home_node < 5) + { + // all the nodes are connected to node 5 and nothing else + if (count == 0) + return 5; + else + return 6; // the number returned by the end() functions. + } + else if (count < 5) + { + return count; + } + else + { + return 6; + } + + } + + unsigned long home_node; + unsigned long count; + }; + + unsigned long number_of_nodes ( + ) const + { + return 6; + } + + node_iterator begin( + ) const + { + node_iterator temp; + temp.nid = 0; + return temp; + } + + node_iterator end( + ) const + { + node_iterator temp; + temp.nid = 6; + return temp; + } + + neighbor_iterator begin( + const node_iterator& it + ) const + { + neighbor_iterator temp; + temp.home_node = it.nid; + return temp; + } + + neighbor_iterator begin( + const neighbor_iterator& it + ) const + { + neighbor_iterator temp; + temp.home_node = it.node_id(); + return temp; + } + + neighbor_iterator end( + const node_iterator& + ) const + { + neighbor_iterator temp; + temp.home_node = 6; + return temp; + } + + neighbor_iterator end( + const neighbor_iterator& + ) const + { + neighbor_iterator temp; + temp.home_node = 6; + return temp; + } + + + unsigned long node_id ( + const node_iterator& it + ) const + { + return it.nid; + } + + unsigned long node_id ( + const neighbor_iterator& it + ) const + { + return it.node_id(); + } + + + unsigned long num_states ( + const node_iterator& + ) const + { + return 3; + } + + unsigned long num_states ( + const neighbor_iterator& + ) const + { + return 3; + } + + double factor_value (const node_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.nid, s1, s2); } + double factor_value (const neighbor_iterator& it1, const node_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.nid, s1, s2); } + double factor_value (const node_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.nid, it2.node_id(), s1, s2); } + double factor_value (const neighbor_iterator& it1, const neighbor_iterator& it2, unsigned long s1, unsigned long s2) const + { return basic_factor_value(it1.node_id(), it2.node_id(), s1, s2); } + + private: + + double basic_factor_value ( + unsigned long n1, + unsigned long n2, + unsigned long s1, + unsigned long s2 + ) const + { + if (n1 > n2) + { + swap(n1,n2); + swap(s1,s2); + } + + + // basically ignore the other node in this factor. The node we + // are ignoring is the center node of this star graph. So we basically + // let it always have a value of 1. + if (s2 == 1) + return numbers(n1,s1) + 1; + else + return numbers(n1,s1); + } + + }; + +// ---------------------------------------------------------------------------------------- + + template <typename map_problem> + double find_total_score ( + const map_problem& prob, + const std::vector<unsigned long>& map_assignment + ) + { + typedef typename map_problem::node_iterator node_iterator; + typedef typename map_problem::neighbor_iterator neighbor_iterator; + + double score = 0; + for (node_iterator i = prob.begin(); i != prob.end(); ++i) + { + const unsigned long id_i = prob.node_id(i); + for (neighbor_iterator j = prob.begin(i); j != prob.end(i); ++j) + { + const unsigned long id_j = prob.node_id(j); + score += prob.factor_value(i,j, map_assignment[id_i], map_assignment[id_j]); + } + } + + return score; + } + +// ---------------------------------------------------------------------------------------- + + + template < + typename map_problem + > + void brute_force_find_max_factor_graph_nmplp ( + const map_problem& prob, + std::vector<unsigned long>& map_assignment + ) + { + std::vector<unsigned long> temp_assignment; + temp_assignment.resize(prob.number_of_nodes(),0); + + double best_score = -std::numeric_limits<double>::infinity(); + + for (unsigned long i = 0; i < 255; ++i) + { + temp_assignment[0] = (i&0x01)!=0; + temp_assignment[1] = (i&0x02)!=0; + temp_assignment[2] = (i&0x04)!=0; + temp_assignment[3] = (i&0x08)!=0; + temp_assignment[4] = (i&0x10)!=0; + temp_assignment[5] = (i&0x20)!=0; + temp_assignment[6] = (i&0x40)!=0; + temp_assignment[7] = (i&0x80)!=0; + + double score = find_total_score(prob,temp_assignment); + if (score > best_score) + { + best_score = score; + map_assignment = temp_assignment; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template <typename map_problem> + void do_test( + ) + { + print_spinner(); + std::vector<unsigned long> map_assignment1, map_assignment2; + map_problem prob; + find_max_factor_graph_nmplp(prob, map_assignment1, 1000, 1e-8); + + const double score1 = find_total_score(prob, map_assignment1); + + brute_force_find_max_factor_graph_nmplp(prob, map_assignment2); + const double score2 = find_total_score(prob, map_assignment2); + + dlog << LINFO << "score NMPLP: " << score1; + dlog << LINFO << "score MAP: " << score2; + + DLIB_TEST(std::abs(score1 - score2) < 1e-10); + DLIB_TEST(mat(map_assignment1) == mat(map_assignment2)); + } + +// ---------------------------------------------------------------------------------------- + + template <typename map_problem> + void do_test2( + ) + { + print_spinner(); + std::vector<unsigned long> map_assignment1, map_assignment2; + map_problem prob; + find_max_factor_graph_nmplp(prob, map_assignment1, 10, 1e-8); + + const double score1 = find_total_score(prob, map_assignment1); + + map_assignment2.resize(6); + map_assignment2[0] = index_of_max(rowm(prob.numbers,0)); + map_assignment2[1] = index_of_max(rowm(prob.numbers,1)); + map_assignment2[2] = index_of_max(rowm(prob.numbers,2)); + map_assignment2[3] = index_of_max(rowm(prob.numbers,3)); + map_assignment2[4] = index_of_max(rowm(prob.numbers,4)); + map_assignment2[5] = 1; + const double score2 = find_total_score(prob, map_assignment2); + + dlog << LINFO << "score NMPLP: " << score1; + dlog << LINFO << "score MAP: " << score2; + dlog << LINFO << "MAP assignment: "<< trans(mat(map_assignment1)); + + DLIB_TEST(std::abs(score1 - score2) < 1e-10); + DLIB_TEST(mat(map_assignment1) == mat(map_assignment2)); + } + +// ---------------------------------------------------------------------------------------- + + class test_find_max_factor_graph_nmplp : public tester + { + public: + test_find_max_factor_graph_nmplp ( + ) : + tester ("test_find_max_factor_graph_nmplp", + "Runs tests on the find_max_factor_graph_nmplp routine.") + {} + + void perform_test ( + ) + { + rnd.clear(); + + dlog << LINFO << "test on a chain structured graph"; + for (int i = 0; i < 30; ++i) + do_test<map_problem_chain>(); + + dlog << LINFO << "test on a 2 cycle graph"; + for (int i = 0; i < 30; ++i) + do_test<map_problem<false> >(); + + dlog << LINFO << "test on a fully connected graph"; + for (int i = 0; i < 5; ++i) + do_test<map_problem<true> >(); + + dlog << LINFO << "test on a tree structured graph"; + for (int i = 0; i < 10; ++i) + do_test2<map_problem2>(); + } + } a; + +} + + + + |