diff options
Diffstat (limited to 'third_party/libwebrtc/p2p/stunprober/stun_prober.cc')
-rw-r--r-- | third_party/libwebrtc/p2p/stunprober/stun_prober.cc | 610 |
1 files changed, 610 insertions, 0 deletions
diff --git a/third_party/libwebrtc/p2p/stunprober/stun_prober.cc b/third_party/libwebrtc/p2p/stunprober/stun_prober.cc new file mode 100644 index 0000000000..977ead4d72 --- /dev/null +++ b/third_party/libwebrtc/p2p/stunprober/stun_prober.cc @@ -0,0 +1,610 @@ +/* + * Copyright 2015 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "p2p/stunprober/stun_prober.h" + +#include <map> +#include <memory> +#include <set> +#include <string> +#include <utility> + +#include "api/packet_socket_factory.h" +#include "api/task_queue/pending_task_safety_flag.h" +#include "api/transport/stun.h" +#include "api/units/time_delta.h" +#include "rtc_base/async_packet_socket.h" +#include "rtc_base/async_resolver_interface.h" +#include "rtc_base/checks.h" +#include "rtc_base/helpers.h" +#include "rtc_base/logging.h" +#include "rtc_base/thread.h" +#include "rtc_base/time_utils.h" + +namespace stunprober { + +namespace { +using ::webrtc::SafeTask; +using ::webrtc::TimeDelta; + +const int THREAD_WAKE_UP_INTERVAL_MS = 5; + +template <typename T> +void IncrementCounterByAddress(std::map<T, int>* counter_per_ip, const T& ip) { + counter_per_ip->insert(std::make_pair(ip, 0)).first->second++; +} + +} // namespace + +// A requester tracks the requests and responses from a single socket to many +// STUN servers +class StunProber::Requester : public sigslot::has_slots<> { + public: + // Each Request maps to a request and response. + struct Request { + // Actual time the STUN bind request was sent. + int64_t sent_time_ms = 0; + // Time the response was received. + int64_t received_time_ms = 0; + + // Server reflexive address from STUN response for this given request. + rtc::SocketAddress srflx_addr; + + rtc::IPAddress server_addr; + + int64_t rtt() { return received_time_ms - sent_time_ms; } + void ProcessResponse(const char* buf, size_t buf_len); + }; + + // StunProber provides `server_ips` for Requester to probe. For shared + // socket mode, it'll be all the resolved IP addresses. For non-shared mode, + // it'll just be a single address. + Requester(StunProber* prober, + rtc::AsyncPacketSocket* socket, + const std::vector<rtc::SocketAddress>& server_ips); + ~Requester() override; + + Requester(const Requester&) = delete; + Requester& operator=(const Requester&) = delete; + + // There is no callback for SendStunRequest as the underneath socket send is + // expected to be completed immediately. Otherwise, it'll skip this request + // and move to the next one. + void SendStunRequest(); + + void OnStunResponseReceived(rtc::AsyncPacketSocket* socket, + const char* buf, + size_t size, + const rtc::SocketAddress& addr, + const int64_t& packet_time_us); + + const std::vector<Request*>& requests() { return requests_; } + + // Whether this Requester has completed all requests. + bool Done() { + return static_cast<size_t>(num_request_sent_) == server_ips_.size(); + } + + private: + Request* GetRequestByAddress(const rtc::IPAddress& ip); + + StunProber* prober_; + + // The socket for this session. + std::unique_ptr<rtc::AsyncPacketSocket> socket_; + + // Temporary SocketAddress and buffer for RecvFrom. + rtc::SocketAddress addr_; + std::unique_ptr<rtc::ByteBufferWriter> response_packet_; + + std::vector<Request*> requests_; + std::vector<rtc::SocketAddress> server_ips_; + int16_t num_request_sent_ = 0; + int16_t num_response_received_ = 0; + + webrtc::SequenceChecker& thread_checker_; +}; + +StunProber::Requester::Requester( + StunProber* prober, + rtc::AsyncPacketSocket* socket, + const std::vector<rtc::SocketAddress>& server_ips) + : prober_(prober), + socket_(socket), + response_packet_(new rtc::ByteBufferWriter(nullptr, kMaxUdpBufferSize)), + server_ips_(server_ips), + thread_checker_(prober->thread_checker_) { + socket_->SignalReadPacket.connect( + this, &StunProber::Requester::OnStunResponseReceived); +} + +StunProber::Requester::~Requester() { + if (socket_) { + socket_->Close(); + } + for (auto* req : requests_) { + if (req) { + delete req; + } + } +} + +void StunProber::Requester::SendStunRequest() { + RTC_DCHECK(thread_checker_.IsCurrent()); + requests_.push_back(new Request()); + Request& request = *(requests_.back()); + // Random transaction ID, STUN_BINDING_REQUEST + cricket::StunMessage message(cricket::STUN_BINDING_REQUEST); + + std::unique_ptr<rtc::ByteBufferWriter> request_packet( + new rtc::ByteBufferWriter(nullptr, kMaxUdpBufferSize)); + if (!message.Write(request_packet.get())) { + prober_->ReportOnFinished(WRITE_FAILED); + return; + } + + auto addr = server_ips_[num_request_sent_]; + request.server_addr = addr.ipaddr(); + + // The write must succeed immediately. Otherwise, the calculating of the STUN + // request timing could become too complicated. Callback is ignored by passing + // empty AsyncCallback. + rtc::PacketOptions options; + int rv = socket_->SendTo(const_cast<char*>(request_packet->Data()), + request_packet->Length(), addr, options); + if (rv < 0) { + prober_->ReportOnFinished(WRITE_FAILED); + return; + } + + request.sent_time_ms = rtc::TimeMillis(); + + num_request_sent_++; + RTC_DCHECK(static_cast<size_t>(num_request_sent_) <= server_ips_.size()); +} + +void StunProber::Requester::Request::ProcessResponse(const char* buf, + size_t buf_len) { + int64_t now = rtc::TimeMillis(); + rtc::ByteBufferReader message(buf, buf_len); + cricket::StunMessage stun_response; + if (!stun_response.Read(&message)) { + // Invalid or incomplete STUN packet. + received_time_ms = 0; + return; + } + + // Get external address of the socket. + const cricket::StunAddressAttribute* addr_attr = + stun_response.GetAddress(cricket::STUN_ATTR_MAPPED_ADDRESS); + if (addr_attr == nullptr) { + // Addresses not available to detect whether or not behind a NAT. + return; + } + + if (addr_attr->family() != cricket::STUN_ADDRESS_IPV4 && + addr_attr->family() != cricket::STUN_ADDRESS_IPV6) { + return; + } + + received_time_ms = now; + + srflx_addr = addr_attr->GetAddress(); +} + +void StunProber::Requester::OnStunResponseReceived( + rtc::AsyncPacketSocket* socket, + const char* buf, + size_t size, + const rtc::SocketAddress& addr, + const int64_t& /* packet_time_us */) { + RTC_DCHECK(thread_checker_.IsCurrent()); + RTC_DCHECK(socket_); + Request* request = GetRequestByAddress(addr.ipaddr()); + if (!request) { + // Something is wrong, finish the test. + prober_->ReportOnFinished(GENERIC_FAILURE); + return; + } + + num_response_received_++; + request->ProcessResponse(buf, size); +} + +StunProber::Requester::Request* StunProber::Requester::GetRequestByAddress( + const rtc::IPAddress& ipaddr) { + RTC_DCHECK(thread_checker_.IsCurrent()); + for (auto* request : requests_) { + if (request->server_addr == ipaddr) { + return request; + } + } + + return nullptr; +} + +StunProber::Stats::Stats() = default; + +StunProber::Stats::~Stats() = default; + +StunProber::ObserverAdapter::ObserverAdapter() = default; + +StunProber::ObserverAdapter::~ObserverAdapter() = default; + +void StunProber::ObserverAdapter::OnPrepared(StunProber* stunprober, + Status status) { + if (status == SUCCESS) { + stunprober->Start(this); + } else { + callback_(stunprober, status); + } +} + +void StunProber::ObserverAdapter::OnFinished(StunProber* stunprober, + Status status) { + callback_(stunprober, status); +} + +StunProber::StunProber(rtc::PacketSocketFactory* socket_factory, + rtc::Thread* thread, + std::vector<const rtc::Network*> networks) + : interval_ms_(0), + socket_factory_(socket_factory), + thread_(thread), + networks_(std::move(networks)) {} + +StunProber::~StunProber() { + RTC_DCHECK(thread_checker_.IsCurrent()); + for (auto* req : requesters_) { + if (req) { + delete req; + } + } + for (auto* s : sockets_) { + if (s) { + delete s; + } + } +} + +bool StunProber::Start(const std::vector<rtc::SocketAddress>& servers, + bool shared_socket_mode, + int interval_ms, + int num_request_per_ip, + int timeout_ms, + const AsyncCallback callback) { + observer_adapter_.set_callback(callback); + return Prepare(servers, shared_socket_mode, interval_ms, num_request_per_ip, + timeout_ms, &observer_adapter_); +} + +bool StunProber::Prepare(const std::vector<rtc::SocketAddress>& servers, + bool shared_socket_mode, + int interval_ms, + int num_request_per_ip, + int timeout_ms, + StunProber::Observer* observer) { + RTC_DCHECK(thread_checker_.IsCurrent()); + interval_ms_ = interval_ms; + shared_socket_mode_ = shared_socket_mode; + + requests_per_ip_ = num_request_per_ip; + if (requests_per_ip_ == 0 || servers.size() == 0) { + return false; + } + + timeout_ms_ = timeout_ms; + servers_ = servers; + observer_ = observer; + // Remove addresses that are already resolved. + for (auto it = servers_.begin(); it != servers_.end();) { + if (it->ipaddr().family() != AF_UNSPEC) { + all_servers_addrs_.push_back(*it); + it = servers_.erase(it); + } else { + ++it; + } + } + if (servers_.empty()) { + CreateSockets(); + return true; + } + return ResolveServerName(servers_.back()); +} + +bool StunProber::Start(StunProber::Observer* observer) { + observer_ = observer; + if (total_ready_sockets_ != total_socket_required()) { + return false; + } + MaybeScheduleStunRequests(); + return true; +} + +bool StunProber::ResolveServerName(const rtc::SocketAddress& addr) { + rtc::AsyncResolverInterface* resolver = + socket_factory_->CreateAsyncResolver(); + if (!resolver) { + return false; + } + resolver->SignalDone.connect(this, &StunProber::OnServerResolved); + resolver->Start(addr); + return true; +} + +void StunProber::OnSocketReady(rtc::AsyncPacketSocket* socket, + const rtc::SocketAddress& addr) { + total_ready_sockets_++; + if (total_ready_sockets_ == total_socket_required()) { + ReportOnPrepared(SUCCESS); + } +} + +void StunProber::OnServerResolved(rtc::AsyncResolverInterface* resolver) { + RTC_DCHECK(thread_checker_.IsCurrent()); + + if (resolver->GetError() == 0) { + rtc::SocketAddress addr(resolver->address().ipaddr(), + resolver->address().port()); + all_servers_addrs_.push_back(addr); + } + + // Deletion of AsyncResolverInterface can't be done in OnResolveResult which + // handles SignalDone. + thread_->PostTask([resolver] { resolver->Destroy(false); }); + servers_.pop_back(); + + if (servers_.size()) { + if (!ResolveServerName(servers_.back())) { + ReportOnPrepared(RESOLVE_FAILED); + } + return; + } + + if (all_servers_addrs_.size() == 0) { + ReportOnPrepared(RESOLVE_FAILED); + return; + } + + CreateSockets(); +} + +void StunProber::CreateSockets() { + // Dedupe. + std::set<rtc::SocketAddress> addrs(all_servers_addrs_.begin(), + all_servers_addrs_.end()); + all_servers_addrs_.assign(addrs.begin(), addrs.end()); + + // Prepare all the sockets beforehand. All of them will bind to "any" address. + while (sockets_.size() < total_socket_required()) { + std::unique_ptr<rtc::AsyncPacketSocket> socket( + socket_factory_->CreateUdpSocket(rtc::SocketAddress(INADDR_ANY, 0), 0, + 0)); + if (!socket) { + ReportOnPrepared(GENERIC_FAILURE); + return; + } + // Chrome and WebRTC behave differently in terms of the state of a socket + // once returned from PacketSocketFactory::CreateUdpSocket. + if (socket->GetState() == rtc::AsyncPacketSocket::STATE_BINDING) { + socket->SignalAddressReady.connect(this, &StunProber::OnSocketReady); + } else { + OnSocketReady(socket.get(), rtc::SocketAddress(INADDR_ANY, 0)); + } + sockets_.push_back(socket.release()); + } +} + +StunProber::Requester* StunProber::CreateRequester() { + RTC_DCHECK(thread_checker_.IsCurrent()); + if (!sockets_.size()) { + return nullptr; + } + StunProber::Requester* requester; + if (shared_socket_mode_) { + requester = new Requester(this, sockets_.back(), all_servers_addrs_); + } else { + std::vector<rtc::SocketAddress> server_ip; + server_ip.push_back( + all_servers_addrs_[(num_request_sent_ % all_servers_addrs_.size())]); + requester = new Requester(this, sockets_.back(), server_ip); + } + + sockets_.pop_back(); + return requester; +} + +bool StunProber::SendNextRequest() { + if (!current_requester_ || current_requester_->Done()) { + current_requester_ = CreateRequester(); + requesters_.push_back(current_requester_); + } + if (!current_requester_) { + return false; + } + current_requester_->SendStunRequest(); + num_request_sent_++; + return true; +} + +bool StunProber::should_send_next_request(int64_t now) { + if (interval_ms_ < THREAD_WAKE_UP_INTERVAL_MS) { + return now >= next_request_time_ms_; + } else { + return (now + (THREAD_WAKE_UP_INTERVAL_MS / 2)) >= next_request_time_ms_; + } +} + +int StunProber::get_wake_up_interval_ms() { + if (interval_ms_ < THREAD_WAKE_UP_INTERVAL_MS) { + return 1; + } else { + return THREAD_WAKE_UP_INTERVAL_MS; + } +} + +void StunProber::MaybeScheduleStunRequests() { + RTC_DCHECK_RUN_ON(thread_); + int64_t now = rtc::TimeMillis(); + + if (Done()) { + thread_->PostDelayedTask( + SafeTask(task_safety_.flag(), [this] { ReportOnFinished(SUCCESS); }), + TimeDelta::Millis(timeout_ms_)); + return; + } + if (should_send_next_request(now)) { + if (!SendNextRequest()) { + ReportOnFinished(GENERIC_FAILURE); + return; + } + next_request_time_ms_ = now + interval_ms_; + } + thread_->PostDelayedTask( + SafeTask(task_safety_.flag(), [this] { MaybeScheduleStunRequests(); }), + TimeDelta::Millis(get_wake_up_interval_ms())); +} + +bool StunProber::GetStats(StunProber::Stats* prob_stats) const { + // No need to be on the same thread. + if (!prob_stats) { + return false; + } + + StunProber::Stats stats; + + int rtt_sum = 0; + int64_t first_sent_time = 0; + int64_t last_sent_time = 0; + NatType nat_type = NATTYPE_INVALID; + + // Track of how many srflx IP that we have seen. + std::set<rtc::IPAddress> srflx_ips; + + // If we're not receiving any response on a given IP, all requests sent to + // that IP should be ignored as this could just be an DNS error. + std::map<rtc::IPAddress, int> num_response_per_server; + std::map<rtc::IPAddress, int> num_request_per_server; + + for (auto* requester : requesters_) { + std::map<rtc::SocketAddress, int> num_response_per_srflx_addr; + for (auto* request : requester->requests()) { + if (request->sent_time_ms <= 0) { + continue; + } + + ++stats.raw_num_request_sent; + IncrementCounterByAddress(&num_request_per_server, request->server_addr); + + if (!first_sent_time) { + first_sent_time = request->sent_time_ms; + } + last_sent_time = request->sent_time_ms; + + if (request->received_time_ms < request->sent_time_ms) { + continue; + } + + IncrementCounterByAddress(&num_response_per_server, request->server_addr); + IncrementCounterByAddress(&num_response_per_srflx_addr, + request->srflx_addr); + rtt_sum += request->rtt(); + stats.srflx_addrs.insert(request->srflx_addr.ToString()); + srflx_ips.insert(request->srflx_addr.ipaddr()); + } + + // If we're using shared mode and seeing >1 srflx addresses for a single + // requester, it's symmetric NAT. + if (shared_socket_mode_ && num_response_per_srflx_addr.size() > 1) { + nat_type = NATTYPE_SYMMETRIC; + } + } + + // We're probably not behind a regular NAT. We have more than 1 distinct + // server reflexive IPs. + if (srflx_ips.size() > 1) { + return false; + } + + int num_sent = 0; + int num_received = 0; + int num_server_ip_with_response = 0; + + for (const auto& kv : num_response_per_server) { + RTC_DCHECK_GT(kv.second, 0); + num_server_ip_with_response++; + num_received += kv.second; + num_sent += num_request_per_server[kv.first]; + } + + // Shared mode is only true if we use the shared socket and there are more + // than 1 responding servers. + stats.shared_socket_mode = + shared_socket_mode_ && (num_server_ip_with_response > 1); + + if (stats.shared_socket_mode && nat_type == NATTYPE_INVALID) { + nat_type = NATTYPE_NON_SYMMETRIC; + } + + // If we could find a local IP matching srflx, we're not behind a NAT. + rtc::SocketAddress srflx_addr; + if (stats.srflx_addrs.size() && + !srflx_addr.FromString(*(stats.srflx_addrs.begin()))) { + return false; + } + for (const auto* net : networks_) { + if (srflx_addr.ipaddr() == net->GetBestIP()) { + nat_type = stunprober::NATTYPE_NONE; + stats.host_ip = net->GetBestIP().ToString(); + break; + } + } + + // Finally, we know we're behind a NAT but can't determine which type it is. + if (nat_type == NATTYPE_INVALID) { + nat_type = NATTYPE_UNKNOWN; + } + + stats.nat_type = nat_type; + stats.num_request_sent = num_sent; + stats.num_response_received = num_received; + stats.target_request_interval_ns = interval_ms_ * 1000; + + if (num_sent) { + stats.success_percent = static_cast<int>(100 * num_received / num_sent); + } + + if (stats.raw_num_request_sent > 1) { + stats.actual_request_interval_ns = + (1000 * (last_sent_time - first_sent_time)) / + (stats.raw_num_request_sent - 1); + } + + if (num_received) { + stats.average_rtt_ms = static_cast<int>((rtt_sum / num_received)); + } + + *prob_stats = stats; + return true; +} + +void StunProber::ReportOnPrepared(StunProber::Status status) { + if (observer_) { + observer_->OnPrepared(this, status); + } +} + +void StunProber::ReportOnFinished(StunProber::Status status) { + if (observer_) { + observer_->OnFinished(this, status); + } +} + +} // namespace stunprober |