diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /dom/media/webrtc/transport/test | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'dom/media/webrtc/transport/test')
23 files changed, 12113 insertions, 0 deletions
diff --git a/dom/media/webrtc/transport/test/TestSyncRunnable.cpp b/dom/media/webrtc/transport/test/TestSyncRunnable.cpp new file mode 100644 index 0000000000..ca671b4e79 --- /dev/null +++ b/dom/media/webrtc/transport/test/TestSyncRunnable.cpp @@ -0,0 +1,56 @@ +/* -*- Mode: C++; tab-width: 12; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nsIThread.h" +#include "nsThreadUtils.h" +#include "mozilla/SyncRunnable.h" + +#include "gtest/gtest.h" + +using namespace mozilla; + +nsIThread* gThread = nullptr; + +class TestRunnable : public Runnable { + public: + TestRunnable() : Runnable("TestRunnable"), ran_(false) {} + + NS_IMETHOD Run() override { + ran_ = true; + + return NS_OK; + } + + bool ran() const { return ran_; } + + private: + bool ran_; +}; + +class TestSyncRunnable : public ::testing::Test { + public: + static void SetUpTestCase() { + nsresult rv = NS_NewNamedThread("thread", &gThread); + ASSERT_TRUE(NS_SUCCEEDED(rv)); + } + + static void TearDownTestCase() { + if (gThread) gThread->Shutdown(); + } +}; + +TEST_F(TestSyncRunnable, TestDispatch) { + RefPtr<TestRunnable> r(new TestRunnable()); + RefPtr<SyncRunnable> s(new SyncRunnable(r)); + s->DispatchToThread(gThread); + + ASSERT_TRUE(r->ran()); +} + +TEST_F(TestSyncRunnable, TestDispatchStatic) { + RefPtr<TestRunnable> r(new TestRunnable()); + SyncRunnable::DispatchToThread(gThread, r); + ASSERT_TRUE(r->ran()); +} diff --git a/dom/media/webrtc/transport/test/buffered_stun_socket_unittest.cpp b/dom/media/webrtc/transport/test/buffered_stun_socket_unittest.cpp new file mode 100644 index 0000000000..e6a9cd38a2 --- /dev/null +++ b/dom/media/webrtc/transport/test/buffered_stun_socket_unittest.cpp @@ -0,0 +1,245 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: ekr@rtfm.com + +extern "C" { +#include "nr_api.h" +#include "nr_socket.h" +#include "nr_socket_buffered_stun.h" +#include "transport_addr.h" +} + +#include "stun_msg.h" + +#include "dummysocket.h" + +#include "nr_socket_prsock.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +using namespace mozilla; + +static uint8_t kStunMessage[] = {0x00, 0x01, 0x00, 0x08, 0x21, 0x12, 0xa4, + 0x42, 0x9b, 0x90, 0xbe, 0x2c, 0xae, 0x1a, + 0x0c, 0xa8, 0xa0, 0xd6, 0x8b, 0x08, 0x80, + 0x28, 0x00, 0x04, 0xdb, 0x35, 0x5f, 0xaa}; +static size_t kStunMessageLen = sizeof(kStunMessage); + +class BufferedStunSocketTest : public MtransportTest { + public: + BufferedStunSocketTest() + : MtransportTest(), dummy_(nullptr), test_socket_(nullptr) {} + + ~BufferedStunSocketTest() { nr_socket_destroy(&test_socket_); } + + void SetUp() override { + MtransportTest::SetUp(); + + RefPtr<DummySocket> dummy(new DummySocket()); + + int r = + nr_socket_buffered_stun_create(dummy->get_nr_socket(), kStunMessageLen, + TURN_TCP_FRAMING, &test_socket_); + ASSERT_EQ(0, r); + dummy_ = std::move(dummy); // Now owned by test_socket_. + + r = nr_str_port_to_transport_addr((char*)"192.0.2.133", 3333, IPPROTO_TCP, + &remote_addr_); + ASSERT_EQ(0, r); + + r = nr_socket_connect(test_socket_, &remote_addr_); + ASSERT_EQ(0, r); + } + + nr_socket* socket() { return test_socket_; } + + protected: + RefPtr<DummySocket> dummy_; + nr_socket* test_socket_; + nr_transport_addr remote_addr_; +}; + +TEST_F(BufferedStunSocketTest, TestCreate) {} + +TEST_F(BufferedStunSocketTest, TestSendTo) { + int r = nr_socket_sendto(test_socket_, kStunMessage, kStunMessageLen, 0, + &remote_addr_); + ASSERT_EQ(0, r); + + dummy_->CheckWriteBuffer(kStunMessage, kStunMessageLen); +} + +TEST_F(BufferedStunSocketTest, TestSendToBuffered) { + dummy_->SetWritable(0); + + int r = nr_socket_sendto(test_socket_, kStunMessage, kStunMessageLen, 0, + &remote_addr_); + ASSERT_EQ(0, r); + + dummy_->CheckWriteBuffer(nullptr, 0); + + dummy_->SetWritable(kStunMessageLen); + dummy_->FireWritableCb(); + dummy_->CheckWriteBuffer(kStunMessage, kStunMessageLen); +} + +TEST_F(BufferedStunSocketTest, TestSendFullThenDrain) { + dummy_->SetWritable(0); + + for (;;) { + int r = nr_socket_sendto(test_socket_, kStunMessage, kStunMessageLen, 0, + &remote_addr_); + if (r == R_WOULDBLOCK) break; + + ASSERT_EQ(0, r); + } + + // Nothing was written. + dummy_->CheckWriteBuffer(nullptr, 0); + + // Now flush. + dummy_->SetWritable(kStunMessageLen); + dummy_->FireWritableCb(); + dummy_->ClearWriteBuffer(); + + // Verify we can write something. + int r = nr_socket_sendto(test_socket_, kStunMessage, kStunMessageLen, 0, + &remote_addr_); + ASSERT_EQ(0, r); + + // And that it appears. + dummy_->CheckWriteBuffer(kStunMessage, kStunMessageLen); +} + +TEST_F(BufferedStunSocketTest, TestSendToPartialBuffered) { + dummy_->SetWritable(10); + + int r = nr_socket_sendto(test_socket_, kStunMessage, kStunMessageLen, 0, + &remote_addr_); + ASSERT_EQ(0, r); + + dummy_->CheckWriteBuffer(kStunMessage, 10); + dummy_->ClearWriteBuffer(); + + dummy_->SetWritable(kStunMessageLen); + dummy_->FireWritableCb(); + dummy_->CheckWriteBuffer(kStunMessage + 10, kStunMessageLen - 10); +} + +TEST_F(BufferedStunSocketTest, TestSendToReject) { + dummy_->SetWritable(0); + + int r = nr_socket_sendto(test_socket_, kStunMessage, kStunMessageLen, 0, + &remote_addr_); + ASSERT_EQ(0, r); + + dummy_->CheckWriteBuffer(nullptr, 0); + + r = nr_socket_sendto(test_socket_, kStunMessage, kStunMessageLen, 0, + &remote_addr_); + ASSERT_EQ(R_WOULDBLOCK, r); + + dummy_->CheckWriteBuffer(nullptr, 0); +} + +TEST_F(BufferedStunSocketTest, TestSendToWrongAddr) { + nr_transport_addr addr; + + int r = nr_str_port_to_transport_addr((char*)"192.0.2.134", 3333, IPPROTO_TCP, + &addr); + ASSERT_EQ(0, r); + + r = nr_socket_sendto(test_socket_, kStunMessage, kStunMessageLen, 0, &addr); + ASSERT_EQ(R_BAD_DATA, r); +} + +TEST_F(BufferedStunSocketTest, TestReceiveRecvFrom) { + dummy_->SetReadBuffer(kStunMessage, kStunMessageLen); + + unsigned char tmp[2048]; + size_t len; + nr_transport_addr addr; + + int r = nr_socket_recvfrom(test_socket_, tmp, sizeof(tmp), &len, 0, &addr); + ASSERT_EQ(0, r); + ASSERT_EQ(kStunMessageLen, len); + ASSERT_EQ(0, memcmp(kStunMessage, tmp, kStunMessageLen)); + ASSERT_EQ(0, nr_transport_addr_cmp(&addr, &remote_addr_, + NR_TRANSPORT_ADDR_CMP_MODE_ALL)); +} + +TEST_F(BufferedStunSocketTest, TestReceiveRecvFromPartial) { + dummy_->SetReadBuffer(kStunMessage, 15); + + unsigned char tmp[2048]; + size_t len; + nr_transport_addr addr; + + int r = nr_socket_recvfrom(test_socket_, tmp, sizeof(tmp), &len, 0, &addr); + ASSERT_EQ(R_WOULDBLOCK, r); + + dummy_->SetReadBuffer(kStunMessage + 15, kStunMessageLen - 15); + + r = nr_socket_recvfrom(test_socket_, tmp, sizeof(tmp), &len, 0, &addr); + ASSERT_EQ(0, r); + ASSERT_EQ(kStunMessageLen, len); + ASSERT_EQ(0, memcmp(kStunMessage, tmp, kStunMessageLen)); + ASSERT_EQ(0, nr_transport_addr_cmp(&addr, &remote_addr_, + NR_TRANSPORT_ADDR_CMP_MODE_ALL)); + + r = nr_socket_recvfrom(test_socket_, tmp, sizeof(tmp), &len, 0, &addr); + ASSERT_EQ(R_WOULDBLOCK, r); +} + +TEST_F(BufferedStunSocketTest, TestReceiveRecvFromGarbage) { + uint8_t garbage[50]; + memset(garbage, 0xff, sizeof(garbage)); + + dummy_->SetReadBuffer(garbage, sizeof(garbage)); + + unsigned char tmp[2048]; + size_t len; + nr_transport_addr addr; + int r = nr_socket_recvfrom(test_socket_, tmp, sizeof(tmp), &len, 0, &addr); + ASSERT_EQ(R_BAD_DATA, r); + + r = nr_socket_recvfrom(test_socket_, tmp, sizeof(tmp), &len, 0, &addr); + ASSERT_EQ(R_FAILED, r); +} + +TEST_F(BufferedStunSocketTest, TestReceiveRecvFromTooShort) { + dummy_->SetReadBuffer(kStunMessage, kStunMessageLen); + + unsigned char tmp[2048]; + size_t len; + nr_transport_addr addr; + + int r = nr_socket_recvfrom(test_socket_, tmp, kStunMessageLen - 1, &len, 0, + &addr); + ASSERT_EQ(R_BAD_ARGS, r); +} + +TEST_F(BufferedStunSocketTest, TestReceiveRecvFromReallyLong) { + uint8_t garbage[4096]; + memset(garbage, 0xff, sizeof(garbage)); + memcpy(garbage, kStunMessage, kStunMessageLen); + nr_stun_message_header* hdr = + reinterpret_cast<nr_stun_message_header*>(garbage); + hdr->length = htons(3000); + + dummy_->SetReadBuffer(garbage, sizeof(garbage)); + + unsigned char tmp[4096]; + size_t len; + nr_transport_addr addr; + + int r = nr_socket_recvfrom(test_socket_, tmp, kStunMessageLen - 1, &len, 0, + &addr); + ASSERT_EQ(R_BAD_DATA, r); +} diff --git a/dom/media/webrtc/transport/test/dummysocket.h b/dom/media/webrtc/transport/test/dummysocket.h new file mode 100644 index 0000000000..6e20a1f7e7 --- /dev/null +++ b/dom/media/webrtc/transport/test/dummysocket.h @@ -0,0 +1,217 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original authors: ekr@rtfm.com; ryan@tokbox.com + +#ifndef MTRANSPORT_DUMMY_SOCKET_H_ +#define MTRANSPORT_DUMMY_SOCKET_H_ + +#include "nr_socket_prsock.h" + +extern "C" { +#include "transport_addr.h" +} + +#include "mediapacket.h" +#include "mozilla/UniquePtr.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +namespace mozilla { + +static UniquePtr<MediaPacket> merge(UniquePtr<MediaPacket> a, + UniquePtr<MediaPacket> b) { + if (a && a->len() && b && b->len()) { + UniquePtr<uint8_t[]> data(new uint8_t[a->len() + b->len()]); + memcpy(data.get(), a->data(), a->len()); + memcpy(data.get() + a->len(), b->data(), b->len()); + + UniquePtr<MediaPacket> merged(new MediaPacket); + merged->Take(std::move(data), a->len() + b->len()); + return merged; + } + + if (a && a->len()) { + return a; + } + + if (b && b->len()) { + return b; + } + + return nullptr; +} + +class DummySocket : public NrSocketBase { + public: + DummySocket() + : writable_(UINT_MAX), + write_buffer_(nullptr), + readable_(UINT_MAX), + read_buffer_(nullptr), + cb_(nullptr), + cb_arg_(nullptr), + self_(nullptr) {} + + // the nr_socket APIs + virtual int create(nr_transport_addr* addr) override { return 0; } + + virtual int sendto(const void* msg, size_t len, int flags, + const nr_transport_addr* to) override { + MOZ_CRASH(); + return 0; + } + + virtual int recvfrom(void* buf, size_t maxlen, size_t* len, int flags, + nr_transport_addr* from) override { + MOZ_CRASH(); + return 0; + } + + virtual int getaddr(nr_transport_addr* addrp) override { + MOZ_CRASH(); + return 0; + } + + virtual void close() override {} + + virtual int connect(const nr_transport_addr* addr) override { + nr_transport_addr_copy(&connect_addr_, addr); + return 0; + } + + virtual int listen(int backlog) override { return 0; } + + virtual int accept(nr_transport_addr* addrp, nr_socket** sockp) override { + return 0; + } + + virtual int write(const void* msg, size_t len, size_t* written) override { + size_t to_write = std::min(len, writable_); + + if (to_write) { + UniquePtr<MediaPacket> msgbuf(new MediaPacket); + msgbuf->Copy(static_cast<const uint8_t*>(msg), to_write); + write_buffer_ = merge(std::move(write_buffer_), std::move(msgbuf)); + } + + *written = to_write; + + return 0; + } + + virtual int read(void* buf, size_t maxlen, size_t* len) override { + if (!read_buffer_.get()) { + return R_WOULDBLOCK; + } + + size_t to_read = std::min(read_buffer_->len(), std::min(maxlen, readable_)); + + memcpy(buf, read_buffer_->data(), to_read); + *len = to_read; + + if (to_read < read_buffer_->len()) { + MediaPacket* newPacket = new MediaPacket; + newPacket->Copy(read_buffer_->data() + to_read, + read_buffer_->len() - to_read); + read_buffer_.reset(newPacket); + } else { + read_buffer_.reset(); + } + + return 0; + } + + // Implementations of the async_event APIs. + // These are no-ops because we handle scheduling manually + // for test purposes. + virtual int async_wait(int how, NR_async_cb cb, void* cb_arg, char* function, + int line) override { + EXPECT_EQ(nullptr, cb_); + cb_ = cb; + cb_arg_ = cb_arg; + + return 0; + } + + virtual int cancel(int how) override { + cb_ = nullptr; + cb_arg_ = nullptr; + + return 0; + } + + // Read/Manipulate the current state. + void CheckWriteBuffer(const uint8_t* data, size_t len) { + if (!len) { + EXPECT_EQ(nullptr, write_buffer_.get()); + } else { + EXPECT_NE(nullptr, write_buffer_.get()); + ASSERT_EQ(len, write_buffer_->len()); + ASSERT_EQ(0, memcmp(data, write_buffer_->data(), len)); + } + } + + void ClearWriteBuffer() { write_buffer_.reset(); } + + void SetWritable(size_t val) { writable_ = val; } + + void FireWritableCb() { + NR_async_cb cb = cb_; + void* cb_arg = cb_arg_; + + cb_ = nullptr; + cb_arg_ = nullptr; + + cb(this, NR_ASYNC_WAIT_WRITE, cb_arg); + } + + void SetReadBuffer(const uint8_t* data, size_t len) { + EXPECT_EQ(nullptr, write_buffer_.get()); + read_buffer_.reset(new MediaPacket); + read_buffer_->Copy(data, len); + } + + void ClearReadBuffer() { read_buffer_.reset(); } + + void SetReadable(size_t val) { readable_ = val; } + + nr_socket* get_nr_socket() { + if (!self_) { + int r = nr_socket_create_int(this, vtbl(), &self_); + AddRef(); + if (r) return nullptr; + } + + return self_; + } + + nr_transport_addr* get_connect_addr() { return &connect_addr_; } + + NS_INLINE_DECL_THREADSAFE_REFCOUNTING(DummySocket, override); + + private: + ~DummySocket() = default; + + DISALLOW_COPY_ASSIGN(DummySocket); + + size_t writable_; // Amount we allow someone to write. + UniquePtr<MediaPacket> write_buffer_; + size_t readable_; // Amount we allow someone to read. + UniquePtr<MediaPacket> read_buffer_; + + NR_async_cb cb_; + void* cb_arg_; + nr_socket* self_; + + nr_transport_addr connect_addr_; +}; + +} // namespace mozilla + +#endif diff --git a/dom/media/webrtc/transport/test/gtest_ringbuffer_dumper.h b/dom/media/webrtc/transport/test/gtest_ringbuffer_dumper.h new file mode 100644 index 0000000000..25e85c2155 --- /dev/null +++ b/dom/media/webrtc/transport/test/gtest_ringbuffer_dumper.h @@ -0,0 +1,78 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: bcampen@mozilla.com + +#ifndef gtest_ringbuffer_dumper_h__ +#define gtest_ringbuffer_dumper_h__ + +#include "mozilla/SyncRunnable.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +#include "mtransport_test_utils.h" +#include "runnable_utils.h" +#include "rlogconnector.h" + +using mozilla::RLogConnector; +using mozilla::WrapRunnable; + +namespace test { +class RingbufferDumper : public ::testing::EmptyTestEventListener { + public: + explicit RingbufferDumper(MtransportTestUtils* test_utils) + : test_utils_(test_utils) {} + + void ClearRingBuffer_s() { + RLogConnector::CreateInstance(); + // Set limit to zero to clear the ringbuffer + RLogConnector::GetInstance()->SetLogLimit(0); + RLogConnector::GetInstance()->SetLogLimit(UINT32_MAX); + } + + void DestroyRingBuffer_s() { RLogConnector::DestroyInstance(); } + + void DumpRingBuffer_s() { + std::deque<std::string> logs; + // Get an unlimited number of log lines, with no filter + RLogConnector::GetInstance()->GetAny(0, &logs); + for (auto l = logs.begin(); l != logs.end(); ++l) { + std::cout << *l << std::endl; + } + ClearRingBuffer_s(); + } + + virtual void OnTestStart(const ::testing::TestInfo& testInfo) override { + mozilla::SyncRunnable::DispatchToThread( + test_utils_->sts_target(), + WrapRunnable(this, &RingbufferDumper::ClearRingBuffer_s)); + } + + virtual void OnTestEnd(const ::testing::TestInfo& testInfo) override { + mozilla::SyncRunnable::DispatchToThread( + test_utils_->sts_target(), + WrapRunnable(this, &RingbufferDumper::DestroyRingBuffer_s)); + } + + // Called after a failed assertion or a SUCCEED() invocation. + virtual void OnTestPartResult( + const ::testing::TestPartResult& testResult) override { + if (testResult.failed()) { + // Dump (and empty) the RLogConnector + mozilla::SyncRunnable::DispatchToThread( + test_utils_->sts_target(), + WrapRunnable(this, &RingbufferDumper::DumpRingBuffer_s)); + } + } + + private: + MtransportTestUtils* test_utils_; +}; + +} // namespace test + +#endif // gtest_ringbuffer_dumper_h__ diff --git a/dom/media/webrtc/transport/test/gtest_utils.h b/dom/media/webrtc/transport/test/gtest_utils.h new file mode 100644 index 0000000000..40c2570ea1 --- /dev/null +++ b/dom/media/webrtc/transport/test/gtest_utils.h @@ -0,0 +1,201 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Utilities to wrap gtest, based on libjingle's gunit + +// Some sections of this code are under the following license: + +/* + * libjingle + * Copyright 2004--2008, Google Inc. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * 3. The name of the author may not be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO + * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +// Original author: ekr@rtfm.com +#ifndef gtest_utils__h__ +#define gtest_utils__h__ + +#include <iostream> + +#include "nspr.h" +#include "prinrval.h" +#include "prthread.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +#include "gtest_ringbuffer_dumper.h" +#include "mtransport_test_utils.h" +#include "nss.h" +#include "ssl.h" + +extern "C" { +#include "registry.h" +#include "transport_addr.h" +} + +// Wait up to timeout seconds for expression to be true +#define WAIT(expression, timeout) \ + do { \ + for (PRIntervalTime start = PR_IntervalNow(); \ + !(expression) && !((PR_IntervalNow() - start) > \ + PR_MillisecondsToInterval(timeout));) { \ + PR_Sleep(10); \ + } \ + } while (0) + +// Same as GTEST_WAIT, but stores the result in res. Used when +// you also want the result of expression but wish to avoid +// double evaluation. +#define WAIT_(expression, timeout, res) \ + do { \ + for (PRIntervalTime start = PR_IntervalNow(); \ + !(res = (expression)) && !((PR_IntervalNow() - start) > \ + PR_MillisecondsToInterval(timeout));) { \ + PR_Sleep(10); \ + } \ + } while (0) + +#define ASSERT_TRUE_WAIT(expression, timeout) \ + do { \ + bool res; \ + WAIT_(expression, timeout, res); \ + ASSERT_TRUE(res); \ + } while (0) + +#define EXPECT_TRUE_WAIT(expression, timeout) \ + do { \ + bool res; \ + WAIT_(expression, timeout, res); \ + EXPECT_TRUE(res); \ + } while (0) + +#define ASSERT_EQ_WAIT(expected, actual, timeout) \ + do { \ + WAIT(expected == actual, timeout); \ + ASSERT_EQ(expected, actual); \ + } while (0) + +using test::RingbufferDumper; + +class MtransportTest : public ::testing::Test { + public: + MtransportTest() : test_utils_(nullptr), dumper_(nullptr) {} + + void SetUp() override { + test_utils_ = new MtransportTestUtils(); + NSS_NoDB_Init(nullptr); + NSS_SetDomesticPolicy(); + + NR_reg_init(NR_REG_MODE_LOCAL); + + // Attempt to load env vars used by tests. + GetEnvironment("TURN_SERVER_ADDRESS", turn_server_); + GetEnvironment("TURN_SERVER_USER", turn_user_); + GetEnvironment("TURN_SERVER_PASSWORD", turn_password_); + GetEnvironment("STUN_SERVER_ADDRESS", stun_server_address_); + GetEnvironment("STUN_SERVER_HOSTNAME", stun_server_hostname_); + + std::string disable_non_local; + GetEnvironment("MOZ_DISABLE_NONLOCAL_CONNECTIONS", disable_non_local); + std::string upload_dir; + GetEnvironment("MOZ_UPLOAD_DIR", upload_dir); + + if ((!disable_non_local.empty() && disable_non_local != "0") || + !upload_dir.empty()) { + // We're assuming that MOZ_UPLOAD_DIR is only set on tbpl; + // MOZ_DISABLE_NONLOCAL_CONNECTIONS probably should be set when running + // the cpp unit-tests, but is not presently. + stun_server_address_ = ""; + stun_server_hostname_ = ""; + turn_server_ = ""; + } + + // Some tests are flaky and need to check if they're supposed to run. + webrtc_enabled_ = CheckEnvironmentFlag("MOZ_WEBRTC_TESTS"); + + ::testing::TestEventListeners& listeners = + ::testing::UnitTest::GetInstance()->listeners(); + + dumper_ = new RingbufferDumper(test_utils_); + listeners.Append(dumper_); + } + + void TearDown() override { + ::testing::UnitTest::GetInstance()->listeners().Release(dumper_); + delete dumper_; + delete test_utils_; + } + + void GetEnvironment(const char* aVar, std::string& out) { + char* value = getenv(aVar); + if (value) { + out = value; + } + } + + bool CheckEnvironmentFlag(const char* aVar) { + std::string value; + GetEnvironment(aVar, value); + return value == "1"; + } + + bool WarnIfTurnNotConfigured() const { + bool configured = + !turn_server_.empty() && !turn_user_.empty() && !turn_password_.empty(); + + if (configured) { + nr_transport_addr addr; + if (nr_str_port_to_transport_addr(turn_server_.c_str(), 3478, IPPROTO_UDP, + &addr)) { + printf( + "Invalid TURN_SERVER_ADDRESS \"%s\". Only IP numbers supported.\n", + turn_server_.c_str()); + configured = false; + } + } else { + printf( + "Set TURN_SERVER_ADDRESS, TURN_SERVER_USER, and " + "TURN_SERVER_PASSWORD\n" + "environment variables to run this test\n"); + } + + return !configured; + } + + MtransportTestUtils* test_utils_; + RingbufferDumper* dumper_; + + std::string turn_server_; + std::string turn_user_; + std::string turn_password_; + std::string stun_server_address_; + std::string stun_server_hostname_; + + bool webrtc_enabled_; +}; +#endif diff --git a/dom/media/webrtc/transport/test/ice_unittest.cpp b/dom/media/webrtc/transport/test/ice_unittest.cpp new file mode 100644 index 0000000000..d87fa0b0da --- /dev/null +++ b/dom/media/webrtc/transport/test/ice_unittest.cpp @@ -0,0 +1,4400 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: ekr@rtfm.com + +#include <algorithm> +#include <deque> +#include <iostream> +#include <limits> +#include <map> +#include <string> +#include <vector> + +#include "sigslot.h" + +#include "logging.h" +#include "ssl.h" + +#include "mozilla/Preferences.h" +#include "nsThreadUtils.h" +#include "nsXPCOM.h" + +extern "C" { +#include "r_types.h" +#include "async_wait.h" +#include "async_timer.h" +#include "r_data.h" +#include "util.h" +#include "r_time.h" +} + +#include "ice_ctx.h" +#include "ice_peer_ctx.h" +#include "ice_media_stream.h" + +#include "nricectx.h" +#include "nricemediastream.h" +#include "nriceresolverfake.h" +#include "nriceresolver.h" +#include "nrinterfaceprioritizer.h" +#include "gtest_ringbuffer_dumper.h" +#include "rlogconnector.h" +#include "runnable_utils.h" +#include "stunserver.h" +#include "nr_socket_prsock.h" +#include "test_nr_socket.h" +#include "nsISocketFilter.h" +#include "mozilla/net/DNS.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +using namespace mozilla; + +static unsigned int kDefaultTimeout = 7000; + +// TODO(nils@mozilla.com): This should get replaced with some non-external +// solution like discussed in bug 860775. +const std::string kDefaultStunServerHostname((char*)"stun.l.google.com"); +const std::string kBogusStunServerHostname( + (char*)"stun-server-nonexistent.invalid"); +const uint16_t kDefaultStunServerPort = 19305; +const std::string kBogusIceCandidate( + (char*)"candidate:0 2 UDP 2113601790 192.168.178.20 50769 typ"); + +const std::string kUnreachableHostIceCandidate( + (char*)"candidate:0 1 UDP 2113601790 192.168.178.20 50769 typ host"); + +namespace { + +// DNS resolution helper code +static std::string Resolve(const std::string& fqdn, int address_family) { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = address_family; + hints.ai_protocol = IPPROTO_UDP; + struct addrinfo* res; + int err = getaddrinfo(fqdn.c_str(), nullptr, &hints, &res); + if (err) { + std::cerr << "Error in getaddrinfo: " << err << std::endl; + return ""; + } + + char str_addr[64] = {0}; + switch (res->ai_family) { + case AF_INET: + inet_ntop(AF_INET, + &reinterpret_cast<struct sockaddr_in*>(res->ai_addr)->sin_addr, + str_addr, sizeof(str_addr)); + break; + case AF_INET6: + inet_ntop( + AF_INET6, + &reinterpret_cast<struct sockaddr_in6*>(res->ai_addr)->sin6_addr, + str_addr, sizeof(str_addr)); + break; + default: + std::cerr << "Got unexpected address family in DNS lookup: " + << res->ai_family << std::endl; + freeaddrinfo(res); + return ""; + } + + if (!strlen(str_addr)) { + std::cerr << "inet_ntop failed" << std::endl; + } + + freeaddrinfo(res); + return str_addr; +} + +class StunTest : public MtransportTest { + public: + StunTest() : MtransportTest() {} + + void SetUp() override { + MtransportTest::SetUp(); + + stun_server_hostname_ = kDefaultStunServerHostname; + // If only a STUN server FQDN was provided, look up its IP address for the + // address-only tests. + if (stun_server_address_.empty() && !stun_server_hostname_.empty()) { + stun_server_address_ = Resolve(stun_server_hostname_, AF_INET); + ASSERT_TRUE(!stun_server_address_.empty()); + } + + // Make sure NrIceCtx is in a testable state. + test_utils_->SyncDispatchToSTS( + WrapRunnableNM(&NrIceCtx::internal_DeinitializeGlobal)); + + // NB: NrIceCtx::internal_DeinitializeGlobal destroys the RLogConnector + // singleton. + RLogConnector::CreateInstance(); + + test_utils_->SyncDispatchToSTS( + WrapRunnableNM(&TestStunServer::GetInstance, AF_INET)); + test_utils_->SyncDispatchToSTS( + WrapRunnableNM(&TestStunServer::GetInstance, AF_INET6)); + + test_utils_->SyncDispatchToSTS( + WrapRunnableNM(&TestStunTcpServer::GetInstance, AF_INET)); + test_utils_->SyncDispatchToSTS( + WrapRunnableNM(&TestStunTcpServer::GetInstance, AF_INET6)); + } + + void TearDown() override { + test_utils_->SyncDispatchToSTS( + WrapRunnableNM(&NrIceCtx::internal_DeinitializeGlobal)); + + test_utils_->SyncDispatchToSTS( + WrapRunnableNM(&TestStunServer::ShutdownInstance)); + + test_utils_->SyncDispatchToSTS( + WrapRunnableNM(&TestStunTcpServer::ShutdownInstance)); + + RLogConnector::DestroyInstance(); + + MtransportTest::TearDown(); + } +}; + +enum TrickleMode { TRICKLE_NONE, TRICKLE_SIMULATE, TRICKLE_REAL }; + +enum ConsentStatus { CONSENT_FRESH, CONSENT_STALE, CONSENT_EXPIRED }; + +typedef std::string (*CandidateFilter)(const std::string& candidate); + +std::vector<std::string> split(const std::string& s, char delim) { + std::vector<std::string> elems; + std::stringstream ss(s); + std::string item; + while (std::getline(ss, item, delim)) { + elems.push_back(item); + } + return elems; +} + +static std::string IsSrflxCandidate(const std::string& candidate) { + std::vector<std::string> tokens = split(candidate, ' '); + if ((tokens.at(6) == "typ") && (tokens.at(7) == "srflx")) { + return candidate; + } + return std::string(); +} + +static std::string IsRelayCandidate(const std::string& candidate) { + if (candidate.find("typ relay") != std::string::npos) { + return candidate; + } + return std::string(); +} + +static std::string IsTcpCandidate(const std::string& candidate) { + if (candidate.find("TCP") != std::string::npos) { + return candidate; + } + return std::string(); +} + +static std::string IsTcpSoCandidate(const std::string& candidate) { + if (candidate.find("tcptype so") != std::string::npos) { + return candidate; + } + return std::string(); +} + +static std::string IsLoopbackCandidate(const std::string& candidate) { + if (candidate.find("127.0.0.") != std::string::npos) { + return candidate; + } + return std::string(); +} + +static std::string IsIpv4Candidate(const std::string& candidate) { + std::vector<std::string> tokens = split(candidate, ' '); + if (tokens.at(4).find(':') == std::string::npos) { + return candidate; + } + return std::string(); +} + +static std::string SabotageHostCandidateAndDropReflexive( + const std::string& candidate) { + if (candidate.find("typ srflx") != std::string::npos) { + return std::string(); + } + + if (candidate.find("typ host") != std::string::npos) { + return kUnreachableHostIceCandidate; + } + + return candidate; +} + +bool ContainsSucceededPair(const std::vector<NrIceCandidatePair>& pairs) { + for (const auto& pair : pairs) { + if (pair.state == NrIceCandidatePair::STATE_SUCCEEDED) { + return true; + } + } + return false; +} + +// Note: Does not correspond to any notion of prioritization; this is just +// so we can use stl containers/algorithms that need a comparator +bool operator<(const NrIceCandidate& lhs, const NrIceCandidate& rhs) { + if (lhs.cand_addr.host == rhs.cand_addr.host) { + if (lhs.cand_addr.port == rhs.cand_addr.port) { + if (lhs.cand_addr.transport == rhs.cand_addr.transport) { + if (lhs.type == rhs.type) { + return lhs.tcp_type < rhs.tcp_type; + } + return lhs.type < rhs.type; + } + return lhs.cand_addr.transport < rhs.cand_addr.transport; + } + return lhs.cand_addr.port < rhs.cand_addr.port; + } + return lhs.cand_addr.host < rhs.cand_addr.host; +} + +bool operator==(const NrIceCandidate& lhs, const NrIceCandidate& rhs) { + return !((lhs < rhs) || (rhs < lhs)); +} + +class IceCandidatePairCompare { + public: + bool operator()(const NrIceCandidatePair& lhs, + const NrIceCandidatePair& rhs) const { + if (lhs.priority == rhs.priority) { + if (lhs.local == rhs.local) { + if (lhs.remote == rhs.remote) { + return lhs.codeword < rhs.codeword; + } + return lhs.remote < rhs.remote; + } + return lhs.local < rhs.local; + } + return lhs.priority < rhs.priority; + } +}; + +class IceTestPeer; + +class SchedulableTrickleCandidate { + public: + SchedulableTrickleCandidate(IceTestPeer* peer, size_t stream, + const std::string& candidate, + const std::string& ufrag, + MtransportTestUtils* utils) + : peer_(peer), + stream_(stream), + candidate_(candidate), + ufrag_(ufrag), + timer_handle_(nullptr), + test_utils_(utils) {} + + ~SchedulableTrickleCandidate() { + if (timer_handle_) NR_async_timer_cancel(timer_handle_); + } + + void Schedule(unsigned int ms) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &SchedulableTrickleCandidate::Schedule_s, ms)); + } + + void Schedule_s(unsigned int ms) { + MOZ_ASSERT(!timer_handle_); + NR_ASYNC_TIMER_SET(ms, Trickle_cb, this, &timer_handle_); + } + + static void Trickle_cb(NR_SOCKET s, int how, void* cb_arg) { + static_cast<SchedulableTrickleCandidate*>(cb_arg)->Trickle(); + } + + void Trickle(); + + std::string& Candidate() { return candidate_; } + + const std::string& Candidate() const { return candidate_; } + + bool IsHost() const { + return candidate_.find("typ host") != std::string::npos; + } + + bool IsReflexive() const { + return candidate_.find("typ srflx") != std::string::npos; + } + + bool IsRelay() const { + return candidate_.find("typ relay") != std::string::npos; + } + + private: + IceTestPeer* peer_; + size_t stream_; + std::string candidate_; + std::string ufrag_; + void* timer_handle_; + MtransportTestUtils* test_utils_; + + DISALLOW_COPY_ASSIGN(SchedulableTrickleCandidate); +}; + +class IceTestPeer : public sigslot::has_slots<> { + public: + IceTestPeer(const std::string& name, MtransportTestUtils* utils, bool offerer, + const NrIceCtx::Config& config) + : name_(name), + ice_ctx_(NrIceCtx::Create(name)), + offerer_(offerer), + candidates_(), + stream_counter_(0), + shutting_down_(false), + gathering_complete_(false), + ready_ct_(0), + ice_connected_(false), + ice_failed_(false), + ice_reached_checking_(false), + received_(0), + sent_(0), + fake_resolver_(), + dns_resolver_(new NrIceResolver()), + remote_(nullptr), + candidate_filter_(nullptr), + expected_local_type_(NrIceCandidate::ICE_HOST), + expected_local_transport_(kNrIceTransportUdp), + expected_remote_type_(NrIceCandidate::ICE_HOST), + trickle_mode_(TRICKLE_NONE), + simulate_ice_lite_(false), + nat_(new TestNat), + test_utils_(utils) { + ice_ctx_->SignalGatheringStateChange.connect( + this, &IceTestPeer::GatheringStateChange); + ice_ctx_->SignalConnectionStateChange.connect( + this, &IceTestPeer::ConnectionStateChange); + + ice_ctx_->SetIceConfig(config); + + consent_timestamp_.tv_sec = 0; + consent_timestamp_.tv_usec = 0; + int r = ice_ctx_->SetNat(nat_); + (void)r; + MOZ_ASSERT(!r); + } + + ~IceTestPeer() { + test_utils_->SyncDispatchToSTS(WrapRunnable(this, &IceTestPeer::Shutdown)); + + // Give the ICE destruction callback time to fire before + // we destroy the resolver. + PR_Sleep(1000); + } + + std::string MakeTransportId(size_t index) const { + char id[100]; + snprintf(id, sizeof(id), "%s:stream%d", name_.c_str(), (int)index); + return id; + } + + void SetIceCredentials_s(NrIceMediaStream& stream) { + static size_t counter = 0; + std::ostringstream prefix; + prefix << name_ << "-" << counter++; + std::string ufrag = prefix.str() + "-ufrag"; + std::string pwd = prefix.str() + "-pwd"; + if (mIceCredentials.count(stream.GetId())) { + mOldIceCredentials[stream.GetId()] = mIceCredentials[stream.GetId()]; + } + mIceCredentials[stream.GetId()] = std::make_pair(ufrag, pwd); + stream.SetIceCredentials(ufrag, pwd); + } + + void AddStream_s(int components) { + std::string id = MakeTransportId(stream_counter_++); + + RefPtr<NrIceMediaStream> stream = + ice_ctx_->CreateStream(id, id, components); + + ASSERT_TRUE(stream); + SetIceCredentials_s(*stream); + + stream->SignalCandidate.connect(this, &IceTestPeer::CandidateInitialized); + stream->SignalReady.connect(this, &IceTestPeer::StreamReady); + stream->SignalFailed.connect(this, &IceTestPeer::StreamFailed); + stream->SignalPacketReceived.connect(this, &IceTestPeer::PacketReceived); + } + + void AddStream(int components) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &IceTestPeer::AddStream_s, components)); + } + + void RemoveStream_s(size_t index) { + ice_ctx_->DestroyStream(MakeTransportId(index)); + } + + void RemoveStream(size_t index) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &IceTestPeer::RemoveStream_s, index)); + } + + RefPtr<NrIceMediaStream> GetStream_s(size_t index) { + std::string id = MakeTransportId(index); + return ice_ctx_->GetStream(id); + } + + void SetStunServer(const std::string addr, uint16_t port, + const char* transport = kNrIceTransportUdp) { + if (addr.empty()) { + // Happens when MOZ_DISABLE_NONLOCAL_CONNECTIONS is set + return; + } + + std::vector<NrIceStunServer> stun_servers; + UniquePtr<NrIceStunServer> server( + NrIceStunServer::Create(addr, port, transport)); + stun_servers.push_back(*server); + SetStunServers(stun_servers); + } + + void SetStunServers(const std::vector<NrIceStunServer>& servers) { + ASSERT_TRUE(NS_SUCCEEDED(ice_ctx_->SetStunServers(servers))); + } + + void UseTestStunServer() { + SetStunServer(TestStunServer::GetInstance(AF_INET)->addr(), + TestStunServer::GetInstance(AF_INET)->port()); + } + + void SetTurnServer(const std::string addr, uint16_t port, + const std::string username, const std::string password, + const char* transport) { + std::vector<unsigned char> password_vec(password.begin(), password.end()); + SetTurnServer(addr, port, username, password_vec, transport); + } + + void SetTurnServer(const std::string addr, uint16_t port, + const std::string username, + const std::vector<unsigned char> password, + const char* transport) { + std::vector<NrIceTurnServer> turn_servers; + UniquePtr<NrIceTurnServer> server( + NrIceTurnServer::Create(addr, port, username, password, transport)); + turn_servers.push_back(*server); + ASSERT_TRUE(NS_SUCCEEDED(ice_ctx_->SetTurnServers(turn_servers))); + } + + void SetTurnServers(const std::vector<NrIceTurnServer> servers) { + ASSERT_TRUE(NS_SUCCEEDED(ice_ctx_->SetTurnServers(servers))); + } + + void SetFakeResolver(const std::string& ip, const std::string& fqdn) { + ASSERT_TRUE(NS_SUCCEEDED(dns_resolver_->Init())); + if (!ip.empty() && !fqdn.empty()) { + PRNetAddr addr; + PRStatus status = PR_StringToNetAddr(ip.c_str(), &addr); + addr.inet.port = kDefaultStunServerPort; + ASSERT_EQ(PR_SUCCESS, status); + fake_resolver_.SetAddr(fqdn, addr); + } + ASSERT_TRUE( + NS_SUCCEEDED(ice_ctx_->SetResolver(fake_resolver_.AllocateResolver()))); + } + + void SetDNSResolver() { + ASSERT_TRUE(NS_SUCCEEDED(dns_resolver_->Init())); + ASSERT_TRUE( + NS_SUCCEEDED(ice_ctx_->SetResolver(dns_resolver_->AllocateResolver()))); + } + + void Gather(bool default_route_only = false, + bool obfuscate_host_addresses = false) { + nsresult res; + + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&res, ice_ctx_, &NrIceCtx::StartGathering, + default_route_only, obfuscate_host_addresses)); + + ASSERT_TRUE(NS_SUCCEEDED(res)); + } + + void SetCtxFlags(bool default_route_only) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(ice_ctx_, &NrIceCtx::SetCtxFlags, default_route_only)); + } + + nsTArray<NrIceStunAddr> GetStunAddrs() { return ice_ctx_->GetStunAddrs(); } + + void SetStunAddrs(const nsTArray<NrIceStunAddr>& addrs) { + ice_ctx_->SetStunAddrs(addrs); + } + + void UseNat() { nat_->enabled_ = true; } + + void SetTimerDivider(int div) { ice_ctx_->internal_SetTimerAccelarator(div); } + + void SetStunResponseDelay(uint32_t delay) { + nat_->delay_stun_resp_ms_ = delay; + } + + void SetFilteringType(TestNat::NatBehavior type) { + MOZ_ASSERT(!nat_->has_port_mappings()); + nat_->filtering_type_ = type; + } + + void SetMappingType(TestNat::NatBehavior type) { + MOZ_ASSERT(!nat_->has_port_mappings()); + nat_->mapping_type_ = type; + } + + void SetBlockUdp(bool block) { + MOZ_ASSERT(!nat_->has_port_mappings()); + nat_->block_udp_ = block; + } + + void SetBlockStun(bool block) { nat_->block_stun_ = block; } + + // Get various pieces of state + std::vector<std::string> GetGlobalAttributes() { + std::vector<std::string> attrs(ice_ctx_->GetGlobalAttributes()); + if (simulate_ice_lite_) { + attrs.push_back("ice-lite"); + } + return attrs; + } + + std::vector<std::string> GetAttributes(size_t stream) { + std::vector<std::string> v; + + RUN_ON_THREAD( + test_utils_->sts_target(), + WrapRunnableRet(&v, this, &IceTestPeer::GetAttributes_s, stream)); + + return v; + } + + std::string FilterCandidate(const std::string& candidate) { + if (candidate_filter_) { + return candidate_filter_(candidate); + } + return candidate; + } + + std::vector<std::string> GetAttributes_s(size_t index) { + std::vector<std::string> attributes; + + auto stream = GetStream_s(index); + if (!stream) { + EXPECT_TRUE(false) << "No such stream " << index; + return attributes; + } + + std::vector<std::string> attributes_in = stream->GetAttributes(); + + for (const auto& attribute : attributes_in) { + if (attribute.find("candidate:") != std::string::npos) { + std::string candidate(FilterCandidate(attribute)); + if (!candidate.empty()) { + std::cerr << name_ << " Returning candidate: " << candidate + << std::endl; + attributes.push_back(candidate); + } + } else { + attributes.push_back(attribute); + } + } + + return attributes; + } + + void SetExpectedTypes(NrIceCandidate::Type local, NrIceCandidate::Type remote, + std::string local_transport = kNrIceTransportUdp) { + expected_local_type_ = local; + expected_local_transport_ = local_transport; + expected_remote_type_ = remote; + } + + void SetExpectedRemoteCandidateAddr(const std::string& addr) { + expected_remote_addr_ = addr; + } + + int GetCandidatesPrivateIpv4Range(size_t stream) { + std::vector<std::string> attributes = GetAttributes(stream); + + int host_net = 0; + for (const auto& a : attributes) { + if (a.find("typ host") != std::string::npos) { + nr_transport_addr addr; + std::vector<std::string> tokens = split(a, ' '); + int r = nr_str_port_to_transport_addr(tokens.at(4).c_str(), 0, + IPPROTO_UDP, &addr); + MOZ_ASSERT(!r); + if (!r && (addr.ip_version == NR_IPV4)) { + int n = nr_transport_addr_get_private_addr_range(&addr); + if (n) { + if (host_net) { + // TODO: add support for multiple private interfaces + std::cerr + << "This test doesn't support multiple private interfaces"; + return -1; + } + host_net = n; + } + } + } + } + return host_net; + } + + bool gathering_complete() { return gathering_complete_; } + int ready_ct() { return ready_ct_; } + bool is_ready_s(size_t index) { + auto media_stream = GetStream_s(index); + if (!media_stream) { + EXPECT_TRUE(false) << "No such stream " << index; + return false; + } + return media_stream->state() == NrIceMediaStream::ICE_OPEN; + } + bool is_ready(size_t stream) { + bool result; + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&result, this, &IceTestPeer::is_ready_s, stream)); + return result; + } + bool ice_connected() { return ice_connected_; } + bool ice_failed() { return ice_failed_; } + bool ice_reached_checking() { return ice_reached_checking_; } + size_t received() { return received_; } + size_t sent() { return sent_; } + + void RestartIce() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &IceTestPeer::RestartIce_s)); + } + + void RestartIce_s() { + for (auto& stream : ice_ctx_->GetStreams()) { + SetIceCredentials_s(*stream); + } + // take care of some local bookkeeping + ready_ct_ = 0; + gathering_complete_ = false; + ice_connected_ = false; + ice_failed_ = false; + ice_reached_checking_ = false; + remote_ = nullptr; + } + + void RollbackIceRestart() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &IceTestPeer::RollbackIceRestart_s)); + } + + void RollbackIceRestart_s() { + for (auto& stream : ice_ctx_->GetStreams()) { + mIceCredentials[stream->GetId()] = mOldIceCredentials[stream->GetId()]; + } + } + + // Start connecting to another peer + void Connect_s(IceTestPeer* remote, TrickleMode trickle_mode, + bool start = true) { + nsresult res; + + remote_ = remote; + + trickle_mode_ = trickle_mode; + ice_connected_ = false; + ice_failed_ = false; + ice_reached_checking_ = false; + res = ice_ctx_->ParseGlobalAttributes(remote->GetGlobalAttributes()); + ASSERT_FALSE(remote->simulate_ice_lite_ && + (ice_ctx_->GetControlling() == NrIceCtx::ICE_CONTROLLED)); + ASSERT_TRUE(NS_SUCCEEDED(res)); + + for (size_t i = 0; i < stream_counter_; ++i) { + auto aStream = GetStream_s(i); + if (aStream) { + std::vector<std::string> attributes = remote->GetAttributes(i); + + for (auto it = attributes.begin(); it != attributes.end();) { + if (trickle_mode == TRICKLE_SIMULATE && + it->find("candidate:") != std::string::npos) { + std::cerr << name_ << " Deferring remote candidate: " << *it + << std::endl; + attributes.erase(it); + } else { + std::cerr << name_ << " Adding remote attribute: " + *it + << std::endl; + ++it; + } + } + auto credentials = mIceCredentials[aStream->GetId()]; + res = aStream->ConnectToPeer(credentials.first, credentials.second, + attributes); + ASSERT_TRUE(NS_SUCCEEDED(res)); + } + } + + if (start) { + ice_ctx_->SetControlling(offerer_ ? NrIceCtx::ICE_CONTROLLING + : NrIceCtx::ICE_CONTROLLED); + // Now start checks + res = ice_ctx_->StartChecks(); + ASSERT_TRUE(NS_SUCCEEDED(res)); + } + } + + void Connect(IceTestPeer* remote, TrickleMode trickle_mode, + bool start = true) { + test_utils_->SyncDispatchToSTS(WrapRunnable(this, &IceTestPeer::Connect_s, + remote, trickle_mode, start)); + } + + void SimulateTrickle(size_t stream) { + std::cerr << name_ << " Doing trickle for stream " << stream << std::endl; + // If we are in trickle deferred mode, now trickle in the candidates + // for |stream| + + std::vector<SchedulableTrickleCandidate*>& candidates = + ControlTrickle(stream); + + for (auto& candidate : candidates) { + candidate->Schedule(0); + } + } + + // Allows test case to completely control when/if candidates are trickled + // (test could also do things like insert extra trickle candidates, or + // change existing ones, or insert duplicates, really anything is fair game) + std::vector<SchedulableTrickleCandidate*>& ControlTrickle(size_t stream) { + std::cerr << "Doing controlled trickle for stream " << stream << std::endl; + + std::vector<std::string> attributes = remote_->GetAttributes(stream); + + for (const auto& attribute : attributes) { + if (attribute.find("candidate:") != std::string::npos) { + controlled_trickle_candidates_[stream].push_back( + new SchedulableTrickleCandidate(this, stream, attribute, "", + test_utils_)); + } + } + + return controlled_trickle_candidates_[stream]; + } + + nsresult TrickleCandidate_s(const std::string& candidate, + const std::string& ufrag, size_t index) { + auto stream = GetStream_s(index); + if (!stream) { + // stream might have gone away before the trickle timer popped + return NS_OK; + } + return stream->ParseTrickleCandidate(candidate, ufrag, ""); + } + + void DumpCandidate(std::string which, const NrIceCandidate& cand) { + std::string type; + std::string tcp_type; + + std::string addr; + int port; + + if (which.find("Remote") != std::string::npos) { + addr = cand.cand_addr.host; + port = cand.cand_addr.port; + } else { + addr = cand.local_addr.host; + port = cand.local_addr.port; + } + switch (cand.type) { + case NrIceCandidate::ICE_HOST: + type = "host"; + break; + case NrIceCandidate::ICE_SERVER_REFLEXIVE: + type = "srflx"; + break; + case NrIceCandidate::ICE_PEER_REFLEXIVE: + type = "prflx"; + break; + case NrIceCandidate::ICE_RELAYED: + type = "relay"; + if (which.find("Local") != std::string::npos) { + type += "(" + cand.local_addr.transport + ")"; + } + break; + default: + FAIL(); + }; + + switch (cand.tcp_type) { + case NrIceCandidate::ICE_NONE: + break; + case NrIceCandidate::ICE_ACTIVE: + tcp_type = " tcptype=active"; + break; + case NrIceCandidate::ICE_PASSIVE: + tcp_type = " tcptype=passive"; + break; + case NrIceCandidate::ICE_SO: + tcp_type = " tcptype=so"; + break; + default: + FAIL(); + }; + + std::cerr << which << " --> " << type << " " << addr << ":" << port << "/" + << cand.cand_addr.transport << tcp_type + << " codeword=" << cand.codeword << std::endl; + } + + void DumpAndCheckActiveCandidates_s() { + std::cerr << name_ << " Active candidates:" << std::endl; + for (const auto& stream : ice_ctx_->GetStreams()) { + for (size_t j = 0; j < stream->components(); ++j) { + std::cerr << name_ << " Stream " << stream->GetId() << " component " + << j + 1 << std::endl; + + UniquePtr<NrIceCandidate> local; + UniquePtr<NrIceCandidate> remote; + + nsresult res = stream->GetActivePair(j + 1, &local, &remote); + if (res == NS_ERROR_NOT_AVAILABLE) { + std::cerr << "Component unpaired or disabled." << std::endl; + } else { + ASSERT_TRUE(NS_SUCCEEDED(res)); + DumpCandidate("Local ", *local); + /* Depending on timing, and the whims of the network + * stack/configuration we're running on top of, prflx is always a + * possibility. */ + if (expected_local_type_ == NrIceCandidate::ICE_HOST) { + ASSERT_NE(NrIceCandidate::ICE_SERVER_REFLEXIVE, local->type); + ASSERT_NE(NrIceCandidate::ICE_RELAYED, local->type); + } else { + ASSERT_EQ(expected_local_type_, local->type); + } + ASSERT_EQ(expected_local_transport_, local->local_addr.transport); + DumpCandidate("Remote ", *remote); + /* Depending on timing, and the whims of the network + * stack/configuration we're running on top of, prflx is always a + * possibility. */ + if (expected_remote_type_ == NrIceCandidate::ICE_HOST) { + ASSERT_NE(NrIceCandidate::ICE_SERVER_REFLEXIVE, remote->type); + ASSERT_NE(NrIceCandidate::ICE_RELAYED, remote->type); + } else { + ASSERT_EQ(expected_remote_type_, remote->type); + } + if (!expected_remote_addr_.empty()) { + ASSERT_EQ(expected_remote_addr_, remote->cand_addr.host); + } + } + } + } + } + + void DumpAndCheckActiveCandidates() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &IceTestPeer::DumpAndCheckActiveCandidates_s)); + } + + void Close() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(ice_ctx_, &NrIceCtx::destroy_peer_ctx)); + } + + void Shutdown() { + std::cerr << name_ << " Shutdown" << std::endl; + shutting_down_ = true; + for (auto& controlled_trickle_candidate : controlled_trickle_candidates_) { + for (auto& cand : controlled_trickle_candidate.second) { + delete cand; + } + } + + ice_ctx_->Destroy(); + ice_ctx_ = nullptr; + + if (remote_) { + remote_->UnsetRemote(); + remote_ = nullptr; + } + } + + void UnsetRemote() { remote_ = nullptr; } + + void StartChecks() { + nsresult res; + + test_utils_->SyncDispatchToSTS(WrapRunnableRet( + &res, ice_ctx_, &NrIceCtx::SetControlling, + offerer_ ? NrIceCtx::ICE_CONTROLLING : NrIceCtx::ICE_CONTROLLED)); + // Now start checks + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&res, ice_ctx_, &NrIceCtx::StartChecks)); + ASSERT_TRUE(NS_SUCCEEDED(res)); + } + + // Handle events + void GatheringStateChange(NrIceCtx* ctx, NrIceCtx::GatheringState state) { + if (shutting_down_) { + return; + } + if (state != NrIceCtx::ICE_CTX_GATHER_COMPLETE) { + return; + } + + std::cerr << name_ << " Gathering complete" << std::endl; + gathering_complete_ = true; + + std::cerr << name_ << " ATTRIBUTES:" << std::endl; + for (const auto& stream : ice_ctx_->GetStreams()) { + std::cerr << "Stream " << stream->GetId() << std::endl; + + std::vector<std::string> attributes = stream->GetAttributes(); + + for (const auto& attribute : attributes) { + std::cerr << attribute << std::endl; + } + } + std::cerr << std::endl; + } + + void CandidateInitialized(NrIceMediaStream* stream, + const std::string& raw_candidate, + const std::string& ufrag, + const std::string& mdns_addr, + const std::string& actual_addr) { + std::string candidate(FilterCandidate(raw_candidate)); + if (candidate.empty()) { + return; + } + std::cerr << "Candidate for stream " << stream->name() + << " initialized: " << candidate << std::endl; + candidates_[stream->name()].push_back(candidate); + + // If we are connected, then try to trickle to the other side. + if (remote_ && remote_->remote_ && (trickle_mode_ != TRICKLE_SIMULATE)) { + // first, find the index of the stream we've been given so + // we can get the corresponding stream on the remote side + for (size_t i = 0; i < stream_counter_; ++i) { + if (GetStream_s(i) == stream) { + ASSERT_GT(remote_->stream_counter_, i); + nsresult res = remote_->GetStream_s(i)->ParseTrickleCandidate( + candidate, ufrag, ""); + ASSERT_TRUE(NS_SUCCEEDED(res)); + return; + } + } + ADD_FAILURE() << "No matching stream found for " << stream; + } + } + + nsresult GetCandidatePairs_s(size_t stream_index, + std::vector<NrIceCandidatePair>* pairs) { + MOZ_ASSERT(pairs); + auto stream = GetStream_s(stream_index); + if (!stream) { + // Is there a better error for "no such index"? + ADD_FAILURE() << "No such media stream index: " << stream_index; + return NS_ERROR_INVALID_ARG; + } + + return stream->GetCandidatePairs(pairs); + } + + nsresult GetCandidatePairs(size_t stream_index, + std::vector<NrIceCandidatePair>* pairs) { + nsresult v; + test_utils_->SyncDispatchToSTS(WrapRunnableRet( + &v, this, &IceTestPeer::GetCandidatePairs_s, stream_index, pairs)); + return v; + } + + void DumpCandidatePair(const NrIceCandidatePair& pair) { + std::cerr << std::endl; + DumpCandidate("Local", pair.local); + DumpCandidate("Remote", pair.remote); + std::cerr << "state = " << pair.state << " priority = " << pair.priority + << " nominated = " << pair.nominated + << " selected = " << pair.selected + << " codeword = " << pair.codeword << std::endl; + } + + void DumpCandidatePairs_s(NrIceMediaStream* stream) { + std::vector<NrIceCandidatePair> pairs; + nsresult res = stream->GetCandidatePairs(&pairs); + ASSERT_TRUE(NS_SUCCEEDED(res)); + + std::cerr << "Begin list of candidate pairs [" << std::endl; + + for (auto& pair : pairs) { + DumpCandidatePair(pair); + } + std::cerr << "]" << std::endl; + } + + void DumpCandidatePairs_s() { + std::cerr << "Dumping candidate pairs for all streams [" << std::endl; + for (const auto& stream : ice_ctx_->GetStreams()) { + DumpCandidatePairs_s(stream.get()); + } + std::cerr << "]" << std::endl; + } + + bool CandidatePairsPriorityDescending( + const std::vector<NrIceCandidatePair>& pairs) { + // Verify that priority is descending + uint64_t priority = std::numeric_limits<uint64_t>::max(); + + for (size_t p = 0; p < pairs.size(); ++p) { + if (priority < pairs[p].priority) { + std::cerr << "Priority increased in subsequent pairs:" << std::endl; + DumpCandidatePair(pairs[p - 1]); + DumpCandidatePair(pairs[p]); + return false; + } + if (priority == pairs[p].priority) { + if (!IceCandidatePairCompare()(pairs[p], pairs[p - 1]) && + !IceCandidatePairCompare()(pairs[p - 1], pairs[p])) { + std::cerr << "Ignoring identical pair from trigger check" + << std::endl; + } else { + std::cerr << "Duplicate priority in subseqent pairs:" << std::endl; + DumpCandidatePair(pairs[p - 1]); + DumpCandidatePair(pairs[p]); + return false; + } + } + priority = pairs[p].priority; + } + return true; + } + + void UpdateAndValidateCandidatePairs( + size_t stream_index, std::vector<NrIceCandidatePair>* new_pairs) { + std::vector<NrIceCandidatePair> old_pairs = *new_pairs; + GetCandidatePairs(stream_index, new_pairs); + ASSERT_TRUE(CandidatePairsPriorityDescending(*new_pairs)) + << "New list of " + "candidate pairs is either not sorted in priority order, or has " + "duplicate priorities."; + ASSERT_TRUE(CandidatePairsPriorityDescending(old_pairs)) + << "Old list of " + "candidate pairs is either not sorted in priority order, or has " + "duplicate priorities. This indicates some bug in the test case."; + std::vector<NrIceCandidatePair> added_pairs; + std::vector<NrIceCandidatePair> removed_pairs; + + // set_difference computes the set of elements that are present in the + // first set, but not the second + // NrIceCandidatePair::operator< compares based on the priority, local + // candidate, and remote candidate in that order. This means this will + // catch cases where the priority has remained the same, but one of the + // candidates has changed. + std::set_difference((*new_pairs).begin(), (*new_pairs).end(), + old_pairs.begin(), old_pairs.end(), + std::inserter(added_pairs, added_pairs.begin()), + IceCandidatePairCompare()); + + std::set_difference(old_pairs.begin(), old_pairs.end(), + (*new_pairs).begin(), (*new_pairs).end(), + std::inserter(removed_pairs, removed_pairs.begin()), + IceCandidatePairCompare()); + + for (auto& added_pair : added_pairs) { + std::cerr << "Found new candidate pair." << std::endl; + DumpCandidatePair(added_pair); + } + + for (auto& removed_pair : removed_pairs) { + std::cerr << "Pre-existing candidate pair is now missing:" << std::endl; + DumpCandidatePair(removed_pair); + } + + ASSERT_TRUE(removed_pairs.empty()) + << "At least one candidate pair has " + "gone missing."; + } + + void StreamReady(NrIceMediaStream* stream) { + ++ready_ct_; + std::cerr << name_ << " Stream ready for " << stream->name() + << " ct=" << ready_ct_ << std::endl; + DumpCandidatePairs_s(stream); + } + void StreamFailed(NrIceMediaStream* stream) { + std::cerr << name_ << " Stream failed for " << stream->name() + << " ct=" << ready_ct_ << std::endl; + DumpCandidatePairs_s(stream); + } + + void ConnectionStateChange(NrIceCtx* ctx, NrIceCtx::ConnectionState state) { + (void)ctx; + switch (state) { + case NrIceCtx::ICE_CTX_INIT: + break; + case NrIceCtx::ICE_CTX_CHECKING: + std::cerr << name_ << " ICE reached checking" << std::endl; + ice_reached_checking_ = true; + break; + case NrIceCtx::ICE_CTX_CONNECTED: + std::cerr << name_ << " ICE connected" << std::endl; + ice_connected_ = true; + break; + case NrIceCtx::ICE_CTX_COMPLETED: + std::cerr << name_ << " ICE completed" << std::endl; + break; + case NrIceCtx::ICE_CTX_FAILED: + std::cerr << name_ << " ICE failed" << std::endl; + ice_failed_ = true; + break; + case NrIceCtx::ICE_CTX_DISCONNECTED: + std::cerr << name_ << " ICE disconnected" << std::endl; + ice_connected_ = false; + break; + default: + MOZ_CRASH(); + } + } + + void PacketReceived(NrIceMediaStream* stream, int component, + const unsigned char* data, int len) { + std::cerr << name_ << ": received " << len << " bytes" << std::endl; + ++received_; + } + + void SendPacket(int stream, int component, const unsigned char* data, + int len) { + auto media_stream = GetStream_s(stream); + if (!media_stream) { + ADD_FAILURE() << "No such stream " << stream; + return; + } + + ASSERT_TRUE(NS_SUCCEEDED(media_stream->SendPacket(component, data, len))); + + ++sent_; + std::cerr << name_ << ": sent " << len << " bytes" << std::endl; + } + + void SendFailure(int stream, int component) { + auto media_stream = GetStream_s(stream); + if (!media_stream) { + ADD_FAILURE() << "No such stream " << stream; + return; + } + + const std::string d("FAIL"); + ASSERT_TRUE(NS_FAILED(media_stream->SendPacket( + component, reinterpret_cast<const unsigned char*>(d.c_str()), + d.length()))); + + std::cerr << name_ << ": send failed as expected" << std::endl; + } + + void SetCandidateFilter(CandidateFilter filter) { + candidate_filter_ = filter; + } + + void ParseCandidate_s(size_t i, const std::string& candidate, + const std::string& mdns_addr) { + auto media_stream = GetStream_s(i); + ASSERT_TRUE(media_stream.get()) + << "No such stream " << i; + media_stream->ParseTrickleCandidate(candidate, "", mdns_addr); + } + + void ParseCandidate(size_t i, const std::string& candidate, + const std::string& mdns_addr) { + test_utils_->SyncDispatchToSTS(WrapRunnable( + this, &IceTestPeer::ParseCandidate_s, i, candidate, mdns_addr)); + } + + void DisableComponent_s(size_t index, int component_id) { + ASSERT_LT(index, stream_counter_); + auto stream = GetStream_s(index); + ASSERT_TRUE(stream.get()) + << "No such stream " << index; + nsresult res = stream->DisableComponent(component_id); + ASSERT_TRUE(NS_SUCCEEDED(res)); + } + + void DisableComponent(size_t stream, int component_id) { + test_utils_->SyncDispatchToSTS(WrapRunnable( + this, &IceTestPeer::DisableComponent_s, stream, component_id)); + } + + void AssertConsentRefresh_s(size_t index, int component_id, + ConsentStatus status) { + ASSERT_LT(index, stream_counter_); + auto stream = GetStream_s(index); + ASSERT_TRUE(stream.get()) + << "No such stream " << index; + bool can_send; + struct timeval timestamp; + nsresult res = + stream->GetConsentStatus(component_id, &can_send, ×tamp); + ASSERT_TRUE(NS_SUCCEEDED(res)); + if (status == CONSENT_EXPIRED) { + ASSERT_EQ(can_send, 0); + } else { + ASSERT_EQ(can_send, 1); + } + if (consent_timestamp_.tv_sec) { + if (status == CONSENT_FRESH) { + ASSERT_EQ(r_timeval_cmp(×tamp, &consent_timestamp_), 1); + } else { + ASSERT_EQ(r_timeval_cmp(×tamp, &consent_timestamp_), 0); + } + } + consent_timestamp_.tv_sec = timestamp.tv_sec; + consent_timestamp_.tv_usec = timestamp.tv_usec; + std::cerr << name_ + << ": new consent timestamp = " << consent_timestamp_.tv_sec + << "." << consent_timestamp_.tv_usec << std::endl; + } + + void AssertConsentRefresh(ConsentStatus status) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &IceTestPeer::AssertConsentRefresh_s, 0, 1, status)); + } + + void ChangeNetworkState_s(bool online) { + ice_ctx_->UpdateNetworkState(online); + } + + void ChangeNetworkStateToOffline() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &IceTestPeer::ChangeNetworkState_s, false)); + } + + void ChangeNetworkStateToOnline() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &IceTestPeer::ChangeNetworkState_s, true)); + } + + void SetControlling(NrIceCtx::Controlling controlling) { + nsresult res; + test_utils_->SyncDispatchToSTS(WrapRunnableRet( + &res, ice_ctx_, &NrIceCtx::SetControlling, controlling)); + ASSERT_TRUE(NS_SUCCEEDED(res)); + } + + NrIceCtx::Controlling GetControlling() { return ice_ctx_->GetControlling(); } + + void SetTiebreaker(uint64_t tiebreaker) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &IceTestPeer::SetTiebreaker_s, tiebreaker)); + } + + void SetTiebreaker_s(uint64_t tiebreaker) { + ice_ctx_->peer()->tiebreaker = tiebreaker; + } + + void SimulateIceLite() { + simulate_ice_lite_ = true; + SetControlling(NrIceCtx::ICE_CONTROLLED); + } + + nsresult GetDefaultCandidate(unsigned int stream, NrIceCandidate* cand) { + nsresult rv; + + test_utils_->SyncDispatchToSTS(WrapRunnableRet( + &rv, this, &IceTestPeer::GetDefaultCandidate_s, stream, cand)); + + return rv; + } + + nsresult GetDefaultCandidate_s(unsigned int index, NrIceCandidate* cand) { + return GetStream_s(index)->GetDefaultCandidate(1, cand); + } + + private: + std::string name_; + RefPtr<NrIceCtx> ice_ctx_; + bool offerer_; + std::map<std::string, std::vector<std::string>> candidates_; + // Maps from stream id to list of remote trickle candidates + std::map<size_t, std::vector<SchedulableTrickleCandidate*>> + controlled_trickle_candidates_; + std::map<std::string, std::pair<std::string, std::string>> mIceCredentials; + std::map<std::string, std::pair<std::string, std::string>> mOldIceCredentials; + size_t stream_counter_; + bool shutting_down_; + bool gathering_complete_; + int ready_ct_; + bool ice_connected_; + bool ice_failed_; + bool ice_reached_checking_; + size_t received_; + size_t sent_; + struct timeval consent_timestamp_; + NrIceResolverFake fake_resolver_; + RefPtr<NrIceResolver> dns_resolver_; + IceTestPeer* remote_; + CandidateFilter candidate_filter_; + NrIceCandidate::Type expected_local_type_; + std::string expected_local_transport_; + NrIceCandidate::Type expected_remote_type_; + std::string expected_remote_addr_; + TrickleMode trickle_mode_; + bool simulate_ice_lite_; + RefPtr<mozilla::TestNat> nat_; + MtransportTestUtils* test_utils_; +}; + +void SchedulableTrickleCandidate::Trickle() { + timer_handle_ = nullptr; + nsresult res = peer_->TrickleCandidate_s(candidate_, ufrag_, stream_); + ASSERT_TRUE(NS_SUCCEEDED(res)); +} + +class WebRtcIceGatherTest : public StunTest { + public: + void SetUp() override { + StunTest::SetUp(); + + Preferences::SetInt("media.peerconnection.ice.tcp_so_sock_count", 3); + + test_utils_->SyncDispatchToSTS(WrapRunnable( + TestStunServer::GetInstance(AF_INET), &TestStunServer::Reset)); + if (TestStunServer::GetInstance(AF_INET6)) { + test_utils_->SyncDispatchToSTS(WrapRunnable( + TestStunServer::GetInstance(AF_INET6), &TestStunServer::Reset)); + } + } + + void TearDown() override { + peer_ = nullptr; + StunTest::TearDown(); + } + + void EnsurePeer() { + if (!peer_) { + peer_ = + MakeUnique<IceTestPeer>("P1", test_utils_, true, NrIceCtx::Config()); + } + } + + void Gather(unsigned int waitTime = kDefaultTimeout, + bool default_route_only = false, + bool obfuscate_host_addresses = false) { + EnsurePeer(); + peer_->Gather(default_route_only, obfuscate_host_addresses); + + if (waitTime) { + WaitForGather(waitTime); + } + } + + void WaitForGather(unsigned int waitTime = kDefaultTimeout) { + ASSERT_TRUE_WAIT(peer_->gathering_complete(), waitTime); + } + + void AddStunServerWithResponse(const std::string& fake_addr, + uint16_t fake_port, const std::string& fqdn, + const std::string& proto, + std::vector<NrIceStunServer>* stun_servers) { + int family; + if (fake_addr.find(':') != std::string::npos) { + family = AF_INET6; + } else { + family = AF_INET; + } + + std::string stun_addr; + uint16_t stun_port; + if (proto == kNrIceTransportUdp) { + TestStunServer::GetInstance(family)->SetResponseAddr(fake_addr, + fake_port); + stun_addr = TestStunServer::GetInstance(family)->addr(); + stun_port = TestStunServer::GetInstance(family)->port(); + } else if (proto == kNrIceTransportTcp) { + TestStunTcpServer::GetInstance(family)->SetResponseAddr(fake_addr, + fake_port); + stun_addr = TestStunTcpServer::GetInstance(family)->addr(); + stun_port = TestStunTcpServer::GetInstance(family)->port(); + } else { + MOZ_CRASH(); + } + + if (!fqdn.empty()) { + peer_->SetFakeResolver(stun_addr, fqdn); + stun_addr = fqdn; + } + + stun_servers->push_back( + *NrIceStunServer::Create(stun_addr, stun_port, proto.c_str())); + + if (family == AF_INET6 && !fqdn.empty()) { + stun_servers->back().SetUseIPv6IfFqdn(); + } + } + + void UseFakeStunUdpServerWithResponse( + const std::string& fake_addr, uint16_t fake_port, + const std::string& fqdn = std::string()) { + EnsurePeer(); + std::vector<NrIceStunServer> stun_servers; + AddStunServerWithResponse(fake_addr, fake_port, fqdn, "udp", &stun_servers); + peer_->SetStunServers(stun_servers); + } + + void UseFakeStunTcpServerWithResponse( + const std::string& fake_addr, uint16_t fake_port, + const std::string& fqdn = std::string()) { + EnsurePeer(); + std::vector<NrIceStunServer> stun_servers; + AddStunServerWithResponse(fake_addr, fake_port, fqdn, "tcp", &stun_servers); + peer_->SetStunServers(stun_servers); + } + + void UseFakeStunUdpTcpServersWithResponse(const std::string& fake_udp_addr, + uint16_t fake_udp_port, + const std::string& fake_tcp_addr, + uint16_t fake_tcp_port) { + EnsurePeer(); + std::vector<NrIceStunServer> stun_servers; + AddStunServerWithResponse(fake_udp_addr, fake_udp_port, + "", // no fqdn + "udp", &stun_servers); + AddStunServerWithResponse(fake_tcp_addr, fake_tcp_port, + "", // no fqdn + "tcp", &stun_servers); + + peer_->SetStunServers(stun_servers); + } + + void UseTestStunServer() { + TestStunServer::GetInstance(AF_INET)->Reset(); + peer_->SetStunServer(TestStunServer::GetInstance(AF_INET)->addr(), + TestStunServer::GetInstance(AF_INET)->port()); + } + + // NB: Only does substring matching, watch out for stuff like "1.2.3.4" + // matching "21.2.3.47". " 1.2.3.4 " should not have false positives. + bool StreamHasMatchingCandidate(unsigned int stream, const std::string& match, + const std::string& match2 = "") { + std::vector<std::string> attributes = peer_->GetAttributes(stream); + for (auto& attribute : attributes) { + if (std::string::npos != attribute.find(match)) { + if (!match2.length() || std::string::npos != attribute.find(match2)) { + return true; + } + } + } + return false; + } + + void DumpAttributes(unsigned int stream) { + std::vector<std::string> attributes = peer_->GetAttributes(stream); + + std::cerr << "Attributes for stream " << stream << "->" << attributes.size() + << std::endl; + + for (const auto& a : attributes) { + std::cerr << "Attribute: " << a << std::endl; + } + } + + protected: + mozilla::UniquePtr<IceTestPeer> peer_; +}; + +class WebRtcIceConnectTest : public StunTest { + public: + WebRtcIceConnectTest() + : initted_(false), + test_stun_server_inited_(false), + use_nat_(false), + filtering_type_(TestNat::ENDPOINT_INDEPENDENT), + mapping_type_(TestNat::ENDPOINT_INDEPENDENT), + block_udp_(false) {} + + void SetUp() override { + StunTest::SetUp(); + + nsresult rv; + target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + ASSERT_TRUE(NS_SUCCEEDED(rv)); + } + + void TearDown() override { + p1_ = nullptr; + p2_ = nullptr; + + StunTest::TearDown(); + } + + void AddStream(int components) { + Init(); + p1_->AddStream(components); + p2_->AddStream(components); + } + + void RemoveStream(size_t index) { + p1_->RemoveStream(index); + p2_->RemoveStream(index); + } + + void Init(bool setup_stun_servers = true, + NrIceCtx::Policy ice_policy = NrIceCtx::ICE_POLICY_ALL) { + if (initted_) { + return; + } + + NrIceCtx::Config config; + config.mPolicy = ice_policy; + + p1_ = MakeUnique<IceTestPeer>("P1", test_utils_, true, config); + p2_ = MakeUnique<IceTestPeer>("P2", test_utils_, false, config); + InitPeer(p1_.get(), setup_stun_servers); + InitPeer(p2_.get(), setup_stun_servers); + + initted_ = true; + } + + void InitPeer(IceTestPeer* peer, bool setup_stun_servers = true) { + if (use_nat_) { + // If we enable nat simulation, but still use a real STUN server somewhere + // on the internet, we will see failures if there is a real NAT in + // addition to our simulated one, particularly if it disallows + // hairpinning. + if (setup_stun_servers) { + InitTestStunServer(); + peer->UseTestStunServer(); + } + peer->UseNat(); + peer->SetFilteringType(filtering_type_); + peer->SetMappingType(mapping_type_); + peer->SetBlockUdp(block_udp_); + } else if (setup_stun_servers) { + std::vector<NrIceStunServer> stun_servers; + + stun_servers.push_back(*NrIceStunServer::Create( + stun_server_address_, kDefaultStunServerPort, kNrIceTransportUdp)); + + peer->SetStunServers(stun_servers); + } + } + + bool Gather(unsigned int waitTime = kDefaultTimeout, + bool default_route_only = false) { + Init(); + + return GatherCallerAndCallee(p1_.get(), p2_.get(), waitTime, + default_route_only); + } + + bool GatherCallerAndCallee(IceTestPeer* caller, IceTestPeer* callee, + unsigned int waitTime = kDefaultTimeout, + bool default_route_only = false) { + caller->Gather(default_route_only); + callee->Gather(default_route_only); + + if (waitTime) { + EXPECT_TRUE_WAIT(caller->gathering_complete(), waitTime); + if (!caller->gathering_complete()) return false; + EXPECT_TRUE_WAIT(callee->gathering_complete(), waitTime); + if (!callee->gathering_complete()) return false; + } + return true; + } + + void UseNat() { + // to be useful, this method should be called before Init + ASSERT_FALSE(initted_); + use_nat_ = true; + } + + void SetFilteringType(TestNat::NatBehavior type) { + // to be useful, this method should be called before Init + ASSERT_FALSE(initted_); + filtering_type_ = type; + } + + void SetMappingType(TestNat::NatBehavior type) { + // to be useful, this method should be called before Init + ASSERT_FALSE(initted_); + mapping_type_ = type; + } + + void BlockUdp() { + // note: |block_udp_| is used only in InitPeer. + // Use IceTestPeer::SetBlockUdp to act on the peer directly. + block_udp_ = true; + } + + void SetupAndCheckConsent() { + p1_->SetTimerDivider(10); + p2_->SetTimerDivider(10); + ASSERT_TRUE(Gather()); + Connect(); + p1_->AssertConsentRefresh(CONSENT_FRESH); + p2_->AssertConsentRefresh(CONSENT_FRESH); + SendReceive(); + } + + void AssertConsentRefresh(ConsentStatus status = CONSENT_FRESH) { + p1_->AssertConsentRefresh(status); + p2_->AssertConsentRefresh(status); + } + + void InitTestStunServer() { + if (test_stun_server_inited_) { + return; + } + + std::cerr << "Resetting TestStunServer" << std::endl; + TestStunServer::GetInstance(AF_INET)->Reset(); + test_stun_server_inited_ = true; + } + + void UseTestStunServer() { + InitTestStunServer(); + p1_->UseTestStunServer(); + p2_->UseTestStunServer(); + } + + void SetTurnServer(const std::string addr, uint16_t port, + const std::string username, const std::string password, + const char* transport = kNrIceTransportUdp) { + p1_->SetTurnServer(addr, port, username, password, transport); + p2_->SetTurnServer(addr, port, username, password, transport); + } + + void SetTurnServers(const std::vector<NrIceTurnServer>& servers) { + p1_->SetTurnServers(servers); + p2_->SetTurnServers(servers); + } + + void SetCandidateFilter(CandidateFilter filter, bool both = true) { + p1_->SetCandidateFilter(filter); + if (both) { + p2_->SetCandidateFilter(filter); + } + } + + void Connect() { ConnectCallerAndCallee(p1_.get(), p2_.get()); } + + void ConnectCallerAndCallee(IceTestPeer* caller, IceTestPeer* callee, + TrickleMode mode = TRICKLE_NONE) { + ASSERT_TRUE(caller->ready_ct() == 0); + ASSERT_TRUE(caller->ice_connected() == 0); + ASSERT_TRUE(caller->ice_reached_checking() == 0); + ASSERT_TRUE(callee->ready_ct() == 0); + ASSERT_TRUE(callee->ice_connected() == 0); + ASSERT_TRUE(callee->ice_reached_checking() == 0); + + // IceTestPeer::Connect grabs attributes from the first arg, and + // gives them to |this|, meaning that callee->Connect(caller, ...) + // simulates caller sending an offer to callee. Order matters here + // because it determines which peer is controlling. + callee->Connect(caller, mode); + caller->Connect(callee, mode); + + if (mode != TRICKLE_SIMULATE) { + ASSERT_TRUE_WAIT(caller->ice_connected() && callee->ice_connected(), + kDefaultTimeout); + ASSERT_TRUE(caller->ready_ct() >= 1 && callee->ready_ct() >= 1); + ASSERT_TRUE(caller->ice_reached_checking()); + ASSERT_TRUE(callee->ice_reached_checking()); + + caller->DumpAndCheckActiveCandidates(); + callee->DumpAndCheckActiveCandidates(); + } + } + + void SetExpectedTypes(NrIceCandidate::Type local, NrIceCandidate::Type remote, + std::string transport = kNrIceTransportUdp) { + p1_->SetExpectedTypes(local, remote, transport); + p2_->SetExpectedTypes(local, remote, transport); + } + + void SetExpectedRemoteCandidateAddr(const std::string& addr) { + p1_->SetExpectedRemoteCandidateAddr(addr); + p2_->SetExpectedRemoteCandidateAddr(addr); + } + + void ConnectP1(TrickleMode mode = TRICKLE_NONE) { + p1_->Connect(p2_.get(), mode); + } + + void ConnectP2(TrickleMode mode = TRICKLE_NONE) { + p2_->Connect(p1_.get(), mode); + } + + void WaitForConnectedStreams(int expected_streams = 1) { + ASSERT_TRUE_WAIT(p1_->ready_ct() == expected_streams && + p2_->ready_ct() == expected_streams, + kDefaultTimeout); + ASSERT_TRUE_WAIT(p1_->ice_connected() && p2_->ice_connected(), + kDefaultTimeout); + } + + void AssertCheckingReached() { + ASSERT_TRUE(p1_->ice_reached_checking()); + ASSERT_TRUE(p2_->ice_reached_checking()); + } + + void WaitForConnected(unsigned int timeout = kDefaultTimeout) { + ASSERT_TRUE_WAIT(p1_->ice_connected(), timeout); + ASSERT_TRUE_WAIT(p2_->ice_connected(), timeout); + } + + void WaitForGather() { + ASSERT_TRUE_WAIT(p1_->gathering_complete(), kDefaultTimeout); + ASSERT_TRUE_WAIT(p2_->gathering_complete(), kDefaultTimeout); + } + + void WaitForDisconnected(unsigned int timeout = kDefaultTimeout) { + ASSERT_TRUE(p1_->ice_connected()); + ASSERT_TRUE(p2_->ice_connected()); + ASSERT_TRUE_WAIT(p1_->ice_connected() == 0 && p2_->ice_connected() == 0, + timeout); + } + + void WaitForFailed(unsigned int timeout = kDefaultTimeout) { + ASSERT_TRUE_WAIT(p1_->ice_failed() && p2_->ice_failed(), timeout); + } + + void ConnectTrickle(TrickleMode trickle = TRICKLE_SIMULATE) { + p2_->Connect(p1_.get(), trickle); + p1_->Connect(p2_.get(), trickle); + } + + void SimulateTrickle(size_t stream) { + p1_->SimulateTrickle(stream); + p2_->SimulateTrickle(stream); + ASSERT_TRUE_WAIT(p1_->is_ready(stream), kDefaultTimeout); + ASSERT_TRUE_WAIT(p2_->is_ready(stream), kDefaultTimeout); + } + + void SimulateTrickleP1(size_t stream) { p1_->SimulateTrickle(stream); } + + void SimulateTrickleP2(size_t stream) { p2_->SimulateTrickle(stream); } + + void CloseP1() { p1_->Close(); } + + void ConnectThenDelete() { + p2_->Connect(p1_.get(), TRICKLE_NONE, false); + p1_->Connect(p2_.get(), TRICKLE_NONE, true); + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &WebRtcIceConnectTest::CloseP1)); + p2_->StartChecks(); + + // Wait to see if we crash + PR_Sleep(PR_MillisecondsToInterval(kDefaultTimeout)); + } + + // default is p1_ sending to p2_ + void SendReceive() { SendReceive(p1_.get(), p2_.get()); } + + void SendReceive(IceTestPeer* p1, IceTestPeer* p2, + bool expect_tx_failure = false, + bool expect_rx_failure = false) { + size_t previousSent = p1->sent(); + size_t previousReceived = p2->received(); + + if (expect_tx_failure) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(p1, &IceTestPeer::SendFailure, 0, 1)); + ASSERT_EQ(previousSent, p1->sent()); + } else { + test_utils_->SyncDispatchToSTS( + WrapRunnable(p1, &IceTestPeer::SendPacket, 0, 1, + reinterpret_cast<const unsigned char*>("TEST"), 4)); + ASSERT_EQ(previousSent + 1, p1->sent()); + } + if (expect_rx_failure) { + usleep(1000); + ASSERT_EQ(previousReceived, p2->received()); + } else { + ASSERT_TRUE_WAIT(p2->received() == previousReceived + 1, 1000); + } + } + + void SendFailure() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(p1_.get(), &IceTestPeer::SendFailure, 0, 1)); + } + + protected: + bool initted_; + bool test_stun_server_inited_; + nsCOMPtr<nsIEventTarget> target_; + mozilla::UniquePtr<IceTestPeer> p1_; + mozilla::UniquePtr<IceTestPeer> p2_; + bool use_nat_; + TestNat::NatBehavior filtering_type_; + TestNat::NatBehavior mapping_type_; + bool block_udp_; +}; + +class WebRtcIcePrioritizerTest : public StunTest { + public: + WebRtcIcePrioritizerTest() : prioritizer_(nullptr) {} + + ~WebRtcIcePrioritizerTest() { + if (prioritizer_) { + nr_interface_prioritizer_destroy(&prioritizer_); + } + } + + void SetPriorizer(nr_interface_prioritizer* prioritizer) { + prioritizer_ = prioritizer; + } + + void AddInterface(const std::string& num, int type, int estimated_speed) { + std::string str_addr = "10.0.0." + num; + std::string ifname = "eth" + num; + nr_local_addr local_addr; + local_addr.interface.type = type; + local_addr.interface.estimated_speed = estimated_speed; + + int r = nr_str_port_to_transport_addr(str_addr.c_str(), 0, IPPROTO_UDP, + &(local_addr.addr)); + ASSERT_EQ(0, r); + strncpy(local_addr.addr.ifname, ifname.c_str(), MAXIFNAME - 1); + local_addr.addr.ifname[MAXIFNAME - 1] = '\0'; + + r = nr_interface_prioritizer_add_interface(prioritizer_, &local_addr); + ASSERT_EQ(0, r); + r = nr_interface_prioritizer_sort_preference(prioritizer_); + ASSERT_EQ(0, r); + } + + void HasLowerPreference(const std::string& num1, const std::string& num2) { + std::string key1 = "eth" + num1 + ":10.0.0." + num1; + std::string key2 = "eth" + num2 + ":10.0.0." + num2; + UCHAR pref1, pref2; + int r = nr_interface_prioritizer_get_priority(prioritizer_, key1.c_str(), + &pref1); + ASSERT_EQ(0, r); + r = nr_interface_prioritizer_get_priority(prioritizer_, key2.c_str(), + &pref2); + ASSERT_EQ(0, r); + ASSERT_LE(pref1, pref2); + } + + private: + nr_interface_prioritizer* prioritizer_; +}; + +class WebRtcIcePacketFilterTest : public StunTest { + public: + WebRtcIcePacketFilterTest() : udp_filter_(nullptr), tcp_filter_(nullptr) {} + + void SetUp() { + StunTest::SetUp(); + + NrIceCtx::InitializeGlobals(NrIceCtx::GlobalConfig()); + + // Set up enough of the ICE ctx to allow the packet filter to work + ice_ctx_ = NrIceCtx::Create("test"); + + nsCOMPtr<nsISocketFilterHandler> udp_handler = + do_GetService(NS_STUN_UDP_SOCKET_FILTER_HANDLER_CONTRACTID); + ASSERT_TRUE(udp_handler); + udp_handler->NewFilter(getter_AddRefs(udp_filter_)); + + nsCOMPtr<nsISocketFilterHandler> tcp_handler = + do_GetService(NS_STUN_TCP_SOCKET_FILTER_HANDLER_CONTRACTID); + ASSERT_TRUE(tcp_handler); + tcp_handler->NewFilter(getter_AddRefs(tcp_filter_)); + } + + void TearDown() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &WebRtcIcePacketFilterTest::TearDown_s)); + StunTest::TearDown(); + } + + void TearDown_s() { ice_ctx_ = nullptr; } + + void TestIncoming(const uint8_t* data, uint32_t len, uint8_t from_addr, + int from_port, bool expected_result) { + mozilla::net::NetAddr addr; + MakeNetAddr(&addr, from_addr, from_port); + bool result; + nsresult rv = udp_filter_->FilterPacket( + &addr, data, len, nsISocketFilter::SF_INCOMING, &result); + ASSERT_EQ(NS_OK, rv); + ASSERT_EQ(expected_result, result); + } + + void TestIncomingTcp(const uint8_t* data, uint32_t len, + bool expected_result) { + mozilla::net::NetAddr addr; + bool result; + nsresult rv = tcp_filter_->FilterPacket( + &addr, data, len, nsISocketFilter::SF_INCOMING, &result); + ASSERT_EQ(NS_OK, rv); + ASSERT_EQ(expected_result, result); + } + + void TestIncomingTcpFramed(const uint8_t* data, uint32_t len, + bool expected_result) { + mozilla::net::NetAddr addr; + bool result; + uint8_t* framed_data = new uint8_t[len + 2]; + framed_data[0] = htons(len); + memcpy(&framed_data[2], data, len); + nsresult rv = tcp_filter_->FilterPacket( + &addr, framed_data, len + 2, nsISocketFilter::SF_INCOMING, &result); + ASSERT_EQ(NS_OK, rv); + ASSERT_EQ(expected_result, result); + delete[] framed_data; + } + + void TestOutgoing(const uint8_t* data, uint32_t len, uint8_t to_addr, + int to_port, bool expected_result) { + mozilla::net::NetAddr addr; + MakeNetAddr(&addr, to_addr, to_port); + bool result; + nsresult rv = udp_filter_->FilterPacket( + &addr, data, len, nsISocketFilter::SF_OUTGOING, &result); + ASSERT_EQ(NS_OK, rv); + ASSERT_EQ(expected_result, result); + } + + void TestOutgoingTcp(const uint8_t* data, uint32_t len, + bool expected_result) { + mozilla::net::NetAddr addr; + bool result; + nsresult rv = tcp_filter_->FilterPacket( + &addr, data, len, nsISocketFilter::SF_OUTGOING, &result); + ASSERT_EQ(NS_OK, rv); + ASSERT_EQ(expected_result, result); + } + + void TestOutgoingTcpFramed(const uint8_t* data, uint32_t len, + bool expected_result) { + mozilla::net::NetAddr addr; + bool result; + uint8_t* framed_data = new uint8_t[len + 2]; + framed_data[0] = htons(len); + memcpy(&framed_data[2], data, len); + nsresult rv = tcp_filter_->FilterPacket( + &addr, framed_data, len + 2, nsISocketFilter::SF_OUTGOING, &result); + ASSERT_EQ(NS_OK, rv); + ASSERT_EQ(expected_result, result); + delete[] framed_data; + } + + private: + void MakeNetAddr(mozilla::net::NetAddr* net_addr, uint8_t last_digit, + uint16_t port) { + net_addr->inet.family = AF_INET; + net_addr->inet.ip = 192 << 24 | 168 << 16 | 1 << 8 | last_digit; + net_addr->inet.port = port; + } + + nsCOMPtr<nsISocketFilter> udp_filter_; + nsCOMPtr<nsISocketFilter> tcp_filter_; + RefPtr<NrIceCtx> ice_ctx_; +}; +} // end namespace + +TEST_F(WebRtcIceGatherTest, TestGatherFakeStunServerHostnameNoResolver) { + if (stun_server_hostname_.empty()) { + return; + } + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->SetStunServer(stun_server_hostname_, kDefaultStunServerPort); + peer_->AddStream(1); + Gather(); +} + +// Disabled because google isn't running any TCP stun servers right now +TEST_F(WebRtcIceGatherTest, + DISABLED_TestGatherFakeStunServerTcpHostnameNoResolver) { + if (stun_server_hostname_.empty()) { + return; + } + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->SetStunServer(stun_server_hostname_, kDefaultStunServerPort, + kNrIceTransportTcp); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " TCP ")); +} + +TEST_F(WebRtcIceGatherTest, TestGatherFakeStunServerIpAddress) { + if (stun_server_address_.empty()) { + return; + } + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->SetStunServer(stun_server_address_, kDefaultStunServerPort); + peer_->SetFakeResolver(stun_server_address_, stun_server_hostname_); + peer_->AddStream(1); + Gather(); +} + +TEST_F(WebRtcIceGatherTest, TestGatherStunServerIpAddressNoHost) { + if (stun_server_address_.empty()) { + return; + } + + { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + } + + NrIceCtx::Config config; + config.mPolicy = NrIceCtx::ICE_POLICY_NO_HOST; + peer_ = MakeUnique<IceTestPeer>("P1", test_utils_, true, config); + peer_->AddStream(1); + peer_->SetStunServer(stun_server_address_, kDefaultStunServerPort); + peer_->SetFakeResolver(stun_server_address_, stun_server_hostname_); + Gather(); + ASSERT_FALSE(StreamHasMatchingCandidate(0, " host ")); +} + +TEST_F(WebRtcIceGatherTest, TestGatherFakeStunServerHostname) { + if (stun_server_hostname_.empty()) { + return; + } + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->SetStunServer(stun_server_hostname_, kDefaultStunServerPort); + peer_->SetFakeResolver(stun_server_address_, stun_server_hostname_); + peer_->AddStream(1); + Gather(); +} + +TEST_F(WebRtcIceGatherTest, TestGatherFakeStunBogusHostname) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->SetStunServer(kBogusStunServerHostname, kDefaultStunServerPort); + peer_->SetFakeResolver(stun_server_address_, stun_server_hostname_); + peer_->AddStream(1); + Gather(); +} + +TEST_F(WebRtcIceGatherTest, TestGatherDNSStunServerIpAddress) { + if (stun_server_address_.empty()) { + return; + } + + { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + } + + // A srflx candidate is considered redundant and discarded if its address + // equals that of a host candidate. (Frequently, a srflx candidate and a host + // candidate have equal addresses when the agent is not behind a NAT.) So set + // ICE_POLICY_NO_HOST here to ensure that a srflx candidate is not falsely + // discarded in this test. + NrIceCtx::Config config; + config.mPolicy = NrIceCtx::ICE_POLICY_NO_HOST; + peer_ = MakeUnique<IceTestPeer>("P1", test_utils_, true, config); + + peer_->SetStunServer(stun_server_address_, kDefaultStunServerPort); + peer_->SetDNSResolver(); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " UDP ")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "typ srflx raddr")); +} + +// Disabled because google isn't running any TCP stun servers right now +TEST_F(WebRtcIceGatherTest, DISABLED_TestGatherDNSStunServerIpAddressTcp) { + if (stun_server_address_.empty()) { + return; + } + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->SetStunServer(stun_server_address_, kDefaultStunServerPort, + kNrIceTransportTcp); + peer_->SetDNSResolver(); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "tcptype passive")); + ASSERT_FALSE(StreamHasMatchingCandidate(0, "tcptype passive", " 9 ")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "tcptype so")); + ASSERT_FALSE(StreamHasMatchingCandidate(0, "tcptype so", " 9 ")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "tcptype active", " 9 ")); +} + +TEST_F(WebRtcIceGatherTest, TestGatherDNSStunServerHostname) { + if (stun_server_hostname_.empty()) { + return; + } + + { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + } + + // A srflx candidate is considered redundant and discarded if its address + // equals that of a host candidate. (Frequently, a srflx candidate and a host + // candidate have equal addresses when the agent is not behind a NAT.) So set + // ICE_POLICY_NO_HOST here to ensure that a srflx candidate is not falsely + // discarded in this test. + NrIceCtx::Config config; + config.mPolicy = NrIceCtx::ICE_POLICY_NO_HOST; + peer_ = MakeUnique<IceTestPeer>("P1", test_utils_, true, config); + + peer_->SetStunServer(stun_server_hostname_, kDefaultStunServerPort); + peer_->SetDNSResolver(); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " UDP ")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "typ srflx raddr")); +} + +// Disabled because google isn't running any TCP stun servers right now +TEST_F(WebRtcIceGatherTest, DISABLED_TestGatherDNSStunServerHostnameTcp) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->SetStunServer(stun_server_hostname_, kDefaultStunServerPort, + kNrIceTransportTcp); + peer_->SetDNSResolver(); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "tcptype passive")); + ASSERT_FALSE(StreamHasMatchingCandidate(0, "tcptype passive", " 9 ")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "tcptype so")); + ASSERT_FALSE(StreamHasMatchingCandidate(0, "tcptype so", " 9 ")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "tcptype active", " 9 ")); +} + +// Disabled because google isn't running any TCP stun servers right now +TEST_F(WebRtcIceGatherTest, + DISABLED_TestGatherDNSStunServerHostnameBothUdpTcp) { + if (stun_server_hostname_.empty()) { + return; + } + + std::vector<NrIceStunServer> stun_servers; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + stun_servers.push_back(*NrIceStunServer::Create( + stun_server_hostname_, kDefaultStunServerPort, kNrIceTransportUdp)); + stun_servers.push_back(*NrIceStunServer::Create( + stun_server_hostname_, kDefaultStunServerPort, kNrIceTransportTcp)); + peer_->SetStunServers(stun_servers); + peer_->SetDNSResolver(); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " UDP ")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " TCP ")); +} + +// Disabled because google isn't running any TCP stun servers right now +TEST_F(WebRtcIceGatherTest, + DISABLED_TestGatherDNSStunServerIpAddressBothUdpTcp) { + if (stun_server_address_.empty()) { + return; + } + + std::vector<NrIceStunServer> stun_servers; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + stun_servers.push_back(*NrIceStunServer::Create( + stun_server_address_, kDefaultStunServerPort, kNrIceTransportUdp)); + stun_servers.push_back(*NrIceStunServer::Create( + stun_server_address_, kDefaultStunServerPort, kNrIceTransportTcp)); + peer_->SetStunServers(stun_servers); + peer_->SetDNSResolver(); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " UDP ")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " TCP ")); +} + +TEST_F(WebRtcIceGatherTest, TestGatherDNSStunBogusHostname) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->SetStunServer(kBogusStunServerHostname, kDefaultStunServerPort); + peer_->SetDNSResolver(); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " UDP ")); +} + +// Disabled because google isn't running any TCP stun servers right now +TEST_F(WebRtcIceGatherTest, DISABLED_TestGatherDNSStunBogusHostnameTcp) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->SetStunServer(kBogusStunServerHostname, kDefaultStunServerPort, + kNrIceTransportTcp); + peer_->SetDNSResolver(); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " TCP ")); +} + +TEST_F(WebRtcIceGatherTest, TestDefaultCandidate) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->SetStunServer(stun_server_hostname_, kDefaultStunServerPort); + peer_->AddStream(1); + Gather(); + NrIceCandidate default_candidate; + ASSERT_TRUE(NS_SUCCEEDED(peer_->GetDefaultCandidate(0, &default_candidate))); +} + +TEST_F(WebRtcIceGatherTest, TestGatherTurn) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + if (turn_server_.empty()) return; + peer_->SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_, kNrIceTransportUdp); + peer_->AddStream(1); + Gather(); +} + +TEST_F(WebRtcIceGatherTest, TestGatherTurnTcp) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + if (turn_server_.empty()) return; + peer_->SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_, kNrIceTransportTcp); + peer_->AddStream(1); + Gather(); +} + +TEST_F(WebRtcIceGatherTest, TestGatherDisableComponent) { + if (stun_server_hostname_.empty()) { + return; + } + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->SetStunServer(stun_server_hostname_, kDefaultStunServerPort); + peer_->AddStream(1); + peer_->AddStream(2); + peer_->DisableComponent(1, 2); + Gather(); + std::vector<std::string> attributes = peer_->GetAttributes(1); + + for (auto& attribute : attributes) { + if (attribute.find("candidate:") != std::string::npos) { + size_t sp1 = attribute.find(' '); + ASSERT_EQ(0, attribute.compare(sp1 + 1, 1, "1", 1)); + } + } +} + +TEST_F(WebRtcIceGatherTest, TestGatherVerifyNoLoopback) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->AddStream(1); + Gather(); + ASSERT_FALSE(StreamHasMatchingCandidate(0, "127.0.0.1")); +} + +TEST_F(WebRtcIceGatherTest, TestGatherAllowLoopback) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + config.mAllowLoopback = true; + NrIceCtx::InitializeGlobals(config); + + // Set up peer with loopback allowed. + peer_ = MakeUnique<IceTestPeer>("P1", test_utils_, true, NrIceCtx::Config()); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "127.0.0.1")); +} + +TEST_F(WebRtcIceGatherTest, TestGatherTcpDisabledNoStun) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->AddStream(1); + Gather(); + ASSERT_FALSE(StreamHasMatchingCandidate(0, " TCP ")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " UDP ")); +} + +TEST_F(WebRtcIceGatherTest, VerifyTestStunServer) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseFakeStunUdpServerWithResponse("192.0.2.133", 3333); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " 192.0.2.133 3333 ")); +} + +TEST_F(WebRtcIceGatherTest, VerifyTestStunTcpServer) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + UseFakeStunTcpServerWithResponse("192.0.2.233", 3333); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " 192.0.2.233 3333 typ srflx", + " tcptype ")); +} + +TEST_F(WebRtcIceGatherTest, VerifyTestStunServerV6) { + if (!TestStunServer::GetInstance(AF_INET6)) { + // No V6 addresses + return; + } + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseFakeStunUdpServerWithResponse("beef::", 3333); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " beef:: 3333 ")); +} + +TEST_F(WebRtcIceGatherTest, VerifyTestStunServerFQDN) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseFakeStunUdpServerWithResponse("192.0.2.133", 3333, "stun.example.com"); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " 192.0.2.133 3333 ")); +} + +TEST_F(WebRtcIceGatherTest, VerifyTestStunServerV6FQDN) { + if (!TestStunServer::GetInstance(AF_INET6)) { + // No V6 addresses + return; + } + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseFakeStunUdpServerWithResponse("beef::", 3333, "stun.example.com"); + peer_->AddStream(1); + Gather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " beef:: 3333 ")); +} + +TEST_F(WebRtcIceGatherTest, TestStunServerReturnsWildcardAddr) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseFakeStunUdpServerWithResponse("0.0.0.0", 3333); + peer_->AddStream(1); + Gather(kDefaultTimeout * 3); + ASSERT_FALSE(StreamHasMatchingCandidate(0, " 0.0.0.0 ")); +} + +TEST_F(WebRtcIceGatherTest, TestStunServerReturnsWildcardAddrV6) { + if (!TestStunServer::GetInstance(AF_INET6)) { + // No V6 addresses + return; + } + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseFakeStunUdpServerWithResponse("::", 3333); + peer_->AddStream(1); + Gather(kDefaultTimeout * 3); + ASSERT_FALSE(StreamHasMatchingCandidate(0, " :: ")); +} + +TEST_F(WebRtcIceGatherTest, TestStunServerReturnsPort0) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseFakeStunUdpServerWithResponse("192.0.2.133", 0); + peer_->AddStream(1); + Gather(kDefaultTimeout * 3); + ASSERT_FALSE(StreamHasMatchingCandidate(0, " 192.0.2.133 0 ")); +} + +TEST_F(WebRtcIceGatherTest, TestStunServerReturnsLoopbackAddr) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseFakeStunUdpServerWithResponse("127.0.0.133", 3333); + peer_->AddStream(1); + Gather(kDefaultTimeout * 3); + ASSERT_FALSE(StreamHasMatchingCandidate(0, " 127.0.0.133 ")); +} + +TEST_F(WebRtcIceGatherTest, TestStunServerReturnsLoopbackAddrV6) { + if (!TestStunServer::GetInstance(AF_INET6)) { + // No V6 addresses + return; + } + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseFakeStunUdpServerWithResponse("::1", 3333); + peer_->AddStream(1); + Gather(kDefaultTimeout * 3); + ASSERT_FALSE(StreamHasMatchingCandidate(0, " ::1 ")); +} + +TEST_F(WebRtcIceGatherTest, TestStunServerTrickle) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseFakeStunUdpServerWithResponse("192.0.2.1", 3333); + peer_->AddStream(1); + TestStunServer::GetInstance(AF_INET)->SetDropInitialPackets(3); + Gather(0); + ASSERT_FALSE(StreamHasMatchingCandidate(0, "192.0.2.1")); + WaitForGather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "192.0.2.1")); +} + +// Test no host with our fake STUN server and apparently NATted. +TEST_F(WebRtcIceGatherTest, TestFakeStunServerNatedNoHost) { + { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + } + + NrIceCtx::Config config; + config.mPolicy = NrIceCtx::ICE_POLICY_NO_HOST; + peer_ = MakeUnique<IceTestPeer>("P1", test_utils_, true, config); + UseFakeStunUdpServerWithResponse("192.0.2.1", 3333); + peer_->AddStream(1); + Gather(0); + WaitForGather(); + DumpAttributes(0); + ASSERT_FALSE(StreamHasMatchingCandidate(0, "host")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "srflx")); + NrIceCandidate default_candidate; + nsresult rv = peer_->GetDefaultCandidate(0, &default_candidate); + if (NS_SUCCEEDED(rv)) { + ASSERT_NE(NrIceCandidate::ICE_HOST, default_candidate.type); + } +} + +// Test no host with our fake STUN server and apparently non-NATted. +TEST_F(WebRtcIceGatherTest, TestFakeStunServerNoNatNoHost) { + { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + } + + NrIceCtx::Config config; + config.mPolicy = NrIceCtx::ICE_POLICY_NO_HOST; + peer_ = MakeUnique<IceTestPeer>("P1", test_utils_, true, config); + UseTestStunServer(); + peer_->AddStream(1); + Gather(0); + WaitForGather(); + DumpAttributes(0); + ASSERT_FALSE(StreamHasMatchingCandidate(0, "host")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "srflx")); +} + +// Test that srflx candidate is discarded in non-NATted environment if host +// address obfuscation is not enabled. +TEST_F(WebRtcIceGatherTest, + TestSrflxCandidateDiscardedWithObfuscateHostAddressesNotEnabled) { + { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + } + + NrIceCtx::Config config; + peer_ = MakeUnique<IceTestPeer>("P1", test_utils_, true, config); + UseTestStunServer(); + peer_->AddStream(1); + Gather(0, false, false); + WaitForGather(); + DumpAttributes(0); + EXPECT_TRUE(StreamHasMatchingCandidate(0, "host")); + EXPECT_FALSE(StreamHasMatchingCandidate(0, "srflx")); +} + +// Test that srflx candidate is generated in non-NATted environment if host +// address obfuscation is enabled. +TEST_F(WebRtcIceGatherTest, + TestSrflxCandidateGeneratedWithObfuscateHostAddressesEnabled) { + { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + } + + NrIceCtx::Config config; + peer_ = MakeUnique<IceTestPeer>("P1", test_utils_, true, config); + UseTestStunServer(); + peer_->AddStream(1); + Gather(0, false, true); + WaitForGather(); + DumpAttributes(0); + EXPECT_TRUE(StreamHasMatchingCandidate(0, "host")); + EXPECT_TRUE(StreamHasMatchingCandidate(0, "srflx")); +} + +TEST_F(WebRtcIceGatherTest, TestStunTcpServerTrickle) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + UseFakeStunTcpServerWithResponse("192.0.3.1", 3333); + TestStunTcpServer::GetInstance(AF_INET)->SetDelay(500); + peer_->AddStream(1); + Gather(0); + ASSERT_FALSE(StreamHasMatchingCandidate(0, " 192.0.3.1 ", " tcptype ")); + WaitForGather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " 192.0.3.1 ", " tcptype ")); +} + +TEST_F(WebRtcIceGatherTest, TestStunTcpAndUdpServerTrickle) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + UseFakeStunUdpTcpServersWithResponse("192.0.2.1", 3333, "192.0.3.1", 3333); + TestStunServer::GetInstance(AF_INET)->SetDropInitialPackets(3); + TestStunTcpServer::GetInstance(AF_INET)->SetDelay(500); + peer_->AddStream(1); + Gather(0); + ASSERT_FALSE(StreamHasMatchingCandidate(0, "192.0.2.1", "UDP")); + ASSERT_FALSE(StreamHasMatchingCandidate(0, " 192.0.3.1 ", " tcptype ")); + WaitForGather(); + ASSERT_TRUE(StreamHasMatchingCandidate(0, "192.0.2.1", "UDP")); + ASSERT_TRUE(StreamHasMatchingCandidate(0, " 192.0.3.1 ", " tcptype ")); +} + +TEST_F(WebRtcIceGatherTest, TestSetIceControlling) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->AddStream(1); + peer_->SetControlling(NrIceCtx::ICE_CONTROLLING); + NrIceCtx::Controlling controlling = peer_->GetControlling(); + ASSERT_EQ(NrIceCtx::ICE_CONTROLLING, controlling); + // SetControlling should only allow setting this once + peer_->SetControlling(NrIceCtx::ICE_CONTROLLED); + controlling = peer_->GetControlling(); + ASSERT_EQ(NrIceCtx::ICE_CONTROLLING, controlling); +} + +TEST_F(WebRtcIceGatherTest, TestSetIceControlled) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + EnsurePeer(); + peer_->AddStream(1); + peer_->SetControlling(NrIceCtx::ICE_CONTROLLED); + NrIceCtx::Controlling controlling = peer_->GetControlling(); + ASSERT_EQ(NrIceCtx::ICE_CONTROLLED, controlling); + // SetControlling should only allow setting this once + peer_->SetControlling(NrIceCtx::ICE_CONTROLLING); + controlling = peer_->GetControlling(); + ASSERT_EQ(NrIceCtx::ICE_CONTROLLED, controlling); +} + +TEST_F(WebRtcIceConnectTest, TestGather) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); +} + +TEST_F(WebRtcIceConnectTest, TestGatherTcp) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + Init(); + AddStream(1); + ASSERT_TRUE(Gather()); +} + +TEST_F(WebRtcIceConnectTest, TestGatherAutoPrioritize) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + Init(); + AddStream(1); + ASSERT_TRUE(Gather()); +} + +TEST_F(WebRtcIceConnectTest, TestConnect) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectRestartIce) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + Connect(); + SendReceive(p1_.get(), p2_.get()); + + p2_->RestartIce(); + ASSERT_FALSE(p2_->gathering_complete()); + + // verify p1 and p2 streams are still connected after restarting ice on p2 + SendReceive(p1_.get(), p2_.get()); + + mozilla::UniquePtr<IceTestPeer> p3_; + p3_ = MakeUnique<IceTestPeer>("P3", test_utils_, true, NrIceCtx::Config()); + InitPeer(p3_.get()); + p3_->AddStream(1); + + ASSERT_TRUE(GatherCallerAndCallee(p2_.get(), p3_.get())); + std::cout << "-------------------------------------------------" << std::endl; + ConnectCallerAndCallee(p3_.get(), p2_.get(), TRICKLE_SIMULATE); + SendReceive(p1_.get(), p2_.get()); // p1 and p2 are still connected + SendReceive(p3_.get(), p2_.get(), true, true); // p3 and p2 not yet connected + p2_->SimulateTrickle(0); + p3_->SimulateTrickle(0); + ASSERT_TRUE_WAIT(p3_->is_ready(0), kDefaultTimeout); + ASSERT_TRUE_WAIT(p2_->is_ready(0), kDefaultTimeout); + SendReceive(p1_.get(), p2_.get(), false, true); // p1 and p2 not connected + SendReceive(p3_.get(), p2_.get()); // p3 and p2 are now connected + + p3_ = nullptr; +} + +TEST_F(WebRtcIceConnectTest, TestConnectRestartIceThenAbort) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + Connect(); + SendReceive(p1_.get(), p2_.get()); + + p2_->RestartIce(); + ASSERT_FALSE(p2_->gathering_complete()); + + // verify p1 and p2 streams are still connected after restarting ice on p2 + SendReceive(p1_.get(), p2_.get()); + + mozilla::UniquePtr<IceTestPeer> p3_; + p3_ = MakeUnique<IceTestPeer>("P3", test_utils_, true, NrIceCtx::Config()); + InitPeer(p3_.get()); + p3_->AddStream(1); + + ASSERT_TRUE(GatherCallerAndCallee(p2_.get(), p3_.get())); + std::cout << "-------------------------------------------------" << std::endl; + p2_->RollbackIceRestart(); + p2_->Connect(p1_.get(), TRICKLE_NONE); + SendReceive(p1_.get(), p2_.get()); + p3_ = nullptr; +} + +TEST_F(WebRtcIceConnectTest, TestConnectIceRestartRoleConflict) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + // Just for fun lets do this with switched rolls + p1_->SetControlling(NrIceCtx::ICE_CONTROLLED); + p2_->SetControlling(NrIceCtx::ICE_CONTROLLING); + Connect(); + SendReceive(p1_.get(), p2_.get()); + // Set rolls should not switch by connecting + ASSERT_EQ(NrIceCtx::ICE_CONTROLLED, p1_->GetControlling()); + ASSERT_EQ(NrIceCtx::ICE_CONTROLLING, p2_->GetControlling()); + + p2_->RestartIce(); + ASSERT_FALSE(p2_->gathering_complete()); + p2_->SetControlling(NrIceCtx::ICE_CONTROLLED); + ASSERT_EQ(NrIceCtx::ICE_CONTROLLING, p2_->GetControlling()) + << "ICE restart should not allow role to change, unless ice-lite happens"; + + mozilla::UniquePtr<IceTestPeer> p3_; + p3_ = MakeUnique<IceTestPeer>("P3", test_utils_, true, NrIceCtx::Config()); + InitPeer(p3_.get()); + p3_->AddStream(1); + // Set control role for p3 accordingly (with role conflict) + p3_->SetControlling(NrIceCtx::ICE_CONTROLLING); + ASSERT_EQ(NrIceCtx::ICE_CONTROLLING, p3_->GetControlling()); + + ASSERT_TRUE(GatherCallerAndCallee(p2_.get(), p3_.get())); + std::cout << "-------------------------------------------------" << std::endl; + ConnectCallerAndCallee(p3_.get(), p2_.get()); + auto p2role = p2_->GetControlling(); + ASSERT_NE(p2role, p3_->GetControlling()) << "Conflict should be resolved"; + ASSERT_EQ(NrIceCtx::ICE_CONTROLLED, p1_->GetControlling()) + << "P1 should be unaffected by role conflict"; + + // And again we are not allowed to switch roles at this point any more + p1_->SetControlling(NrIceCtx::ICE_CONTROLLING); + ASSERT_EQ(NrIceCtx::ICE_CONTROLLED, p1_->GetControlling()); + p3_->SetControlling(p2role); + ASSERT_NE(p2role, p3_->GetControlling()); + + p3_ = nullptr; +} + +TEST_F(WebRtcIceConnectTest, + TestIceRestartWithMultipleInterfacesAndUserStartingScreenSharing) { + const char* FAKE_WIFI_ADDR = "10.0.0.1"; + const char* FAKE_WIFI_IF_NAME = "wlan9"; + + // prepare a fake wifi interface + nr_local_addr wifi_addr; + wifi_addr.interface.type = NR_INTERFACE_TYPE_WIFI; + wifi_addr.interface.estimated_speed = 1000; + + int r = nr_str_port_to_transport_addr(FAKE_WIFI_ADDR, 0, IPPROTO_UDP, + &(wifi_addr.addr)); + ASSERT_EQ(0, r); + strncpy(wifi_addr.addr.ifname, FAKE_WIFI_IF_NAME, MAXIFNAME); + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + // setup initial ICE connection between p1_ and p2_ + UseNat(); + AddStream(1); + SetExpectedTypes(NrIceCandidate::Type::ICE_SERVER_REFLEXIVE, + NrIceCandidate::Type::ICE_SERVER_REFLEXIVE); + ASSERT_TRUE(Gather(kDefaultTimeout, true)); + Connect(); + + // verify the connection is working + SendReceive(p1_.get(), p2_.get()); + + // simulate user accepting permissions for screen sharing + p2_->SetCtxFlags(false); + + // and having an additional non-default interface + nsTArray<NrIceStunAddr> stunAddr = p2_->GetStunAddrs(); + stunAddr.InsertElementAt(0, NrIceStunAddr(&wifi_addr)); + p2_->SetStunAddrs(stunAddr); + + std::cout << "-------------------------------------------------" << std::endl; + + // now restart ICE + p2_->RestartIce(); + ASSERT_FALSE(p2_->gathering_complete()); + + // verify that we can successfully gather candidates + p2_->Gather(); + EXPECT_TRUE_WAIT(p2_->gathering_complete(), kDefaultTimeout); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTcp) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + Init(); + AddStream(1); + ASSERT_TRUE(Gather()); + SetCandidateFilter(IsTcpCandidate); + SetExpectedTypes(NrIceCandidate::Type::ICE_HOST, + NrIceCandidate::Type::ICE_HOST, kNrIceTransportTcp); + Connect(); +} + +// TCP SO tests works on localhost only with delay applied: +// tc qdisc add dev lo root netem delay 10ms +TEST_F(WebRtcIceConnectTest, DISABLED_TestConnectTcpSo) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + Init(); + AddStream(1); + ASSERT_TRUE(Gather()); + SetCandidateFilter(IsTcpSoCandidate); + SetExpectedTypes(NrIceCandidate::Type::ICE_HOST, + NrIceCandidate::Type::ICE_HOST, kNrIceTransportTcp); + Connect(); +} + +// Disabled because this breaks with hairpinning. +TEST_F(WebRtcIceConnectTest, DISABLED_TestConnectNoHost) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + Init(false, NrIceCtx::ICE_POLICY_NO_HOST); + AddStream(1); + ASSERT_TRUE(Gather()); + SetExpectedTypes(NrIceCandidate::Type::ICE_SERVER_REFLEXIVE, + NrIceCandidate::Type::ICE_SERVER_REFLEXIVE, + kNrIceTransportTcp); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestLoopbackOnlySortOf) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + config.mAllowLoopback = true; + NrIceCtx::InitializeGlobals(config); + Init(false); + AddStream(1); + SetCandidateFilter(IsLoopbackCandidate); + ASSERT_TRUE(Gather()); + SetExpectedRemoteCandidateAddr("127.0.0.1"); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectBothControllingP1Wins) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + p1_->SetTiebreaker(1); + p2_->SetTiebreaker(0); + ASSERT_TRUE(Gather()); + p1_->SetControlling(NrIceCtx::ICE_CONTROLLING); + p2_->SetControlling(NrIceCtx::ICE_CONTROLLING); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectBothControllingP2Wins) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + p1_->SetTiebreaker(0); + p2_->SetTiebreaker(1); + ASSERT_TRUE(Gather()); + p1_->SetControlling(NrIceCtx::ICE_CONTROLLING); + p2_->SetControlling(NrIceCtx::ICE_CONTROLLING); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectIceLiteOfferer) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + p1_->SimulateIceLite(); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestTrickleBothControllingP1Wins) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + p1_->SetTiebreaker(1); + p2_->SetTiebreaker(0); + ASSERT_TRUE(Gather()); + p1_->SetControlling(NrIceCtx::ICE_CONTROLLING); + p2_->SetControlling(NrIceCtx::ICE_CONTROLLING); + ConnectTrickle(); + SimulateTrickle(0); + WaitForConnected(1000); + AssertCheckingReached(); +} + +TEST_F(WebRtcIceConnectTest, TestTrickleBothControllingP2Wins) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + p1_->SetTiebreaker(0); + p2_->SetTiebreaker(1); + ASSERT_TRUE(Gather()); + p1_->SetControlling(NrIceCtx::ICE_CONTROLLING); + p2_->SetControlling(NrIceCtx::ICE_CONTROLLING); + ConnectTrickle(); + SimulateTrickle(0); + WaitForConnected(1000); + AssertCheckingReached(); +} + +TEST_F(WebRtcIceConnectTest, TestTrickleIceLiteOfferer) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + p1_->SimulateIceLite(); + ConnectTrickle(); + SimulateTrickle(0); + WaitForConnected(1000); + AssertCheckingReached(); +} + +TEST_F(WebRtcIceConnectTest, TestGatherFullCone) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + AddStream(1); + ASSERT_TRUE(Gather()); +} + +TEST_F(WebRtcIceConnectTest, TestGatherFullConeAutoPrioritize) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + Init(); + AddStream(1); + ASSERT_TRUE(Gather()); +} + +TEST_F(WebRtcIceConnectTest, TestConnectFullCone) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + AddStream(1); + SetExpectedTypes(NrIceCandidate::Type::ICE_SERVER_REFLEXIVE, + NrIceCandidate::Type::ICE_SERVER_REFLEXIVE); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectNoNatNoHost) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + Init(false, NrIceCtx::ICE_POLICY_NO_HOST); + UseTestStunServer(); + // Because we are connecting from our host candidate to the + // other side's apparent srflx (which is also their host) + // we see a host/srflx pair. + SetExpectedTypes(NrIceCandidate::Type::ICE_HOST, + NrIceCandidate::Type::ICE_SERVER_REFLEXIVE); + AddStream(1); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectFullConeNoHost) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + Init(false, NrIceCtx::ICE_POLICY_NO_HOST); + UseTestStunServer(); + SetExpectedTypes(NrIceCandidate::Type::ICE_SERVER_REFLEXIVE, + NrIceCandidate::Type::ICE_SERVER_REFLEXIVE); + AddStream(1); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestGatherAddressRestrictedCone) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + SetFilteringType(TestNat::ADDRESS_DEPENDENT); + SetMappingType(TestNat::ENDPOINT_INDEPENDENT); + AddStream(1); + ASSERT_TRUE(Gather()); +} + +TEST_F(WebRtcIceConnectTest, TestConnectAddressRestrictedCone) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + SetFilteringType(TestNat::ADDRESS_DEPENDENT); + SetMappingType(TestNat::ENDPOINT_INDEPENDENT); + AddStream(1); + SetExpectedTypes(NrIceCandidate::Type::ICE_SERVER_REFLEXIVE, + NrIceCandidate::Type::ICE_SERVER_REFLEXIVE); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestGatherPortRestrictedCone) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + SetFilteringType(TestNat::PORT_DEPENDENT); + SetMappingType(TestNat::ENDPOINT_INDEPENDENT); + AddStream(1); + ASSERT_TRUE(Gather()); +} + +TEST_F(WebRtcIceConnectTest, TestConnectPortRestrictedCone) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + SetFilteringType(TestNat::PORT_DEPENDENT); + SetMappingType(TestNat::ENDPOINT_INDEPENDENT); + AddStream(1); + SetExpectedTypes(NrIceCandidate::Type::ICE_SERVER_REFLEXIVE, + NrIceCandidate::Type::ICE_SERVER_REFLEXIVE); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestGatherSymmetricNat) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + SetFilteringType(TestNat::PORT_DEPENDENT); + SetMappingType(TestNat::PORT_DEPENDENT); + AddStream(1); + ASSERT_TRUE(Gather()); +} + +TEST_F(WebRtcIceConnectTest, TestConnectSymmetricNat) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + SetFilteringType(TestNat::PORT_DEPENDENT); + SetMappingType(TestNat::PORT_DEPENDENT); + p1_->SetExpectedTypes(NrIceCandidate::Type::ICE_RELAYED, + NrIceCandidate::Type::ICE_RELAYED); + p2_->SetExpectedTypes(NrIceCandidate::Type::ICE_RELAYED, + NrIceCandidate::Type::ICE_RELAYED); + SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_); + AddStream(1); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectSymmetricNatAndNoNat) { + { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + } + + NrIceCtx::Config config; + p1_ = MakeUnique<IceTestPeer>("P1", test_utils_, true, config); + p1_->UseNat(); + p1_->SetFilteringType(TestNat::PORT_DEPENDENT); + p1_->SetMappingType(TestNat::PORT_DEPENDENT); + + p2_ = MakeUnique<IceTestPeer>("P2", test_utils_, false, config); + initted_ = true; + + AddStream(1); + p1_->SetExpectedTypes(NrIceCandidate::Type::ICE_PEER_REFLEXIVE, + NrIceCandidate::Type::ICE_HOST); + p2_->SetExpectedTypes(NrIceCandidate::Type::ICE_HOST, + NrIceCandidate::Type::ICE_PEER_REFLEXIVE); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestGatherNatBlocksUDP) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + BlockUdp(); + std::vector<NrIceTurnServer> turn_servers; + std::vector<unsigned char> password_vec(turn_password_.begin(), + turn_password_.end()); + turn_servers.push_back( + *NrIceTurnServer::Create(turn_server_, kDefaultStunServerPort, turn_user_, + password_vec, kNrIceTransportTcp)); + turn_servers.push_back( + *NrIceTurnServer::Create(turn_server_, kDefaultStunServerPort, turn_user_, + password_vec, kNrIceTransportUdp)); + SetTurnServers(turn_servers); + AddStream(1); + // We have to wait for the UDP-based stuff to time out. + ASSERT_TRUE(Gather(kDefaultTimeout * 3)); +} + +TEST_F(WebRtcIceConnectTest, TestConnectNatBlocksUDP) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + UseNat(); + BlockUdp(); + std::vector<NrIceTurnServer> turn_servers; + std::vector<unsigned char> password_vec(turn_password_.begin(), + turn_password_.end()); + turn_servers.push_back( + *NrIceTurnServer::Create(turn_server_, kDefaultStunServerPort, turn_user_, + password_vec, kNrIceTransportTcp)); + turn_servers.push_back( + *NrIceTurnServer::Create(turn_server_, kDefaultStunServerPort, turn_user_, + password_vec, kNrIceTransportUdp)); + SetTurnServers(turn_servers); + p1_->SetExpectedTypes(NrIceCandidate::Type::ICE_RELAYED, + NrIceCandidate::Type::ICE_RELAYED, kNrIceTransportTcp); + p2_->SetExpectedTypes(NrIceCandidate::Type::ICE_RELAYED, + NrIceCandidate::Type::ICE_RELAYED, kNrIceTransportTcp); + AddStream(1); + ASSERT_TRUE(Gather(kDefaultTimeout * 3)); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTwoComponents) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(2); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTwoComponentsDisableSecond) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(2); + ASSERT_TRUE(Gather()); + p1_->DisableComponent(0, 2); + p2_->DisableComponent(0, 2); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectP2ThenP1) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectP2(); + PR_Sleep(1000); + ConnectP1(); + WaitForConnectedStreams(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectP2ThenP1Trickle) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectP2(); + PR_Sleep(1000); + ConnectP1(TRICKLE_SIMULATE); + SimulateTrickleP1(0); + WaitForConnectedStreams(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectP2ThenP1TrickleTwoComponents) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + AddStream(2); + ASSERT_TRUE(Gather()); + ConnectP2(); + PR_Sleep(1000); + ConnectP1(TRICKLE_SIMULATE); + SimulateTrickleP1(0); + std::cerr << "Sleeping between trickle streams" << std::endl; + PR_Sleep(1000); // Give this some time to settle but not complete + // all of ICE. + SimulateTrickleP1(1); + WaitForConnectedStreams(2); +} + +TEST_F(WebRtcIceConnectTest, TestConnectAutoPrioritize) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + Init(); + AddStream(1); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTrickleOneStreamOneComponent) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + SimulateTrickle(0); + WaitForConnected(1000); + AssertCheckingReached(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTrickleTwoStreamsOneComponent) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + SimulateTrickle(0); + SimulateTrickle(1); + WaitForConnected(1000); + AssertCheckingReached(); +} + +void RealisticTrickleDelay( + std::vector<SchedulableTrickleCandidate*>& candidates) { + for (size_t i = 0; i < candidates.size(); ++i) { + SchedulableTrickleCandidate* cand = candidates[i]; + if (cand->IsHost()) { + cand->Schedule(i * 10); + } else if (cand->IsReflexive()) { + cand->Schedule(i * 10 + 100); + } else if (cand->IsRelay()) { + cand->Schedule(i * 10 + 200); + } + } +} + +void DelayRelayCandidates(std::vector<SchedulableTrickleCandidate*>& candidates, + unsigned int ms) { + for (auto& candidate : candidates) { + if (candidate->IsRelay()) { + candidate->Schedule(ms); + } else { + candidate->Schedule(0); + } + } +} + +void AddNonPairableCandidates( + std::vector<SchedulableTrickleCandidate*>& candidates, IceTestPeer* peer, + size_t stream, int net_type, MtransportTestUtils* test_utils_) { + for (int i = 1; i < 5; i++) { + if (net_type == i) continue; + switch (i) { + case 1: + candidates.push_back(new SchedulableTrickleCandidate( + peer, stream, + "candidate:0 1 UDP 2113601790 10.0.0.1 12345 typ host", "", + test_utils_)); + break; + case 2: + candidates.push_back(new SchedulableTrickleCandidate( + peer, stream, + "candidate:0 1 UDP 2113601791 172.16.1.1 12345 typ host", "", + test_utils_)); + break; + case 3: + candidates.push_back(new SchedulableTrickleCandidate( + peer, stream, + "candidate:0 1 UDP 2113601792 192.168.0.1 12345 typ host", "", + test_utils_)); + break; + case 4: + candidates.push_back(new SchedulableTrickleCandidate( + peer, stream, + "candidate:0 1 UDP 2113601793 100.64.1.1 12345 typ host", "", + test_utils_)); + break; + default: + NR_UNIMPLEMENTED; + } + } + + for (auto i = candidates.rbegin(); i != candidates.rend(); ++i) { + std::cerr << "Scheduling candidate: " << (*i)->Candidate().c_str() + << std::endl; + (*i)->Schedule(0); + } +} + +void DropTrickleCandidates( + std::vector<SchedulableTrickleCandidate*>& candidates) {} + +TEST_F(WebRtcIceConnectTest, TestConnectTrickleAddStreamDuringICE) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(0)); + RealisticTrickleDelay(p2_->ControlTrickle(0)); + AddStream(1); + RealisticTrickleDelay(p1_->ControlTrickle(1)); + RealisticTrickleDelay(p2_->ControlTrickle(1)); + WaitForConnected(1000); + AssertCheckingReached(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTrickleAddStreamAfterICE) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(0)); + RealisticTrickleDelay(p2_->ControlTrickle(0)); + WaitForConnected(1000); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(1)); + RealisticTrickleDelay(p2_->ControlTrickle(1)); + WaitForConnected(1000); + AssertCheckingReached(); +} + +TEST_F(WebRtcIceConnectTest, RemoveStream) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(0)); + RealisticTrickleDelay(p2_->ControlTrickle(0)); + RealisticTrickleDelay(p1_->ControlTrickle(1)); + RealisticTrickleDelay(p2_->ControlTrickle(1)); + WaitForConnected(1000); + + RemoveStream(0); + ASSERT_TRUE(Gather()); + ConnectTrickle(); +} + +TEST_F(WebRtcIceConnectTest, P1NoTrickle) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + DropTrickleCandidates(p1_->ControlTrickle(0)); + RealisticTrickleDelay(p2_->ControlTrickle(0)); + WaitForConnected(1000); +} + +TEST_F(WebRtcIceConnectTest, P2NoTrickle) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(0)); + DropTrickleCandidates(p2_->ControlTrickle(0)); + WaitForConnected(1000); +} + +TEST_F(WebRtcIceConnectTest, RemoveAndAddStream) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(0)); + RealisticTrickleDelay(p2_->ControlTrickle(0)); + RealisticTrickleDelay(p1_->ControlTrickle(1)); + RealisticTrickleDelay(p2_->ControlTrickle(1)); + WaitForConnected(1000); + + RemoveStream(0); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(2)); + RealisticTrickleDelay(p2_->ControlTrickle(2)); + WaitForConnected(1000); +} + +TEST_F(WebRtcIceConnectTest, RemoveStreamBeforeGather) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + AddStream(1); + ASSERT_TRUE(Gather(0)); + RemoveStream(0); + WaitForGather(); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(1)); + RealisticTrickleDelay(p2_->ControlTrickle(1)); + WaitForConnected(1000); +} + +TEST_F(WebRtcIceConnectTest, RemoveStreamDuringGather) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + AddStream(1); + RemoveStream(0); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(1)); + RealisticTrickleDelay(p2_->ControlTrickle(1)); + WaitForConnected(1000); +} + +TEST_F(WebRtcIceConnectTest, RemoveStreamDuringConnect) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(0)); + RealisticTrickleDelay(p2_->ControlTrickle(0)); + RealisticTrickleDelay(p1_->ControlTrickle(1)); + RealisticTrickleDelay(p2_->ControlTrickle(1)); + RemoveStream(0); + WaitForConnected(1000); +} + +TEST_F(WebRtcIceConnectTest, TestConnectRealTrickleOneStreamOneComponent) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + AddStream(1); + ASSERT_TRUE(Gather(0)); + ConnectTrickle(TRICKLE_REAL); + WaitForConnected(); + WaitForGather(); // ICE can complete before we finish gathering. + AssertCheckingReached(); +} + +TEST_F(WebRtcIceConnectTest, TestSendReceive) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + Connect(); + SendReceive(); +} + +TEST_F(WebRtcIceConnectTest, TestSendReceiveTcp) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + Init(); + AddStream(1); + ASSERT_TRUE(Gather()); + SetCandidateFilter(IsTcpCandidate); + SetExpectedTypes(NrIceCandidate::Type::ICE_HOST, + NrIceCandidate::Type::ICE_HOST, kNrIceTransportTcp); + Connect(); + SendReceive(); +} + +// TCP SO tests works on localhost only with delay applied: +// tc qdisc add dev lo root netem delay 10ms +TEST_F(WebRtcIceConnectTest, DISABLED_TestSendReceiveTcpSo) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + Init(); + AddStream(1); + ASSERT_TRUE(Gather()); + SetCandidateFilter(IsTcpSoCandidate); + SetExpectedTypes(NrIceCandidate::Type::ICE_HOST, + NrIceCandidate::Type::ICE_HOST, kNrIceTransportTcp); + Connect(); + SendReceive(); +} + +TEST_F(WebRtcIceConnectTest, TestConsent) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + SetupAndCheckConsent(); + PR_Sleep(1500); + AssertConsentRefresh(); + SendReceive(); +} + +TEST_F(WebRtcIceConnectTest, TestConsentTcp) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = true; + NrIceCtx::InitializeGlobals(config); + Init(); + AddStream(1); + SetCandidateFilter(IsTcpCandidate); + SetExpectedTypes(NrIceCandidate::Type::ICE_HOST, + NrIceCandidate::Type::ICE_HOST, kNrIceTransportTcp); + SetupAndCheckConsent(); + PR_Sleep(1500); + AssertConsentRefresh(); + SendReceive(); +} + +TEST_F(WebRtcIceConnectTest, TestConsentIntermittent) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + SetupAndCheckConsent(); + p1_->SetBlockStun(true); + p2_->SetBlockStun(true); + WaitForDisconnected(); + AssertConsentRefresh(CONSENT_STALE); + SendReceive(); + p1_->SetBlockStun(false); + p2_->SetBlockStun(false); + WaitForConnected(); + AssertConsentRefresh(); + SendReceive(); + p1_->SetBlockStun(true); + p2_->SetBlockStun(true); + WaitForDisconnected(); + AssertConsentRefresh(CONSENT_STALE); + SendReceive(); + p1_->SetBlockStun(false); + p2_->SetBlockStun(false); + WaitForConnected(); + AssertConsentRefresh(); +} + +TEST_F(WebRtcIceConnectTest, TestConsentTimeout) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + SetupAndCheckConsent(); + p1_->SetBlockStun(true); + p2_->SetBlockStun(true); + WaitForDisconnected(); + AssertConsentRefresh(CONSENT_STALE); + SendReceive(); + WaitForFailed(); + AssertConsentRefresh(CONSENT_EXPIRED); + SendFailure(); +} + +TEST_F(WebRtcIceConnectTest, TestConsentDelayed) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + SetupAndCheckConsent(); + /* Note: We don't have a list of STUN transaction IDs of the previously timed + out consent requests. Thus responses after sending the next consent + request are ignored. */ + p1_->SetStunResponseDelay(200); + p2_->SetStunResponseDelay(200); + PR_Sleep(1000); + AssertConsentRefresh(); + SendReceive(); +} + +TEST_F(WebRtcIceConnectTest, TestNetworkForcedOfflineAndRecovery) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + SetupAndCheckConsent(); + p1_->ChangeNetworkStateToOffline(); + ASSERT_TRUE_WAIT(p1_->ice_connected() == 0, kDefaultTimeout); + // Next round of consent check should switch it back to online + ASSERT_TRUE_WAIT(p1_->ice_connected(), kDefaultTimeout); +} + +TEST_F(WebRtcIceConnectTest, TestNetworkForcedOfflineTwice) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + SetupAndCheckConsent(); + p2_->ChangeNetworkStateToOffline(); + ASSERT_TRUE_WAIT(p2_->ice_connected() == 0, kDefaultTimeout); + p2_->ChangeNetworkStateToOffline(); + ASSERT_TRUE_WAIT(p2_->ice_connected() == 0, kDefaultTimeout); +} + +TEST_F(WebRtcIceConnectTest, TestNetworkOnlineDoesntChangeState) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + SetupAndCheckConsent(); + p2_->ChangeNetworkStateToOnline(); + ASSERT_TRUE(p2_->ice_connected()); + PR_Sleep(1500); + p2_->ChangeNetworkStateToOnline(); + ASSERT_TRUE(p2_->ice_connected()); +} + +TEST_F(WebRtcIceConnectTest, TestNetworkOnlineTriggersConsent) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + // Let's emulate audio + video w/o rtcp-mux + AddStream(2); + AddStream(2); + SetupAndCheckConsent(); + p1_->ChangeNetworkStateToOffline(); + p1_->SetBlockStun(true); + ASSERT_TRUE_WAIT(p1_->ice_connected() == 0, kDefaultTimeout); + PR_Sleep(1500); + ASSERT_TRUE(p1_->ice_connected() == 0); + p1_->SetBlockStun(false); + p1_->ChangeNetworkStateToOnline(); + ASSERT_TRUE_WAIT(p1_->ice_connected(), 500); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTurn) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_); + AddStream(1); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTurnWithDelay) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_); + SetCandidateFilter(SabotageHostCandidateAndDropReflexive); + AddStream(1); + p1_->Gather(); + PR_Sleep(500); + p2_->Gather(); + ConnectTrickle(TRICKLE_REAL); + WaitForGather(); + WaitForConnectedStreams(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTurnWithNormalTrickleDelay) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(0)); + RealisticTrickleDelay(p2_->ControlTrickle(0)); + + WaitForConnected(); + AssertCheckingReached(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTurnWithNormalTrickleDelayOneSided) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + RealisticTrickleDelay(p1_->ControlTrickle(0)); + p2_->SimulateTrickle(0); + + WaitForConnected(); + AssertCheckingReached(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTurnWithLargeTrickleDelay) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_); + SetCandidateFilter(SabotageHostCandidateAndDropReflexive); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectTrickle(); + // Trickle host candidates immediately, but delay relay candidates + DelayRelayCandidates(p1_->ControlTrickle(0), 3700); + DelayRelayCandidates(p2_->ControlTrickle(0), 3700); + + WaitForConnected(); + AssertCheckingReached(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTurnTcp) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_, kNrIceTransportTcp); + AddStream(1); + ASSERT_TRUE(Gather()); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTurnOnly) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_); + AddStream(1); + ASSERT_TRUE(Gather()); + SetCandidateFilter(IsRelayCandidate); + SetExpectedTypes(NrIceCandidate::Type::ICE_RELAYED, + NrIceCandidate::Type::ICE_RELAYED); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectTurnTcpOnly) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_, kNrIceTransportTcp); + AddStream(1); + ASSERT_TRUE(Gather()); + SetCandidateFilter(IsRelayCandidate); + SetExpectedTypes(NrIceCandidate::Type::ICE_RELAYED, + NrIceCandidate::Type::ICE_RELAYED, kNrIceTransportTcp); + Connect(); +} + +TEST_F(WebRtcIceConnectTest, TestSendReceiveTurnOnly) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_); + AddStream(1); + ASSERT_TRUE(Gather()); + SetCandidateFilter(IsRelayCandidate); + SetExpectedTypes(NrIceCandidate::Type::ICE_RELAYED, + NrIceCandidate::Type::ICE_RELAYED); + Connect(); + SendReceive(); +} + +TEST_F(WebRtcIceConnectTest, TestSendReceiveTurnTcpOnly) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + SetTurnServer(turn_server_, kDefaultStunServerPort, turn_user_, + turn_password_, kNrIceTransportTcp); + AddStream(1); + ASSERT_TRUE(Gather()); + SetCandidateFilter(IsRelayCandidate); + SetExpectedTypes(NrIceCandidate::Type::ICE_RELAYED, + NrIceCandidate::Type::ICE_RELAYED, kNrIceTransportTcp); + Connect(); + SendReceive(); +} + +TEST_F(WebRtcIceConnectTest, TestSendReceiveTurnBothOnly) { + if (turn_server_.empty()) return; + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + std::vector<NrIceTurnServer> turn_servers; + std::vector<unsigned char> password_vec(turn_password_.begin(), + turn_password_.end()); + turn_servers.push_back( + *NrIceTurnServer::Create(turn_server_, kDefaultStunServerPort, turn_user_, + password_vec, kNrIceTransportTcp)); + turn_servers.push_back( + *NrIceTurnServer::Create(turn_server_, kDefaultStunServerPort, turn_user_, + password_vec, kNrIceTransportUdp)); + SetTurnServers(turn_servers); + AddStream(1); + ASSERT_TRUE(Gather()); + SetCandidateFilter(IsRelayCandidate); + // UDP is preferred. + SetExpectedTypes(NrIceCandidate::Type::ICE_RELAYED, + NrIceCandidate::Type::ICE_RELAYED, kNrIceTransportUdp); + Connect(); + SendReceive(); +} + +TEST_F(WebRtcIceConnectTest, TestConnectShutdownOneSide) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + ConnectThenDelete(); +} + +TEST_F(WebRtcIceConnectTest, TestPollCandPairsBeforeConnect) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + + std::vector<NrIceCandidatePair> pairs; + nsresult res = p1_->GetCandidatePairs(0, &pairs); + // There should be no candidate pairs prior to calling Connect() + ASSERT_EQ(NS_OK, res); + ASSERT_EQ(0U, pairs.size()); + + res = p2_->GetCandidatePairs(0, &pairs); + ASSERT_EQ(NS_OK, res); + ASSERT_EQ(0U, pairs.size()); +} + +TEST_F(WebRtcIceConnectTest, TestPollCandPairsAfterConnect) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + Connect(); + + std::vector<NrIceCandidatePair> pairs; + nsresult r = p1_->GetCandidatePairs(0, &pairs); + ASSERT_EQ(NS_OK, r); + // How detailed of a check do we want to do here? If the turn server is + // functioning, we'll get at least two pairs, but this is probably not + // something we should assume. + ASSERT_NE(0U, pairs.size()); + ASSERT_TRUE(p1_->CandidatePairsPriorityDescending(pairs)); + ASSERT_TRUE(ContainsSucceededPair(pairs)); + pairs.clear(); + + r = p2_->GetCandidatePairs(0, &pairs); + ASSERT_EQ(NS_OK, r); + ASSERT_NE(0U, pairs.size()); + ASSERT_TRUE(p2_->CandidatePairsPriorityDescending(pairs)); + ASSERT_TRUE(ContainsSucceededPair(pairs)); +} + +// TODO Bug 1259842 - disabled until we find a better way to handle two +// candidates from different RFC1918 ranges +TEST_F(WebRtcIceConnectTest, DISABLED_TestHostCandPairingFilter) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + Init(false); + AddStream(1); + ASSERT_TRUE(Gather()); + SetCandidateFilter(IsIpv4Candidate); + + int host_net = p1_->GetCandidatesPrivateIpv4Range(0); + if (host_net <= 0) { + // TODO bug 1226838: make this work with multiple private IPs + FAIL() << "This test needs exactly one private IPv4 host candidate to work" + << std::endl; + } + + ConnectTrickle(); + AddNonPairableCandidates(p1_->ControlTrickle(0), p1_.get(), 0, host_net, + test_utils_); + AddNonPairableCandidates(p2_->ControlTrickle(0), p2_.get(), 0, host_net, + test_utils_); + + std::vector<NrIceCandidatePair> pairs; + p1_->GetCandidatePairs(0, &pairs); + for (auto p : pairs) { + std::cerr << "Verifying pair:" << std::endl; + p1_->DumpCandidatePair(p); + nr_transport_addr addr; + nr_str_port_to_transport_addr(p.local.local_addr.host.c_str(), 0, + IPPROTO_UDP, &addr); + ASSERT_TRUE(nr_transport_addr_get_private_addr_range(&addr) == host_net); + nr_str_port_to_transport_addr(p.remote.cand_addr.host.c_str(), 0, + IPPROTO_UDP, &addr); + ASSERT_TRUE(nr_transport_addr_get_private_addr_range(&addr) == host_net); + } +} + +// TODO Bug 1226838 - See Comment 2 - this test can't work as written +TEST_F(WebRtcIceConnectTest, DISABLED_TestSrflxCandPairingFilter) { + if (stun_server_address_.empty()) { + return; + } + + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + Init(false); + AddStream(1); + ASSERT_TRUE(Gather()); + SetCandidateFilter(IsSrflxCandidate); + + if (p1_->GetCandidatesPrivateIpv4Range(0) <= 0) { + // TODO bug 1226838: make this work with public IP addresses + std::cerr << "Don't run this test at IETF meetings!" << std::endl; + FAIL() << "This test needs one private IPv4 host candidate to work" + << std::endl; + } + + ConnectTrickle(); + SimulateTrickleP1(0); + SimulateTrickleP2(0); + + std::vector<NrIceCandidatePair> pairs; + p1_->GetCandidatePairs(0, &pairs); + for (auto p : pairs) { + std::cerr << "Verifying P1 pair:" << std::endl; + p1_->DumpCandidatePair(p); + nr_transport_addr addr; + nr_str_port_to_transport_addr(p.local.local_addr.host.c_str(), 0, + IPPROTO_UDP, &addr); + ASSERT_TRUE(nr_transport_addr_get_private_addr_range(&addr) != 0); + nr_str_port_to_transport_addr(p.remote.cand_addr.host.c_str(), 0, + IPPROTO_UDP, &addr); + ASSERT_TRUE(nr_transport_addr_get_private_addr_range(&addr) == 0); + } + p2_->GetCandidatePairs(0, &pairs); + for (auto p : pairs) { + std::cerr << "Verifying P2 pair:" << std::endl; + p2_->DumpCandidatePair(p); + nr_transport_addr addr; + nr_str_port_to_transport_addr(p.local.local_addr.host.c_str(), 0, + IPPROTO_UDP, &addr); + ASSERT_TRUE(nr_transport_addr_get_private_addr_range(&addr) != 0); + nr_str_port_to_transport_addr(p.remote.cand_addr.host.c_str(), 0, + IPPROTO_UDP, &addr); + ASSERT_TRUE(nr_transport_addr_get_private_addr_range(&addr) == 0); + } +} + +TEST_F(WebRtcIceConnectTest, TestPollCandPairsDuringConnect) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + + p2_->Connect(p1_.get(), TRICKLE_NONE, false); + p1_->Connect(p2_.get(), TRICKLE_NONE, false); + + std::vector<NrIceCandidatePair> pairs1; + std::vector<NrIceCandidatePair> pairs2; + + p1_->StartChecks(); + p1_->UpdateAndValidateCandidatePairs(0, &pairs1); + p2_->UpdateAndValidateCandidatePairs(0, &pairs2); + + p2_->StartChecks(); + p1_->UpdateAndValidateCandidatePairs(0, &pairs1); + p2_->UpdateAndValidateCandidatePairs(0, &pairs2); + + WaitForConnectedStreams(); + p1_->UpdateAndValidateCandidatePairs(0, &pairs1); + p2_->UpdateAndValidateCandidatePairs(0, &pairs2); + ASSERT_TRUE(ContainsSucceededPair(pairs1)); + ASSERT_TRUE(ContainsSucceededPair(pairs2)); +} + +TEST_F(WebRtcIceConnectTest, TestRLogConnector) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + ASSERT_TRUE(Gather()); + + p2_->Connect(p1_.get(), TRICKLE_NONE, false); + p1_->Connect(p2_.get(), TRICKLE_NONE, false); + + std::vector<NrIceCandidatePair> pairs1; + std::vector<NrIceCandidatePair> pairs2; + + p1_->StartChecks(); + p1_->UpdateAndValidateCandidatePairs(0, &pairs1); + p2_->UpdateAndValidateCandidatePairs(0, &pairs2); + + p2_->StartChecks(); + p1_->UpdateAndValidateCandidatePairs(0, &pairs1); + p2_->UpdateAndValidateCandidatePairs(0, &pairs2); + + WaitForConnectedStreams(); + p1_->UpdateAndValidateCandidatePairs(0, &pairs1); + p2_->UpdateAndValidateCandidatePairs(0, &pairs2); + ASSERT_TRUE(ContainsSucceededPair(pairs1)); + ASSERT_TRUE(ContainsSucceededPair(pairs2)); + + for (auto& p : pairs1) { + std::deque<std::string> logs; + std::string substring("CAND-PAIR("); + substring += p.codeword; + RLogConnector::GetInstance()->Filter(substring, 0, &logs); + ASSERT_NE(0U, logs.size()); + } + + for (auto& p : pairs2) { + std::deque<std::string> logs; + std::string substring("CAND-PAIR("); + substring += p.codeword; + RLogConnector::GetInstance()->Filter(substring, 0, &logs); + ASSERT_NE(0U, logs.size()); + } +} + +// Verify that a bogus candidate doesn't cause crashes on the +// main thread. See bug 856433. +TEST_F(WebRtcIceConnectTest, TestBogusCandidate) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + Gather(); + ConnectTrickle(); + p1_->ParseCandidate(0, kBogusIceCandidate, ""); + + std::vector<NrIceCandidatePair> pairs; + nsresult res = p1_->GetCandidatePairs(0, &pairs); + ASSERT_EQ(NS_OK, res); + ASSERT_EQ(0U, pairs.size()); +} + +TEST_F(WebRtcIceConnectTest, TestNonMDNSCandidate) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + Gather(); + ConnectTrickle(); + p1_->ParseCandidate(0, kUnreachableHostIceCandidate, ""); + + std::vector<NrIceCandidatePair> pairs; + nsresult res = p1_->GetCandidatePairs(0, &pairs); + ASSERT_EQ(NS_OK, res); + ASSERT_EQ(1U, pairs.size()); + ASSERT_EQ(pairs[0].remote.mdns_addr, ""); +} + +TEST_F(WebRtcIceConnectTest, TestMDNSCandidate) { + NrIceCtx::GlobalConfig config; + config.mTcpEnabled = false; + NrIceCtx::InitializeGlobals(config); + AddStream(1); + Gather(); + ConnectTrickle(); + p1_->ParseCandidate(0, kUnreachableHostIceCandidate, "host.local"); + + std::vector<NrIceCandidatePair> pairs; + nsresult res = p1_->GetCandidatePairs(0, &pairs); + ASSERT_EQ(NS_OK, res); + ASSERT_EQ(1U, pairs.size()); + ASSERT_EQ(pairs[0].remote.mdns_addr, "host.local"); +} + +TEST_F(WebRtcIcePrioritizerTest, TestPrioritizer) { + SetPriorizer(::mozilla::CreateInterfacePrioritizer()); + + AddInterface("0", NR_INTERFACE_TYPE_VPN, 100); // unknown vpn + AddInterface("1", NR_INTERFACE_TYPE_VPN | NR_INTERFACE_TYPE_WIRED, + 100); // wired vpn + AddInterface("2", NR_INTERFACE_TYPE_VPN | NR_INTERFACE_TYPE_WIFI, + 100); // wifi vpn + AddInterface("3", NR_INTERFACE_TYPE_VPN | NR_INTERFACE_TYPE_MOBILE, + 100); // wifi vpn + AddInterface("4", NR_INTERFACE_TYPE_WIRED, 1000); // wired, high speed + AddInterface("5", NR_INTERFACE_TYPE_WIRED, 10); // wired, low speed + AddInterface("6", NR_INTERFACE_TYPE_WIFI, 10); // wifi, low speed + AddInterface("7", NR_INTERFACE_TYPE_WIFI, 1000); // wifi, high speed + AddInterface("8", NR_INTERFACE_TYPE_MOBILE, 10); // mobile, low speed + AddInterface("9", NR_INTERFACE_TYPE_MOBILE, 1000); // mobile, high speed + AddInterface("10", NR_INTERFACE_TYPE_UNKNOWN, 10); // unknown, low speed + AddInterface("11", NR_INTERFACE_TYPE_UNKNOWN, 1000); // unknown, high speed + + // expected preference "4" > "5" > "1" > "7" > "6" > "2" > "9" > "8" > "3" > + // "11" > "10" > "0" + + HasLowerPreference("0", "10"); + HasLowerPreference("10", "11"); + HasLowerPreference("11", "3"); + HasLowerPreference("3", "8"); + HasLowerPreference("8", "9"); + HasLowerPreference("9", "2"); + HasLowerPreference("2", "6"); + HasLowerPreference("6", "7"); + HasLowerPreference("7", "1"); + HasLowerPreference("1", "5"); + HasLowerPreference("5", "4"); +} + +TEST_F(WebRtcIcePacketFilterTest, TestSendNonStunPacket) { + const unsigned char data[] = "12345abcde"; + TestOutgoing(data, sizeof(data), 123, 45, false); + TestOutgoingTcp(data, sizeof(data), false); +} + +TEST_F(WebRtcIcePacketFilterTest, TestRecvNonStunPacket) { + const unsigned char data[] = "12345abcde"; + TestIncoming(data, sizeof(data), 123, 45, false); + TestIncomingTcp(data, sizeof(data), true); +} + +TEST_F(WebRtcIcePacketFilterTest, TestSendStunPacket) { + nr_stun_message* msg; + ASSERT_EQ(0, nr_stun_build_req_no_auth(nullptr, &msg)); + msg->header.type = NR_STUN_MSG_BINDING_REQUEST; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoing(msg->buffer, msg->length, 123, 45, true); + TestOutgoingTcp(msg->buffer, msg->length, true); + TestOutgoingTcpFramed(msg->buffer, msg->length, true); + ASSERT_EQ(0, nr_stun_message_destroy(&msg)); +} + +TEST_F(WebRtcIcePacketFilterTest, TestRecvStunPacketWithoutAPendingId) { + nr_stun_message* msg; + ASSERT_EQ(0, nr_stun_build_req_no_auth(nullptr, &msg)); + + msg->header.id.octet[0] = 1; + msg->header.type = NR_STUN_MSG_BINDING_REQUEST; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoing(msg->buffer, msg->length, 123, 45, true); + TestOutgoingTcp(msg->buffer, msg->length, true); + + msg->header.id.octet[0] = 0; + msg->header.type = NR_STUN_MSG_BINDING_RESPONSE; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestIncoming(msg->buffer, msg->length, 123, 45, true); + TestIncomingTcp(msg->buffer, msg->length, true); + + ASSERT_EQ(0, nr_stun_message_destroy(&msg)); +} + +TEST_F(WebRtcIcePacketFilterTest, TestRecvStunBindingRequestWithoutAPendingId) { + nr_stun_message* msg; + ASSERT_EQ(0, nr_stun_build_req_no_auth(nullptr, &msg)); + + msg->header.id.octet[0] = 1; + msg->header.type = NR_STUN_MSG_BINDING_REQUEST; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestIncoming(msg->buffer, msg->length, 123, 45, true); + TestIncomingTcp(msg->buffer, msg->length, true); + + msg->header.id.octet[0] = 1; + msg->header.type = NR_STUN_MSG_BINDING_RESPONSE; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoing(msg->buffer, msg->length, 123, 45, true); + TestOutgoingTcp(msg->buffer, msg->length, true); + + ASSERT_EQ(0, nr_stun_message_destroy(&msg)); +} + +TEST_F(WebRtcIcePacketFilterTest, + TestRecvStunPacketWithoutAPendingIdTcpFramed) { + nr_stun_message* msg; + ASSERT_EQ(0, nr_stun_build_req_no_auth(nullptr, &msg)); + + msg->header.id.octet[0] = 1; + msg->header.type = NR_STUN_MSG_BINDING_REQUEST; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoingTcpFramed(msg->buffer, msg->length, true); + + msg->header.id.octet[0] = 0; + msg->header.type = NR_STUN_MSG_BINDING_RESPONSE; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestIncomingTcpFramed(msg->buffer, msg->length, true); + + ASSERT_EQ(0, nr_stun_message_destroy(&msg)); +} + +TEST_F(WebRtcIcePacketFilterTest, TestRecvStunPacketWithoutAPendingAddress) { + nr_stun_message* msg; + ASSERT_EQ(0, nr_stun_build_req_no_auth(nullptr, &msg)); + + msg->header.type = NR_STUN_MSG_BINDING_REQUEST; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoing(msg->buffer, msg->length, 123, 45, true); + // nothing to test here for the TCP filter + + msg->header.type = NR_STUN_MSG_BINDING_RESPONSE; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestIncoming(msg->buffer, msg->length, 123, 46, false); + TestIncoming(msg->buffer, msg->length, 124, 45, false); + + ASSERT_EQ(0, nr_stun_message_destroy(&msg)); +} + +TEST_F(WebRtcIcePacketFilterTest, TestRecvStunPacketWithPendingIdAndAddress) { + nr_stun_message* msg; + ASSERT_EQ(0, nr_stun_build_req_no_auth(nullptr, &msg)); + + msg->header.type = NR_STUN_MSG_BINDING_REQUEST; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoing(msg->buffer, msg->length, 123, 45, true); + TestOutgoingTcp(msg->buffer, msg->length, true); + + msg->header.type = NR_STUN_MSG_BINDING_RESPONSE; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestIncoming(msg->buffer, msg->length, 123, 45, true); + TestIncomingTcp(msg->buffer, msg->length, true); + + // Test whitelist by filtering non-stun packets. + const unsigned char data[] = "12345abcde"; + + // 123:45 is white-listed. + TestOutgoing(data, sizeof(data), 123, 45, true); + TestOutgoingTcp(data, sizeof(data), true); + TestIncoming(data, sizeof(data), 123, 45, true); + TestIncomingTcp(data, sizeof(data), true); + + // Indications pass as well. + msg->header.type = NR_STUN_MSG_BINDING_INDICATION; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoing(msg->buffer, msg->length, 123, 45, true); + TestOutgoingTcp(msg->buffer, msg->length, true); + TestIncoming(msg->buffer, msg->length, 123, 45, true); + TestIncomingTcp(msg->buffer, msg->length, true); + + // Packets from and to other address are still disallowed. + // Note: this doesn't apply for TCP connections + TestOutgoing(data, sizeof(data), 123, 46, false); + TestIncoming(data, sizeof(data), 123, 46, false); + TestOutgoing(data, sizeof(data), 124, 45, false); + TestIncoming(data, sizeof(data), 124, 45, false); + + ASSERT_EQ(0, nr_stun_message_destroy(&msg)); +} + +TEST_F(WebRtcIcePacketFilterTest, TestRecvStunPacketWithPendingIdTcpFramed) { + nr_stun_message* msg; + ASSERT_EQ(0, nr_stun_build_req_no_auth(nullptr, &msg)); + + msg->header.type = NR_STUN_MSG_BINDING_REQUEST; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoingTcpFramed(msg->buffer, msg->length, true); + + msg->header.type = NR_STUN_MSG_BINDING_RESPONSE; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestIncomingTcpFramed(msg->buffer, msg->length, true); + + // Test whitelist by filtering non-stun packets. + const unsigned char data[] = "12345abcde"; + + TestOutgoingTcpFramed(data, sizeof(data), true); + TestIncomingTcpFramed(data, sizeof(data), true); + + ASSERT_EQ(0, nr_stun_message_destroy(&msg)); +} + +TEST_F(WebRtcIcePacketFilterTest, TestSendNonRequestStunPacket) { + nr_stun_message* msg; + ASSERT_EQ(0, nr_stun_build_req_no_auth(nullptr, &msg)); + + msg->header.type = NR_STUN_MSG_BINDING_RESPONSE; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoing(msg->buffer, msg->length, 123, 45, false); + TestOutgoingTcp(msg->buffer, msg->length, false); + + // Send a packet so we allow the incoming request. + msg->header.type = NR_STUN_MSG_BINDING_REQUEST; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoing(msg->buffer, msg->length, 123, 45, true); + TestOutgoingTcp(msg->buffer, msg->length, true); + + // This packet makes us able to send a response. + msg->header.type = NR_STUN_MSG_BINDING_REQUEST; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestIncoming(msg->buffer, msg->length, 123, 45, true); + TestIncomingTcp(msg->buffer, msg->length, true); + + msg->header.type = NR_STUN_MSG_BINDING_RESPONSE; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoing(msg->buffer, msg->length, 123, 45, true); + TestOutgoingTcp(msg->buffer, msg->length, true); + + ASSERT_EQ(0, nr_stun_message_destroy(&msg)); +} + +TEST_F(WebRtcIcePacketFilterTest, TestRecvDataPacketWithAPendingAddress) { + nr_stun_message* msg; + ASSERT_EQ(0, nr_stun_build_req_no_auth(nullptr, &msg)); + + msg->header.type = NR_STUN_MSG_BINDING_REQUEST; + ASSERT_EQ(0, nr_stun_encode_message(msg)); + TestOutgoing(msg->buffer, msg->length, 123, 45, true); + TestOutgoingTcp(msg->buffer, msg->length, true); + + const unsigned char data[] = "12345abcde"; + TestIncoming(data, sizeof(data), 123, 45, true); + TestIncomingTcp(data, sizeof(data), true); + + ASSERT_EQ(0, nr_stun_message_destroy(&msg)); +} + +TEST(WebRtcIceInternalsTest, TestAddBogusAttribute) +{ + nr_stun_message* req; + ASSERT_EQ(0, nr_stun_message_create(&req)); + Data* data; + ASSERT_EQ(0, r_data_alloc(&data, 3000)); + memset(data->data, 'A', data->len); + ASSERT_TRUE(nr_stun_message_add_message_integrity_attribute(req, data)); + ASSERT_EQ(0, r_data_destroy(&data)); + ASSERT_EQ(0, nr_stun_message_destroy(&req)); +} diff --git a/dom/media/webrtc/transport/test/moz.build b/dom/media/webrtc/transport/test/moz.build new file mode 100644 index 0000000000..69d3a587a5 --- /dev/null +++ b/dom/media/webrtc/transport/test/moz.build @@ -0,0 +1,104 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +include("/ipc/chromium/chromium-config.mozbuild") + +if CONFIG["OS_TARGET"] != "WINNT": + + if CONFIG["OS_TARGET"] != "Android": + SOURCES += [ + "ice_unittest.cpp", + ] + + SOURCES += [ + "buffered_stun_socket_unittest.cpp", + "multi_tcp_socket_unittest.cpp", + "nrappkit_unittest.cpp", + "proxy_tunnel_socket_unittest.cpp", + "rlogconnector_unittest.cpp", + "runnable_utils_unittest.cpp", + "simpletokenbucket_unittest.cpp", + "sockettransportservice_unittest.cpp", + "stunserver.cpp", + "test_nr_socket_ice_unittest.cpp", + "test_nr_socket_unittest.cpp", + "TestSyncRunnable.cpp", + "transport_unittests.cpp", + "turn_unittest.cpp", + "webrtcproxychannel_unittest.cpp", + ] + + if CONFIG["MOZ_SCTP"]: + SOURCES += [ + "sctp_unittest.cpp", + ] + + +for var in ("HAVE_STRDUP", "NR_SOCKET_IS_VOID_PTR", "SCTP_DEBUG"): + DEFINES[var] = True + +if CONFIG["OS_TARGET"] == "Android": + DEFINES["LINUX"] = True + DEFINES["ANDROID"] = True + LOCAL_INCLUDES += [ + "/dom/media/webrtc/transport/third_party/nrappkit/src/port/android/include", + ] + +if CONFIG["OS_TARGET"] == "Linux": + DEFINES["LINUX"] = True + LOCAL_INCLUDES += [ + "/dom/media/webrtc/transport/third_party/nrappkit/src/port/linux/include", + ] + +if CONFIG["OS_TARGET"] == "Darwin": + LOCAL_INCLUDES += [ + "/dom/media/webrtc/transport/third_party/nrappkit/src/port/darwin/include", + ] + +if CONFIG["OS_TARGET"] in ("DragonFly", "FreeBSD", "NetBSD", "OpenBSD"): + if CONFIG["OS_TARGET"] == "Darwin": + DEFINES["DARWIN"] = True + else: + DEFINES["BSD"] = True + LOCAL_INCLUDES += [ + "/dom/media/webrtc/transport/third_party/nrappkit/src/port/darwin/include", + ] + +# SCTP DEFINES +if CONFIG["OS_TARGET"] == "WINNT": + DEFINES["WIN"] = True + # for stun.h + DEFINES["WIN32"] = True + DEFINES["__Userspace_os_Windows"] = 1 +else: + # Works for Darwin, Linux, Android. Probably doesn't work for others. + DEFINES["__Userspace_os_%s" % CONFIG["OS_TARGET"]] = 1 + +if CONFIG["OS_TARGET"] in ("Darwin", "Android"): + DEFINES["GTEST_USE_OWN_TR1_TUPLE"] = 1 + +LOCAL_INCLUDES += [ + "/dom/media/webrtc/transport/", + "/dom/media/webrtc/transport/third_party/", + "/dom/media/webrtc/transport/third_party/nICEr/src/crypto", + "/dom/media/webrtc/transport/third_party/nICEr/src/ice", + "/dom/media/webrtc/transport/third_party/nICEr/src/net", + "/dom/media/webrtc/transport/third_party/nICEr/src/stun", + "/dom/media/webrtc/transport/third_party/nICEr/src/util", + "/dom/media/webrtc/transport/third_party/nrappkit/src/event", + "/dom/media/webrtc/transport/third_party/nrappkit/src/log", + "/dom/media/webrtc/transport/third_party/nrappkit/src/plugin", + "/dom/media/webrtc/transport/third_party/nrappkit/src/port/generic/include", + "/dom/media/webrtc/transport/third_party/nrappkit/src/registry", + "/dom/media/webrtc/transport/third_party/nrappkit/src/share", + "/dom/media/webrtc/transport/third_party/nrappkit/src/stats", + "/dom/media/webrtc/transport/third_party/nrappkit/src/util/", + "/dom/media/webrtc/transport/third_party/nrappkit/src/util/libekr", + "/netwerk/sctp/src/", + "/xpcom/tests/", +] + +FINAL_LIBRARY = "xul-gtest" diff --git a/dom/media/webrtc/transport/test/mtransport_test_utils.h b/dom/media/webrtc/transport/test/mtransport_test_utils.h new file mode 100644 index 0000000000..04031c0dc2 --- /dev/null +++ b/dom/media/webrtc/transport/test/mtransport_test_utils.h @@ -0,0 +1,57 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: ekr@rtfm.com + +#ifndef mtransport_test_utils_h__ +#define mtransport_test_utils_h__ + +#include "nsCOMPtr.h" +#include "nsNetCID.h" + +#include "nsISerialEventTarget.h" +#include "nsPISocketTransportService.h" +#include "nsServiceManagerUtils.h" +#include "nsThreadUtils.h" + +class MtransportTestUtils { + public: + MtransportTestUtils() { InitServices(); } + + ~MtransportTestUtils() = default; + + void InitServices() { + nsresult rv; + sts_target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + sts_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + } + + nsISerialEventTarget* sts_target() { return sts_target_; } + + nsresult SyncDispatchToSTS(nsIRunnable* aRunnable) { + return SyncDispatchToSTS(do_AddRef(aRunnable)); + } + nsresult SyncDispatchToSTS(already_AddRefed<nsIRunnable>&& aRunnable) { + return NS_DispatchAndSpinEventLoopUntilComplete( + "MtransportTestUtils::SyncDispatchToSts"_ns, sts_target_, + std::move(aRunnable)); + } + + private: + nsCOMPtr<nsISerialEventTarget> sts_target_; + nsCOMPtr<nsPISocketTransportService> sts_; +}; + +#define CHECK_ENVIRONMENT_FLAG(envname) \ + char* test_flag = getenv(envname); \ + if (!test_flag || strcmp(test_flag, "1")) { \ + printf("To run this test set %s=1 in your environment\n", envname); \ + exit(0); \ + } + +#endif diff --git a/dom/media/webrtc/transport/test/multi_tcp_socket_unittest.cpp b/dom/media/webrtc/transport/test/multi_tcp_socket_unittest.cpp new file mode 100644 index 0000000000..d0c3ae6e53 --- /dev/null +++ b/dom/media/webrtc/transport/test/multi_tcp_socket_unittest.cpp @@ -0,0 +1,501 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <iostream> +#include <vector> + +#include "mozilla/Atomics.h" +#include "runnable_utils.h" +#include "pk11pub.h" + +extern "C" { +#include "nr_api.h" +#include "nr_socket.h" +#include "transport_addr.h" +#include "nr_socket_multi_tcp.h" +} + +#include "stunserver.h" + +#include "nricectx.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +using namespace mozilla; + +namespace { + +class MultiTcpSocketTest : public MtransportTest { + public: + MultiTcpSocketTest() + : MtransportTest(), socks(3, nullptr), readable(false), ice_ctx_() {} + + void SetUp() { + MtransportTest::SetUp(); + + NrIceCtx::InitializeGlobals(NrIceCtx::GlobalConfig()); + ice_ctx_ = NrIceCtx::Create("stun"); + + test_utils_->SyncDispatchToSTS( + WrapRunnableNM(&TestStunTcpServer::GetInstance, AF_INET)); + test_utils_->SyncDispatchToSTS( + WrapRunnableNM(&TestStunTcpServer::GetInstance, AF_INET6)); + } + + void TearDown() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &MultiTcpSocketTest::Shutdown_s)); + + MtransportTest::TearDown(); + } + + DISALLOW_COPY_ASSIGN(MultiTcpSocketTest); + + static void SockReadable(NR_SOCKET s, int how, void* arg) { + MultiTcpSocketTest* obj = static_cast<MultiTcpSocketTest*>(arg); + obj->SetReadable(true); + } + + void Shutdown_s() { + ice_ctx_ = nullptr; + for (auto& sock : socks) { + nr_socket_destroy(&sock); + } + } + + static uint16_t GetRandomPort() { + uint16_t result; + if (PK11_GenerateRandom((unsigned char*)&result, 2) != SECSuccess) { + MOZ_ASSERT(false); + return 0; + } + return result; + } + + static uint16_t EnsureEphemeral(uint16_t port) { + // IANA ephemeral port range (49152 to 65535) + return port | 49152; + } + + void Create_s(nr_socket_tcp_type tcp_type, std::string stun_server_addr, + uint16_t stun_server_port, nr_socket** sock) { + nr_transport_addr local; + // Get start of port range for test + static unsigned short port_s = GetRandomPort(); + int r; + + if (!stun_server_addr.empty()) { + std::vector<NrIceStunServer> stun_servers; + UniquePtr<NrIceStunServer> server(NrIceStunServer::Create( + stun_server_addr, stun_server_port, kNrIceTransportTcp)); + stun_servers.push_back(*server); + + ASSERT_TRUE(NS_SUCCEEDED(ice_ctx_->SetStunServers(stun_servers))); + } + + r = 1; + for (int tries = 10; tries && r; --tries) { + r = nr_str_port_to_transport_addr( + (char*)"127.0.0.1", EnsureEphemeral(port_s++), IPPROTO_TCP, &local); + ASSERT_EQ(0, r); + + r = nr_socket_multi_tcp_create(ice_ctx_->ctx(), nullptr, &local, tcp_type, + 1, 2048, sock); + } + + ASSERT_EQ(0, r); + printf("Creating socket on %s\n", local.as_string); + r = nr_socket_multi_tcp_set_readable_cb( + *sock, &MultiTcpSocketTest::SockReadable, this); + ASSERT_EQ(0, r); + } + + nr_socket* Create(nr_socket_tcp_type tcp_type, + std::string stun_server_addr = "", + uint16_t stun_server_port = 0) { + nr_socket* sock = nullptr; + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &MultiTcpSocketTest::Create_s, tcp_type, + stun_server_addr, stun_server_port, &sock)); + return sock; + } + + void Listen_s(nr_socket* sock) { + nr_transport_addr addr; + int r = nr_socket_getaddr(sock, &addr); + ASSERT_EQ(0, r); + printf("Listening on %s\n", addr.as_string); + r = nr_socket_listen(sock, 5); + ASSERT_EQ(0, r); + } + + void Listen(nr_socket* sock) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &MultiTcpSocketTest::Listen_s, sock)); + } + + void Destroy_s(nr_socket* sock) { + int r = nr_socket_destroy(&sock); + ASSERT_EQ(0, r); + } + + void Destroy(nr_socket* sock) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &MultiTcpSocketTest::Destroy_s, sock)); + } + + void Connect_s(nr_socket* from, nr_socket* to) { + nr_transport_addr addr_to; + nr_transport_addr addr_from; + int r = nr_socket_getaddr(to, &addr_to); + ASSERT_EQ(0, r); + r = nr_socket_getaddr(from, &addr_from); + ASSERT_EQ(0, r); + printf("Connecting from %s to %s\n", addr_from.as_string, + addr_to.as_string); + r = nr_socket_connect(from, &addr_to); + ASSERT_EQ(0, r); + } + + void Connect(nr_socket* from, nr_socket* to) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &MultiTcpSocketTest::Connect_s, from, to)); + } + + void ConnectSo_s(nr_socket* so1, nr_socket* so2) { + nr_transport_addr addr_so1; + nr_transport_addr addr_so2; + int r = nr_socket_getaddr(so1, &addr_so1); + ASSERT_EQ(0, r); + r = nr_socket_getaddr(so2, &addr_so2); + ASSERT_EQ(0, r); + printf("Connecting SO %s <-> %s\n", addr_so1.as_string, addr_so2.as_string); + r = nr_socket_connect(so1, &addr_so2); + ASSERT_EQ(0, r); + r = nr_socket_connect(so2, &addr_so1); + ASSERT_EQ(0, r); + } + + void ConnectSo(nr_socket* from, nr_socket* to) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &MultiTcpSocketTest::ConnectSo_s, from, to)); + } + + void SendDataToAddress_s(nr_socket* from, nr_transport_addr* to, + const char* data, size_t len) { + nr_transport_addr addr_from; + + int r = nr_socket_getaddr(from, &addr_from); + ASSERT_EQ(0, r); + printf("Sending %lu bytes %s -> %s\n", (unsigned long)len, + addr_from.as_string, to->as_string); + r = nr_socket_sendto(from, data, len, 0, to); + ASSERT_EQ(0, r); + } + + void SendData(nr_socket* from, nr_transport_addr* to, const char* data, + size_t len) { + test_utils_->SyncDispatchToSTS(WrapRunnable( + this, &MultiTcpSocketTest::SendDataToAddress_s, from, to, data, len)); + } + + void SendDataToSocket_s(nr_socket* from, nr_socket* to, const char* data, + size_t len) { + nr_transport_addr addr_to; + + int r = nr_socket_getaddr(to, &addr_to); + ASSERT_EQ(0, r); + SendDataToAddress_s(from, &addr_to, data, len); + } + + void SendData(nr_socket* from, nr_socket* to, const char* data, size_t len) { + test_utils_->SyncDispatchToSTS(WrapRunnable( + this, &MultiTcpSocketTest::SendDataToSocket_s, from, to, data, len)); + } + + void RecvDataFromAddress_s(nr_transport_addr* expected_from, + nr_socket* sent_to, const char* expected_data, + size_t expected_len) { + SetReadable(false); + size_t buflen = expected_len ? expected_len + 1 : 100; + char received_data[buflen]; + nr_transport_addr addr_to; + nr_transport_addr retaddr; + size_t retlen; + + int r = nr_socket_getaddr(sent_to, &addr_to); + ASSERT_EQ(0, r); + printf("Receiving %lu bytes %s <- %s\n", (unsigned long)expected_len, + addr_to.as_string, expected_from->as_string); + r = nr_socket_recvfrom(sent_to, received_data, buflen, &retlen, 0, + &retaddr); + ASSERT_EQ(0, r); + r = nr_transport_addr_cmp(&retaddr, expected_from, + NR_TRANSPORT_ADDR_CMP_MODE_ALL); + ASSERT_EQ(0, r); + // expected_len == 0 means we just expected some data + if (expected_len == 0) { + ASSERT_GT(retlen, 0U); + } else { + ASSERT_EQ(expected_len, retlen); + r = memcmp(expected_data, received_data, retlen); + ASSERT_EQ(0, r); + } + } + + void RecvData(nr_transport_addr* expected_from, nr_socket* sent_to, + const char* expected_data = nullptr, size_t expected_len = 0) { + ASSERT_TRUE_WAIT(IsReadable(), 1000); + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &MultiTcpSocketTest::RecvDataFromAddress_s, + expected_from, sent_to, expected_data, expected_len)); + } + + void RecvDataFromSocket_s(nr_socket* expected_from, nr_socket* sent_to, + const char* expected_data, size_t expected_len) { + nr_transport_addr addr_from; + + int r = nr_socket_getaddr(expected_from, &addr_from); + ASSERT_EQ(0, r); + + RecvDataFromAddress_s(&addr_from, sent_to, expected_data, expected_len); + } + + void RecvData(nr_socket* expected_from, nr_socket* sent_to, + const char* expected_data, size_t expected_len) { + ASSERT_TRUE_WAIT(IsReadable(), 1000); + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &MultiTcpSocketTest::RecvDataFromSocket_s, + expected_from, sent_to, expected_data, expected_len)); + } + + void RecvDataFailed_s(nr_socket* sent_to, size_t expected_len, + int expected_err) { + SetReadable(false); + char received_data[expected_len + 1]; + nr_transport_addr addr_to; + nr_transport_addr retaddr; + size_t retlen; + + int r = nr_socket_getaddr(sent_to, &addr_to); + ASSERT_EQ(0, r); + r = nr_socket_recvfrom(sent_to, received_data, expected_len + 1, &retlen, 0, + &retaddr); + ASSERT_EQ(expected_err, r) << "Expecting receive failure " << expected_err + << " on " << addr_to.as_string; + } + + void RecvDataFailed(nr_socket* sent_to, size_t expected_len, + int expected_err) { + ASSERT_TRUE_WAIT(IsReadable(), 1000); + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &MultiTcpSocketTest::RecvDataFailed_s, sent_to, + expected_len, expected_err)); + } + + void TransferData(nr_socket* from, nr_socket* to, const char* data, + size_t len) { + SendData(from, to, data, len); + RecvData(from, to, data, len); + } + + protected: + bool IsReadable() const { return readable; } + void SetReadable(bool r) { readable = r; } + std::vector<nr_socket*> socks; + Atomic<bool> readable; + RefPtr<NrIceCtx> ice_ctx_; +}; +} // namespace + +TEST_F(MultiTcpSocketTest, TestListen) { + socks[0] = Create(TCP_TYPE_PASSIVE); + Listen(socks[0]); +} + +TEST_F(MultiTcpSocketTest, TestConnect) { + socks[0] = Create(TCP_TYPE_PASSIVE); + socks[1] = Create(TCP_TYPE_ACTIVE); + socks[2] = Create(TCP_TYPE_ACTIVE); + Listen(socks[0]); + Connect(socks[1], socks[0]); + Connect(socks[2], socks[0]); +} + +TEST_F(MultiTcpSocketTest, TestTransmit) { + const char data[] = "TestTransmit"; + socks[0] = Create(TCP_TYPE_ACTIVE); + socks[1] = Create(TCP_TYPE_PASSIVE); + Listen(socks[1]); + Connect(socks[0], socks[1]); + + TransferData(socks[0], socks[1], data, sizeof(data)); + TransferData(socks[1], socks[0], data, sizeof(data)); +} + +TEST_F(MultiTcpSocketTest, TestClosePassive) { + const char data[] = "TestClosePassive"; + socks[0] = Create(TCP_TYPE_ACTIVE); + socks[1] = Create(TCP_TYPE_PASSIVE); + Listen(socks[1]); + Connect(socks[0], socks[1]); + + TransferData(socks[0], socks[1], data, sizeof(data)); + TransferData(socks[1], socks[0], data, sizeof(data)); + + /* We have to destroy as only that calls PR_Close() */ + std::cerr << "Destructing socket" << std::endl; + Destroy(socks[1]); + + RecvDataFailed(socks[0], sizeof(data), R_EOD); + + socks[1] = nullptr; +} + +TEST_F(MultiTcpSocketTest, TestCloseActive) { + const char data[] = "TestCloseActive"; + socks[0] = Create(TCP_TYPE_ACTIVE); + socks[1] = Create(TCP_TYPE_PASSIVE); + Listen(socks[1]); + Connect(socks[0], socks[1]); + + TransferData(socks[0], socks[1], data, sizeof(data)); + TransferData(socks[1], socks[0], data, sizeof(data)); + + /* We have to destroy as only that calls PR_Close() */ + std::cerr << "Destructing socket" << std::endl; + Destroy(socks[0]); + + RecvDataFailed(socks[1], sizeof(data), R_EOD); + + socks[0] = nullptr; +} + +TEST_F(MultiTcpSocketTest, TestTwoSendsBeforeReceives) { + const char data1[] = "TestTwoSendsBeforeReceives"; + const char data2[] = "2nd data"; + socks[0] = Create(TCP_TYPE_ACTIVE); + socks[1] = Create(TCP_TYPE_PASSIVE); + Listen(socks[1]); + Connect(socks[0], socks[1]); + + SendData(socks[0], socks[1], data1, sizeof(data1)); + SendData(socks[0], socks[1], data2, sizeof(data2)); + RecvData(socks[0], socks[1], data1, sizeof(data1)); + /* ICE TCP framing turns TCP effectively into datagram mode */ + RecvData(socks[0], socks[1], data2, sizeof(data2)); +} + +TEST_F(MultiTcpSocketTest, TestTwoActiveBidirectionalTransmit) { + const char data1[] = "TestTwoActiveBidirectionalTransmit"; + const char data2[] = "ReplyToTheFirstSocket"; + const char data3[] = "TestMessageFromTheSecondSocket"; + const char data4[] = "ThisIsAReplyToTheSecondSocket"; + socks[0] = Create(TCP_TYPE_PASSIVE); + socks[1] = Create(TCP_TYPE_ACTIVE); + socks[2] = Create(TCP_TYPE_ACTIVE); + Listen(socks[0]); + Connect(socks[1], socks[0]); + Connect(socks[2], socks[0]); + + TransferData(socks[1], socks[0], data1, sizeof(data1)); + TransferData(socks[0], socks[1], data2, sizeof(data2)); + TransferData(socks[2], socks[0], data3, sizeof(data3)); + TransferData(socks[0], socks[2], data4, sizeof(data4)); +} + +TEST_F(MultiTcpSocketTest, TestTwoPassiveBidirectionalTransmit) { + const char data1[] = "TestTwoPassiveBidirectionalTransmit"; + const char data2[] = "FirstReply"; + const char data3[] = "TestTwoPassiveBidirectionalTransmitToTheSecondSock"; + const char data4[] = "SecondReply"; + socks[0] = Create(TCP_TYPE_PASSIVE); + socks[1] = Create(TCP_TYPE_PASSIVE); + socks[2] = Create(TCP_TYPE_ACTIVE); + Listen(socks[0]); + Listen(socks[1]); + Connect(socks[2], socks[0]); + Connect(socks[2], socks[1]); + + TransferData(socks[2], socks[0], data1, sizeof(data1)); + TransferData(socks[0], socks[2], data2, sizeof(data2)); + TransferData(socks[2], socks[1], data3, sizeof(data3)); + TransferData(socks[1], socks[2], data4, sizeof(data4)); +} + +TEST_F(MultiTcpSocketTest, TestActivePassiveWithStunServerMockup) { + /* Fake STUN message able to pass the nr_is_stun_msg check + used in nr_socket_buffered_stun */ + const char stunMessage[] = {'\x00', '\x01', '\x00', '\x04', '\x21', '\x12', + '\xa4', '\x42', '\x00', '\x00', '\x00', '\x00', + '\x00', '\x00', '\x0c', '\x00', '\x00', '\x00', + '\x00', '\x00', '\x1c', '\xed', '\xca', '\xfe'}; + const char data[] = "TestActivePassiveWithStunServerMockup"; + + nr_transport_addr stun_srv_addr; + std::string stun_addr; + uint16_t stun_port; + stun_addr = TestStunTcpServer::GetInstance(AF_INET)->addr(); + stun_port = TestStunTcpServer::GetInstance(AF_INET)->port(); + int r = nr_str_port_to_transport_addr(stun_addr.c_str(), stun_port, + IPPROTO_TCP, &stun_srv_addr); + ASSERT_EQ(0, r); + + socks[0] = Create(TCP_TYPE_PASSIVE, stun_addr, stun_port); + Listen(socks[0]); + socks[1] = Create(TCP_TYPE_ACTIVE, stun_addr, stun_port); + + /* Send a fake STUN request and expect a STUN error response */ + SendData(socks[0], &stun_srv_addr, stunMessage, sizeof(stunMessage)); + RecvData(&stun_srv_addr, socks[0]); + + Connect(socks[1], socks[0]); + TransferData(socks[1], socks[0], data, sizeof(data)); + TransferData(socks[0], socks[1], data, sizeof(data)); +} + +TEST_F(MultiTcpSocketTest, TestConnectTwoSo) { + socks[0] = Create(TCP_TYPE_SO); + socks[1] = Create(TCP_TYPE_SO); + ConnectSo(socks[0], socks[1]); +} + +// test works on localhost only with delay applied: +// tc qdisc add dev lo root netem delay 5ms +TEST_F(MultiTcpSocketTest, DISABLED_TestTwoSoBidirectionalTransmit) { + const char data[] = "TestTwoSoBidirectionalTransmit"; + socks[0] = Create(TCP_TYPE_SO); + socks[1] = Create(TCP_TYPE_SO); + ConnectSo(socks[0], socks[1]); + TransferData(socks[0], socks[1], data, sizeof(data)); + TransferData(socks[1], socks[0], data, sizeof(data)); +} + +TEST_F(MultiTcpSocketTest, TestBigData) { + char buf1[2048]; + char buf2[1024]; + + for (unsigned i = 0; i < sizeof(buf1); ++i) { + buf1[i] = i & 0xff; + } + for (unsigned i = 0; i < sizeof(buf2); ++i) { + buf2[i] = (i + 0x80) & 0xff; + } + socks[0] = Create(TCP_TYPE_ACTIVE); + socks[1] = Create(TCP_TYPE_PASSIVE); + Listen(socks[1]); + Connect(socks[0], socks[1]); + + TransferData(socks[0], socks[1], buf1, sizeof(buf1)); + TransferData(socks[0], socks[1], buf2, sizeof(buf2)); + // opposite dir + SendData(socks[1], socks[0], buf2, sizeof(buf2)); + SendData(socks[1], socks[0], buf1, sizeof(buf1)); + RecvData(socks[1], socks[0], buf2, sizeof(buf2)); + RecvData(socks[1], socks[0], buf1, sizeof(buf1)); +} diff --git a/dom/media/webrtc/transport/test/nrappkit_unittest.cpp b/dom/media/webrtc/transport/test/nrappkit_unittest.cpp new file mode 100644 index 0000000000..b6a63fb993 --- /dev/null +++ b/dom/media/webrtc/transport/test/nrappkit_unittest.cpp @@ -0,0 +1,123 @@ + +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: ekr@rtfm.com +#include <iostream> + +// nrappkit includes +extern "C" { +#include "nr_api.h" +#include "async_timer.h" +} + +#include "runnable_utils.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +using namespace mozilla; + +namespace { + +class TimerTest : public MtransportTest { + public: + TimerTest() : MtransportTest(), handle_(nullptr), fired_(false) {} + virtual ~TimerTest() = default; + + int ArmTimer(int timeout) { + int ret; + + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&ret, this, &TimerTest::ArmTimer_w, timeout)); + + return ret; + } + + int ArmCancelTimer(int timeout) { + int ret; + + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&ret, this, &TimerTest::ArmCancelTimer_w, timeout)); + + return ret; + } + + int ArmTimer_w(int timeout) { + return NR_ASYNC_TIMER_SET(timeout, cb, this, &handle_); + } + + int ArmCancelTimer_w(int timeout) { + int r; + r = ArmTimer_w(timeout); + if (r) return r; + + return CancelTimer_w(); + } + + int CancelTimer() { + int ret; + + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&ret, this, &TimerTest::CancelTimer_w)); + + return ret; + } + + int CancelTimer_w() { return NR_async_timer_cancel(handle_); } + + int Schedule() { + int ret; + + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&ret, this, &TimerTest::Schedule_w)); + + return ret; + } + + int Schedule_w() { + NR_ASYNC_SCHEDULE(cb, this); + + return 0; + } + + static void cb(NR_SOCKET r, int how, void* arg) { + std::cerr << "Timer fired " << std::endl; + + TimerTest* t = static_cast<TimerTest*>(arg); + + t->fired_ = true; + } + + protected: + void* handle_; + bool fired_; +}; +} // namespace + +TEST_F(TimerTest, SimpleTimer) { + ArmTimer(100); + ASSERT_TRUE_WAIT(fired_, 1000); +} + +TEST_F(TimerTest, CancelTimer) { + ArmTimer(1000); + CancelTimer(); + PR_Sleep(2000); + ASSERT_FALSE(fired_); +} + +TEST_F(TimerTest, CancelTimer0) { + ArmCancelTimer(0); + PR_Sleep(100); + ASSERT_FALSE(fired_); +} + +TEST_F(TimerTest, ScheduleTest) { + Schedule(); + ASSERT_TRUE_WAIT(fired_, 1000); +} diff --git a/dom/media/webrtc/transport/test/proxy_tunnel_socket_unittest.cpp b/dom/media/webrtc/transport/test/proxy_tunnel_socket_unittest.cpp new file mode 100644 index 0000000000..1b54126dd6 --- /dev/null +++ b/dom/media/webrtc/transport/test/proxy_tunnel_socket_unittest.cpp @@ -0,0 +1,277 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original authors: ekr@rtfm.com; ryan@tokbox.com + +#include <vector> +#include <numeric> + +#include "nr_socket_tcp.h" +#include "WebrtcTCPSocketWrapper.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +using namespace mozilla; + +// update TestReadMultipleSizes if you change this +const std::string kHelloMessage = "HELLO IS IT ME YOU'RE LOOKING FOR?"; + +class NrTcpSocketTest : public MtransportTest { + public: + NrTcpSocketTest() + : mSProxy(nullptr), + nr_socket_(nullptr), + mEmptyArray(0), + mReadChunkSize(0), + mReadChunkSizeIncrement(1), + mReadAllowance(-1), + mConnected(false) {} + + void SetUp() override { + mSProxy = new NrTcpSocket(nullptr); + int r = nr_socket_create_int((void*)mSProxy.get(), mSProxy->vtbl(), + &nr_socket_); + ASSERT_EQ(0, r); + + // fake calling AsyncOpen() due to IPC calls. must be non-null + mSProxy->AssignChannel_DoNotUse(new WebrtcTCPSocketWrapper(nullptr)); + } + + void TearDown() override { mSProxy->close(); } + + static void readable_cb(NR_SOCKET s, int how, void* cb_arg) { + NrTcpSocketTest* test = (NrTcpSocketTest*)cb_arg; + size_t capacity = std::min(test->mReadChunkSize, test->mReadAllowance); + nsTArray<uint8_t> array(capacity); + size_t read; + + nr_socket_read(test->nr_socket_, (char*)array.Elements(), array.Capacity(), + &read, 0); + + ASSERT_TRUE(read <= array.Capacity()); + ASSERT_TRUE(test->mReadAllowance >= read); + + array.SetLength(read); + test->mData.AppendElements(array); + test->mReadAllowance -= read; + + // We may read more bytes each time we're called. This way we can ensure we + // consume buffers partially and across multiple buffers. + test->mReadChunkSize += test->mReadChunkSizeIncrement; + + if (test->mReadAllowance > 0) { + NR_ASYNC_WAIT(s, how, &NrTcpSocketTest::readable_cb, cb_arg); + } + } + + static void writable_cb(NR_SOCKET s, int how, void* cb_arg) { + NrTcpSocketTest* test = (NrTcpSocketTest*)cb_arg; + test->mConnected = true; + } + + const std::string DataString() { + return std::string((char*)mData.Elements(), mData.Length()); + } + + protected: + RefPtr<NrTcpSocket> mSProxy; + nr_socket* nr_socket_; + + nsTArray<uint8_t> mData; + nsTArray<uint8_t> mEmptyArray; + + uint32_t mReadChunkSize; + uint32_t mReadChunkSizeIncrement; + uint32_t mReadAllowance; + + bool mConnected; +}; + +TEST_F(NrTcpSocketTest, TestCreate) {} + +TEST_F(NrTcpSocketTest, TestConnected) { + ASSERT_TRUE(!mConnected); + + NR_ASYNC_WAIT(mSProxy, NR_ASYNC_WAIT_WRITE, &NrTcpSocketTest::writable_cb, + this); + + // still not connected just registered for writes... + ASSERT_TRUE(!mConnected); + + mSProxy->OnConnected("http"_ns); + + ASSERT_TRUE(mConnected); +} + +TEST_F(NrTcpSocketTest, TestRead) { + nsTArray<uint8_t> array; + array.AppendElements(kHelloMessage.c_str(), kHelloMessage.length()); + + NR_ASYNC_WAIT(mSProxy, NR_ASYNC_WAIT_READ, &NrTcpSocketTest::readable_cb, + this); + // this will read 0 bytes here + mSProxy->OnRead(std::move(array)); + + ASSERT_EQ(kHelloMessage.length(), mSProxy->CountUnreadBytes()); + + // callback is still set but terminated due to 0 byte read + // start callbacks again (first read is 0 then 1,2,3,...) + mSProxy->OnRead(std::move(mEmptyArray)); + + ASSERT_EQ(kHelloMessage.length(), mData.Length()); + ASSERT_EQ(kHelloMessage, DataString()); +} + +TEST_F(NrTcpSocketTest, TestReadConstantConsumeSize) { + std::string data; + + // triangle number + const int kCount = 32; + + // ~17kb + // triangle number formula n*(n+1)/2 + for (int i = 0; i < kCount * (kCount + 1) / 2; ++i) { + data += kHelloMessage; + } + + // decreasing buffer sizes + for (int i = 0, start = 0; i < kCount; ++i) { + int length = (kCount - i) * kHelloMessage.length(); + + nsTArray<uint8_t> array; + array.AppendElements(data.c_str() + start, length); + start += length; + + mSProxy->OnRead(std::move(array)); + } + + ASSERT_EQ(data.length(), mSProxy->CountUnreadBytes()); + + // read same amount each callback + mReadChunkSize = 128; + mReadChunkSizeIncrement = 0; + NR_ASYNC_WAIT(mSProxy, NR_ASYNC_WAIT_READ, &NrTcpSocketTest::readable_cb, + this); + + ASSERT_EQ(data.length(), mSProxy->CountUnreadBytes()); + + // start callbacks + mSProxy->OnRead(std::move(mEmptyArray)); + + ASSERT_EQ(data.length(), mData.Length()); + ASSERT_EQ(data, DataString()); +} + +TEST_F(NrTcpSocketTest, TestReadNone) { + char buf[4096]; + size_t read = 0; + int r = nr_socket_read(nr_socket_, buf, sizeof(buf), &read, 0); + + ASSERT_EQ(R_WOULDBLOCK, r); + + nsTArray<uint8_t> array; + array.AppendElements(kHelloMessage.c_str(), kHelloMessage.length()); + mSProxy->OnRead(std::move(array)); + + ASSERT_EQ(kHelloMessage.length(), mSProxy->CountUnreadBytes()); + + r = nr_socket_read(nr_socket_, buf, sizeof(buf), &read, 0); + + ASSERT_EQ(0, r); + ASSERT_EQ(kHelloMessage.length(), read); + ASSERT_EQ(kHelloMessage, std::string(buf, read)); +} + +TEST_F(NrTcpSocketTest, TestReadMultipleSizes) { + using namespace std; + + string data; + // 515 * kHelloMessage.length() == 17510 + const size_t kCount = 515; + // randomly generated numbers, sums to 17510, 20 numbers + vector<int> varyingSizes = {404, 622, 1463, 1597, 1676, 389, 389, + 1272, 781, 81, 1030, 1450, 256, 812, + 1571, 29, 1045, 911, 643, 1089}; + + // changing varyingSizes or the test message breaks this so check here + ASSERT_EQ(kCount, 17510 / kHelloMessage.length()); + ASSERT_EQ(17510, accumulate(varyingSizes.begin(), varyingSizes.end(), 0)); + + // ~17kb + for (size_t i = 0; i < kCount; ++i) { + data += kHelloMessage; + } + + nsTArray<uint8_t> array; + array.AppendElements(data.c_str(), data.length()); + + for (int amountToRead : varyingSizes) { + nsTArray<uint8_t> buffer; + buffer.AppendElements(array.Elements(), amountToRead); + array.RemoveElementsAt(0, amountToRead); + mSProxy->OnRead(std::move(buffer)); + } + + ASSERT_EQ(data.length(), mSProxy->CountUnreadBytes()); + + // don't need to read 0 on the first read, so start at 1 and keep going + mReadChunkSize = 1; + NR_ASYNC_WAIT(mSProxy, NR_ASYNC_WAIT_READ, &NrTcpSocketTest::readable_cb, + this); + // start callbacks + mSProxy->OnRead(std::move(mEmptyArray)); + + ASSERT_EQ(data.length(), mData.Length()); + ASSERT_EQ(data, DataString()); +} + +TEST_F(NrTcpSocketTest, TestReadConsumeReadDrain) { + std::string data; + // ~26kb total; should be even + const int kCount = 512; + + // there's some division by 2 here so check that kCount is even + ASSERT_EQ(0, kCount % 2); + + for (int i = 0; i < kCount; ++i) { + data += kHelloMessage; + nsTArray<uint8_t> array; + array.AppendElements(kHelloMessage.c_str(), kHelloMessage.length()); + mSProxy->OnRead(std::move(array)); + } + + // read half at first + mReadAllowance = kCount / 2 * kHelloMessage.length(); + // start by reading 1 byte + mReadChunkSize = 1; + NR_ASYNC_WAIT(mSProxy, NR_ASYNC_WAIT_READ, &NrTcpSocketTest::readable_cb, + this); + mSProxy->OnRead(std::move(mEmptyArray)); + + ASSERT_EQ(data.length() / 2, mSProxy->CountUnreadBytes()); + ASSERT_EQ(data.length() / 2, mData.Length()); + + // fill read buffer back up + for (int i = 0; i < kCount / 2; ++i) { + data += kHelloMessage; + nsTArray<uint8_t> array; + array.AppendElements(kHelloMessage.c_str(), kHelloMessage.length()); + mSProxy->OnRead(std::move(array)); + } + + // remove read limit + mReadAllowance = -1; + // used entire read allowance so we need to setup a new await + NR_ASYNC_WAIT(mSProxy, NR_ASYNC_WAIT_READ, &NrTcpSocketTest::readable_cb, + this); + // start callbacks + mSProxy->OnRead(std::move(mEmptyArray)); + + ASSERT_EQ(data.length(), mData.Length()); + ASSERT_EQ(data, DataString()); +} diff --git a/dom/media/webrtc/transport/test/rlogconnector_unittest.cpp b/dom/media/webrtc/transport/test/rlogconnector_unittest.cpp new file mode 100644 index 0000000000..93fabae481 --- /dev/null +++ b/dom/media/webrtc/transport/test/rlogconnector_unittest.cpp @@ -0,0 +1,255 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +/* Original author: bcampen@mozilla.com */ + +#include "rlogconnector.h" + +extern "C" { +#include "registry.h" +#include "r_log.h" +} + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +#include <deque> +#include <string> +#include <vector> + +using mozilla::RLogConnector; + +int NR_LOG_TEST = 0; + +class RLogConnectorTest : public ::testing::Test { + public: + RLogConnectorTest() { Init(); } + + ~RLogConnectorTest() { Free(); } + + static void SetUpTestCase() { + NR_reg_init(NR_REG_MODE_LOCAL); + r_log_init(); + /* Would be nice to be able to unregister in the fixture */ + const char* facility = "rlogconnector_test"; + r_log_register(const_cast<char*>(facility), &NR_LOG_TEST); + } + + void Init() { RLogConnector::CreateInstance(); } + + void Free() { RLogConnector::DestroyInstance(); } + + void ReInit() { + Free(); + Init(); + } +}; + +TEST_F(RLogConnectorTest, TestGetFree) { + RLogConnector* instance = RLogConnector::GetInstance(); + ASSERT_NE(nullptr, instance); +} + +TEST_F(RLogConnectorTest, TestFilterEmpty) { + std::deque<std::string> logs; + RLogConnector::GetInstance()->GetAny(0, &logs); + ASSERT_EQ(0U, logs.size()); +} + +TEST_F(RLogConnectorTest, TestBasicFilter) { + r_log(NR_LOG_TEST, LOG_INFO, "Test"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->Filter("Test", 0, &logs); + ASSERT_EQ(1U, logs.size()); +} + +TEST_F(RLogConnectorTest, TestBasicFilterContent) { + r_log(NR_LOG_TEST, LOG_INFO, "Test"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->Filter("Test", 0, &logs); + ASSERT_EQ("Test", logs.back()); +} + +TEST_F(RLogConnectorTest, TestFilterAnyFrontMatch) { + r_log(NR_LOG_TEST, LOG_INFO, "Test"); + std::vector<std::string> substrings; + substrings.push_back("foo"); + substrings.push_back("Test"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->FilterAny(substrings, 0, &logs); + ASSERT_EQ("Test", logs.back()); +} + +TEST_F(RLogConnectorTest, TestFilterAnyBackMatch) { + r_log(NR_LOG_TEST, LOG_INFO, "Test"); + std::vector<std::string> substrings; + substrings.push_back("Test"); + substrings.push_back("foo"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->FilterAny(substrings, 0, &logs); + ASSERT_EQ("Test", logs.back()); +} + +TEST_F(RLogConnectorTest, TestFilterAnyBothMatch) { + r_log(NR_LOG_TEST, LOG_INFO, "Test"); + std::vector<std::string> substrings; + substrings.push_back("Tes"); + substrings.push_back("est"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->FilterAny(substrings, 0, &logs); + ASSERT_EQ("Test", logs.back()); +} + +TEST_F(RLogConnectorTest, TestFilterAnyNeitherMatch) { + r_log(NR_LOG_TEST, LOG_INFO, "Test"); + std::vector<std::string> substrings; + substrings.push_back("tes"); + substrings.push_back("esT"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->FilterAny(substrings, 0, &logs); + ASSERT_EQ(0U, logs.size()); +} + +TEST_F(RLogConnectorTest, TestAllMatch) { + r_log(NR_LOG_TEST, LOG_INFO, "Test1"); + r_log(NR_LOG_TEST, LOG_INFO, "Test2"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->GetAny(0, &logs); + ASSERT_EQ(2U, logs.size()); +} + +TEST_F(RLogConnectorTest, TestOrder) { + r_log(NR_LOG_TEST, LOG_INFO, "Test1"); + r_log(NR_LOG_TEST, LOG_INFO, "Test2"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->GetAny(0, &logs); + ASSERT_EQ("Test2", logs.back()); + ASSERT_EQ("Test1", logs.front()); +} + +TEST_F(RLogConnectorTest, TestNoMatch) { + r_log(NR_LOG_TEST, LOG_INFO, "Test1"); + r_log(NR_LOG_TEST, LOG_INFO, "Test2"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->Filter("foo", 0, &logs); + ASSERT_EQ(0U, logs.size()); +} + +TEST_F(RLogConnectorTest, TestSubstringFilter) { + r_log(NR_LOG_TEST, LOG_INFO, "Test1"); + r_log(NR_LOG_TEST, LOG_INFO, "Test2"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->Filter("t1", 0, &logs); + ASSERT_EQ(1U, logs.size()); + ASSERT_EQ("Test1", logs.back()); +} + +TEST_F(RLogConnectorTest, TestFilterLimit) { + r_log(NR_LOG_TEST, LOG_INFO, "Test1"); + r_log(NR_LOG_TEST, LOG_INFO, "Test2"); + r_log(NR_LOG_TEST, LOG_INFO, "Test3"); + r_log(NR_LOG_TEST, LOG_INFO, "Test4"); + r_log(NR_LOG_TEST, LOG_INFO, "Test5"); + r_log(NR_LOG_TEST, LOG_INFO, "Test6"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->Filter("Test", 2, &logs); + ASSERT_EQ(2U, logs.size()); + ASSERT_EQ("Test6", logs.back()); + ASSERT_EQ("Test5", logs.front()); +} + +TEST_F(RLogConnectorTest, TestFilterAnyLimit) { + r_log(NR_LOG_TEST, LOG_INFO, "TestOne"); + r_log(NR_LOG_TEST, LOG_INFO, "TestTwo"); + r_log(NR_LOG_TEST, LOG_INFO, "TestThree"); + r_log(NR_LOG_TEST, LOG_INFO, "TestFour"); + r_log(NR_LOG_TEST, LOG_INFO, "TestFive"); + r_log(NR_LOG_TEST, LOG_INFO, "TestSix"); + std::vector<std::string> substrings; + // Matches Two, Three, Four, and Six + substrings.push_back("tT"); + substrings.push_back("o"); + substrings.push_back("r"); + substrings.push_back("S"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->FilterAny(substrings, 2, &logs); + ASSERT_EQ(2U, logs.size()); + ASSERT_EQ("TestSix", logs.back()); + ASSERT_EQ("TestFour", logs.front()); +} + +TEST_F(RLogConnectorTest, TestLimit) { + RLogConnector::GetInstance()->SetLogLimit(3); + r_log(NR_LOG_TEST, LOG_INFO, "Test1"); + r_log(NR_LOG_TEST, LOG_INFO, "Test2"); + r_log(NR_LOG_TEST, LOG_INFO, "Test3"); + r_log(NR_LOG_TEST, LOG_INFO, "Test4"); + r_log(NR_LOG_TEST, LOG_INFO, "Test5"); + r_log(NR_LOG_TEST, LOG_INFO, "Test6"); + std::deque<std::string> logs; + RLogConnector::GetInstance()->GetAny(0, &logs); + ASSERT_EQ(3U, logs.size()); + ASSERT_EQ("Test6", logs.back()); + ASSERT_EQ("Test4", logs.front()); +} + +TEST_F(RLogConnectorTest, TestLimitBulkDiscard) { + r_log(NR_LOG_TEST, LOG_INFO, "Test1"); + r_log(NR_LOG_TEST, LOG_INFO, "Test2"); + r_log(NR_LOG_TEST, LOG_INFO, "Test3"); + r_log(NR_LOG_TEST, LOG_INFO, "Test4"); + r_log(NR_LOG_TEST, LOG_INFO, "Test5"); + r_log(NR_LOG_TEST, LOG_INFO, "Test6"); + RLogConnector::GetInstance()->SetLogLimit(3); + std::deque<std::string> logs; + RLogConnector::GetInstance()->GetAny(0, &logs); + ASSERT_EQ(3U, logs.size()); + ASSERT_EQ("Test6", logs.back()); + ASSERT_EQ("Test4", logs.front()); +} + +TEST_F(RLogConnectorTest, TestIncreaseLimit) { + RLogConnector::GetInstance()->SetLogLimit(3); + r_log(NR_LOG_TEST, LOG_INFO, "Test1"); + r_log(NR_LOG_TEST, LOG_INFO, "Test2"); + r_log(NR_LOG_TEST, LOG_INFO, "Test3"); + r_log(NR_LOG_TEST, LOG_INFO, "Test4"); + r_log(NR_LOG_TEST, LOG_INFO, "Test5"); + r_log(NR_LOG_TEST, LOG_INFO, "Test6"); + RLogConnector::GetInstance()->SetLogLimit(300); + std::deque<std::string> logs; + RLogConnector::GetInstance()->GetAny(0, &logs); + ASSERT_EQ(3U, logs.size()); + ASSERT_EQ("Test6", logs.back()); + ASSERT_EQ("Test4", logs.front()); +} + +TEST_F(RLogConnectorTest, TestClear) { + r_log(NR_LOG_TEST, LOG_INFO, "Test1"); + r_log(NR_LOG_TEST, LOG_INFO, "Test2"); + r_log(NR_LOG_TEST, LOG_INFO, "Test3"); + r_log(NR_LOG_TEST, LOG_INFO, "Test4"); + r_log(NR_LOG_TEST, LOG_INFO, "Test5"); + r_log(NR_LOG_TEST, LOG_INFO, "Test6"); + RLogConnector::GetInstance()->SetLogLimit(0); + RLogConnector::GetInstance()->SetLogLimit(4096); + std::deque<std::string> logs; + RLogConnector::GetInstance()->GetAny(0, &logs); + ASSERT_EQ(0U, logs.size()); +} + +TEST_F(RLogConnectorTest, TestReInit) { + r_log(NR_LOG_TEST, LOG_INFO, "Test1"); + r_log(NR_LOG_TEST, LOG_INFO, "Test2"); + r_log(NR_LOG_TEST, LOG_INFO, "Test3"); + r_log(NR_LOG_TEST, LOG_INFO, "Test4"); + r_log(NR_LOG_TEST, LOG_INFO, "Test5"); + r_log(NR_LOG_TEST, LOG_INFO, "Test6"); + ReInit(); + std::deque<std::string> logs; + RLogConnector::GetInstance()->GetAny(0, &logs); + ASSERT_EQ(0U, logs.size()); +} diff --git a/dom/media/webrtc/transport/test/runnable_utils_unittest.cpp b/dom/media/webrtc/transport/test/runnable_utils_unittest.cpp new file mode 100644 index 0000000000..70707b148f --- /dev/null +++ b/dom/media/webrtc/transport/test/runnable_utils_unittest.cpp @@ -0,0 +1,353 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: ekr@rtfm.com +#include <iostream> + +#include "nsCOMPtr.h" +#include "nsNetCID.h" + +#include "mozilla/RefPtr.h" +#include "mozilla/UniquePtr.h" + +#include "nsServiceManagerUtils.h" +#include "nsThreadUtils.h" + +#include "runnable_utils.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +using namespace mozilla; + +namespace { + +// Helper used to make sure args are properly copied and/or moved. +struct CtorDtorState { + enum class State { Empty, Copy, Explicit, Move, Moved }; + + static const char* ToStr(const State& state) { + switch (state) { + case State::Empty: + return "empty"; + case State::Copy: + return "copy"; + case State::Explicit: + return "explicit"; + case State::Move: + return "move"; + case State::Moved: + return "moved"; + default: + return "unknown"; + } + } + + void DumpState() const { std::cerr << ToStr(state_) << std::endl; } + + CtorDtorState() { DumpState(); } + + explicit CtorDtorState(int* destroyed) + : dtor_count_(destroyed), state_(State::Explicit) { + DumpState(); + } + + CtorDtorState(const CtorDtorState& other) + : dtor_count_(other.dtor_count_), state_(State::Copy) { + DumpState(); + } + + // Clear the other's dtor counter so it's not counted if moved. + CtorDtorState(CtorDtorState&& other) + : dtor_count_(std::exchange(other.dtor_count_, nullptr)), + state_(State::Move) { + other.state_ = State::Moved; + DumpState(); + } + + ~CtorDtorState() { + const char* const state = ToStr(state_); + std::cerr << "Destructor called with end state: " << state << std::endl; + + if (dtor_count_) { + ++*dtor_count_; + } + } + + int* dtor_count_ = nullptr; + State state_ = State::Empty; +}; + +class Destructor { + private: + ~Destructor() { + std::cerr << "Destructor called" << std::endl; + *destroyed_ = true; + } + + public: + explicit Destructor(bool* destroyed) : destroyed_(destroyed) {} + + NS_INLINE_DECL_THREADSAFE_REFCOUNTING(Destructor) + + private: + bool* destroyed_; +}; + +class TargetClass { + public: + explicit TargetClass(int* ran) : ran_(ran) {} + + void m1(int x) { + std::cerr << __FUNCTION__ << " " << x << std::endl; + *ran_ = 1; + } + + void m2(int x, int y) { + std::cerr << __FUNCTION__ << " " << x << " " << y << std::endl; + *ran_ = 2; + } + + void m1set(bool* z) { + std::cerr << __FUNCTION__ << std::endl; + *z = true; + } + int return_int(int x) { + std::cerr << __FUNCTION__ << std::endl; + return x; + } + + void destructor_target_ref(RefPtr<Destructor> destructor) {} + + int* ran_; +}; + +class RunnableArgsTest : public MtransportTest { + public: + RunnableArgsTest() : MtransportTest(), ran_(0), cl_(&ran_) {} + + void Test1Arg() { + Runnable* r = WrapRunnable(&cl_, &TargetClass::m1, 1); + r->Run(); + ASSERT_EQ(1, ran_); + } + + void Test2Args() { + Runnable* r = WrapRunnable(&cl_, &TargetClass::m2, 1, 2); + r->Run(); + ASSERT_EQ(2, ran_); + } + + private: + int ran_; + TargetClass cl_; +}; + +class DispatchTest : public MtransportTest { + public: + DispatchTest() : MtransportTest(), ran_(0), cl_(&ran_) {} + + void SetUp() { + MtransportTest::SetUp(); + + nsresult rv; + target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + ASSERT_TRUE(NS_SUCCEEDED(rv)); + } + + void Test1Arg() { + Runnable* r = WrapRunnable(&cl_, &TargetClass::m1, 1); + NS_DispatchAndSpinEventLoopUntilComplete("DispatchTest::Test1Arg"_ns, + target_, do_AddRef(r)); + ASSERT_EQ(1, ran_); + } + + void Test2Args() { + Runnable* r = WrapRunnable(&cl_, &TargetClass::m2, 1, 2); + NS_DispatchAndSpinEventLoopUntilComplete("DispatchTest::Test2Args"_ns, + target_, do_AddRef(r)); + ASSERT_EQ(2, ran_); + } + + void Test1Set() { + bool x = false; + NS_DispatchAndSpinEventLoopUntilComplete( + "DispatchTest::Test1Set"_ns, target_, + do_AddRef(WrapRunnable(&cl_, &TargetClass::m1set, &x))); + ASSERT_TRUE(x); + } + + void TestRet() { + int z; + int x = 10; + + NS_DispatchAndSpinEventLoopUntilComplete( + "DispatchTest::TestRet"_ns, target_, + do_AddRef(WrapRunnableRet(&z, &cl_, &TargetClass::return_int, x))); + ASSERT_EQ(10, z); + } + + protected: + int ran_; + TargetClass cl_; + nsCOMPtr<nsIEventTarget> target_; +}; + +TEST_F(RunnableArgsTest, OneArgument) { Test1Arg(); } + +TEST_F(RunnableArgsTest, TwoArguments) { Test2Args(); } + +TEST_F(DispatchTest, OneArgument) { Test1Arg(); } + +TEST_F(DispatchTest, TwoArguments) { Test2Args(); } + +TEST_F(DispatchTest, Test1Set) { Test1Set(); } + +TEST_F(DispatchTest, TestRet) { TestRet(); } + +void SetNonMethod(TargetClass* cl, int x) { cl->m1(x); } + +int SetNonMethodRet(TargetClass* cl, int x) { + cl->m1(x); + + return x; +} + +TEST_F(DispatchTest, TestNonMethod) { + test_utils_->SyncDispatchToSTS(WrapRunnableNM(SetNonMethod, &cl_, 10)); + + ASSERT_EQ(1, ran_); +} + +TEST_F(DispatchTest, TestNonMethodRet) { + int z; + + test_utils_->SyncDispatchToSTS( + WrapRunnableNMRet(&z, SetNonMethodRet, &cl_, 10)); + + ASSERT_EQ(1, ran_); + ASSERT_EQ(10, z); +} + +TEST_F(DispatchTest, TestDestructorRef) { + bool destroyed = false; + { + RefPtr<Destructor> destructor = new Destructor(&destroyed); + NS_DispatchAndSpinEventLoopUntilComplete( + "DispatchTest::TestDestructorRef"_ns, target_, + do_AddRef(WrapRunnable(&cl_, &TargetClass::destructor_target_ref, + destructor))); + ASSERT_FALSE(destroyed); + } + ASSERT_TRUE(destroyed); + + // Now try with a move. + destroyed = false; + { + RefPtr<Destructor> destructor = new Destructor(&destroyed); + NS_DispatchAndSpinEventLoopUntilComplete( + "DispatchTest::TestDestructorRef"_ns, target_, + do_AddRef(WrapRunnable(&cl_, &TargetClass::destructor_target_ref, + std::move(destructor)))); + ASSERT_TRUE(destroyed); + } +} + +TEST_F(DispatchTest, TestMove) { + int destroyed = 0; + { + CtorDtorState state(&destroyed); + + // Dispatch with: + // - moved arg + // - by-val capture in function should consume a move + // - expect destruction in the function scope + NS_DispatchAndSpinEventLoopUntilComplete( + "DispatchTest::TestMove"_ns, target_, + do_AddRef(WrapRunnableNM([](CtorDtorState s) {}, std::move(state)))); + ASSERT_EQ(1, destroyed); + } + // Still shouldn't count when we go out of scope as it was moved. + ASSERT_EQ(1, destroyed); + + { + CtorDtorState state(&destroyed); + + // Dispatch with: + // - copied arg + // - by-val capture in function should consume a move + // - expect destruction in the function scope and call scope + NS_DispatchAndSpinEventLoopUntilComplete( + "DispatchTest::TestMove"_ns, target_, + do_AddRef(WrapRunnableNM([](CtorDtorState s) {}, state))); + ASSERT_EQ(2, destroyed); + } + // Original state should be destroyed + ASSERT_EQ(3, destroyed); + + { + CtorDtorState state(&destroyed); + + // Dispatch with: + // - moved arg + // - by-ref in function should accept a moved arg + // - expect destruction in the wrapper invocation scope + NS_DispatchAndSpinEventLoopUntilComplete( + "DispatchTest::TestMove"_ns, target_, + do_AddRef( + WrapRunnableNM([](const CtorDtorState& s) {}, std::move(state)))); + ASSERT_EQ(4, destroyed); + } + // Still shouldn't count when we go out of scope as it was moved. + ASSERT_EQ(4, destroyed); + + { + CtorDtorState state(&destroyed); + + // Dispatch with: + // - moved arg + // - r-value function should accept a moved arg + // - expect destruction in the wrapper invocation scope + NS_DispatchAndSpinEventLoopUntilComplete( + "DispatchTest::TestMove"_ns, target_, + do_AddRef(WrapRunnableNM([](CtorDtorState&& s) {}, std::move(state)))); + ASSERT_EQ(5, destroyed); + } + // Still shouldn't count when we go out of scope as it was moved. + ASSERT_EQ(5, destroyed); +} + +TEST_F(DispatchTest, TestUniquePtr) { + // Test that holding the class in UniquePtr works + int ran = 0; + auto cl = MakeUnique<TargetClass>(&ran); + + NS_DispatchAndSpinEventLoopUntilComplete( + "DispatchTest::TestUniquePtr"_ns, target_, + do_AddRef(WrapRunnable(std::move(cl), &TargetClass::m1, 1))); + ASSERT_EQ(1, ran); + + // Test that UniquePtr works as a param to the runnable + int destroyed = 0; + { + auto state = MakeUnique<CtorDtorState>(&destroyed); + + // Dispatch with: + // - moved arg + // - Function should move construct from arg + // - expect destruction in the wrapper invocation scope + NS_DispatchAndSpinEventLoopUntilComplete( + "DispatchTest::TestUniquePtr"_ns, target_, + do_AddRef(WrapRunnableNM([](UniquePtr<CtorDtorState> s) {}, + std::move(state)))); + ASSERT_EQ(1, destroyed); + } + // Still shouldn't count when we go out of scope as it was moved. + ASSERT_EQ(1, destroyed); +} + +} // end of namespace diff --git a/dom/media/webrtc/transport/test/sctp_unittest.cpp b/dom/media/webrtc/transport/test/sctp_unittest.cpp new file mode 100644 index 0000000000..ea32565fb2 --- /dev/null +++ b/dom/media/webrtc/transport/test/sctp_unittest.cpp @@ -0,0 +1,381 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: ekr@rtfm.com + +#include <iostream> +#include <string> + +#include "sigslot.h" + +#include "nsITimer.h" + +#include "transportflow.h" +#include "transportlayer.h" +#include "transportlayerloopback.h" + +#include "runnable_utils.h" +#include "usrsctp.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +using namespace mozilla; + +static bool sctp_logging = false; +static int port_number = 5000; + +namespace { + +class TransportTestPeer; + +class SendPeriodic : public nsITimerCallback, public nsINamed { + public: + SendPeriodic(TransportTestPeer* peer, int to_send) + : peer_(peer), to_send_(to_send) {} + + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSITIMERCALLBACK + NS_DECL_NSINAMED + + protected: + virtual ~SendPeriodic() = default; + + TransportTestPeer* peer_; + int to_send_; +}; + +NS_IMPL_ISUPPORTS(SendPeriodic, nsITimerCallback, nsINamed) + +class TransportTestPeer : public sigslot::has_slots<> { + public: + TransportTestPeer(std::string name, int local_port, int remote_port, + MtransportTestUtils* utils) + : name_(name), + connected_(false), + sent_(0), + received_(0), + flow_(new TransportFlow()), + loopback_(new TransportLayerLoopback()), + sctp_(usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, receive_cb, + nullptr, 0, nullptr)), + timer_(NS_NewTimer()), + periodic_(nullptr), + test_utils_(utils) { + std::cerr << "Creating TransportTestPeer; flow=" + << static_cast<void*>(flow_.get()) << " local=" << local_port + << " remote=" << remote_port << std::endl; + + usrsctp_register_address(static_cast<void*>(this)); + int r = usrsctp_set_non_blocking(sctp_, 1); + EXPECT_GE(r, 0); + + struct linger l; + l.l_onoff = 1; + l.l_linger = 0; + r = usrsctp_setsockopt(sctp_, SOL_SOCKET, SO_LINGER, &l, + (socklen_t)sizeof(l)); + EXPECT_GE(r, 0); + + struct sctp_event subscription; + memset(&subscription, 0, sizeof(subscription)); + subscription.se_assoc_id = SCTP_ALL_ASSOC; + subscription.se_on = 1; + subscription.se_type = SCTP_ASSOC_CHANGE; + r = usrsctp_setsockopt(sctp_, IPPROTO_SCTP, SCTP_EVENT, &subscription, + sizeof(subscription)); + EXPECT_GE(r, 0); + + memset(&local_addr_, 0, sizeof(local_addr_)); + local_addr_.sconn_family = AF_CONN; +#if !defined(__Userspace_os_Linux) && !defined(__Userspace_os_Windows) && \ + !defined(__Userspace_os_Android) + local_addr_.sconn_len = sizeof(struct sockaddr_conn); +#endif + local_addr_.sconn_port = htons(local_port); + local_addr_.sconn_addr = static_cast<void*>(this); + + memset(&remote_addr_, 0, sizeof(remote_addr_)); + remote_addr_.sconn_family = AF_CONN; +#if !defined(__Userspace_os_Linux) && !defined(__Userspace_os_Windows) && \ + !defined(__Userspace_os_Android) + remote_addr_.sconn_len = sizeof(struct sockaddr_conn); +#endif + remote_addr_.sconn_port = htons(remote_port); + remote_addr_.sconn_addr = static_cast<void*>(this); + + nsresult res; + res = loopback_->Init(); + EXPECT_EQ((nsresult)NS_OK, res); + } + + ~TransportTestPeer() { + std::cerr << "Destroying sctp connection flow=" + << static_cast<void*>(flow_.get()) << std::endl; + usrsctp_close(sctp_); + usrsctp_deregister_address(static_cast<void*>(this)); + + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &TransportTestPeer::Disconnect_s)); + + std::cerr << "~TransportTestPeer() completed" << std::endl; + } + + void ConnectSocket(TransportTestPeer* peer) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &TransportTestPeer::ConnectSocket_s, peer)); + } + + void ConnectSocket_s(TransportTestPeer* peer) { + loopback_->Connect(peer->loopback_); + ASSERT_EQ((nsresult)NS_OK, loopback_->Init()); + flow_->PushLayer(loopback_); + + loopback_->SignalPacketReceived.connect(this, + &TransportTestPeer::PacketReceived); + + // SCTP here! + ASSERT_TRUE(sctp_); + std::cerr << "Calling usrsctp_bind()" << std::endl; + int r = + usrsctp_bind(sctp_, reinterpret_cast<struct sockaddr*>(&local_addr_), + sizeof(local_addr_)); + ASSERT_GE(0, r); + + std::cerr << "Calling usrsctp_connect()" << std::endl; + r = usrsctp_connect(sctp_, + reinterpret_cast<struct sockaddr*>(&remote_addr_), + sizeof(remote_addr_)); + ASSERT_GE(0, r); + } + + void Disconnect_s() { + disconnect_all(); + if (flow_) { + flow_ = nullptr; + } + } + + void Disconnect() { loopback_->Disconnect(); } + + void StartTransfer(size_t to_send) { + periodic_ = new SendPeriodic(this, to_send); + timer_->SetTarget(test_utils_->sts_target()); + timer_->InitWithCallback(periodic_, 10, nsITimer::TYPE_REPEATING_SLACK); + } + + void SendOne() { + unsigned char buf[100]; + memset(buf, sent_ & 0xff, sizeof(buf)); + + struct sctp_sndinfo info; + info.snd_sid = 1; + info.snd_flags = 0; + info.snd_ppid = 50; // What the heck is this? + info.snd_context = 0; + info.snd_assoc_id = 0; + + int r = usrsctp_sendv(sctp_, buf, sizeof(buf), nullptr, 0, + static_cast<void*>(&info), sizeof(info), + SCTP_SENDV_SNDINFO, 0); + ASSERT_TRUE(r >= 0); + ASSERT_EQ(sizeof(buf), (size_t)r); + + ++sent_; + } + + int sent() const { return sent_; } + int received() const { return received_; } + bool connected() const { return connected_; } + + static TransportResult SendPacket_s(UniquePtr<MediaPacket> packet, + const RefPtr<TransportFlow>& flow, + TransportLayer* layer) { + return layer->SendPacket(*packet); + } + + TransportResult SendPacket(const unsigned char* data, size_t len) { + UniquePtr<MediaPacket> packet(new MediaPacket); + packet->Copy(data, len); + + // Uses DISPATCH_NORMAL to avoid possible deadlocks when we're called + // from MainThread especially during shutdown (same as DataChannels). + // RUN_ON_THREAD short-circuits if already on the STS thread, which is + // normal for most transfers outside of connect() and close(). Passes + // a refptr to flow_ to avoid any async deletion issues (since we can't + // make 'this' into a refptr as it isn't refcounted) + RUN_ON_THREAD(test_utils_->sts_target(), + WrapRunnableNM(&TransportTestPeer::SendPacket_s, + std::move(packet), flow_, loopback_), + NS_DISPATCH_NORMAL); + + return 0; + } + + void PacketReceived(TransportLayer* layer, MediaPacket& packet) { + std::cerr << "Received " << packet.len() << " bytes" << std::endl; + + // Pass the data to SCTP + + usrsctp_conninput(static_cast<void*>(this), packet.data(), packet.len(), 0); + } + + // Process SCTP notification + void Notification(union sctp_notification* msg, size_t len) { + ASSERT_EQ(msg->sn_header.sn_length, len); + + if (msg->sn_header.sn_type == SCTP_ASSOC_CHANGE) { + struct sctp_assoc_change* change = &msg->sn_assoc_change; + + if (change->sac_state == SCTP_COMM_UP) { + std::cerr << "Connection up" << std::endl; + SetConnected(true); + } else { + std::cerr << "Connection down" << std::endl; + SetConnected(false); + } + } + } + + void SetConnected(bool state) { connected_ = state; } + + static int conn_output(void* addr, void* buffer, size_t length, uint8_t tos, + uint8_t set_df) { + TransportTestPeer* peer = static_cast<TransportTestPeer*>(addr); + + peer->SendPacket(static_cast<unsigned char*>(buffer), length); + + return 0; + } + + static int receive_cb(struct socket* sock, union sctp_sockstore addr, + void* data, size_t datalen, struct sctp_rcvinfo rcv, + int flags, void* ulp_info) { + TransportTestPeer* me = + static_cast<TransportTestPeer*>(addr.sconn.sconn_addr); + MOZ_ASSERT(me); + + if (flags & MSG_NOTIFICATION) { + union sctp_notification* notif = + static_cast<union sctp_notification*>(data); + + me->Notification(notif, datalen); + return 0; + } + + me->received_ += datalen; + + std::cerr << "receive_cb: sock " << sock << " data " << data << "(" + << datalen << ") total received bytes = " << me->received_ + << std::endl; + + return 0; + } + + private: + std::string name_; + bool connected_; + size_t sent_; + size_t received_; + // Owns the TransportLayerLoopback, but basically does nothing else. + RefPtr<TransportFlow> flow_; + TransportLayerLoopback* loopback_; + + struct sockaddr_conn local_addr_; + struct sockaddr_conn remote_addr_; + struct socket* sctp_; + nsCOMPtr<nsITimer> timer_; + RefPtr<SendPeriodic> periodic_; + MtransportTestUtils* test_utils_; +}; + +// Implemented here because it calls a method of TransportTestPeer +NS_IMETHODIMP SendPeriodic::Notify(nsITimer* timer) { + peer_->SendOne(); + --to_send_; + if (!to_send_) { + timer->Cancel(); + } + return NS_OK; +} + +NS_IMETHODIMP +SendPeriodic::GetName(nsACString& aName) { + aName.AssignLiteral("SendPeriodic"); + return NS_OK; +} + +class SctpTransportTest : public MtransportTest { + public: + SctpTransportTest() = default; + + ~SctpTransportTest() = default; + + static void debug_printf(const char* format, ...) { + va_list ap; + + va_start(ap, format); + vprintf(format, ap); + va_end(ap); + } + + static void SetUpTestCase() { + if (sctp_logging) { + usrsctp_init(0, &TransportTestPeer::conn_output, debug_printf); + usrsctp_sysctl_set_sctp_debug_on(0xffffffff); + } else { + usrsctp_init(0, &TransportTestPeer::conn_output, nullptr); + } + } + + void TearDown() override { + if (p1_) p1_->Disconnect(); + if (p2_) p2_->Disconnect(); + delete p1_; + delete p2_; + + MtransportTest::TearDown(); + } + + void ConnectSocket(int p1port = 0, int p2port = 0) { + if (!p1port) p1port = port_number++; + if (!p2port) p2port = port_number++; + + p1_ = new TransportTestPeer("P1", p1port, p2port, test_utils_); + p2_ = new TransportTestPeer("P2", p2port, p1port, test_utils_); + + p1_->ConnectSocket(p2_); + p2_->ConnectSocket(p1_); + ASSERT_TRUE_WAIT(p1_->connected(), 2000); + ASSERT_TRUE_WAIT(p2_->connected(), 2000); + } + + void TestTransfer(int expected = 1) { + std::cerr << "Starting trasnsfer test" << std::endl; + p1_->StartTransfer(expected); + ASSERT_TRUE_WAIT(p1_->sent() == expected, 10000); + ASSERT_TRUE_WAIT(p2_->received() == (expected * 100), 10000); + std::cerr << "P2 received " << p2_->received() << std::endl; + } + + protected: + TransportTestPeer* p1_ = nullptr; + TransportTestPeer* p2_ = nullptr; +}; + +TEST_F(SctpTransportTest, TestConnect) { ConnectSocket(); } + +TEST_F(SctpTransportTest, TestConnectSymmetricalPorts) { + ConnectSocket(5002, 5002); +} + +TEST_F(SctpTransportTest, TestTransfer) { + ConnectSocket(); + TestTransfer(50); +} + +} // end namespace diff --git a/dom/media/webrtc/transport/test/simpletokenbucket_unittest.cpp b/dom/media/webrtc/transport/test/simpletokenbucket_unittest.cpp new file mode 100644 index 0000000000..66622d795b --- /dev/null +++ b/dom/media/webrtc/transport/test/simpletokenbucket_unittest.cpp @@ -0,0 +1,114 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +/* Original author: bcampen@mozilla.com */ + +#include "simpletokenbucket.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +using mozilla::SimpleTokenBucket; + +class TestSimpleTokenBucket : public SimpleTokenBucket { + public: + TestSimpleTokenBucket(size_t bucketSize, size_t tokensPerSecond) + : SimpleTokenBucket(bucketSize, tokensPerSecond) {} + + void fastForward(int32_t timeMilliSeconds) { + if (timeMilliSeconds >= 0) { + last_time_tokens_added_ -= PR_MillisecondsToInterval(timeMilliSeconds); + } else { + last_time_tokens_added_ += PR_MillisecondsToInterval(-timeMilliSeconds); + } + } +}; + +TEST(SimpleTokenBucketTest, TestConstruct) +{ TestSimpleTokenBucket b(10, 1); } + +TEST(SimpleTokenBucketTest, TestGet) +{ + TestSimpleTokenBucket b(10, 1); + ASSERT_EQ(5U, b.getTokens(5)); +} + +TEST(SimpleTokenBucketTest, TestGetAll) +{ + TestSimpleTokenBucket b(10, 1); + ASSERT_EQ(10U, b.getTokens(10)); +} + +TEST(SimpleTokenBucketTest, TestGetInsufficient) +{ + TestSimpleTokenBucket b(10, 1); + ASSERT_EQ(5U, b.getTokens(5)); + ASSERT_EQ(5U, b.getTokens(6)); +} + +TEST(SimpleTokenBucketTest, TestGetBucketCount) +{ + TestSimpleTokenBucket b(10, 1); + ASSERT_EQ(10U, b.getTokens(UINT32_MAX)); + ASSERT_EQ(5U, b.getTokens(5)); + ASSERT_EQ(5U, b.getTokens(UINT32_MAX)); +} + +TEST(SimpleTokenBucketTest, TestTokenRefill) +{ + TestSimpleTokenBucket b(10, 1); + ASSERT_EQ(5U, b.getTokens(5)); + b.fastForward(1000); + ASSERT_EQ(6U, b.getTokens(6)); +} + +TEST(SimpleTokenBucketTest, TestNoTimeWasted) +{ + // Makes sure that when the time elapsed is insufficient to add any + // tokens to the bucket, the internal timestamp that is used in this + // calculation is not updated (ie; two subsequent 0.5 second elapsed times + // counts as a full second) + TestSimpleTokenBucket b(10, 1); + ASSERT_EQ(5U, b.getTokens(5)); + b.fastForward(500); + ASSERT_EQ(5U, b.getTokens(6)); + b.fastForward(500); + ASSERT_EQ(6U, b.getTokens(6)); +} + +TEST(SimpleTokenBucketTest, TestNegativeTime) +{ + TestSimpleTokenBucket b(10, 1); + b.fastForward(-1000); + // Make sure we don't end up with an invalid number of tokens, but otherwise + // permit anything. + ASSERT_GT(11U, b.getTokens(100)); +} + +TEST(SimpleTokenBucketTest, TestEmptyBucket) +{ + TestSimpleTokenBucket b(10, 1); + ASSERT_EQ(10U, b.getTokens(10)); + ASSERT_EQ(0U, b.getTokens(10)); +} + +TEST(SimpleTokenBucketTest, TestEmptyThenFillBucket) +{ + TestSimpleTokenBucket b(10, 1); + ASSERT_EQ(10U, b.getTokens(10)); + ASSERT_EQ(0U, b.getTokens(1)); + b.fastForward(50000); + ASSERT_EQ(10U, b.getTokens(10)); +} + +TEST(SimpleTokenBucketTest, TestNoOverflow) +{ + TestSimpleTokenBucket b(10, 1); + ASSERT_EQ(10U, b.getTokens(10)); + ASSERT_EQ(0U, b.getTokens(1)); + b.fastForward(50000); + ASSERT_EQ(10U, b.getTokens(11)); +} diff --git a/dom/media/webrtc/transport/test/sockettransportservice_unittest.cpp b/dom/media/webrtc/transport/test/sockettransportservice_unittest.cpp new file mode 100644 index 0000000000..ffa87fe91f --- /dev/null +++ b/dom/media/webrtc/transport/test/sockettransportservice_unittest.cpp @@ -0,0 +1,181 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: ekr@rtfm.com +#include <iostream> + +#include "prio.h" + +#include "nsCOMPtr.h" +#include "nsNetCID.h" + +#include "nsISocketTransportService.h" + +#include "nsASocketHandler.h" +#include "nsServiceManagerUtils.h" +#include "nsThreadUtils.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +using namespace mozilla; + +namespace { +class SocketTransportServiceTest : public MtransportTest { + public: + SocketTransportServiceTest() + : MtransportTest(), + received_(0), + readpipe_(nullptr), + writepipe_(nullptr), + registered_(false) {} + + ~SocketTransportServiceTest() { + if (readpipe_) PR_Close(readpipe_); + if (writepipe_) PR_Close(writepipe_); + } + + void SetUp(); + void RegisterHandler(); + void SendEvent(); + void SendPacket(); + + void ReceivePacket() { ++received_; } + + void ReceiveEvent() { ++received_; } + + size_t Received() { return received_; } + + private: + nsCOMPtr<nsISocketTransportService> stservice_; + nsCOMPtr<nsIEventTarget> target_; + size_t received_; + PRFileDesc* readpipe_; + PRFileDesc* writepipe_; + bool registered_; +}; + +// Received an event. +class EventReceived : public Runnable { + public: + explicit EventReceived(SocketTransportServiceTest* test) + : Runnable("EventReceived"), test_(test) {} + + NS_IMETHOD Run() override { + test_->ReceiveEvent(); + return NS_OK; + } + + SocketTransportServiceTest* test_; +}; + +// Register our listener on the socket +class RegisterEvent : public Runnable { + public: + explicit RegisterEvent(SocketTransportServiceTest* test) + : Runnable("RegisterEvent"), test_(test) {} + + NS_IMETHOD Run() override { + test_->RegisterHandler(); + return NS_OK; + } + + SocketTransportServiceTest* test_; +}; + +class SocketHandler : public nsASocketHandler { + public: + explicit SocketHandler(SocketTransportServiceTest* test) : test_(test) {} + + void OnSocketReady(PRFileDesc* fd, int16_t outflags) override { + unsigned char buf[1600]; + + int32_t rv; + rv = PR_Recv(fd, buf, sizeof(buf), 0, PR_INTERVAL_NO_WAIT); + if (rv > 0) { + std::cerr << "Read " << rv << " bytes" << std::endl; + test_->ReceivePacket(); + } + } + + void OnSocketDetached(PRFileDesc* fd) override {} + + void IsLocal(bool* aIsLocal) override { + // TODO(jesup): better check? Does it matter? (likely no) + *aIsLocal = false; + } + + virtual uint64_t ByteCountSent() override { return 0; } + virtual uint64_t ByteCountReceived() override { return 0; } + + NS_DECL_ISUPPORTS + + protected: + virtual ~SocketHandler() = default; + + private: + SocketTransportServiceTest* test_; +}; + +NS_IMPL_ISUPPORTS0(SocketHandler) + +void SocketTransportServiceTest::SetUp() { + MtransportTest::SetUp(); + + // Get the transport service as a dispatch target + nsresult rv; + target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + ASSERT_TRUE(NS_SUCCEEDED(rv)); + + // Get the transport service as a transport service + stservice_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + ASSERT_TRUE(NS_SUCCEEDED(rv)); + + // Create a loopback pipe + PRStatus status = PR_CreatePipe(&readpipe_, &writepipe_); + ASSERT_EQ(status, PR_SUCCESS); + + // Register ourselves as a listener for the read side of the + // socket. The registration has to happen on the STS thread, + // hence this event stuff. + rv = target_->Dispatch(new RegisterEvent(this), 0); + ASSERT_TRUE(NS_SUCCEEDED(rv)); + ASSERT_TRUE_WAIT(registered_, 10000); +} + +void SocketTransportServiceTest::RegisterHandler() { + nsresult rv; + + rv = stservice_->AttachSocket(readpipe_, new SocketHandler(this)); + ASSERT_TRUE(NS_SUCCEEDED(rv)); + + registered_ = true; +} + +void SocketTransportServiceTest::SendEvent() { + nsresult rv; + + rv = target_->Dispatch(new EventReceived(this), 0); + ASSERT_TRUE(NS_SUCCEEDED(rv)); + ASSERT_TRUE_WAIT(Received() == 1, 10000); +} + +void SocketTransportServiceTest::SendPacket() { + unsigned char buffer[1024]; + memset(buffer, 0, sizeof(buffer)); + + int32_t status = PR_Write(writepipe_, buffer, sizeof(buffer)); + uint32_t size = status & 0xffff; + ASSERT_EQ(sizeof(buffer), size); +} + +// The unit tests themselves +TEST_F(SocketTransportServiceTest, SendEvent) { SendEvent(); } + +TEST_F(SocketTransportServiceTest, SendPacket) { SendPacket(); } + +} // end namespace diff --git a/dom/media/webrtc/transport/test/stunserver.cpp b/dom/media/webrtc/transport/test/stunserver.cpp new file mode 100644 index 0000000000..b5fce21e19 --- /dev/null +++ b/dom/media/webrtc/transport/test/stunserver.cpp @@ -0,0 +1,652 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: ekr@rtfm.com + +/* +Original code from nICEr and nrappkit. + +nICEr copyright: + +Copyright (c) 2007, Adobe Systems, Incorporated +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of Adobe Systems, Network Resonance nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +nrappkit copyright: + + Copyright (C) 2001-2003, Network Resonance, Inc. + Copyright (C) 2006, Network Resonance, Inc. + All Rights Reserved + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + 3. Neither the name of Network Resonance, Inc. nor the name of any + contributors to this software may be used to endorse or promote + products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + POSSIBILITY OF SUCH DAMAGE. + + + ekr@rtfm.com Thu Dec 20 20:14:49 2001 +*/ +#include "logging.h" +#include "mozilla/UniquePtr.h" +#include "mozilla/Unused.h" +#include "mediapacket.h" + +// mozilla/utils.h defines this as well +#ifdef UNIMPLEMENTED +# undef UNIMPLEMENTED +#endif + +extern "C" { +#include "nr_api.h" +#include "async_wait.h" +#include "async_timer.h" +#include "nr_socket.h" +#include "nr_socket_local.h" +#include "transport_addr.h" +#include "stun_util.h" +#include "registry.h" +#include "nr_socket_buffered_stun.h" +} + +#include "stunserver.h" + +#include <string> + +MOZ_MTLOG_MODULE("stunserver"); + +namespace mozilla { + +// Wrapper nr_socket which allows us to lie to the stun server about the +// IP address. +struct nr_socket_wrapped { + nr_socket* sock_; + nr_transport_addr addr_; +}; + +static int nr_socket_wrapped_destroy(void** objp) { + if (!objp || !*objp) return 0; + + nr_socket_wrapped* wrapped = static_cast<nr_socket_wrapped*>(*objp); + *objp = nullptr; + + delete wrapped; + + return 0; +} + +static int nr_socket_wrapped_sendto(void* obj, const void* msg, size_t len, + int flags, const nr_transport_addr* addr) { + nr_socket_wrapped* wrapped = static_cast<nr_socket_wrapped*>(obj); + + return nr_socket_sendto(wrapped->sock_, msg, len, flags, &wrapped->addr_); +} + +static int nr_socket_wrapped_recvfrom(void* obj, void* restrict buf, + size_t maxlen, size_t* len, int flags, + nr_transport_addr* addr) { + nr_socket_wrapped* wrapped = static_cast<nr_socket_wrapped*>(obj); + + return nr_socket_recvfrom(wrapped->sock_, buf, maxlen, len, flags, addr); +} + +static int nr_socket_wrapped_getfd(void* obj, NR_SOCKET* fd) { + nr_socket_wrapped* wrapped = static_cast<nr_socket_wrapped*>(obj); + + return nr_socket_getfd(wrapped->sock_, fd); +} + +static int nr_socket_wrapped_getaddr(void* obj, nr_transport_addr* addrp) { + nr_socket_wrapped* wrapped = static_cast<nr_socket_wrapped*>(obj); + + return nr_socket_getaddr(wrapped->sock_, addrp); +} + +static int nr_socket_wrapped_close(void* obj) { MOZ_CRASH(); } + +static int nr_socket_wrapped_set_send_addr(nr_socket* sock, + nr_transport_addr* addr) { + nr_socket_wrapped* wrapped = static_cast<nr_socket_wrapped*>(sock->obj); + + return nr_transport_addr_copy(&wrapped->addr_, addr); +} + +static nr_socket_vtbl nr_socket_wrapped_vtbl = {2, + nr_socket_wrapped_destroy, + nr_socket_wrapped_sendto, + nr_socket_wrapped_recvfrom, + nr_socket_wrapped_getfd, + nr_socket_wrapped_getaddr, + nullptr, + nullptr, + nullptr, + nr_socket_wrapped_close, + nullptr, + nullptr}; + +int nr_socket_wrapped_create(nr_socket* inner, nr_socket** outp) { + auto wrapped = MakeUnique<nr_socket_wrapped>(); + + wrapped->sock_ = inner; + + int r = nr_socket_create_int(wrapped.get(), &nr_socket_wrapped_vtbl, outp); + if (r) return r; + + Unused << wrapped.release(); + return 0; +} + +// Instance static. +// Note: Calling Create() at static init time is not going to be safe, since +// we have no reason to expect this will be initted to a nullptr yet. +TestStunServer* TestStunServer::instance; +TestStunTcpServer* TestStunTcpServer::instance; +TestStunServer* TestStunServer::instance6; +TestStunTcpServer* TestStunTcpServer::instance6; +uint16_t TestStunServer::instance_port = 3478; +uint16_t TestStunTcpServer::instance_port = 3478; + +TestStunServer::~TestStunServer() { + // TODO(ekr@rtfm.com): Put this on the right thread. + + // Unhook callback from our listen socket. + if (listen_sock_) { + NR_SOCKET fd; + if (!nr_socket_getfd(listen_sock_, &fd)) { + NR_ASYNC_CANCEL(fd, NR_ASYNC_WAIT_READ); + } + } + + // Free up stun context and network resources + nr_stun_server_ctx_destroy(&stun_server_); + nr_socket_destroy(&listen_sock_); + nr_socket_destroy(&send_sock_); + + // Make sure we aren't still waiting on a deferred response timer to pop + if (timer_handle_) NR_async_timer_cancel(timer_handle_); + + delete response_addr_; +} + +int TestStunServer::SetInternalPort(nr_local_addr* addr, uint16_t port) { + if (nr_transport_addr_set_port(&addr->addr, port)) { + MOZ_MTLOG(ML_ERROR, "Couldn't set port"); + return R_INTERNAL; + } + + if (nr_transport_addr_fmt_addr_string(&addr->addr)) { + MOZ_MTLOG(ML_ERROR, "Couldn't re-set addr string"); + return R_INTERNAL; + } + + return 0; +} + +int TestStunServer::TryOpenListenSocket(nr_local_addr* addr, uint16_t port) { + int r = SetInternalPort(addr, port); + + if (r) return r; + + if (nr_socket_local_create(nullptr, &addr->addr, &listen_sock_)) { + MOZ_MTLOG(ML_ERROR, "Couldn't create listen socket"); + return R_ALREADY; + } + + return 0; +} + +static int addressFamilyToIpVersion(int address_family) { + switch (address_family) { + case AF_INET: + return NR_IPV4; + case AF_INET6: + return NR_IPV6; + default: + MOZ_CRASH(); + } + return NR_IPV4; +} + +int TestStunServer::Initialize(int address_family) { + static const size_t max_addrs = 100; + nr_local_addr addrs[max_addrs]; + int addr_ct; + int r; + int i; + + r = nr_stun_find_local_addresses(addrs, max_addrs, &addr_ct); + if (r) { + MOZ_MTLOG(ML_ERROR, "Couldn't retrieve addresses"); + return R_INTERNAL; + } + + // removes duplicates and, based on prefs, loopback and link_local addrs + r = nr_stun_filter_local_addresses(addrs, &addr_ct); + if (r) { + MOZ_MTLOG(ML_ERROR, "Couldn't filter addresses"); + return R_INTERNAL; + } + + if (addr_ct < 1) { + MOZ_MTLOG(ML_ERROR, "No local addresses"); + return R_INTERNAL; + } + + for (i = 0; i < addr_ct; ++i) { + if (addrs[i].addr.ip_version == addressFamilyToIpVersion(address_family)) { + break; + } + } + + if (i == addr_ct) { + MOZ_MTLOG(ML_ERROR, "No local addresses of the configured IP version"); + return R_INTERNAL; + } + + int tries = 100; + while (tries--) { + // Bind on configured port (default 3478) + r = TryOpenListenSocket(&addrs[i], instance_port); + // We interpret R_ALREADY to mean the addr is probably in use. Try another. + // Otherwise, it either worked or it didn't, and we check below. + if (r != R_ALREADY) { + break; + } + ++instance_port; + } + + if (r) { + return R_INTERNAL; + } + + r = nr_socket_wrapped_create(listen_sock_, &send_sock_); + if (r) { + MOZ_MTLOG(ML_ERROR, "Couldn't create send socket"); + return R_INTERNAL; + } + + r = nr_stun_server_ctx_create(const_cast<char*>("Test STUN server"), + &stun_server_); + if (r) { + MOZ_MTLOG(ML_ERROR, "Couldn't create STUN server"); + return R_INTERNAL; + } + + // Cache the address and port. + char addr_string[INET6_ADDRSTRLEN]; + r = nr_transport_addr_get_addrstring(&addrs[i].addr, addr_string, + sizeof(addr_string)); + if (r) { + MOZ_MTLOG(ML_ERROR, + "Failed to convert listen addr to a string representation"); + return R_INTERNAL; + } + + listen_addr_ = addr_string; + listen_port_ = instance_port; + + return 0; +} + +UniquePtr<TestStunServer> TestStunServer::Create(int address_family) { + NR_reg_init(NR_REG_MODE_LOCAL); + + UniquePtr<TestStunServer> server(new TestStunServer()); + + if (server->Initialize(address_family)) return nullptr; + + NR_SOCKET fd; + int r = nr_socket_getfd(server->listen_sock_, &fd); + if (r) { + MOZ_MTLOG(ML_ERROR, "Couldn't get fd"); + return nullptr; + } + + NR_ASYNC_WAIT(fd, NR_ASYNC_WAIT_READ, &TestStunServer::readable_cb, + server.get()); + + return server; +} + +void TestStunServer::ConfigurePort(uint16_t port) { instance_port = port; } + +TestStunServer* TestStunServer::GetInstance(int address_family) { + switch (address_family) { + case AF_INET: + if (!instance) instance = Create(address_family).release(); + + MOZ_ASSERT(instance); + return instance; + case AF_INET6: + if (!instance6) instance6 = Create(address_family).release(); + + return instance6; + default: + MOZ_CRASH(); + } +} + +void TestStunServer::ShutdownInstance() { + delete instance; + instance = nullptr; + delete instance6; + instance6 = nullptr; +} + +struct DeferredStunOperation { + DeferredStunOperation(TestStunServer* server, const char* data, size_t len, + nr_transport_addr* addr, nr_socket* sock) + : server_(server), buffer_(), sock_(sock) { + buffer_.Copy(reinterpret_cast<const uint8_t*>(data), len); + nr_transport_addr_copy(&addr_, addr); + } + + TestStunServer* server_; + MediaPacket buffer_; + nr_transport_addr addr_; + nr_socket* sock_; +}; + +void TestStunServer::Process(const uint8_t* msg, size_t len, + nr_transport_addr* addr, nr_socket* sock) { + if (!sock) { + sock = send_sock_; + } + + // Set the wrapped address so that the response goes to the right place. + nr_socket_wrapped_set_send_addr(sock, addr); + + nr_stun_server_process_request( + stun_server_, sock, const_cast<char*>(reinterpret_cast<const char*>(msg)), + len, response_addr_ ? response_addr_ : addr, NR_STUN_AUTH_RULE_OPTIONAL); +} + +void TestStunServer::process_cb(NR_SOCKET s, int how, void* cb_arg) { + DeferredStunOperation* op = static_cast<DeferredStunOperation*>(cb_arg); + op->server_->timer_handle_ = nullptr; + op->server_->Process(op->buffer_.data(), op->buffer_.len(), &op->addr_, + op->sock_); + + delete op; +} + +nr_socket* TestStunServer::GetReceivingSocket(NR_SOCKET s) { + return listen_sock_; +} + +nr_socket* TestStunServer::GetSendingSocket(nr_socket* sock) { + return send_sock_; +} + +void TestStunServer::readable_cb(NR_SOCKET s, int how, void* cb_arg) { + TestStunServer* server = static_cast<TestStunServer*>(cb_arg); + + char message[max_stun_message_size]; + size_t message_len; + nr_transport_addr addr; + nr_socket* recv_sock = server->GetReceivingSocket(s); + if (!recv_sock) { + MOZ_MTLOG(ML_ERROR, "Failed to lookup receiving socket"); + return; + } + nr_socket* send_sock = server->GetSendingSocket(recv_sock); + + /* Re-arm. */ + NR_ASYNC_WAIT(s, NR_ASYNC_WAIT_READ, &TestStunServer::readable_cb, server); + + if (nr_socket_recvfrom(recv_sock, message, sizeof(message), &message_len, 0, + &addr)) { + MOZ_MTLOG(ML_ERROR, "Couldn't read STUN message"); + return; + } + + MOZ_MTLOG(ML_DEBUG, "Received data of length " << message_len); + + // If we have initial dropping set, check at this point. + std::string key(addr.as_string); + + if (server->received_ct_.count(key) == 0) { + server->received_ct_[key] = 0; + } + + ++server->received_ct_[key]; + + if (!server->active_ || (server->received_ct_[key] <= server->initial_ct_)) { + MOZ_MTLOG(ML_DEBUG, "Dropping message #" << server->received_ct_[key] + << " from " << key); + return; + } + + if (server->delay_ms_) { + NR_ASYNC_TIMER_SET(server->delay_ms_, process_cb, + new DeferredStunOperation(server, message, message_len, + &addr, send_sock), + &server->timer_handle_); + } else { + server->Process(reinterpret_cast<const uint8_t*>(message), message_len, + &addr, send_sock); + } +} + +void TestStunServer::SetActive(bool active) { active_ = active; } + +void TestStunServer::SetDelay(uint32_t delay_ms) { delay_ms_ = delay_ms; } + +void TestStunServer::SetDropInitialPackets(uint32_t count) { + initial_ct_ = count; +} + +nsresult TestStunServer::SetResponseAddr(nr_transport_addr* addr) { + delete response_addr_; + + response_addr_ = new nr_transport_addr(); + + int r = nr_transport_addr_copy(response_addr_, addr); + if (r) return NS_ERROR_FAILURE; + + return NS_OK; +} + +nsresult TestStunServer::SetResponseAddr(const std::string& addr, + uint16_t port) { + nr_transport_addr addr2; + + int r = + nr_str_port_to_transport_addr(addr.c_str(), port, IPPROTO_UDP, &addr2); + if (r) return NS_ERROR_FAILURE; + + return SetResponseAddr(&addr2); +} + +void TestStunServer::Reset() { + delay_ms_ = 0; + if (timer_handle_) { + NR_async_timer_cancel(timer_handle_); + timer_handle_ = nullptr; + } + delete response_addr_; + response_addr_ = nullptr; + received_ct_.clear(); +} + +// TestStunTcpServer + +void TestStunTcpServer::ConfigurePort(uint16_t port) { instance_port = port; } + +TestStunTcpServer* TestStunTcpServer::GetInstance(int address_family) { + switch (address_family) { + case AF_INET: + if (!instance) instance = Create(address_family).release(); + + MOZ_ASSERT(instance); + return instance; + case AF_INET6: + if (!instance6) instance6 = Create(address_family).release(); + + return instance6; + default: + MOZ_CRASH(); + } +} + +void TestStunTcpServer::ShutdownInstance() { + delete instance; + instance = nullptr; + delete instance6; + instance6 = nullptr; +} + +int TestStunTcpServer::TryOpenListenSocket(nr_local_addr* addr, uint16_t port) { + addr->addr.protocol = IPPROTO_TCP; + + int r = SetInternalPort(addr, port); + + if (r) return r; + + nr_socket* sock; + if (nr_socket_local_create(nullptr, &addr->addr, &sock)) { + MOZ_MTLOG(ML_ERROR, "Couldn't create listen tcp socket"); + return R_ALREADY; + } + + if (nr_socket_buffered_stun_create(sock, 2048, TURN_TCP_FRAMING, + &listen_sock_)) { + MOZ_MTLOG(ML_ERROR, "Couldn't create listen tcp socket"); + return R_ALREADY; + } + + if (nr_socket_listen(listen_sock_, 10)) { + MOZ_MTLOG(ML_ERROR, "Couldn't listen on socket"); + return R_ALREADY; + } + + return 0; +} + +nr_socket* TestStunTcpServer::GetReceivingSocket(NR_SOCKET s) { + return connections_[s]; +} + +nr_socket* TestStunTcpServer::GetSendingSocket(nr_socket* sock) { return sock; } + +void TestStunTcpServer::accept_cb(NR_SOCKET s, int how, void* cb_arg) { + TestStunTcpServer* server = static_cast<TestStunTcpServer*>(cb_arg); + nr_socket *newsock, *bufsock, *wrapsock; + nr_transport_addr remote_addr; + NR_SOCKET fd; + + /* rearm */ + NR_ASYNC_WAIT(s, NR_ASYNC_WAIT_READ, &TestStunTcpServer::accept_cb, cb_arg); + + /* accept */ + if (nr_socket_accept(server->listen_sock_, &remote_addr, &newsock)) { + MOZ_MTLOG(ML_ERROR, "Couldn't accept incoming tcp connection"); + return; + } + + if (nr_socket_buffered_stun_create(newsock, 2048, TURN_TCP_FRAMING, + &bufsock)) { + MOZ_MTLOG(ML_ERROR, "Couldn't create connected tcp socket"); + nr_socket_destroy(&newsock); + return; + } + + nr_socket_buffered_set_connected_to(bufsock, &remote_addr); + + if (nr_socket_wrapped_create(bufsock, &wrapsock)) { + MOZ_MTLOG(ML_ERROR, "Couldn't wrap connected tcp socket"); + nr_socket_destroy(&bufsock); + return; + } + + if (nr_socket_getfd(wrapsock, &fd)) { + MOZ_MTLOG(ML_ERROR, "Couldn't get fd from connected tcp socket"); + nr_socket_destroy(&wrapsock); + return; + } + + server->connections_[fd] = wrapsock; + + NR_ASYNC_WAIT(fd, NR_ASYNC_WAIT_READ, &TestStunServer::readable_cb, server); +} + +UniquePtr<TestStunTcpServer> TestStunTcpServer::Create(int address_family) { + NR_reg_init(NR_REG_MODE_LOCAL); + + UniquePtr<TestStunTcpServer> server(new TestStunTcpServer()); + + if (server->Initialize(address_family)) { + return nullptr; + } + + NR_SOCKET fd; + if (nr_socket_getfd(server->listen_sock_, &fd)) { + MOZ_MTLOG(ML_ERROR, "Couldn't get tcp fd"); + return nullptr; + } + + NR_ASYNC_WAIT(fd, NR_ASYNC_WAIT_READ, &TestStunTcpServer::accept_cb, + server.get()); + + return server; +} + +TestStunTcpServer::~TestStunTcpServer() { + for (auto it = connections_.begin(); it != connections_.end();) { + NR_ASYNC_CANCEL(it->first, NR_ASYNC_WAIT_READ); + nr_socket_destroy(&it->second); + connections_.erase(it++); + } +} + +} // namespace mozilla diff --git a/dom/media/webrtc/transport/test/stunserver.h b/dom/media/webrtc/transport/test/stunserver.h new file mode 100644 index 0000000000..4903cb89ad --- /dev/null +++ b/dom/media/webrtc/transport/test/stunserver.h @@ -0,0 +1,123 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: ekr@rtfm.com + +#ifndef stunserver_h__ +#define stunserver_h__ + +#include <map> +#include <string> +#include "nsError.h" +#include "mozilla/UniquePtr.h" + +typedef struct nr_stun_server_ctx_ nr_stun_server_ctx; +typedef struct nr_socket_ nr_socket; +typedef struct nr_local_addr_ nr_local_addr; + +namespace mozilla { + +class TestStunServer { + public: + // Generally, you should only call API in this class from the same thread that + // the initial |GetInstance| call was made from. + static TestStunServer* GetInstance(int address_family = AF_INET); + static void ShutdownInstance(); + // |ConfigurePort| will only have an effect if called before the first call + // to |GetInstance| (possibly following a |ShutdownInstance| call) + static void ConfigurePort(uint16_t port); + // AF_INET, AF_INET6 + static UniquePtr<TestStunServer> Create(int address_family); + + virtual ~TestStunServer(); + + void SetActive(bool active); + void SetDelay(uint32_t delay_ms); + void SetDropInitialPackets(uint32_t count); + const std::string& addr() const { return listen_addr_; } + uint16_t port() const { return listen_port_; } + + // These should only be called from the same thread as the initial + // |GetInstance| call. + nsresult SetResponseAddr(nr_transport_addr* addr); + nsresult SetResponseAddr(const std::string& addr, uint16_t port); + + void Reset(); + + static const size_t max_stun_message_size = 4096; + + virtual nr_socket* GetReceivingSocket(NR_SOCKET s); + virtual nr_socket* GetSendingSocket(nr_socket* sock); + + protected: + TestStunServer() + : listen_port_(0), + listen_sock_(nullptr), + send_sock_(nullptr), + stun_server_(nullptr), + active_(true), + delay_ms_(0), + initial_ct_(0), + response_addr_(nullptr), + timer_handle_(nullptr) {} + + int SetInternalPort(nr_local_addr* addr, uint16_t port); + int Initialize(int address_family); + + static void readable_cb(NR_SOCKET sock, int how, void* cb_arg); + + private: + void Process(const uint8_t* msg, size_t len, nr_transport_addr* addr_in, + nr_socket* sock); + virtual int TryOpenListenSocket(nr_local_addr* addr, uint16_t port); + static void process_cb(NR_SOCKET sock, int how, void* cb_arg); + + protected: + std::string listen_addr_; + uint16_t listen_port_; + nr_socket* listen_sock_; + nr_socket* send_sock_; + nr_stun_server_ctx* stun_server_; + + private: + bool active_; + uint32_t delay_ms_; + uint32_t initial_ct_; + nr_transport_addr* response_addr_; + void* timer_handle_; + std::map<std::string, uint32_t> received_ct_; + + static TestStunServer* instance; + static TestStunServer* instance6; + static uint16_t instance_port; +}; + +class TestStunTcpServer : public TestStunServer { + public: + static TestStunTcpServer* GetInstance(int address_family); + static void ShutdownInstance(); + static void ConfigurePort(uint16_t port); + virtual ~TestStunTcpServer(); + + virtual nr_socket* GetReceivingSocket(NR_SOCKET s); + virtual nr_socket* GetSendingSocket(nr_socket* sock); + + protected: + TestStunTcpServer() = default; + static void accept_cb(NR_SOCKET sock, int how, void* cb_arg); + + private: + virtual int TryOpenListenSocket(nr_local_addr* addr, uint16_t port); + static UniquePtr<TestStunTcpServer> Create(int address_family); + + static TestStunTcpServer* instance; + static TestStunTcpServer* instance6; + static uint16_t instance_port; + + std::map<NR_SOCKET, nr_socket*> connections_; +}; +} // End of namespace mozilla +#endif diff --git a/dom/media/webrtc/transport/test/test_nr_socket_ice_unittest.cpp b/dom/media/webrtc/transport/test/test_nr_socket_ice_unittest.cpp new file mode 100644 index 0000000000..b55b05f10c --- /dev/null +++ b/dom/media/webrtc/transport/test/test_nr_socket_ice_unittest.cpp @@ -0,0 +1,409 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Some of this code is taken from nricectx.cpp and nricemediastream.cpp +// which in turn contains code cut-and-pasted from nICEr. Copyright is: + +/* +Copyright (c) 2007, Adobe Systems, Incorporated +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of Adobe Systems, Network Resonance nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#include "gtest/gtest.h" +#include "gtest_utils.h" + +extern "C" { +#include "ice_ctx.h" +#include "ice_peer_ctx.h" +#include "nICEr/src/net/transport_addr.h" +} + +#include "mtransport_test_utils.h" +#include "nricectx.h" +#include "nricemediastream.h" +#include "runnable_utils.h" +#include "test_nr_socket.h" + +namespace mozilla { + +static unsigned int kDefaultTimeout = 7000; + +class IcePeer { + public: + IcePeer(const char* name, TestNat* nat, UINT4 flags, + MtransportTestUtils* test_utils) + : name_(name), + ice_checking_(false), + ice_connected_(false), + ice_disconnected_(false), + gather_cb_(false), + stream_ready_(false), + stream_failed_(false), + ice_ctx_(nullptr), + peer_ctx_(nullptr), + nat_(nat), + test_utils_(test_utils) { + nr_ice_ctx_create(const_cast<char*>(name_.c_str()), flags, &ice_ctx_); + + if (nat_) { + nr_socket_factory* factory; + nat_->create_socket_factory(&factory); + nr_ice_ctx_set_socket_factory(ice_ctx_, factory); + } + + // Create the handler objects + ice_handler_vtbl_ = new nr_ice_handler_vtbl(); + ice_handler_vtbl_->select_pair = &IcePeer::select_pair; + ice_handler_vtbl_->stream_ready = &IcePeer::stream_ready; + ice_handler_vtbl_->stream_failed = &IcePeer::stream_failed; + ice_handler_vtbl_->ice_connected = &IcePeer::ice_connected; + ice_handler_vtbl_->msg_recvd = &IcePeer::msg_recvd; + ice_handler_vtbl_->ice_checking = &IcePeer::ice_checking; + ice_handler_vtbl_->ice_disconnected = &IcePeer::ice_disconnected; + + ice_handler_ = new nr_ice_handler(); + ice_handler_->vtbl = ice_handler_vtbl_; + ice_handler_->obj = this; + + nr_ice_peer_ctx_create(ice_ctx_, ice_handler_, + const_cast<char*>(name_.c_str()), &peer_ctx_); + + nr_ice_add_media_stream(ice_ctx_, const_cast<char*>(name_.c_str()), "ufrag", + "pass", 2, &ice_media_stream_); + EXPECT_EQ(2UL, GetStreamAttributes().size()); + + nr_ice_media_stream_initialize(ice_ctx_, ice_media_stream_); + } + + virtual ~IcePeer() { Destroy(); } + + void Destroy() { + test_utils_->SyncDispatchToSTS(WrapRunnable(this, &IcePeer::Destroy_s)); + } + + void Destroy_s() { + nr_ice_peer_ctx_destroy(&peer_ctx_); + delete ice_handler_; + delete ice_handler_vtbl_; + nr_ice_ctx_destroy(&ice_ctx_); + } + + void Gather(bool default_route_only = false) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &IcePeer::Gather_s, default_route_only)); + } + + void Gather_s(bool default_route_only = false) { + int r = nr_ice_gather(ice_ctx_, &IcePeer::gather_cb, this); + ASSERT_TRUE(r == 0 || r == R_WOULDBLOCK); + } + + std::vector<std::string> GetStreamAttributes() { + std::vector<std::string> attributes; + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&attributes, this, &IcePeer::GetStreamAttributes_s)); + return attributes; + } + + std::vector<std::string> GetStreamAttributes_s() { + char** attrs = nullptr; + int attrct; + std::vector<std::string> ret; + + int r = + nr_ice_media_stream_get_attributes(ice_media_stream_, &attrs, &attrct); + EXPECT_EQ(0, r); + + for (int i = 0; i < attrct; i++) { + ret.push_back(std::string(attrs[i])); + RFREE(attrs[i]); + } + RFREE(attrs); + + return ret; + } + + std::vector<std::string> GetGlobalAttributes() { + std::vector<std::string> attributes; + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&attributes, this, &IcePeer::GetGlobalAttributes_s)); + return attributes; + } + + std::vector<std::string> GetGlobalAttributes_s() { + char** attrs = nullptr; + int attrct; + std::vector<std::string> ret; + + nr_ice_get_global_attributes(ice_ctx_, &attrs, &attrct); + + for (int i = 0; i < attrct; i++) { + ret.push_back(std::string(attrs[i])); + RFREE(attrs[i]); + } + RFREE(attrs); + + return ret; + } + + void ParseGlobalAttributes(std::vector<std::string> attrs) { + std::vector<char*> attrs_in; + attrs_in.reserve(attrs.size()); + for (auto& attr : attrs) { + attrs_in.push_back(const_cast<char*>(attr.c_str())); + } + + int r = nr_ice_peer_ctx_parse_global_attributes( + peer_ctx_, attrs_in.empty() ? nullptr : &attrs_in[0], attrs_in.size()); + ASSERT_EQ(0, r); + } + + void SetControlling(bool controlling) { + peer_ctx_->controlling = controlling ? 1 : 0; + } + + void SetRemoteAttributes(std::vector<std::string> attributes) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &IcePeer::SetRemoteAttributes_s, attributes)); + } + + void SetRemoteAttributes_s(std::vector<std::string> attributes) { + int r; + + std::vector<char*> attrs; + attrs.reserve(attributes.size()); + for (auto& attr : attributes) { + attrs.push_back(const_cast<char*>(attr.c_str())); + } + + if (!attrs.empty()) { + r = nr_ice_peer_ctx_parse_stream_attributes(peer_ctx_, ice_media_stream_, + &attrs[0], attrs.size()); + ASSERT_EQ(0, r); + } + } + + void StartChecks() { + test_utils_->SyncDispatchToSTS(WrapRunnable(this, &IcePeer::StartChecks_s)); + } + + void StartChecks_s() { + int r = nr_ice_peer_ctx_pair_candidates(peer_ctx_); + ASSERT_EQ(0, r); + + r = nr_ice_peer_ctx_start_checks2(peer_ctx_, 1); + ASSERT_EQ(0, r); + } + + // Handler callbacks + static int select_pair(void* obj, nr_ice_media_stream* stream, + int component_id, nr_ice_cand_pair** potentials, + int potential_ct) { + return 0; + } + + static int stream_ready(void* obj, nr_ice_media_stream* stream) { + IcePeer* peer = static_cast<IcePeer*>(obj); + peer->stream_ready_ = true; + return 0; + } + + static int stream_failed(void* obj, nr_ice_media_stream* stream) { + IcePeer* peer = static_cast<IcePeer*>(obj); + peer->stream_failed_ = true; + return 0; + } + + static int ice_checking(void* obj, nr_ice_peer_ctx* pctx) { + IcePeer* peer = static_cast<IcePeer*>(obj); + peer->ice_checking_ = true; + return 0; + } + + static int ice_connected(void* obj, nr_ice_peer_ctx* pctx) { + IcePeer* peer = static_cast<IcePeer*>(obj); + peer->ice_connected_ = true; + return 0; + } + + static int ice_disconnected(void* obj, nr_ice_peer_ctx* pctx) { + IcePeer* peer = static_cast<IcePeer*>(obj); + peer->ice_disconnected_ = true; + return 0; + } + + static int msg_recvd(void* obj, nr_ice_peer_ctx* pctx, + nr_ice_media_stream* stream, int component_id, + UCHAR* msg, int len) { + return 0; + } + + static void gather_cb(NR_SOCKET s, int h, void* arg) { + IcePeer* peer = static_cast<IcePeer*>(arg); + peer->gather_cb_ = true; + } + + std::string name_; + + bool ice_checking_; + bool ice_connected_; + bool ice_disconnected_; + bool gather_cb_; + bool stream_ready_; + bool stream_failed_; + + nr_ice_ctx* ice_ctx_; + nr_ice_handler* ice_handler_; + nr_ice_handler_vtbl* ice_handler_vtbl_; + nr_ice_media_stream* ice_media_stream_; + nr_ice_peer_ctx* peer_ctx_; + TestNat* nat_; + MtransportTestUtils* test_utils_; +}; + +class TestNrSocketIceUnitTest : public ::testing::Test { + public: + void SetUp() override { + NSS_NoDB_Init(nullptr); + NSS_SetDomesticPolicy(); + + test_utils_ = new MtransportTestUtils(); + test_utils2_ = new MtransportTestUtils(); + + NrIceCtx::InitializeGlobals(NrIceCtx::GlobalConfig()); + } + + void TearDown() override { + delete test_utils_; + delete test_utils2_; + } + + MtransportTestUtils* test_utils_; + MtransportTestUtils* test_utils2_; +}; + +TEST_F(TestNrSocketIceUnitTest, TestIcePeer) { + IcePeer peer("IcePeer", nullptr, NR_ICE_CTX_FLAGS_AGGRESSIVE_NOMINATION, + test_utils_); + ASSERT_NE(peer.ice_ctx_, nullptr); + ASSERT_NE(peer.peer_ctx_, nullptr); + ASSERT_NE(peer.ice_media_stream_, nullptr); + ASSERT_EQ(2UL, peer.GetStreamAttributes().size()) + << "Should have ice-ufrag and ice-pwd"; + peer.Gather(); + ASSERT_LT(2UL, peer.GetStreamAttributes().size()) + << "Should have ice-ufrag, ice-pwd, and at least one candidate."; +} + +TEST_F(TestNrSocketIceUnitTest, TestIcePeersNoNAT) { + IcePeer peer("IcePeer", nullptr, NR_ICE_CTX_FLAGS_AGGRESSIVE_NOMINATION, + test_utils_); + IcePeer peer2("IcePeer2", nullptr, NR_ICE_CTX_FLAGS_AGGRESSIVE_NOMINATION, + test_utils2_); + peer.SetControlling(true); + peer2.SetControlling(false); + + peer.Gather(); + peer2.Gather(); + std::vector<std::string> attrs = peer.GetGlobalAttributes(); + peer2.ParseGlobalAttributes(attrs); + std::vector<std::string> attributes = peer.GetStreamAttributes(); + peer2.SetRemoteAttributes(attributes); + + attrs = peer2.GetGlobalAttributes(); + peer.ParseGlobalAttributes(attrs); + attributes = peer2.GetStreamAttributes(); + peer.SetRemoteAttributes(attributes); + peer2.StartChecks(); + peer.StartChecks(); + + ASSERT_TRUE_WAIT(peer.ice_connected_, kDefaultTimeout); + ASSERT_TRUE_WAIT(peer2.ice_connected_, kDefaultTimeout); +} + +TEST_F(TestNrSocketIceUnitTest, TestIcePeersPacketLoss) { + IcePeer peer("IcePeer", nullptr, NR_ICE_CTX_FLAGS_AGGRESSIVE_NOMINATION, + test_utils_); + + RefPtr<TestNat> nat(new TestNat); + class NatDelegate : public TestNat::NatDelegate { + public: + NatDelegate() : messages(0) {} + + int on_read(TestNat* nat, void* buf, size_t maxlen, size_t* len) override { + return 0; + } + + int on_sendto(TestNat* nat, const void* msg, size_t len, int flags, + const nr_transport_addr* to) override { + ++messages; + // 25% packet loss + if (messages % 4 == 0) { + return 1; + } + return 0; + } + + int on_write(TestNat* nat, const void* msg, size_t len, + size_t* written) override { + return 0; + } + + int messages; + } delegate; + nat->nat_delegate_ = &delegate; + + IcePeer peer2("IcePeer2", nat, NR_ICE_CTX_FLAGS_AGGRESSIVE_NOMINATION, + test_utils2_); + peer.SetControlling(true); + peer2.SetControlling(false); + + peer.Gather(); + peer2.Gather(); + std::vector<std::string> attrs = peer.GetGlobalAttributes(); + peer2.ParseGlobalAttributes(attrs); + std::vector<std::string> attributes = peer.GetStreamAttributes(); + peer2.SetRemoteAttributes(attributes); + + attrs = peer2.GetGlobalAttributes(); + peer.ParseGlobalAttributes(attrs); + attributes = peer2.GetStreamAttributes(); + peer.SetRemoteAttributes(attributes); + peer2.StartChecks(); + peer.StartChecks(); + + ASSERT_TRUE_WAIT(peer.ice_connected_, kDefaultTimeout); + ASSERT_TRUE_WAIT(peer2.ice_connected_, kDefaultTimeout); +} + +} // namespace mozilla diff --git a/dom/media/webrtc/transport/test/test_nr_socket_unittest.cpp b/dom/media/webrtc/transport/test/test_nr_socket_unittest.cpp new file mode 100644 index 0000000000..af2779accd --- /dev/null +++ b/dom/media/webrtc/transport/test/test_nr_socket_unittest.cpp @@ -0,0 +1,800 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: bcampen@mozilla.com + +#include <cstddef> + +extern "C" { +#include "r_errors.h" +#include "async_wait.h" +} + +#include "test_nr_socket.h" + +#include "nsCOMPtr.h" +#include "nsNetCID.h" +#include "nsServiceManagerUtils.h" +#include "runnable_utils.h" + +#include <vector> + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +#define DATA_BUF_SIZE 1024 + +namespace mozilla { + +class TestNrSocketTest : public MtransportTest { + public: + TestNrSocketTest() + : MtransportTest(), + wait_done_for_main_(false), + sts_(), + public_addrs_(), + private_addrs_(), + nats_() {} + + void SetUp() override { + MtransportTest::SetUp(); + + // Get the transport service as a dispatch target + nsresult rv; + sts_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + EXPECT_TRUE(NS_SUCCEEDED(rv)) << "Failed to get STS: " << (int)rv; + } + + void TearDown() override { + SyncDispatchToSTS(WrapRunnable(this, &TestNrSocketTest::TearDown_s)); + + MtransportTest::TearDown(); + } + + void TearDown_s() { + public_addrs_.clear(); + private_addrs_.clear(); + nats_.clear(); + sts_ = nullptr; + } + + RefPtr<TestNrSocket> CreateTestNrSocket_s(const char* ip_str, int proto, + TestNat* nat) { + // If no nat is supplied, we create a default NAT which is disabled. This + // is how we simulate a non-natted socket. + RefPtr<TestNrSocket> sock(new TestNrSocket(nat ? nat : new TestNat)); + nr_transport_addr address; + nr_str_port_to_transport_addr(ip_str, 0, proto, &address); + int r = sock->create(&address); + if (r) { + return nullptr; + } + return sock; + } + + void CreatePublicAddrs(size_t count, const char* ip_str = "127.0.0.1", + int proto = IPPROTO_UDP) { + SyncDispatchToSTS(WrapRunnable(this, &TestNrSocketTest::CreatePublicAddrs_s, + count, ip_str, proto)); + } + + void CreatePublicAddrs_s(size_t count, const char* ip_str, int proto) { + while (count--) { + auto sock = CreateTestNrSocket_s(ip_str, proto, nullptr); + ASSERT_TRUE(sock) + << "Failed to create socket"; + public_addrs_.push_back(sock); + } + } + + RefPtr<TestNat> CreatePrivateAddrs(size_t size, + const char* ip_str = "127.0.0.1", + int proto = IPPROTO_UDP) { + RefPtr<TestNat> result; + SyncDispatchToSTS(WrapRunnableRet(&result, this, + &TestNrSocketTest::CreatePrivateAddrs_s, + size, ip_str, proto)); + return result; + } + + RefPtr<TestNat> CreatePrivateAddrs_s(size_t count, const char* ip_str, + int proto) { + RefPtr<TestNat> nat(new TestNat); + while (count--) { + auto sock = CreateTestNrSocket_s(ip_str, proto, nat); + if (!sock) { + EXPECT_TRUE(false) << "Failed to create socket"; + break; + } + private_addrs_.push_back(sock); + } + nat->enabled_ = true; + nats_.push_back(nat); + return nat; + } + + bool CheckConnectivityVia( + TestNrSocket* from, TestNrSocket* to, const nr_transport_addr& via, + nr_transport_addr* sender_external_address = nullptr) { + MOZ_ASSERT(from); + + if (!WaitForWriteable(from)) { + return false; + } + + int result = 0; + SyncDispatchToSTS(WrapRunnableRet( + &result, this, &TestNrSocketTest::SendData_s, from, via)); + if (result) { + return false; + } + + if (!WaitForReadable(to)) { + return false; + } + + nr_transport_addr dummy_outparam; + if (!sender_external_address) { + sender_external_address = &dummy_outparam; + } + + MOZ_ASSERT(to); + SyncDispatchToSTS(WrapRunnableRet(&result, this, + &TestNrSocketTest::RecvData_s, to, + sender_external_address)); + + return !result; + } + + bool CheckConnectivity(TestNrSocket* from, TestNrSocket* to, + nr_transport_addr* sender_external_address = nullptr) { + nr_transport_addr destination_address; + int r = GetAddress(to, &destination_address); + if (r) { + return false; + } + + return CheckConnectivityVia(from, to, destination_address, + sender_external_address); + } + + bool CheckTcpConnectivity(TestNrSocket* from, TestNrSocket* to) { + NrSocketBase* accepted_sock; + if (!Connect(from, to, &accepted_sock)) { + std::cerr << "Connect failed" << std::endl; + return false; + } + + // write on |from|, recv on |accepted_sock| + if (!WaitForWriteable(from)) { + std::cerr << __LINE__ << "WaitForWriteable (1) failed" << std::endl; + return false; + } + + int r; + SyncDispatchToSTS( + WrapRunnableRet(&r, this, &TestNrSocketTest::SendDataTcp_s, from)); + if (r) { + std::cerr << "SendDataTcp_s (1) failed" << std::endl; + return false; + } + + if (!WaitForReadable(accepted_sock)) { + std::cerr << __LINE__ << "WaitForReadable (1) failed" << std::endl; + return false; + } + + SyncDispatchToSTS(WrapRunnableRet( + &r, this, &TestNrSocketTest::RecvDataTcp_s, accepted_sock)); + if (r) { + std::cerr << "RecvDataTcp_s (1) failed" << std::endl; + return false; + } + + if (!WaitForWriteable(accepted_sock)) { + std::cerr << __LINE__ << "WaitForWriteable (2) failed" << std::endl; + return false; + } + + SyncDispatchToSTS(WrapRunnableRet( + &r, this, &TestNrSocketTest::SendDataTcp_s, accepted_sock)); + if (r) { + std::cerr << "SendDataTcp_s (2) failed" << std::endl; + return false; + } + + if (!WaitForReadable(from)) { + std::cerr << __LINE__ << "WaitForReadable (2) failed" << std::endl; + return false; + } + + SyncDispatchToSTS( + WrapRunnableRet(&r, this, &TestNrSocketTest::RecvDataTcp_s, from)); + if (r) { + std::cerr << "RecvDataTcp_s (2) failed" << std::endl; + return false; + } + + return true; + } + + int GetAddress(TestNrSocket* sock, nr_transport_addr_* address) { + MOZ_ASSERT(sock); + MOZ_ASSERT(address); + int r; + SyncDispatchToSTS(WrapRunnableRet(&r, this, &TestNrSocketTest::GetAddress_s, + sock, address)); + return r; + } + + int GetAddress_s(TestNrSocket* sock, nr_transport_addr* address) { + return sock->getaddr(address); + } + + int SendData_s(TestNrSocket* from, const nr_transport_addr& to) { + // It is up to caller to ensure that |from| is writeable. + const char buf[] = "foobajooba"; + return from->sendto(buf, sizeof(buf), 0, &to); + } + + int SendDataTcp_s(NrSocketBase* from) { + // It is up to caller to ensure that |from| is writeable. + const char buf[] = "foobajooba"; + size_t written; + return from->write(buf, sizeof(buf), &written); + } + + int RecvData_s(TestNrSocket* to, nr_transport_addr* from) { + // It is up to caller to ensure that |to| is readable + char buf[DATA_BUF_SIZE]; + size_t len; + // Maybe check that data matches? + int r = to->recvfrom(buf, sizeof(buf), &len, 0, from); + if (!r && (len == 0)) { + r = R_INTERNAL; + } + return r; + } + + int RecvDataTcp_s(NrSocketBase* to) { + // It is up to caller to ensure that |to| is readable + char buf[DATA_BUF_SIZE]; + size_t len; + // Maybe check that data matches? + int r = to->read(buf, sizeof(buf), &len); + if (!r && (len == 0)) { + r = R_INTERNAL; + } + return r; + } + + int Listen_s(TestNrSocket* to) { + // listen on |to| + int r = to->listen(1); + if (r) { + return r; + } + return 0; + } + + int Connect_s(TestNrSocket* from, TestNrSocket* to) { + // connect on |from| + nr_transport_addr destination_address; + int r = to->getaddr(&destination_address); + if (r) { + return r; + } + + r = from->connect(&destination_address); + if (r) { + return r; + } + + return 0; + } + + int Accept_s(TestNrSocket* to, NrSocketBase** accepted_sock) { + nr_socket* sock; + nr_transport_addr source_address; + int r = to->accept(&source_address, &sock); + if (r) { + return r; + } + + *accepted_sock = reinterpret_cast<NrSocketBase*>(sock->obj); + return 0; + } + + bool Connect(TestNrSocket* from, TestNrSocket* to, + NrSocketBase** accepted_sock) { + int r; + SyncDispatchToSTS( + WrapRunnableRet(&r, this, &TestNrSocketTest::Listen_s, to)); + if (r) { + std::cerr << "Listen_s failed: " << r << std::endl; + return false; + } + + SyncDispatchToSTS( + WrapRunnableRet(&r, this, &TestNrSocketTest::Connect_s, from, to)); + if (r && r != R_WOULDBLOCK) { + std::cerr << "Connect_s failed: " << r << std::endl; + return false; + } + + if (!WaitForReadable(to)) { + std::cerr << "WaitForReadable failed" << std::endl; + return false; + } + + SyncDispatchToSTS(WrapRunnableRet(&r, this, &TestNrSocketTest::Accept_s, to, + accepted_sock)); + + if (r) { + std::cerr << "Accept_s failed: " << r << std::endl; + return false; + } + return true; + } + + bool WaitForSocketState(NrSocketBase* sock, int state) { + MOZ_ASSERT(sock); + SyncDispatchToSTS(WrapRunnable( + this, &TestNrSocketTest::WaitForSocketState_s, sock, state)); + + bool res; + WAIT_(wait_done_for_main_, 500, res); + wait_done_for_main_ = false; + + if (!res) { + SyncDispatchToSTS( + WrapRunnable(this, &TestNrSocketTest::CancelWait_s, sock, state)); + } + + return res; + } + + void WaitForSocketState_s(NrSocketBase* sock, int state) { + NR_ASYNC_WAIT(sock, state, &WaitDone, this); + } + + void CancelWait_s(NrSocketBase* sock, int state) { sock->cancel(state); } + + bool WaitForReadable(NrSocketBase* sock) { + return WaitForSocketState(sock, NR_ASYNC_WAIT_READ); + } + + bool WaitForWriteable(NrSocketBase* sock) { + return WaitForSocketState(sock, NR_ASYNC_WAIT_WRITE); + } + + void SyncDispatchToSTS(nsIRunnable* runnable) { + NS_DispatchAndSpinEventLoopUntilComplete( + "TestNrSocketTest::SyncDispatchToSTS"_ns, sts_, do_AddRef(runnable)); + } + + static void WaitDone(void* sock, int how, void* test_fixture) { + TestNrSocketTest* test = static_cast<TestNrSocketTest*>(test_fixture); + test->wait_done_for_main_ = true; + } + + // Simple busywait boolean for the test cases to spin on. + Atomic<bool> wait_done_for_main_; + + nsCOMPtr<nsIEventTarget> sts_; + std::vector<RefPtr<TestNrSocket>> public_addrs_; + std::vector<RefPtr<TestNrSocket>> private_addrs_; + std::vector<RefPtr<TestNat>> nats_; +}; + +} // namespace mozilla + +using mozilla::NrSocketBase; +using mozilla::TestNat; +using mozilla::TestNrSocketTest; + +TEST_F(TestNrSocketTest, UnsafePortRejectedUDP) { + nr_transport_addr address; + ASSERT_FALSE(nr_str_port_to_transport_addr("127.0.0.1", + // ssh + 22, IPPROTO_UDP, &address)); + ASSERT_TRUE(NrSocketBase::IsForbiddenAddress(&address)); +} + +TEST_F(TestNrSocketTest, UnsafePortRejectedTCP) { + nr_transport_addr address; + ASSERT_FALSE(nr_str_port_to_transport_addr("127.0.0.1", + // ssh + 22, IPPROTO_TCP, &address)); + ASSERT_TRUE(NrSocketBase::IsForbiddenAddress(&address)); +} + +TEST_F(TestNrSocketTest, SafePortAcceptedUDP) { + nr_transport_addr address; + ASSERT_FALSE(nr_str_port_to_transport_addr("127.0.0.1", + // stuns + 5349, IPPROTO_UDP, &address)); + ASSERT_FALSE(NrSocketBase::IsForbiddenAddress(&address)); +} + +TEST_F(TestNrSocketTest, SafePortAcceptedTCP) { + nr_transport_addr address; + ASSERT_FALSE(nr_str_port_to_transport_addr("127.0.0.1", + // turns + 5349, IPPROTO_TCP, &address)); + ASSERT_FALSE(NrSocketBase::IsForbiddenAddress(&address)); +} + +TEST_F(TestNrSocketTest, PublicConnectivity) { + CreatePublicAddrs(2); + + ASSERT_TRUE(CheckConnectivity(public_addrs_[0], public_addrs_[1])); + ASSERT_TRUE(CheckConnectivity(public_addrs_[1], public_addrs_[0])); + ASSERT_TRUE(CheckConnectivity(public_addrs_[0], public_addrs_[0])); + ASSERT_TRUE(CheckConnectivity(public_addrs_[1], public_addrs_[1])); +} + +TEST_F(TestNrSocketTest, PrivateConnectivity) { + RefPtr<TestNat> nat(CreatePrivateAddrs(2)); + nat->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], private_addrs_[1])); + ASSERT_TRUE(CheckConnectivity(private_addrs_[1], private_addrs_[0])); + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], private_addrs_[0])); + ASSERT_TRUE(CheckConnectivity(private_addrs_[1], private_addrs_[1])); +} + +TEST_F(TestNrSocketTest, NoConnectivityWithoutPinhole) { + RefPtr<TestNat> nat(CreatePrivateAddrs(1)); + nat->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + CreatePublicAddrs(1); + + ASSERT_FALSE(CheckConnectivity(public_addrs_[0], private_addrs_[0])); +} + +TEST_F(TestNrSocketTest, NoConnectivityBetweenSubnets) { + RefPtr<TestNat> nat1(CreatePrivateAddrs(1)); + nat1->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat1->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + RefPtr<TestNat> nat2(CreatePrivateAddrs(1)); + nat2->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat2->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + + ASSERT_FALSE(CheckConnectivity(private_addrs_[0], private_addrs_[1])); + ASSERT_FALSE(CheckConnectivity(private_addrs_[1], private_addrs_[0])); + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], private_addrs_[0])); + ASSERT_TRUE(CheckConnectivity(private_addrs_[1], private_addrs_[1])); +} + +TEST_F(TestNrSocketTest, FullConeAcceptIngress) { + RefPtr<TestNat> nat(CreatePrivateAddrs(1)); + nat->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + CreatePublicAddrs(2); + + nr_transport_addr sender_external_address; + // Open pinhole to public IP 0 + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[0], + &sender_external_address)); + + // Verify that return traffic works + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[0], private_addrs_[0], + sender_external_address)); + + // Verify that other public IP can use the pinhole + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[1], private_addrs_[0], + sender_external_address)); +} + +TEST_F(TestNrSocketTest, FullConeOnePinhole) { + RefPtr<TestNat> nat(CreatePrivateAddrs(1)); + nat->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + CreatePublicAddrs(2); + + nr_transport_addr sender_external_address; + // Open pinhole to public IP 0 + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[0], + &sender_external_address)); + + // Verify that return traffic works + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[0], private_addrs_[0], + sender_external_address)); + + // Send traffic to other public IP, verify that it uses the same pinhole + nr_transport_addr sender_external_address2; + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[1], + &sender_external_address2)); + ASSERT_FALSE(nr_transport_addr_cmp(&sender_external_address, + &sender_external_address2, + NR_TRANSPORT_ADDR_CMP_MODE_ALL)) + << "addr1: " << sender_external_address.as_string + << " addr2: " << sender_external_address2.as_string; +} + +// OS 10.6 doesn't seem to allow us to open ports on 127.0.0.2, and while linux +// does allow this, it has other behavior (see below) that prevents this test +// from working. +TEST_F(TestNrSocketTest, DISABLED_AddressRestrictedCone) { + RefPtr<TestNat> nat(CreatePrivateAddrs(1)); + nat->filtering_type_ = TestNat::ADDRESS_DEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + CreatePublicAddrs(2, "127.0.0.1"); + CreatePublicAddrs(1, "127.0.0.2"); + + nr_transport_addr sender_external_address; + // Open pinhole to public IP 0 + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[0], + &sender_external_address)); + + // Verify that return traffic works + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[0], private_addrs_[0], + sender_external_address)); + + // Verify that another address on the same host can use the pinhole + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[1], private_addrs_[0], + sender_external_address)); + + // Linux has a tendency to monkey around with source addresses, doing + // stuff like substituting 127.0.0.1 for packets sent by 127.0.0.2, and even + // going as far as substituting localhost for a packet sent from a real IP + // address when the destination is localhost. The only way to make this test + // work on linux is to have two real IP addresses. +#ifndef __linux__ + // Verify that an address on a different host can't use the pinhole + ASSERT_FALSE(CheckConnectivityVia(public_addrs_[2], private_addrs_[0], + sender_external_address)); +#endif + + // Send traffic to other public IP, verify that it uses the same pinhole + nr_transport_addr sender_external_address2; + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[1], + &sender_external_address2)); + ASSERT_FALSE(nr_transport_addr_cmp(&sender_external_address, + &sender_external_address2, + NR_TRANSPORT_ADDR_CMP_MODE_ALL)) + << "addr1: " << sender_external_address.as_string + << " addr2: " << sender_external_address2.as_string; + + // Verify that the other public IP can now use the pinhole + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[1], private_addrs_[0], + sender_external_address2)); + + // Send traffic to other public IP, verify that it uses the same pinhole + nr_transport_addr sender_external_address3; + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[2], + &sender_external_address3)); + ASSERT_FALSE(nr_transport_addr_cmp(&sender_external_address, + &sender_external_address3, + NR_TRANSPORT_ADDR_CMP_MODE_ALL)) + << "addr1: " << sender_external_address.as_string + << " addr2: " << sender_external_address3.as_string; + + // Verify that the other public IP can now use the pinhole + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[2], private_addrs_[0], + sender_external_address3)); +} + +TEST_F(TestNrSocketTest, RestrictedCone) { + RefPtr<TestNat> nat(CreatePrivateAddrs(1)); + nat->filtering_type_ = TestNat::PORT_DEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + CreatePublicAddrs(2); + + nr_transport_addr sender_external_address; + // Open pinhole to public IP 0 + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[0], + &sender_external_address)); + + // Verify that return traffic works + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[0], private_addrs_[0], + sender_external_address)); + + // Verify that other public IP cannot use the pinhole + ASSERT_FALSE(CheckConnectivityVia(public_addrs_[1], private_addrs_[0], + sender_external_address)); + + // Send traffic to other public IP, verify that it uses the same pinhole + nr_transport_addr sender_external_address2; + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[1], + &sender_external_address2)); + ASSERT_FALSE(nr_transport_addr_cmp(&sender_external_address, + &sender_external_address2, + NR_TRANSPORT_ADDR_CMP_MODE_ALL)) + << "addr1: " << sender_external_address.as_string + << " addr2: " << sender_external_address2.as_string; + + // Verify that the other public IP can now use the pinhole + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[1], private_addrs_[0], + sender_external_address2)); +} + +TEST_F(TestNrSocketTest, PortDependentMappingFullCone) { + RefPtr<TestNat> nat(CreatePrivateAddrs(1)); + nat->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_type_ = TestNat::PORT_DEPENDENT; + CreatePublicAddrs(2); + + nr_transport_addr sender_external_address0; + // Open pinhole to public IP 0 + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[0], + &sender_external_address0)); + + // Verify that return traffic works + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[0], private_addrs_[0], + sender_external_address0)); + + // Verify that other public IP can use the pinhole + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[1], private_addrs_[0], + sender_external_address0)); + + // Send traffic to other public IP, verify that it uses a different pinhole + nr_transport_addr sender_external_address1; + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[1], + &sender_external_address1)); + ASSERT_TRUE(nr_transport_addr_cmp(&sender_external_address0, + &sender_external_address1, + NR_TRANSPORT_ADDR_CMP_MODE_ALL)) + << "addr1: " << sender_external_address0.as_string + << " addr2: " << sender_external_address1.as_string; + + // Verify that return traffic works + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[1], private_addrs_[0], + sender_external_address1)); + + // Verify that other public IP can use the original pinhole + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[0], private_addrs_[0], + sender_external_address1)); +} + +TEST_F(TestNrSocketTest, Symmetric) { + RefPtr<TestNat> nat(CreatePrivateAddrs(1)); + nat->filtering_type_ = TestNat::PORT_DEPENDENT; + nat->mapping_type_ = TestNat::PORT_DEPENDENT; + CreatePublicAddrs(2); + + nr_transport_addr sender_external_address; + // Open pinhole to public IP 0 + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[0], + &sender_external_address)); + + // Verify that return traffic works + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[0], private_addrs_[0], + sender_external_address)); + + // Verify that other public IP cannot use the pinhole + ASSERT_FALSE(CheckConnectivityVia(public_addrs_[1], private_addrs_[0], + sender_external_address)); + + // Send traffic to other public IP, verify that it uses a new pinhole + nr_transport_addr sender_external_address2; + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[1], + &sender_external_address2)); + ASSERT_TRUE(nr_transport_addr_cmp(&sender_external_address, + &sender_external_address2, + NR_TRANSPORT_ADDR_CMP_MODE_ALL)); + + // Verify that the other public IP can use the new pinhole + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[1], private_addrs_[0], + sender_external_address2)); +} + +TEST_F(TestNrSocketTest, BlockUdp) { + RefPtr<TestNat> nat(CreatePrivateAddrs(2)); + nat->block_udp_ = true; + CreatePublicAddrs(1); + + nr_transport_addr sender_external_address; + ASSERT_FALSE(CheckConnectivity(private_addrs_[0], public_addrs_[0], + &sender_external_address)); + + // Make sure UDP behind the NAT still works + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], private_addrs_[1])); + ASSERT_TRUE(CheckConnectivity(private_addrs_[1], private_addrs_[0])); +} + +TEST_F(TestNrSocketTest, DenyHairpinning) { + RefPtr<TestNat> nat(CreatePrivateAddrs(2)); + nat->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + CreatePublicAddrs(1); + + nr_transport_addr sender_external_address; + // Open pinhole to public IP 0 + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[0], + &sender_external_address)); + + // Verify that hairpinning is disallowed + ASSERT_FALSE(CheckConnectivityVia(private_addrs_[1], private_addrs_[0], + sender_external_address)); +} + +TEST_F(TestNrSocketTest, AllowHairpinning) { + RefPtr<TestNat> nat(CreatePrivateAddrs(2)); + nat->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_timeout_ = 30000; + nat->allow_hairpinning_ = true; + CreatePublicAddrs(1); + + nr_transport_addr sender_external_address; + // Open pinhole to public IP 0, obtain external address + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[0], + &sender_external_address)); + + // Verify that hairpinning is allowed + ASSERT_TRUE(CheckConnectivityVia(private_addrs_[1], private_addrs_[0], + sender_external_address)); +} + +TEST_F(TestNrSocketTest, FullConeTimeout) { + RefPtr<TestNat> nat(CreatePrivateAddrs(1)); + nat->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_timeout_ = 200; + CreatePublicAddrs(2); + + nr_transport_addr sender_external_address; + // Open pinhole to public IP 0 + ASSERT_TRUE(CheckConnectivity(private_addrs_[0], public_addrs_[0], + &sender_external_address)); + + // Verify that return traffic works + ASSERT_TRUE(CheckConnectivityVia(public_addrs_[0], private_addrs_[0], + sender_external_address)); + + PR_Sleep(201); + + // Verify that return traffic does not work + ASSERT_FALSE(CheckConnectivityVia(public_addrs_[0], private_addrs_[0], + sender_external_address)); +} + +TEST_F(TestNrSocketTest, PublicConnectivityTcp) { + CreatePublicAddrs(2, "127.0.0.1", IPPROTO_TCP); + + ASSERT_TRUE(CheckTcpConnectivity(public_addrs_[0], public_addrs_[1])); +} + +TEST_F(TestNrSocketTest, PrivateConnectivityTcp) { + RefPtr<TestNat> nat(CreatePrivateAddrs(2, "127.0.0.1", IPPROTO_TCP)); + nat->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + + ASSERT_TRUE(CheckTcpConnectivity(private_addrs_[0], private_addrs_[1])); +} + +TEST_F(TestNrSocketTest, PrivateToPublicConnectivityTcp) { + RefPtr<TestNat> nat(CreatePrivateAddrs(1, "127.0.0.1", IPPROTO_TCP)); + nat->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + CreatePublicAddrs(1, "127.0.0.1", IPPROTO_TCP); + + ASSERT_TRUE(CheckTcpConnectivity(private_addrs_[0], public_addrs_[0])); +} + +TEST_F(TestNrSocketTest, NoConnectivityBetweenSubnetsTcp) { + RefPtr<TestNat> nat1(CreatePrivateAddrs(1, "127.0.0.1", IPPROTO_TCP)); + nat1->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat1->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + RefPtr<TestNat> nat2(CreatePrivateAddrs(1, "127.0.0.1", IPPROTO_TCP)); + nat2->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat2->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + + ASSERT_FALSE(CheckTcpConnectivity(private_addrs_[0], private_addrs_[1])); +} + +TEST_F(TestNrSocketTest, NoConnectivityPublicToPrivateTcp) { + RefPtr<TestNat> nat(CreatePrivateAddrs(1, "127.0.0.1", IPPROTO_TCP)); + nat->filtering_type_ = TestNat::ENDPOINT_INDEPENDENT; + nat->mapping_type_ = TestNat::ENDPOINT_INDEPENDENT; + CreatePublicAddrs(1, "127.0.0.1", IPPROTO_TCP); + + ASSERT_FALSE(CheckTcpConnectivity(public_addrs_[0], private_addrs_[0])); +} diff --git a/dom/media/webrtc/transport/test/transport_unittests.cpp b/dom/media/webrtc/transport/test/transport_unittests.cpp new file mode 100644 index 0000000000..28f9359afb --- /dev/null +++ b/dom/media/webrtc/transport/test/transport_unittests.cpp @@ -0,0 +1,1400 @@ + +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: ekr@rtfm.com + +#include <iostream> +#include <string> +#include <algorithm> +#include <functional> + +#ifdef XP_MACOSX +// ensure that Apple Security kit enum goes before "sslproto.h" +# include <CoreFoundation/CFAvailability.h> +# include <Security/CipherSuite.h> +#endif + +#include "mozilla/UniquePtr.h" + +#include "sigslot.h" + +#include "logging.h" +#include "ssl.h" +#include "sslexp.h" +#include "sslproto.h" + +#include "nsThreadUtils.h" + +#include "mediapacket.h" +#include "dtlsidentity.h" +#include "nricectx.h" +#include "nricemediastream.h" +#include "transportflow.h" +#include "transportlayer.h" +#include "transportlayerdtls.h" +#include "transportlayerice.h" +#include "transportlayerlog.h" +#include "transportlayerloopback.h" + +#include "runnable_utils.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +using namespace mozilla; +MOZ_MTLOG_MODULE("mtransport") + +const uint8_t kTlsChangeCipherSpecType = 0x14; +const uint8_t kTlsHandshakeType = 0x16; + +const uint8_t kTlsHandshakeCertificate = 0x0b; +const uint8_t kTlsHandshakeServerKeyExchange = 0x0c; + +const uint8_t kTlsFakeChangeCipherSpec[] = { + kTlsChangeCipherSpecType, // Type + 0xfe, + 0xff, // Version + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x10, // Fictitious sequence # + 0x00, + 0x01, // Length + 0x01 // Value +}; + +// Layer class which can't be initialized. +class TransportLayerDummy : public TransportLayer { + public: + TransportLayerDummy(bool allow_init, bool* destroyed) + : allow_init_(allow_init), destroyed_(destroyed) { + *destroyed_ = false; + } + + virtual ~TransportLayerDummy() { *destroyed_ = true; } + + nsresult InitInternal() override { + return allow_init_ ? NS_OK : NS_ERROR_FAILURE; + } + + TransportResult SendPacket(MediaPacket& packet) override { + MOZ_CRASH(); // Should never be called. + return 0; + } + + TRANSPORT_LAYER_ID("lossy") + + private: + bool allow_init_; + bool* destroyed_; +}; + +class Inspector { + public: + virtual ~Inspector() = default; + + virtual void Inspect(TransportLayer* layer, const unsigned char* data, + size_t len) = 0; +}; + +// Class to simulate various kinds of network lossage +class TransportLayerLossy : public TransportLayer { + public: + TransportLayerLossy() : loss_mask_(0), packet_(0), inspector_(nullptr) {} + ~TransportLayerLossy() = default; + + TransportResult SendPacket(MediaPacket& packet) override { + MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "SendPacket(" << packet.len() << ")"); + + if (loss_mask_ & (1 << (packet_ % 32))) { + MOZ_MTLOG(ML_NOTICE, "Dropping packet"); + ++packet_; + return packet.len(); + } + if (inspector_) { + inspector_->Inspect(this, packet.data(), packet.len()); + } + + ++packet_; + + return downward_->SendPacket(packet); + } + + void SetLoss(uint32_t packet) { loss_mask_ |= (1 << (packet & 32)); } + + void SetInspector(UniquePtr<Inspector> inspector) { + inspector_ = std::move(inspector); + } + + void StateChange(TransportLayer* layer, State state) { TL_SET_STATE(state); } + + void PacketReceived(TransportLayer* layer, MediaPacket& packet) { + SignalPacketReceived(this, packet); + } + + TRANSPORT_LAYER_ID("lossy") + + protected: + void WasInserted() override { + downward_->SignalPacketReceived.connect( + this, &TransportLayerLossy::PacketReceived); + downward_->SignalStateChange.connect(this, + &TransportLayerLossy::StateChange); + + TL_SET_STATE(downward_->state()); + } + + private: + uint32_t loss_mask_; + uint32_t packet_; + UniquePtr<Inspector> inspector_; +}; + +// Process DTLS Records +#define CHECK_LENGTH(expected) \ + do { \ + EXPECT_GE(remaining(), expected); \ + if (remaining() < expected) return false; \ + } while (0) + +class TlsParser { + public: + TlsParser(const unsigned char* data, size_t len) : buffer_(), offset_(0) { + buffer_.Copy(data, len); + } + + bool Read(unsigned char* val) { + if (remaining() < 1) { + return false; + } + *val = *ptr(); + consume(1); + return true; + } + + // Read an integral type of specified width. + bool Read(uint32_t* val, size_t len) { + if (len > sizeof(uint32_t)) return false; + + *val = 0; + + for (size_t i = 0; i < len; ++i) { + unsigned char tmp; + + if (!Read(&tmp)) return false; + + (*val) = ((*val) << 8) + tmp; + } + + return true; + } + + bool Read(unsigned char* val, size_t len) { + if (remaining() < len) { + return false; + } + + if (val) { + memcpy(val, ptr(), len); + } + consume(len); + + return true; + } + + private: + size_t remaining() const { return buffer_.len() - offset_; } + const uint8_t* ptr() const { return buffer_.data() + offset_; } + void consume(size_t len) { offset_ += len; } + + MediaPacket buffer_; + size_t offset_; +}; + +class DtlsRecordParser { + public: + DtlsRecordParser(const unsigned char* data, size_t len) + : buffer_(), offset_(0) { + buffer_.Copy(data, len); + } + + bool NextRecord(uint8_t* ct, UniquePtr<MediaPacket>* buffer) { + if (!remaining()) return false; + + CHECK_LENGTH(13U); + const uint8_t* ctp = reinterpret_cast<const uint8_t*>(ptr()); + consume(11); // ct + version + length + + const uint16_t* tmp = reinterpret_cast<const uint16_t*>(ptr()); + size_t length = ntohs(*tmp); + consume(2); + + CHECK_LENGTH(length); + auto db = MakeUnique<MediaPacket>(); + db->Copy(ptr(), length); + consume(length); + + *ct = *ctp; + *buffer = std::move(db); + + return true; + } + + private: + size_t remaining() const { return buffer_.len() - offset_; } + const uint8_t* ptr() const { return buffer_.data() + offset_; } + void consume(size_t len) { offset_ += len; } + + MediaPacket buffer_; + size_t offset_; +}; + +// Inspector that parses out DTLS records and passes +// them on. +class DtlsRecordInspector : public Inspector { + public: + virtual void Inspect(TransportLayer* layer, const unsigned char* data, + size_t len) { + DtlsRecordParser parser(data, len); + + uint8_t ct; + UniquePtr<MediaPacket> buf; + while (parser.NextRecord(&ct, &buf)) { + OnRecord(layer, ct, buf->data(), buf->len()); + } + } + + virtual void OnRecord(TransportLayer* layer, uint8_t content_type, + const unsigned char* record, size_t len) = 0; +}; + +// Inspector that injects arbitrary packets based on +// DTLS records of various types. +class DtlsInspectorInjector : public DtlsRecordInspector { + public: + DtlsInspectorInjector(uint8_t packet_type, uint8_t handshake_type, + const unsigned char* data, size_t len) + : packet_type_(packet_type), handshake_type_(handshake_type) { + packet_.Copy(data, len); + } + + virtual void OnRecord(TransportLayer* layer, uint8_t content_type, + const unsigned char* data, size_t len) { + // Only inject once. + if (!packet_.data()) { + return; + } + + // Check that the first byte is as requested. + if (content_type != packet_type_) { + return; + } + + if (handshake_type_ != 0xff) { + // Check that the packet is plausibly long enough. + if (len < 1) { + return; + } + + // Check that the handshake type is as requested. + if (data[0] != handshake_type_) { + return; + } + } + + layer->SendPacket(packet_); + packet_.Reset(); + } + + private: + uint8_t packet_type_; + uint8_t handshake_type_; + MediaPacket packet_; +}; + +// Make a copy of the first instance of a message. +class DtlsInspectorRecordHandshakeMessage : public DtlsRecordInspector { + public: + explicit DtlsInspectorRecordHandshakeMessage(uint8_t handshake_type) + : handshake_type_(handshake_type), buffer_() {} + + virtual void OnRecord(TransportLayer* layer, uint8_t content_type, + const unsigned char* data, size_t len) { + // Only do this once. + if (buffer_.len()) { + return; + } + + // Check that the first byte is as requested. + if (content_type != kTlsHandshakeType) { + return; + } + + TlsParser parser(data, len); + unsigned char message_type; + // Read the handshake message type. + if (!parser.Read(&message_type)) { + return; + } + if (message_type != handshake_type_) { + return; + } + + uint32_t length; + if (!parser.Read(&length, 3)) { + return; + } + + uint32_t message_seq; + if (!parser.Read(&message_seq, 2)) { + return; + } + + uint32_t fragment_offset; + if (!parser.Read(&fragment_offset, 3)) { + return; + } + + uint32_t fragment_length; + if (!parser.Read(&fragment_length, 3)) { + return; + } + + if ((fragment_offset != 0) || (fragment_length != length)) { + // This shouldn't happen because all current tests where we + // are using this code don't fragment. + return; + } + + UniquePtr<uint8_t[]> buffer(new uint8_t[length]); + if (!parser.Read(buffer.get(), length)) { + return; + } + buffer_.Take(std::move(buffer), length); + } + + const MediaPacket& buffer() { return buffer_; } + + private: + uint8_t handshake_type_; + MediaPacket buffer_; +}; + +class TlsServerKeyExchangeECDHE { + public: + bool Parse(const unsigned char* data, size_t len) { + TlsParser parser(data, len); + + uint8_t curve_type; + if (!parser.Read(&curve_type)) { + return false; + } + + if (curve_type != 3) { // named_curve + return false; + } + + uint32_t named_curve; + if (!parser.Read(&named_curve, 2)) { + return false; + } + + uint32_t point_length; + if (!parser.Read(&point_length, 1)) { + return false; + } + + UniquePtr<uint8_t[]> key(new uint8_t[point_length]); + if (!parser.Read(key.get(), point_length)) { + return false; + } + public_key_.Take(std::move(key), point_length); + + return true; + } + + MediaPacket public_key_; +}; + +namespace { +class TransportTestPeer : public sigslot::has_slots<> { + public: + TransportTestPeer(nsCOMPtr<nsIEventTarget> target, std::string name, + MtransportTestUtils* utils) + : name_(name), + offerer_(name == "P1"), + target_(target), + received_packets_(0), + received_bytes_(0), + flow_(new TransportFlow(name)), + loopback_(new TransportLayerLoopback()), + logging_(new TransportLayerLogging()), + lossy_(new TransportLayerLossy()), + dtls_(new TransportLayerDtls()), + identity_(DtlsIdentity::Generate()), + ice_ctx_(), + streams_(), + peer_(nullptr), + gathering_complete_(false), + digest_("sha-1"), + enabled_cipersuites_(), + disabled_cipersuites_(), + test_utils_(utils) { + NrIceCtx::InitializeGlobals(NrIceCtx::GlobalConfig()); + ice_ctx_ = NrIceCtx::Create(name); + std::vector<NrIceStunServer> stun_servers; + UniquePtr<NrIceStunServer> server(NrIceStunServer::Create( + std::string((char*)"stun.services.mozilla.com"), 3478)); + stun_servers.push_back(*server); + EXPECT_TRUE(NS_SUCCEEDED(ice_ctx_->SetStunServers(stun_servers))); + + dtls_->SetIdentity(identity_); + dtls_->SetRole(offerer_ ? TransportLayerDtls::SERVER + : TransportLayerDtls::CLIENT); + + nsresult res = identity_->ComputeFingerprint(&digest_); + EXPECT_TRUE(NS_SUCCEEDED(res)); + EXPECT_EQ(20u, digest_.value_.size()); + } + + ~TransportTestPeer() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &TransportTestPeer::DestroyFlow)); + } + + void DestroyFlow() { + disconnect_all(); + if (flow_) { + loopback_->Disconnect(); + flow_ = nullptr; + } + ice_ctx_->Destroy(); + ice_ctx_ = nullptr; + streams_.clear(); + } + + void DisconnectDestroyFlow() { + test_utils_->SyncDispatchToSTS(NS_NewRunnableFunction(__func__, [this] { + loopback_->Disconnect(); + disconnect_all(); // Disconnect from the signals; + flow_ = nullptr; + })); + } + + void SetDtlsAllowAll() { + nsresult res = dtls_->SetVerificationAllowAll(); + ASSERT_TRUE(NS_SUCCEEDED(res)); + } + + void SetAlpn(std::string str, bool withDefault, std::string extra = "") { + std::set<std::string> alpn; + alpn.insert(str); // the one we want to select + if (!extra.empty()) { + alpn.insert(extra); + } + nsresult res = dtls_->SetAlpn(alpn, withDefault ? str : ""); + ASSERT_EQ(NS_OK, res); + } + + const std::string& GetAlpn() const { return dtls_->GetNegotiatedAlpn(); } + + void SetDtlsPeer(TransportTestPeer* peer, int digests, unsigned int damage) { + unsigned int mask = 1; + + for (int i = 0; i < digests; i++) { + DtlsDigest digest_to_set(peer->digest_); + + if (damage & mask) digest_to_set.value_.data()[0]++; + + nsresult res = dtls_->SetVerificationDigest(digest_to_set); + + ASSERT_TRUE(NS_SUCCEEDED(res)); + + mask <<= 1; + } + } + + void SetupSrtp() { + std::vector<uint16_t> srtp_ciphers = + TransportLayerDtls::GetDefaultSrtpCiphers(); + SetSrtpCiphers(srtp_ciphers); + } + + void SetSrtpCiphers(std::vector<uint16_t>& srtp_ciphers) { + ASSERT_TRUE(NS_SUCCEEDED(dtls_->SetSrtpCiphers(srtp_ciphers))); + } + + void ConnectSocket_s(TransportTestPeer* peer) { + nsresult res; + res = loopback_->Init(); + ASSERT_EQ((nsresult)NS_OK, res); + + loopback_->Connect(peer->loopback_); + ASSERT_EQ((nsresult)NS_OK, loopback_->Init()); + ASSERT_EQ((nsresult)NS_OK, logging_->Init()); + ASSERT_EQ((nsresult)NS_OK, lossy_->Init()); + ASSERT_EQ((nsresult)NS_OK, dtls_->Init()); + dtls_->Chain(lossy_); + lossy_->Chain(logging_); + logging_->Chain(loopback_); + + flow_->PushLayer(loopback_); + flow_->PushLayer(logging_); + flow_->PushLayer(lossy_); + flow_->PushLayer(dtls_); + + if (dtls_->state() != TransportLayer::TS_ERROR) { + // Don't execute these blocks if DTLS didn't initialize. + TweakCiphers(dtls_->internal_fd()); + if (post_setup_) { + post_setup_(dtls_->internal_fd()); + } + } + + dtls_->SignalPacketReceived.connect(this, + &TransportTestPeer::PacketReceived); + } + + void TweakCiphers(PRFileDesc* fd) { + for (unsigned short& enabled_cipersuite : enabled_cipersuites_) { + SSL_CipherPrefSet(fd, enabled_cipersuite, PR_TRUE); + } + for (unsigned short& disabled_cipersuite : disabled_cipersuites_) { + SSL_CipherPrefSet(fd, disabled_cipersuite, PR_FALSE); + } + } + + void ConnectSocket(TransportTestPeer* peer) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &TransportTestPeer::ConnectSocket_s, peer)); + } + + nsresult InitIce_s() { + nsresult rv = ice_->Init(); + NS_ENSURE_SUCCESS(rv, rv); + rv = dtls_->Init(); + NS_ENSURE_SUCCESS(rv, rv); + dtls_->Chain(ice_); + flow_->PushLayer(ice_); + flow_->PushLayer(dtls_); + return NS_OK; + } + + void InitIce() { + nsresult res; + + // Attach our slots + ice_ctx_->SignalGatheringStateChange.connect( + this, &TransportTestPeer::GatheringStateChange); + + char name[100]; + snprintf(name, sizeof(name), "%s:stream%d", name_.c_str(), + (int)streams_.size()); + + // Create the media stream + RefPtr<NrIceMediaStream> stream = ice_ctx_->CreateStream(name, name, 1); + + ASSERT_TRUE(stream != nullptr); + stream->SetIceCredentials("ufrag", "pass"); + streams_.push_back(stream); + + // Listen for candidates + stream->SignalCandidate.connect(this, &TransportTestPeer::GotCandidate); + + // Create the transport layer + ice_ = new TransportLayerIce(); + ice_->SetParameters(stream, 1); + + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&res, this, &TransportTestPeer::InitIce_s)); + + ASSERT_EQ((nsresult)NS_OK, res); + + // Listen for media events + dtls_->SignalPacketReceived.connect(this, + &TransportTestPeer::PacketReceived); + dtls_->SignalStateChange.connect(this, &TransportTestPeer::StateChanged); + + // Start gathering + test_utils_->SyncDispatchToSTS(WrapRunnableRet( + &res, ice_ctx_, &NrIceCtx::StartGathering, false, false)); + ASSERT_TRUE(NS_SUCCEEDED(res)); + } + + void ConnectIce(TransportTestPeer* peer) { + peer_ = peer; + + // If gathering is already complete, push the candidates over + if (gathering_complete_) GatheringComplete(); + } + + // New candidate + void GotCandidate(NrIceMediaStream* stream, const std::string& candidate, + const std::string& ufrag, const std::string& mdns_addr, + const std::string& actual_addr) { + std::cerr << "Got candidate " << candidate << " (ufrag=" << ufrag << ")" + << std::endl; + } + + void GatheringStateChange(NrIceCtx* ctx, NrIceCtx::GatheringState state) { + (void)ctx; + if (state == NrIceCtx::ICE_CTX_GATHER_COMPLETE) { + GatheringComplete(); + } + } + + // Gathering complete, so send our candidates and start + // connecting on the other peer. + void GatheringComplete() { + nsresult res; + + // Don't send to the other side + if (!peer_) { + gathering_complete_ = true; + return; + } + + // First send attributes + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&res, peer_->ice_ctx_, &NrIceCtx::ParseGlobalAttributes, + ice_ctx_->GetGlobalAttributes())); + ASSERT_TRUE(NS_SUCCEEDED(res)); + + for (size_t i = 0; i < streams_.size(); ++i) { + test_utils_->SyncDispatchToSTS(WrapRunnableRet( + &res, peer_->streams_[i], &NrIceMediaStream::ConnectToPeer, "ufrag", + "pass", streams_[i]->GetAttributes())); + + ASSERT_TRUE(NS_SUCCEEDED(res)); + } + + // Start checks on the other peer. + test_utils_->SyncDispatchToSTS( + WrapRunnableRet(&res, peer_->ice_ctx_, &NrIceCtx::StartChecks)); + ASSERT_TRUE(NS_SUCCEEDED(res)); + } + + // WrapRunnable/lambda and move semantics (MediaPacket is not copyable) don't + // get along yet, so we need a wrapper. Gross. + static TransportResult SendPacketWrapper(TransportLayer* layer, + MediaPacket* packet) { + return layer->SendPacket(*packet); + } + + TransportResult SendPacket(MediaPacket& packet) { + TransportResult ret; + + test_utils_->SyncDispatchToSTS(WrapRunnableNMRet( + &ret, &TransportTestPeer::SendPacketWrapper, dtls_, &packet)); + + return ret; + } + + void StateChanged(TransportLayer* layer, TransportLayer::State state) { + if (state == TransportLayer::TS_OPEN) { + std::cerr << "Now connected" << std::endl; + } + } + + void PacketReceived(TransportLayer* layer, MediaPacket& packet) { + std::cerr << "Received " << packet.len() << " bytes" << std::endl; + ++received_packets_; + received_bytes_ += packet.len(); + } + + void SetLoss(uint32_t loss) { lossy_->SetLoss(loss); } + + void SetCombinePackets(bool combine) { loopback_->CombinePackets(combine); } + + void SetInspector(UniquePtr<Inspector> inspector) { + lossy_->SetInspector(std::move(inspector)); + } + + void SetInspector(Inspector* in) { + UniquePtr<Inspector> inspector(in); + + lossy_->SetInspector(std::move(inspector)); + } + + void SetCipherSuiteChanges(const std::vector<uint16_t>& enableThese, + const std::vector<uint16_t>& disableThese) { + disabled_cipersuites_ = disableThese; + enabled_cipersuites_ = enableThese; + } + + void SetPostSetup(const std::function<void(PRFileDesc*)>& setup) { + post_setup_ = std::move(setup); + } + + TransportLayer::State state() { + TransportLayer::State tstate; + + RUN_ON_THREAD(test_utils_->sts_target(), + WrapRunnableRet(&tstate, dtls_, &TransportLayer::state)); + + return tstate; + } + + bool connected() { return state() == TransportLayer::TS_OPEN; } + + bool failed() { return state() == TransportLayer::TS_ERROR; } + + size_t receivedPackets() { return received_packets_; } + + size_t receivedBytes() { return received_bytes_; } + + uint16_t cipherSuite() const { + nsresult rv; + uint16_t cipher; + RUN_ON_THREAD( + test_utils_->sts_target(), + WrapRunnableRet(&rv, dtls_, &TransportLayerDtls::GetCipherSuite, + &cipher)); + + if (NS_FAILED(rv)) { + return TLS_NULL_WITH_NULL_NULL; // i.e., not good + } + return cipher; + } + + uint16_t srtpCipher() const { + nsresult rv; + uint16_t cipher; + RUN_ON_THREAD(test_utils_->sts_target(), + WrapRunnableRet(&rv, dtls_, + &TransportLayerDtls::GetSrtpCipher, &cipher)); + if (NS_FAILED(rv)) { + return 0; // the SRTP equivalent of TLS_NULL_WITH_NULL_NULL + } + return cipher; + } + + private: + std::string name_; + bool offerer_; + nsCOMPtr<nsIEventTarget> target_; + size_t received_packets_; + size_t received_bytes_; + RefPtr<TransportFlow> flow_; + TransportLayerLoopback* loopback_; + TransportLayerLogging* logging_; + TransportLayerLossy* lossy_; + TransportLayerDtls* dtls_; + TransportLayerIce* ice_; + RefPtr<DtlsIdentity> identity_; + RefPtr<NrIceCtx> ice_ctx_; + std::vector<RefPtr<NrIceMediaStream> > streams_; + TransportTestPeer* peer_; + bool gathering_complete_; + DtlsDigest digest_; + std::vector<uint16_t> enabled_cipersuites_; + std::vector<uint16_t> disabled_cipersuites_; + MtransportTestUtils* test_utils_; + std::function<void(PRFileDesc* fd)> post_setup_ = nullptr; +}; + +class TransportTest : public MtransportTest { + public: + TransportTest() { + fds_[0] = nullptr; + fds_[1] = nullptr; + p1_ = nullptr; + p2_ = nullptr; + } + + void TearDown() override { + delete p1_; + delete p2_; + + // Can't detach these + // PR_Close(fds_[0]); + // PR_Close(fds_[1]); + MtransportTest::TearDown(); + } + + void DestroyPeerFlows() { + p1_->DisconnectDestroyFlow(); + p2_->DisconnectDestroyFlow(); + } + + void SetUp() override { + MtransportTest::SetUp(); + + nsresult rv; + target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + ASSERT_TRUE(NS_SUCCEEDED(rv)); + + Reset(); + } + + void Reset() { + if (p1_) { + delete p1_; + } + if (p2_) { + delete p2_; + } + p1_ = new TransportTestPeer(target_, "P1", test_utils_); + p2_ = new TransportTestPeer(target_, "P2", test_utils_); + } + + void SetupSrtp() { + p1_->SetupSrtp(); + p2_->SetupSrtp(); + } + + void SetDtlsPeer(int digests = 1, unsigned int damage = 0) { + p1_->SetDtlsPeer(p2_, digests, damage); + p2_->SetDtlsPeer(p1_, digests, damage); + } + + void SetDtlsAllowAll() { + p1_->SetDtlsAllowAll(); + p2_->SetDtlsAllowAll(); + } + + void SetAlpn(std::string first, std::string second, + bool withDefaults = true) { + if (!first.empty()) { + p1_->SetAlpn(first, withDefaults, "bogus"); + } + if (!second.empty()) { + p2_->SetAlpn(second, withDefaults); + } + } + + void CheckAlpn(std::string first, std::string second) { + ASSERT_EQ(first, p1_->GetAlpn()); + ASSERT_EQ(second, p2_->GetAlpn()); + } + + void ConnectSocket() { + ConnectSocketInternal(); + ASSERT_TRUE_WAIT(p1_->connected(), 10000); + ASSERT_TRUE_WAIT(p2_->connected(), 10000); + + ASSERT_EQ(p1_->cipherSuite(), p2_->cipherSuite()); + ASSERT_EQ(p1_->srtpCipher(), p2_->srtpCipher()); + } + + void ConnectSocketExpectFail() { + ConnectSocketInternal(); + ASSERT_TRUE_WAIT(p1_->failed(), 10000); + ASSERT_TRUE_WAIT(p2_->failed(), 10000); + } + + void ConnectSocketExpectState(TransportLayer::State s1, + TransportLayer::State s2) { + ConnectSocketInternal(); + ASSERT_EQ_WAIT(s1, p1_->state(), 10000); + ASSERT_EQ_WAIT(s2, p2_->state(), 10000); + } + + void ConnectIce() { + p1_->InitIce(); + p2_->InitIce(); + p1_->ConnectIce(p2_); + p2_->ConnectIce(p1_); + ASSERT_TRUE_WAIT(p1_->connected(), 10000); + ASSERT_TRUE_WAIT(p2_->connected(), 10000); + } + + void TransferTest(size_t count, size_t bytes = 1024) { + unsigned char buf[bytes]; + + for (size_t i = 0; i < count; ++i) { + memset(buf, count & 0xff, sizeof(buf)); + MediaPacket packet; + packet.Copy(buf, sizeof(buf)); + TransportResult rv = p1_->SendPacket(packet); + ASSERT_TRUE(rv > 0); + } + + std::cerr << "Received == " << p2_->receivedPackets() << " packets" + << std::endl; + ASSERT_TRUE_WAIT(count == p2_->receivedPackets(), 10000); + ASSERT_TRUE((count * sizeof(buf)) == p2_->receivedBytes()); + } + + protected: + void ConnectSocketInternal() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(p1_, &TransportTestPeer::ConnectSocket, p2_)); + test_utils_->SyncDispatchToSTS( + WrapRunnable(p2_, &TransportTestPeer::ConnectSocket, p1_)); + } + + PRFileDesc* fds_[2]; + TransportTestPeer* p1_; + TransportTestPeer* p2_; + nsCOMPtr<nsIEventTarget> target_; +}; + +TEST_F(TransportTest, TestNoDtlsVerificationSettings) { + ConnectSocketExpectFail(); +} + +static void DisableChaCha(TransportTestPeer* peer) { + // On ARM, ChaCha20Poly1305 might be preferred; disable it for the tests that + // want to check the cipher suite. It doesn't matter which peer disables the + // suite, disabling on either side has the same effect. + std::vector<uint16_t> chachaSuites; + chachaSuites.push_back(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256); + chachaSuites.push_back(TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256); + peer->SetCipherSuiteChanges(std::vector<uint16_t>(), chachaSuites); +} + +TEST_F(TransportTest, TestConnect) { + SetDtlsPeer(); + DisableChaCha(p1_); + ConnectSocket(); + + // check that we got the right suite + ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite()); + + // no SRTP on this one + ASSERT_EQ(0, p1_->srtpCipher()); +} + +TEST_F(TransportTest, TestConnectSrtp) { + SetupSrtp(); + SetDtlsPeer(); + DisableChaCha(p2_); + ConnectSocket(); + + ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite()); + + // SRTP is on with default value + ASSERT_EQ(kDtlsSrtpAeadAes128Gcm, p1_->srtpCipher()); +} + +TEST_F(TransportTest, TestConnectDestroyFlowsMainThread) { + SetDtlsPeer(); + ConnectSocket(); + DestroyPeerFlows(); +} + +TEST_F(TransportTest, TestConnectAllowAll) { + SetDtlsAllowAll(); + ConnectSocket(); +} + +TEST_F(TransportTest, TestConnectAlpn) { + SetDtlsPeer(); + SetAlpn("a", "a"); + ConnectSocket(); + CheckAlpn("a", "a"); +} + +TEST_F(TransportTest, TestConnectAlpnMismatch) { + SetDtlsPeer(); + SetAlpn("something", "different"); + ConnectSocketExpectFail(); +} + +TEST_F(TransportTest, TestConnectAlpnServerDefault) { + SetDtlsPeer(); + SetAlpn("def", ""); + // server allows default, client doesn't support + ConnectSocket(); + CheckAlpn("def", ""); +} + +TEST_F(TransportTest, TestConnectAlpnClientDefault) { + SetDtlsPeer(); + SetAlpn("", "clientdef"); + // client allows default, but server will ignore the extension + ConnectSocket(); + CheckAlpn("", "clientdef"); +} + +TEST_F(TransportTest, TestConnectClientNoAlpn) { + SetDtlsPeer(); + // Here the server has ALPN, but no default is allowed. + // Reminder: p1 == server, p2 == client + SetAlpn("server-nodefault", "", false); + // The server doesn't see the extension, so negotiates without it. + // But then the server is forced to close when it discovers that ALPN wasn't + // negotiated; the client sees a close. + ConnectSocketExpectState(TransportLayer::TS_ERROR, TransportLayer::TS_CLOSED); +} + +TEST_F(TransportTest, TestConnectServerNoAlpn) { + SetDtlsPeer(); + SetAlpn("", "client-nodefault", false); + // The client aborts; the server doesn't realize this is a problem and just + // sees the close. + ConnectSocketExpectState(TransportLayer::TS_CLOSED, TransportLayer::TS_ERROR); +} + +TEST_F(TransportTest, TestConnectNoDigest) { + SetDtlsPeer(0, 0); + + ConnectSocketExpectFail(); +} + +TEST_F(TransportTest, TestConnectBadDigest) { + SetDtlsPeer(1, 1); + + ConnectSocketExpectFail(); +} + +TEST_F(TransportTest, TestConnectTwoDigests) { + SetDtlsPeer(2, 0); + + ConnectSocket(); +} + +TEST_F(TransportTest, TestConnectTwoDigestsFirstBad) { + SetDtlsPeer(2, 1); + + ConnectSocket(); +} + +TEST_F(TransportTest, TestConnectTwoDigestsSecondBad) { + SetDtlsPeer(2, 2); + + ConnectSocket(); +} + +TEST_F(TransportTest, TestConnectTwoDigestsBothBad) { + SetDtlsPeer(2, 3); + + ConnectSocketExpectFail(); +} + +TEST_F(TransportTest, TestConnectInjectCCS) { + SetDtlsPeer(); + p2_->SetInspector(MakeUnique<DtlsInspectorInjector>( + kTlsHandshakeType, kTlsHandshakeCertificate, kTlsFakeChangeCipherSpec, + sizeof(kTlsFakeChangeCipherSpec))); + + ConnectSocket(); +} + +TEST_F(TransportTest, TestConnectVerifyNewECDHE) { + SetDtlsPeer(); + DtlsInspectorRecordHandshakeMessage* i1 = + new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + p1_->SetInspector(i1); + ConnectSocket(); + TlsServerKeyExchangeECDHE dhe1; + ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len())); + + Reset(); + SetDtlsPeer(); + DtlsInspectorRecordHandshakeMessage* i2 = + new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + p1_->SetInspector(i2); + ConnectSocket(); + TlsServerKeyExchangeECDHE dhe2; + ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len())); + + // Now compare these two to see if they are the same. + ASSERT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) && + (!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(), + dhe1.public_key_.len()))); +} + +TEST_F(TransportTest, TestConnectVerifyReusedECDHE) { + auto set_reuse_ecdhe_key = [](PRFileDesc* fd) { + // TransportLayerDtls automatically sets this pref to false + // so set it back for test. + // This is pretty gross. Dig directly into the NSS FD. The problem + // is that we are testing a feature which TransaportLayerDtls doesn't + // expose. + SECStatus rv = SSL_OptionSet(fd, SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE); + ASSERT_EQ(SECSuccess, rv); + }; + + SetDtlsPeer(); + DtlsInspectorRecordHandshakeMessage* i1 = + new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + p1_->SetInspector(i1); + p1_->SetPostSetup(set_reuse_ecdhe_key); + ConnectSocket(); + TlsServerKeyExchangeECDHE dhe1; + ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len())); + + Reset(); + SetDtlsPeer(); + DtlsInspectorRecordHandshakeMessage* i2 = + new DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + + p1_->SetInspector(i2); + p1_->SetPostSetup(set_reuse_ecdhe_key); + + ConnectSocket(); + TlsServerKeyExchangeECDHE dhe2; + ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len())); + + // Now compare these two to see if they are the same. + ASSERT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len()); + ASSERT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(), + dhe1.public_key_.len())); +} + +TEST_F(TransportTest, TestTransfer) { + SetDtlsPeer(); + ConnectSocket(); + TransferTest(1); +} + +TEST_F(TransportTest, TestTransferMaxSize) { + SetDtlsPeer(); + ConnectSocket(); + /* transportlayerdtls uses a 9216 bytes buffer - as this test uses the + * loopback implementation it does not have to take into account the extra + * bytes added by the DTLS layer below. */ + TransferTest(1, 9216); +} + +TEST_F(TransportTest, TestTransferMultiple) { + SetDtlsPeer(); + ConnectSocket(); + TransferTest(3); +} + +TEST_F(TransportTest, TestTransferCombinedPackets) { + SetDtlsPeer(); + ConnectSocket(); + p2_->SetCombinePackets(true); + TransferTest(3); +} + +TEST_F(TransportTest, TestConnectLoseFirst) { + SetDtlsPeer(); + p1_->SetLoss(0); + ConnectSocket(); + TransferTest(1); +} + +TEST_F(TransportTest, TestConnectIce) { + SetDtlsPeer(); + ConnectIce(); +} + +TEST_F(TransportTest, TestTransferIceMaxSize) { + SetDtlsPeer(); + ConnectIce(); + /* nICEr and transportlayerdtls both use 9216 bytes buffers. But the DTLS + * layer add extra bytes to the packet, which size depends on chosen cipher + * etc. Sending more then 9216 bytes works, but on the receiving side the call + * to PR_recvfrom() will truncate any packet bigger then nICEr's buffer size + * of 9216 bytes, which then results in the DTLS layer discarding the packet. + * Therefore we leave some headroom (according to + * https://bugzilla.mozilla.org/show_bug.cgi?id=1214269#c29 256 bytes should + * be save choice) here for the DTLS bytes to make it safely into the + * receiving buffer in nICEr. */ + TransferTest(1, 8960); +} + +TEST_F(TransportTest, TestTransferIceMultiple) { + SetDtlsPeer(); + ConnectIce(); + TransferTest(3); +} + +TEST_F(TransportTest, TestTransferIceCombinedPackets) { + SetDtlsPeer(); + ConnectIce(); + p2_->SetCombinePackets(true); + TransferTest(3); +} + +// test the default configuration against a peer that supports only +// one of the mandatory-to-implement suites, which should succeed +static void ConfigureOneCipher(TransportTestPeer* peer, uint16_t suite) { + std::vector<uint16_t> justOne; + justOne.push_back(suite); + std::vector<uint16_t> everythingElse( + SSL_GetImplementedCiphers(), + SSL_GetImplementedCiphers() + SSL_GetNumImplementedCiphers()); + std::remove(everythingElse.begin(), everythingElse.end(), suite); + peer->SetCipherSuiteChanges(justOne, everythingElse); +} + +TEST_F(TransportTest, TestCipherMismatch) { + SetDtlsPeer(); + ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256); + ConfigureOneCipher(p2_, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA); + ConnectSocketExpectFail(); +} + +TEST_F(TransportTest, TestCipherMandatoryOnlyGcm) { + SetDtlsPeer(); + ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256); + ConnectSocket(); + ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite()); +} + +TEST_F(TransportTest, TestCipherMandatoryOnlyCbc) { + SetDtlsPeer(); + ConfigureOneCipher(p1_, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA); + ConnectSocket(); + ASSERT_EQ(TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, p1_->cipherSuite()); +} + +TEST_F(TransportTest, TestSrtpMismatch) { + std::vector<uint16_t> setA; + setA.push_back(kDtlsSrtpAes128CmHmacSha1_80); + std::vector<uint16_t> setB; + setB.push_back(kDtlsSrtpAes128CmHmacSha1_32); + + p1_->SetSrtpCiphers(setA); + p2_->SetSrtpCiphers(setB); + SetDtlsPeer(); + ConnectSocketExpectFail(); + + ASSERT_EQ(0, p1_->srtpCipher()); + ASSERT_EQ(0, p2_->srtpCipher()); +} + +static SECStatus NoopXtnHandler(PRFileDesc* fd, SSLHandshakeType message, + const uint8_t* data, unsigned int len, + SSLAlertDescription* alert, void* arg) { + return SECSuccess; +} + +static PRBool WriteFixedXtn(PRFileDesc* fd, SSLHandshakeType message, + uint8_t* data, unsigned int* len, + unsigned int max_len, void* arg) { + // When we enable TLS 1.3, change ssl_hs_server_hello here to + // ssl_hs_encrypted_extensions. At the same time, add a test that writes to + // ssl_hs_server_hello, which should fail. + if (message != ssl_hs_client_hello && message != ssl_hs_server_hello) { + return false; + } + + auto v = reinterpret_cast<std::vector<uint8_t>*>(arg); + memcpy(data, &((*v)[0]), v->size()); + *len = v->size(); + return true; +} + +// Note that |value| needs to be readable after this function returns. +static void InstallBadSrtpExtensionWriter(TransportTestPeer* peer, + std::vector<uint8_t>* value) { + peer->SetPostSetup([value](PRFileDesc* fd) { + // Override the handler that is installed by the DTLS setup. + SECStatus rv = SSL_InstallExtensionHooks( + fd, ssl_use_srtp_xtn, WriteFixedXtn, value, NoopXtnHandler, nullptr); + ASSERT_EQ(SECSuccess, rv); + }); +} + +TEST_F(TransportTest, TestSrtpErrorServerSendsTwoSrtpCiphers) { + // Server (p1_) sends an extension with two values, and empty MKI. + std::vector<uint8_t> xtn = {0x04, 0x00, 0x01, 0x00, 0x02, 0x00}; + InstallBadSrtpExtensionWriter(p1_, &xtn); + SetupSrtp(); + SetDtlsPeer(); + ConnectSocketExpectFail(); +} + +TEST_F(TransportTest, TestSrtpErrorServerSendsTwoMki) { + // Server (p1_) sends an MKI. + std::vector<uint8_t> xtn = {0x02, 0x00, 0x01, 0x01, 0x00}; + InstallBadSrtpExtensionWriter(p1_, &xtn); + SetupSrtp(); + SetDtlsPeer(); + ConnectSocketExpectFail(); +} + +TEST_F(TransportTest, TestSrtpErrorServerSendsUnknownValue) { + std::vector<uint8_t> xtn = {0x02, 0x9a, 0xf1, 0x00}; + InstallBadSrtpExtensionWriter(p1_, &xtn); + SetupSrtp(); + SetDtlsPeer(); + ConnectSocketExpectFail(); +} + +TEST_F(TransportTest, TestSrtpErrorServerSendsOverflow) { + std::vector<uint8_t> xtn = {0x32, 0x00, 0x01, 0x00}; + InstallBadSrtpExtensionWriter(p1_, &xtn); + SetupSrtp(); + SetDtlsPeer(); + ConnectSocketExpectFail(); +} + +TEST_F(TransportTest, TestSrtpErrorServerSendsUnevenList) { + std::vector<uint8_t> xtn = {0x01, 0x00, 0x00}; + InstallBadSrtpExtensionWriter(p1_, &xtn); + SetupSrtp(); + SetDtlsPeer(); + ConnectSocketExpectFail(); +} + +TEST_F(TransportTest, TestSrtpErrorClientSendsUnevenList) { + std::vector<uint8_t> xtn = {0x01, 0x00, 0x00}; + InstallBadSrtpExtensionWriter(p2_, &xtn); + SetupSrtp(); + SetDtlsPeer(); + ConnectSocketExpectFail(); +} + +TEST_F(TransportTest, OnlyServerSendsSrtpXtn) { + p1_->SetupSrtp(); + SetDtlsPeer(); + // This should connect, but with no SRTP extension neogtiated. + // The client side might negotiate a data channel only. + ConnectSocket(); + ASSERT_NE(TLS_NULL_WITH_NULL_NULL, p1_->cipherSuite()); + ASSERT_EQ(0, p1_->srtpCipher()); +} + +TEST_F(TransportTest, OnlyClientSendsSrtpXtn) { + p2_->SetupSrtp(); + SetDtlsPeer(); + // This should connect, but with no SRTP extension neogtiated. + // The server side might negotiate a data channel only. + ConnectSocket(); + ASSERT_NE(TLS_NULL_WITH_NULL_NULL, p1_->cipherSuite()); + ASSERT_EQ(0, p1_->srtpCipher()); +} + +class TransportSrtpParameterTest + : public TransportTest, + public ::testing::WithParamInterface<uint16_t> {}; + +INSTANTIATE_TEST_SUITE_P( + SrtpParamInit, TransportSrtpParameterTest, + ::testing::ValuesIn(TransportLayerDtls::GetDefaultSrtpCiphers())); + +TEST_P(TransportSrtpParameterTest, TestSrtpCiphersMismatchCombinations) { + uint16_t cipher = GetParam(); + std::cerr << "Checking cipher: " << cipher << std::endl; + + p1_->SetupSrtp(); + + std::vector<uint16_t> setB; + setB.push_back(cipher); + + p2_->SetSrtpCiphers(setB); + SetDtlsPeer(); + ConnectSocket(); + + ASSERT_EQ(cipher, p1_->srtpCipher()); + ASSERT_EQ(cipher, p2_->srtpCipher()); +} + +// NSS doesn't support DHE suites on the server end. +// This checks to see if we barf when that's the only option available. +TEST_F(TransportTest, TestDheOnlyFails) { + SetDtlsPeer(); + + // p2_ is the client + // setting this on p1_ (the server) causes NSS to assert + ConfigureOneCipher(p2_, TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + ConnectSocketExpectFail(); +} + +} // end namespace diff --git a/dom/media/webrtc/transport/test/turn_unittest.cpp b/dom/media/webrtc/transport/test/turn_unittest.cpp new file mode 100644 index 0000000000..ae5d0386d9 --- /dev/null +++ b/dom/media/webrtc/transport/test/turn_unittest.cpp @@ -0,0 +1,432 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// Original author: ekr@rtfm.com + +// Some code copied from nICEr. License is: +/* +Copyright (c) 2007, Adobe Systems, Incorporated +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of Adobe Systems, Network Resonance nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#include <stdlib.h> +#include <iostream> + +#include "runnable_utils.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +#define USE_TURN + +// nICEr includes +extern "C" { +#include "nr_api.h" +#include "transport_addr.h" +#include "nr_crypto.h" +#include "nr_socket.h" +#include "nr_socket_local.h" +#include "nr_socket_buffered_stun.h" +#include "stun_client_ctx.h" +#include "turn_client_ctx.h" +} + +#include "nricectx.h" + +using namespace mozilla; + +static std::string kDummyTurnServer("192.0.2.1"); // From RFC 5737 + +class TurnClient : public MtransportTest { + public: + TurnClient() + : MtransportTest(), + real_socket_(nullptr), + net_socket_(nullptr), + buffered_socket_(nullptr), + net_fd_(nullptr), + turn_ctx_(nullptr), + allocated_(false), + received_(0), + protocol_(IPPROTO_UDP) {} + + ~TurnClient() = default; + + static void SetUpTestCase() { + NrIceCtx::InitializeGlobals(NrIceCtx::GlobalConfig()); + } + + void SetTcp() { protocol_ = IPPROTO_TCP; } + + void Init_s() { + int r; + nr_transport_addr addr; + r = nr_ip4_port_to_transport_addr(0, 0, protocol_, &addr); + ASSERT_EQ(0, r); + + r = nr_socket_local_create(nullptr, &addr, &real_socket_); + ASSERT_EQ(0, r); + + if (protocol_ == IPPROTO_TCP) { + int r = nr_socket_buffered_stun_create( + real_socket_, 100000, TURN_TCP_FRAMING, &buffered_socket_); + ASSERT_EQ(0, r); + net_socket_ = buffered_socket_; + } else { + net_socket_ = real_socket_; + } + + r = nr_str_port_to_transport_addr(turn_server_.c_str(), 3478, protocol_, + &addr); + ASSERT_EQ(0, r); + + std::vector<unsigned char> password_vec(turn_password_.begin(), + turn_password_.end()); + Data password; + INIT_DATA(password, &password_vec[0], password_vec.size()); + r = nr_turn_client_ctx_create("test", net_socket_, turn_user_.c_str(), + &password, &addr, nullptr, &turn_ctx_); + ASSERT_EQ(0, r); + + r = nr_socket_getfd(net_socket_, &net_fd_); + ASSERT_EQ(0, r); + + NR_ASYNC_WAIT(net_fd_, NR_ASYNC_WAIT_READ, socket_readable_cb, (void*)this); + } + + void TearDown_s() { + nr_turn_client_ctx_destroy(&turn_ctx_); + if (net_fd_) { + NR_ASYNC_CANCEL(net_fd_, NR_ASYNC_WAIT_READ); + } + + nr_socket_destroy(&buffered_socket_); + } + + void TearDown() { + test_utils_->SyncDispatchToSTS(WrapRunnable(this, &TurnClient::TearDown_s)); + } + + void Allocate_s() { + Init_s(); + ASSERT_TRUE(turn_ctx_); + + int r = nr_turn_client_allocate(turn_ctx_, allocate_success_cb, this); + ASSERT_EQ(0, r); + } + + void Allocate(bool expect_success = true) { + test_utils_->SyncDispatchToSTS(WrapRunnable(this, &TurnClient::Allocate_s)); + + if (expect_success) { + ASSERT_TRUE_WAIT(allocated_, 5000); + } else { + PR_Sleep(10000); + ASSERT_FALSE(allocated_); + } + } + + void Allocated() { + if (turn_ctx_->state != NR_TURN_CLIENT_STATE_ALLOCATED) { + std::cerr << "Allocation failed" << std::endl; + return; + } + allocated_ = true; + + int r; + nr_transport_addr addr; + + r = nr_turn_client_get_relayed_address(turn_ctx_, &addr); + ASSERT_EQ(0, r); + + relay_addr_ = addr.as_string; + + std::cerr << "Allocation succeeded with addr=" << relay_addr_ << std::endl; + } + + void Deallocate_s() { + ASSERT_TRUE(turn_ctx_); + + std::cerr << "De-Allocating..." << std::endl; + int r = nr_turn_client_deallocate(turn_ctx_); + ASSERT_EQ(0, r); + } + + void Deallocate() { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &TurnClient::Deallocate_s)); + } + + void RequestPermission_s(const std::string& target) { + nr_transport_addr addr; + int r; + + // Expected pattern here is "IP4:127.0.0.1:3487" + ASSERT_EQ(0, target.compare(0, 4, "IP4:")); + + size_t offset = target.rfind(':'); + ASSERT_NE(std::string::npos, offset); + + std::string host = target.substr(4, offset - 4); + std::string port = target.substr(offset + 1); + + r = nr_str_port_to_transport_addr(host.c_str(), atoi(port.c_str()), + IPPROTO_UDP, &addr); + ASSERT_EQ(0, r); + + r = nr_turn_client_ensure_perm(turn_ctx_, &addr); + ASSERT_EQ(0, r); + } + + void RequestPermission(const std::string& target) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &TurnClient::RequestPermission_s, target)); + } + + void Readable(NR_SOCKET s, int how, void* arg) { + // Re-arm + std::cerr << "Socket is readable" << std::endl; + NR_ASYNC_WAIT(s, how, socket_readable_cb, arg); + + UCHAR buf[8192]; + size_t len_s; + nr_transport_addr addr; + + int r = nr_socket_recvfrom(net_socket_, buf, sizeof(buf), &len_s, 0, &addr); + if (r) { + std::cerr << "Error reading from socket" << std::endl; + return; + } + + ASSERT_LT(len_s, (size_t)INT_MAX); + int len = (int)len_s; + + if (nr_is_stun_response_message(buf, len)) { + std::cerr << "STUN response" << std::endl; + r = nr_turn_client_process_response(turn_ctx_, buf, len, &addr); + + if (r && r != R_REJECTED && r != R_RETRY) { + std::cerr << "Error processing STUN: " << r << std::endl; + } + } else if (nr_is_stun_indication_message(buf, len)) { + std::cerr << "STUN indication" << std::endl; + + /* Process the indication */ + unsigned char data[NR_STUN_MAX_MESSAGE_SIZE]; + size_t datal; + nr_transport_addr remote_addr; + + r = nr_turn_client_parse_data_indication( + turn_ctx_, &addr, buf, len, data, &datal, sizeof(data), &remote_addr); + ASSERT_EQ(0, r); + std::cerr << "Received " << datal << " bytes from " + << remote_addr.as_string << std::endl; + + received_ += datal; + + for (size_t i = 0; i < datal; i++) { + ASSERT_EQ(i & 0xff, data[i]); + } + } else { + if (nr_is_stun_message(buf, len)) { + std::cerr << "STUN message of unexpected type" << std::endl; + } else { + std::cerr << "Not a STUN message" << std::endl; + } + return; + } + } + + void SendTo_s(const std::string& target, int expect_return) { + nr_transport_addr addr; + int r; + + // Expected pattern here is "IP4:127.0.0.1:3487" + ASSERT_EQ(0, target.compare(0, 4, "IP4:")); + + size_t offset = target.rfind(':'); + ASSERT_NE(std::string::npos, offset); + + std::string host = target.substr(4, offset - 4); + std::string port = target.substr(offset + 1); + + r = nr_str_port_to_transport_addr(host.c_str(), atoi(port.c_str()), + IPPROTO_UDP, &addr); + ASSERT_EQ(0, r); + + unsigned char test[100]; + for (size_t i = 0; i < sizeof(test); i++) { + test[i] = i & 0xff; + } + + std::cerr << "Sending test message to " << target << " ..." << std::endl; + + r = nr_turn_client_send_indication(turn_ctx_, test, sizeof(test), 0, &addr); + if (expect_return >= 0) { + ASSERT_EQ(expect_return, r); + } + } + + void SendTo(const std::string& target, int expect_return = 0) { + test_utils_->SyncDispatchToSTS( + WrapRunnable(this, &TurnClient::SendTo_s, target, expect_return)); + } + + int received() const { return received_; } + + static void socket_readable_cb(NR_SOCKET s, int how, void* arg) { + static_cast<TurnClient*>(arg)->Readable(s, how, arg); + } + + static void allocate_success_cb(NR_SOCKET s, int how, void* arg) { + static_cast<TurnClient*>(arg)->Allocated(); + } + + protected: + std::string turn_server_; + nr_socket* real_socket_; + nr_socket* net_socket_; + nr_socket* buffered_socket_; + NR_SOCKET net_fd_; + nr_turn_client_ctx* turn_ctx_; + std::string relay_addr_; + bool allocated_; + int received_; + int protocol_; +}; + +TEST_F(TurnClient, Allocate) { + if (WarnIfTurnNotConfigured()) return; + + Allocate(); +} + +TEST_F(TurnClient, AllocateTcp) { + if (WarnIfTurnNotConfigured()) return; + + SetTcp(); + Allocate(); +} + +TEST_F(TurnClient, AllocateAndHold) { + if (WarnIfTurnNotConfigured()) return; + + Allocate(); + PR_Sleep(20000); + ASSERT_TRUE(turn_ctx_->state == NR_TURN_CLIENT_STATE_ALLOCATED); +} + +TEST_F(TurnClient, SendToSelf) { + if (WarnIfTurnNotConfigured()) return; + + Allocate(); + SendTo(relay_addr_); + ASSERT_TRUE_WAIT(received() == 100, 5000); + SendTo(relay_addr_); + ASSERT_TRUE_WAIT(received() == 200, 1000); +} + +TEST_F(TurnClient, SendToSelfTcp) { + if (WarnIfTurnNotConfigured()) return; + + SetTcp(); + Allocate(); + SendTo(relay_addr_); + ASSERT_TRUE_WAIT(received() == 100, 5000); + SendTo(relay_addr_); + ASSERT_TRUE_WAIT(received() == 200, 1000); +} + +TEST_F(TurnClient, PermissionDenied) { + if (WarnIfTurnNotConfigured()) return; + + Allocate(); + RequestPermission(relay_addr_); + PR_Sleep(1000); + + /* Fake a 403 response */ + nr_turn_permission* perm; + perm = STAILQ_FIRST(&turn_ctx_->permissions); + ASSERT_TRUE(perm); + while (perm) { + perm->stun->last_error_code = 403; + std::cerr << "Set 403's on permission" << std::endl; + perm = STAILQ_NEXT(perm, entry); + } + + SendTo(relay_addr_, R_NOT_PERMITTED); + ASSERT_TRUE(received() == 0); + + // TODO: We should check if we can still send to a second destination, but + // we would need a second TURN client as one client can only handle one + // allocation (maybe as part of bug 1128128 ?). +} + +TEST_F(TurnClient, DeallocateReceiveFailure) { + if (WarnIfTurnNotConfigured()) return; + + Allocate(); + SendTo(relay_addr_); + ASSERT_TRUE_WAIT(received() == 100, 5000); + Deallocate(); + turn_ctx_->state = NR_TURN_CLIENT_STATE_ALLOCATED; + SendTo(relay_addr_); + PR_Sleep(1000); + ASSERT_TRUE(received() == 100); +} + +TEST_F(TurnClient, DeallocateReceiveFailureTcp) { + if (WarnIfTurnNotConfigured()) return; + + SetTcp(); + Allocate(); + SendTo(relay_addr_); + ASSERT_TRUE_WAIT(received() == 100, 5000); + Deallocate(); + turn_ctx_->state = NR_TURN_CLIENT_STATE_ALLOCATED; + /* Either the connection got closed by the TURN server already, then the send + * is going to fail, which we simply ignore. Or the connection is still alive + * and we cand send the data, but it should not get forwarded to us. In either + * case we should not receive more data. */ + SendTo(relay_addr_, -1); + PR_Sleep(1000); + ASSERT_TRUE(received() == 100); +} + +TEST_F(TurnClient, AllocateDummyServer) { + if (WarnIfTurnNotConfigured()) return; + + turn_server_ = kDummyTurnServer; + Allocate(false); +} diff --git a/dom/media/webrtc/transport/test/webrtcproxychannel_unittest.cpp b/dom/media/webrtc/transport/test/webrtcproxychannel_unittest.cpp new file mode 100644 index 0000000000..5bfddc7a3f --- /dev/null +++ b/dom/media/webrtc/transport/test/webrtcproxychannel_unittest.cpp @@ -0,0 +1,754 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <algorithm> +#include <mutex> + +#include "mozilla/net/WebrtcTCPSocket.h" +#include "mozilla/net/WebrtcTCPSocketCallback.h" + +#include "nsISocketTransport.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" + +static const uint32_t kDefaultTestTimeout = 2000; +static const char kReadData[] = "Hello, World!"; +static const size_t kReadDataLength = sizeof(kReadData) - 1; +static const std::string kReadDataString = + std::string(kReadData, kReadDataLength); +static int kDataLargeOuterLoopCount = 128; +static int kDataLargeInnerLoopCount = 1024; + +namespace mozilla { + +using namespace net; +using namespace testing; + +class WebrtcTCPSocketTestCallback; + +class FakeSocketTransportProvider : public nsISocketTransport { + public: + NS_DECL_THREADSAFE_ISUPPORTS + + // nsISocketTransport + NS_IMETHOD GetHost(nsACString& aHost) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetPort(int32_t* aPort) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetScriptableOriginAttributes( + JSContext* cx, JS::MutableHandle<JS::Value> aOriginAttributes) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetScriptableOriginAttributes( + JSContext* cx, JS::Handle<JS::Value> aOriginAttributes) override { + MOZ_ASSERT(false); + return NS_OK; + } + virtual nsresult GetOriginAttributes( + mozilla::OriginAttributes* _retval) override { + MOZ_ASSERT(false); + return NS_OK; + } + virtual nsresult SetOriginAttributes( + const mozilla::OriginAttributes& aOriginAttrs) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetPeerAddr(mozilla::net::NetAddr* _retval) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetSelfAddr(mozilla::net::NetAddr* _retval) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD Bind(mozilla::net::NetAddr* aLocalAddr) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetScriptablePeerAddr(nsINetAddr** _retval) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetScriptableSelfAddr(nsINetAddr** _retval) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetTlsSocketControl( + nsITLSSocketControl** aTLSSocketControl) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetSecurityCallbacks( + nsIInterfaceRequestor** aSecurityCallbacks) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetSecurityCallbacks( + nsIInterfaceRequestor* aSecurityCallbacks) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD IsAlive(bool* _retval) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetTimeout(uint32_t aType, uint32_t* _retval) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetTimeout(uint32_t aType, uint32_t aValue) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetLinger(bool aPolarity, int16_t aTimeout) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetReuseAddrPort(bool reuseAddrPort) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetConnectionFlags(uint32_t* aConnectionFlags) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetConnectionFlags(uint32_t aConnectionFlags) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetIsPrivate(bool) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetTlsFlags(uint32_t* aTlsFlags) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetTlsFlags(uint32_t aTlsFlags) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetQoSBits(uint8_t* aQoSBits) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetQoSBits(uint8_t aQoSBits) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetRecvBufferSize(uint32_t* aRecvBufferSize) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetSendBufferSize(uint32_t* aSendBufferSize) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetKeepaliveEnabled(bool* aKeepaliveEnabled) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetKeepaliveEnabled(bool aKeepaliveEnabled) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetKeepaliveVals(int32_t keepaliveIdleTime, + int32_t keepaliveRetryInterval) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetResetIPFamilyPreference( + bool* aResetIPFamilyPreference) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetEchConfigUsed(bool* aEchConfigUsed) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetEchConfig(const nsACString& aEchConfig) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD ResolvedByTRR(bool* _retval) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetEffectiveTRRMode( + nsIRequest::TRRMode* aEffectiveTRRMode) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetTrrSkipReason(nsITRRSkipReason::value* aSkipReason) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetRetryDnsIfPossible(bool* aRetryDns) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD GetStatus(nsresult* aStatus) override { + MOZ_ASSERT(false); + return NS_OK; + } + + // nsITransport + NS_IMETHOD OpenInputStream(uint32_t aFlags, uint32_t aSegmentSize, + uint32_t aSegmentCount, + nsIInputStream** _retval) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD OpenOutputStream(uint32_t aFlags, uint32_t aSegmentSize, + uint32_t aSegmentCount, + nsIOutputStream** _retval) override { + MOZ_ASSERT(false); + return NS_OK; + } + NS_IMETHOD SetEventSink(nsITransportEventSink* aSink, + nsIEventTarget* aEventTarget) override { + MOZ_ASSERT(false); + return NS_OK; + } + + // fake except for these methods which are OK to call + // nsISocketTransport + NS_IMETHOD SetRecvBufferSize(uint32_t aRecvBufferSize) override { + return NS_OK; + } + NS_IMETHOD SetSendBufferSize(uint32_t aSendBufferSize) override { + return NS_OK; + } + // nsITransport + NS_IMETHOD Close(nsresult aReason) override { return NS_OK; } + + protected: + virtual ~FakeSocketTransportProvider() = default; +}; + +NS_IMPL_ISUPPORTS(FakeSocketTransportProvider, nsISocketTransport, nsITransport) + +// Implements some common elements to WebrtcTCPSocketTestOutputStream and +// WebrtcTCPSocketTestInputStream. +class WebrtcTCPSocketTestStream { + public: + WebrtcTCPSocketTestStream(); + + void Fail() { mMustFail = true; } + + size_t DataLength(); + template <typename T> + void AppendElements(const T* aBuffer, size_t aLength); + + protected: + virtual ~WebrtcTCPSocketTestStream() = default; + + nsTArray<uint8_t> mData; + std::mutex mDataMutex; + + bool mMustFail; +}; + +WebrtcTCPSocketTestStream::WebrtcTCPSocketTestStream() : mMustFail(false) {} + +template <typename T> +void WebrtcTCPSocketTestStream::AppendElements(const T* aBuffer, + size_t aLength) { + std::lock_guard<std::mutex> guard(mDataMutex); + mData.AppendElements(aBuffer, aLength); +} + +size_t WebrtcTCPSocketTestStream::DataLength() { + std::lock_guard<std::mutex> guard(mDataMutex); + return mData.Length(); +} + +class WebrtcTCPSocketTestInputStream : public nsIAsyncInputStream, + public WebrtcTCPSocketTestStream { + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSIASYNCINPUTSTREAM + NS_DECL_NSIINPUTSTREAM + + WebrtcTCPSocketTestInputStream() + : mMaxReadSize(1024 * 1024), mAllowCallbacks(false) {} + + void DoCallback(); + void CallCallback(const nsCOMPtr<nsIInputStreamCallback>& aCallback); + void AllowCallbacks() { mAllowCallbacks = true; } + + size_t mMaxReadSize; + + protected: + virtual ~WebrtcTCPSocketTestInputStream() = default; + + private: + nsCOMPtr<nsIInputStreamCallback> mCallback; + nsCOMPtr<nsIEventTarget> mCallbackTarget; + + bool mAllowCallbacks; +}; + +NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestInputStream, nsIAsyncInputStream, + nsIInputStream) + +nsresult WebrtcTCPSocketTestInputStream::AsyncWait( + nsIInputStreamCallback* aCallback, uint32_t aFlags, + uint32_t aRequestedCount, nsIEventTarget* aEventTarget) { + MOZ_ASSERT(!aEventTarget, "no event target should be set"); + + mCallback = aCallback; + mCallbackTarget = NS_GetCurrentThread(); + + if (mAllowCallbacks && DataLength() > 0) { + DoCallback(); + } + + return NS_OK; +} + +nsresult WebrtcTCPSocketTestInputStream::CloseWithStatus(nsresult aStatus) { + return Close(); +} + +nsresult WebrtcTCPSocketTestInputStream::Close() { return NS_OK; } + +nsresult WebrtcTCPSocketTestInputStream::Available(uint64_t* aAvailable) { + *aAvailable = DataLength(); + return NS_OK; +} + +nsresult WebrtcTCPSocketTestInputStream::StreamStatus() { return NS_OK; } + +nsresult WebrtcTCPSocketTestInputStream::Read(char* aBuffer, uint32_t aCount, + uint32_t* aRead) { + std::lock_guard<std::mutex> guard(mDataMutex); + if (mMustFail) { + return NS_ERROR_FAILURE; + } + *aRead = std::min({(size_t)aCount, mData.Length(), mMaxReadSize}); + memcpy(aBuffer, mData.Elements(), *aRead); + mData.RemoveElementsAt(0, *aRead); + return *aRead > 0 ? NS_OK : NS_BASE_STREAM_WOULD_BLOCK; +} + +nsresult WebrtcTCPSocketTestInputStream::ReadSegments(nsWriteSegmentFun aWriter, + void* aClosure, + uint32_t aCount, + uint32_t* _retval) { + MOZ_ASSERT(false); + return NS_OK; +} + +nsresult WebrtcTCPSocketTestInputStream::IsNonBlocking(bool* aIsNonBlocking) { + *aIsNonBlocking = true; + return NS_OK; +} + +void WebrtcTCPSocketTestInputStream::CallCallback( + const nsCOMPtr<nsIInputStreamCallback>& aCallback) { + aCallback->OnInputStreamReady(this); +} + +void WebrtcTCPSocketTestInputStream::DoCallback() { + if (mCallback) { + mCallbackTarget->Dispatch( + NewRunnableMethod<const nsCOMPtr<nsIInputStreamCallback>&>( + "WebrtcTCPSocketTestInputStream::DoCallback", this, + &WebrtcTCPSocketTestInputStream::CallCallback, + std::move(mCallback))); + + mCallbackTarget = nullptr; + } +} + +class WebrtcTCPSocketTestOutputStream : public nsIAsyncOutputStream, + public WebrtcTCPSocketTestStream { + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSIASYNCOUTPUTSTREAM + NS_DECL_NSIOUTPUTSTREAM + + WebrtcTCPSocketTestOutputStream() : mMaxWriteSize(1024 * 1024) {} + + void DoCallback(); + void CallCallback(const nsCOMPtr<nsIOutputStreamCallback>& aCallback); + + std::string DataString(); + + uint32_t mMaxWriteSize; + + protected: + virtual ~WebrtcTCPSocketTestOutputStream() = default; + + private: + nsCOMPtr<nsIOutputStreamCallback> mCallback; + nsCOMPtr<nsIEventTarget> mCallbackTarget; +}; + +NS_IMPL_ISUPPORTS(WebrtcTCPSocketTestOutputStream, nsIAsyncOutputStream, + nsIOutputStream) + +nsresult WebrtcTCPSocketTestOutputStream::AsyncWait( + nsIOutputStreamCallback* aCallback, uint32_t aFlags, + uint32_t aRequestedCount, nsIEventTarget* aEventTarget) { + MOZ_ASSERT(!aEventTarget, "no event target should be set"); + + mCallback = aCallback; + mCallbackTarget = NS_GetCurrentThread(); + + return NS_OK; +} + +nsresult WebrtcTCPSocketTestOutputStream::CloseWithStatus(nsresult aStatus) { + return Close(); +} + +nsresult WebrtcTCPSocketTestOutputStream::Close() { return NS_OK; } + +nsresult WebrtcTCPSocketTestOutputStream::Flush() { return NS_OK; } + +nsresult WebrtcTCPSocketTestOutputStream::StreamStatus() { + return mMustFail ? NS_ERROR_FAILURE : NS_OK; +} + +nsresult WebrtcTCPSocketTestOutputStream::Write(const char* aBuffer, + uint32_t aCount, + uint32_t* aWrote) { + if (mMustFail) { + return NS_ERROR_FAILURE; + } + *aWrote = std::min(aCount, mMaxWriteSize); + AppendElements(aBuffer, *aWrote); + return NS_OK; +} + +nsresult WebrtcTCPSocketTestOutputStream::WriteSegments( + nsReadSegmentFun aReader, void* aClosure, uint32_t aCount, + uint32_t* _retval) { + MOZ_ASSERT(false); + return NS_OK; +} + +nsresult WebrtcTCPSocketTestOutputStream::WriteFrom(nsIInputStream* aFromStream, + uint32_t aCount, + uint32_t* _retval) { + MOZ_ASSERT(false); + return NS_OK; +} + +nsresult WebrtcTCPSocketTestOutputStream::IsNonBlocking(bool* aIsNonBlocking) { + *aIsNonBlocking = true; + return NS_OK; +} + +void WebrtcTCPSocketTestOutputStream::CallCallback( + const nsCOMPtr<nsIOutputStreamCallback>& aCallback) { + aCallback->OnOutputStreamReady(this); +} + +void WebrtcTCPSocketTestOutputStream::DoCallback() { + if (mCallback) { + mCallbackTarget->Dispatch( + NewRunnableMethod<const nsCOMPtr<nsIOutputStreamCallback>&>( + "WebrtcTCPSocketTestOutputStream::CallCallback", this, + &WebrtcTCPSocketTestOutputStream::CallCallback, + std::move(mCallback))); + + mCallbackTarget = nullptr; + } +} + +std::string WebrtcTCPSocketTestOutputStream::DataString() { + std::lock_guard<std::mutex> guard(mDataMutex); + return std::string((char*)mData.Elements(), mData.Length()); +} + +// Fake as in not the real WebrtcTCPSocket but real enough +class FakeWebrtcTCPSocket : public WebrtcTCPSocket { + public: + explicit FakeWebrtcTCPSocket(WebrtcTCPSocketCallback* aCallback) + : WebrtcTCPSocket(aCallback) {} + + protected: + virtual ~FakeWebrtcTCPSocket() = default; + + void InvokeOnClose(nsresult aReason) override; + void InvokeOnConnected() override; + void InvokeOnRead(nsTArray<uint8_t>&& aReadData) override; +}; + +void FakeWebrtcTCPSocket::InvokeOnClose(nsresult aReason) { + mProxyCallbacks->OnClose(aReason); +} + +void FakeWebrtcTCPSocket::InvokeOnConnected() { + mProxyCallbacks->OnConnected("http"_ns); +} + +void FakeWebrtcTCPSocket::InvokeOnRead(nsTArray<uint8_t>&& aReadData) { + mProxyCallbacks->OnRead(std::move(aReadData)); +} + +class WebrtcTCPSocketTest : public MtransportTest { + public: + WebrtcTCPSocketTest() + : MtransportTest(), + mSocketThread(nullptr), + mSocketTransport(nullptr), + mInputStream(nullptr), + mOutputStream(nullptr), + mChannel(nullptr), + mCallback(nullptr), + mOnCloseCalled(false), + mOnConnectedCalled(false) {} + + // WebrtcTCPSocketCallback forwards from mCallback + void OnClose(nsresult aReason); + void OnConnected(const nsACString& aProxyType); + void OnRead(nsTArray<uint8_t>&& aReadData); + + void SetUp() override; + void TearDown() override; + + void DoTransportAvailable(); + + std::string ReadDataAsString(); + std::string GetDataLarge(); + + nsCOMPtr<nsIEventTarget> mSocketThread; + + nsCOMPtr<nsISocketTransport> mSocketTransport; + RefPtr<WebrtcTCPSocketTestInputStream> mInputStream; + RefPtr<WebrtcTCPSocketTestOutputStream> mOutputStream; + RefPtr<FakeWebrtcTCPSocket> mChannel; + RefPtr<WebrtcTCPSocketTestCallback> mCallback; + + bool mOnCloseCalled; + bool mOnConnectedCalled; + + size_t ReadDataLength(); + template <typename T> + void AppendReadData(const T* aBuffer, size_t aLength); + + private: + nsTArray<uint8_t> mReadData; + std::mutex mReadDataMutex; +}; + +class WebrtcTCPSocketTestCallback : public WebrtcTCPSocketCallback { + public: + NS_INLINE_DECL_THREADSAFE_REFCOUNTING(WebrtcTCPSocketTestCallback, override) + + explicit WebrtcTCPSocketTestCallback(WebrtcTCPSocketTest* aTest) + : mTest(aTest) {} + + // WebrtcTCPSocketCallback + void OnClose(nsresult aReason) override; + void OnConnected(const nsACString& aProxyType) override; + void OnRead(nsTArray<uint8_t>&& aReadData) override; + + protected: + virtual ~WebrtcTCPSocketTestCallback() = default; + + private: + WebrtcTCPSocketTest* mTest; +}; + +void WebrtcTCPSocketTest::SetUp() { + nsresult rv; + // WebrtcTCPSocket's threading model is the same as mtransport + // all socket operations are done on the socket thread + // callbacks are invoked on the main thread + mSocketThread = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv); + ASSERT_TRUE(NS_SUCCEEDED(rv)); + + mSocketTransport = new FakeSocketTransportProvider(); + mInputStream = new WebrtcTCPSocketTestInputStream(); + mOutputStream = new WebrtcTCPSocketTestOutputStream(); + mCallback = new WebrtcTCPSocketTestCallback(this); + mChannel = new FakeWebrtcTCPSocket(mCallback.get()); +} + +void WebrtcTCPSocketTest::TearDown() {} + +// WebrtcTCPSocketCallback +void WebrtcTCPSocketTest::OnRead(nsTArray<uint8_t>&& aReadData) { + AppendReadData(aReadData.Elements(), aReadData.Length()); +} + +void WebrtcTCPSocketTest::OnConnected(const nsACString& aProxyType) { + mOnConnectedCalled = true; +} + +void WebrtcTCPSocketTest::OnClose(nsresult aReason) { mOnCloseCalled = true; } + +void WebrtcTCPSocketTest::DoTransportAvailable() { + if (!mSocketThread->IsOnCurrentThread()) { + mSocketThread->Dispatch( + NS_NewRunnableFunction("DoTransportAvailable", [this]() -> void { + nsresult rv; + rv = mChannel->OnTransportAvailable(mSocketTransport, mInputStream, + mOutputStream); + ASSERT_EQ(NS_OK, rv); + })); + } else { + // should always be called on the main thread + MOZ_ASSERT(0); + } +} + +std::string WebrtcTCPSocketTest::ReadDataAsString() { + std::lock_guard<std::mutex> guard(mReadDataMutex); + return std::string((char*)mReadData.Elements(), mReadData.Length()); +} + +std::string WebrtcTCPSocketTest::GetDataLarge() { + std::string data; + for (int i = 0; i < kDataLargeOuterLoopCount * kDataLargeInnerLoopCount; + ++i) { + data += kReadData; + } + return data; +} + +template <typename T> +void WebrtcTCPSocketTest::AppendReadData(const T* aBuffer, size_t aLength) { + std::lock_guard<std::mutex> guard(mReadDataMutex); + mReadData.AppendElements(aBuffer, aLength); +} + +size_t WebrtcTCPSocketTest::ReadDataLength() { + std::lock_guard<std::mutex> guard(mReadDataMutex); + return mReadData.Length(); +} + +void WebrtcTCPSocketTestCallback::OnClose(nsresult aReason) { + mTest->OnClose(aReason); +} + +void WebrtcTCPSocketTestCallback::OnConnected(const nsACString& aProxyType) { + mTest->OnConnected(aProxyType); +} + +void WebrtcTCPSocketTestCallback::OnRead(nsTArray<uint8_t>&& aReadData) { + mTest->OnRead(std::move(aReadData)); +} + +} // namespace mozilla + +typedef mozilla::WebrtcTCPSocketTest WebrtcTCPSocketTest; + +TEST_F(WebrtcTCPSocketTest, SetUp) {} + +TEST_F(WebrtcTCPSocketTest, TransportAvailable) { + DoTransportAvailable(); + ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout); +} + +TEST_F(WebrtcTCPSocketTest, Read) { + DoTransportAvailable(); + ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout); + + mInputStream->AppendElements(kReadData, kReadDataLength); + mInputStream->DoCallback(); + + ASSERT_TRUE_WAIT(ReadDataAsString() == kReadDataString, kDefaultTestTimeout); +} + +TEST_F(WebrtcTCPSocketTest, Write) { + DoTransportAvailable(); + ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout); + + nsTArray<uint8_t> data; + data.AppendElements(kReadData, kReadDataLength); + mChannel->Write(std::move(data)); + + ASSERT_TRUE_WAIT(mChannel->CountUnwrittenBytes() == kReadDataLength, + kDefaultTestTimeout); + + mOutputStream->DoCallback(); + + ASSERT_TRUE_WAIT(mOutputStream->DataString() == kReadDataString, + kDefaultTestTimeout); +} + +TEST_F(WebrtcTCPSocketTest, ReadFail) { + DoTransportAvailable(); + ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout); + + mInputStream->AppendElements(kReadData, kReadDataLength); + mInputStream->Fail(); + mInputStream->DoCallback(); + + ASSERT_TRUE_WAIT(mOnCloseCalled, kDefaultTestTimeout); + ASSERT_EQ(0U, ReadDataLength()); +} + +TEST_F(WebrtcTCPSocketTest, WriteFail) { + DoTransportAvailable(); + ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout); + + nsTArray<uint8_t> array; + array.AppendElements(kReadData, kReadDataLength); + mChannel->Write(std::move(array)); + + ASSERT_TRUE_WAIT(mChannel->CountUnwrittenBytes() == kReadDataLength, + kDefaultTestTimeout); + + mOutputStream->Fail(); + mOutputStream->DoCallback(); + + ASSERT_TRUE_WAIT(mOnCloseCalled, kDefaultTestTimeout); + ASSERT_EQ(0U, mOutputStream->DataLength()); +} + +TEST_F(WebrtcTCPSocketTest, ReadLarge) { + DoTransportAvailable(); + ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout); + + const std::string data = GetDataLarge(); + + mInputStream->AppendElements(data.c_str(), data.length()); + // make sure reading loops more than once + mInputStream->mMaxReadSize = 3072; + mInputStream->AllowCallbacks(); + mInputStream->DoCallback(); + + ASSERT_TRUE_WAIT(ReadDataAsString() == data, kDefaultTestTimeout); +} + +TEST_F(WebrtcTCPSocketTest, WriteLarge) { + DoTransportAvailable(); + ASSERT_TRUE_WAIT(mOnConnectedCalled, kDefaultTestTimeout); + + const std::string data = GetDataLarge(); + + for (int i = 0; i < kDataLargeOuterLoopCount; ++i) { + nsTArray<uint8_t> array; + int chunkSize = kReadDataString.length() * kDataLargeInnerLoopCount; + int offset = i * chunkSize; + array.AppendElements(data.c_str() + offset, chunkSize); + mChannel->Write(std::move(array)); + } + + ASSERT_TRUE_WAIT(mChannel->CountUnwrittenBytes() == data.length(), + kDefaultTestTimeout); + + // make sure writing loops more than once per write request + mOutputStream->mMaxWriteSize = 1024; + mOutputStream->DoCallback(); + + ASSERT_TRUE_WAIT(mOutputStream->DataString() == data, kDefaultTestTimeout); +} |