diff options
Diffstat (limited to 'ml/dlib/dlib/bsp/bsp.h')
-rw-r--r-- | ml/dlib/dlib/bsp/bsp.h | 1043 |
1 files changed, 1043 insertions, 0 deletions
diff --git a/ml/dlib/dlib/bsp/bsp.h b/ml/dlib/dlib/bsp/bsp.h new file mode 100644 index 000000000..f0732c153 --- /dev/null +++ b/ml/dlib/dlib/bsp/bsp.h @@ -0,0 +1,1043 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BsP_Hh_ +#define DLIB_BsP_Hh_ + +#include "bsp_abstract.h" + +#include <memory> +#include <queue> +#include <vector> + +#include "../sockets.h" +#include "../array.h" +#include "../sockstreambuf.h" +#include "../string.h" +#include "../serialize.h" +#include "../map.h" +#include "../ref.h" +#include "../vectorstream.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl1 + { + inline void null_notify( + unsigned short + ) {} + + struct bsp_con + { + bsp_con( + const network_address& dest + ) : + con(connect(dest)), + buf(con), + stream(&buf), + terminated(false) + { + con->disable_nagle(); + } + + bsp_con( + std::unique_ptr<connection>& conptr + ) : + buf(conptr), + stream(&buf), + terminated(false) + { + // make sure we own the connection + conptr.swap(con); + + con->disable_nagle(); + } + + std::unique_ptr<connection> con; + sockstreambuf buf; + std::iostream stream; + bool terminated; + }; + + typedef dlib::map<unsigned long, std::unique_ptr<bsp_con> >::kernel_1a_c map_id_to_con; + + void connect_all ( + map_id_to_con& cons, + const std::vector<network_address>& hosts, + unsigned long node_id + ); + /*! + ensures + - creates connections to all the given hosts and stores them into cons + !*/ + + void send_out_connection_orders ( + map_id_to_con& cons, + const std::vector<network_address>& hosts + ); + + // ------------------------------------------------------------------------------------ + + struct hostinfo + { + hostinfo() {} + hostinfo ( + const network_address& addr_, + unsigned long node_id_ + ) : + addr(addr_), + node_id(node_id_) + { + } + + network_address addr; + unsigned long node_id; + }; + + inline void serialize ( + const hostinfo& item, + std::ostream& out + ) + { + dlib::serialize(item.addr, out); + dlib::serialize(item.node_id, out); + } + + inline void deserialize ( + hostinfo& item, + std::istream& in + ) + { + dlib::deserialize(item.addr, in); + dlib::deserialize(item.node_id, in); + } + + // ------------------------------------------------------------------------------------ + + void connect_all_hostinfo ( + map_id_to_con& cons, + const std::vector<hostinfo>& hosts, + unsigned long node_id, + std::string& error_string + ); + + // ------------------------------------------------------------------------------------ + + template < + typename port_notify_function_type + > + void listen_and_connect_all( + unsigned long& node_id, + map_id_to_con& cons, + unsigned short port, + port_notify_function_type port_notify_function + ) + { + cons.clear(); + std::unique_ptr<listener> list; + const int status = create_listener(list, port); + if (status == PORTINUSE) + { + throw socket_error("Unable to create listening port " + cast_to_string(port) + + ". The port is already in use"); + } + else if (status != 0) + { + throw socket_error("Unable to create listening port " + cast_to_string(port) ); + } + + port_notify_function(list->get_listening_port()); + + std::unique_ptr<connection> con; + if (list->accept(con)) + { + throw socket_error("Error occurred while accepting new connection"); + } + + std::unique_ptr<bsp_con> temp(new bsp_con(con)); + + unsigned long remote_node_id; + dlib::deserialize(remote_node_id, temp->stream); + dlib::deserialize(node_id, temp->stream); + std::vector<hostinfo> targets; + dlib::deserialize(targets, temp->stream); + unsigned long num_incoming_connections; + dlib::deserialize(num_incoming_connections, temp->stream); + + cons.add(remote_node_id,temp); + + // make a thread that will connect to all the targets + map_id_to_con cons2; + std::string error_string; + thread_function thread(connect_all_hostinfo, dlib::ref(cons2), dlib::ref(targets), node_id, dlib::ref(error_string)); + if (error_string.size() != 0) + throw socket_error(error_string); + + // accept any incoming connections + for (unsigned long i = 0; i < num_incoming_connections; ++i) + { + // If it takes more than 10 seconds for the other nodes to connect to us + // then something has gone horribly wrong and it almost certainly will + // never connect at all. So just give up if that happens. + const unsigned long timeout_milliseconds = 10000; + if (list->accept(con, timeout_milliseconds)) + { + throw socket_error("Error occurred while accepting new connection"); + } + + temp.reset(new bsp_con(con)); + + dlib::deserialize(remote_node_id, temp->stream); + cons.add(remote_node_id,temp); + } + + + // put all the connections created by the thread into cons + thread.wait(); + while (cons2.size() > 0) + { + unsigned long id; + std::unique_ptr<bsp_con> temp; + cons2.remove_any(id,temp); + cons.add(id,temp); + } + } + + // ------------------------------------------------------------------------------------ + + struct msg_data + { + std::shared_ptr<std::vector<char> > data; + unsigned long sender_id; + char msg_type; + dlib::uint64 epoch; + + msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {} + + std::string data_to_string() const + { + if (data && data->size() != 0) + return std::string(&(*data)[0], data->size()); + else + return ""; + } + }; + + // ------------------------------------------------------------------------------------ + + class thread_safe_message_queue : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple message queue for msg_data objects. Note that it + has the special property that, while messages will generally leave + the queue in the order they are inserted, any message with a smaller + epoch value will always be popped out first. But for all messages + with equal epoch values the queue functions as a normal FIFO queue. + !*/ + private: + struct msg_wrap + { + msg_wrap( + const msg_data& data_, + const dlib::uint64& sequence_number_ + ) : data(data_), sequence_number(sequence_number_) {} + + msg_wrap() : sequence_number(0){} + + msg_data data; + dlib::uint64 sequence_number; + + // Make it so that when msg_wrap objects are in a std::priority_queue, + // messages with a smaller epoch number always come first. Then, within an + // epoch, messages are ordered by their sequence number (so smaller first + // there as well). + bool operator<(const msg_wrap& item) const + { + if (data.epoch < item.data.epoch) + { + return false; + } + else if (data.epoch > item.data.epoch) + { + return true; + } + else + { + if (sequence_number < item.sequence_number) + return false; + else + return true; + } + } + }; + + public: + thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {} + + ~thread_safe_message_queue() + { + disable(); + } + + void disable() + { + auto_mutex lock(class_mutex); + disabled = true; + sig.broadcast(); + } + + unsigned long size() const + { + auto_mutex lock(class_mutex); + return data.size(); + } + + void push_and_consume( msg_data& item) + { + auto_mutex lock(class_mutex); + data.push(msg_wrap(item, next_seq_num++)); + // do this here so that we don't have to worry about different threads touching the shared_ptr. + item.data.reset(); + sig.signal(); + } + + bool pop ( + msg_data& item + ) + /*! + ensures + - if (this function returns true) then + - #item == the next thing from the queue + - else + - this object is disabled + !*/ + { + auto_mutex lock(class_mutex); + while (data.size() == 0 && !disabled) + sig.wait(); + + if (disabled) + return false; + + item = data.top().data; + data.pop(); + + return true; + } + + bool pop ( + msg_data& item, + const dlib::uint64& max_epoch + ) + /*! + ensures + - if (this function returns true) then + - #item == the next thing from the queue that has an epoch <= max_epoch + - else + - this object is disabled + !*/ + { + auto_mutex lock(class_mutex); + while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled) + sig.wait(); + + if (disabled) + return false; + + item = data.top().data; + data.pop(); + + return true; + } + + private: + std::priority_queue<msg_wrap> data; + dlib::mutex class_mutex; + dlib::signaler sig; + bool disabled; + dlib::uint64 next_seq_num; + }; + + + } + +// ---------------------------------------------------------------------------------------- + + class bsp_context : noncopyable + { + + public: + + template <typename T> + void send( + const T& item, + unsigned long target_node_id + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(target_node_id < number_of_nodes() && + target_node_id != node_id(), + "\t void bsp_context::send()" + << "\n\t Invalid arguments were given to this function." + << "\n\t target_node_id: " << target_node_id + << "\n\t node_id(): " << node_id() + << "\n\t number_of_nodes(): " << number_of_nodes() + << "\n\t this: " << this + ); + + std::vector<char> buf; + vectorstream sout(buf); + serialize(item, sout); + send_data(buf, target_node_id); + } + + template <typename T> + void broadcast ( + const T& item + ) + { + std::vector<char> buf; + vectorstream sout(buf); + serialize(item, sout); + for (unsigned long i = 0; i < number_of_nodes(); ++i) + { + // Don't send to yourself. + if (i == node_id()) + continue; + + send_data(buf, i); + } + } + + unsigned long node_id ( + ) const { return _node_id; } + + unsigned long number_of_nodes ( + ) const { return _cons.size()+1; } + + void receive ( + ) + { + unsigned long id; + std::shared_ptr<std::vector<char> > temp; + if (receive_data(temp,id)) + throw dlib::socket_error("Call to bsp_context::receive() got an unexpected message."); + } + + template <typename T> + void receive ( + T& item + ) + { + if(!try_receive(item)) + throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); + } + + template <typename T> + bool try_receive ( + T& item + ) + { + unsigned long sending_node_id; + return try_receive(item, sending_node_id); + } + + template <typename T> + void receive ( + T& item, + unsigned long& sending_node_id + ) + { + if(!try_receive(item, sending_node_id)) + throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); + } + + template <typename T> + bool try_receive ( + T& item, + unsigned long& sending_node_id + ) + { + std::shared_ptr<std::vector<char> > temp; + if (receive_data(temp, sending_node_id)) + { + vectorstream sin(*temp); + deserialize(item, sin); + if (sin.peek() != EOF) + throw serialization_error("deserialize() did not consume all bytes produced by serialize(). " + "This probably means you are calling a receive method with a different type " + "of object than the one which was sent."); + return true; + } + else + { + return false; + } + } + + ~bsp_context(); + + private: + + bsp_context(); + + bsp_context( + unsigned long node_id_, + impl1::map_id_to_con& cons_ + ); + + void close_all_connections_gracefully(); + /*! + ensures + - closes all the connections to other nodes and lets them know that + we are terminating normally rather than as the result of some kind + of error. + !*/ + + bool receive_data ( + std::shared_ptr<std::vector<char> >& item, + unsigned long& sending_node_id + ); + + + void notify_control_node ( + char val + ); + + void broadcast_byte ( + char val + ); + + void send_data( + const std::vector<char>& item, + unsigned long target_node_id + ); + /*! + requires + - target_node_id < number_of_nodes() + - target_node_id != node_id() + ensures + - sends a copy of item to the node with the given id. + !*/ + + + + + unsigned long outstanding_messages; + unsigned long num_waiting_nodes; + unsigned long num_terminated_nodes; + dlib::uint64 current_epoch; + + impl1::thread_safe_message_queue msg_buffer; + + impl1::map_id_to_con& _cons; + const unsigned long _node_id; + array<std::unique_ptr<thread_function> > threads; + + // ----------------------------------- + + template < + typename funct_type + > + friend void bsp_connect ( + const std::vector<network_address>& hosts, + funct_type funct + ); + + template < + typename funct_type, + typename ARG1 + > + friend void bsp_connect ( + const std::vector<network_address>& hosts, + funct_type funct, + ARG1 arg1 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + friend void bsp_connect ( + const std::vector<network_address>& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + friend void bsp_connect ( + const std::vector<network_address>& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + friend void bsp_connect ( + const std::vector<network_address>& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + + // ----------------------------------- + + template < + typename port_notify_function_type, + typename funct_type + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + + // ----------------------------------- + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_connect ( + const std::vector<network_address>& hosts, + funct_type funct + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_connect ( + const std::vector<network_address>& hosts, + funct_type funct, + ARG1 arg1 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_connect ( + const std::vector<network_address>& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_connect ( + const std::vector<network_address>& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_connect ( + const std::vector<network_address>& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3,arg4); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3, arg4); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3,arg4); + obj.close_all_connections_gracefully(); + } +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "bsp.cpp" +#endif + +#endif // DLIB_BsP_Hh_ + |