summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc')
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc3058
1 files changed, 3058 insertions, 0 deletions
diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc
new file mode 100644
index 0000000000..13202846ac
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc
@@ -0,0 +1,3058 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/dcsctp_socket.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <deque>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/flags/flag.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "net/dcsctp/common/handover_testing.h"
+#include "net/dcsctp/common/math.h"
+#include "net/dcsctp/packet/chunk/abort_chunk.h"
+#include "net/dcsctp/packet/chunk/chunk.h"
+#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
+#include "net/dcsctp/packet/chunk/data_chunk.h"
+#include "net/dcsctp/packet/chunk/data_common.h"
+#include "net/dcsctp/packet/chunk/error_chunk.h"
+#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
+#include "net/dcsctp/packet/chunk/idata_chunk.h"
+#include "net/dcsctp/packet/chunk/init_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/init_chunk.h"
+#include "net/dcsctp/packet/chunk/reconfig_chunk.h"
+#include "net/dcsctp/packet/chunk/sack_chunk.h"
+#include "net/dcsctp/packet/chunk/shutdown_chunk.h"
+#include "net/dcsctp/packet/error_cause/error_cause.h"
+#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h"
+#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h"
+#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h"
+#include "net/dcsctp/packet/parameter/parameter.h"
+#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/packet/tlv_trait.h"
+#include "net/dcsctp/public/dcsctp_message.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/public/text_pcap_packet_observer.h"
+#include "net/dcsctp/public/types.h"
+#include "net/dcsctp/rx/reassembly_queue.h"
+#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h"
+#include "net/dcsctp/testing/testing_macros.h"
+#include "rtc_base/gunit.h"
+#include "test/gmock.h"
+
+ABSL_FLAG(bool, dcsctp_capture_packets, false, "Print packet capture.");
+
+namespace dcsctp {
+namespace {
+using ::testing::_;
+using ::testing::AllOf;
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+using ::testing::Not;
+using ::testing::Property;
+using ::testing::SizeIs;
+using ::testing::UnorderedElementsAre;
+
+constexpr SendOptions kSendOptions;
+constexpr size_t kLargeMessageSize = DcSctpOptions::kMaxSafeMTUSize * 20;
+constexpr size_t kSmallMessageSize = 10;
+constexpr int kMaxBurstPackets = 4;
+constexpr DcSctpOptions kDefaultOptions;
+
+MATCHER_P(HasChunks, chunks, "") {
+ absl::optional<SctpPacket> packet = SctpPacket::Parse(arg, kDefaultOptions);
+ if (!packet.has_value()) {
+ *result_listener << "data didn't parse as an SctpPacket";
+ return false;
+ }
+
+ return ExplainMatchResult(chunks, packet->descriptors(), result_listener);
+}
+
+MATCHER_P(IsChunkType, chunk_type, "") {
+ return ExplainMatchResult(chunk_type, arg.type, result_listener);
+}
+
+MATCHER_P(IsDataChunk, properties, "") {
+ if (arg.type != DataChunk::kType) {
+ *result_listener << "the chunk is not a data chunk";
+ return false;
+ }
+
+ absl::optional<DataChunk> chunk = DataChunk::Parse(arg.data);
+ if (!chunk.has_value()) {
+ *result_listener << "The chunk didn't parse as a data chunk";
+ return false;
+ }
+
+ return ExplainMatchResult(properties, *chunk, result_listener);
+}
+
+MATCHER_P(IsSack, properties, "") {
+ if (arg.type != SackChunk::kType) {
+ *result_listener << "the chunk is not a sack chunk";
+ return false;
+ }
+
+ absl::optional<SackChunk> chunk = SackChunk::Parse(arg.data);
+ if (!chunk.has_value()) {
+ *result_listener << "The chunk didn't parse as a sack chunk";
+ return false;
+ }
+
+ return ExplainMatchResult(properties, *chunk, result_listener);
+}
+
+MATCHER_P(IsReConfig, properties, "") {
+ if (arg.type != ReConfigChunk::kType) {
+ *result_listener << "the chunk is not a re-config chunk";
+ return false;
+ }
+
+ absl::optional<ReConfigChunk> chunk = ReConfigChunk::Parse(arg.data);
+ if (!chunk.has_value()) {
+ *result_listener << "The chunk didn't parse as a re-config chunk";
+ return false;
+ }
+
+ return ExplainMatchResult(properties, *chunk, result_listener);
+}
+
+MATCHER_P(IsHeartbeatAck, properties, "") {
+ if (arg.type != HeartbeatAckChunk::kType) {
+ *result_listener << "the chunk is not a HeartbeatAckChunk";
+ return false;
+ }
+
+ absl::optional<HeartbeatAckChunk> chunk = HeartbeatAckChunk::Parse(arg.data);
+ if (!chunk.has_value()) {
+ *result_listener << "The chunk didn't parse as a HeartbeatAckChunk";
+ return false;
+ }
+
+ return ExplainMatchResult(properties, *chunk, result_listener);
+}
+
+MATCHER_P(IsHeartbeatRequest, properties, "") {
+ if (arg.type != HeartbeatRequestChunk::kType) {
+ *result_listener << "the chunk is not a HeartbeatRequestChunk";
+ return false;
+ }
+
+ absl::optional<HeartbeatRequestChunk> chunk =
+ HeartbeatRequestChunk::Parse(arg.data);
+ if (!chunk.has_value()) {
+ *result_listener << "The chunk didn't parse as a HeartbeatRequestChunk";
+ return false;
+ }
+
+ return ExplainMatchResult(properties, *chunk, result_listener);
+}
+
+MATCHER_P(HasParameters, parameters, "") {
+ return ExplainMatchResult(parameters, arg.parameters().descriptors(),
+ result_listener);
+}
+
+MATCHER_P(IsOutgoingResetRequest, properties, "") {
+ if (arg.type != OutgoingSSNResetRequestParameter::kType) {
+ *result_listener
+ << "the parameter is not an OutgoingSSNResetRequestParameter";
+ return false;
+ }
+
+ absl::optional<OutgoingSSNResetRequestParameter> parameter =
+ OutgoingSSNResetRequestParameter::Parse(arg.data);
+ if (!parameter.has_value()) {
+ *result_listener
+ << "The parameter didn't parse as an OutgoingSSNResetRequestParameter";
+ return false;
+ }
+
+ return ExplainMatchResult(properties, *parameter, result_listener);
+}
+
+MATCHER_P(IsReconfigurationResponse, properties, "") {
+ if (arg.type != ReconfigurationResponseParameter::kType) {
+ *result_listener
+ << "the parameter is not an ReconfigurationResponseParameter";
+ return false;
+ }
+
+ absl::optional<ReconfigurationResponseParameter> parameter =
+ ReconfigurationResponseParameter::Parse(arg.data);
+ if (!parameter.has_value()) {
+ *result_listener
+ << "The parameter didn't parse as an ReconfigurationResponseParameter";
+ return false;
+ }
+
+ return ExplainMatchResult(properties, *parameter, result_listener);
+}
+
+TSN AddTo(TSN tsn, int delta) {
+ return TSN(*tsn + delta);
+}
+
+DcSctpOptions FixupOptions(DcSctpOptions options = {}) {
+ DcSctpOptions fixup = options;
+ // To make the interval more predictable in tests.
+ fixup.heartbeat_interval_include_rtt = false;
+ fixup.max_burst = kMaxBurstPackets;
+ return fixup;
+}
+
+std::unique_ptr<PacketObserver> GetPacketObserver(absl::string_view name) {
+ if (absl::GetFlag(FLAGS_dcsctp_capture_packets)) {
+ return std::make_unique<TextPcapPacketObserver>(name);
+ }
+ return nullptr;
+}
+
+struct SocketUnderTest {
+ explicit SocketUnderTest(absl::string_view name,
+ const DcSctpOptions& opts = kDefaultOptions)
+ : options(FixupOptions(opts)),
+ cb(name),
+ socket(name, cb, GetPacketObserver(name), options) {}
+
+ const DcSctpOptions options;
+ testing::NiceMock<MockDcSctpSocketCallbacks> cb;
+ DcSctpSocket socket;
+};
+
+void ExchangeMessages(SocketUnderTest& a, SocketUnderTest& z) {
+ bool delivered_packet = false;
+ do {
+ delivered_packet = false;
+ std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
+ if (!packet_from_a.empty()) {
+ delivered_packet = true;
+ z.socket.ReceivePacket(std::move(packet_from_a));
+ }
+ std::vector<uint8_t> packet_from_z = z.cb.ConsumeSentPacket();
+ if (!packet_from_z.empty()) {
+ delivered_packet = true;
+ a.socket.ReceivePacket(std::move(packet_from_z));
+ }
+ } while (delivered_packet);
+}
+
+void RunTimers(SocketUnderTest& s) {
+ for (;;) {
+ absl::optional<TimeoutID> timeout_id = s.cb.GetNextExpiredTimeout();
+ if (!timeout_id.has_value()) {
+ break;
+ }
+ s.socket.HandleTimeout(*timeout_id);
+ }
+}
+
+void AdvanceTime(SocketUnderTest& a, SocketUnderTest& z, DurationMs duration) {
+ a.cb.AdvanceTime(duration);
+ z.cb.AdvanceTime(duration);
+
+ RunTimers(a);
+ RunTimers(z);
+}
+
+// Exchanges messages between `a` and `z`, advancing time until there are no
+// more pending timers, or until `max_timeout` is reached.
+void ExchangeMessagesAndAdvanceTime(
+ SocketUnderTest& a,
+ SocketUnderTest& z,
+ DurationMs max_timeout = DurationMs(10000)) {
+ TimeMs time_started = a.cb.TimeMillis();
+ while (a.cb.TimeMillis() - time_started < max_timeout) {
+ ExchangeMessages(a, z);
+
+ DurationMs time_to_next_timeout =
+ std::min(a.cb.GetTimeToNextTimeout(), z.cb.GetTimeToNextTimeout());
+ if (time_to_next_timeout == DurationMs::InfiniteDuration()) {
+ // No more pending timer.
+ return;
+ }
+ AdvanceTime(a, z, time_to_next_timeout);
+ }
+}
+
+// Calls Connect() on `sock_a_` and make the connection established.
+void ConnectSockets(SocketUnderTest& a, SocketUnderTest& z) {
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+
+ a.socket.Connect();
+ // Z reads INIT, INIT_ACK, COOKIE_ECHO, COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+std::unique_ptr<SocketUnderTest> HandoverSocket(
+ std::unique_ptr<SocketUnderTest> sut) {
+ EXPECT_EQ(sut->socket.GetHandoverReadiness(), HandoverReadinessStatus());
+
+ bool is_closed = sut->socket.state() == SocketState::kClosed;
+ if (!is_closed) {
+ EXPECT_CALL(sut->cb, OnClosed).Times(1);
+ }
+ absl::optional<DcSctpSocketHandoverState> handover_state =
+ sut->socket.GetHandoverStateAndClose();
+ EXPECT_TRUE(handover_state.has_value());
+ g_handover_state_transformer_for_test(&*handover_state);
+
+ auto handover_socket = std::make_unique<SocketUnderTest>("H", sut->options);
+ if (!is_closed) {
+ EXPECT_CALL(handover_socket->cb, OnConnected).Times(1);
+ }
+ handover_socket->socket.RestoreFromState(*handover_state);
+ return handover_socket;
+}
+
+std::vector<uint32_t> GetReceivedMessagePpids(SocketUnderTest& z) {
+ std::vector<uint32_t> ppids;
+ for (;;) {
+ absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+ if (!msg.has_value()) {
+ break;
+ }
+ ppids.push_back(*msg->ppid());
+ }
+ return ppids;
+}
+
+// Test parameter that controls whether to perform handovers during the test. A
+// test can have multiple points where it conditionally hands over socket Z.
+// Either socket Z will be handed over at all those points or handed over never.
+enum class HandoverMode {
+ kNoHandover,
+ kPerformHandovers,
+};
+
+class DcSctpSocketParametrizedTest
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<HandoverMode> {
+ protected:
+ // Trigger handover for `sut` depending on the current test param.
+ std::unique_ptr<SocketUnderTest> MaybeHandoverSocket(
+ std::unique_ptr<SocketUnderTest> sut) {
+ if (GetParam() == HandoverMode::kPerformHandovers) {
+ return HandoverSocket(std::move(sut));
+ }
+ return sut;
+ }
+
+ // Trigger handover for socket Z depending on the current test param.
+ // Then checks message passing to verify the handed over socket is functional.
+ void MaybeHandoverSocketAndSendMessage(SocketUnderTest& a,
+ std::unique_ptr<SocketUnderTest> z) {
+ if (GetParam() == HandoverMode::kPerformHandovers) {
+ z = HandoverSocket(std::move(z));
+ }
+
+ ExchangeMessages(a, *z);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ ExchangeMessages(a, *z);
+
+ absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ }
+};
+
+INSTANTIATE_TEST_SUITE_P(Handovers,
+ DcSctpSocketParametrizedTest,
+ testing::Values(HandoverMode::kNoHandover,
+ HandoverMode::kPerformHandovers),
+ [](const auto& test_info) {
+ return test_info.param ==
+ HandoverMode::kPerformHandovers
+ ? "WithHandovers"
+ : "NoHandover";
+ });
+
+TEST(DcSctpSocketTest, EstablishConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
+ EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
+
+ a.socket.Connect();
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, EstablishConnectionWithSetupCollision) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
+ EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
+ a.socket.Connect();
+ z.socket.Connect();
+
+ ExchangeMessages(a, z);
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, ShuttingDownWhileEstablishingConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(0);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ a.socket.Connect();
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // Drop COOKIE_ACK, just to more easily verify shutdown protocol.
+ z.cb.ConsumeSentPacket();
+
+ // As Socket A has received INIT_ACK, it has a TCB and is connected, while
+ // Socket Z needs to receive COOKIE_ECHO to get there. Socket A still has
+ // timers running at this point.
+ EXPECT_EQ(a.socket.state(), SocketState::kConnecting);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+
+ // Socket A is now shut down, which should make it stop those timers.
+ a.socket.Shutdown();
+
+ EXPECT_CALL(a.cb, OnClosed).Times(1);
+ EXPECT_CALL(z.cb, OnClosed).Times(1);
+
+ // Z reads SHUTDOWN, produces SHUTDOWN_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // Z reads SHUTDOWN_COMPLETE.
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
+ EXPECT_TRUE(z.cb.ConsumeSentPacket().empty());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+ EXPECT_EQ(z.socket.state(), SocketState::kClosed);
+}
+
+TEST(DcSctpSocketTest, EstablishSimultaneousConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
+ EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
+ a.socket.Connect();
+
+ // INIT isn't received by Z, as it wasn't ready yet.
+ a.cb.ConsumeSentPacket();
+
+ z.socket.Connect();
+
+ // A reads INIT, produces INIT_ACK
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // Z reads INIT_ACK, sends COOKIE_ECHO
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ // A reads COOKIE_ECHO - establishes connection.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+
+ // Proceed with the remaining packets.
+ ExchangeMessages(a, z);
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, EstablishConnectionLostCookieAck) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
+ EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
+
+ a.socket.Connect();
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // COOKIE_ACK is lost.
+ z.cb.ConsumeSentPacket();
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnecting);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+
+ // This will make A re-send the COOKIE_ECHO
+ AdvanceTime(a, z, DurationMs(a.options.t1_cookie_timeout));
+
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, ResendInitAndEstablishConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+ // INIT is never received by Z.
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsChunkType(InitChunk::kType))));
+
+ AdvanceTime(a, z, a.options.t1_init_timeout);
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, ResendingInitTooManyTimesAborts) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+
+ // INIT is never received by Z.
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsChunkType(InitChunk::kType))));
+
+ for (int i = 0; i < *a.options.max_init_retransmits; ++i) {
+ AdvanceTime(a, z, a.options.t1_init_timeout * (1 << i));
+
+ // INIT is resent
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsChunkType(InitChunk::kType))));
+ }
+
+ // Another timeout, after the max init retransmits.
+ EXPECT_CALL(a.cb, OnAborted).Times(1);
+ AdvanceTime(
+ a, z, a.options.t1_init_timeout * (1 << *a.options.max_init_retransmits));
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+}
+
+TEST(DcSctpSocketTest, ResendCookieEchoAndEstablishConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // COOKIE_ECHO is never received by Z.
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsChunkType(CookieEchoChunk::kType))));
+
+ AdvanceTime(a, z, a.options.t1_init_timeout);
+
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, ResendingCookieEchoTooManyTimesAborts) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // COOKIE_ECHO is never received by Z.
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsChunkType(CookieEchoChunk::kType))));
+
+ for (int i = 0; i < *a.options.max_init_retransmits; ++i) {
+ AdvanceTime(a, z, a.options.t1_cookie_timeout * (1 << i));
+
+ // COOKIE_ECHO is resent
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsChunkType(CookieEchoChunk::kType))));
+ }
+
+ // Another timeout, after the max init retransmits.
+ EXPECT_CALL(a.cb, OnAborted).Times(1);
+ AdvanceTime(
+ a, z,
+ a.options.t1_cookie_timeout * (1 << *a.options.max_init_retransmits));
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+}
+
+TEST(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ a.socket.Connect();
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // COOKIE_ECHO is never received by Z.
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsChunkType(CookieEchoChunk::kType),
+ IsDataChunk(_))));
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ // There are DATA chunks in the sent packet (that was lost), which means that
+ // the T3-RTX timer is running, but as the socket is in kCookieEcho state, it
+ // will be T1-COOKIE that drives retransmissions, so when the T3-RTX expires,
+ // nothing should be retransmitted.
+ ASSERT_TRUE(a.options.rto_initial < a.options.t1_cookie_timeout);
+ AdvanceTime(a, z, a.options.rto_initial);
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ // When T1-COOKIE expires, both the COOKIE-ECHO and DATA should be present.
+ AdvanceTime(a, z, a.options.t1_cookie_timeout - a.options.rto_initial);
+
+ // And this COOKIE-ECHO and DATA is also lost - never received by Z.
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsChunkType(CookieEchoChunk::kType),
+ IsDataChunk(_))));
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ // COOKIE_ECHO has exponential backoff.
+ AdvanceTime(a, z, a.options.t1_cookie_timeout * 2);
+
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+
+ ExchangeMessages(a, z);
+ EXPECT_THAT(z.cb.ConsumeReceivedMessage()->payload(),
+ SizeIs(kLargeMessageSize));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ShutdownConnection) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ RTC_LOG(LS_INFO) << "Shutting down";
+
+ EXPECT_CALL(z->cb, OnClosed).Times(1);
+ a.socket.Shutdown();
+ // Z reads SHUTDOWN, produces SHUTDOWN_ACK
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+ // Z reads SHUTDOWN_COMPLETE.
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+ EXPECT_EQ(z->socket.state(), SocketState::kClosed);
+
+ z = MaybeHandoverSocket(std::move(z));
+ EXPECT_EQ(z->socket.state(), SocketState::kClosed);
+}
+
+TEST(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ a.socket.Shutdown();
+ // Drop first SHUTDOWN packet.
+ a.cb.ConsumeSentPacket();
+
+ EXPECT_EQ(a.socket.state(), SocketState::kShuttingDown);
+
+ for (int i = 0; i < *a.options.max_retransmissions; ++i) {
+ AdvanceTime(a, z, DurationMs(a.options.rto_initial * (1 << i)));
+
+ // Dropping every shutdown chunk.
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsChunkType(ShutdownChunk::kType))));
+ EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
+ }
+ // The last expiry, makes it abort the connection.
+ EXPECT_CALL(a.cb, OnAborted).Times(1);
+ AdvanceTime(a, z,
+ a.options.rto_initial * (1 << *a.options.max_retransmissions));
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsChunkType(AbortChunk::kType))));
+ EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
+}
+
+TEST(DcSctpSocketTest, EstablishConnectionWhileSendingData) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+
+ absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+}
+
+TEST(DcSctpSocketTest, SendMessageAfterEstablished) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, TimeoutResendsPacket) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ a.cb.ConsumeSentPacket();
+
+ RTC_LOG(LS_INFO) << "Advancing time";
+ AdvanceTime(a, *z, a.options.rto_initial);
+
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendALotOfBytesMissedSecondPacket) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ std::vector<uint8_t> payload(kLargeMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+
+ // First DATA
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // Second DATA (lost)
+ a.cb.ConsumeSentPacket();
+
+ // Retransmit and handle the rest
+ ExchangeMessages(a, *z);
+
+ absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ // Inject a HEARTBEAT chunk
+ SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
+ uint8_t info[] = {1, 2, 3, 4};
+ Parameters::Builder params_builder;
+ params_builder.Add(HeartbeatInfoParameter(info));
+ b.Add(HeartbeatRequestChunk(params_builder.Build()));
+ a.socket.ReceivePacket(b.Build());
+
+ // HEARTBEAT_ACK is sent as a reply. Capture it.
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsHeartbeatAck(
+ Property(&HeartbeatAckChunk::info,
+ Optional(Property(&HeartbeatInfoParameter::info,
+ ElementsAre(1, 2, 3, 4))))))));
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ AdvanceTime(a, *z, a.options.heartbeat_interval);
+
+ std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
+ // The info is a single 64-bit number.
+ EXPECT_THAT(
+ packet,
+ HasChunks(ElementsAre(IsHeartbeatRequest(Property(
+ &HeartbeatRequestChunk::info,
+ Optional(Property(&HeartbeatInfoParameter::info, SizeIs(8))))))));
+
+ // Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back.
+ z->socket.ReceivePacket(packet);
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ CloseConnectionAfterTooManyLostHeartbeats) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(z->cb, OnClosed).Times(1);
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty());
+ // Force-close socket Z so that it doesn't interfere from now on.
+ z->socket.Close();
+
+ DurationMs time_to_next_hearbeat = a.options.heartbeat_interval;
+
+ for (int i = 0; i < *a.options.max_retransmissions; ++i) {
+ RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending...";
+ AdvanceTime(a, *z, time_to_next_hearbeat);
+
+ // Dropping every heartbeat.
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ SctpPacket hb_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket(), z->options));
+ EXPECT_EQ(hb_packet.descriptors()[0].type, HeartbeatRequestChunk::kType);
+
+ RTC_LOG(LS_INFO) << "Letting the heartbeat expire.";
+ AdvanceTime(a, *z, DurationMs(1000));
+
+ time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000);
+ }
+
+ RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending...";
+ AdvanceTime(a, *z, time_to_next_hearbeat);
+
+ // Last heartbeat
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), Not(IsEmpty()));
+
+ EXPECT_CALL(a.cb, OnAborted).Times(1);
+ // Should suffice as exceeding RTO
+ AdvanceTime(a, *z, DurationMs(1000));
+
+ z = MaybeHandoverSocket(std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty());
+ EXPECT_CALL(z->cb, OnClosed).Times(1);
+ // Force-close socket Z so that it doesn't interfere from now on.
+ z->socket.Close();
+
+ DurationMs time_to_next_hearbeat = a.options.heartbeat_interval;
+
+ for (int i = 0; i < *a.options.max_retransmissions; ++i) {
+ AdvanceTime(a, *z, time_to_next_hearbeat);
+
+ // Dropping every heartbeat.
+ a.cb.ConsumeSentPacket();
+
+ RTC_LOG(LS_INFO) << "Letting the heartbeat expire.";
+ AdvanceTime(a, *z, DurationMs(1000));
+
+ time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000);
+ }
+
+ RTC_LOG(LS_INFO) << "Getting the last heartbeat - and acking it";
+ AdvanceTime(a, *z, time_to_next_hearbeat);
+
+ std::vector<uint8_t> hb_packet_raw = a.cb.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet,
+ SctpPacket::Parse(hb_packet_raw, z->options));
+ ASSERT_THAT(hb_packet.descriptors(), SizeIs(1));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ HeartbeatRequestChunk hb,
+ HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data));
+
+ SctpPacket::Builder b(a.socket.verification_tag(), a.options);
+ b.Add(HeartbeatAckChunk(std::move(hb).extract_parameters()));
+ a.socket.ReceivePacket(b.Build());
+
+ // Should suffice as exceeding RTO - which will not fire.
+ EXPECT_CALL(a.cb, OnAborted).Times(0);
+ AdvanceTime(a, *z, DurationMs(1000));
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ // Verify that we get new heartbeats again.
+ RTC_LOG(LS_INFO) << "Expecting a new heartbeat";
+ AdvanceTime(a, *z, time_to_next_hearbeat);
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsHeartbeatRequest(_))));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ResetStream) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {});
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+
+ // Handle SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Reset the outgoing stream. This will directly send a RE-CONFIG.
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+
+ // Receiving the packet will trigger a callback, indicating that A has
+ // reset its stream. It will also send a RE-CONFIG with a response.
+ EXPECT_CALL(z->cb, OnIncomingStreamsReset).Times(1);
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ // Receiving a response will trigger a callback. Streams are now reset.
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed).Times(1);
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillMakeChunksStartAtZeroSsn) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ std::vector<uint8_t> payload(a.options.mtu - 100);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ auto packet1 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(
+ packet1,
+ HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0))))));
+ z->socket.ReceivePacket(packet1);
+
+ auto packet2 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(
+ packet2,
+ HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(1))))));
+ z->socket.ReceivePacket(packet2);
+
+ // Handle SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->stream_id(), StreamID(1));
+
+ absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg2.has_value());
+ EXPECT_EQ(msg2->stream_id(), StreamID(1));
+
+ // Reset the outgoing stream. This will directly send a RE-CONFIG.
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+ // RE-CONFIG, req
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // RE-CONFIG, resp
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ auto packet3 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(
+ packet3,
+ HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0))))));
+ z->socket.ReceivePacket(packet3);
+
+ auto packet4 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(
+ packet4,
+ HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(1))))));
+ z->socket.ReceivePacket(packet4);
+
+ // Handle SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ ResetStreamWillOnlyResetTheRequestedStreams) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ std::vector<uint8_t> payload(a.options.mtu - 100);
+
+ // Send two ordered messages on SID 1
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ auto packet1 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet1, HasChunks(ElementsAre(IsDataChunk(
+ AllOf(Property(&DataChunk::stream_id, StreamID(1)),
+ Property(&DataChunk::ssn, SSN(0)))))));
+ z->socket.ReceivePacket(packet1);
+
+ auto packet2 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet2, HasChunks(ElementsAre(IsDataChunk(
+ AllOf(Property(&DataChunk::stream_id, StreamID(1)),
+ Property(&DataChunk::ssn, SSN(1)))))));
+ z->socket.ReceivePacket(packet2);
+
+ // Handle SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Do the same, for SID 3
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
+ auto packet3 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet3, HasChunks(ElementsAre(IsDataChunk(
+ AllOf(Property(&DataChunk::stream_id, StreamID(3)),
+ Property(&DataChunk::ssn, SSN(0)))))));
+ z->socket.ReceivePacket(packet3);
+ auto packet4 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet4, HasChunks(ElementsAre(IsDataChunk(
+ AllOf(Property(&DataChunk::stream_id, StreamID(3)),
+ Property(&DataChunk::ssn, SSN(1)))))));
+ z->socket.ReceivePacket(packet4);
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Receive all messages.
+ absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->stream_id(), StreamID(1));
+
+ absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg2.has_value());
+ EXPECT_EQ(msg2->stream_id(), StreamID(1));
+
+ absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg3.has_value());
+ EXPECT_EQ(msg3->stream_id(), StreamID(3));
+
+ absl::optional<DcSctpMessage> msg4 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg4.has_value());
+ EXPECT_EQ(msg4->stream_id(), StreamID(3));
+
+ // Reset SID 1. This will directly send a RE-CONFIG.
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
+ // RE-CONFIG, req
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // RE-CONFIG, resp
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Send a message on SID 1 and 3 - SID 1 should not be reset, but 3 should.
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
+
+ auto packet5 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet5,
+ HasChunks(ElementsAre(IsDataChunk(
+ AllOf(Property(&DataChunk::stream_id, StreamID(1)),
+ Property(&DataChunk::ssn, SSN(2))))))); // Unchanged.
+ z->socket.ReceivePacket(packet5);
+
+ auto packet6 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet6, HasChunks(ElementsAre(IsDataChunk(AllOf(
+ Property(&DataChunk::stream_id, StreamID(3)),
+ Property(&DataChunk::ssn, SSN(0))))))); // Reset
+ z->socket.ReceivePacket(packet6);
+
+ // Handle SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, OnePeerReconnects) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnConnectionRestarted).Times(1);
+ // Let's be evil here - reconnect while a fragmented packet was about to be
+ // sent. The receiving side should get it in full.
+ std::vector<uint8_t> payload(kLargeMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+
+ // First DATA
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ // Create a new association, z2 - and don't use z anymore.
+ SocketUnderTest z2("Z2");
+ z2.socket.Connect();
+
+ // Retransmit and handle the rest. As there will be some chunks in-flight that
+ // have the wrong verification tag, those will yield errors.
+ ExchangeMessages(a, z2);
+
+ absl::optional<DcSctpMessage> msg = z2.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendMessageWithLimitedRtx) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ SendOptions send_options;
+ send_options.max_retransmissions = 0;
+ std::vector<uint8_t> payload(a.options.mtu - 100);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options);
+
+ // First DATA
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // Second DATA (lost)
+ a.cb.ConsumeSentPacket();
+ // Third DATA
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ // Handle SACK for first DATA
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Handle delayed SACK for third DATA
+ AdvanceTime(a, *z, a.options.delayed_ack_max_timeout);
+
+ // Handle SACK for second DATA
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Now the missing data chunk will be marked as nacked, but it might still be
+ // in-flight and the reported gap could be due to out-of-order delivery. So
+ // the RetransmissionQueue will not mark it as "to be retransmitted" until
+ // after the t3-rtx timer has expired.
+ AdvanceTime(a, *z, a.options.rto_initial);
+
+ // The chunk will be marked as retransmitted, and then as abandoned, which
+ // will trigger a FORWARD-TSN to be sent.
+
+ // FORWARD-TSN (third)
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ // Which will trigger a SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->ppid(), PPID(51));
+
+ absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg2.has_value());
+ EXPECT_EQ(msg2->ppid(), PPID(53));
+
+ absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage();
+ EXPECT_FALSE(msg3.has_value());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendManyFragmentedMessagesWithLimitedRtx) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ SendOptions send_options;
+ send_options.unordered = IsUnordered(true);
+ send_options.max_retransmissions = 0;
+ std::vector<uint8_t> payload(a.options.mtu * 2 - 100 /* margin */);
+ // Sending first message
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
+ // Sending second message
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
+ // Sending third message
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options);
+ // Sending fourth message
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(54), payload), send_options);
+
+ // First DATA, first fragment
+ std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasChunks(ElementsAre(
+ IsDataChunk(Property(&DataChunk::ppid, PPID(51))))));
+ z->socket.ReceivePacket(std::move(packet));
+
+ // First DATA, second fragment (lost)
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasChunks(ElementsAre(
+ IsDataChunk(Property(&DataChunk::ppid, PPID(51))))));
+
+ // Second DATA, first fragment
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasChunks(ElementsAre(
+ IsDataChunk(Property(&DataChunk::ppid, PPID(52))))));
+ z->socket.ReceivePacket(std::move(packet));
+
+ // Second DATA, second fragment (lost)
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasChunks(ElementsAre(IsDataChunk(
+ AllOf(Property(&DataChunk::ppid, PPID(52)),
+ Property(&DataChunk::ssn, SSN(0)))))));
+
+ // Third DATA, first fragment
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasChunks(ElementsAre(IsDataChunk(
+ AllOf(Property(&DataChunk::ppid, PPID(53)),
+ Property(&DataChunk::ssn, SSN(0)))))));
+ z->socket.ReceivePacket(std::move(packet));
+
+ // Third DATA, second fragment (lost)
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasChunks(ElementsAre(IsDataChunk(
+ AllOf(Property(&DataChunk::ppid, PPID(53)),
+ Property(&DataChunk::ssn, SSN(0)))))));
+
+ // Fourth DATA, first fragment
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasChunks(ElementsAre(IsDataChunk(
+ AllOf(Property(&DataChunk::ppid, PPID(54)),
+ Property(&DataChunk::ssn, SSN(0)))))));
+ z->socket.ReceivePacket(std::move(packet));
+
+ // Fourth DATA, second fragment
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasChunks(ElementsAre(IsDataChunk(
+ AllOf(Property(&DataChunk::ppid, PPID(54)),
+ Property(&DataChunk::ssn, SSN(0)))))));
+ z->socket.ReceivePacket(std::move(packet));
+
+ ExchangeMessages(a, *z);
+
+ // Let the RTX timer expire, and exchange FORWARD-TSN/SACKs
+ AdvanceTime(a, *z, a.options.rto_initial);
+
+ ExchangeMessages(a, *z);
+
+ absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->ppid(), PPID(54));
+
+ ASSERT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+struct FakeChunkConfig : ChunkConfig {
+ static constexpr int kType = 0x49;
+ static constexpr size_t kHeaderSize = 4;
+ static constexpr int kVariableLengthAlignment = 0;
+};
+
+class FakeChunk : public Chunk, public TLVTrait<FakeChunkConfig> {
+ public:
+ FakeChunk() {}
+
+ FakeChunk(FakeChunk&& other) = default;
+ FakeChunk& operator=(FakeChunk&& other) = default;
+
+ void SerializeTo(std::vector<uint8_t>& out) const override {
+ AllocateTLV(out);
+ }
+ std::string ToString() const override { return "FAKE"; }
+};
+
+TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ // Inject a FAKE chunk
+ SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
+ b.Add(FakeChunk());
+ a.socket.ReceivePacket(b.Build());
+
+ // ERROR is sent as a reply. Capture it.
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ SctpPacket reply_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket(), z->options));
+ ASSERT_THAT(reply_packet.descriptors(), SizeIs(1));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ ErrorChunk error, ErrorChunk::Parse(reply_packet.descriptors()[0].data));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ UnrecognizedChunkTypeCause cause,
+ error.error_causes().get<UnrecognizedChunkTypeCause>());
+ EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04));
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ReceivingErrorChunkReportsAsCallback) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ // Inject a ERROR chunk
+ SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
+ b.Add(
+ ErrorChunk(Parameters::Builder()
+ .Add(UnrecognizedChunkTypeCause({0x49, 0x00, 0x00, 0x04}))
+ .Build()));
+
+ EXPECT_CALL(a.cb, OnError(ErrorKind::kPeerReported,
+ HasSubstr("Unrecognized Chunk Type")));
+ a.socket.ReceivePacket(b.Build());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) {
+ SocketUnderTest a("A");
+
+ constexpr size_t kReceiveWindowBufferSize = 2000;
+ SocketUnderTest z(
+ "Z", {.mtu = 3000,
+ .max_receiver_window_buffer_size = kReceiveWindowBufferSize});
+
+ EXPECT_CALL(z.cb, OnClosed).Times(0);
+ EXPECT_CALL(z.cb, OnAborted).Times(0);
+
+ a.socket.Connect();
+ std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
+ SctpPacket::Parse(init_data, z.options));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ InitChunk init_chunk,
+ InitChunk::Parse(init_packet.descriptors()[0].data));
+ z.socket.ReceivePacket(init_data);
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // Fill up Z2 to the high watermark limit.
+ constexpr size_t kWatermarkLimit =
+ kReceiveWindowBufferSize * ReassemblyQueue::kHighWatermarkLimit;
+ constexpr size_t kRemainingSize = kReceiveWindowBufferSize - kWatermarkLimit;
+
+ TSN tsn = init_chunk.initial_tsn();
+ AnyDataChunk::Options opts;
+ opts.is_beginning = Data::IsBeginning(true);
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(kWatermarkLimit + 1), opts))
+ .Build());
+
+ // First DATA will always trigger a SACK. It's not interesting.
+ EXPECT_THAT(z.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsSack(
+ AllOf(Property(&SackChunk::cumulative_tsn_ack, tsn),
+ Property(&SackChunk::gap_ack_blocks, IsEmpty()))))));
+
+ // This DATA should be accepted - it's advancing cum ack tsn.
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(AddTo(tsn, 1), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(1),
+ /*options=*/{}))
+ .Build());
+
+ // The receiver might have moved into delayed ack mode.
+ AdvanceTime(a, z, z.options.rto_initial);
+
+ EXPECT_THAT(z.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsSack(
+ AllOf(Property(&SackChunk::cumulative_tsn_ack, AddTo(tsn, 1)),
+ Property(&SackChunk::gap_ack_blocks, IsEmpty()))))));
+
+ // This DATA will not be accepted - it's not advancing cum ack tsn.
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(1),
+ /*options=*/{}))
+ .Build());
+
+ // Sack will be sent in IMMEDIATE mode when this is happening.
+ EXPECT_THAT(z.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsSack(
+ AllOf(Property(&SackChunk::cumulative_tsn_ack, AddTo(tsn, 1)),
+ Property(&SackChunk::gap_ack_blocks, IsEmpty()))))));
+
+ // This DATA will not be accepted either.
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(AddTo(tsn, 4), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(1),
+ /*options=*/{}))
+ .Build());
+
+ // Sack will be sent in IMMEDIATE mode when this is happening.
+ EXPECT_THAT(z.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsSack(
+ AllOf(Property(&SackChunk::cumulative_tsn_ack, AddTo(tsn, 1)),
+ Property(&SackChunk::gap_ack_blocks, IsEmpty()))))));
+
+ // This DATA should be accepted, and it fills the reassembly queue.
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(AddTo(tsn, 2), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(kRemainingSize),
+ /*options=*/{}))
+ .Build());
+
+ // The receiver might have moved into delayed ack mode.
+ AdvanceTime(a, z, z.options.rto_initial);
+
+ EXPECT_THAT(z.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(IsSack(
+ AllOf(Property(&SackChunk::cumulative_tsn_ack, AddTo(tsn, 2)),
+ Property(&SackChunk::gap_ack_blocks, IsEmpty()))))));
+
+ EXPECT_CALL(z.cb, OnAborted(ErrorKind::kResourceExhaustion, _));
+ EXPECT_CALL(z.cb, OnClosed).Times(0);
+
+ // This DATA will make the connection close. It's too full now.
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(kSmallMessageSize),
+ /*options=*/{}))
+ .Build());
+}
+
+TEST(DcSctpSocketTest, SetMaxMessageSize) {
+ SocketUnderTest a("A");
+
+ a.socket.SetMaxMessageSize(42u);
+ EXPECT_EQ(a.socket.options().max_message_size, 42u);
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ // Mock that the time always goes forward.
+ TimeMs now(0);
+ EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() {
+ now += DurationMs(3);
+ return now;
+ });
+ EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() {
+ now += DurationMs(3);
+ return now;
+ });
+
+ // Queue a few small messages with low lifetime, both ordered and unordered,
+ // and validate that all are delivered.
+ static constexpr int kIterations = 100;
+ for (int i = 0; i < kIterations; ++i) {
+ SendOptions send_options;
+ send_options.unordered = IsUnordered((i % 2) == 0);
+ send_options.lifetime = DurationMs(i % 3); // 0, 1, 2 ms
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
+ }
+
+ ExchangeMessages(a, *z);
+
+ for (int i = 0; i < kIterations; ++i) {
+ EXPECT_TRUE(z->cb.ConsumeReceivedMessage().has_value());
+ }
+
+ EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
+
+ // Validate that the sockets really make the time move forward.
+ EXPECT_GE(*now, kIterations * 2);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ DiscardsMessagesWithLowLifetimeIfMustBuffer) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ SendOptions lifetime_0;
+ lifetime_0.unordered = IsUnordered(true);
+ lifetime_0.lifetime = DurationMs(0);
+
+ SendOptions lifetime_1;
+ lifetime_1.unordered = IsUnordered(true);
+ lifetime_1.lifetime = DurationMs(1);
+
+ // Mock that the time always goes forward.
+ TimeMs now(0);
+ EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() {
+ now += DurationMs(3);
+ return now;
+ });
+ EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() {
+ now += DurationMs(3);
+ return now;
+ });
+
+ // Fill up the send buffer with a large message.
+ std::vector<uint8_t> payload(kLargeMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+
+ // And queue a few small messages with lifetime=0 or 1 ms - can't be sent.
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), lifetime_0);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {4, 5, 6}), lifetime_1);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {7, 8, 9}), lifetime_0);
+
+ // Handle all that was sent until congestion window got full.
+ for (;;) {
+ std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
+ if (packet_from_a.empty()) {
+ break;
+ }
+ z->socket.ReceivePacket(std::move(packet_from_a));
+ }
+
+ // Shouldn't be enough to send that large message.
+ EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
+
+ // Exchange the rest of the messages, with the time ever increasing.
+ ExchangeMessages(a, *z);
+
+ // The large message should be delivered. It was sent reliably.
+ ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage m1, z->cb.ConsumeReceivedMessage());
+ EXPECT_EQ(m1.stream_id(), StreamID(1));
+ EXPECT_THAT(m1.payload(), SizeIs(kLargeMessageSize));
+
+ // But none of the smaller messages.
+ EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, HasReasonableBufferedAmountValues) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ // Sending a small message will directly send it as a single packet, so
+ // nothing is left in the queue.
+ EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ // Sending a message will directly start sending a few packets, so the
+ // buffered amount is not the full message size.
+ EXPECT_GT(a.socket.buffered_amount(StreamID(1)), 0u);
+ EXPECT_LT(a.socket.buffered_amount(StreamID(1)), kLargeMessageSize);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) {
+ SocketUnderTest a("A");
+ EXPECT_EQ(a.socket.buffered_amount_low_threshold(StreamID(1)), 0u);
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ TriggersOnBufferedAmountLowWithDefaultValueZero) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1)));
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).WillRepeatedly(testing::Return());
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ DoesntTriggerOnBufferedAmountLowIfBelowThreshold) {
+ static constexpr size_t kMessageSize = 1000;
+ static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10;
+
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ a.socket.SetBufferedAmountLowThreshold(StreamID(1),
+ kBufferedAmountLowThreshold);
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(0);
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountMultipleTimes) {
+ static constexpr size_t kMessageSize = 1000;
+ static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2;
+
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ a.socket.SetBufferedAmountLowThreshold(StreamID(1),
+ kBufferedAmountLowThreshold);
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(3);
+ EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(2))).Times(2);
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ a.socket.Send(
+ DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ a.socket.Send(
+ DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
+ static constexpr size_t kMessageSize = 1000;
+ static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5;
+
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ a.socket.SetBufferedAmountLowThreshold(StreamID(1),
+ kBufferedAmountLowThreshold);
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
+
+ // Add a few messages to fill up the congestion window. When that is full,
+ // messages will start to be fully buffered.
+ while (a.socket.buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) {
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ }
+ size_t initial_buffered = a.socket.buffered_amount(StreamID(1));
+ ASSERT_GT(initial_buffered, kBufferedAmountLowThreshold);
+
+ // Start ACKing packets, which will empty the send queue, and trigger the
+ // callback.
+ EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(1);
+ ExchangeMessages(a, *z);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ DoesntTriggerOnTotalBufferAmountLowWhenBelow) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ ExchangeMessages(a, *z);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ TriggersOnTotalBufferAmountLowWhenCrossingThreshold) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0);
+
+ // Fill up the send queue completely.
+ for (;;) {
+ if (a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions) == SendStatus::kErrorResourceExhaustion) {
+ break;
+ }
+ }
+
+ EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(1);
+ ExchangeMessages(a, *z);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST(DcSctpSocketTest, InitialMetricsAreUnset) {
+ SocketUnderTest a("A");
+
+ EXPECT_FALSE(a.socket.GetMetrics().has_value());
+}
+
+TEST(DcSctpSocketTest, MessageInterleavingMetricsAreSet) {
+ std::vector<std::pair<bool, bool>> combinations = {
+ {false, false}, {false, true}, {true, false}, {true, true}};
+ for (const auto& [a_enable, z_enable] : combinations) {
+ DcSctpOptions a_options = {.enable_message_interleaving = a_enable};
+ DcSctpOptions z_options = {.enable_message_interleaving = z_enable};
+
+ SocketUnderTest a("A", a_options);
+ SocketUnderTest z("Z", z_options);
+ ConnectSockets(a, z);
+
+ EXPECT_EQ(a.socket.GetMetrics()->uses_message_interleaving,
+ a_enable && z_enable);
+ }
+}
+
+TEST(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ const size_t initial_a_rwnd = a.options.max_receiver_window_buffer_size *
+ ReassemblyQueue::kHighWatermarkLimit;
+
+ EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 2u);
+ EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 2u);
+ EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 0u);
+ EXPECT_EQ(a.socket.GetMetrics()->cwnd_bytes,
+ a.options.cwnd_mtus_initial * a.options.mtu);
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
+
+ EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 2u);
+ EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 0u);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u);
+
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK
+ EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
+
+ EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value());
+
+ EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 3u);
+ EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 3u);
+ EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 1u);
+
+ EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 3u);
+ EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 1u);
+
+ // Send one more (large - fragmented), and receive the delayed SACK.
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(a.options.mtu * 2 + 1)),
+ kSendOptions);
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 3u);
+
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
+
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u);
+ EXPECT_GT(a.socket.GetMetrics()->peer_rwnd_bytes, 0u);
+ EXPECT_LT(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
+
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
+
+ EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value());
+
+ EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 6u);
+ EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 4u);
+ EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 2u);
+
+ EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 6u);
+ EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 2u);
+
+ // Delayed sack
+ AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
+
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
+ EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 5u);
+ EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
+}
+
+TEST(DcSctpSocketTest, RetransmissionMetricsAreSetForFastRetransmit) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+ ConnectSockets(a, z);
+
+ // Enough to trigger fast retransmit of the missing second packet.
+ std::vector<uint8_t> payload(DcSctpOptions::kMaxSafeMTUSize * 5);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+
+ // Receive first packet, drop second, receive and retransmit the remaining.
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ a.cb.ConsumeSentPacket();
+ ExchangeMessages(a, z);
+
+ EXPECT_EQ(a.socket.GetMetrics()->rtx_packets_count, 1u);
+ size_t expected_data_size =
+ RoundDownTo4(DcSctpOptions::kMaxSafeMTUSize - SctpPacket::kHeaderSize);
+ EXPECT_EQ(a.socket.GetMetrics()->rtx_bytes_count, expected_data_size);
+}
+
+TEST(DcSctpSocketTest, RetransmissionMetricsAreSetForNormalRetransmit) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+ ConnectSockets(a, z);
+
+ std::vector<uint8_t> payload(kSmallMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+
+ a.cb.ConsumeSentPacket();
+ AdvanceTime(a, z, a.options.rto_initial);
+ ExchangeMessages(a, z);
+
+ EXPECT_EQ(a.socket.GetMetrics()->rtx_packets_count, 1u);
+ size_t expected_data_size =
+ RoundUpTo4(kSmallMessageSize + DataChunk::kHeaderSize);
+ EXPECT_EQ(a.socket.GetMetrics()->rtx_bytes_count, expected_data_size);
+}
+
+TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ size_t payload_bytes =
+ a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
+
+ size_t expected_sent_packets = a.options.cwnd_mtus_initial;
+
+ size_t expected_queued_bytes =
+ kLargeMessageSize - expected_sent_packets * payload_bytes;
+
+ size_t expected_queued_packets = expected_queued_bytes / payload_bytes;
+
+ // Due to alignment, padding etc, it's hard to calculate the exact number, but
+ // it should be in this range.
+ EXPECT_GE(a.socket.GetMetrics()->unack_data_count,
+ expected_sent_packets + expected_queued_packets);
+
+ EXPECT_LE(a.socket.GetMetrics()->unack_data_count,
+ expected_sent_packets + expected_queued_packets + 2);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, DoesntSendMoreThanMaxBurstPackets) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ for (int i = 0; i < kMaxBurstPackets; ++i) {
+ std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, Not(IsEmpty()));
+ z->socket.ReceivePacket(std::move(packet)); // DATA
+ }
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ ExchangeMessages(a, *z);
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ // A really large message, to ensure that the congestion window is often full.
+ constexpr size_t kMessageSize = 100000;
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+
+ bool delivered_packet = false;
+ std::vector<size_t> data_packet_sizes;
+ do {
+ delivered_packet = false;
+ std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
+ if (!packet_from_a.empty()) {
+ data_packet_sizes.push_back(packet_from_a.size());
+ delivered_packet = true;
+ z->socket.ReceivePacket(std::move(packet_from_a));
+ }
+ std::vector<uint8_t> packet_from_z = z->cb.ConsumeSentPacket();
+ if (!packet_from_z.empty()) {
+ delivered_packet = true;
+ a.socket.ReceivePacket(std::move(packet_from_z));
+ }
+ } while (delivered_packet);
+
+ size_t packet_payload_bytes =
+ a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
+ // +1 accounts for padding, and rounding up.
+ size_t expected_packets =
+ (kMessageSize + packet_payload_bytes - 1) / packet_payload_bytes + 1;
+ EXPECT_THAT(data_packet_sizes, SizeIs(expected_packets));
+
+ // Remove the last size - it will be the remainder. But all other sizes should
+ // be large.
+ data_packet_sizes.pop_back();
+
+ for (size_t size : data_packet_sizes) {
+ // The 4 is for padding/alignment.
+ EXPECT_GE(size, a.options.mtu - 4);
+ }
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST(DcSctpSocketTest, SendMessagesAfterHandover) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+
+ // Send message before handover to move socket to a not initial state
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ z->cb.ConsumeReceivedMessage();
+
+ z = HandoverSocket(std::move(z));
+
+ absl::optional<DcSctpMessage> msg;
+
+ RTC_LOG(LS_INFO) << "Sending A #1";
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions);
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), testing::ElementsAre(3, 4));
+
+ RTC_LOG(LS_INFO) << "Sending A #2";
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions);
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(2));
+ EXPECT_THAT(msg->payload(), testing::ElementsAre(5, 6));
+
+ RTC_LOG(LS_INFO) << "Sending Z #1";
+
+ z->socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions);
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // ack
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // data
+
+ msg = a.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), testing::ElementsAre(1, 2, 3));
+}
+
+TEST(DcSctpSocketTest, CanDetectDcsctpImplementation) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp);
+
+ // As A initiated the connection establishment, Z will not receive enough
+ // information to know about A's implementation
+ EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kUnknown);
+}
+
+TEST(DcSctpSocketTest, BothCanDetectDcsctpImplementation) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ a.socket.Connect();
+ z.socket.Connect();
+
+ ExchangeMessages(a, z);
+
+ EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp);
+ EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kDcsctp);
+}
+
+TEST_P(DcSctpSocketParametrizedTest, CanLoseFirstOrderedMessage) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ SendOptions send_options;
+ send_options.unordered = IsUnordered(false);
+ send_options.max_retransmissions = 0;
+ std::vector<uint8_t> payload(a.options.mtu - 100);
+
+ // Send a first message (SID=1, SSN=0)
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
+
+ // First DATA is lost, and retransmission timer will delete it.
+ a.cb.ConsumeSentPacket();
+ AdvanceTime(a, *z, a.options.rto_initial);
+ ExchangeMessages(a, *z);
+
+ // Send a second message (SID=0, SSN=1).
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
+ ExchangeMessages(a, *z);
+
+ // The Z socket should receive the second message, but not the first.
+ absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->ppid(), PPID(52));
+
+ EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST(DcSctpSocketTest, ReceiveBothUnorderedAndOrderedWithSameTSN) {
+ /* This issue was found by fuzzing. */
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+ std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
+ SctpPacket::Parse(init_data, z.options));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ InitChunk init_chunk,
+ InitChunk::Parse(init_packet.descriptors()[0].data));
+ z.socket.ReceivePacket(init_data);
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // Receive a short unordered message with tsn=INITIAL_TSN+1
+ TSN tsn = init_chunk.initial_tsn();
+ AnyDataChunk::Options opts;
+ opts.is_beginning = Data::IsBeginning(true);
+ opts.is_end = Data::IsEnd(true);
+ opts.is_unordered = IsUnordered(true);
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(10), opts))
+ .Build());
+
+ // Now receive a longer _ordered_ message with [INITIAL_TSN, INITIAL_TSN+1].
+ // This isn't allowed as it reuses TSN=53 with different properties, but it
+ // shouldn't cause any issues.
+ opts.is_unordered = IsUnordered(false);
+ opts.is_end = Data::IsEnd(false);
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(10), opts))
+ .Build());
+
+ opts.is_beginning = Data::IsBeginning(false);
+ opts.is_end = Data::IsEnd(true);
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(10), opts))
+ .Build());
+}
+
+TEST(DcSctpSocketTest, CloseTwoStreamsAtTheSameTime) {
+ // Reported as https://crbug.com/1312009.
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(2)))).Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(2)))).Times(1);
+
+ ConnectSockets(a, z);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+
+ ExchangeMessages(a, z);
+
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
+
+ ExchangeMessages(a, z);
+}
+
+TEST(DcSctpSocketTest, CloseThreeStreamsAtTheSameTime) {
+ // Similar to CloseTwoStreamsAtTheSameTime, but ensuring that the two
+ // remaining streams are reset at the same time in the second request.
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(
+ UnorderedElementsAre(StreamID(2), StreamID(3))))
+ .Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(
+ UnorderedElementsAre(StreamID(2), StreamID(3))))
+ .Times(1);
+
+ ConnectSockets(a, z);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), kSendOptions);
+
+ ExchangeMessages(a, z);
+
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
+
+ ExchangeMessages(a, z);
+}
+
+TEST(DcSctpSocketTest, CloseStreamsWithPendingRequest) {
+ // Checks that stream reset requests are properly paused when they can't be
+ // immediately reset - i.e. when there is already an ongoing stream reset
+ // request (and there can only be a single one in-flight).
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(
+ UnorderedElementsAre(StreamID(2), StreamID(3))))
+ .Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(
+ UnorderedElementsAre(StreamID(2), StreamID(3))))
+ .Times(1);
+
+ ConnectSockets(a, z);
+
+ SendOptions send_options = {.unordered = IsUnordered(false)};
+
+ // Send a few ordered messages
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options);
+
+ ExchangeMessages(a, z);
+
+ // Receive these messages
+ absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->stream_id(), StreamID(1));
+ absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg2.has_value());
+ EXPECT_EQ(msg2->stream_id(), StreamID(2));
+ absl::optional<DcSctpMessage> msg3 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg3.has_value());
+ EXPECT_EQ(msg3->stream_id(), StreamID(3));
+
+ // Reset the streams - not all at once.
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+
+ std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasChunks(ElementsAre(IsReConfig(HasParameters(
+ ElementsAre(IsOutgoingResetRequest(Property(
+ &OutgoingSSNResetRequestParameter::stream_ids,
+ ElementsAre(StreamID(1))))))))));
+ z.socket.ReceivePacket(std::move(packet));
+
+ // Sending more reset requests while this one is ongoing.
+
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
+
+ ExchangeMessages(a, z);
+
+ // Send a few more ordered messages
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options);
+
+ ExchangeMessages(a, z);
+
+ // Receive these messages
+ absl::optional<DcSctpMessage> msg4 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg4.has_value());
+ EXPECT_EQ(msg4->stream_id(), StreamID(1));
+ absl::optional<DcSctpMessage> msg5 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg5.has_value());
+ EXPECT_EQ(msg5->stream_id(), StreamID(2));
+ absl::optional<DcSctpMessage> msg6 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg6.has_value());
+ EXPECT_EQ(msg6->stream_id(), StreamID(3));
+}
+
+TEST(DcSctpSocketTest, StreamsHaveInitialPriority) {
+ DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+ SocketUnderTest a("A", options);
+
+ EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)),
+ options.default_stream_priority);
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+
+ EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)),
+ options.default_stream_priority);
+}
+
+TEST(DcSctpSocketTest, CanChangeStreamPriority) {
+ DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+ SocketUnderTest a("A", options);
+
+ a.socket.SetStreamPriority(StreamID(1), StreamPriority(43));
+ EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+
+ a.socket.SetStreamPriority(StreamID(2), StreamPriority(43));
+ EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, WillHandoverPriority) {
+ DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+ auto a = std::make_unique<SocketUnderTest>("A", options);
+ SocketUnderTest z("Z");
+
+ ConnectSockets(*a, z);
+
+ a->socket.SetStreamPriority(StreamID(1), StreamPriority(43));
+ a->socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+ a->socket.SetStreamPriority(StreamID(2), StreamPriority(43));
+
+ ExchangeMessages(*a, z);
+
+ a = MaybeHandoverSocket(std::move(a));
+
+ EXPECT_EQ(a->socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
+ EXPECT_EQ(a->socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
+}
+
+TEST(DcSctpSocketTest, ReconnectSocketWithPendingStreamReset) {
+ // This is an issue found by fuzzing, and doesn't really make sense in WebRTC
+ // data channels as a SCTP connection is never ever closed and then
+ // reconnected. SCTP connections are closed when the peer connection is
+ // deleted, and then it doesn't do more with SCTP.
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+
+ EXPECT_CALL(z.cb, OnAborted).Times(1);
+ a.socket.Close();
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ a.socket.Connect();
+ ExchangeMessages(a, z);
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
+}
+
+TEST(DcSctpSocketTest, SmallSentMessagesWithPrioWillArriveInSpecificOrder) {
+ DcSctpOptions options = {.enable_message_interleaving = true};
+ SocketUnderTest a("A", options);
+ SocketUnderTest z("A", options);
+
+ a.socket.SetStreamPriority(StreamID(1), StreamPriority(700));
+ a.socket.SetStreamPriority(StreamID(2), StreamPriority(200));
+ a.socket.SetStreamPriority(StreamID(3), StreamPriority(100));
+
+ // Enqueue messages before connecting the socket, to ensure they aren't send
+ // as soon as Send() is called.
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(301),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(103),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+
+ ConnectSockets(a, z);
+ ExchangeMessages(a, z);
+
+ std::vector<uint32_t> received_ppids;
+ for (;;) {
+ absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+ if (!msg.has_value()) {
+ break;
+ }
+ received_ppids.push_back(*msg->ppid());
+ }
+
+ EXPECT_THAT(received_ppids, ElementsAre(101, 102, 103, 201, 301));
+}
+
+TEST(DcSctpSocketTest, LargeSentMessagesWithPrioWillArriveInSpecificOrder) {
+ DcSctpOptions options = {.enable_message_interleaving = true};
+ SocketUnderTest a("A", options);
+ SocketUnderTest z("A", options);
+
+ a.socket.SetStreamPriority(StreamID(1), StreamPriority(700));
+ a.socket.SetStreamPriority(StreamID(2), StreamPriority(200));
+ a.socket.SetStreamPriority(StreamID(3), StreamPriority(100));
+
+ // Enqueue messages before connecting the socket, to ensure they aren't send
+ // as soon as Send() is called.
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(301),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ ConnectSockets(a, z);
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201, 301));
+}
+
+TEST(DcSctpSocketTest, MessageWithHigherPrioWillInterruptLowerPrioMessage) {
+ DcSctpOptions options = {.enable_message_interleaving = true};
+ SocketUnderTest a("A", options);
+ SocketUnderTest z("Z", options);
+
+ ConnectSockets(a, z);
+
+ a.socket.SetStreamPriority(StreamID(2), StreamPriority(128));
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ // Due to a non-zero initial congestion window, the message will already start
+ // to send, but will not succeed to be sent completely before filling the
+ // congestion window or stopping due to reaching how many packets that can be
+ // sent at once (max burst). The important thing is that the entire message
+ // doesn't get sent in full.
+
+ // Now enqueue two messages; one small and one large higher priority message.
+ a.socket.SetStreamPriority(StreamID(1), StreamPriority(512));
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201));
+}
+
+TEST(DcSctpSocketTest, LifecycleEventsAreGeneratedForAckedMessages) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+ ConnectSockets(a, z);
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(101),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ {.lifecycle_id = LifecycleId(41)});
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(102),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(103),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ {.lifecycle_id = LifecycleId(42)});
+
+ EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(41)));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(41)));
+ EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(42)));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(42)));
+ ExchangeMessages(a, z);
+ // In case of delayed ack.
+ AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 103));
+}
+
+TEST(DcSctpSocketTest, LifecycleEventsForFailMaxRetransmissions) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+ ConnectSockets(a, z);
+
+ std::vector<uint8_t> payload(a.options.mtu - 100);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
+ {
+ .max_retransmissions = 0,
+ .lifecycle_id = LifecycleId(1),
+ });
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload),
+ {
+ .max_retransmissions = 0,
+ .lifecycle_id = LifecycleId(2),
+ });
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload),
+ {
+ .max_retransmissions = 0,
+ .lifecycle_id = LifecycleId(3),
+ });
+
+ // First DATA
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // Second DATA (lost)
+ a.cb.ConsumeSentPacket();
+
+ EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(1)));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
+ EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(2),
+ /*maybe_delivered=*/true));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(2)));
+ EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(3)));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(3)));
+ ExchangeMessages(a, z);
+
+ // Handle delayed SACK.
+ AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
+ ExchangeMessages(a, z);
+
+ // The chunk is now NACKed. Let the RTO expire, to discard the message.
+ AdvanceTime(a, z, a.options.rto_initial);
+ ExchangeMessages(a, z);
+
+ // Handle delayed SACK.
+ AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(51, 53));
+}
+
+TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithRetransmitLimit) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+ ConnectSockets(a, z);
+
+ // Will not be able to send it in full within the congestion window, but will
+ // need to wait for SACKs to be received for more fragments to be sent.
+ std::vector<uint8_t> payload(kLargeMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
+ {
+ .max_retransmissions = 0,
+ .lifecycle_id = LifecycleId(1),
+ });
+
+ // First DATA
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // Second DATA (lost)
+ a.cb.ConsumeSentPacket();
+
+ EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1),
+ /*maybe_delivered=*/false));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty());
+}
+
+TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithLifetimeLimit) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ // Send it before the socket is connected, to prevent it from being sent too
+ // quickly. The idea is that it should be expired before even attempting to
+ // send it in full.
+ std::vector<uint8_t> payload(kSmallMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
+ {
+ .lifetime = DurationMs(100),
+ .lifecycle_id = LifecycleId(1),
+ });
+
+ AdvanceTime(a, z, DurationMs(200));
+
+ EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1),
+ /*maybe_delivered=*/false));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
+ ConnectSockets(a, z);
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty());
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ExposesTheNumberOfNegotiatedStreams) {
+ DcSctpOptions options_a = {
+ .announced_maximum_incoming_streams = 12,
+ .announced_maximum_outgoing_streams = 45,
+ };
+ SocketUnderTest a("A", options_a);
+
+ DcSctpOptions options_z = {
+ .announced_maximum_incoming_streams = 23,
+ .announced_maximum_outgoing_streams = 34,
+ };
+ auto z = std::make_unique<SocketUnderTest>("Z", options_z);
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(Metrics metrics_a, a.socket.GetMetrics());
+ EXPECT_EQ(metrics_a.negotiated_maximum_incoming_streams, 12);
+ EXPECT_EQ(metrics_a.negotiated_maximum_outgoing_streams, 23);
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(Metrics metrics_z, z->socket.GetMetrics());
+ EXPECT_EQ(metrics_z.negotiated_maximum_incoming_streams, 23);
+ EXPECT_EQ(metrics_z.negotiated_maximum_outgoing_streams, 12);
+}
+
+TEST(DcSctpSocketTest, ResetStreamsDeferred) {
+ // Guaranteed to be fragmented into two fragments.
+ constexpr size_t kTwoFragmentsSize = DcSctpOptions::kMaxSafeMTUSize + 100;
+
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kTwoFragmentsSize)),
+ {});
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(54),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ {});
+
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+
+ auto data1 = a.cb.ConsumeSentPacket();
+ auto data2 = a.cb.ConsumeSentPacket();
+ auto data3 = a.cb.ConsumeSentPacket();
+ auto reconfig = a.cb.ConsumeSentPacket();
+
+ EXPECT_THAT(
+ data1,
+ HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0))))));
+ EXPECT_THAT(
+ data2,
+ HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0))))));
+ EXPECT_THAT(
+ data3,
+ HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(1))))));
+ EXPECT_THAT(reconfig, HasChunks(ElementsAre(IsReConfig(HasParameters(
+ ElementsAre(IsOutgoingResetRequest(Property(
+ &OutgoingSSNResetRequestParameter::stream_ids,
+ ElementsAre(StreamID(1))))))))));
+
+ // Receive them slightly out of order to make stream resetting deferred.
+ z.socket.ReceivePacket(reconfig);
+
+ z.socket.ReceivePacket(data1);
+ z.socket.ReceivePacket(data2);
+ z.socket.ReceivePacket(data3);
+
+ absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->stream_id(), StreamID(1));
+ EXPECT_EQ(msg1->ppid(), PPID(53));
+ EXPECT_EQ(msg1->payload().size(), kTwoFragmentsSize);
+
+ absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg2.has_value());
+ EXPECT_EQ(msg2->stream_id(), StreamID(1));
+ EXPECT_EQ(msg2->ppid(), PPID(54));
+ EXPECT_EQ(msg2->payload().size(), kSmallMessageSize);
+
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1))));
+ ExchangeMessages(a, z);
+
+ // Z sent "in progress", which will make A buffer packets until it's sure
+ // that the reconfiguration has been applied. A will retry - wait for that.
+ AdvanceTime(a, z, a.options.rto_initial);
+
+ auto reconfig2 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(reconfig2, HasChunks(ElementsAre(IsReConfig(HasParameters(
+ ElementsAre(IsOutgoingResetRequest(Property(
+ &OutgoingSSNResetRequestParameter::stream_ids,
+ ElementsAre(StreamID(1))))))))));
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1))));
+ z.socket.ReceivePacket(reconfig2);
+
+ auto reconfig3 = z.cb.ConsumeSentPacket();
+ EXPECT_THAT(reconfig3, HasChunks(ElementsAre(IsReConfig(HasParameters(
+ ElementsAre(IsReconfigurationResponse(Property(
+ &ReconfigurationResponseParameter::result,
+ ReconfigurationResponseParameter::Result::
+ kSuccessPerformed))))))));
+ a.socket.ReceivePacket(reconfig3);
+
+ EXPECT_THAT(
+ data1,
+ HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0))))));
+ EXPECT_THAT(
+ data2,
+ HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0))))));
+ EXPECT_THAT(
+ data3,
+ HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(1))))));
+ EXPECT_THAT(reconfig, HasChunks(ElementsAre(IsReConfig(HasParameters(
+ ElementsAre(IsOutgoingResetRequest(Property(
+ &OutgoingSSNResetRequestParameter::stream_ids,
+ ElementsAre(StreamID(1))))))))));
+
+ // Send a new message after the stream has been reset.
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(55),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ {});
+ ExchangeMessages(a, z);
+
+ absl::optional<DcSctpMessage> msg3 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg3.has_value());
+ EXPECT_EQ(msg3->stream_id(), StreamID(1));
+ EXPECT_EQ(msg3->ppid(), PPID(55));
+ EXPECT_EQ(msg3->payload().size(), kSmallMessageSize);
+}
+
+TEST(DcSctpSocketTest, ResetStreamsWithPausedSenderResumesWhenPerformed) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ {});
+
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+
+ // Will be queued, as the stream has an outstanding reset operation.
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(52),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ {});
+
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1))));
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1))));
+ ExchangeMessages(a, z);
+
+ absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->stream_id(), StreamID(1));
+ EXPECT_EQ(msg1->ppid(), PPID(51));
+ EXPECT_EQ(msg1->payload().size(), kSmallMessageSize);
+
+ absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg2.has_value());
+ EXPECT_EQ(msg2->stream_id(), StreamID(1));
+ EXPECT_EQ(msg2->ppid(), PPID(52));
+ EXPECT_EQ(msg2->payload().size(), kSmallMessageSize);
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ZeroChecksumMetricsAreSet) {
+ std::vector<std::pair<bool, bool>> combinations = {
+ {false, false}, {false, true}, {true, false}, {true, true}};
+ for (const auto& [a_enable, z_enable] : combinations) {
+ DcSctpOptions a_options = {.enable_zero_checksum = a_enable};
+ DcSctpOptions z_options = {.enable_zero_checksum = z_enable};
+
+ SocketUnderTest a("A", a_options);
+ auto z = std::make_unique<SocketUnderTest>("Z", z_options);
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_EQ(a.socket.GetMetrics()->uses_zero_checksum, a_enable && z_enable);
+ EXPECT_EQ(z->socket.GetMetrics()->uses_zero_checksum, a_enable && z_enable);
+ }
+}
+
+TEST(DcSctpSocketTest, AlwaysSendsInitWithNonZeroChecksum) {
+ DcSctpOptions options = {.enable_zero_checksum = true};
+ SocketUnderTest a("A", options);
+
+ a.socket.Connect();
+ std::vector<uint8_t> data = a.cb.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+ SctpPacket::Parse(data, options));
+ EXPECT_THAT(packet.descriptors(),
+ ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type,
+ InitChunk::kType)));
+ EXPECT_THAT(packet.common_header().checksum, Not(Eq(0u)));
+}
+
+TEST(DcSctpSocketTest, MaySendInitAckWithZeroChecksum) {
+ DcSctpOptions options = {.enable_zero_checksum = true};
+ SocketUnderTest a("A", options);
+ SocketUnderTest z("Z", options);
+
+ a.socket.Connect();
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // INIT
+
+ std::vector<uint8_t> data = z.cb.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+ SctpPacket::Parse(data, options));
+ EXPECT_THAT(packet.descriptors(),
+ ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type,
+ InitAckChunk::kType)));
+ EXPECT_THAT(packet.common_header().checksum, 0u);
+}
+
+TEST(DcSctpSocketTest, AlwaysSendsCookieEchoWithNonZeroChecksum) {
+ DcSctpOptions options = {.enable_zero_checksum = true};
+ SocketUnderTest a("A", options);
+ SocketUnderTest z("Z", options);
+
+ a.socket.Connect();
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // INIT
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // INIT-ACK
+
+ std::vector<uint8_t> data = a.cb.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+ SctpPacket::Parse(data, options));
+ EXPECT_THAT(packet.descriptors(),
+ ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type,
+ CookieEchoChunk::kType)));
+ EXPECT_THAT(packet.common_header().checksum, Not(Eq(0u)));
+}
+
+TEST(DcSctpSocketTest, SendsCookieAckWithZeroChecksum) {
+ DcSctpOptions options = {.enable_zero_checksum = true};
+ SocketUnderTest a("A", options);
+ SocketUnderTest z("Z", options);
+
+ a.socket.Connect();
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // INIT
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // INIT-ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // COOKIE-ECHO
+
+ std::vector<uint8_t> data = z.cb.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+ SctpPacket::Parse(data, options));
+ EXPECT_THAT(packet.descriptors(),
+ ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type,
+ CookieAckChunk::kType)));
+ EXPECT_THAT(packet.common_header().checksum, 0u);
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendsDataWithZeroChecksum) {
+ DcSctpOptions options = {.enable_zero_checksum = true};
+ SocketUnderTest a("A", options);
+ auto z = std::make_unique<SocketUnderTest>("Z", options);
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ std::vector<uint8_t> payload(a.options.mtu - 100);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ std::vector<uint8_t> data = a.cb.ConsumeSentPacket();
+ z->socket.ReceivePacket(data);
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+ SctpPacket::Parse(data, options));
+ EXPECT_THAT(packet.descriptors(),
+ ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type,
+ DataChunk::kType)));
+ EXPECT_THAT(packet.common_header().checksum, 0u);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, AllPacketsAfterConnectHaveZeroChecksum) {
+ DcSctpOptions options = {.enable_zero_checksum = true};
+ SocketUnderTest a("A", options);
+ auto z = std::make_unique<SocketUnderTest>("Z", options);
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ // Send large messages in both directions, and verify that they arrive and
+ // that every packet has zero checksum.
+ std::vector<uint8_t> payload(kLargeMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+ z->socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+
+ for (;;) {
+ if (auto data = a.cb.ConsumeSentPacket(); !data.empty()) {
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+ SctpPacket::Parse(data, options));
+ EXPECT_THAT(packet.common_header().checksum, 0u);
+ z->socket.ReceivePacket(std::move(data));
+
+ } else if (auto data = z->cb.ConsumeSentPacket(); !data.empty()) {
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+ SctpPacket::Parse(data, options));
+ EXPECT_THAT(packet.common_header().checksum, 0u);
+ a.socket.ReceivePacket(std::move(data));
+
+ } else {
+ break;
+ }
+ }
+
+ absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_THAT(msg1->payload(), SizeIs(kLargeMessageSize));
+
+ absl::optional<DcSctpMessage> msg2 = a.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg2.has_value());
+ EXPECT_THAT(msg2->payload(), SizeIs(kLargeMessageSize));
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST(DcSctpSocketTest, HandlesForwardTsnOutOfOrderWithStreamResetting) {
+ // This test ensures that receiving FORWARD-TSN and RECONFIG out of order is
+ // handled correctly.
+ SocketUnderTest a("A", {.heartbeat_interval = DurationMs(0)});
+ SocketUnderTest z("Z", {.heartbeat_interval = DurationMs(0)});
+
+ ConnectSockets(a, z);
+ std::vector<uint8_t> payload(kSmallMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
+ {
+ .max_retransmissions = 0,
+ });
+
+ // Packet is lost.
+ EXPECT_THAT(a.cb.ConsumeSentPacket(),
+ HasChunks(ElementsAre(
+ IsDataChunk(AllOf(Property(&DataChunk::ssn, SSN(0)),
+ Property(&DataChunk::ppid, PPID(51)))))));
+ AdvanceTime(a, z, a.options.rto_initial);
+
+ auto fwd_tsn_packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(fwd_tsn_packet,
+ HasChunks(ElementsAre(IsChunkType(ForwardTsnChunk::kType))));
+ // Reset stream 1
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+ auto reconfig_packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(reconfig_packet,
+ HasChunks(ElementsAre(IsChunkType(ReConfigChunk::kType))));
+
+ // These two packets are received in the wrong order.
+ z.socket.ReceivePacket(reconfig_packet);
+ z.socket.ReceivePacket(fwd_tsn_packet);
+ ExchangeMessagesAndAdvanceTime(a, z);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), {});
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ auto data_packet_2 = a.cb.ConsumeSentPacket();
+ auto data_packet_3 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(data_packet_2, HasChunks(ElementsAre(IsDataChunk(AllOf(
+ Property(&DataChunk::ssn, SSN(0)),
+ Property(&DataChunk::ppid, PPID(52)))))));
+ EXPECT_THAT(data_packet_3, HasChunks(ElementsAre(IsDataChunk(AllOf(
+ Property(&DataChunk::ssn, SSN(1)),
+ Property(&DataChunk::ppid, PPID(53)))))));
+
+ z.socket.ReceivePacket(data_packet_2);
+ z.socket.ReceivePacket(data_packet_3);
+ ASSERT_THAT(z.cb.ConsumeReceivedMessage(),
+ testing::Optional(Property(&DcSctpMessage::ppid, PPID(52))));
+ ASSERT_THAT(z.cb.ConsumeReceivedMessage(),
+ testing::Optional(Property(&DcSctpMessage::ppid, PPID(53))));
+}
+
+} // namespace
+} // namespace dcsctp