diff options
Diffstat (limited to 'third_party/libwebrtc/rtc_base/nat_unittest.cc')
-rw-r--r-- | third_party/libwebrtc/rtc_base/nat_unittest.cc | 407 |
1 files changed, 407 insertions, 0 deletions
diff --git a/third_party/libwebrtc/rtc_base/nat_unittest.cc b/third_party/libwebrtc/rtc_base/nat_unittest.cc new file mode 100644 index 0000000000..f6dd83cadb --- /dev/null +++ b/third_party/libwebrtc/rtc_base/nat_unittest.cc @@ -0,0 +1,407 @@ +/* + * Copyright 2004 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include <string.h> + +#include <algorithm> +#include <memory> +#include <string> +#include <vector> + +#include "absl/memory/memory.h" +#include "rtc_base/async_packet_socket.h" +#include "rtc_base/async_tcp_socket.h" +#include "rtc_base/async_udp_socket.h" +#include "rtc_base/gunit.h" +#include "rtc_base/ip_address.h" +#include "rtc_base/logging.h" +#include "rtc_base/nat_server.h" +#include "rtc_base/nat_socket_factory.h" +#include "rtc_base/nat_types.h" +#include "rtc_base/net_helpers.h" +#include "rtc_base/network.h" +#include "rtc_base/physical_socket_server.h" +#include "rtc_base/socket.h" +#include "rtc_base/socket_address.h" +#include "rtc_base/socket_factory.h" +#include "rtc_base/socket_server.h" +#include "rtc_base/test_client.h" +#include "rtc_base/third_party/sigslot/sigslot.h" +#include "rtc_base/thread.h" +#include "rtc_base/virtual_socket_server.h" +#include "test/gtest.h" +#include "test/scoped_key_value_config.h" + +namespace rtc { +namespace { + +bool CheckReceive(TestClient* client, + bool should_receive, + const char* buf, + size_t size) { + return (should_receive) ? client->CheckNextPacket(buf, size, 0) + : client->CheckNoPacket(); +} + +TestClient* CreateTestClient(SocketFactory* factory, + const SocketAddress& local_addr) { + return new TestClient( + absl::WrapUnique(AsyncUDPSocket::Create(factory, local_addr))); +} + +TestClient* CreateTCPTestClient(Socket* socket) { + return new TestClient(std::make_unique<AsyncTCPSocket>(socket)); +} + +// Tests that when sending from internal_addr to external_addrs through the +// NAT type specified by nat_type, all external addrs receive the sent packet +// and, if exp_same is true, all use the same mapped-address on the NAT. +void TestSend(SocketServer* internal, + const SocketAddress& internal_addr, + SocketServer* external, + const SocketAddress external_addrs[4], + NATType nat_type, + bool exp_same) { + Thread th_int(internal); + Thread th_ext(external); + + SocketAddress server_addr = internal_addr; + server_addr.SetPort(0); // Auto-select a port + NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr, + external, external_addrs[0]); + NATSocketFactory* natsf = new NATSocketFactory( + internal, nat->internal_udp_address(), nat->internal_tcp_address()); + + TestClient* in = CreateTestClient(natsf, internal_addr); + TestClient* out[4]; + for (int i = 0; i < 4; i++) + out[i] = CreateTestClient(external, external_addrs[i]); + + th_int.Start(); + th_ext.Start(); + + const char* buf = "filter_test"; + size_t len = strlen(buf); + + in->SendTo(buf, len, out[0]->address()); + SocketAddress trans_addr; + EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr)); + + for (int i = 1; i < 4; i++) { + in->SendTo(buf, len, out[i]->address()); + SocketAddress trans_addr2; + EXPECT_TRUE(out[i]->CheckNextPacket(buf, len, &trans_addr2)); + bool are_same = (trans_addr == trans_addr2); + ASSERT_EQ(are_same, exp_same) << "same translated address"; + ASSERT_NE(AF_UNSPEC, trans_addr.family()); + ASSERT_NE(AF_UNSPEC, trans_addr2.family()); + } + + th_int.Stop(); + th_ext.Stop(); + + delete nat; + delete natsf; + delete in; + for (int i = 0; i < 4; i++) + delete out[i]; +} + +// Tests that when sending from external_addrs to internal_addr, the packet +// is delivered according to the specified filter_ip and filter_port rules. +void TestRecv(SocketServer* internal, + const SocketAddress& internal_addr, + SocketServer* external, + const SocketAddress external_addrs[4], + NATType nat_type, + bool filter_ip, + bool filter_port) { + Thread th_int(internal); + Thread th_ext(external); + + SocketAddress server_addr = internal_addr; + server_addr.SetPort(0); // Auto-select a port + NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr, + external, external_addrs[0]); + NATSocketFactory* natsf = new NATSocketFactory( + internal, nat->internal_udp_address(), nat->internal_tcp_address()); + + TestClient* in = CreateTestClient(natsf, internal_addr); + TestClient* out[4]; + for (int i = 0; i < 4; i++) + out[i] = CreateTestClient(external, external_addrs[i]); + + th_int.Start(); + th_ext.Start(); + + const char* buf = "filter_test"; + size_t len = strlen(buf); + + in->SendTo(buf, len, out[0]->address()); + SocketAddress trans_addr; + EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr)); + + out[1]->SendTo(buf, len, trans_addr); + EXPECT_TRUE(CheckReceive(in, !filter_ip, buf, len)); + + out[2]->SendTo(buf, len, trans_addr); + EXPECT_TRUE(CheckReceive(in, !filter_port, buf, len)); + + out[3]->SendTo(buf, len, trans_addr); + EXPECT_TRUE(CheckReceive(in, !filter_ip && !filter_port, buf, len)); + + th_int.Stop(); + th_ext.Stop(); + + delete nat; + delete natsf; + delete in; + for (int i = 0; i < 4; i++) + delete out[i]; +} + +// Tests that NATServer allocates bindings properly. +void TestBindings(SocketServer* internal, + const SocketAddress& internal_addr, + SocketServer* external, + const SocketAddress external_addrs[4]) { + TestSend(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE, + true); + TestSend(internal, internal_addr, external, external_addrs, + NAT_ADDR_RESTRICTED, true); + TestSend(internal, internal_addr, external, external_addrs, + NAT_PORT_RESTRICTED, true); + TestSend(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC, + false); +} + +// Tests that NATServer filters packets properly. +void TestFilters(SocketServer* internal, + const SocketAddress& internal_addr, + SocketServer* external, + const SocketAddress external_addrs[4]) { + TestRecv(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE, + false, false); + TestRecv(internal, internal_addr, external, external_addrs, + NAT_ADDR_RESTRICTED, true, false); + TestRecv(internal, internal_addr, external, external_addrs, + NAT_PORT_RESTRICTED, true, true); + TestRecv(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC, + true, true); +} + +bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) { + // The physical NAT tests require connectivity to the selected ip from the + // internal address used for the NAT. Things like firewalls can break that, so + // check to see if it's worth even trying with this ip. + std::unique_ptr<PhysicalSocketServer> pss(new PhysicalSocketServer()); + std::unique_ptr<Socket> client(pss->CreateSocket(src.family(), SOCK_DGRAM)); + std::unique_ptr<Socket> server(pss->CreateSocket(src.family(), SOCK_DGRAM)); + if (client->Bind(SocketAddress(src.ipaddr(), 0)) != 0 || + server->Bind(SocketAddress(dst, 0)) != 0) { + return false; + } + const char* buf = "hello other socket"; + size_t len = strlen(buf); + int sent = client->SendTo(buf, len, server->GetLocalAddress()); + SocketAddress addr; + const size_t kRecvBufSize = 64; + char recvbuf[kRecvBufSize]; + Thread::Current()->SleepMs(100); + int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr, nullptr); + return received == sent && ::memcmp(buf, recvbuf, len) == 0; +} + +void TestPhysicalInternal(const SocketAddress& int_addr) { + webrtc::test::ScopedKeyValueConfig field_trials; + rtc::AutoThread main_thread; + PhysicalSocketServer socket_server; + BasicNetworkManager network_manager(nullptr, &socket_server, &field_trials); + network_manager.StartUpdating(); + // Process pending messages so the network list is updated. + Thread::Current()->ProcessMessages(0); + + std::vector<const Network*> networks = network_manager.GetNetworks(); + networks.erase(std::remove_if(networks.begin(), networks.end(), + [](const rtc::Network* network) { + return rtc::kDefaultNetworkIgnoreMask & + network->type(); + }), + networks.end()); + if (networks.empty()) { + RTC_LOG(LS_WARNING) << "Not enough network adapters for test."; + return; + } + + SocketAddress ext_addr1(int_addr); + SocketAddress ext_addr2; + // Find an available IP with matching family. The test breaks if int_addr + // can't talk to ip, so check for connectivity as well. + for (const Network* const network : networks) { + const IPAddress& ip = network->GetBestIP(); + if (ip.family() == int_addr.family() && TestConnectivity(int_addr, ip)) { + ext_addr2.SetIP(ip); + break; + } + } + if (ext_addr2.IsNil()) { + RTC_LOG(LS_WARNING) << "No available IP of same family as " + << int_addr.ToString(); + return; + } + + RTC_LOG(LS_INFO) << "selected ip " << ext_addr2.ipaddr().ToString(); + + SocketAddress ext_addrs[4] = { + SocketAddress(ext_addr1), SocketAddress(ext_addr2), + SocketAddress(ext_addr1), SocketAddress(ext_addr2)}; + + std::unique_ptr<PhysicalSocketServer> int_pss(new PhysicalSocketServer()); + std::unique_ptr<PhysicalSocketServer> ext_pss(new PhysicalSocketServer()); + + TestBindings(int_pss.get(), int_addr, ext_pss.get(), ext_addrs); + TestFilters(int_pss.get(), int_addr, ext_pss.get(), ext_addrs); +} + +TEST(NatTest, TestPhysicalIPv4) { + TestPhysicalInternal(SocketAddress("127.0.0.1", 0)); +} + +TEST(NatTest, TestPhysicalIPv6) { + if (HasIPv6Enabled()) { + TestPhysicalInternal(SocketAddress("::1", 0)); + } else { + RTC_LOG(LS_WARNING) << "No IPv6, skipping"; + } +} + +namespace { + +class TestVirtualSocketServer : public VirtualSocketServer { + public: + // Expose this publicly + IPAddress GetNextIP(int af) { return VirtualSocketServer::GetNextIP(af); } +}; + +} // namespace + +void TestVirtualInternal(int family) { + rtc::AutoThread main_thread; + std::unique_ptr<TestVirtualSocketServer> int_vss( + new TestVirtualSocketServer()); + std::unique_ptr<TestVirtualSocketServer> ext_vss( + new TestVirtualSocketServer()); + + SocketAddress int_addr; + SocketAddress ext_addrs[4]; + int_addr.SetIP(int_vss->GetNextIP(family)); + ext_addrs[0].SetIP(ext_vss->GetNextIP(int_addr.family())); + ext_addrs[1].SetIP(ext_vss->GetNextIP(int_addr.family())); + ext_addrs[2].SetIP(ext_addrs[0].ipaddr()); + ext_addrs[3].SetIP(ext_addrs[1].ipaddr()); + + TestBindings(int_vss.get(), int_addr, ext_vss.get(), ext_addrs); + TestFilters(int_vss.get(), int_addr, ext_vss.get(), ext_addrs); +} + +TEST(NatTest, TestVirtualIPv4) { + TestVirtualInternal(AF_INET); +} + +TEST(NatTest, TestVirtualIPv6) { + if (HasIPv6Enabled()) { + TestVirtualInternal(AF_INET6); + } else { + RTC_LOG(LS_WARNING) << "No IPv6, skipping"; + } +} + +class NatTcpTest : public ::testing::Test, public sigslot::has_slots<> { + public: + NatTcpTest() + : int_addr_("192.168.0.1", 0), + ext_addr_("10.0.0.1", 0), + connected_(false), + int_vss_(new TestVirtualSocketServer()), + ext_vss_(new TestVirtualSocketServer()), + int_thread_(new Thread(int_vss_.get())), + ext_thread_(new Thread(ext_vss_.get())), + nat_(new NATServer(NAT_OPEN_CONE, + int_vss_.get(), + int_addr_, + int_addr_, + ext_vss_.get(), + ext_addr_)), + natsf_(new NATSocketFactory(int_vss_.get(), + nat_->internal_udp_address(), + nat_->internal_tcp_address())) { + int_thread_->Start(); + ext_thread_->Start(); + } + + void OnConnectEvent(Socket* socket) { connected_ = true; } + + void OnAcceptEvent(Socket* socket) { + accepted_.reset(server_->Accept(nullptr)); + } + + void OnCloseEvent(Socket* socket, int error) {} + + void ConnectEvents() { + server_->SignalReadEvent.connect(this, &NatTcpTest::OnAcceptEvent); + client_->SignalConnectEvent.connect(this, &NatTcpTest::OnConnectEvent); + } + + SocketAddress int_addr_; + SocketAddress ext_addr_; + bool connected_; + std::unique_ptr<TestVirtualSocketServer> int_vss_; + std::unique_ptr<TestVirtualSocketServer> ext_vss_; + std::unique_ptr<Thread> int_thread_; + std::unique_ptr<Thread> ext_thread_; + std::unique_ptr<NATServer> nat_; + std::unique_ptr<NATSocketFactory> natsf_; + std::unique_ptr<Socket> client_; + std::unique_ptr<Socket> server_; + std::unique_ptr<Socket> accepted_; +}; + +TEST_F(NatTcpTest, DISABLED_TestConnectOut) { + server_.reset(ext_vss_->CreateSocket(AF_INET, SOCK_STREAM)); + server_->Bind(ext_addr_); + server_->Listen(5); + + client_.reset(natsf_->CreateSocket(AF_INET, SOCK_STREAM)); + EXPECT_GE(0, client_->Bind(int_addr_)); + EXPECT_GE(0, client_->Connect(server_->GetLocalAddress())); + + ConnectEvents(); + + EXPECT_TRUE_WAIT(connected_, 1000); + EXPECT_EQ(client_->GetRemoteAddress(), server_->GetLocalAddress()); + EXPECT_EQ(accepted_->GetRemoteAddress().ipaddr(), ext_addr_.ipaddr()); + + std::unique_ptr<rtc::TestClient> in(CreateTCPTestClient(client_.release())); + std::unique_ptr<rtc::TestClient> out( + CreateTCPTestClient(accepted_.release())); + + const char* buf = "test_packet"; + size_t len = strlen(buf); + + in->Send(buf, len); + SocketAddress trans_addr; + EXPECT_TRUE(out->CheckNextPacket(buf, len, &trans_addr)); + + out->Send(buf, len); + EXPECT_TRUE(in->CheckNextPacket(buf, len, &trans_addr)); +} + +} // namespace +} // namespace rtc |