summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/sockets/sockets_extensions.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/sockets/sockets_extensions.cpp')
-rw-r--r--ml/dlib/dlib/sockets/sockets_extensions.cpp341
1 files changed, 341 insertions, 0 deletions
diff --git a/ml/dlib/dlib/sockets/sockets_extensions.cpp b/ml/dlib/dlib/sockets/sockets_extensions.cpp
new file mode 100644
index 000000000..be08c1998
--- /dev/null
+++ b/ml/dlib/dlib/sockets/sockets_extensions.cpp
@@ -0,0 +1,341 @@
+// Copyright (C) 2006 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_SOCKETS_EXTENSIONs_CPP
+#define DLIB_SOCKETS_EXTENSIONs_CPP
+
+#include <string>
+#include <sstream>
+#include "../sockets.h"
+#include "../error.h"
+#include "sockets_extensions.h"
+#include "../timer.h"
+#include "../algs.h"
+#include "../timeout.h"
+#include "../misc_api.h"
+#include "../serialize.h"
+#include "../string.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ network_address::
+ network_address(
+ const std::string& full_address
+ )
+ {
+ std::istringstream sin(full_address);
+ sin >> *this;
+ if (!sin || sin.peek() != EOF)
+ throw invalid_network_address("invalid network address: " + full_address);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void serialize(
+ const network_address& item,
+ std::ostream& out
+ )
+ {
+ serialize(item.host_address, out);
+ serialize(item.port, out);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void deserialize(
+ network_address& item,
+ std::istream& in
+ )
+ {
+ deserialize(item.host_address, in);
+ deserialize(item.port, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::ostream& operator<< (
+ std::ostream& out,
+ const network_address& item
+ )
+ {
+ out << item.host_address << ":" << item.port;
+ return out;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ std::istream& operator>> (
+ std::istream& in,
+ network_address& item
+ )
+ {
+ std::string temp;
+ in >> temp;
+
+ std::string::size_type pos = temp.find_last_of(":");
+ if (pos == std::string::npos)
+ {
+ in.setstate(std::ios::badbit);
+ return in;
+ }
+
+ item.host_address = temp.substr(0, pos);
+ try
+ {
+ item.port = sa = temp.substr(pos+1);
+ } catch (std::exception& )
+ {
+ in.setstate(std::ios::badbit);
+ return in;
+ }
+
+
+ return in;
+ }
+
+// ----------------------------------------------------------------------------------------
+// ----------------------------------------------------------------------------------------
+
+ connection* connect (
+ const std::string& host_or_ip,
+ unsigned short port
+ )
+ {
+ std::string ip;
+ connection* con;
+ if (is_ip_address(host_or_ip))
+ {
+ ip = host_or_ip;
+ }
+ else
+ {
+ if( hostname_to_ip(host_or_ip,ip))
+ throw socket_error(ERESOLVE,"unable to resolve '" + host_or_ip + "' in connect()");
+ }
+
+ if(create_connection(con,port,ip))
+ {
+ std::ostringstream sout;
+ sout << "unable to connect to '" << host_or_ip << ":" << port << "'";
+ throw socket_error(sout.str());
+ }
+
+ return con;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ connection* connect (
+ const network_address& addr
+ )
+ {
+ return connect(addr.host_address, addr.port);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ namespace connect_timeout_helpers
+ {
+ mutex connect_mutex;
+ signaler connect_signaler(connect_mutex);
+ timestamper ts;
+ long outstanding_connects = 0;
+
+ struct thread_data
+ {
+ std::string host_or_ip;
+ unsigned short port;
+ connection* con;
+ bool connect_ended;
+ bool error_occurred;
+ };
+
+ void thread(void* param)
+ {
+ thread_data p = *static_cast<thread_data*>(param);
+ try
+ {
+ p.con = connect(p.host_or_ip, p.port);
+ }
+ catch (...)
+ {
+ p.error_occurred = true;
+ }
+
+ auto_mutex M(connect_mutex);
+ // report the results back to the connect() call that spawned this
+ // thread.
+ static_cast<thread_data*>(param)->con = p.con;
+ static_cast<thread_data*>(param)->error_occurred = p.error_occurred;
+ connect_signaler.broadcast();
+
+ // wait for the call to connect() that spawned this thread to terminate
+ // before we delete the thread_data struct.
+ while (static_cast<thread_data*>(param)->connect_ended == false)
+ connect_signaler.wait();
+
+ connect_signaler.broadcast();
+ --outstanding_connects;
+ delete static_cast<thread_data*>(param);
+ }
+ }
+
+ connection* connect (
+ const std::string& host_or_ip,
+ unsigned short port,
+ unsigned long timeout
+ )
+ {
+ using namespace connect_timeout_helpers;
+
+ auto_mutex M(connect_mutex);
+
+ const uint64 end_time = ts.get_timestamp() + timeout*1000;
+
+
+ // wait until there are less than 100 outstanding connections
+ while (outstanding_connects > 100)
+ {
+ uint64 cur_time = ts.get_timestamp();
+ if (end_time > cur_time)
+ {
+ timeout = static_cast<unsigned long>((end_time - cur_time)/1000);
+ }
+ else
+ {
+ throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out");
+ }
+
+ connect_signaler.wait_or_timeout(timeout);
+ }
+
+
+ thread_data* data = new thread_data;
+ data->host_or_ip = host_or_ip.c_str();
+ data->port = port;
+ data->con = 0;
+ data->connect_ended = false;
+ data->error_occurred = false;
+
+
+ if (create_new_thread(thread, data) == false)
+ {
+ delete data;
+ throw socket_error("unable to connect to '" + host_or_ip);
+ }
+
+ ++outstanding_connects;
+
+ // wait until we have a connection object
+ while (data->con == 0)
+ {
+ uint64 cur_time = ts.get_timestamp();
+ if (end_time > cur_time && data->error_occurred == false)
+ {
+ timeout = static_cast<unsigned long>((end_time - cur_time)/1000);
+ }
+ else
+ {
+ // let the thread know that it should terminate
+ data->connect_ended = true;
+ connect_signaler.broadcast();
+ if (data->error_occurred)
+ throw socket_error("unable to connect to '" + host_or_ip);
+ else
+ throw socket_error("unable to connect to '" + host_or_ip + "' because connect timed out");
+ }
+
+ connect_signaler.wait_or_timeout(timeout);
+ }
+
+ // let the thread know that it should terminate
+ data->connect_ended = true;
+ connect_signaler.broadcast();
+ return data->con;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ bool is_ip_address (
+ std::string ip
+ )
+ {
+ for (std::string::size_type i = 0; i < ip.size(); ++i)
+ {
+ if (ip[i] == '.')
+ ip[i] = ' ';
+ }
+ std::istringstream sin(ip);
+
+ bool bad = false;
+ int num;
+ for (int i = 0; i < 4; ++i)
+ {
+ sin >> num;
+ if (!sin || num < 0 || num > 255)
+ {
+ bad = true;
+ break;
+ }
+ }
+
+ if (sin.get() != EOF)
+ bad = true;
+
+ return !bad;
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void close_gracefully (
+ connection* con,
+ unsigned long timeout
+ )
+ {
+ std::unique_ptr<connection> ptr(con);
+ close_gracefully(ptr,timeout);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void close_gracefully (
+ std::unique_ptr<connection>& con,
+ unsigned long timeout
+ )
+ {
+ if (!con)
+ return;
+
+ if(con->shutdown_outgoing())
+ {
+ // there was an error so just close it now and return
+ con.reset();
+ return;
+ }
+
+ try
+ {
+ dlib::timeout t(*con,&connection::shutdown,timeout);
+
+ char junk[100];
+ // wait for the other end to close their side
+ while (con->read(junk,sizeof(junk)) > 0) ;
+ }
+ catch (...)
+ {
+ con.reset();
+ throw;
+ }
+
+ con.reset();
+ }
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_SOCKETS_EXTENSIONs_CPP
+
+