summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/rtc_base/nat_socket_factory.cc
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
commit26a029d407be480d791972afb5975cf62c9360a6 (patch)
treef435a8308119effd964b339f76abb83a57c29483 /third_party/libwebrtc/rtc_base/nat_socket_factory.cc
parentInitial commit. (diff)
downloadfirefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz
firefox-26a029d407be480d791972afb5975cf62c9360a6.zip
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/libwebrtc/rtc_base/nat_socket_factory.cc')
-rw-r--r--third_party/libwebrtc/rtc_base/nat_socket_factory.cc515
1 files changed, 515 insertions, 0 deletions
diff --git a/third_party/libwebrtc/rtc_base/nat_socket_factory.cc b/third_party/libwebrtc/rtc_base/nat_socket_factory.cc
new file mode 100644
index 0000000000..fe021b95ff
--- /dev/null
+++ b/third_party/libwebrtc/rtc_base/nat_socket_factory.cc
@@ -0,0 +1,515 @@
+/*
+ * Copyright 2004 The WebRTC Project Authors. All rights reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include "rtc_base/nat_socket_factory.h"
+
+#include "rtc_base/arraysize.h"
+#include "rtc_base/checks.h"
+#include "rtc_base/logging.h"
+#include "rtc_base/nat_server.h"
+#include "rtc_base/virtual_socket_server.h"
+
+namespace rtc {
+
+// Packs the given socketaddress into the buffer in buf, in the quasi-STUN
+// format that the natserver uses.
+// Returns 0 if an invalid address is passed.
+size_t PackAddressForNAT(char* buf,
+ size_t buf_size,
+ const SocketAddress& remote_addr) {
+ const IPAddress& ip = remote_addr.ipaddr();
+ int family = ip.family();
+ buf[0] = 0;
+ buf[1] = family;
+ // Writes the port.
+ *(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port());
+ if (family == AF_INET) {
+ RTC_DCHECK(buf_size >= kNATEncodedIPv4AddressSize);
+ in_addr v4addr = ip.ipv4_address();
+ memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
+ return kNATEncodedIPv4AddressSize;
+ } else if (family == AF_INET6) {
+ RTC_DCHECK(buf_size >= kNATEncodedIPv6AddressSize);
+ in6_addr v6addr = ip.ipv6_address();
+ memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
+ return kNATEncodedIPv6AddressSize;
+ }
+ return 0U;
+}
+
+// Decodes the remote address from a packet that has been encoded with the nat's
+// quasi-STUN format. Returns the length of the address (i.e., the offset into
+// data where the original packet starts).
+size_t UnpackAddressFromNAT(const char* buf,
+ size_t buf_size,
+ SocketAddress* remote_addr) {
+ RTC_DCHECK(buf_size >= 8);
+ RTC_DCHECK(buf[0] == 0);
+ int family = buf[1];
+ uint16_t port =
+ NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2])));
+ if (family == AF_INET) {
+ const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
+ *remote_addr = SocketAddress(IPAddress(*v4addr), port);
+ return kNATEncodedIPv4AddressSize;
+ } else if (family == AF_INET6) {
+ RTC_DCHECK(buf_size >= 20);
+ const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
+ *remote_addr = SocketAddress(IPAddress(*v6addr), port);
+ return kNATEncodedIPv6AddressSize;
+ }
+ return 0U;
+}
+
+// NATSocket
+class NATSocket : public Socket, public sigslot::has_slots<> {
+ public:
+ explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
+ : sf_(sf),
+ family_(family),
+ type_(type),
+ connected_(false),
+ socket_(nullptr),
+ buf_(nullptr),
+ size_(0) {}
+
+ ~NATSocket() override {
+ delete socket_;
+ delete[] buf_;
+ }
+
+ SocketAddress GetLocalAddress() const override {
+ return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
+ }
+
+ SocketAddress GetRemoteAddress() const override {
+ return remote_addr_; // will be NIL if not connected
+ }
+
+ int Bind(const SocketAddress& addr) override {
+ if (socket_) { // already bound, bubble up error
+ return -1;
+ }
+
+ return BindInternal(addr);
+ }
+
+ int Connect(const SocketAddress& addr) override {
+ int result = 0;
+ // If we're not already bound (meaning `socket_` is null), bind to ANY
+ // address.
+ if (!socket_) {
+ result = BindInternal(SocketAddress(GetAnyIP(family_), 0));
+ if (result < 0) {
+ return result;
+ }
+ }
+
+ if (type_ == SOCK_STREAM) {
+ result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
+ } else {
+ connected_ = true;
+ }
+
+ if (result >= 0) {
+ remote_addr_ = addr;
+ }
+
+ return result;
+ }
+
+ int Send(const void* data, size_t size) override {
+ RTC_DCHECK(connected_);
+ return SendTo(data, size, remote_addr_);
+ }
+
+ int SendTo(const void* data,
+ size_t size,
+ const SocketAddress& addr) override {
+ RTC_DCHECK(!connected_ || addr == remote_addr_);
+ if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
+ return socket_->SendTo(data, size, addr);
+ }
+ // This array will be too large for IPv4 packets, but only by 12 bytes.
+ std::unique_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
+ size_t addrlength =
+ PackAddressForNAT(buf.get(), size + kNATEncodedIPv6AddressSize, addr);
+ size_t encoded_size = size + addrlength;
+ memcpy(buf.get() + addrlength, data, size);
+ int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
+ if (result >= 0) {
+ RTC_DCHECK(result == static_cast<int>(encoded_size));
+ result = result - static_cast<int>(addrlength);
+ }
+ return result;
+ }
+
+ int Recv(void* data, size_t size, int64_t* timestamp) override {
+ SocketAddress addr;
+ return RecvFrom(data, size, &addr, timestamp);
+ }
+
+ int RecvFrom(void* data,
+ size_t size,
+ SocketAddress* out_addr,
+ int64_t* timestamp) override {
+ if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
+ return socket_->RecvFrom(data, size, out_addr, timestamp);
+ }
+ // Make sure we have enough room to read the requested amount plus the
+ // largest possible header address.
+ SocketAddress remote_addr;
+ Grow(size + kNATEncodedIPv6AddressSize);
+
+ // Read the packet from the socket.
+ int result = socket_->RecvFrom(buf_, size_, &remote_addr, timestamp);
+ if (result >= 0) {
+ RTC_DCHECK(remote_addr == server_addr_);
+
+ // TODO: we need better framing so we know how many bytes we can
+ // return before we need to read the next address. For UDP, this will be
+ // fine as long as the reader always reads everything in the packet.
+ RTC_DCHECK((size_t)result < size_);
+
+ // Decode the wire packet into the actual results.
+ SocketAddress real_remote_addr;
+ size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr);
+ memcpy(data, buf_ + addrlength, result - addrlength);
+
+ // Make sure this packet should be delivered before returning it.
+ if (!connected_ || (real_remote_addr == remote_addr_)) {
+ if (out_addr)
+ *out_addr = real_remote_addr;
+ result = result - static_cast<int>(addrlength);
+ } else {
+ RTC_LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
+ << real_remote_addr.ToString();
+ result = 0; // Tell the caller we didn't read anything
+ }
+ }
+
+ return result;
+ }
+
+ int Close() override {
+ int result = 0;
+ if (socket_) {
+ result = socket_->Close();
+ if (result >= 0) {
+ connected_ = false;
+ remote_addr_ = SocketAddress();
+ delete socket_;
+ socket_ = nullptr;
+ }
+ }
+ return result;
+ }
+
+ int Listen(int backlog) override { return socket_->Listen(backlog); }
+ Socket* Accept(SocketAddress* paddr) override {
+ return socket_->Accept(paddr);
+ }
+ int GetError() const override {
+ return socket_ ? socket_->GetError() : error_;
+ }
+ void SetError(int error) override {
+ if (socket_) {
+ socket_->SetError(error);
+ } else {
+ error_ = error;
+ }
+ }
+ ConnState GetState() const override {
+ return connected_ ? CS_CONNECTED : CS_CLOSED;
+ }
+ int GetOption(Option opt, int* value) override {
+ return socket_ ? socket_->GetOption(opt, value) : -1;
+ }
+ int SetOption(Option opt, int value) override {
+ return socket_ ? socket_->SetOption(opt, value) : -1;
+ }
+
+ void OnConnectEvent(Socket* socket) {
+ // If we're NATed, we need to send a message with the real addr to use.
+ RTC_DCHECK(socket == socket_);
+ if (server_addr_.IsNil()) {
+ connected_ = true;
+ SignalConnectEvent(this);
+ } else {
+ SendConnectRequest();
+ }
+ }
+ void OnReadEvent(Socket* socket) {
+ // If we're NATed, we need to process the connect reply.
+ RTC_DCHECK(socket == socket_);
+ if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
+ HandleConnectReply();
+ } else {
+ SignalReadEvent(this);
+ }
+ }
+ void OnWriteEvent(Socket* socket) {
+ RTC_DCHECK(socket == socket_);
+ SignalWriteEvent(this);
+ }
+ void OnCloseEvent(Socket* socket, int error) {
+ RTC_DCHECK(socket == socket_);
+ SignalCloseEvent(this, error);
+ }
+
+ private:
+ int BindInternal(const SocketAddress& addr) {
+ RTC_DCHECK(!socket_);
+
+ int result;
+ socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
+ result = (socket_) ? socket_->Bind(addr) : -1;
+ if (result >= 0) {
+ socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
+ socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
+ socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
+ socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
+ } else {
+ server_addr_.Clear();
+ delete socket_;
+ socket_ = nullptr;
+ }
+
+ return result;
+ }
+
+ // Makes sure the buffer is at least the given size.
+ void Grow(size_t new_size) {
+ if (size_ < new_size) {
+ delete[] buf_;
+ size_ = new_size;
+ buf_ = new char[size_];
+ }
+ }
+
+ // Sends the destination address to the server to tell it to connect.
+ void SendConnectRequest() {
+ char buf[kNATEncodedIPv6AddressSize];
+ size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_);
+ socket_->Send(buf, length);
+ }
+
+ // Handles the byte sent back from the server and fires the appropriate event.
+ void HandleConnectReply() {
+ char code;
+ socket_->Recv(&code, sizeof(code), nullptr);
+ if (code == 0) {
+ connected_ = true;
+ SignalConnectEvent(this);
+ } else {
+ Close();
+ SignalCloseEvent(this, code);
+ }
+ }
+
+ NATInternalSocketFactory* sf_;
+ int family_;
+ int type_;
+ bool connected_;
+ SocketAddress remote_addr_;
+ SocketAddress server_addr_; // address of the NAT server
+ Socket* socket_;
+ // Need to hold error in case it occurs before the socket is created.
+ int error_ = 0;
+ char* buf_;
+ size_t size_;
+};
+
+// NATSocketFactory
+NATSocketFactory::NATSocketFactory(SocketFactory* factory,
+ const SocketAddress& nat_udp_addr,
+ const SocketAddress& nat_tcp_addr)
+ : factory_(factory),
+ nat_udp_addr_(nat_udp_addr),
+ nat_tcp_addr_(nat_tcp_addr) {}
+
+Socket* NATSocketFactory::CreateSocket(int family, int type) {
+ return new NATSocket(this, family, type);
+}
+
+Socket* NATSocketFactory::CreateInternalSocket(int family,
+ int type,
+ const SocketAddress& local_addr,
+ SocketAddress* nat_addr) {
+ if (type == SOCK_STREAM) {
+ *nat_addr = nat_tcp_addr_;
+ } else {
+ *nat_addr = nat_udp_addr_;
+ }
+ return factory_->CreateSocket(family, type);
+}
+
+// NATSocketServer
+NATSocketServer::NATSocketServer(SocketServer* server)
+ : server_(server), msg_queue_(nullptr) {}
+
+NATSocketServer::Translator* NATSocketServer::GetTranslator(
+ const SocketAddress& ext_ip) {
+ return nats_.Get(ext_ip);
+}
+
+NATSocketServer::Translator* NATSocketServer::AddTranslator(
+ const SocketAddress& ext_ip,
+ const SocketAddress& int_ip,
+ NATType type) {
+ // Fail if a translator already exists with this extternal address.
+ if (nats_.Get(ext_ip))
+ return nullptr;
+
+ return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
+}
+
+void NATSocketServer::RemoveTranslator(const SocketAddress& ext_ip) {
+ nats_.Remove(ext_ip);
+}
+
+Socket* NATSocketServer::CreateSocket(int family, int type) {
+ return new NATSocket(this, family, type);
+}
+
+void NATSocketServer::SetMessageQueue(Thread* queue) {
+ msg_queue_ = queue;
+ server_->SetMessageQueue(queue);
+}
+
+bool NATSocketServer::Wait(webrtc::TimeDelta max_wait_duration,
+ bool process_io) {
+ return server_->Wait(max_wait_duration, process_io);
+}
+
+void NATSocketServer::WakeUp() {
+ server_->WakeUp();
+}
+
+Socket* NATSocketServer::CreateInternalSocket(int family,
+ int type,
+ const SocketAddress& local_addr,
+ SocketAddress* nat_addr) {
+ Socket* socket = nullptr;
+ Translator* nat = nats_.FindClient(local_addr);
+ if (nat) {
+ socket = nat->internal_factory()->CreateSocket(family, type);
+ *nat_addr = (type == SOCK_STREAM) ? nat->internal_tcp_address()
+ : nat->internal_udp_address();
+ } else {
+ socket = server_->CreateSocket(family, type);
+ }
+ return socket;
+}
+
+// NATSocketServer::Translator
+NATSocketServer::Translator::Translator(NATSocketServer* server,
+ NATType type,
+ const SocketAddress& int_ip,
+ SocketFactory* ext_factory,
+ const SocketAddress& ext_ip)
+ : server_(server) {
+ // Create a new private network, and a NATServer running on the private
+ // network that bridges to the external network. Also tell the private
+ // network to use the same message queue as us.
+ internal_server_ = std::make_unique<VirtualSocketServer>();
+ internal_server_->SetMessageQueue(server_->queue());
+ nat_server_ = std::make_unique<NATServer>(
+ type, internal_server_.get(), int_ip, int_ip, ext_factory, ext_ip);
+}
+
+NATSocketServer::Translator::~Translator() {
+ internal_server_->SetMessageQueue(nullptr);
+}
+
+NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
+ const SocketAddress& ext_ip) {
+ return nats_.Get(ext_ip);
+}
+
+NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
+ const SocketAddress& ext_ip,
+ const SocketAddress& int_ip,
+ NATType type) {
+ // Fail if a translator already exists with this extternal address.
+ if (nats_.Get(ext_ip))
+ return nullptr;
+
+ AddClient(ext_ip);
+ return nats_.Add(ext_ip,
+ new Translator(server_, type, int_ip, server_, ext_ip));
+}
+void NATSocketServer::Translator::RemoveTranslator(
+ const SocketAddress& ext_ip) {
+ nats_.Remove(ext_ip);
+ RemoveClient(ext_ip);
+}
+
+bool NATSocketServer::Translator::AddClient(const SocketAddress& int_ip) {
+ // Fail if a client already exists with this internal address.
+ if (clients_.find(int_ip) != clients_.end())
+ return false;
+
+ clients_.insert(int_ip);
+ return true;
+}
+
+void NATSocketServer::Translator::RemoveClient(const SocketAddress& int_ip) {
+ std::set<SocketAddress>::iterator it = clients_.find(int_ip);
+ if (it != clients_.end()) {
+ clients_.erase(it);
+ }
+}
+
+NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
+ const SocketAddress& int_ip) {
+ // See if we have the requested IP, or any of our children do.
+ return (clients_.find(int_ip) != clients_.end()) ? this
+ : nats_.FindClient(int_ip);
+}
+
+// NATSocketServer::TranslatorMap
+NATSocketServer::TranslatorMap::~TranslatorMap() {
+ for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
+ delete it->second;
+ }
+}
+
+NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
+ const SocketAddress& ext_ip) {
+ TranslatorMap::iterator it = find(ext_ip);
+ return (it != end()) ? it->second : nullptr;
+}
+
+NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
+ const SocketAddress& ext_ip,
+ Translator* nat) {
+ (*this)[ext_ip] = nat;
+ return nat;
+}
+
+void NATSocketServer::TranslatorMap::Remove(const SocketAddress& ext_ip) {
+ TranslatorMap::iterator it = find(ext_ip);
+ if (it != end()) {
+ delete it->second;
+ erase(it);
+ }
+}
+
+NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
+ const SocketAddress& int_ip) {
+ Translator* nat = nullptr;
+ for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
+ nat = it->second->FindClient(int_ip);
+ }
+ return nat;
+}
+
+} // namespace rtc