diff options
Diffstat (limited to 'ml/dlib/dlib/sockets/sockets_extensions.cpp')
-rw-r--r-- | ml/dlib/dlib/sockets/sockets_extensions.cpp | 341 |
1 files changed, 341 insertions, 0 deletions
diff --git a/ml/dlib/dlib/sockets/sockets_extensions.cpp b/ml/dlib/dlib/sockets/sockets_extensions.cpp new file mode 100644 index 000000000..be08c1998 --- /dev/null +++ b/ml/dlib/dlib/sockets/sockets_extensions.cpp @@ -0,0 +1,341 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SOCKETS_EXTENSIONs_CPP +#define DLIB_SOCKETS_EXTENSIONs_CPP + +#include <string> +#include <sstream> +#include "../sockets.h" +#include "../error.h" +#include "sockets_extensions.h" +#include "../timer.h" +#include "../algs.h" +#include "../timeout.h" +#include "../misc_api.h" +#include "../serialize.h" +#include "../string.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + network_address:: + network_address( + const std::string& full_address + ) + { + std::istringstream sin(full_address); + sin >> *this; + if (!sin || sin.peek() != EOF) + throw invalid_network_address("invalid network address: " + full_address); + } + +// ---------------------------------------------------------------------------------------- + + void serialize( + const network_address& item, + std::ostream& out + ) + { + serialize(item.host_address, out); + serialize(item.port, out); + } + +// ---------------------------------------------------------------------------------------- + + void deserialize( + network_address& item, + std::istream& in + ) + { + deserialize(item.host_address, in); + deserialize(item.port, in); + } + +// ---------------------------------------------------------------------------------------- + + std::ostream& operator<< ( + std::ostream& out, + const network_address& item + ) + { + out << item.host_address << ":" << item.port; + return out; + } + +// ---------------------------------------------------------------------------------------- + + std::istream& operator>> ( + std::istream& in, + network_address& item + ) + { + std::string temp; + in >> temp; + + std::string::size_type pos = temp.find_last_of(":"); + if (pos == std::string::npos) + { + in.setstate(std::ios::badbit); + return in; + } + + item.host_address = temp.substr(0, pos); + try + { + item.port = sa = temp.substr(pos+1); + } catch (std::exception& ) + { + in.setstate(std::ios::badbit); + return in; + } + + + return in; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + connection* connect ( + const std::string& host_or_ip, + unsigned short port + ) + { + std::string ip; + connection* con; + if (is_ip_address(host_or_ip)) + { + ip = host_or_ip; + } + else + { + if( hostname_to_ip(host_or_ip,ip)) + throw socket_error(ERESOLVE,"unable to resolve '" + host_or_ip + "' in connect()"); + } + + if(create_connection(con,port,ip)) + { + std::ostringstream sout; + sout << "unable to connect to '" << host_or_ip << ":" << port << "'"; + throw socket_error(sout.str()); + } + + return con; + } + +// ---------------------------------------------------------------------------------------- + + connection* connect ( + const network_address& addr + ) + { + return connect(addr.host_address, addr.port); + } + +// ---------------------------------------------------------------------------------------- + + namespace connect_timeout_helpers + { + mutex connect_mutex; + signaler connect_signaler(connect_mutex); + timestamper ts; + long outstanding_connects = 0; + + struct thread_data + { + std::string host_or_ip; + unsigned short port; + connection* con; + bool connect_ended; + bool error_occurred; + }; + + void thread(void* param) + { + thread_data p = *static_cast<thread_data*>(param); + try + { + p.con = connect(p.host_or_ip, p.port); + } + catch (...) + { + p.error_occurred = true; + } + + auto_mutex M(connect_mutex); + // report the results back to the connect() call that spawned this + // thread. + static_cast<thread_data*>(param)->con = p.con; + static_cast<thread_data*>(param)->error_occurred = p.error_occurred; + connect_signaler.broadcast(); + + // wait for the call to connect() that spawned this thread to terminate + // before we delete the thread_data struct. + while (static_cast<thread_data*>(param)->connect_ended == false) + connect_signaler.wait(); + + connect_signaler.broadcast(); + --outstanding_connects; + delete static_cast<thread_data*>(param); + } + } + + connection* connect ( + const std::string& host_or_ip, + unsigned short port, + unsigned long timeout + ) + { + using namespace connect_timeout_helpers; + + auto_mutex M(connect_mutex); + + const uint64 end_time = ts.get_timestamp() + timeout*1000; + + + // wait until there are less than 100 outstanding connections + while (outstanding_connects > 100) + { + uint64 cur_time = ts.get_timestamp(); + if (end_time > cur_time) + { + timeout = static_cast<unsigned long>((end_time - cur_time)/1000); + } + else + { + throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out"); + } + + connect_signaler.wait_or_timeout(timeout); + } + + + thread_data* data = new thread_data; + data->host_or_ip = host_or_ip.c_str(); + data->port = port; + data->con = 0; + data->connect_ended = false; + data->error_occurred = false; + + + if (create_new_thread(thread, data) == false) + { + delete data; + throw socket_error("unable to connect to '" + host_or_ip); + } + + ++outstanding_connects; + + // wait until we have a connection object + while (data->con == 0) + { + uint64 cur_time = ts.get_timestamp(); + if (end_time > cur_time && data->error_occurred == false) + { + timeout = static_cast<unsigned long>((end_time - cur_time)/1000); + } + else + { + // let the thread know that it should terminate + data->connect_ended = true; + connect_signaler.broadcast(); + if (data->error_occurred) + throw socket_error("unable to connect to '" + host_or_ip); + else + throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out"); + } + + connect_signaler.wait_or_timeout(timeout); + } + + // let the thread know that it should terminate + data->connect_ended = true; + connect_signaler.broadcast(); + return data->con; + } + +// ---------------------------------------------------------------------------------------- + + bool is_ip_address ( + std::string ip + ) + { + for (std::string::size_type i = 0; i < ip.size(); ++i) + { + if (ip[i] == '.') + ip[i] = ' '; + } + std::istringstream sin(ip); + + bool bad = false; + int num; + for (int i = 0; i < 4; ++i) + { + sin >> num; + if (!sin || num < 0 || num > 255) + { + bad = true; + break; + } + } + + if (sin.get() != EOF) + bad = true; + + return !bad; + } + +// ---------------------------------------------------------------------------------------- + + void close_gracefully ( + connection* con, + unsigned long timeout + ) + { + std::unique_ptr<connection> ptr(con); + close_gracefully(ptr,timeout); + } + +// ---------------------------------------------------------------------------------------- + + void close_gracefully ( + std::unique_ptr<connection>& con, + unsigned long timeout + ) + { + if (!con) + return; + + if(con->shutdown_outgoing()) + { + // there was an error so just close it now and return + con.reset(); + return; + } + + try + { + dlib::timeout t(*con,&connection::shutdown,timeout); + + char junk[100]; + // wait for the other end to close their side + while (con->read(junk,sizeof(junk)) > 0) ; + } + catch (...) + { + con.reset(); + throw; + } + + con.reset(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_SOCKETS_EXTENSIONs_CPP + + |