/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ // Original author: ekr@rtfm.com #include #include #include "sigslot.h" #include "nsITimer.h" #include "transportflow.h" #include "transportlayer.h" #include "transportlayerloopback.h" #include "runnable_utils.h" #include "usrsctp.h" #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" #include "gtest_utils.h" using namespace mozilla; static bool sctp_logging = false; static int port_number = 5000; namespace { class TransportTestPeer; class SendPeriodic : public nsITimerCallback, public nsINamed { public: SendPeriodic(TransportTestPeer* peer, int to_send) : peer_(peer), to_send_(to_send) {} NS_DECL_THREADSAFE_ISUPPORTS NS_DECL_NSITIMERCALLBACK NS_DECL_NSINAMED protected: virtual ~SendPeriodic() = default; TransportTestPeer* peer_; int to_send_; }; NS_IMPL_ISUPPORTS(SendPeriodic, nsITimerCallback, nsINamed) class TransportTestPeer : public sigslot::has_slots<> { public: TransportTestPeer(std::string name, int local_port, int remote_port, MtransportTestUtils* utils) : name_(name), connected_(false), sent_(0), received_(0), flow_(new TransportFlow()), loopback_(new TransportLayerLoopback()), sctp_(usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, receive_cb, nullptr, 0, nullptr)), timer_(NS_NewTimer()), periodic_(nullptr), test_utils_(utils) { std::cerr << "Creating TransportTestPeer; flow=" << static_cast(flow_.get()) << " local=" << local_port << " remote=" << remote_port << std::endl; usrsctp_register_address(static_cast(this)); int r = usrsctp_set_non_blocking(sctp_, 1); EXPECT_GE(r, 0); struct linger l; l.l_onoff = 1; l.l_linger = 0; r = usrsctp_setsockopt(sctp_, SOL_SOCKET, SO_LINGER, &l, (socklen_t)sizeof(l)); EXPECT_GE(r, 0); struct sctp_event subscription; memset(&subscription, 0, sizeof(subscription)); subscription.se_assoc_id = SCTP_ALL_ASSOC; subscription.se_on = 1; subscription.se_type = SCTP_ASSOC_CHANGE; r = usrsctp_setsockopt(sctp_, IPPROTO_SCTP, SCTP_EVENT, &subscription, sizeof(subscription)); EXPECT_GE(r, 0); memset(&local_addr_, 0, sizeof(local_addr_)); local_addr_.sconn_family = AF_CONN; #if !defined(__Userspace_os_Linux) && !defined(__Userspace_os_Windows) && \ !defined(__Userspace_os_Android) local_addr_.sconn_len = sizeof(struct sockaddr_conn); #endif local_addr_.sconn_port = htons(local_port); local_addr_.sconn_addr = static_cast(this); memset(&remote_addr_, 0, sizeof(remote_addr_)); remote_addr_.sconn_family = AF_CONN; #if !defined(__Userspace_os_Linux) && !defined(__Userspace_os_Windows) && \ !defined(__Userspace_os_Android) remote_addr_.sconn_len = sizeof(struct sockaddr_conn); #endif remote_addr_.sconn_port = htons(remote_port); remote_addr_.sconn_addr = static_cast(this); nsresult res; res = loopback_->Init(); EXPECT_EQ((nsresult)NS_OK, res); } ~TransportTestPeer() { std::cerr << "Destroying sctp connection flow=" << static_cast(flow_.get()) << std::endl; usrsctp_close(sctp_); usrsctp_deregister_address(static_cast(this)); test_utils_->SyncDispatchToSTS( WrapRunnable(this, &TransportTestPeer::Disconnect_s)); std::cerr << "~TransportTestPeer() completed" << std::endl; } void ConnectSocket(TransportTestPeer* peer) { test_utils_->SyncDispatchToSTS( WrapRunnable(this, &TransportTestPeer::ConnectSocket_s, peer)); } void ConnectSocket_s(TransportTestPeer* peer) { loopback_->Connect(peer->loopback_); ASSERT_EQ((nsresult)NS_OK, loopback_->Init()); flow_->PushLayer(loopback_); loopback_->SignalPacketReceived.connect(this, &TransportTestPeer::PacketReceived); // SCTP here! ASSERT_TRUE(sctp_); std::cerr << "Calling usrsctp_bind()" << std::endl; int r = usrsctp_bind(sctp_, reinterpret_cast(&local_addr_), sizeof(local_addr_)); ASSERT_GE(0, r); std::cerr << "Calling usrsctp_connect()" << std::endl; r = usrsctp_connect(sctp_, reinterpret_cast(&remote_addr_), sizeof(remote_addr_)); ASSERT_GE(0, r); } void Disconnect_s() { disconnect_all(); if (flow_) { flow_ = nullptr; } } void Disconnect() { loopback_->Disconnect(); } void StartTransfer(size_t to_send) { periodic_ = new SendPeriodic(this, to_send); timer_->SetTarget(test_utils_->sts_target()); timer_->InitWithCallback(periodic_, 10, nsITimer::TYPE_REPEATING_SLACK); } void SendOne() { unsigned char buf[100]; memset(buf, sent_ & 0xff, sizeof(buf)); struct sctp_sndinfo info; info.snd_sid = 1; info.snd_flags = 0; info.snd_ppid = 50; // What the heck is this? info.snd_context = 0; info.snd_assoc_id = 0; int r = usrsctp_sendv(sctp_, buf, sizeof(buf), nullptr, 0, static_cast(&info), sizeof(info), SCTP_SENDV_SNDINFO, 0); ASSERT_TRUE(r >= 0); ASSERT_EQ(sizeof(buf), (size_t)r); ++sent_; } int sent() const { return sent_; } int received() const { return received_; } bool connected() const { return connected_; } static TransportResult SendPacket_s(UniquePtr packet, const RefPtr& flow, TransportLayer* layer) { return layer->SendPacket(*packet); } TransportResult SendPacket(const unsigned char* data, size_t len) { UniquePtr packet(new MediaPacket); packet->Copy(data, len); // Uses DISPATCH_NORMAL to avoid possible deadlocks when we're called // from MainThread especially during shutdown (same as DataChannels). // RUN_ON_THREAD short-circuits if already on the STS thread, which is // normal for most transfers outside of connect() and close(). Passes // a refptr to flow_ to avoid any async deletion issues (since we can't // make 'this' into a refptr as it isn't refcounted) RUN_ON_THREAD(test_utils_->sts_target(), WrapRunnableNM(&TransportTestPeer::SendPacket_s, std::move(packet), flow_, loopback_), NS_DISPATCH_NORMAL); return 0; } void PacketReceived(TransportLayer* layer, MediaPacket& packet) { std::cerr << "Received " << packet.len() << " bytes" << std::endl; // Pass the data to SCTP usrsctp_conninput(static_cast(this), packet.data(), packet.len(), 0); } // Process SCTP notification void Notification(union sctp_notification* msg, size_t len) { ASSERT_EQ(msg->sn_header.sn_length, len); if (msg->sn_header.sn_type == SCTP_ASSOC_CHANGE) { struct sctp_assoc_change* change = &msg->sn_assoc_change; if (change->sac_state == SCTP_COMM_UP) { std::cerr << "Connection up" << std::endl; SetConnected(true); } else { std::cerr << "Connection down" << std::endl; SetConnected(false); } } } void SetConnected(bool state) { connected_ = state; } static int conn_output(void* addr, void* buffer, size_t length, uint8_t tos, uint8_t set_df) { TransportTestPeer* peer = static_cast(addr); peer->SendPacket(static_cast(buffer), length); return 0; } static int receive_cb(struct socket* sock, union sctp_sockstore addr, void* data, size_t datalen, struct sctp_rcvinfo rcv, int flags, void* ulp_info) { TransportTestPeer* me = static_cast(addr.sconn.sconn_addr); MOZ_ASSERT(me); if (flags & MSG_NOTIFICATION) { union sctp_notification* notif = static_cast(data); me->Notification(notif, datalen); return 0; } me->received_ += datalen; std::cerr << "receive_cb: sock " << sock << " data " << data << "(" << datalen << ") total received bytes = " << me->received_ << std::endl; return 0; } private: std::string name_; bool connected_; size_t sent_; size_t received_; // Owns the TransportLayerLoopback, but basically does nothing else. RefPtr flow_; TransportLayerLoopback* loopback_; struct sockaddr_conn local_addr_; struct sockaddr_conn remote_addr_; struct socket* sctp_; nsCOMPtr timer_; RefPtr periodic_; MtransportTestUtils* test_utils_; }; // Implemented here because it calls a method of TransportTestPeer NS_IMETHODIMP SendPeriodic::Notify(nsITimer* timer) { peer_->SendOne(); --to_send_; if (!to_send_) { timer->Cancel(); } return NS_OK; } NS_IMETHODIMP SendPeriodic::GetName(nsACString& aName) { aName.AssignLiteral("SendPeriodic"); return NS_OK; } class SctpTransportTest : public MtransportTest { public: SctpTransportTest() = default; ~SctpTransportTest() = default; static void debug_printf(const char* format, ...) { va_list ap; va_start(ap, format); vprintf(format, ap); va_end(ap); } static void SetUpTestCase() { if (sctp_logging) { usrsctp_init(0, &TransportTestPeer::conn_output, debug_printf); usrsctp_sysctl_set_sctp_debug_on(0xffffffff); } else { usrsctp_init(0, &TransportTestPeer::conn_output, nullptr); } } void TearDown() override { if (p1_) p1_->Disconnect(); if (p2_) p2_->Disconnect(); delete p1_; delete p2_; MtransportTest::TearDown(); } void ConnectSocket(int p1port = 0, int p2port = 0) { if (!p1port) p1port = port_number++; if (!p2port) p2port = port_number++; p1_ = new TransportTestPeer("P1", p1port, p2port, test_utils_); p2_ = new TransportTestPeer("P2", p2port, p1port, test_utils_); p1_->ConnectSocket(p2_); p2_->ConnectSocket(p1_); ASSERT_TRUE_WAIT(p1_->connected(), 2000); ASSERT_TRUE_WAIT(p2_->connected(), 2000); } void TestTransfer(int expected = 1) { std::cerr << "Starting trasnsfer test" << std::endl; p1_->StartTransfer(expected); ASSERT_TRUE_WAIT(p1_->sent() == expected, 10000); ASSERT_TRUE_WAIT(p2_->received() == (expected * 100), 10000); std::cerr << "P2 received " << p2_->received() << std::endl; } protected: TransportTestPeer* p1_ = nullptr; TransportTestPeer* p2_ = nullptr; }; TEST_F(SctpTransportTest, TestConnect) { ConnectSocket(); } TEST_F(SctpTransportTest, TestConnectSymmetricalPorts) { ConnectSocket(5002, 5002); } TEST_F(SctpTransportTest, TestTransfer) { ConnectSocket(); TestTransfer(50); } } // end namespace