summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/net/dcsctp/socket
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
commit43a97878ce14b72f0981164f87f2e35e14151312 (patch)
tree620249daf56c0258faa40cbdcf9cfba06de2a846 /third_party/libwebrtc/net/dcsctp/socket
parentInitial commit. (diff)
downloadfirefox-upstream.tar.xz
firefox-upstream.zip
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/libwebrtc/net/dcsctp/socket')
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/BUILD.gn278
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/DEPS5
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc181
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h100
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/capabilities.h26
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/context.h66
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc1752
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h298
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_network_test.cc518
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc2672
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.cc196
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.h69
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler_test.cc184
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/mock_context.h72
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h179
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/packet_sender.cc48
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/packet_sender.h40
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/packet_sender_test.cc50
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc78
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/state_cookie.h65
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc53
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.cc349
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.h230
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler_test.cc802
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.cc314
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.h193
26 files changed, 8818 insertions, 0 deletions
diff --git a/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn b/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn
new file mode 100644
index 0000000000..92ce413d0d
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn
@@ -0,0 +1,278 @@
+# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+#
+# Use of this source code is governed by a BSD-style license
+# that can be found in the LICENSE file in the root of the source
+# tree. An additional intellectual property rights grant can be found
+# in the file PATENTS. All contributing project authors may
+# be found in the AUTHORS file in the root of the source tree.
+
+import("../../../webrtc.gni")
+
+rtc_source_set("context") {
+ sources = [ "context.h" ]
+ deps = [
+ "../common:internal_types",
+ "../packet:sctp_packet",
+ "../public:socket",
+ "../public:types",
+ ]
+ absl_deps = [ "//third_party/abseil-cpp/absl/strings" ]
+}
+
+rtc_library("heartbeat_handler") {
+ deps = [
+ ":context",
+ "../../../api:array_view",
+ "../../../rtc_base:checks",
+ "../../../rtc_base:logging",
+ "../packet:bounded_io",
+ "../packet:chunk",
+ "../packet:parameter",
+ "../packet:sctp_packet",
+ "../public:socket",
+ "../public:types",
+ "../timer",
+ ]
+ sources = [
+ "heartbeat_handler.cc",
+ "heartbeat_handler.h",
+ ]
+ absl_deps = [
+ "//third_party/abseil-cpp/absl/functional:bind_front",
+ "//third_party/abseil-cpp/absl/strings",
+ "//third_party/abseil-cpp/absl/types:optional",
+ ]
+}
+
+rtc_library("stream_reset_handler") {
+ deps = [
+ ":context",
+ "../../../api:array_view",
+ "../../../rtc_base:checks",
+ "../../../rtc_base:logging",
+ "../../../rtc_base/containers:flat_set",
+ "../common:internal_types",
+ "../common:str_join",
+ "../packet:chunk",
+ "../packet:parameter",
+ "../packet:sctp_packet",
+ "../packet:tlv_trait",
+ "../public:socket",
+ "../public:types",
+ "../rx:data_tracker",
+ "../rx:reassembly_queue",
+ "../timer",
+ "../tx:retransmission_queue",
+ ]
+ sources = [
+ "stream_reset_handler.cc",
+ "stream_reset_handler.h",
+ ]
+ absl_deps = [
+ "//third_party/abseil-cpp/absl/functional:bind_front",
+ "//third_party/abseil-cpp/absl/strings",
+ "//third_party/abseil-cpp/absl/types:optional",
+ ]
+}
+
+rtc_library("packet_sender") {
+ deps = [
+ "../packet:sctp_packet",
+ "../public:socket",
+ "../public:types",
+ "../timer",
+ ]
+ sources = [
+ "packet_sender.cc",
+ "packet_sender.h",
+ ]
+ absl_deps = []
+}
+
+rtc_library("transmission_control_block") {
+ deps = [
+ ":context",
+ ":heartbeat_handler",
+ ":packet_sender",
+ ":stream_reset_handler",
+ "../../../api:array_view",
+ "../../../api/task_queue:task_queue",
+ "../../../rtc_base:checks",
+ "../../../rtc_base:logging",
+ "../../../rtc_base:stringutils",
+ "../common:sequence_numbers",
+ "../packet:chunk",
+ "../packet:sctp_packet",
+ "../public:socket",
+ "../public:types",
+ "../rx:data_tracker",
+ "../rx:reassembly_queue",
+ "../timer",
+ "../tx:retransmission_error_counter",
+ "../tx:retransmission_queue",
+ "../tx:retransmission_timeout",
+ "../tx:send_queue",
+ ]
+ sources = [
+ "capabilities.h",
+ "transmission_control_block.cc",
+ "transmission_control_block.h",
+ ]
+ absl_deps = [
+ "//third_party/abseil-cpp/absl/functional:bind_front",
+ "//third_party/abseil-cpp/absl/strings",
+ "//third_party/abseil-cpp/absl/types:optional",
+ ]
+}
+
+rtc_library("dcsctp_socket") {
+ deps = [
+ ":context",
+ ":heartbeat_handler",
+ ":packet_sender",
+ ":stream_reset_handler",
+ ":transmission_control_block",
+ "../../../api:array_view",
+ "../../../api:make_ref_counted",
+ "../../../api:refcountedbase",
+ "../../../api:scoped_refptr",
+ "../../../api:sequence_checker",
+ "../../../api/task_queue:task_queue",
+ "../../../rtc_base:checks",
+ "../../../rtc_base:logging",
+ "../../../rtc_base:stringutils",
+ "../common:internal_types",
+ "../packet:bounded_io",
+ "../packet:chunk",
+ "../packet:chunk_validators",
+ "../packet:data",
+ "../packet:error_cause",
+ "../packet:parameter",
+ "../packet:sctp_packet",
+ "../packet:tlv_trait",
+ "../public:socket",
+ "../public:types",
+ "../rx:data_tracker",
+ "../rx:reassembly_queue",
+ "../timer",
+ "../tx:retransmission_error_counter",
+ "../tx:retransmission_queue",
+ "../tx:retransmission_timeout",
+ "../tx:rr_send_queue",
+ "../tx:send_queue",
+ ]
+ sources = [
+ "callback_deferrer.cc",
+ "callback_deferrer.h",
+ "dcsctp_socket.cc",
+ "dcsctp_socket.h",
+ "state_cookie.cc",
+ "state_cookie.h",
+ ]
+ absl_deps = [
+ "//third_party/abseil-cpp/absl/functional:bind_front",
+ "//third_party/abseil-cpp/absl/memory",
+ "//third_party/abseil-cpp/absl/strings",
+ "//third_party/abseil-cpp/absl/types:optional",
+ ]
+}
+
+if (rtc_include_tests) {
+ rtc_source_set("mock_callbacks") {
+ testonly = true
+ sources = [ "mock_dcsctp_socket_callbacks.h" ]
+ deps = [
+ "../../../api:array_view",
+ "../../../api/task_queue:task_queue",
+ "../../../rtc_base:logging",
+ "../../../rtc_base:random",
+ "../../../test:test_support",
+ "../public:socket",
+ "../public:types",
+ "../timer",
+ ]
+ absl_deps = [
+ "//third_party/abseil-cpp/absl/strings",
+ "//third_party/abseil-cpp/absl/types:optional",
+ ]
+ }
+
+ rtc_source_set("mock_context") {
+ testonly = true
+ sources = [ "mock_context.h" ]
+ deps = [
+ ":context",
+ ":mock_callbacks",
+ "../../../test:test_support",
+ "../common:internal_types",
+ "../packet:sctp_packet",
+ "../public:socket",
+ "../public:types",
+ ]
+ absl_deps = [
+ "//third_party/abseil-cpp/absl/strings",
+ "//third_party/abseil-cpp/absl/types:optional",
+ ]
+ }
+
+ rtc_library("dcsctp_socket_unittests") {
+ testonly = true
+
+ deps = [
+ ":dcsctp_socket",
+ ":heartbeat_handler",
+ ":mock_callbacks",
+ ":mock_context",
+ ":packet_sender",
+ ":stream_reset_handler",
+ "../../../api:array_view",
+ "../../../api:create_network_emulation_manager",
+ "../../../api:network_emulation_manager_api",
+ "../../../api/task_queue",
+ "../../../api/task_queue:pending_task_safety_flag",
+ "../../../api/units:time_delta",
+ "../../../call:simulated_network",
+ "../../../rtc_base:checks",
+ "../../../rtc_base:copy_on_write_buffer",
+ "../../../rtc_base:gunit_helpers",
+ "../../../rtc_base:logging",
+ "../../../rtc_base:rtc_base_tests_utils",
+ "../../../rtc_base:socket_address",
+ "../../../rtc_base:stringutils",
+ "../../../rtc_base:timeutils",
+ "../../../test:test_support",
+ "../common:handover_testing",
+ "../common:internal_types",
+ "../packet:chunk",
+ "../packet:error_cause",
+ "../packet:parameter",
+ "../packet:sctp_packet",
+ "../packet:tlv_trait",
+ "../public:socket",
+ "../public:types",
+ "../public:utils",
+ "../rx:data_tracker",
+ "../rx:reassembly_queue",
+ "../testing:data_generator",
+ "../testing:testing_macros",
+ "../timer",
+ "../timer:task_queue_timeout",
+ "../tx:mock_send_queue",
+ "../tx:retransmission_queue",
+ ]
+ absl_deps = [
+ "//third_party/abseil-cpp/absl/flags:flag",
+ "//third_party/abseil-cpp/absl/memory",
+ "//third_party/abseil-cpp/absl/strings",
+ "//third_party/abseil-cpp/absl/types:optional",
+ ]
+ sources = [
+ "dcsctp_socket_network_test.cc",
+ "dcsctp_socket_test.cc",
+ "heartbeat_handler_test.cc",
+ "packet_sender_test.cc",
+ "state_cookie_test.cc",
+ "stream_reset_handler_test.cc",
+ ]
+ }
+}
diff --git a/third_party/libwebrtc/net/dcsctp/socket/DEPS b/third_party/libwebrtc/net/dcsctp/socket/DEPS
new file mode 100644
index 0000000000..d4966290e3
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/DEPS
@@ -0,0 +1,5 @@
+specific_include_rules = {
+ "dcsctp_socket_network_test.cc": [
+ "+call",
+ ]
+}
diff --git a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc
new file mode 100644
index 0000000000..123526e782
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc
@@ -0,0 +1,181 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/callback_deferrer.h"
+
+#include "api/make_ref_counted.h"
+
+namespace dcsctp {
+namespace {
+// A wrapper around the move-only DcSctpMessage, to let it be captured in a
+// lambda.
+class MessageDeliverer {
+ public:
+ explicit MessageDeliverer(DcSctpMessage&& message)
+ : state_(rtc::make_ref_counted<State>(std::move(message))) {}
+
+ void Deliver(DcSctpSocketCallbacks& c) {
+ // Really ensure that it's only called once.
+ RTC_DCHECK(!state_->has_delivered);
+ state_->has_delivered = true;
+ c.OnMessageReceived(std::move(state_->message));
+ }
+
+ private:
+ struct State : public rtc::RefCountInterface {
+ explicit State(DcSctpMessage&& m)
+ : has_delivered(false), message(std::move(m)) {}
+ bool has_delivered;
+ DcSctpMessage message;
+ };
+ rtc::scoped_refptr<State> state_;
+};
+} // namespace
+
+void CallbackDeferrer::Prepare() {
+ RTC_DCHECK(!prepared_);
+ prepared_ = true;
+}
+
+void CallbackDeferrer::TriggerDeferred() {
+ // Need to swap here. The client may call into the library from within a
+ // callback, and that might result in adding new callbacks to this instance,
+ // and the vector can't be modified while iterated on.
+ RTC_DCHECK(prepared_);
+ std::vector<std::function<void(DcSctpSocketCallbacks & cb)>> deferred;
+ deferred.swap(deferred_);
+ prepared_ = false;
+
+ for (auto& cb : deferred) {
+ cb(underlying_);
+ }
+}
+
+SendPacketStatus CallbackDeferrer::SendPacketWithStatus(
+ rtc::ArrayView<const uint8_t> data) {
+ // Will not be deferred - call directly.
+ return underlying_.SendPacketWithStatus(data);
+}
+
+std::unique_ptr<Timeout> CallbackDeferrer::CreateTimeout(
+ webrtc::TaskQueueBase::DelayPrecision precision) {
+ // Will not be deferred - call directly.
+ return underlying_.CreateTimeout(precision);
+}
+
+TimeMs CallbackDeferrer::TimeMillis() {
+ // Will not be deferred - call directly.
+ return underlying_.TimeMillis();
+}
+
+uint32_t CallbackDeferrer::GetRandomInt(uint32_t low, uint32_t high) {
+ // Will not be deferred - call directly.
+ return underlying_.GetRandomInt(low, high);
+}
+
+void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) {
+ RTC_DCHECK(prepared_);
+ deferred_.emplace_back(
+ [deliverer = MessageDeliverer(std::move(message))](
+ DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); });
+}
+
+void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) {
+ RTC_DCHECK(prepared_);
+ deferred_.emplace_back(
+ [error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
+ cb.OnError(error, message);
+ });
+}
+
+void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) {
+ RTC_DCHECK(prepared_);
+ deferred_.emplace_back(
+ [error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
+ cb.OnAborted(error, message);
+ });
+}
+
+void CallbackDeferrer::OnConnected() {
+ RTC_DCHECK(prepared_);
+ deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); });
+}
+
+void CallbackDeferrer::OnClosed() {
+ RTC_DCHECK(prepared_);
+ deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); });
+}
+
+void CallbackDeferrer::OnConnectionRestarted() {
+ RTC_DCHECK(prepared_);
+ deferred_.emplace_back(
+ [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); });
+}
+
+void CallbackDeferrer::OnStreamsResetFailed(
+ rtc::ArrayView<const StreamID> outgoing_streams,
+ absl::string_view reason) {
+ RTC_DCHECK(prepared_);
+ deferred_.emplace_back(
+ [streams = std::vector<StreamID>(outgoing_streams.begin(),
+ outgoing_streams.end()),
+ reason = std::string(reason)](DcSctpSocketCallbacks& cb) {
+ cb.OnStreamsResetFailed(streams, reason);
+ });
+}
+
+void CallbackDeferrer::OnStreamsResetPerformed(
+ rtc::ArrayView<const StreamID> outgoing_streams) {
+ RTC_DCHECK(prepared_);
+ deferred_.emplace_back(
+ [streams = std::vector<StreamID>(outgoing_streams.begin(),
+ outgoing_streams.end())](
+ DcSctpSocketCallbacks& cb) { cb.OnStreamsResetPerformed(streams); });
+}
+
+void CallbackDeferrer::OnIncomingStreamsReset(
+ rtc::ArrayView<const StreamID> incoming_streams) {
+ RTC_DCHECK(prepared_);
+ deferred_.emplace_back(
+ [streams = std::vector<StreamID>(incoming_streams.begin(),
+ incoming_streams.end())](
+ DcSctpSocketCallbacks& cb) { cb.OnIncomingStreamsReset(streams); });
+}
+
+void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) {
+ RTC_DCHECK(prepared_);
+ deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) {
+ cb.OnBufferedAmountLow(stream_id);
+ });
+}
+
+void CallbackDeferrer::OnTotalBufferedAmountLow() {
+ RTC_DCHECK(prepared_);
+ deferred_.emplace_back(
+ [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); });
+}
+
+void CallbackDeferrer::OnLifecycleMessageExpired(LifecycleId lifecycle_id,
+ bool maybe_delivered) {
+ // Will not be deferred - call directly.
+ underlying_.OnLifecycleMessageExpired(lifecycle_id, maybe_delivered);
+}
+void CallbackDeferrer::OnLifecycleMessageFullySent(LifecycleId lifecycle_id) {
+ // Will not be deferred - call directly.
+ underlying_.OnLifecycleMessageFullySent(lifecycle_id);
+}
+void CallbackDeferrer::OnLifecycleMessageDelivered(LifecycleId lifecycle_id) {
+ // Will not be deferred - call directly.
+ underlying_.OnLifecycleMessageDelivered(lifecycle_id);
+}
+void CallbackDeferrer::OnLifecycleEnd(LifecycleId lifecycle_id) {
+ // Will not be deferred - call directly.
+ underlying_.OnLifecycleEnd(lifecycle_id);
+}
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h
new file mode 100644
index 0000000000..1c35dda6cf
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_
+#define NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "api/array_view.h"
+#include "api/ref_counted_base.h"
+#include "api/scoped_refptr.h"
+#include "api/task_queue/task_queue_base.h"
+#include "net/dcsctp/public/dcsctp_message.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+
+namespace dcsctp {
+// Defers callbacks until they can be safely triggered.
+//
+// There are a lot of callbacks from the dcSCTP library to the client,
+// such as when messages are received or streams are closed. When the client
+// receives these callbacks, the client is expected to be able to call into the
+// library - from within the callback. For example, sending a reply message when
+// a certain SCTP message has been received, or to reconnect when the connection
+// was closed for any reason. This means that the dcSCTP library must always be
+// in a consistent and stable state when these callbacks are delivered, and to
+// ensure that's the case, callbacks are not immediately delivered from where
+// they originate, but instead queued (deferred) by this class. At the end of
+// any public API method that may result in callbacks, they are triggered and
+// then delivered.
+//
+// There are a number of exceptions, which is clearly annotated in the API.
+class CallbackDeferrer : public DcSctpSocketCallbacks {
+ public:
+ class ScopedDeferrer {
+ public:
+ explicit ScopedDeferrer(CallbackDeferrer& callback_deferrer)
+ : callback_deferrer_(callback_deferrer) {
+ callback_deferrer_.Prepare();
+ }
+
+ ~ScopedDeferrer() { callback_deferrer_.TriggerDeferred(); }
+
+ private:
+ CallbackDeferrer& callback_deferrer_;
+ };
+
+ explicit CallbackDeferrer(DcSctpSocketCallbacks& underlying)
+ : underlying_(underlying) {}
+
+ // Implementation of DcSctpSocketCallbacks
+ SendPacketStatus SendPacketWithStatus(
+ rtc::ArrayView<const uint8_t> data) override;
+ std::unique_ptr<Timeout> CreateTimeout(
+ webrtc::TaskQueueBase::DelayPrecision precision) override;
+ TimeMs TimeMillis() override;
+ uint32_t GetRandomInt(uint32_t low, uint32_t high) override;
+ void OnMessageReceived(DcSctpMessage message) override;
+ void OnError(ErrorKind error, absl::string_view message) override;
+ void OnAborted(ErrorKind error, absl::string_view message) override;
+ void OnConnected() override;
+ void OnClosed() override;
+ void OnConnectionRestarted() override;
+ void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,
+ absl::string_view reason) override;
+ void OnStreamsResetPerformed(
+ rtc::ArrayView<const StreamID> outgoing_streams) override;
+ void OnIncomingStreamsReset(
+ rtc::ArrayView<const StreamID> incoming_streams) override;
+ void OnBufferedAmountLow(StreamID stream_id) override;
+ void OnTotalBufferedAmountLow() override;
+
+ void OnLifecycleMessageExpired(LifecycleId lifecycle_id,
+ bool maybe_delivered) override;
+ void OnLifecycleMessageFullySent(LifecycleId lifecycle_id) override;
+ void OnLifecycleMessageDelivered(LifecycleId lifecycle_id) override;
+ void OnLifecycleEnd(LifecycleId lifecycle_id) override;
+
+ private:
+ void Prepare();
+ void TriggerDeferred();
+
+ DcSctpSocketCallbacks& underlying_;
+ bool prepared_ = false;
+ std::vector<std::function<void(DcSctpSocketCallbacks& cb)>> deferred_;
+};
+} // namespace dcsctp
+
+#endif // NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_
diff --git a/third_party/libwebrtc/net/dcsctp/socket/capabilities.h b/third_party/libwebrtc/net/dcsctp/socket/capabilities.h
new file mode 100644
index 0000000000..c6d3692b2d
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/capabilities.h
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_CAPABILITIES_H_
+#define NET_DCSCTP_SOCKET_CAPABILITIES_H_
+
+namespace dcsctp {
+// Indicates what the association supports, meaning that both parties
+// support it and that feature can be used.
+struct Capabilities {
+ // RFC3758 Partial Reliability Extension
+ bool partial_reliability = false;
+ // RFC8260 Stream Schedulers and User Message Interleaving
+ bool message_interleaving = false;
+ // RFC6525 Stream Reconfiguration
+ bool reconfig = false;
+};
+} // namespace dcsctp
+
+#endif // NET_DCSCTP_SOCKET_CAPABILITIES_H_
diff --git a/third_party/libwebrtc/net/dcsctp/socket/context.h b/third_party/libwebrtc/net/dcsctp/socket/context.h
new file mode 100644
index 0000000000..eca5b9e4fb
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/context.h
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_CONTEXT_H_
+#define NET_DCSCTP_SOCKET_CONTEXT_H_
+
+#include <cstdint>
+
+#include "absl/strings/string_view.h"
+#include "net/dcsctp/common/internal_types.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/public/types.h"
+
+namespace dcsctp {
+
+// A set of helper methods used by handlers to e.g. send packets.
+//
+// Implemented by the TransmissionControlBlock.
+class Context {
+ public:
+ virtual ~Context() = default;
+
+ // Indicates if a connection has been established.
+ virtual bool is_connection_established() const = 0;
+
+ // Returns this side's initial TSN value.
+ virtual TSN my_initial_tsn() const = 0;
+
+ // Returns the peer's initial TSN value.
+ virtual TSN peer_initial_tsn() const = 0;
+
+ // Returns the socket callbacks.
+ virtual DcSctpSocketCallbacks& callbacks() const = 0;
+
+ // Observes a measured RTT value, in milliseconds.
+ virtual void ObserveRTT(DurationMs rtt_ms) = 0;
+
+ // Returns the current Retransmission Timeout (rto) value, in milliseconds.
+ virtual DurationMs current_rto() const = 0;
+
+ // Increments the transmission error counter, given a human readable reason.
+ virtual bool IncrementTxErrorCounter(absl::string_view reason) = 0;
+
+ // Clears the transmission error counter.
+ virtual void ClearTxErrorCounter() = 0;
+
+ // Returns true if there have been too many retransmission errors.
+ virtual bool HasTooManyTxErrors() const = 0;
+
+ // Returns a PacketBuilder, filled in with the correct verification tag.
+ virtual SctpPacket::Builder PacketBuilder() const = 0;
+
+ // Builds the packet from `builder` and sends it.
+ virtual void Send(SctpPacket::Builder& builder) = 0;
+};
+
+} // namespace dcsctp
+
+#endif // NET_DCSCTP_SOCKET_CONTEXT_H_
diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc
new file mode 100644
index 0000000000..53838193ec
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc
@@ -0,0 +1,1752 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/dcsctp_socket.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/functional/bind_front.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "api/task_queue/task_queue_base.h"
+#include "net/dcsctp/packet/chunk/abort_chunk.h"
+#include "net/dcsctp/packet/chunk/chunk.h"
+#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
+#include "net/dcsctp/packet/chunk/data_chunk.h"
+#include "net/dcsctp/packet/chunk/data_common.h"
+#include "net/dcsctp/packet/chunk/error_chunk.h"
+#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h"
+#include "net/dcsctp/packet/chunk/forward_tsn_common.h"
+#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
+#include "net/dcsctp/packet/chunk/idata_chunk.h"
+#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h"
+#include "net/dcsctp/packet/chunk/init_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/init_chunk.h"
+#include "net/dcsctp/packet/chunk/reconfig_chunk.h"
+#include "net/dcsctp/packet/chunk/sack_chunk.h"
+#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/shutdown_chunk.h"
+#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h"
+#include "net/dcsctp/packet/chunk_validators.h"
+#include "net/dcsctp/packet/data.h"
+#include "net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h"
+#include "net/dcsctp/packet/error_cause/error_cause.h"
+#include "net/dcsctp/packet/error_cause/no_user_data_cause.h"
+#include "net/dcsctp/packet/error_cause/out_of_resource_error_cause.h"
+#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h"
+#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h"
+#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h"
+#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h"
+#include "net/dcsctp/packet/parameter/parameter.h"
+#include "net/dcsctp/packet/parameter/state_cookie_parameter.h"
+#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/packet/tlv_trait.h"
+#include "net/dcsctp/public/dcsctp_message.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/public/packet_observer.h"
+#include "net/dcsctp/public/types.h"
+#include "net/dcsctp/rx/data_tracker.h"
+#include "net/dcsctp/rx/reassembly_queue.h"
+#include "net/dcsctp/socket/callback_deferrer.h"
+#include "net/dcsctp/socket/capabilities.h"
+#include "net/dcsctp/socket/heartbeat_handler.h"
+#include "net/dcsctp/socket/state_cookie.h"
+#include "net/dcsctp/socket/stream_reset_handler.h"
+#include "net/dcsctp/socket/transmission_control_block.h"
+#include "net/dcsctp/timer/timer.h"
+#include "net/dcsctp/tx/retransmission_queue.h"
+#include "net/dcsctp/tx/send_queue.h"
+#include "rtc_base/checks.h"
+#include "rtc_base/logging.h"
+#include "rtc_base/strings/string_builder.h"
+#include "rtc_base/strings/string_format.h"
+
+namespace dcsctp {
+namespace {
+
+// https://tools.ietf.org/html/rfc4960#section-5.1
+constexpr uint32_t kMinVerificationTag = 1;
+constexpr uint32_t kMaxVerificationTag = std::numeric_limits<uint32_t>::max();
+
+// https://tools.ietf.org/html/rfc4960#section-3.3.2
+constexpr uint32_t kMinInitialTsn = 0;
+constexpr uint32_t kMaxInitialTsn = std::numeric_limits<uint32_t>::max();
+
+Capabilities GetCapabilities(const DcSctpOptions& options,
+ const Parameters& parameters) {
+ Capabilities capabilities;
+ absl::optional<SupportedExtensionsParameter> supported_extensions =
+ parameters.get<SupportedExtensionsParameter>();
+
+ if (options.enable_partial_reliability) {
+ capabilities.partial_reliability =
+ parameters.get<ForwardTsnSupportedParameter>().has_value();
+ if (supported_extensions.has_value()) {
+ capabilities.partial_reliability |=
+ supported_extensions->supports(ForwardTsnChunk::kType);
+ }
+ }
+
+ if (options.enable_message_interleaving && supported_extensions.has_value()) {
+ capabilities.message_interleaving =
+ supported_extensions->supports(IDataChunk::kType) &&
+ supported_extensions->supports(IForwardTsnChunk::kType);
+ }
+ if (supported_extensions.has_value() &&
+ supported_extensions->supports(ReConfigChunk::kType)) {
+ capabilities.reconfig = true;
+ }
+ return capabilities;
+}
+
+void AddCapabilityParameters(const DcSctpOptions& options,
+ Parameters::Builder& builder) {
+ std::vector<uint8_t> chunk_types = {ReConfigChunk::kType};
+
+ if (options.enable_partial_reliability) {
+ builder.Add(ForwardTsnSupportedParameter());
+ chunk_types.push_back(ForwardTsnChunk::kType);
+ }
+ if (options.enable_message_interleaving) {
+ chunk_types.push_back(IDataChunk::kType);
+ chunk_types.push_back(IForwardTsnChunk::kType);
+ }
+ builder.Add(SupportedExtensionsParameter(std::move(chunk_types)));
+}
+
+TieTag MakeTieTag(DcSctpSocketCallbacks& cb) {
+ uint32_t tie_tag_upper =
+ cb.GetRandomInt(0, std::numeric_limits<uint32_t>::max());
+ uint32_t tie_tag_lower =
+ cb.GetRandomInt(1, std::numeric_limits<uint32_t>::max());
+ return TieTag(static_cast<uint64_t>(tie_tag_upper) << 32 |
+ static_cast<uint64_t>(tie_tag_lower));
+}
+
+SctpImplementation DeterminePeerImplementation(
+ rtc::ArrayView<const uint8_t> cookie) {
+ if (cookie.size() > 8) {
+ absl::string_view magic(reinterpret_cast<const char*>(cookie.data()), 8);
+ if (magic == "dcSCTP00") {
+ return SctpImplementation::kDcsctp;
+ }
+ if (magic == "KAME-BSD") {
+ return SctpImplementation::kUsrSctp;
+ }
+ }
+ return SctpImplementation::kOther;
+}
+} // namespace
+
+DcSctpSocket::DcSctpSocket(absl::string_view log_prefix,
+ DcSctpSocketCallbacks& callbacks,
+ std::unique_ptr<PacketObserver> packet_observer,
+ const DcSctpOptions& options)
+ : log_prefix_(std::string(log_prefix) + ": "),
+ packet_observer_(std::move(packet_observer)),
+ options_(options),
+ callbacks_(callbacks),
+ timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) {
+ return callbacks_.CreateTimeout(precision);
+ }),
+ t1_init_(timer_manager_.CreateTimer(
+ "t1-init",
+ absl::bind_front(&DcSctpSocket::OnInitTimerExpiry, this),
+ TimerOptions(options.t1_init_timeout,
+ TimerBackoffAlgorithm::kExponential,
+ options.max_init_retransmits))),
+ t1_cookie_(timer_manager_.CreateTimer(
+ "t1-cookie",
+ absl::bind_front(&DcSctpSocket::OnCookieTimerExpiry, this),
+ TimerOptions(options.t1_cookie_timeout,
+ TimerBackoffAlgorithm::kExponential,
+ options.max_init_retransmits))),
+ t2_shutdown_(timer_manager_.CreateTimer(
+ "t2-shutdown",
+ absl::bind_front(&DcSctpSocket::OnShutdownTimerExpiry, this),
+ TimerOptions(options.t2_shutdown_timeout,
+ TimerBackoffAlgorithm::kExponential,
+ options.max_retransmissions))),
+ packet_sender_(callbacks_,
+ absl::bind_front(&DcSctpSocket::OnSentPacket, this)),
+ send_queue_(log_prefix_,
+ &callbacks_,
+ options_.max_send_buffer_size,
+ options_.mtu,
+ options_.default_stream_priority,
+ options_.total_buffered_amount_low_threshold) {}
+
+std::string DcSctpSocket::log_prefix() const {
+ return log_prefix_ + "[" + std::string(ToString(state_)) + "] ";
+}
+
+bool DcSctpSocket::IsConsistent() const {
+ if (tcb_ != nullptr && tcb_->reassembly_queue().HasMessages()) {
+ return false;
+ }
+ switch (state_) {
+ case State::kClosed:
+ return (tcb_ == nullptr && !t1_init_->is_running() &&
+ !t1_cookie_->is_running() && !t2_shutdown_->is_running());
+ case State::kCookieWait:
+ return (tcb_ == nullptr && t1_init_->is_running() &&
+ !t1_cookie_->is_running() && !t2_shutdown_->is_running());
+ case State::kCookieEchoed:
+ return (tcb_ != nullptr && !t1_init_->is_running() &&
+ t1_cookie_->is_running() && !t2_shutdown_->is_running() &&
+ tcb_->has_cookie_echo_chunk());
+ case State::kEstablished:
+ return (tcb_ != nullptr && !t1_init_->is_running() &&
+ !t1_cookie_->is_running() && !t2_shutdown_->is_running());
+ case State::kShutdownPending:
+ return (tcb_ != nullptr && !t1_init_->is_running() &&
+ !t1_cookie_->is_running() && !t2_shutdown_->is_running());
+ case State::kShutdownSent:
+ return (tcb_ != nullptr && !t1_init_->is_running() &&
+ !t1_cookie_->is_running() && t2_shutdown_->is_running());
+ case State::kShutdownReceived:
+ return (tcb_ != nullptr && !t1_init_->is_running() &&
+ !t1_cookie_->is_running() && !t2_shutdown_->is_running());
+ case State::kShutdownAckSent:
+ return (tcb_ != nullptr && !t1_init_->is_running() &&
+ !t1_cookie_->is_running() && t2_shutdown_->is_running());
+ }
+}
+
+constexpr absl::string_view DcSctpSocket::ToString(DcSctpSocket::State state) {
+ switch (state) {
+ case DcSctpSocket::State::kClosed:
+ return "CLOSED";
+ case DcSctpSocket::State::kCookieWait:
+ return "COOKIE_WAIT";
+ case DcSctpSocket::State::kCookieEchoed:
+ return "COOKIE_ECHOED";
+ case DcSctpSocket::State::kEstablished:
+ return "ESTABLISHED";
+ case DcSctpSocket::State::kShutdownPending:
+ return "SHUTDOWN_PENDING";
+ case DcSctpSocket::State::kShutdownSent:
+ return "SHUTDOWN_SENT";
+ case DcSctpSocket::State::kShutdownReceived:
+ return "SHUTDOWN_RECEIVED";
+ case DcSctpSocket::State::kShutdownAckSent:
+ return "SHUTDOWN_ACK_SENT";
+ }
+}
+
+void DcSctpSocket::SetState(State state, absl::string_view reason) {
+ if (state_ != state) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Socket state changed from "
+ << ToString(state_) << " to " << ToString(state)
+ << " due to " << reason;
+ state_ = state;
+ }
+}
+
+void DcSctpSocket::SendInit() {
+ Parameters::Builder params_builder;
+ AddCapabilityParameters(options_, params_builder);
+ InitChunk init(/*initiate_tag=*/connect_params_.verification_tag,
+ /*a_rwnd=*/options_.max_receiver_window_buffer_size,
+ options_.announced_maximum_outgoing_streams,
+ options_.announced_maximum_incoming_streams,
+ connect_params_.initial_tsn, params_builder.Build());
+ SctpPacket::Builder b(VerificationTag(0), options_);
+ b.Add(init);
+ packet_sender_.Send(b);
+}
+
+void DcSctpSocket::MakeConnectionParameters() {
+ VerificationTag new_verification_tag(
+ callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag));
+ TSN initial_tsn(callbacks_.GetRandomInt(kMinInitialTsn, kMaxInitialTsn));
+ connect_params_.initial_tsn = initial_tsn;
+ connect_params_.verification_tag = new_verification_tag;
+}
+
+void DcSctpSocket::Connect() {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
+ if (state_ == State::kClosed) {
+ MakeConnectionParameters();
+ RTC_DLOG(LS_INFO)
+ << log_prefix()
+ << rtc::StringFormat(
+ "Connecting. my_verification_tag=%08x, my_initial_tsn=%u",
+ *connect_params_.verification_tag, *connect_params_.initial_tsn);
+ SendInit();
+ t1_init_->Start();
+ SetState(State::kCookieWait, "Connect called");
+ } else {
+ RTC_DLOG(LS_WARNING) << log_prefix()
+ << "Called Connect on a socket that is not closed";
+ }
+ RTC_DCHECK(IsConsistent());
+}
+
+void DcSctpSocket::CreateTransmissionControlBlock(
+ const Capabilities& capabilities,
+ VerificationTag my_verification_tag,
+ TSN my_initial_tsn,
+ VerificationTag peer_verification_tag,
+ TSN peer_initial_tsn,
+ size_t a_rwnd,
+ TieTag tie_tag) {
+ metrics_.uses_message_interleaving = capabilities.message_interleaving;
+ tcb_ = std::make_unique<TransmissionControlBlock>(
+ timer_manager_, log_prefix_, options_, capabilities, callbacks_,
+ send_queue_, my_verification_tag, my_initial_tsn, peer_verification_tag,
+ peer_initial_tsn, a_rwnd, tie_tag, packet_sender_,
+ [this]() { return state_ == State::kEstablished; });
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Created TCB: " << tcb_->ToString();
+}
+
+void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
+ if (state_ != State::kClosed) {
+ callbacks_.OnError(ErrorKind::kUnsupportedOperation,
+ "Only closed socket can be restored from state");
+ } else {
+ if (state.socket_state ==
+ DcSctpSocketHandoverState::SocketState::kConnected) {
+ VerificationTag my_verification_tag =
+ VerificationTag(state.my_verification_tag);
+ connect_params_.verification_tag = my_verification_tag;
+
+ Capabilities capabilities;
+ capabilities.partial_reliability = state.capabilities.partial_reliability;
+ capabilities.message_interleaving =
+ state.capabilities.message_interleaving;
+ capabilities.reconfig = state.capabilities.reconfig;
+
+ send_queue_.RestoreFromState(state);
+
+ CreateTransmissionControlBlock(
+ capabilities, my_verification_tag, TSN(state.my_initial_tsn),
+ VerificationTag(state.peer_verification_tag),
+ TSN(state.peer_initial_tsn), static_cast<size_t>(0),
+ TieTag(state.tie_tag));
+
+ tcb_->RestoreFromState(state);
+
+ SetState(State::kEstablished, "restored from handover state");
+ callbacks_.OnConnected();
+ }
+ }
+
+ RTC_DCHECK(IsConsistent());
+}
+
+void DcSctpSocket::Shutdown() {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
+ if (tcb_ != nullptr) {
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "Upon receipt of the SHUTDOWN primitive from its upper layer, the
+ // endpoint enters the SHUTDOWN-PENDING state and remains there until all
+ // outstanding data has been acknowledged by its peer."
+
+ // TODO(webrtc:12739): Remove this check, as it just hides the problem that
+ // the socket can transition from ShutdownSent to ShutdownPending, or
+ // ShutdownAckSent to ShutdownPending which is illegal.
+ if (state_ != State::kShutdownSent && state_ != State::kShutdownAckSent) {
+ SetState(State::kShutdownPending, "Shutdown called");
+ t1_init_->Stop();
+ t1_cookie_->Stop();
+ MaybeSendShutdownOrAck();
+ }
+ } else {
+ // Connection closed before even starting to connect, or during the initial
+ // connection phase. There is no outstanding data, so the socket can just
+ // be closed (stopping any connection timers, if any), as this is the
+ // client's intention, by calling Shutdown.
+ InternalClose(ErrorKind::kNoError, "");
+ }
+ RTC_DCHECK(IsConsistent());
+}
+
+void DcSctpSocket::Close() {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
+ if (state_ != State::kClosed) {
+ if (tcb_ != nullptr) {
+ SctpPacket::Builder b = tcb_->PacketBuilder();
+ b.Add(AbortChunk(/*filled_in_verification_tag=*/true,
+ Parameters::Builder()
+ .Add(UserInitiatedAbortCause("Close called"))
+ .Build()));
+ packet_sender_.Send(b);
+ }
+ InternalClose(ErrorKind::kNoError, "");
+ } else {
+ RTC_DLOG(LS_INFO) << log_prefix() << "Called Close on a closed socket";
+ }
+ RTC_DCHECK(IsConsistent());
+}
+
+void DcSctpSocket::CloseConnectionBecauseOfTooManyTransmissionErrors() {
+ packet_sender_.Send(tcb_->PacketBuilder().Add(AbortChunk(
+ true, Parameters::Builder()
+ .Add(UserInitiatedAbortCause("Too many retransmissions"))
+ .Build())));
+ InternalClose(ErrorKind::kTooManyRetries, "Too many retransmissions");
+}
+
+void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) {
+ if (state_ != State::kClosed) {
+ t1_init_->Stop();
+ t1_cookie_->Stop();
+ t2_shutdown_->Stop();
+ tcb_ = nullptr;
+
+ if (error == ErrorKind::kNoError) {
+ callbacks_.OnClosed();
+ } else {
+ callbacks_.OnAborted(error, message);
+ }
+ SetState(State::kClosed, message);
+ }
+ // This method's purpose is to abort/close and make it consistent by ensuring
+ // that e.g. all timers really are stopped.
+ RTC_DCHECK(IsConsistent());
+}
+
+void DcSctpSocket::SetStreamPriority(StreamID stream_id,
+ StreamPriority priority) {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ send_queue_.SetStreamPriority(stream_id, priority);
+}
+StreamPriority DcSctpSocket::GetStreamPriority(StreamID stream_id) const {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ return send_queue_.GetStreamPriority(stream_id);
+}
+
+SendStatus DcSctpSocket::Send(DcSctpMessage message,
+ const SendOptions& send_options) {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+ LifecycleId lifecycle_id = send_options.lifecycle_id;
+
+ if (message.payload().empty()) {
+ if (lifecycle_id.IsSet()) {
+ callbacks_.OnLifecycleEnd(lifecycle_id);
+ }
+ callbacks_.OnError(ErrorKind::kProtocolViolation,
+ "Unable to send empty message");
+ return SendStatus::kErrorMessageEmpty;
+ }
+ if (message.payload().size() > options_.max_message_size) {
+ if (lifecycle_id.IsSet()) {
+ callbacks_.OnLifecycleEnd(lifecycle_id);
+ }
+ callbacks_.OnError(ErrorKind::kProtocolViolation,
+ "Unable to send too large message");
+ return SendStatus::kErrorMessageTooLarge;
+ }
+ if (state_ == State::kShutdownPending || state_ == State::kShutdownSent ||
+ state_ == State::kShutdownReceived || state_ == State::kShutdownAckSent) {
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "An endpoint should reject any new data request from its upper layer
+ // if it is in the SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED, or
+ // SHUTDOWN-ACK-SENT state."
+ if (lifecycle_id.IsSet()) {
+ callbacks_.OnLifecycleEnd(lifecycle_id);
+ }
+ callbacks_.OnError(ErrorKind::kWrongSequence,
+ "Unable to send message as the socket is shutting down");
+ return SendStatus::kErrorShuttingDown;
+ }
+ if (send_queue_.IsFull()) {
+ if (lifecycle_id.IsSet()) {
+ callbacks_.OnLifecycleEnd(lifecycle_id);
+ }
+ callbacks_.OnError(ErrorKind::kResourceExhaustion,
+ "Unable to send message as the send queue is full");
+ return SendStatus::kErrorResourceExhaustion;
+ }
+
+ TimeMs now = callbacks_.TimeMillis();
+ ++metrics_.tx_messages_count;
+ send_queue_.Add(now, std::move(message), send_options);
+ if (tcb_ != nullptr) {
+ tcb_->SendBufferedPackets(now);
+ }
+
+ RTC_DCHECK(IsConsistent());
+ return SendStatus::kSuccess;
+}
+
+ResetStreamsStatus DcSctpSocket::ResetStreams(
+ rtc::ArrayView<const StreamID> outgoing_streams) {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
+ if (tcb_ == nullptr) {
+ callbacks_.OnError(ErrorKind::kWrongSequence,
+ "Can't reset streams as the socket is not connected");
+ return ResetStreamsStatus::kNotConnected;
+ }
+ if (!tcb_->capabilities().reconfig) {
+ callbacks_.OnError(ErrorKind::kUnsupportedOperation,
+ "Can't reset streams as the peer doesn't support it");
+ return ResetStreamsStatus::kNotSupported;
+ }
+
+ tcb_->stream_reset_handler().ResetStreams(outgoing_streams);
+ MaybeSendResetStreamsRequest();
+
+ RTC_DCHECK(IsConsistent());
+ return ResetStreamsStatus::kPerformed;
+}
+
+SocketState DcSctpSocket::state() const {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ switch (state_) {
+ case State::kClosed:
+ return SocketState::kClosed;
+ case State::kCookieWait:
+ case State::kCookieEchoed:
+ return SocketState::kConnecting;
+ case State::kEstablished:
+ return SocketState::kConnected;
+ case State::kShutdownPending:
+ case State::kShutdownSent:
+ case State::kShutdownReceived:
+ case State::kShutdownAckSent:
+ return SocketState::kShuttingDown;
+ }
+}
+
+void DcSctpSocket::SetMaxMessageSize(size_t max_message_size) {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ options_.max_message_size = max_message_size;
+}
+
+size_t DcSctpSocket::buffered_amount(StreamID stream_id) const {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ return send_queue_.buffered_amount(stream_id);
+}
+
+size_t DcSctpSocket::buffered_amount_low_threshold(StreamID stream_id) const {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ return send_queue_.buffered_amount_low_threshold(stream_id);
+}
+
+void DcSctpSocket::SetBufferedAmountLowThreshold(StreamID stream_id,
+ size_t bytes) {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ send_queue_.SetBufferedAmountLowThreshold(stream_id, bytes);
+}
+
+absl::optional<Metrics> DcSctpSocket::GetMetrics() const {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+
+ if (tcb_ == nullptr) {
+ return absl::nullopt;
+ }
+
+ Metrics metrics = metrics_;
+ metrics.cwnd_bytes = tcb_->cwnd();
+ metrics.srtt_ms = tcb_->current_srtt().value();
+ size_t packet_payload_size =
+ options_.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
+ metrics.unack_data_count =
+ tcb_->retransmission_queue().outstanding_items() +
+ (send_queue_.total_buffered_amount() + packet_payload_size - 1) /
+ packet_payload_size;
+ metrics.peer_rwnd_bytes = tcb_->retransmission_queue().rwnd();
+
+ return metrics;
+}
+
+void DcSctpSocket::MaybeSendShutdownOnPacketReceived(const SctpPacket& packet) {
+ if (state_ == State::kShutdownSent) {
+ bool has_data_chunk =
+ std::find_if(packet.descriptors().begin(), packet.descriptors().end(),
+ [](const SctpPacket::ChunkDescriptor& descriptor) {
+ return descriptor.type == DataChunk::kType;
+ }) != packet.descriptors().end();
+ if (has_data_chunk) {
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "While in the SHUTDOWN-SENT state, the SHUTDOWN sender MUST immediately
+ // respond to each received packet containing one or more DATA chunks with
+ // a SHUTDOWN chunk and restart the T2-shutdown timer.""
+ SendShutdown();
+ t2_shutdown_->set_duration(tcb_->current_rto());
+ t2_shutdown_->Start();
+ }
+ }
+}
+
+void DcSctpSocket::MaybeSendResetStreamsRequest() {
+ absl::optional<ReConfigChunk> reconfig =
+ tcb_->stream_reset_handler().MakeStreamResetRequest();
+ if (reconfig.has_value()) {
+ SctpPacket::Builder builder = tcb_->PacketBuilder();
+ builder.Add(*reconfig);
+ packet_sender_.Send(builder);
+ }
+}
+
+bool DcSctpSocket::ValidatePacket(const SctpPacket& packet) {
+ const CommonHeader& header = packet.common_header();
+ VerificationTag my_verification_tag =
+ tcb_ != nullptr ? tcb_->my_verification_tag() : VerificationTag(0);
+
+ if (header.verification_tag == VerificationTag(0)) {
+ if (packet.descriptors().size() == 1 &&
+ packet.descriptors()[0].type == InitChunk::kType) {
+ // https://tools.ietf.org/html/rfc4960#section-8.5.1
+ // "When an endpoint receives an SCTP packet with the Verification Tag
+ // set to 0, it should verify that the packet contains only an INIT chunk.
+ // Otherwise, the receiver MUST silently discard the packet.""
+ return true;
+ }
+ callbacks_.OnError(
+ ErrorKind::kParseFailed,
+ "Only a single INIT chunk can be present in packets sent on "
+ "verification_tag = 0");
+ return false;
+ }
+
+ if (packet.descriptors().size() == 1 &&
+ packet.descriptors()[0].type == AbortChunk::kType) {
+ // https://tools.ietf.org/html/rfc4960#section-8.5.1
+ // "The receiver of an ABORT MUST accept the packet if the Verification
+ // Tag field of the packet matches its own tag and the T bit is not set OR
+ // if it is set to its peer's tag and the T bit is set in the Chunk Flags.
+ // Otherwise, the receiver MUST silently discard the packet and take no
+ // further action."
+ bool t_bit = (packet.descriptors()[0].flags & 0x01) != 0;
+ if (t_bit && tcb_ == nullptr) {
+ // Can't verify the tag - assume it's okey.
+ return true;
+ }
+ if ((!t_bit && header.verification_tag == my_verification_tag) ||
+ (t_bit && header.verification_tag == tcb_->peer_verification_tag())) {
+ return true;
+ }
+ callbacks_.OnError(ErrorKind::kParseFailed,
+ "ABORT chunk verification tag was wrong");
+ return false;
+ }
+
+ if (packet.descriptors()[0].type == InitAckChunk::kType) {
+ if (header.verification_tag == connect_params_.verification_tag) {
+ return true;
+ }
+ callbacks_.OnError(
+ ErrorKind::kParseFailed,
+ rtc::StringFormat(
+ "Packet has invalid verification tag: %08x, expected %08x",
+ *header.verification_tag, *connect_params_.verification_tag));
+ return false;
+ }
+
+ if (packet.descriptors()[0].type == CookieEchoChunk::kType) {
+ // Handled in chunk handler (due to RFC 4960, section 5.2.4).
+ return true;
+ }
+
+ if (packet.descriptors().size() == 1 &&
+ packet.descriptors()[0].type == ShutdownCompleteChunk::kType) {
+ // https://tools.ietf.org/html/rfc4960#section-8.5.1
+ // "The receiver of a SHUTDOWN COMPLETE shall accept the packet if the
+ // Verification Tag field of the packet matches its own tag and the T bit is
+ // not set OR if it is set to its peer's tag and the T bit is set in the
+ // Chunk Flags. Otherwise, the receiver MUST silently discard the packet
+ // and take no further action."
+ bool t_bit = (packet.descriptors()[0].flags & 0x01) != 0;
+ if (t_bit && tcb_ == nullptr) {
+ // Can't verify the tag - assume it's okey.
+ return true;
+ }
+ if ((!t_bit && header.verification_tag == my_verification_tag) ||
+ (t_bit && header.verification_tag == tcb_->peer_verification_tag())) {
+ return true;
+ }
+ callbacks_.OnError(ErrorKind::kParseFailed,
+ "SHUTDOWN_COMPLETE chunk verification tag was wrong");
+ return false;
+ }
+
+ // https://tools.ietf.org/html/rfc4960#section-8.5
+ // "When receiving an SCTP packet, the endpoint MUST ensure that the value
+ // in the Verification Tag field of the received SCTP packet matches its own
+ // tag. If the received Verification Tag value does not match the receiver's
+ // own tag value, the receiver shall silently discard the packet and shall not
+ // process it any further..."
+ if (header.verification_tag == my_verification_tag) {
+ return true;
+ }
+
+ callbacks_.OnError(
+ ErrorKind::kParseFailed,
+ rtc::StringFormat(
+ "Packet has invalid verification tag: %08x, expected %08x",
+ *header.verification_tag, *my_verification_tag));
+ return false;
+}
+
+void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
+ timer_manager_.HandleTimeout(timeout_id);
+
+ if (tcb_ != nullptr && tcb_->HasTooManyTxErrors()) {
+ // Tearing down the TCB has to be done outside the handlers.
+ CloseConnectionBecauseOfTooManyTransmissionErrors();
+ }
+
+ RTC_DCHECK(IsConsistent());
+}
+
+void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
+ ++metrics_.rx_packets_count;
+
+ if (packet_observer_ != nullptr) {
+ packet_observer_->OnReceivedPacket(callbacks_.TimeMillis(), data);
+ }
+
+ absl::optional<SctpPacket> packet =
+ SctpPacket::Parse(data, options_.disable_checksum_verification);
+ if (!packet.has_value()) {
+ // https://tools.ietf.org/html/rfc4960#section-6.8
+ // "The default procedure for handling invalid SCTP packets is to
+ // silently discard them."
+ callbacks_.OnError(ErrorKind::kParseFailed,
+ "Failed to parse received SCTP packet");
+ RTC_DCHECK(IsConsistent());
+ return;
+ }
+
+ if (RTC_DLOG_IS_ON) {
+ for (const auto& descriptor : packet->descriptors()) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received "
+ << DebugConvertChunkToString(descriptor.data);
+ }
+ }
+
+ if (!ValidatePacket(*packet)) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Packet failed verification tag check - dropping";
+ RTC_DCHECK(IsConsistent());
+ return;
+ }
+
+ MaybeSendShutdownOnPacketReceived(*packet);
+
+ for (const auto& descriptor : packet->descriptors()) {
+ if (!Dispatch(packet->common_header(), descriptor)) {
+ break;
+ }
+ }
+
+ if (tcb_ != nullptr) {
+ tcb_->data_tracker().ObservePacketEnd();
+ tcb_->MaybeSendSack();
+ }
+
+ RTC_DCHECK(IsConsistent());
+}
+
+void DcSctpSocket::DebugPrintOutgoing(rtc::ArrayView<const uint8_t> payload) {
+ auto packet = SctpPacket::Parse(payload);
+ RTC_DCHECK(packet.has_value());
+
+ for (const auto& desc : packet->descriptors()) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Sent "
+ << DebugConvertChunkToString(desc.data);
+ }
+}
+
+bool DcSctpSocket::Dispatch(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ switch (descriptor.type) {
+ case DataChunk::kType:
+ HandleData(header, descriptor);
+ break;
+ case InitChunk::kType:
+ HandleInit(header, descriptor);
+ break;
+ case InitAckChunk::kType:
+ HandleInitAck(header, descriptor);
+ break;
+ case SackChunk::kType:
+ HandleSack(header, descriptor);
+ break;
+ case HeartbeatRequestChunk::kType:
+ HandleHeartbeatRequest(header, descriptor);
+ break;
+ case HeartbeatAckChunk::kType:
+ HandleHeartbeatAck(header, descriptor);
+ break;
+ case AbortChunk::kType:
+ HandleAbort(header, descriptor);
+ break;
+ case ErrorChunk::kType:
+ HandleError(header, descriptor);
+ break;
+ case CookieEchoChunk::kType:
+ HandleCookieEcho(header, descriptor);
+ break;
+ case CookieAckChunk::kType:
+ HandleCookieAck(header, descriptor);
+ break;
+ case ShutdownChunk::kType:
+ HandleShutdown(header, descriptor);
+ break;
+ case ShutdownAckChunk::kType:
+ HandleShutdownAck(header, descriptor);
+ break;
+ case ShutdownCompleteChunk::kType:
+ HandleShutdownComplete(header, descriptor);
+ break;
+ case ReConfigChunk::kType:
+ HandleReconfig(header, descriptor);
+ break;
+ case ForwardTsnChunk::kType:
+ HandleForwardTsn(header, descriptor);
+ break;
+ case IDataChunk::kType:
+ HandleIData(header, descriptor);
+ break;
+ case IForwardTsnChunk::kType:
+ HandleIForwardTsn(header, descriptor);
+ break;
+ default:
+ return HandleUnrecognizedChunk(descriptor);
+ }
+ return true;
+}
+
+bool DcSctpSocket::HandleUnrecognizedChunk(
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ bool report_as_error = (descriptor.type & 0x40) != 0;
+ bool continue_processing = (descriptor.type & 0x80) != 0;
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received unknown chunk: "
+ << static_cast<int>(descriptor.type);
+ if (report_as_error) {
+ rtc::StringBuilder sb;
+ sb << "Received unknown chunk of type: "
+ << static_cast<int>(descriptor.type) << " with report-error bit set";
+ callbacks_.OnError(ErrorKind::kParseFailed, sb.str());
+ RTC_DLOG(LS_VERBOSE)
+ << log_prefix()
+ << "Unknown chunk, with type indicating it should be reported.";
+
+ // https://tools.ietf.org/html/rfc4960#section-3.2
+ // "... report in an ERROR chunk using the 'Unrecognized Chunk Type'
+ // cause."
+ if (tcb_ != nullptr) {
+ // Need TCB - this chunk must be sent with a correct verification tag.
+ packet_sender_.Send(tcb_->PacketBuilder().Add(
+ ErrorChunk(Parameters::Builder()
+ .Add(UnrecognizedChunkTypeCause(std::vector<uint8_t>(
+ descriptor.data.begin(), descriptor.data.end())))
+ .Build())));
+ }
+ }
+ if (!continue_processing) {
+ // https://tools.ietf.org/html/rfc4960#section-3.2
+ // "Stop processing this SCTP packet and discard it, do not process any
+ // further chunks within it."
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Unknown chunk, with type indicating not to "
+ "process any further chunks";
+ }
+
+ return continue_processing;
+}
+
+absl::optional<DurationMs> DcSctpSocket::OnInitTimerExpiry() {
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t1_init_->name()
+ << " has expired: " << t1_init_->expiration_count()
+ << "/" << t1_init_->options().max_restarts.value_or(-1);
+ RTC_DCHECK(state_ == State::kCookieWait);
+
+ if (t1_init_->is_running()) {
+ SendInit();
+ } else {
+ InternalClose(ErrorKind::kTooManyRetries, "No INIT_ACK received");
+ }
+ RTC_DCHECK(IsConsistent());
+ return absl::nullopt;
+}
+
+absl::optional<DurationMs> DcSctpSocket::OnCookieTimerExpiry() {
+ // https://tools.ietf.org/html/rfc4960#section-4
+ // "If the T1-cookie timer expires, the endpoint MUST retransmit COOKIE
+ // ECHO and restart the T1-cookie timer without changing state. This MUST
+ // be repeated up to 'Max.Init.Retransmits' times. After that, the endpoint
+ // MUST abort the initialization process and report the error to the SCTP
+ // user."
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t1_cookie_->name()
+ << " has expired: " << t1_cookie_->expiration_count()
+ << "/"
+ << t1_cookie_->options().max_restarts.value_or(-1);
+
+ RTC_DCHECK(state_ == State::kCookieEchoed);
+
+ if (t1_cookie_->is_running()) {
+ tcb_->SendBufferedPackets(callbacks_.TimeMillis());
+ } else {
+ InternalClose(ErrorKind::kTooManyRetries, "No COOKIE_ACK received");
+ }
+
+ RTC_DCHECK(IsConsistent());
+ return absl::nullopt;
+}
+
+absl::optional<DurationMs> DcSctpSocket::OnShutdownTimerExpiry() {
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t2_shutdown_->name()
+ << " has expired: " << t2_shutdown_->expiration_count()
+ << "/"
+ << t2_shutdown_->options().max_restarts.value_or(-1);
+
+ if (!t2_shutdown_->is_running()) {
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "An endpoint should limit the number of retransmissions of the SHUTDOWN
+ // chunk to the protocol parameter 'Association.Max.Retrans'. If this
+ // threshold is exceeded, the endpoint should destroy the TCB..."
+
+ packet_sender_.Send(tcb_->PacketBuilder().Add(
+ AbortChunk(true, Parameters::Builder()
+ .Add(UserInitiatedAbortCause(
+ "Too many retransmissions of SHUTDOWN"))
+ .Build())));
+
+ InternalClose(ErrorKind::kTooManyRetries, "No SHUTDOWN_ACK received");
+ RTC_DCHECK(IsConsistent());
+ return absl::nullopt;
+ }
+
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "If the timer expires, the endpoint must resend the SHUTDOWN with the
+ // updated last sequential TSN received from its peer."
+ SendShutdown();
+ RTC_DCHECK(IsConsistent());
+ return tcb_->current_rto();
+}
+
+void DcSctpSocket::OnSentPacket(rtc::ArrayView<const uint8_t> packet,
+ SendPacketStatus status) {
+ // The packet observer is invoked even if the packet was failed to be sent, to
+ // indicate an attempt was made.
+ if (packet_observer_ != nullptr) {
+ packet_observer_->OnSentPacket(callbacks_.TimeMillis(), packet);
+ }
+
+ if (status == SendPacketStatus::kSuccess) {
+ if (RTC_DLOG_IS_ON) {
+ DebugPrintOutgoing(packet);
+ }
+
+ // The heartbeat interval timer is restarted for every sent packet, to
+ // fire when the outgoing channel is inactive.
+ if (tcb_ != nullptr) {
+ tcb_->heartbeat_handler().RestartTimer();
+ }
+
+ ++metrics_.tx_packets_count;
+ }
+}
+
+bool DcSctpSocket::ValidateHasTCB() {
+ if (tcb_ != nullptr) {
+ return true;
+ }
+
+ callbacks_.OnError(
+ ErrorKind::kNotConnected,
+ "Received unexpected commands on socket that is not connected");
+ return false;
+}
+
+void DcSctpSocket::ReportFailedToParseChunk(int chunk_type) {
+ rtc::StringBuilder sb;
+ sb << "Failed to parse chunk of type: " << chunk_type;
+ callbacks_.OnError(ErrorKind::kParseFailed, sb.str());
+}
+
+void DcSctpSocket::HandleData(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<DataChunk> chunk = DataChunk::Parse(descriptor.data);
+ if (ValidateParseSuccess(chunk) && ValidateHasTCB()) {
+ HandleDataCommon(*chunk);
+ }
+}
+
+void DcSctpSocket::HandleIData(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<IDataChunk> chunk = IDataChunk::Parse(descriptor.data);
+ if (ValidateParseSuccess(chunk) && ValidateHasTCB()) {
+ HandleDataCommon(*chunk);
+ }
+}
+
+void DcSctpSocket::HandleDataCommon(AnyDataChunk& chunk) {
+ TSN tsn = chunk.tsn();
+ AnyDataChunk::ImmediateAckFlag immediate_ack = chunk.options().immediate_ack;
+ Data data = std::move(chunk).extract();
+
+ if (data.payload.empty()) {
+ // Empty DATA chunks are illegal.
+ packet_sender_.Send(tcb_->PacketBuilder().Add(
+ ErrorChunk(Parameters::Builder().Add(NoUserDataCause(tsn)).Build())));
+ callbacks_.OnError(ErrorKind::kProtocolViolation,
+ "Received DATA chunk with no user data");
+ return;
+ }
+
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Handle DATA, queue_size="
+ << tcb_->reassembly_queue().queued_bytes()
+ << ", water_mark="
+ << tcb_->reassembly_queue().watermark_bytes()
+ << ", full=" << tcb_->reassembly_queue().is_full()
+ << ", above="
+ << tcb_->reassembly_queue().is_above_watermark();
+
+ if (tcb_->reassembly_queue().is_full()) {
+ // If the reassembly queue is full, there is nothing that can be done. The
+ // specification only allows dropping gap-ack-blocks, and that's not
+ // likely to help as the socket has been trying to fill gaps since the
+ // watermark was reached.
+ packet_sender_.Send(tcb_->PacketBuilder().Add(AbortChunk(
+ true, Parameters::Builder().Add(OutOfResourceErrorCause()).Build())));
+ InternalClose(ErrorKind::kResourceExhaustion,
+ "Reassembly Queue is exhausted");
+ return;
+ }
+
+ if (tcb_->reassembly_queue().is_above_watermark()) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Is above high watermark";
+ // If the reassembly queue is above its high watermark, only accept data
+ // chunks that increase its cumulative ack tsn in an attempt to fill gaps
+ // to deliver messages.
+ if (!tcb_->data_tracker().will_increase_cum_ack_tsn(tsn)) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Rejected data because of exceeding watermark";
+ tcb_->data_tracker().ForceImmediateSack();
+ return;
+ }
+ }
+
+ if (!tcb_->data_tracker().IsTSNValid(tsn)) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Rejected data because of failing TSN validity";
+ return;
+ }
+
+ if (tcb_->data_tracker().Observe(tsn, immediate_ack)) {
+ tcb_->reassembly_queue().MaybeResetStreamsDeferred(
+ tcb_->data_tracker().last_cumulative_acked_tsn());
+ tcb_->reassembly_queue().Add(tsn, std::move(data));
+ DeliverReassembledMessages();
+ }
+}
+
+void DcSctpSocket::HandleInit(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<InitChunk> chunk = InitChunk::Parse(descriptor.data);
+ if (!ValidateParseSuccess(chunk)) {
+ return;
+ }
+
+ if (chunk->initiate_tag() == VerificationTag(0) ||
+ chunk->nbr_outbound_streams() == 0 || chunk->nbr_inbound_streams() == 0) {
+ // https://tools.ietf.org/html/rfc4960#section-3.3.2
+ // "If the value of the Initiate Tag in a received INIT chunk is found
+ // to be 0, the receiver MUST treat it as an error and close the
+ // association by transmitting an ABORT."
+
+ // "A receiver of an INIT with the OS value set to 0 SHOULD abort the
+ // association."
+
+ // "A receiver of an INIT with the MIS value of 0 SHOULD abort the
+ // association."
+
+ packet_sender_.Send(
+ SctpPacket::Builder(VerificationTag(0), options_)
+ .Add(AbortChunk(
+ /*filled_in_verification_tag=*/false,
+ Parameters::Builder()
+ .Add(ProtocolViolationCause("INIT malformed"))
+ .Build())));
+ InternalClose(ErrorKind::kProtocolViolation, "Received invalid INIT");
+ return;
+ }
+
+ if (state_ == State::kShutdownAckSent) {
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "If an endpoint is in the SHUTDOWN-ACK-SENT state and receives an
+ // INIT chunk (e.g., if the SHUTDOWN COMPLETE was lost) with source and
+ // destination transport addresses (either in the IP addresses or in the
+ // INIT chunk) that belong to this association, it should discard the INIT
+ // chunk and retransmit the SHUTDOWN ACK chunk."
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Received Init indicating lost ShutdownComplete";
+ SendShutdownAck();
+ return;
+ }
+
+ TieTag tie_tag(0);
+ if (state_ == State::kClosed) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Received Init in closed state (normal)";
+
+ MakeConnectionParameters();
+ } else if (state_ == State::kCookieWait || state_ == State::kCookieEchoed) {
+ // https://tools.ietf.org/html/rfc4960#section-5.2.1
+ // "This usually indicates an initialization collision, i.e., each
+ // endpoint is attempting, at about the same time, to establish an
+ // association with the other endpoint. Upon receipt of an INIT in the
+ // COOKIE-WAIT state, an endpoint MUST respond with an INIT ACK using the
+ // same parameters it sent in its original INIT chunk (including its
+ // Initiate Tag, unchanged). When responding, the endpoint MUST send the
+ // INIT ACK back to the same address that the original INIT (sent by this
+ // endpoint) was sent."
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Received Init indicating simultaneous connections";
+ } else {
+ RTC_DCHECK(tcb_ != nullptr);
+ // https://tools.ietf.org/html/rfc4960#section-5.2.2
+ // "The outbound SCTP packet containing this INIT ACK MUST carry a
+ // Verification Tag value equal to the Initiate Tag found in the
+ // unexpected INIT. And the INIT ACK MUST contain a new Initiate Tag
+ // (randomly generated; see Section 5.3.1). Other parameters for the
+ // endpoint SHOULD be copied from the existing parameters of the
+ // association (e.g., number of outbound streams) into the INIT ACK and
+ // cookie."
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Received Init indicating restarted connection";
+ // Create a new verification tag - different from the previous one.
+ for (int tries = 0; tries < 10; ++tries) {
+ connect_params_.verification_tag = VerificationTag(
+ callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag));
+ if (connect_params_.verification_tag != tcb_->my_verification_tag()) {
+ break;
+ }
+ }
+
+ // Make the initial TSN make a large jump, so that there is no overlap
+ // with the old and new association.
+ connect_params_.initial_tsn =
+ TSN(*tcb_->retransmission_queue().next_tsn() + 1000000);
+ tie_tag = tcb_->tie_tag();
+ }
+
+ RTC_DLOG(LS_VERBOSE)
+ << log_prefix()
+ << rtc::StringFormat(
+ "Proceeding with connection. my_verification_tag=%08x, "
+ "my_initial_tsn=%u, peer_verification_tag=%08x, "
+ "peer_initial_tsn=%u",
+ *connect_params_.verification_tag, *connect_params_.initial_tsn,
+ *chunk->initiate_tag(), *chunk->initial_tsn());
+
+ Capabilities capabilities = GetCapabilities(options_, chunk->parameters());
+
+ SctpPacket::Builder b(chunk->initiate_tag(), options_);
+ Parameters::Builder params_builder =
+ Parameters::Builder().Add(StateCookieParameter(
+ StateCookie(chunk->initiate_tag(), chunk->initial_tsn(),
+ chunk->a_rwnd(), tie_tag, capabilities)
+ .Serialize()));
+ AddCapabilityParameters(options_, params_builder);
+
+ InitAckChunk init_ack(/*initiate_tag=*/connect_params_.verification_tag,
+ options_.max_receiver_window_buffer_size,
+ options_.announced_maximum_outgoing_streams,
+ options_.announced_maximum_incoming_streams,
+ connect_params_.initial_tsn, params_builder.Build());
+ b.Add(init_ack);
+ packet_sender_.Send(b);
+}
+
+void DcSctpSocket::HandleInitAck(
+ const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<InitAckChunk> chunk = InitAckChunk::Parse(descriptor.data);
+ if (!ValidateParseSuccess(chunk)) {
+ return;
+ }
+
+ if (state_ != State::kCookieWait) {
+ // https://tools.ietf.org/html/rfc4960#section-5.2.3
+ // "If an INIT ACK is received by an endpoint in any state other than
+ // the COOKIE-WAIT state, the endpoint should discard the INIT ACK chunk."
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Received INIT_ACK in unexpected state";
+ return;
+ }
+
+ auto cookie = chunk->parameters().get<StateCookieParameter>();
+ if (!cookie.has_value()) {
+ packet_sender_.Send(
+ SctpPacket::Builder(connect_params_.verification_tag, options_)
+ .Add(AbortChunk(
+ /*filled_in_verification_tag=*/false,
+ Parameters::Builder()
+ .Add(ProtocolViolationCause("INIT-ACK malformed"))
+ .Build())));
+ InternalClose(ErrorKind::kProtocolViolation,
+ "InitAck chunk doesn't contain a cookie");
+ return;
+ }
+ Capabilities capabilities = GetCapabilities(options_, chunk->parameters());
+ t1_init_->Stop();
+
+ metrics_.peer_implementation = DeterminePeerImplementation(cookie->data());
+
+ // If the connection is re-established (peer restarted, but re-used old
+ // connection), make sure that all message identifiers are reset and any
+ // partly sent message is re-sent in full. The same is true when the socket
+ // is closed and later re-opened, which never happens in WebRTC, but is a
+ // valid operation on the SCTP level. Note that in case of handover, the
+ // send queue is already re-configured, and shouldn't be reset.
+ send_queue_.Reset();
+
+ CreateTransmissionControlBlock(capabilities, connect_params_.verification_tag,
+ connect_params_.initial_tsn,
+ chunk->initiate_tag(), chunk->initial_tsn(),
+ chunk->a_rwnd(), MakeTieTag(callbacks_));
+
+ SetState(State::kCookieEchoed, "INIT_ACK received");
+
+ // The connection isn't fully established just yet.
+ tcb_->SetCookieEchoChunk(CookieEchoChunk(cookie->data()));
+ tcb_->SendBufferedPackets(callbacks_.TimeMillis());
+ t1_cookie_->Start();
+}
+
+void DcSctpSocket::HandleCookieEcho(
+ const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<CookieEchoChunk> chunk =
+ CookieEchoChunk::Parse(descriptor.data);
+ if (!ValidateParseSuccess(chunk)) {
+ return;
+ }
+
+ absl::optional<StateCookie> cookie =
+ StateCookie::Deserialize(chunk->cookie());
+ if (!cookie.has_value()) {
+ callbacks_.OnError(ErrorKind::kParseFailed, "Failed to parse state cookie");
+ return;
+ }
+
+ if (tcb_ != nullptr) {
+ if (!HandleCookieEchoWithTCB(header, *cookie)) {
+ return;
+ }
+ } else {
+ if (header.verification_tag != connect_params_.verification_tag) {
+ callbacks_.OnError(
+ ErrorKind::kParseFailed,
+ rtc::StringFormat(
+ "Received CookieEcho with invalid verification tag: %08x, "
+ "expected %08x",
+ *header.verification_tag, *connect_params_.verification_tag));
+ return;
+ }
+ }
+
+ // The init timer can be running on simultaneous connections.
+ t1_init_->Stop();
+ t1_cookie_->Stop();
+ if (state_ != State::kEstablished) {
+ if (tcb_ != nullptr) {
+ tcb_->ClearCookieEchoChunk();
+ }
+ SetState(State::kEstablished, "COOKIE_ECHO received");
+ callbacks_.OnConnected();
+ }
+
+ if (tcb_ == nullptr) {
+ // If the connection is re-established (peer restarted, but re-used old
+ // connection), make sure that all message identifiers are reset and any
+ // partly sent message is re-sent in full. The same is true when the socket
+ // is closed and later re-opened, which never happens in WebRTC, but is a
+ // valid operation on the SCTP level. Note that in case of handover, the
+ // send queue is already re-configured, and shouldn't be reset.
+ send_queue_.Reset();
+
+ CreateTransmissionControlBlock(
+ cookie->capabilities(), connect_params_.verification_tag,
+ connect_params_.initial_tsn, cookie->initiate_tag(),
+ cookie->initial_tsn(), cookie->a_rwnd(), MakeTieTag(callbacks_));
+ }
+
+ SctpPacket::Builder b = tcb_->PacketBuilder();
+ b.Add(CookieAckChunk());
+
+ // https://tools.ietf.org/html/rfc4960#section-5.1
+ // "A COOKIE ACK chunk may be bundled with any pending DATA chunks (and/or
+ // SACK chunks), but the COOKIE ACK chunk MUST be the first chunk in the
+ // packet."
+ tcb_->SendBufferedPackets(b, callbacks_.TimeMillis());
+}
+
+bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header,
+ const StateCookie& cookie) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Handling CookieEchoChunk with TCB. local_tag="
+ << *tcb_->my_verification_tag()
+ << ", peer_tag=" << *header.verification_tag
+ << ", tcb_tag=" << *tcb_->peer_verification_tag()
+ << ", cookie_tag=" << *cookie.initiate_tag()
+ << ", local_tie_tag=" << *tcb_->tie_tag()
+ << ", peer_tie_tag=" << *cookie.tie_tag();
+ // https://tools.ietf.org/html/rfc4960#section-5.2.4
+ // "Handle a COOKIE ECHO when a TCB Exists"
+ if (header.verification_tag != tcb_->my_verification_tag() &&
+ tcb_->peer_verification_tag() != cookie.initiate_tag() &&
+ cookie.tie_tag() == tcb_->tie_tag()) {
+ // "A) In this case, the peer may have restarted."
+ if (state_ == State::kShutdownAckSent) {
+ // "If the endpoint is in the SHUTDOWN-ACK-SENT state and recognizes
+ // that the peer has restarted ... it MUST NOT set up a new association
+ // but instead resend the SHUTDOWN ACK and send an ERROR chunk with a
+ // "Cookie Received While Shutting Down" error cause to its peer."
+ SctpPacket::Builder b(cookie.initiate_tag(), options_);
+ b.Add(ShutdownAckChunk());
+ b.Add(ErrorChunk(Parameters::Builder()
+ .Add(CookieReceivedWhileShuttingDownCause())
+ .Build()));
+ packet_sender_.Send(b);
+ callbacks_.OnError(ErrorKind::kWrongSequence,
+ "Received COOKIE-ECHO while shutting down");
+ return false;
+ }
+
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Received COOKIE-ECHO indicating a restarted peer";
+
+ tcb_ = nullptr;
+ callbacks_.OnConnectionRestarted();
+ } else if (header.verification_tag == tcb_->my_verification_tag() &&
+ tcb_->peer_verification_tag() != cookie.initiate_tag()) {
+ // TODO(boivie): Handle the peer_tag == 0?
+ // "B) In this case, both sides may be attempting to start an
+ // association at about the same time, but the peer endpoint started its
+ // INIT after responding to the local endpoint's INIT."
+ RTC_DLOG(LS_VERBOSE)
+ << log_prefix()
+ << "Received COOKIE-ECHO indicating simultaneous connections";
+ tcb_ = nullptr;
+ } else if (header.verification_tag != tcb_->my_verification_tag() &&
+ tcb_->peer_verification_tag() == cookie.initiate_tag() &&
+ cookie.tie_tag() == TieTag(0)) {
+ // "C) In this case, the local endpoint's cookie has arrived late.
+ // Before it arrived, the local endpoint sent an INIT and received an
+ // INIT ACK and finally sent a COOKIE ECHO with the peer's same tag but
+ // a new tag of its own. The cookie should be silently discarded. The
+ // endpoint SHOULD NOT change states and should leave any timers
+ // running."
+ RTC_DLOG(LS_VERBOSE)
+ << log_prefix()
+ << "Received COOKIE-ECHO indicating a late COOKIE-ECHO. Discarding";
+ return false;
+ } else if (header.verification_tag == tcb_->my_verification_tag() &&
+ tcb_->peer_verification_tag() == cookie.initiate_tag()) {
+ // "D) When both local and remote tags match, the endpoint should enter
+ // the ESTABLISHED state, if it is in the COOKIE-ECHOED state. It
+ // should stop any cookie timer that may be running and send a COOKIE
+ // ACK."
+ RTC_DLOG(LS_VERBOSE)
+ << log_prefix()
+ << "Received duplicate COOKIE-ECHO, probably because of peer not "
+ "receiving COOKIE-ACK and retransmitting COOKIE-ECHO. Continuing.";
+ }
+ return true;
+}
+
+void DcSctpSocket::HandleCookieAck(
+ const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<CookieAckChunk> chunk = CookieAckChunk::Parse(descriptor.data);
+ if (!ValidateParseSuccess(chunk)) {
+ return;
+ }
+
+ if (state_ != State::kCookieEchoed) {
+ // https://tools.ietf.org/html/rfc4960#section-5.2.5
+ // "At any state other than COOKIE-ECHOED, an endpoint should silently
+ // discard a received COOKIE ACK chunk."
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Received COOKIE_ACK not in COOKIE_ECHOED state";
+ return;
+ }
+
+ // RFC 4960, Errata ID: 4400
+ t1_cookie_->Stop();
+ tcb_->ClearCookieEchoChunk();
+ SetState(State::kEstablished, "COOKIE_ACK received");
+ tcb_->SendBufferedPackets(callbacks_.TimeMillis());
+ callbacks_.OnConnected();
+}
+
+void DcSctpSocket::DeliverReassembledMessages() {
+ if (tcb_->reassembly_queue().HasMessages()) {
+ for (auto& message : tcb_->reassembly_queue().FlushMessages()) {
+ ++metrics_.rx_messages_count;
+ callbacks_.OnMessageReceived(std::move(message));
+ }
+ }
+}
+
+void DcSctpSocket::HandleSack(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<SackChunk> chunk = SackChunk::Parse(descriptor.data);
+
+ if (ValidateParseSuccess(chunk) && ValidateHasTCB()) {
+ TimeMs now = callbacks_.TimeMillis();
+ SackChunk sack = ChunkValidators::Clean(*std::move(chunk));
+
+ if (tcb_->retransmission_queue().HandleSack(now, sack)) {
+ MaybeSendShutdownOrAck();
+ // Receiving an ACK may make the socket go into fast recovery mode.
+ // https://datatracker.ietf.org/doc/html/rfc4960#section-7.2.4
+ // "Determine how many of the earliest (i.e., lowest TSN) DATA chunks
+ // marked for retransmission will fit into a single packet, subject to
+ // constraint of the path MTU of the destination transport address to
+ // which the packet is being sent. Call this value K. Retransmit those K
+ // DATA chunks in a single packet. When a Fast Retransmit is being
+ // performed, the sender SHOULD ignore the value of cwnd and SHOULD NOT
+ // delay retransmission for this single packet."
+ tcb_->MaybeSendFastRetransmit();
+
+ // Receiving an ACK will decrease outstanding bytes (maybe now below
+ // cwnd?) or indicate packet loss that may result in sending FORWARD-TSN.
+ tcb_->SendBufferedPackets(now);
+ } else {
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Dropping out-of-order SACK with TSN "
+ << *sack.cumulative_tsn_ack();
+ }
+ }
+}
+
+void DcSctpSocket::HandleHeartbeatRequest(
+ const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<HeartbeatRequestChunk> chunk =
+ HeartbeatRequestChunk::Parse(descriptor.data);
+
+ if (ValidateParseSuccess(chunk) && ValidateHasTCB()) {
+ tcb_->heartbeat_handler().HandleHeartbeatRequest(*std::move(chunk));
+ }
+}
+
+void DcSctpSocket::HandleHeartbeatAck(
+ const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<HeartbeatAckChunk> chunk =
+ HeartbeatAckChunk::Parse(descriptor.data);
+
+ if (ValidateParseSuccess(chunk) && ValidateHasTCB()) {
+ tcb_->heartbeat_handler().HandleHeartbeatAck(*std::move(chunk));
+ }
+}
+
+void DcSctpSocket::HandleAbort(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<AbortChunk> chunk = AbortChunk::Parse(descriptor.data);
+ if (ValidateParseSuccess(chunk)) {
+ std::string error_string = ErrorCausesToString(chunk->error_causes());
+ if (tcb_ == nullptr) {
+ // https://tools.ietf.org/html/rfc4960#section-3.3.7
+ // "If an endpoint receives an ABORT with a format error or no TCB is
+ // found, it MUST silently discard it."
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received ABORT (" << error_string
+ << ") on a connection with no TCB. Ignoring";
+ return;
+ }
+
+ RTC_DLOG(LS_WARNING) << log_prefix() << "Received ABORT (" << error_string
+ << ") - closing connection.";
+ InternalClose(ErrorKind::kPeerReported, error_string);
+ }
+}
+
+void DcSctpSocket::HandleError(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<ErrorChunk> chunk = ErrorChunk::Parse(descriptor.data);
+ if (ValidateParseSuccess(chunk)) {
+ std::string error_string = ErrorCausesToString(chunk->error_causes());
+ if (tcb_ == nullptr) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received ERROR (" << error_string
+ << ") on a connection with no TCB. Ignoring";
+ return;
+ }
+
+ RTC_DLOG(LS_WARNING) << log_prefix() << "Received ERROR: " << error_string;
+ callbacks_.OnError(ErrorKind::kPeerReported,
+ "Peer reported error: " + error_string);
+ }
+}
+
+void DcSctpSocket::HandleReconfig(
+ const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<ReConfigChunk> chunk = ReConfigChunk::Parse(descriptor.data);
+ if (ValidateParseSuccess(chunk) && ValidateHasTCB()) {
+ tcb_->stream_reset_handler().HandleReConfig(*std::move(chunk));
+ // Handling this response may result in outgoing stream resets finishing
+ // (either successfully or with failure). If there still are pending streams
+ // that were waiting for this request to finish, continue resetting them.
+ MaybeSendResetStreamsRequest();
+ }
+}
+
+void DcSctpSocket::HandleShutdown(
+ const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ if (!ValidateParseSuccess(ShutdownChunk::Parse(descriptor.data))) {
+ return;
+ }
+
+ if (state_ == State::kClosed) {
+ return;
+ } else if (state_ == State::kCookieWait || state_ == State::kCookieEchoed) {
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "If a SHUTDOWN is received in the COOKIE-WAIT or COOKIE ECHOED state,
+ // the SHUTDOWN chunk SHOULD be silently discarded."
+ } else if (state_ == State::kShutdownSent) {
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "If an endpoint is in the SHUTDOWN-SENT state and receives a
+ // SHUTDOWN chunk from its peer, the endpoint shall respond immediately
+ // with a SHUTDOWN ACK to its peer, and move into the SHUTDOWN-ACK-SENT
+ // state restarting its T2-shutdown timer."
+ SendShutdownAck();
+ SetState(State::kShutdownAckSent, "SHUTDOWN received");
+ } else if (state_ == State::kShutdownAckSent) {
+ // TODO(webrtc:12739): This condition should be removed and handled by the
+ // next (state_ != State::kShutdownReceived).
+ return;
+ } else if (state_ != State::kShutdownReceived) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix()
+ << "Received SHUTDOWN - shutting down the socket";
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "Upon reception of the SHUTDOWN, the peer endpoint shall enter the
+ // SHUTDOWN-RECEIVED state, stop accepting new data from its SCTP user,
+ // and verify, by checking the Cumulative TSN Ack field of the chunk, that
+ // all its outstanding DATA chunks have been received by the SHUTDOWN
+ // sender."
+ SetState(State::kShutdownReceived, "SHUTDOWN received");
+ MaybeSendShutdownOrAck();
+ }
+}
+
+void DcSctpSocket::HandleShutdownAck(
+ const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ if (!ValidateParseSuccess(ShutdownAckChunk::Parse(descriptor.data))) {
+ return;
+ }
+
+ if (state_ == State::kShutdownSent || state_ == State::kShutdownAckSent) {
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "Upon the receipt of the SHUTDOWN ACK, the SHUTDOWN sender shall stop
+ // the T2-shutdown timer, send a SHUTDOWN COMPLETE chunk to its peer, and
+ // remove all record of the association."
+
+ // "If an endpoint is in the SHUTDOWN-ACK-SENT state and receives a
+ // SHUTDOWN ACK, it shall stop the T2-shutdown timer, send a SHUTDOWN
+ // COMPLETE chunk to its peer, and remove all record of the association."
+
+ SctpPacket::Builder b = tcb_->PacketBuilder();
+ b.Add(ShutdownCompleteChunk(/*tag_reflected=*/false));
+ packet_sender_.Send(b);
+ InternalClose(ErrorKind::kNoError, "");
+ } else {
+ // https://tools.ietf.org/html/rfc4960#section-8.5.1
+ // "If the receiver is in COOKIE-ECHOED or COOKIE-WAIT state
+ // the procedures in Section 8.4 SHOULD be followed; in other words, it
+ // should be treated as an Out Of The Blue packet."
+
+ // https://tools.ietf.org/html/rfc4960#section-8.4
+ // "If the packet contains a SHUTDOWN ACK chunk, the receiver
+ // should respond to the sender of the OOTB packet with a SHUTDOWN
+ // COMPLETE. When sending the SHUTDOWN COMPLETE, the receiver of the OOTB
+ // packet must fill in the Verification Tag field of the outbound packet
+ // with the Verification Tag received in the SHUTDOWN ACK and set the T
+ // bit in the Chunk Flags to indicate that the Verification Tag is
+ // reflected."
+
+ SctpPacket::Builder b(header.verification_tag, options_);
+ b.Add(ShutdownCompleteChunk(/*tag_reflected=*/true));
+ packet_sender_.Send(b);
+ }
+}
+
+void DcSctpSocket::HandleShutdownComplete(
+ const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ if (!ValidateParseSuccess(ShutdownCompleteChunk::Parse(descriptor.data))) {
+ return;
+ }
+
+ if (state_ == State::kShutdownAckSent) {
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "Upon reception of the SHUTDOWN COMPLETE chunk, the endpoint will
+ // verify that it is in the SHUTDOWN-ACK-SENT state; if it is not, the
+ // chunk should be discarded. If the endpoint is in the SHUTDOWN-ACK-SENT
+ // state, the endpoint should stop the T2-shutdown timer and remove all
+ // knowledge of the association (and thus the association enters the
+ // CLOSED state)."
+ InternalClose(ErrorKind::kNoError, "");
+ }
+}
+
+void DcSctpSocket::HandleForwardTsn(
+ const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<ForwardTsnChunk> chunk =
+ ForwardTsnChunk::Parse(descriptor.data);
+ if (ValidateParseSuccess(chunk) && ValidateHasTCB()) {
+ HandleForwardTsnCommon(*chunk);
+ }
+}
+
+void DcSctpSocket::HandleIForwardTsn(
+ const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor) {
+ absl::optional<IForwardTsnChunk> chunk =
+ IForwardTsnChunk::Parse(descriptor.data);
+ if (ValidateParseSuccess(chunk) && ValidateHasTCB()) {
+ HandleForwardTsnCommon(*chunk);
+ }
+}
+
+void DcSctpSocket::HandleForwardTsnCommon(const AnyForwardTsnChunk& chunk) {
+ if (!tcb_->capabilities().partial_reliability) {
+ SctpPacket::Builder b = tcb_->PacketBuilder();
+ b.Add(AbortChunk(/*filled_in_verification_tag=*/true,
+ Parameters::Builder()
+ .Add(ProtocolViolationCause(
+ "I-FORWARD-TSN received, but not indicated "
+ "during connection establishment"))
+ .Build()));
+ packet_sender_.Send(b);
+
+ callbacks_.OnError(ErrorKind::kProtocolViolation,
+ "Received a FORWARD_TSN without announced peer support");
+ return;
+ }
+ tcb_->data_tracker().HandleForwardTsn(chunk.new_cumulative_tsn());
+ tcb_->reassembly_queue().Handle(chunk);
+ // A forward TSN - for ordered streams - may allow messages to be
+ // delivered.
+ DeliverReassembledMessages();
+
+ // Processing a FORWARD_TSN might result in sending a SACK.
+ tcb_->MaybeSendSack();
+}
+
+void DcSctpSocket::MaybeSendShutdownOrAck() {
+ if (tcb_->retransmission_queue().outstanding_bytes() != 0) {
+ return;
+ }
+
+ if (state_ == State::kShutdownPending) {
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "Once all its outstanding data has been acknowledged, the endpoint
+ // shall send a SHUTDOWN chunk to its peer including in the Cumulative TSN
+ // Ack field the last sequential TSN it has received from the peer. It
+ // shall then start the T2-shutdown timer and enter the SHUTDOWN-SENT
+ // state.""
+
+ SendShutdown();
+ t2_shutdown_->set_duration(tcb_->current_rto());
+ t2_shutdown_->Start();
+ SetState(State::kShutdownSent, "No more outstanding data");
+ } else if (state_ == State::kShutdownReceived) {
+ // https://tools.ietf.org/html/rfc4960#section-9.2
+ // "If the receiver of the SHUTDOWN has no more outstanding DATA
+ // chunks, the SHUTDOWN receiver MUST send a SHUTDOWN ACK and start a
+ // T2-shutdown timer of its own, entering the SHUTDOWN-ACK-SENT state. If
+ // the timer expires, the endpoint must resend the SHUTDOWN ACK."
+
+ SendShutdownAck();
+ SetState(State::kShutdownAckSent, "No more outstanding data");
+ }
+}
+
+void DcSctpSocket::SendShutdown() {
+ SctpPacket::Builder b = tcb_->PacketBuilder();
+ b.Add(ShutdownChunk(tcb_->data_tracker().last_cumulative_acked_tsn()));
+ packet_sender_.Send(b);
+}
+
+void DcSctpSocket::SendShutdownAck() {
+ packet_sender_.Send(tcb_->PacketBuilder().Add(ShutdownAckChunk()));
+ t2_shutdown_->set_duration(tcb_->current_rto());
+ t2_shutdown_->Start();
+}
+
+HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ HandoverReadinessStatus status;
+ if (state_ != State::kClosed && state_ != State::kEstablished) {
+ status.Add(HandoverUnreadinessReason::kWrongConnectionState);
+ }
+ status.Add(send_queue_.GetHandoverReadiness());
+ if (tcb_) {
+ status.Add(tcb_->GetHandoverReadiness());
+ }
+ return status;
+}
+
+absl::optional<DcSctpSocketHandoverState>
+DcSctpSocket::GetHandoverStateAndClose() {
+ RTC_DCHECK_RUN_ON(&thread_checker_);
+ CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+
+ if (!GetHandoverReadiness().IsReady()) {
+ return absl::nullopt;
+ }
+
+ DcSctpSocketHandoverState state;
+
+ if (state_ == State::kClosed) {
+ state.socket_state = DcSctpSocketHandoverState::SocketState::kClosed;
+ } else if (state_ == State::kEstablished) {
+ state.socket_state = DcSctpSocketHandoverState::SocketState::kConnected;
+ tcb_->AddHandoverState(state);
+ send_queue_.AddHandoverState(state);
+ InternalClose(ErrorKind::kNoError, "handover");
+ }
+
+ return std::move(state);
+}
+
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h
new file mode 100644
index 0000000000..157c515d65
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h
@@ -0,0 +1,298 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_
+#define NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/strings/string_view.h"
+#include "api/array_view.h"
+#include "api/sequence_checker.h"
+#include "net/dcsctp/packet/chunk/abort_chunk.h"
+#include "net/dcsctp/packet/chunk/chunk.h"
+#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
+#include "net/dcsctp/packet/chunk/data_chunk.h"
+#include "net/dcsctp/packet/chunk/data_common.h"
+#include "net/dcsctp/packet/chunk/error_chunk.h"
+#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h"
+#include "net/dcsctp/packet/chunk/forward_tsn_common.h"
+#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
+#include "net/dcsctp/packet/chunk/idata_chunk.h"
+#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h"
+#include "net/dcsctp/packet/chunk/init_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/init_chunk.h"
+#include "net/dcsctp/packet/chunk/reconfig_chunk.h"
+#include "net/dcsctp/packet/chunk/sack_chunk.h"
+#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/shutdown_chunk.h"
+#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h"
+#include "net/dcsctp/packet/data.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_message.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/public/packet_observer.h"
+#include "net/dcsctp/rx/data_tracker.h"
+#include "net/dcsctp/rx/reassembly_queue.h"
+#include "net/dcsctp/socket/callback_deferrer.h"
+#include "net/dcsctp/socket/packet_sender.h"
+#include "net/dcsctp/socket/state_cookie.h"
+#include "net/dcsctp/socket/transmission_control_block.h"
+#include "net/dcsctp/timer/timer.h"
+#include "net/dcsctp/tx/retransmission_error_counter.h"
+#include "net/dcsctp/tx/retransmission_queue.h"
+#include "net/dcsctp/tx/retransmission_timeout.h"
+#include "net/dcsctp/tx/rr_send_queue.h"
+
+namespace dcsctp {
+
+// DcSctpSocket represents a single SCTP socket, to be used over DTLS.
+//
+// Every dcSCTP is completely isolated from any other socket.
+//
+// This class manages all packet and chunk dispatching and mainly handles the
+// connection sequences (connect, close, shutdown, etc) as well as managing
+// the Transmission Control Block (tcb).
+//
+// This class is thread-compatible.
+class DcSctpSocket : public DcSctpSocketInterface {
+ public:
+ // Instantiates a DcSctpSocket, which interacts with the world through the
+ // `callbacks` interface and is configured using `options`.
+ //
+ // For debugging, `log_prefix` will prefix all debug logs, and a
+ // `packet_observer` can be attached to e.g. dump sent and received packets.
+ DcSctpSocket(absl::string_view log_prefix,
+ DcSctpSocketCallbacks& callbacks,
+ std::unique_ptr<PacketObserver> packet_observer,
+ const DcSctpOptions& options);
+
+ DcSctpSocket(const DcSctpSocket&) = delete;
+ DcSctpSocket& operator=(const DcSctpSocket&) = delete;
+
+ // Implementation of `DcSctpSocketInterface`.
+ void ReceivePacket(rtc::ArrayView<const uint8_t> data) override;
+ void HandleTimeout(TimeoutID timeout_id) override;
+ void Connect() override;
+ void RestoreFromState(const DcSctpSocketHandoverState& state) override;
+ void Shutdown() override;
+ void Close() override;
+ SendStatus Send(DcSctpMessage message,
+ const SendOptions& send_options) override;
+ ResetStreamsStatus ResetStreams(
+ rtc::ArrayView<const StreamID> outgoing_streams) override;
+ SocketState state() const override;
+ const DcSctpOptions& options() const override { return options_; }
+ void SetMaxMessageSize(size_t max_message_size) override;
+ void SetStreamPriority(StreamID stream_id, StreamPriority priority) override;
+ StreamPriority GetStreamPriority(StreamID stream_id) const override;
+ size_t buffered_amount(StreamID stream_id) const override;
+ size_t buffered_amount_low_threshold(StreamID stream_id) const override;
+ void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override;
+ absl::optional<Metrics> GetMetrics() const override;
+ HandoverReadinessStatus GetHandoverReadiness() const override;
+ absl::optional<DcSctpSocketHandoverState> GetHandoverStateAndClose() override;
+ SctpImplementation peer_implementation() const override {
+ return metrics_.peer_implementation;
+ }
+ // Returns this socket's verification tag, or zero if not yet connected.
+ VerificationTag verification_tag() const {
+ return tcb_ != nullptr ? tcb_->my_verification_tag() : VerificationTag(0);
+ }
+
+ private:
+ // Parameter proposals valid during the connect phase.
+ struct ConnectParameters {
+ TSN initial_tsn = TSN(0);
+ VerificationTag verification_tag = VerificationTag(0);
+ };
+
+ // Detailed state (separate from SocketState, which is the public state).
+ enum class State {
+ kClosed,
+ kCookieWait,
+ // TCB valid in these:
+ kCookieEchoed,
+ kEstablished,
+ kShutdownPending,
+ kShutdownSent,
+ kShutdownReceived,
+ kShutdownAckSent,
+ };
+
+ // Returns the log prefix used for debug logging.
+ std::string log_prefix() const;
+
+ bool IsConsistent() const;
+ static constexpr absl::string_view ToString(DcSctpSocket::State state);
+
+ void CreateTransmissionControlBlock(const Capabilities& capabilities,
+ VerificationTag my_verification_tag,
+ TSN my_initial_tsn,
+ VerificationTag peer_verification_tag,
+ TSN peer_initial_tsn,
+ size_t a_rwnd,
+ TieTag tie_tag);
+
+ // Changes the socket state, given a `reason` (for debugging/logging).
+ void SetState(State state, absl::string_view reason);
+ // Fills in `connect_params` with random verification tag and initial TSN.
+ void MakeConnectionParameters();
+ // Closes the association. Note that the TCB will not be valid past this call.
+ void InternalClose(ErrorKind error, absl::string_view message);
+ // Closes the association, because of too many retransmission errors.
+ void CloseConnectionBecauseOfTooManyTransmissionErrors();
+ // Timer expiration handlers
+ absl::optional<DurationMs> OnInitTimerExpiry();
+ absl::optional<DurationMs> OnCookieTimerExpiry();
+ absl::optional<DurationMs> OnShutdownTimerExpiry();
+ void OnSentPacket(rtc::ArrayView<const uint8_t> packet,
+ SendPacketStatus status);
+ // Sends SHUTDOWN or SHUTDOWN-ACK if the socket is shutting down and if all
+ // outstanding data has been acknowledged.
+ void MaybeSendShutdownOrAck();
+ // If the socket is shutting down, responds SHUTDOWN to any incoming DATA.
+ void MaybeSendShutdownOnPacketReceived(const SctpPacket& packet);
+ // If there are streams pending to be reset, send a request to reset them.
+ void MaybeSendResetStreamsRequest();
+ // Sends a INIT chunk.
+ void SendInit();
+ // Sends a SHUTDOWN chunk.
+ void SendShutdown();
+ // Sends a SHUTDOWN-ACK chunk.
+ void SendShutdownAck();
+ // Validates the SCTP packet, as a whole - not the validity of individual
+ // chunks within it, as that's done in the different chunk handlers.
+ bool ValidatePacket(const SctpPacket& packet);
+ // Parses `payload`, which is a serialized packet that is just going to be
+ // sent and prints all chunks.
+ void DebugPrintOutgoing(rtc::ArrayView<const uint8_t> payload);
+ // Called whenever there may be reassembled messages, and delivers those.
+ void DeliverReassembledMessages();
+ // Returns true if there is a TCB, and false otherwise (and reports an error).
+ bool ValidateHasTCB();
+
+ // Returns true if the parsing of a chunk of type `T` succeeded. If it didn't,
+ // it reports an error and returns false.
+ template <class T>
+ bool ValidateParseSuccess(const absl::optional<T>& c) {
+ if (c.has_value()) {
+ return true;
+ }
+
+ ReportFailedToParseChunk(T::kType);
+ return false;
+ }
+
+ // Reports failing to have parsed a chunk with the provided `chunk_type`.
+ void ReportFailedToParseChunk(int chunk_type);
+ // Called when unknown chunks are received. May report an error.
+ bool HandleUnrecognizedChunk(const SctpPacket::ChunkDescriptor& descriptor);
+
+ // Will dispatch more specific chunk handlers.
+ bool Dispatch(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming DATA chunks.
+ void HandleData(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming I-DATA chunks.
+ void HandleIData(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Common handler for DATA and I-DATA chunks.
+ void HandleDataCommon(AnyDataChunk& chunk);
+ // Handles incoming INIT chunks.
+ void HandleInit(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming INIT-ACK chunks.
+ void HandleInitAck(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming SACK chunks.
+ void HandleSack(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming HEARTBEAT chunks.
+ void HandleHeartbeatRequest(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming HEARTBEAT-ACK chunks.
+ void HandleHeartbeatAck(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming ABORT chunks.
+ void HandleAbort(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming ERROR chunks.
+ void HandleError(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming COOKIE-ECHO chunks.
+ void HandleCookieEcho(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles receiving COOKIE-ECHO when there already is a TCB. The return value
+ // indicates if the processing should continue.
+ bool HandleCookieEchoWithTCB(const CommonHeader& header,
+ const StateCookie& cookie);
+ // Handles incoming COOKIE-ACK chunks.
+ void HandleCookieAck(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming SHUTDOWN chunks.
+ void HandleShutdown(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming SHUTDOWN-ACK chunks.
+ void HandleShutdownAck(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming FORWARD-TSN chunks.
+ void HandleForwardTsn(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming I-FORWARD-TSN chunks.
+ void HandleIForwardTsn(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Handles incoming RE-CONFIG chunks.
+ void HandleReconfig(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+ // Common handled for FORWARD-TSN/I-FORWARD-TSN.
+ void HandleForwardTsnCommon(const AnyForwardTsnChunk& chunk);
+ // Handles incoming SHUTDOWN-COMPLETE chunks
+ void HandleShutdownComplete(const CommonHeader& header,
+ const SctpPacket::ChunkDescriptor& descriptor);
+
+ const std::string log_prefix_;
+ const std::unique_ptr<PacketObserver> packet_observer_;
+ RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker thread_checker_;
+ Metrics metrics_;
+ DcSctpOptions options_;
+
+ // Enqueues callbacks and dispatches them just before returning to the caller.
+ CallbackDeferrer callbacks_;
+
+ TimerManager timer_manager_;
+ const std::unique_ptr<Timer> t1_init_;
+ const std::unique_ptr<Timer> t1_cookie_;
+ const std::unique_ptr<Timer> t2_shutdown_;
+
+ // Packets that failed to be sent, but should be retried.
+ PacketSender packet_sender_;
+
+ // The actual SendQueue implementation. As data can be sent on a socket before
+ // the connection is established, this component is not in the TCB.
+ RRSendQueue send_queue_;
+
+ // Contains verification tag and initial TSN between having sent the INIT
+ // until the connection is established (there is no TCB at this point).
+ ConnectParameters connect_params_;
+ // The socket state.
+ State state_ = State::kClosed;
+ // If the connection is established, contains a transmission control block.
+ std::unique_ptr<TransmissionControlBlock> tcb_;
+};
+} // namespace dcsctp
+
+#endif // NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_
diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_network_test.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_network_test.cc
new file mode 100644
index 0000000000..f097bfa095
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_network_test.cc
@@ -0,0 +1,518 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include <cstdint>
+#include <deque>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "api/task_queue/pending_task_safety_flag.h"
+#include "api/task_queue/task_queue_base.h"
+#include "api/test/create_network_emulation_manager.h"
+#include "api/test/network_emulation_manager.h"
+#include "api/units/time_delta.h"
+#include "call/simulated_network.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/public/types.h"
+#include "net/dcsctp/socket/dcsctp_socket.h"
+#include "net/dcsctp/testing/testing_macros.h"
+#include "net/dcsctp/timer/task_queue_timeout.h"
+#include "rtc_base/copy_on_write_buffer.h"
+#include "rtc_base/gunit.h"
+#include "rtc_base/logging.h"
+#include "rtc_base/socket_address.h"
+#include "rtc_base/strings/string_format.h"
+#include "rtc_base/time_utils.h"
+#include "test/gmock.h"
+
+#if !defined(WEBRTC_ANDROID) && defined(NDEBUG) && \
+ !defined(THREAD_SANITIZER) && !defined(MEMORY_SANITIZER)
+#define DCSCTP_NDEBUG_TEST(t) t
+#else
+// In debug mode, and when MSAN or TSAN sanitizers are enabled, these tests are
+// too expensive to run due to extensive consistency checks that iterate on all
+// outstanding chunks. Same with low-end Android devices, which have
+// difficulties with these tests.
+#define DCSCTP_NDEBUG_TEST(t) DISABLED_##t
+#endif
+
+namespace dcsctp {
+namespace {
+using ::testing::AllOf;
+using ::testing::Ge;
+using ::testing::Le;
+using ::testing::SizeIs;
+
+constexpr StreamID kStreamId(1);
+constexpr PPID kPpid(53);
+constexpr size_t kSmallPayloadSize = 10;
+constexpr size_t kLargePayloadSize = 10000;
+constexpr size_t kHugePayloadSize = 262144;
+constexpr size_t kBufferedAmountLowThreshold = kLargePayloadSize * 2;
+constexpr webrtc::TimeDelta kPrintBandwidthDuration =
+ webrtc::TimeDelta::Seconds(1);
+constexpr webrtc::TimeDelta kBenchmarkRuntime(webrtc::TimeDelta::Seconds(10));
+constexpr webrtc::TimeDelta kAWhile(webrtc::TimeDelta::Seconds(1));
+
+inline int GetUniqueSeed() {
+ static int seed = 0;
+ return ++seed;
+}
+
+DcSctpOptions MakeOptionsForTest() {
+ DcSctpOptions options;
+
+ // Throughput numbers are affected by the MTU. Ensure it's constant.
+ options.mtu = 1200;
+
+ // By disabling the heartbeat interval, there will no timers at all running
+ // when the socket is idle, which makes it easy to just continue the test
+ // until there are no more scheduled tasks. Note that it _will_ run for longer
+ // than necessary as timers aren't cancelled when they are stopped (as that's
+ // not supported), but it's still simulated time and passes quickly.
+ options.heartbeat_interval = DurationMs(0);
+ return options;
+}
+
+// When doing throughput tests, knowing what each actor should do.
+enum class ActorMode {
+ kAtRest,
+ kThroughputSender,
+ kThroughputReceiver,
+ kLimitedRetransmissionSender,
+};
+
+// An abstraction around EmulatedEndpoint, representing a bound socket that
+// will send its packet to a given destination.
+class BoundSocket : public webrtc::EmulatedNetworkReceiverInterface {
+ public:
+ void Bind(webrtc::EmulatedEndpoint* endpoint) {
+ endpoint_ = endpoint;
+ uint16_t port = endpoint->BindReceiver(0, this).value();
+ source_address_ =
+ rtc::SocketAddress(endpoint_->GetPeerLocalAddress(), port);
+ }
+
+ void SetDestination(const BoundSocket& socket) {
+ dest_address_ = socket.source_address_;
+ }
+
+ void SetReceiver(std::function<void(rtc::CopyOnWriteBuffer)> receiver) {
+ receiver_ = std::move(receiver);
+ }
+
+ void SendPacket(rtc::ArrayView<const uint8_t> data) {
+ endpoint_->SendPacket(source_address_, dest_address_,
+ rtc::CopyOnWriteBuffer(data.data(), data.size()));
+ }
+
+ private:
+ // Implementation of `webrtc::EmulatedNetworkReceiverInterface`.
+ void OnPacketReceived(webrtc::EmulatedIpPacket packet) override {
+ receiver_(std::move(packet.data));
+ }
+
+ std::function<void(rtc::CopyOnWriteBuffer)> receiver_;
+ webrtc::EmulatedEndpoint* endpoint_ = nullptr;
+ rtc::SocketAddress source_address_;
+ rtc::SocketAddress dest_address_;
+};
+
+// Sends at a constant rate but with random packet sizes.
+class SctpActor : public DcSctpSocketCallbacks {
+ public:
+ SctpActor(absl::string_view name,
+ BoundSocket& emulated_socket,
+ const DcSctpOptions& sctp_options)
+ : log_prefix_(std::string(name) + ": "),
+ thread_(rtc::Thread::Current()),
+ emulated_socket_(emulated_socket),
+ timeout_factory_(
+ *thread_,
+ [this]() { return TimeMillis(); },
+ [this](dcsctp::TimeoutID timeout_id) {
+ sctp_socket_.HandleTimeout(timeout_id);
+ }),
+ random_(GetUniqueSeed()),
+ sctp_socket_(name, *this, nullptr, sctp_options),
+ last_bandwidth_printout_(TimeMs(TimeMillis())) {
+ emulated_socket.SetReceiver([this](rtc::CopyOnWriteBuffer buf) {
+ // The receiver will be executed on the NetworkEmulation task queue, but
+ // the dcSCTP socket is owned by `thread_` and is not thread-safe.
+ thread_->PostTask([this, buf] { this->sctp_socket_.ReceivePacket(buf); });
+ });
+ }
+
+ void PrintBandwidth() {
+ TimeMs now = TimeMillis();
+ DurationMs duration = now - last_bandwidth_printout_;
+
+ double bitrate_mbps =
+ static_cast<double>(received_bytes_ * 8) / *duration / 1000;
+ RTC_LOG(LS_INFO) << log_prefix()
+ << rtc::StringFormat("Received %0.2f Mbps", bitrate_mbps);
+
+ received_bitrate_mbps_.push_back(bitrate_mbps);
+ received_bytes_ = 0;
+ last_bandwidth_printout_ = now;
+ // Print again in a second.
+ if (mode_ == ActorMode::kThroughputReceiver) {
+ thread_->PostDelayedTask(
+ SafeTask(safety_.flag(), [this] { PrintBandwidth(); }),
+ kPrintBandwidthDuration);
+ }
+ }
+
+ void SendPacket(rtc::ArrayView<const uint8_t> data) override {
+ emulated_socket_.SendPacket(data);
+ }
+
+ std::unique_ptr<Timeout> CreateTimeout(
+ webrtc::TaskQueueBase::DelayPrecision precision) override {
+ return timeout_factory_.CreateTimeout(precision);
+ }
+
+ TimeMs TimeMillis() override { return TimeMs(rtc::TimeMillis()); }
+
+ uint32_t GetRandomInt(uint32_t low, uint32_t high) override {
+ return random_.Rand(low, high);
+ }
+
+ void OnMessageReceived(DcSctpMessage message) override {
+ received_bytes_ += message.payload().size();
+ last_received_message_ = std::move(message);
+ }
+
+ void OnError(ErrorKind error, absl::string_view message) override {
+ RTC_LOG(LS_WARNING) << log_prefix() << "Socket error: " << ToString(error)
+ << "; " << message;
+ }
+
+ void OnAborted(ErrorKind error, absl::string_view message) override {
+ RTC_LOG(LS_ERROR) << log_prefix() << "Socket abort: " << ToString(error)
+ << "; " << message;
+ }
+
+ void OnConnected() override {}
+
+ void OnClosed() override {}
+
+ void OnConnectionRestarted() override {}
+
+ void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,
+ absl::string_view reason) override {}
+
+ void OnStreamsResetPerformed(
+ rtc::ArrayView<const StreamID> outgoing_streams) override {}
+
+ void OnIncomingStreamsReset(
+ rtc::ArrayView<const StreamID> incoming_streams) override {}
+
+ void NotifyOutgoingMessageBufferEmpty() override {}
+
+ void OnBufferedAmountLow(StreamID stream_id) override {
+ if (mode_ == ActorMode::kThroughputSender) {
+ std::vector<uint8_t> payload(kHugePayloadSize);
+ sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)),
+ SendOptions());
+
+ } else if (mode_ == ActorMode::kLimitedRetransmissionSender) {
+ while (sctp_socket_.buffered_amount(kStreamId) <
+ kBufferedAmountLowThreshold * 2) {
+ SendOptions send_options;
+ send_options.max_retransmissions = 0;
+ sctp_socket_.Send(
+ DcSctpMessage(kStreamId, kPpid,
+ std::vector<uint8_t>(kLargePayloadSize)),
+ send_options);
+
+ send_options.max_retransmissions = absl::nullopt;
+ sctp_socket_.Send(
+ DcSctpMessage(kStreamId, kPpid,
+ std::vector<uint8_t>(kSmallPayloadSize)),
+ send_options);
+ }
+ }
+ }
+
+ absl::optional<DcSctpMessage> ConsumeReceivedMessage() {
+ if (!last_received_message_.has_value()) {
+ return absl::nullopt;
+ }
+ DcSctpMessage ret = *std::move(last_received_message_);
+ last_received_message_ = absl::nullopt;
+ return ret;
+ }
+
+ DcSctpSocket& sctp_socket() { return sctp_socket_; }
+
+ void SetActorMode(ActorMode mode) {
+ mode_ = mode;
+ if (mode_ == ActorMode::kThroughputSender) {
+ sctp_socket_.SetBufferedAmountLowThreshold(kStreamId,
+ kBufferedAmountLowThreshold);
+ std::vector<uint8_t> payload(kHugePayloadSize);
+ sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)),
+ SendOptions());
+
+ } else if (mode_ == ActorMode::kLimitedRetransmissionSender) {
+ sctp_socket_.SetBufferedAmountLowThreshold(kStreamId,
+ kBufferedAmountLowThreshold);
+ std::vector<uint8_t> payload(kHugePayloadSize);
+ sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)),
+ SendOptions());
+
+ } else if (mode == ActorMode::kThroughputReceiver) {
+ thread_->PostDelayedTask(
+ SafeTask(safety_.flag(), [this] { PrintBandwidth(); }),
+ kPrintBandwidthDuration);
+ }
+ }
+
+ // Returns the average bitrate, stripping the first `remove_first_n` that
+ // represent the time it took to ramp up the congestion control algorithm.
+ double avg_received_bitrate_mbps(size_t remove_first_n = 3) const {
+ std::vector<double> bitrates = received_bitrate_mbps_;
+ bitrates.erase(bitrates.begin(), bitrates.begin() + remove_first_n);
+
+ double sum = 0;
+ for (double bitrate : bitrates) {
+ sum += bitrate;
+ }
+
+ return sum / bitrates.size();
+ }
+
+ private:
+ std::string log_prefix() const {
+ rtc::StringBuilder sb;
+ sb << log_prefix_;
+ sb << rtc::TimeMillis();
+ sb << ": ";
+ return sb.Release();
+ }
+
+ ActorMode mode_ = ActorMode::kAtRest;
+ const std::string log_prefix_;
+ rtc::Thread* thread_;
+ BoundSocket& emulated_socket_;
+ TaskQueueTimeoutFactory timeout_factory_;
+ webrtc::Random random_;
+ DcSctpSocket sctp_socket_;
+ size_t received_bytes_ = 0;
+ absl::optional<DcSctpMessage> last_received_message_;
+ TimeMs last_bandwidth_printout_;
+ // Per-second received bitrates, in Mbps
+ std::vector<double> received_bitrate_mbps_;
+ webrtc::ScopedTaskSafety safety_;
+};
+
+class DcSctpSocketNetworkTest : public testing::Test {
+ protected:
+ DcSctpSocketNetworkTest()
+ : options_(MakeOptionsForTest()),
+ emulation_(webrtc::CreateNetworkEmulationManager(
+ webrtc::TimeMode::kSimulated)) {}
+
+ void MakeNetwork(const webrtc::BuiltInNetworkBehaviorConfig& config) {
+ webrtc::EmulatedEndpoint* endpoint_a =
+ emulation_->CreateEndpoint(webrtc::EmulatedEndpointConfig());
+ webrtc::EmulatedEndpoint* endpoint_z =
+ emulation_->CreateEndpoint(webrtc::EmulatedEndpointConfig());
+
+ webrtc::EmulatedNetworkNode* node1 = emulation_->CreateEmulatedNode(config);
+ webrtc::EmulatedNetworkNode* node2 = emulation_->CreateEmulatedNode(config);
+
+ emulation_->CreateRoute(endpoint_a, {node1}, endpoint_z);
+ emulation_->CreateRoute(endpoint_z, {node2}, endpoint_a);
+
+ emulated_socket_a_.Bind(endpoint_a);
+ emulated_socket_z_.Bind(endpoint_z);
+
+ emulated_socket_a_.SetDestination(emulated_socket_z_);
+ emulated_socket_z_.SetDestination(emulated_socket_a_);
+ }
+
+ void Sleep(webrtc::TimeDelta duration) {
+ // Sleep in one-millisecond increments, to let timers expire when expected.
+ for (int i = 0; i < duration.ms(); ++i) {
+ emulation_->time_controller()->AdvanceTime(webrtc::TimeDelta::Millis(1));
+ }
+ }
+
+ DcSctpOptions options_;
+ std::unique_ptr<webrtc::NetworkEmulationManager> emulation_;
+ BoundSocket emulated_socket_a_;
+ BoundSocket emulated_socket_z_;
+};
+
+TEST_F(DcSctpSocketNetworkTest, CanConnectAndShutdown) {
+ webrtc::BuiltInNetworkBehaviorConfig pipe_config;
+ MakeNetwork(pipe_config);
+
+ SctpActor sender("A", emulated_socket_a_, options_);
+ SctpActor receiver("Z", emulated_socket_z_, options_);
+ EXPECT_THAT(sender.sctp_socket().state(), SocketState::kClosed);
+
+ sender.sctp_socket().Connect();
+ Sleep(kAWhile);
+ EXPECT_THAT(sender.sctp_socket().state(), SocketState::kConnected);
+
+ sender.sctp_socket().Shutdown();
+ Sleep(kAWhile);
+ EXPECT_THAT(sender.sctp_socket().state(), SocketState::kClosed);
+}
+
+TEST_F(DcSctpSocketNetworkTest, CanSendLargeMessage) {
+ webrtc::BuiltInNetworkBehaviorConfig pipe_config;
+ pipe_config.queue_delay_ms = 30;
+ MakeNetwork(pipe_config);
+
+ SctpActor sender("A", emulated_socket_a_, options_);
+ SctpActor receiver("Z", emulated_socket_z_, options_);
+ sender.sctp_socket().Connect();
+
+ constexpr size_t kPayloadSize = 100 * 1024;
+
+ std::vector<uint8_t> payload(kPayloadSize);
+ sender.sctp_socket().Send(DcSctpMessage(kStreamId, kPpid, payload),
+ SendOptions());
+
+ Sleep(kAWhile);
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage message,
+ receiver.ConsumeReceivedMessage());
+
+ EXPECT_THAT(message.payload(), SizeIs(kPayloadSize));
+
+ sender.sctp_socket().Shutdown();
+ Sleep(kAWhile);
+}
+
+TEST_F(DcSctpSocketNetworkTest, CanSendMessagesReliablyWithLowBandwidth) {
+ webrtc::BuiltInNetworkBehaviorConfig pipe_config;
+ pipe_config.queue_delay_ms = 30;
+ pipe_config.link_capacity_kbps = 1000;
+ MakeNetwork(pipe_config);
+
+ SctpActor sender("A", emulated_socket_a_, options_);
+ SctpActor receiver("Z", emulated_socket_z_, options_);
+ sender.sctp_socket().Connect();
+
+ sender.SetActorMode(ActorMode::kThroughputSender);
+ receiver.SetActorMode(ActorMode::kThroughputReceiver);
+
+ Sleep(kBenchmarkRuntime);
+ sender.SetActorMode(ActorMode::kAtRest);
+ receiver.SetActorMode(ActorMode::kAtRest);
+
+ Sleep(kAWhile);
+
+ sender.sctp_socket().Shutdown();
+
+ Sleep(kAWhile);
+
+ // Verify that the bitrates are in the range of 0.5-1.0 Mbps.
+ double bitrate = receiver.avg_received_bitrate_mbps();
+ EXPECT_THAT(bitrate, AllOf(Ge(0.5), Le(1.0)));
+}
+
+TEST_F(DcSctpSocketNetworkTest,
+ DCSCTP_NDEBUG_TEST(CanSendMessagesReliablyWithMediumBandwidth)) {
+ webrtc::BuiltInNetworkBehaviorConfig pipe_config;
+ pipe_config.queue_delay_ms = 30;
+ pipe_config.link_capacity_kbps = 18000;
+ MakeNetwork(pipe_config);
+
+ SctpActor sender("A", emulated_socket_a_, options_);
+ SctpActor receiver("Z", emulated_socket_z_, options_);
+ sender.sctp_socket().Connect();
+
+ sender.SetActorMode(ActorMode::kThroughputSender);
+ receiver.SetActorMode(ActorMode::kThroughputReceiver);
+
+ Sleep(kBenchmarkRuntime);
+ sender.SetActorMode(ActorMode::kAtRest);
+ receiver.SetActorMode(ActorMode::kAtRest);
+
+ Sleep(kAWhile);
+
+ sender.sctp_socket().Shutdown();
+
+ Sleep(kAWhile);
+
+ // Verify that the bitrates are in the range of 16-18 Mbps.
+ double bitrate = receiver.avg_received_bitrate_mbps();
+ EXPECT_THAT(bitrate, AllOf(Ge(16), Le(18)));
+}
+
+TEST_F(DcSctpSocketNetworkTest, CanSendMessagesReliablyWithMuchPacketLoss) {
+ webrtc::BuiltInNetworkBehaviorConfig config;
+ config.queue_delay_ms = 30;
+ config.loss_percent = 1;
+ MakeNetwork(config);
+
+ SctpActor sender("A", emulated_socket_a_, options_);
+ SctpActor receiver("Z", emulated_socket_z_, options_);
+ sender.sctp_socket().Connect();
+
+ sender.SetActorMode(ActorMode::kThroughputSender);
+ receiver.SetActorMode(ActorMode::kThroughputReceiver);
+
+ Sleep(kBenchmarkRuntime);
+ sender.SetActorMode(ActorMode::kAtRest);
+ receiver.SetActorMode(ActorMode::kAtRest);
+
+ Sleep(kAWhile);
+
+ sender.sctp_socket().Shutdown();
+
+ Sleep(kAWhile);
+
+ // TCP calculator gives: 1200 MTU, 60ms RTT and 1% packet loss -> 1.6Mbps.
+ // This test is doing slightly better (doesn't have any additional header
+ // overhead etc). Verify that the bitrates are in the range of 1.5-2.5 Mbps.
+ double bitrate = receiver.avg_received_bitrate_mbps();
+ EXPECT_THAT(bitrate, AllOf(Ge(1.5), Le(2.5)));
+}
+
+TEST_F(DcSctpSocketNetworkTest, DCSCTP_NDEBUG_TEST(HasHighBandwidth)) {
+ webrtc::BuiltInNetworkBehaviorConfig pipe_config;
+ pipe_config.queue_delay_ms = 30;
+ MakeNetwork(pipe_config);
+
+ SctpActor sender("A", emulated_socket_a_, options_);
+ SctpActor receiver("Z", emulated_socket_z_, options_);
+ sender.sctp_socket().Connect();
+
+ sender.SetActorMode(ActorMode::kThroughputSender);
+ receiver.SetActorMode(ActorMode::kThroughputReceiver);
+
+ Sleep(kBenchmarkRuntime);
+
+ sender.SetActorMode(ActorMode::kAtRest);
+ receiver.SetActorMode(ActorMode::kAtRest);
+ Sleep(kAWhile);
+
+ sender.sctp_socket().Shutdown();
+ Sleep(kAWhile);
+
+ // Verify that the bitrate is in the range of 540-640 Mbps
+ double bitrate = receiver.avg_received_bitrate_mbps();
+ EXPECT_THAT(bitrate, AllOf(Ge(520), Le(640)));
+}
+} // namespace
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc
new file mode 100644
index 0000000000..d1e2d904e0
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc
@@ -0,0 +1,2672 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/dcsctp_socket.h"
+
+#include <cstdint>
+#include <deque>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/flags/flag.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "net/dcsctp/common/handover_testing.h"
+#include "net/dcsctp/packet/chunk/chunk.h"
+#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
+#include "net/dcsctp/packet/chunk/data_chunk.h"
+#include "net/dcsctp/packet/chunk/data_common.h"
+#include "net/dcsctp/packet/chunk/error_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
+#include "net/dcsctp/packet/chunk/idata_chunk.h"
+#include "net/dcsctp/packet/chunk/init_chunk.h"
+#include "net/dcsctp/packet/chunk/sack_chunk.h"
+#include "net/dcsctp/packet/chunk/shutdown_chunk.h"
+#include "net/dcsctp/packet/error_cause/error_cause.h"
+#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h"
+#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h"
+#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h"
+#include "net/dcsctp/packet/parameter/parameter.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/packet/tlv_trait.h"
+#include "net/dcsctp/public/dcsctp_message.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/public/text_pcap_packet_observer.h"
+#include "net/dcsctp/public/types.h"
+#include "net/dcsctp/rx/reassembly_queue.h"
+#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h"
+#include "net/dcsctp/testing/testing_macros.h"
+#include "rtc_base/gunit.h"
+#include "test/gmock.h"
+
+ABSL_FLAG(bool, dcsctp_capture_packets, false, "Print packet capture.");
+
+namespace dcsctp {
+namespace {
+using ::testing::_;
+using ::testing::AllOf;
+using ::testing::ElementsAre;
+using ::testing::HasSubstr;
+using ::testing::IsEmpty;
+using ::testing::SizeIs;
+using ::testing::UnorderedElementsAre;
+
+constexpr SendOptions kSendOptions;
+constexpr size_t kLargeMessageSize = DcSctpOptions::kMaxSafeMTUSize * 20;
+constexpr size_t kSmallMessageSize = 10;
+constexpr int kMaxBurstPackets = 4;
+
+MATCHER_P(HasDataChunkWithStreamId, stream_id, "") {
+ absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
+ if (!packet.has_value()) {
+ *result_listener << "data didn't parse as an SctpPacket";
+ return false;
+ }
+
+ if (packet->descriptors()[0].type != DataChunk::kType) {
+ *result_listener << "the first chunk in the packet is not a data chunk";
+ return false;
+ }
+
+ absl::optional<DataChunk> dc =
+ DataChunk::Parse(packet->descriptors()[0].data);
+ if (!dc.has_value()) {
+ *result_listener << "The first chunk didn't parse as a data chunk";
+ return false;
+ }
+
+ if (dc->stream_id() != stream_id) {
+ *result_listener << "the stream_id is " << *dc->stream_id();
+ return false;
+ }
+
+ return true;
+}
+
+MATCHER_P(HasDataChunkWithPPID, ppid, "") {
+ absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
+ if (!packet.has_value()) {
+ *result_listener << "data didn't parse as an SctpPacket";
+ return false;
+ }
+
+ if (packet->descriptors()[0].type != DataChunk::kType) {
+ *result_listener << "the first chunk in the packet is not a data chunk";
+ return false;
+ }
+
+ absl::optional<DataChunk> dc =
+ DataChunk::Parse(packet->descriptors()[0].data);
+ if (!dc.has_value()) {
+ *result_listener << "The first chunk didn't parse as a data chunk";
+ return false;
+ }
+
+ if (dc->ppid() != ppid) {
+ *result_listener << "the ppid is " << *dc->ppid();
+ return false;
+ }
+
+ return true;
+}
+
+MATCHER_P(HasDataChunkWithSsn, ssn, "") {
+ absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
+ if (!packet.has_value()) {
+ *result_listener << "data didn't parse as an SctpPacket";
+ return false;
+ }
+
+ if (packet->descriptors()[0].type != DataChunk::kType) {
+ *result_listener << "the first chunk in the packet is not a data chunk";
+ return false;
+ }
+
+ absl::optional<DataChunk> dc =
+ DataChunk::Parse(packet->descriptors()[0].data);
+ if (!dc.has_value()) {
+ *result_listener << "The first chunk didn't parse as a data chunk";
+ return false;
+ }
+
+ if (dc->ssn() != ssn) {
+ *result_listener << "the ssn is " << *dc->ssn();
+ return false;
+ }
+
+ return true;
+}
+
+MATCHER_P(HasDataChunkWithMid, mid, "") {
+ absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
+ if (!packet.has_value()) {
+ *result_listener << "data didn't parse as an SctpPacket";
+ return false;
+ }
+
+ if (packet->descriptors()[0].type != IDataChunk::kType) {
+ *result_listener << "the first chunk in the packet is not an i-data chunk";
+ return false;
+ }
+
+ absl::optional<IDataChunk> dc =
+ IDataChunk::Parse(packet->descriptors()[0].data);
+ if (!dc.has_value()) {
+ *result_listener << "The first chunk didn't parse as an i-data chunk";
+ return false;
+ }
+
+ if (dc->message_id() != mid) {
+ *result_listener << "the mid is " << *dc->message_id();
+ return false;
+ }
+
+ return true;
+}
+
+MATCHER_P(HasSackWithCumAckTsn, tsn, "") {
+ absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
+ if (!packet.has_value()) {
+ *result_listener << "data didn't parse as an SctpPacket";
+ return false;
+ }
+
+ if (packet->descriptors()[0].type != SackChunk::kType) {
+ *result_listener << "the first chunk in the packet is not a data chunk";
+ return false;
+ }
+
+ absl::optional<SackChunk> sc =
+ SackChunk::Parse(packet->descriptors()[0].data);
+ if (!sc.has_value()) {
+ *result_listener << "The first chunk didn't parse as a data chunk";
+ return false;
+ }
+
+ if (sc->cumulative_tsn_ack() != tsn) {
+ *result_listener << "the cum_ack_tsn is " << *sc->cumulative_tsn_ack();
+ return false;
+ }
+
+ return true;
+}
+
+MATCHER(HasSackWithNoGapAckBlocks, "") {
+ absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
+ if (!packet.has_value()) {
+ *result_listener << "data didn't parse as an SctpPacket";
+ return false;
+ }
+
+ if (packet->descriptors()[0].type != SackChunk::kType) {
+ *result_listener << "the first chunk in the packet is not a data chunk";
+ return false;
+ }
+
+ absl::optional<SackChunk> sc =
+ SackChunk::Parse(packet->descriptors()[0].data);
+ if (!sc.has_value()) {
+ *result_listener << "The first chunk didn't parse as a data chunk";
+ return false;
+ }
+
+ if (!sc->gap_ack_blocks().empty()) {
+ *result_listener << "there are gap ack blocks";
+ return false;
+ }
+
+ return true;
+}
+
+MATCHER_P(HasReconfigWithStreams, streams_matcher, "") {
+ absl::optional<SctpPacket> packet = SctpPacket::Parse(arg);
+ if (!packet.has_value()) {
+ *result_listener << "data didn't parse as an SctpPacket";
+ return false;
+ }
+
+ if (packet->descriptors()[0].type != ReConfigChunk::kType) {
+ *result_listener << "the first chunk in the packet is not a data chunk";
+ return false;
+ }
+
+ absl::optional<ReConfigChunk> reconfig =
+ ReConfigChunk::Parse(packet->descriptors()[0].data);
+ if (!reconfig.has_value()) {
+ *result_listener << "The first chunk didn't parse as a data chunk";
+ return false;
+ }
+
+ const Parameters& parameters = reconfig->parameters();
+ if (parameters.descriptors().size() != 1 ||
+ parameters.descriptors()[0].type !=
+ OutgoingSSNResetRequestParameter::kType) {
+ *result_listener << "Expected the reconfig chunk to have an outgoing SSN "
+ "reset request parameter";
+ return false;
+ }
+
+ absl::optional<OutgoingSSNResetRequestParameter> p =
+ OutgoingSSNResetRequestParameter::Parse(parameters.descriptors()[0].data);
+ testing::Matcher<rtc::ArrayView<const StreamID>> matcher = streams_matcher;
+ if (!matcher.MatchAndExplain(p->stream_ids(), result_listener)) {
+ return false;
+ }
+
+ return true;
+}
+
+TSN AddTo(TSN tsn, int delta) {
+ return TSN(*tsn + delta);
+}
+
+DcSctpOptions FixupOptions(DcSctpOptions options = {}) {
+ DcSctpOptions fixup = options;
+ // To make the interval more predictable in tests.
+ fixup.heartbeat_interval_include_rtt = false;
+ fixup.max_burst = kMaxBurstPackets;
+ return fixup;
+}
+
+std::unique_ptr<PacketObserver> GetPacketObserver(absl::string_view name) {
+ if (absl::GetFlag(FLAGS_dcsctp_capture_packets)) {
+ return std::make_unique<TextPcapPacketObserver>(name);
+ }
+ return nullptr;
+}
+
+struct SocketUnderTest {
+ explicit SocketUnderTest(absl::string_view name,
+ const DcSctpOptions& opts = {})
+ : options(FixupOptions(opts)),
+ cb(name),
+ socket(name, cb, GetPacketObserver(name), options) {}
+
+ const DcSctpOptions options;
+ testing::NiceMock<MockDcSctpSocketCallbacks> cb;
+ DcSctpSocket socket;
+};
+
+void ExchangeMessages(SocketUnderTest& a, SocketUnderTest& z) {
+ bool delivered_packet = false;
+ do {
+ delivered_packet = false;
+ std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
+ if (!packet_from_a.empty()) {
+ delivered_packet = true;
+ z.socket.ReceivePacket(std::move(packet_from_a));
+ }
+ std::vector<uint8_t> packet_from_z = z.cb.ConsumeSentPacket();
+ if (!packet_from_z.empty()) {
+ delivered_packet = true;
+ a.socket.ReceivePacket(std::move(packet_from_z));
+ }
+ } while (delivered_packet);
+}
+
+void RunTimers(SocketUnderTest& s) {
+ for (;;) {
+ absl::optional<TimeoutID> timeout_id = s.cb.GetNextExpiredTimeout();
+ if (!timeout_id.has_value()) {
+ break;
+ }
+ s.socket.HandleTimeout(*timeout_id);
+ }
+}
+
+void AdvanceTime(SocketUnderTest& a, SocketUnderTest& z, DurationMs duration) {
+ a.cb.AdvanceTime(duration);
+ z.cb.AdvanceTime(duration);
+
+ RunTimers(a);
+ RunTimers(z);
+}
+
+// Calls Connect() on `sock_a_` and make the connection established.
+void ConnectSockets(SocketUnderTest& a, SocketUnderTest& z) {
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+
+ a.socket.Connect();
+ // Z reads INIT, INIT_ACK, COOKIE_ECHO, COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+std::unique_ptr<SocketUnderTest> HandoverSocket(
+ std::unique_ptr<SocketUnderTest> sut) {
+ EXPECT_EQ(sut->socket.GetHandoverReadiness(), HandoverReadinessStatus());
+
+ bool is_closed = sut->socket.state() == SocketState::kClosed;
+ if (!is_closed) {
+ EXPECT_CALL(sut->cb, OnClosed).Times(1);
+ }
+ absl::optional<DcSctpSocketHandoverState> handover_state =
+ sut->socket.GetHandoverStateAndClose();
+ EXPECT_TRUE(handover_state.has_value());
+ g_handover_state_transformer_for_test(&*handover_state);
+
+ auto handover_socket = std::make_unique<SocketUnderTest>("H", sut->options);
+ if (!is_closed) {
+ EXPECT_CALL(handover_socket->cb, OnConnected).Times(1);
+ }
+ handover_socket->socket.RestoreFromState(*handover_state);
+ return handover_socket;
+}
+
+std::vector<uint32_t> GetReceivedMessagePpids(SocketUnderTest& z) {
+ std::vector<uint32_t> ppids;
+ for (;;) {
+ absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+ if (!msg.has_value()) {
+ break;
+ }
+ ppids.push_back(*msg->ppid());
+ }
+ return ppids;
+}
+
+// Test parameter that controls whether to perform handovers during the test. A
+// test can have multiple points where it conditionally hands over socket Z.
+// Either socket Z will be handed over at all those points or handed over never.
+enum class HandoverMode {
+ kNoHandover,
+ kPerformHandovers,
+};
+
+class DcSctpSocketParametrizedTest
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<HandoverMode> {
+ protected:
+ // Trigger handover for `sut` depending on the current test param.
+ std::unique_ptr<SocketUnderTest> MaybeHandoverSocket(
+ std::unique_ptr<SocketUnderTest> sut) {
+ if (GetParam() == HandoverMode::kPerformHandovers) {
+ return HandoverSocket(std::move(sut));
+ }
+ return sut;
+ }
+
+ // Trigger handover for socket Z depending on the current test param.
+ // Then checks message passing to verify the handed over socket is functional.
+ void MaybeHandoverSocketAndSendMessage(SocketUnderTest& a,
+ std::unique_ptr<SocketUnderTest> z) {
+ if (GetParam() == HandoverMode::kPerformHandovers) {
+ z = HandoverSocket(std::move(z));
+ }
+
+ ExchangeMessages(a, *z);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ ExchangeMessages(a, *z);
+
+ absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ }
+};
+
+INSTANTIATE_TEST_SUITE_P(Handovers,
+ DcSctpSocketParametrizedTest,
+ testing::Values(HandoverMode::kNoHandover,
+ HandoverMode::kPerformHandovers),
+ [](const auto& test_info) {
+ return test_info.param ==
+ HandoverMode::kPerformHandovers
+ ? "WithHandovers"
+ : "NoHandover";
+ });
+
+TEST(DcSctpSocketTest, EstablishConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
+ EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
+
+ a.socket.Connect();
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, EstablishConnectionWithSetupCollision) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
+ EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
+ a.socket.Connect();
+ z.socket.Connect();
+
+ ExchangeMessages(a, z);
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, ShuttingDownWhileEstablishingConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(0);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ a.socket.Connect();
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // Drop COOKIE_ACK, just to more easily verify shutdown protocol.
+ z.cb.ConsumeSentPacket();
+
+ // As Socket A has received INIT_ACK, it has a TCB and is connected, while
+ // Socket Z needs to receive COOKIE_ECHO to get there. Socket A still has
+ // timers running at this point.
+ EXPECT_EQ(a.socket.state(), SocketState::kConnecting);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+
+ // Socket A is now shut down, which should make it stop those timers.
+ a.socket.Shutdown();
+
+ EXPECT_CALL(a.cb, OnClosed).Times(1);
+ EXPECT_CALL(z.cb, OnClosed).Times(1);
+
+ // Z reads SHUTDOWN, produces SHUTDOWN_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // Z reads SHUTDOWN_COMPLETE.
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
+ EXPECT_TRUE(z.cb.ConsumeSentPacket().empty());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+ EXPECT_EQ(z.socket.state(), SocketState::kClosed);
+}
+
+TEST(DcSctpSocketTest, EstablishSimultaneousConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
+ EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
+ a.socket.Connect();
+
+ // INIT isn't received by Z, as it wasn't ready yet.
+ a.cb.ConsumeSentPacket();
+
+ z.socket.Connect();
+
+ // A reads INIT, produces INIT_ACK
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // Z reads INIT_ACK, sends COOKIE_ECHO
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ // A reads COOKIE_ECHO - establishes connection.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+
+ // Proceed with the remaining packets.
+ ExchangeMessages(a, z);
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, EstablishConnectionLostCookieAck) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0);
+ EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0);
+
+ a.socket.Connect();
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // COOKIE_ACK is lost.
+ z.cb.ConsumeSentPacket();
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnecting);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+
+ // This will make A re-send the COOKIE_ECHO
+ AdvanceTime(a, z, DurationMs(a.options.t1_cookie_timeout));
+
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, ResendInitAndEstablishConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+ // INIT is never received by Z.
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType);
+
+ AdvanceTime(a, z, a.options.t1_init_timeout);
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, ResendingInitTooManyTimesAborts) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+
+ // INIT is never received by Z.
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType);
+
+ for (int i = 0; i < *a.options.max_init_retransmits; ++i) {
+ AdvanceTime(a, z, a.options.t1_init_timeout * (1 << i));
+
+ // INIT is resent
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_EQ(resent_init_packet.descriptors()[0].type, InitChunk::kType);
+ }
+
+ // Another timeout, after the max init retransmits.
+ EXPECT_CALL(a.cb, OnAborted).Times(1);
+ AdvanceTime(
+ a, z, a.options.t1_init_timeout * (1 << *a.options.max_init_retransmits));
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+}
+
+TEST(DcSctpSocketTest, ResendCookieEchoAndEstablishConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // COOKIE_ECHO is never received by Z.
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType);
+
+ AdvanceTime(a, z, a.options.t1_init_timeout);
+
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+}
+
+TEST(DcSctpSocketTest, ResendingCookieEchoTooManyTimesAborts) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // COOKIE_ECHO is never received by Z.
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType);
+
+ for (int i = 0; i < *a.options.max_init_retransmits; ++i) {
+ AdvanceTime(a, z, a.options.t1_cookie_timeout * (1 << i));
+
+ // COOKIE_ECHO is resent
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_EQ(resent_init_packet.descriptors()[0].type, CookieEchoChunk::kType);
+ }
+
+ // Another timeout, after the max init retransmits.
+ EXPECT_CALL(a.cb, OnAborted).Times(1);
+ AdvanceTime(
+ a, z,
+ a.options.t1_cookie_timeout * (1 << *a.options.max_init_retransmits));
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+}
+
+TEST(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ a.socket.Connect();
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // COOKIE_ECHO is never received by Z.
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet1,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_THAT(cookie_echo_packet1.descriptors(), SizeIs(2));
+ EXPECT_EQ(cookie_echo_packet1.descriptors()[0].type, CookieEchoChunk::kType);
+ EXPECT_EQ(cookie_echo_packet1.descriptors()[1].type, DataChunk::kType);
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ // There are DATA chunks in the sent packet (that was lost), which means that
+ // the T3-RTX timer is running, but as the socket is in kCookieEcho state, it
+ // will be T1-COOKIE that drives retransmissions, so when the T3-RTX expires,
+ // nothing should be retransmitted.
+ ASSERT_TRUE(a.options.rto_initial < a.options.t1_cookie_timeout);
+ AdvanceTime(a, z, a.options.rto_initial);
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ // When T1-COOKIE expires, both the COOKIE-ECHO and DATA should be present.
+ AdvanceTime(a, z, a.options.t1_cookie_timeout - a.options.rto_initial);
+
+ // And this COOKIE-ECHO and DATA is also lost - never received by Z.
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet2,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_THAT(cookie_echo_packet2.descriptors(), SizeIs(2));
+ EXPECT_EQ(cookie_echo_packet2.descriptors()[0].type, CookieEchoChunk::kType);
+ EXPECT_EQ(cookie_echo_packet2.descriptors()[1].type, DataChunk::kType);
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ // COOKIE_ECHO has exponential backoff.
+ AdvanceTime(a, z, a.options.t1_cookie_timeout * 2);
+
+ // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+
+ ExchangeMessages(a, z);
+ EXPECT_THAT(z.cb.ConsumeReceivedMessage()->payload(),
+ SizeIs(kLargeMessageSize));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ShutdownConnection) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ RTC_LOG(LS_INFO) << "Shutting down";
+
+ EXPECT_CALL(z->cb, OnClosed).Times(1);
+ a.socket.Shutdown();
+ // Z reads SHUTDOWN, produces SHUTDOWN_ACK
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+ // Z reads SHUTDOWN_COMPLETE.
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+ EXPECT_EQ(z->socket.state(), SocketState::kClosed);
+
+ z = MaybeHandoverSocket(std::move(z));
+ EXPECT_EQ(z->socket.state(), SocketState::kClosed);
+}
+
+TEST(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ a.socket.Shutdown();
+ // Drop first SHUTDOWN packet.
+ a.cb.ConsumeSentPacket();
+
+ EXPECT_EQ(a.socket.state(), SocketState::kShuttingDown);
+
+ for (int i = 0; i < *a.options.max_retransmissions; ++i) {
+ AdvanceTime(a, z, DurationMs(a.options.rto_initial * (1 << i)));
+
+ // Dropping every shutdown chunk.
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_EQ(packet.descriptors()[0].type, ShutdownChunk::kType);
+ EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
+ }
+ // The last expiry, makes it abort the connection.
+ EXPECT_CALL(a.cb, OnAborted).Times(1);
+ AdvanceTime(a, z,
+ a.options.rto_initial * (1 << *a.options.max_retransmissions));
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_EQ(packet.descriptors()[0].type, AbortChunk::kType);
+ EXPECT_TRUE(a.cb.ConsumeSentPacket().empty());
+}
+
+TEST(DcSctpSocketTest, EstablishConnectionWhileSendingData) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+
+ // Z reads INIT, produces INIT_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // // A reads INIT_ACK, produces COOKIE_ECHO
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ // // Z reads COOKIE_ECHO, produces COOKIE_ACK
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // // A reads COOKIE_ACK.
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ EXPECT_EQ(a.socket.state(), SocketState::kConnected);
+ EXPECT_EQ(z.socket.state(), SocketState::kConnected);
+
+ absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+}
+
+TEST(DcSctpSocketTest, SendMessageAfterEstablished) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, TimeoutResendsPacket) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ a.cb.ConsumeSentPacket();
+
+ RTC_LOG(LS_INFO) << "Advancing time";
+ AdvanceTime(a, *z, a.options.rto_initial);
+
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendALotOfBytesMissedSecondPacket) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ std::vector<uint8_t> payload(kLargeMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+
+ // First DATA
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // Second DATA (lost)
+ a.cb.ConsumeSentPacket();
+
+ // Retransmit and handle the rest
+ ExchangeMessages(a, *z);
+
+ absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ // Inject a HEARTBEAT chunk
+ SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
+ uint8_t info[] = {1, 2, 3, 4};
+ Parameters::Builder params_builder;
+ params_builder.Add(HeartbeatInfoParameter(info));
+ b.Add(HeartbeatRequestChunk(params_builder.Build()));
+ a.socket.ReceivePacket(b.Build());
+
+ // HEARTBEAT_ACK is sent as a reply. Capture it.
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket ack_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ ASSERT_THAT(ack_packet.descriptors(), SizeIs(1));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ HeartbeatAckChunk ack,
+ HeartbeatAckChunk::Parse(ack_packet.descriptors()[0].data));
+ ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, ack.info());
+ EXPECT_THAT(info_param.info(), ElementsAre(1, 2, 3, 4));
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ AdvanceTime(a, *z, a.options.heartbeat_interval);
+
+ std::vector<uint8_t> hb_packet_raw = a.cb.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet,
+ SctpPacket::Parse(hb_packet_raw));
+ ASSERT_THAT(hb_packet.descriptors(), SizeIs(1));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ HeartbeatRequestChunk hb,
+ HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data));
+ ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, hb.info());
+
+ // The info is a single 64-bit number.
+ EXPECT_THAT(hb.info()->info(), SizeIs(8));
+
+ // Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back.
+ z->socket.ReceivePacket(hb_packet_raw);
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ CloseConnectionAfterTooManyLostHeartbeats) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(z->cb, OnClosed).Times(1);
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty());
+ // Force-close socket Z so that it doesn't interfere from now on.
+ z->socket.Close();
+
+ DurationMs time_to_next_hearbeat = a.options.heartbeat_interval;
+
+ for (int i = 0; i < *a.options.max_retransmissions; ++i) {
+ RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending...";
+ AdvanceTime(a, *z, time_to_next_hearbeat);
+
+ // Dropping every heartbeat.
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_EQ(hb_packet.descriptors()[0].type, HeartbeatRequestChunk::kType);
+
+ RTC_LOG(LS_INFO) << "Letting the heartbeat expire.";
+ AdvanceTime(a, *z, DurationMs(1000));
+
+ time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000);
+ }
+
+ RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending...";
+ AdvanceTime(a, *z, time_to_next_hearbeat);
+
+ // Last heartbeat
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), Not(IsEmpty()));
+
+ EXPECT_CALL(a.cb, OnAborted).Times(1);
+ // Should suffice as exceeding RTO
+ AdvanceTime(a, *z, DurationMs(1000));
+
+ z = MaybeHandoverSocket(std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty());
+ EXPECT_CALL(z->cb, OnClosed).Times(1);
+ // Force-close socket Z so that it doesn't interfere from now on.
+ z->socket.Close();
+
+ DurationMs time_to_next_hearbeat = a.options.heartbeat_interval;
+
+ for (int i = 0; i < *a.options.max_retransmissions; ++i) {
+ AdvanceTime(a, *z, time_to_next_hearbeat);
+
+ // Dropping every heartbeat.
+ a.cb.ConsumeSentPacket();
+
+ RTC_LOG(LS_INFO) << "Letting the heartbeat expire.";
+ AdvanceTime(a, *z, DurationMs(1000));
+
+ time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000);
+ }
+
+ RTC_LOG(LS_INFO) << "Getting the last heartbeat - and acking it";
+ AdvanceTime(a, *z, time_to_next_hearbeat);
+
+ std::vector<uint8_t> hb_packet_raw = a.cb.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet,
+ SctpPacket::Parse(hb_packet_raw));
+ ASSERT_THAT(hb_packet.descriptors(), SizeIs(1));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ HeartbeatRequestChunk hb,
+ HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data));
+
+ SctpPacket::Builder b(a.socket.verification_tag(), a.options);
+ b.Add(HeartbeatAckChunk(std::move(hb).extract_parameters()));
+ a.socket.ReceivePacket(b.Build());
+
+ // Should suffice as exceeding RTO - which will not fire.
+ EXPECT_CALL(a.cb, OnAborted).Times(0);
+ AdvanceTime(a, *z, DurationMs(1000));
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ // Verify that we get new heartbeats again.
+ RTC_LOG(LS_INFO) << "Expecting a new heartbeat";
+ AdvanceTime(a, *z, time_to_next_hearbeat);
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket another_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ EXPECT_EQ(another_packet.descriptors()[0].type, HeartbeatRequestChunk::kType);
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ResetStream) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {});
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+
+ // Handle SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Reset the outgoing stream. This will directly send a RE-CONFIG.
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+
+ // Receiving the packet will trigger a callback, indicating that A has
+ // reset its stream. It will also send a RE-CONFIG with a response.
+ EXPECT_CALL(z->cb, OnIncomingStreamsReset).Times(1);
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ // Receiving a response will trigger a callback. Streams are now reset.
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed).Times(1);
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillMakeChunksStartAtZeroSsn) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ std::vector<uint8_t> payload(a.options.mtu - 100);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ auto packet1 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0)));
+ z->socket.ReceivePacket(packet1);
+
+ auto packet2 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1)));
+ z->socket.ReceivePacket(packet2);
+
+ // Handle SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->stream_id(), StreamID(1));
+
+ absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg2.has_value());
+ EXPECT_EQ(msg2->stream_id(), StreamID(1));
+
+ // Reset the outgoing stream. This will directly send a RE-CONFIG.
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+ // RE-CONFIG, req
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // RE-CONFIG, resp
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ auto packet3 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0)));
+ z->socket.ReceivePacket(packet3);
+
+ auto packet4 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1)));
+ z->socket.ReceivePacket(packet4);
+
+ // Handle SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ ResetStreamWillOnlyResetTheRequestedStreams) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ std::vector<uint8_t> payload(a.options.mtu - 100);
+
+ // Send two ordered messages on SID 1
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ auto packet1 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1)));
+ EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0)));
+ z->socket.ReceivePacket(packet1);
+
+ auto packet2 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1)));
+ EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1)));
+ z->socket.ReceivePacket(packet2);
+
+ // Handle SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Do the same, for SID 3
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
+ auto packet3 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet3, HasDataChunkWithStreamId(StreamID(3)));
+ EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0)));
+ z->socket.ReceivePacket(packet3);
+ auto packet4 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet4, HasDataChunkWithStreamId(StreamID(3)));
+ EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1)));
+ z->socket.ReceivePacket(packet4);
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Receive all messages.
+ absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->stream_id(), StreamID(1));
+
+ absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg2.has_value());
+ EXPECT_EQ(msg2->stream_id(), StreamID(1));
+
+ absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg3.has_value());
+ EXPECT_EQ(msg3->stream_id(), StreamID(3));
+
+ absl::optional<DcSctpMessage> msg4 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg4.has_value());
+ EXPECT_EQ(msg4->stream_id(), StreamID(3));
+
+ // Reset SID 1. This will directly send a RE-CONFIG.
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
+ // RE-CONFIG, req
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // RE-CONFIG, resp
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Send a message on SID 1 and 3 - SID 1 should not be reset, but 3 should.
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {});
+
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {});
+
+ auto packet5 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet5, HasDataChunkWithStreamId(StreamID(1)));
+ EXPECT_THAT(packet5, HasDataChunkWithSsn(SSN(2))); // Unchanged.
+ z->socket.ReceivePacket(packet5);
+
+ auto packet6 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet6, HasDataChunkWithStreamId(StreamID(3)));
+ EXPECT_THAT(packet6, HasDataChunkWithSsn(SSN(0))); // Reset.
+ z->socket.ReceivePacket(packet6);
+
+ // Handle SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, OnePeerReconnects) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnConnectionRestarted).Times(1);
+ // Let's be evil here - reconnect while a fragmented packet was about to be
+ // sent. The receiving side should get it in full.
+ std::vector<uint8_t> payload(kLargeMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+
+ // First DATA
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ // Create a new association, z2 - and don't use z anymore.
+ SocketUnderTest z2("Z2");
+ z2.socket.Connect();
+
+ // Retransmit and handle the rest. As there will be some chunks in-flight that
+ // have the wrong verification tag, those will yield errors.
+ ExchangeMessages(a, z2);
+
+ absl::optional<DcSctpMessage> msg = z2.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendMessageWithLimitedRtx) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ SendOptions send_options;
+ send_options.max_retransmissions = 0;
+ std::vector<uint8_t> payload(a.options.mtu - 100);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options);
+
+ // First DATA
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // Second DATA (lost)
+ a.cb.ConsumeSentPacket();
+ // Third DATA
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ // Handle SACK for first DATA
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Handle delayed SACK for third DATA
+ AdvanceTime(a, *z, a.options.delayed_ack_max_timeout);
+
+ // Handle SACK for second DATA
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ // Now the missing data chunk will be marked as nacked, but it might still be
+ // in-flight and the reported gap could be due to out-of-order delivery. So
+ // the RetransmissionQueue will not mark it as "to be retransmitted" until
+ // after the t3-rtx timer has expired.
+ AdvanceTime(a, *z, a.options.rto_initial);
+
+ // The chunk will be marked as retransmitted, and then as abandoned, which
+ // will trigger a FORWARD-TSN to be sent.
+
+ // FORWARD-TSN (third)
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ // Which will trigger a SACK
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket());
+
+ absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->ppid(), PPID(51));
+
+ absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg2.has_value());
+ EXPECT_EQ(msg2->ppid(), PPID(53));
+
+ absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage();
+ EXPECT_FALSE(msg3.has_value());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendManyFragmentedMessagesWithLimitedRtx) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ SendOptions send_options;
+ send_options.unordered = IsUnordered(true);
+ send_options.max_retransmissions = 0;
+ std::vector<uint8_t> payload(a.options.mtu * 2 - 100 /* margin */);
+ // Sending first message
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
+ // Sending second message
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
+ // Sending third message
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options);
+ // Sending fourth message
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(54), payload), send_options);
+
+ // First DATA, first fragment
+ std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51)));
+ z->socket.ReceivePacket(std::move(packet));
+
+ // First DATA, second fragment (lost)
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51)));
+
+ // Second DATA, first fragment
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52)));
+ z->socket.ReceivePacket(std::move(packet));
+
+ // Second DATA, second fragment (lost)
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52)));
+ EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
+
+ // Third DATA, first fragment
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53)));
+ EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
+ z->socket.ReceivePacket(std::move(packet));
+
+ // Third DATA, second fragment (lost)
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53)));
+ EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
+
+ // Fourth DATA, first fragment
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54)));
+ EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
+ z->socket.ReceivePacket(std::move(packet));
+
+ // Fourth DATA, second fragment
+ packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54)));
+ EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0)));
+ z->socket.ReceivePacket(std::move(packet));
+
+ ExchangeMessages(a, *z);
+
+ // Let the RTX timer expire, and exchange FORWARD-TSN/SACKs
+ AdvanceTime(a, *z, a.options.rto_initial);
+
+ ExchangeMessages(a, *z);
+
+ absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->ppid(), PPID(54));
+
+ ASSERT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+struct FakeChunkConfig : ChunkConfig {
+ static constexpr int kType = 0x49;
+ static constexpr size_t kHeaderSize = 4;
+ static constexpr int kVariableLengthAlignment = 0;
+};
+
+class FakeChunk : public Chunk, public TLVTrait<FakeChunkConfig> {
+ public:
+ FakeChunk() {}
+
+ FakeChunk(FakeChunk&& other) = default;
+ FakeChunk& operator=(FakeChunk&& other) = default;
+
+ void SerializeTo(std::vector<uint8_t>& out) const override {
+ AllocateTLV(out);
+ }
+ std::string ToString() const override { return "FAKE"; }
+};
+
+TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ // Inject a FAKE chunk
+ SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
+ b.Add(FakeChunk());
+ a.socket.ReceivePacket(b.Build());
+
+ // ERROR is sent as a reply. Capture it.
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket reply_packet,
+ SctpPacket::Parse(a.cb.ConsumeSentPacket()));
+ ASSERT_THAT(reply_packet.descriptors(), SizeIs(1));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ ErrorChunk error, ErrorChunk::Parse(reply_packet.descriptors()[0].data));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ UnrecognizedChunkTypeCause cause,
+ error.error_causes().get<UnrecognizedChunkTypeCause>());
+ EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04));
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, ReceivingErrorChunkReportsAsCallback) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ // Inject a ERROR chunk
+ SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions());
+ b.Add(
+ ErrorChunk(Parameters::Builder()
+ .Add(UnrecognizedChunkTypeCause({0x49, 0x00, 0x00, 0x04}))
+ .Build()));
+
+ EXPECT_CALL(a.cb, OnError(ErrorKind::kPeerReported,
+ HasSubstr("Unrecognized Chunk Type")));
+ a.socket.ReceivePacket(b.Build());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) {
+ SocketUnderTest a("A");
+
+ constexpr size_t kReceiveWindowBufferSize = 2000;
+ SocketUnderTest z(
+ "Z", {.mtu = 3000,
+ .max_receiver_window_buffer_size = kReceiveWindowBufferSize});
+
+ EXPECT_CALL(z.cb, OnClosed).Times(0);
+ EXPECT_CALL(z.cb, OnAborted).Times(0);
+
+ a.socket.Connect();
+ std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
+ SctpPacket::Parse(init_data));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ InitChunk init_chunk,
+ InitChunk::Parse(init_packet.descriptors()[0].data));
+ z.socket.ReceivePacket(init_data);
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // Fill up Z2 to the high watermark limit.
+ constexpr size_t kWatermarkLimit =
+ kReceiveWindowBufferSize * ReassemblyQueue::kHighWatermarkLimit;
+ constexpr size_t kRemainingSize = kReceiveWindowBufferSize - kWatermarkLimit;
+
+ TSN tsn = init_chunk.initial_tsn();
+ AnyDataChunk::Options opts;
+ opts.is_beginning = Data::IsBeginning(true);
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(kWatermarkLimit + 1), opts))
+ .Build());
+
+ // First DATA will always trigger a SACK. It's not interesting.
+ EXPECT_THAT(z.cb.ConsumeSentPacket(),
+ AllOf(HasSackWithCumAckTsn(tsn), HasSackWithNoGapAckBlocks()));
+
+ // This DATA should be accepted - it's advancing cum ack tsn.
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(AddTo(tsn, 1), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(1),
+ /*options=*/{}))
+ .Build());
+
+ // The receiver might have moved into delayed ack mode.
+ AdvanceTime(a, z, z.options.rto_initial);
+
+ EXPECT_THAT(
+ z.cb.ConsumeSentPacket(),
+ AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks()));
+
+ // This DATA will not be accepted - it's not advancing cum ack tsn.
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(1),
+ /*options=*/{}))
+ .Build());
+
+ // Sack will be sent in IMMEDIATE mode when this is happening.
+ EXPECT_THAT(
+ z.cb.ConsumeSentPacket(),
+ AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks()));
+
+ // This DATA will not be accepted either.
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(AddTo(tsn, 4), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(1),
+ /*options=*/{}))
+ .Build());
+
+ // Sack will be sent in IMMEDIATE mode when this is happening.
+ EXPECT_THAT(
+ z.cb.ConsumeSentPacket(),
+ AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks()));
+
+ // This DATA should be accepted, and it fills the reassembly queue.
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(AddTo(tsn, 2), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(kRemainingSize),
+ /*options=*/{}))
+ .Build());
+
+ // The receiver might have moved into delayed ack mode.
+ AdvanceTime(a, z, z.options.rto_initial);
+
+ EXPECT_THAT(
+ z.cb.ConsumeSentPacket(),
+ AllOf(HasSackWithCumAckTsn(AddTo(tsn, 2)), HasSackWithNoGapAckBlocks()));
+
+ EXPECT_CALL(z.cb, OnAborted(ErrorKind::kResourceExhaustion, _));
+ EXPECT_CALL(z.cb, OnClosed).Times(0);
+
+ // This DATA will make the connection close. It's too full now.
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(kSmallMessageSize),
+ /*options=*/{}))
+ .Build());
+}
+
+TEST(DcSctpSocketTest, SetMaxMessageSize) {
+ SocketUnderTest a("A");
+
+ a.socket.SetMaxMessageSize(42u);
+ EXPECT_EQ(a.socket.options().max_message_size, 42u);
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ // Mock that the time always goes forward.
+ TimeMs now(0);
+ EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() {
+ now += DurationMs(3);
+ return now;
+ });
+ EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() {
+ now += DurationMs(3);
+ return now;
+ });
+
+ // Queue a few small messages with low lifetime, both ordered and unordered,
+ // and validate that all are delivered.
+ static constexpr int kIterations = 100;
+ for (int i = 0; i < kIterations; ++i) {
+ SendOptions send_options;
+ send_options.unordered = IsUnordered((i % 2) == 0);
+ send_options.lifetime = DurationMs(i % 3); // 0, 1, 2 ms
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
+ }
+
+ ExchangeMessages(a, *z);
+
+ for (int i = 0; i < kIterations; ++i) {
+ EXPECT_TRUE(z->cb.ConsumeReceivedMessage().has_value());
+ }
+
+ EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
+
+ // Validate that the sockets really make the time move forward.
+ EXPECT_GE(*now, kIterations * 2);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ DiscardsMessagesWithLowLifetimeIfMustBuffer) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ SendOptions lifetime_0;
+ lifetime_0.unordered = IsUnordered(true);
+ lifetime_0.lifetime = DurationMs(0);
+
+ SendOptions lifetime_1;
+ lifetime_1.unordered = IsUnordered(true);
+ lifetime_1.lifetime = DurationMs(1);
+
+ // Mock that the time always goes forward.
+ TimeMs now(0);
+ EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() {
+ now += DurationMs(3);
+ return now;
+ });
+ EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() {
+ now += DurationMs(3);
+ return now;
+ });
+
+ // Fill up the send buffer with a large message.
+ std::vector<uint8_t> payload(kLargeMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions);
+
+ // And queue a few small messages with lifetime=0 or 1 ms - can't be sent.
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), lifetime_0);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {4, 5, 6}), lifetime_1);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {7, 8, 9}), lifetime_0);
+
+ // Handle all that was sent until congestion window got full.
+ for (;;) {
+ std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
+ if (packet_from_a.empty()) {
+ break;
+ }
+ z->socket.ReceivePacket(std::move(packet_from_a));
+ }
+
+ // Shouldn't be enough to send that large message.
+ EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
+
+ // Exchange the rest of the messages, with the time ever increasing.
+ ExchangeMessages(a, *z);
+
+ // The large message should be delivered. It was sent reliably.
+ ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage m1, z->cb.ConsumeReceivedMessage());
+ EXPECT_EQ(m1.stream_id(), StreamID(1));
+ EXPECT_THAT(m1.payload(), SizeIs(kLargeMessageSize));
+
+ // But none of the smaller messages.
+ EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, HasReasonableBufferedAmountValues) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ // Sending a small message will directly send it as a single packet, so
+ // nothing is left in the queue.
+ EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ // Sending a message will directly start sending a few packets, so the
+ // buffered amount is not the full message size.
+ EXPECT_GT(a.socket.buffered_amount(StreamID(1)), 0u);
+ EXPECT_LT(a.socket.buffered_amount(StreamID(1)), kLargeMessageSize);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) {
+ SocketUnderTest a("A");
+ EXPECT_EQ(a.socket.buffered_amount_low_threshold(StreamID(1)), 0u);
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ TriggersOnBufferedAmountLowWithDefaultValueZero) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1)));
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).WillRepeatedly(testing::Return());
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ DoesntTriggerOnBufferedAmountLowIfBelowThreshold) {
+ static constexpr size_t kMessageSize = 1000;
+ static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10;
+
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ a.socket.SetBufferedAmountLowThreshold(StreamID(1),
+ kBufferedAmountLowThreshold);
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(0);
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountMultipleTimes) {
+ static constexpr size_t kMessageSize = 1000;
+ static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2;
+
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ a.socket.SetBufferedAmountLowThreshold(StreamID(1),
+ kBufferedAmountLowThreshold);
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(3);
+ EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(2))).Times(2);
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ a.socket.Send(
+ DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ a.socket.Send(
+ DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ ExchangeMessages(a, *z);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) {
+ static constexpr size_t kMessageSize = 1000;
+ static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5;
+
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ a.socket.SetBufferedAmountLowThreshold(StreamID(1),
+ kBufferedAmountLowThreshold);
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0);
+
+ // Add a few messages to fill up the congestion window. When that is full,
+ // messages will start to be fully buffered.
+ while (a.socket.buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) {
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+ }
+ size_t initial_buffered = a.socket.buffered_amount(StreamID(1));
+ ASSERT_GT(initial_buffered, kBufferedAmountLowThreshold);
+
+ // Start ACKing packets, which will empty the send queue, and trigger the
+ // callback.
+ EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(1);
+ ExchangeMessages(a, *z);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ DoesntTriggerOnTotalBufferAmountLowWhenBelow) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ ExchangeMessages(a, *z);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest,
+ TriggersOnTotalBufferAmountLowWhenCrossingThreshold) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0);
+
+ // Fill up the send queue completely.
+ for (;;) {
+ if (a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions) == SendStatus::kErrorResourceExhaustion) {
+ break;
+ }
+ }
+
+ EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(1);
+ ExchangeMessages(a, *z);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST(DcSctpSocketTest, InitialMetricsAreUnset) {
+ SocketUnderTest a("A");
+
+ EXPECT_FALSE(a.socket.GetMetrics().has_value());
+}
+
+TEST(DcSctpSocketTest, MessageInterleavingMetricsAreSet) {
+ std::vector<std::pair<bool, bool>> combinations = {
+ {false, false}, {false, true}, {true, false}, {true, true}};
+ for (const auto& [a_enable, z_enable] : combinations) {
+ DcSctpOptions a_options = {.enable_message_interleaving = a_enable};
+ DcSctpOptions z_options = {.enable_message_interleaving = z_enable};
+
+ SocketUnderTest a("A", a_options);
+ SocketUnderTest z("Z", z_options);
+ ConnectSockets(a, z);
+
+ EXPECT_EQ(a.socket.GetMetrics()->uses_message_interleaving,
+ a_enable && z_enable);
+ }
+}
+
+TEST(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ const size_t initial_a_rwnd = a.options.max_receiver_window_buffer_size *
+ ReassemblyQueue::kHighWatermarkLimit;
+
+ EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 2u);
+ EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 2u);
+ EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 0u);
+ EXPECT_EQ(a.socket.GetMetrics()->cwnd_bytes,
+ a.options.cwnd_mtus_initial * a.options.mtu);
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
+
+ EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 2u);
+ EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 0u);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u);
+
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK
+ EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
+
+ EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value());
+
+ EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 3u);
+ EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 3u);
+ EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 1u);
+
+ EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 3u);
+ EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 1u);
+
+ // Send one more (large - fragmented), and receive the delayed SACK.
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(a.options.mtu * 2 + 1)),
+ kSendOptions);
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 3u);
+
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
+
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u);
+ EXPECT_GT(a.socket.GetMetrics()->peer_rwnd_bytes, 0u);
+ EXPECT_LT(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
+
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA
+
+ EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value());
+
+ EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 6u);
+ EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 4u);
+ EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 2u);
+
+ EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 6u);
+ EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 2u);
+
+ // Delayed sack
+ AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
+
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK
+ EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u);
+ EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 5u);
+ EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd);
+}
+
+TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ size_t payload_bytes =
+ a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
+
+ size_t expected_sent_packets = a.options.cwnd_mtus_initial;
+
+ size_t expected_queued_bytes =
+ kLargeMessageSize - expected_sent_packets * payload_bytes;
+
+ size_t expected_queued_packets = expected_queued_bytes / payload_bytes;
+
+ // Due to alignment, padding etc, it's hard to calculate the exact number, but
+ // it should be in this range.
+ EXPECT_GE(a.socket.GetMetrics()->unack_data_count,
+ expected_sent_packets + expected_queued_packets);
+
+ EXPECT_LE(a.socket.GetMetrics()->unack_data_count,
+ expected_sent_packets + expected_queued_packets + 2);
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, DoesntSendMoreThanMaxBurstPackets) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ for (int i = 0; i < kMaxBurstPackets; ++i) {
+ std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, Not(IsEmpty()));
+ z->socket.ReceivePacket(std::move(packet)); // DATA
+ }
+
+ EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty());
+
+ ExchangeMessages(a, *z);
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ // A really large message, to ensure that the congestion window is often full.
+ constexpr size_t kMessageSize = 100000;
+ a.socket.Send(
+ DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)),
+ kSendOptions);
+
+ bool delivered_packet = false;
+ std::vector<size_t> data_packet_sizes;
+ do {
+ delivered_packet = false;
+ std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket();
+ if (!packet_from_a.empty()) {
+ data_packet_sizes.push_back(packet_from_a.size());
+ delivered_packet = true;
+ z->socket.ReceivePacket(std::move(packet_from_a));
+ }
+ std::vector<uint8_t> packet_from_z = z->cb.ConsumeSentPacket();
+ if (!packet_from_z.empty()) {
+ delivered_packet = true;
+ a.socket.ReceivePacket(std::move(packet_from_z));
+ }
+ } while (delivered_packet);
+
+ size_t packet_payload_bytes =
+ a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize;
+ // +1 accounts for padding, and rounding up.
+ size_t expected_packets =
+ (kMessageSize + packet_payload_bytes - 1) / packet_payload_bytes + 1;
+ EXPECT_THAT(data_packet_sizes, SizeIs(expected_packets));
+
+ // Remove the last size - it will be the remainder. But all other sizes should
+ // be large.
+ data_packet_sizes.pop_back();
+
+ for (size_t size : data_packet_sizes) {
+ // The 4 is for padding/alignment.
+ EXPECT_GE(size, a.options.mtu - 4);
+ }
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST(DcSctpSocketTest, SendMessagesAfterHandover) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+
+ // Send message before handover to move socket to a not initial state
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ z->cb.ConsumeReceivedMessage();
+
+ z = HandoverSocket(std::move(z));
+
+ absl::optional<DcSctpMessage> msg;
+
+ RTC_LOG(LS_INFO) << "Sending A #1";
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions);
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), testing::ElementsAre(3, 4));
+
+ RTC_LOG(LS_INFO) << "Sending A #2";
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions);
+ z->socket.ReceivePacket(a.cb.ConsumeSentPacket());
+
+ msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(2));
+ EXPECT_THAT(msg->payload(), testing::ElementsAre(5, 6));
+
+ RTC_LOG(LS_INFO) << "Sending Z #1";
+
+ z->socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions);
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // ack
+ a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // data
+
+ msg = a.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), testing::ElementsAre(1, 2, 3));
+}
+
+TEST(DcSctpSocketTest, CanDetectDcsctpImplementation) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp);
+
+ // As A initiated the connection establishment, Z will not receive enough
+ // information to know about A's implementation
+ EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kUnknown);
+}
+
+TEST(DcSctpSocketTest, BothCanDetectDcsctpImplementation) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ a.socket.Connect();
+ z.socket.Connect();
+
+ ExchangeMessages(a, z);
+
+ EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp);
+ EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kDcsctp);
+}
+
+TEST_P(DcSctpSocketParametrizedTest, CanLoseFirstOrderedMessage) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ SendOptions send_options;
+ send_options.unordered = IsUnordered(false);
+ send_options.max_retransmissions = 0;
+ std::vector<uint8_t> payload(a.options.mtu - 100);
+
+ // Send a first message (SID=1, SSN=0)
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options);
+
+ // First DATA is lost, and retransmission timer will delete it.
+ a.cb.ConsumeSentPacket();
+ AdvanceTime(a, *z, a.options.rto_initial);
+ ExchangeMessages(a, *z);
+
+ // Send a second message (SID=0, SSN=1).
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options);
+ ExchangeMessages(a, *z);
+
+ // The Z socket should receive the second message, but not the first.
+ absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->ppid(), PPID(52));
+
+ EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
+TEST(DcSctpSocketTest, ReceiveBothUnorderedAndOrderedWithSameTSN) {
+ /* This issue was found by fuzzing. */
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+ std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet,
+ SctpPacket::Parse(init_data));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ InitChunk init_chunk,
+ InitChunk::Parse(init_packet.descriptors()[0].data));
+ z.socket.ReceivePacket(init_data);
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ a.socket.ReceivePacket(z.cb.ConsumeSentPacket());
+
+ // Receive a short unordered message with tsn=INITIAL_TSN+1
+ TSN tsn = init_chunk.initial_tsn();
+ AnyDataChunk::Options opts;
+ opts.is_beginning = Data::IsBeginning(true);
+ opts.is_end = Data::IsEnd(true);
+ opts.is_unordered = IsUnordered(true);
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(10), opts))
+ .Build());
+
+ // Now receive a longer _ordered_ message with [INITIAL_TSN, INITIAL_TSN+1].
+ // This isn't allowed as it reuses TSN=53 with different properties, but it
+ // shouldn't cause any issues.
+ opts.is_unordered = IsUnordered(false);
+ opts.is_end = Data::IsEnd(false);
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(10), opts))
+ .Build());
+
+ opts.is_beginning = Data::IsBeginning(false);
+ opts.is_end = Data::IsEnd(true);
+ z.socket.ReceivePacket(
+ SctpPacket::Builder(z.socket.verification_tag(), z.options)
+ .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53),
+ std::vector<uint8_t>(10), opts))
+ .Build());
+}
+
+TEST(DcSctpSocketTest, CloseTwoStreamsAtTheSameTime) {
+ // Reported as https://crbug.com/1312009.
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(2)))).Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(2)))).Times(1);
+
+ ConnectSockets(a, z);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+
+ ExchangeMessages(a, z);
+
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
+
+ ExchangeMessages(a, z);
+}
+
+TEST(DcSctpSocketTest, CloseThreeStreamsAtTheSameTime) {
+ // Similar to CloseTwoStreamsAtTheSameTime, but ensuring that the two
+ // remaining streams are reset at the same time in the second request.
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(
+ UnorderedElementsAre(StreamID(2), StreamID(3))))
+ .Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(
+ UnorderedElementsAre(StreamID(2), StreamID(3))))
+ .Times(1);
+
+ ConnectSockets(a, z);
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), kSendOptions);
+
+ ExchangeMessages(a, z);
+
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
+
+ ExchangeMessages(a, z);
+}
+
+TEST(DcSctpSocketTest, CloseStreamsWithPendingRequest) {
+ // Checks that stream reset requests are properly paused when they can't be
+ // immediately reset - i.e. when there is already an ongoing stream reset
+ // request (and there can only be a single one in-flight).
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(z.cb, OnIncomingStreamsReset(
+ UnorderedElementsAre(StreamID(2), StreamID(3))))
+ .Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1);
+ EXPECT_CALL(a.cb, OnStreamsResetPerformed(
+ UnorderedElementsAre(StreamID(2), StreamID(3))))
+ .Times(1);
+
+ ConnectSockets(a, z);
+
+ SendOptions send_options = {.unordered = IsUnordered(false)};
+
+ // Send a few ordered messages
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options);
+
+ ExchangeMessages(a, z);
+
+ // Receive these messages
+ absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg1.has_value());
+ EXPECT_EQ(msg1->stream_id(), StreamID(1));
+ absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg2.has_value());
+ EXPECT_EQ(msg2->stream_id(), StreamID(2));
+ absl::optional<DcSctpMessage> msg3 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg3.has_value());
+ EXPECT_EQ(msg3->stream_id(), StreamID(3));
+
+ // Reset the streams - not all at once.
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+
+ std::vector<uint8_t> packet = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet, HasReconfigWithStreams(ElementsAre(StreamID(1))));
+ z.socket.ReceivePacket(std::move(packet));
+
+ // Sending more reset requests while this one is ongoing.
+
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)}));
+
+ ExchangeMessages(a, z);
+
+ // Send a few more ordered messages
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options);
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options);
+
+ ExchangeMessages(a, z);
+
+ // Receive these messages
+ absl::optional<DcSctpMessage> msg4 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg4.has_value());
+ EXPECT_EQ(msg4->stream_id(), StreamID(1));
+ absl::optional<DcSctpMessage> msg5 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg5.has_value());
+ EXPECT_EQ(msg5->stream_id(), StreamID(2));
+ absl::optional<DcSctpMessage> msg6 = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg6.has_value());
+ EXPECT_EQ(msg6->stream_id(), StreamID(3));
+}
+
+TEST(DcSctpSocketTest, StreamsHaveInitialPriority) {
+ DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+ SocketUnderTest a("A", options);
+
+ EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)),
+ options.default_stream_priority);
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+
+ EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)),
+ options.default_stream_priority);
+}
+
+TEST(DcSctpSocketTest, CanChangeStreamPriority) {
+ DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+ SocketUnderTest a("A", options);
+
+ a.socket.SetStreamPriority(StreamID(1), StreamPriority(43));
+ EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+
+ a.socket.SetStreamPriority(StreamID(2), StreamPriority(43));
+ EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
+}
+
+TEST_P(DcSctpSocketParametrizedTest, WillHandoverPriority) {
+ DcSctpOptions options = {.default_stream_priority = StreamPriority(42)};
+ auto a = std::make_unique<SocketUnderTest>("A", options);
+ SocketUnderTest z("Z");
+
+ ConnectSockets(*a, z);
+
+ a->socket.SetStreamPriority(StreamID(1), StreamPriority(43));
+ a->socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions);
+ a->socket.SetStreamPriority(StreamID(2), StreamPriority(43));
+
+ ExchangeMessages(*a, z);
+
+ a = MaybeHandoverSocket(std::move(a));
+
+ EXPECT_EQ(a->socket.GetStreamPriority(StreamID(1)), StreamPriority(43));
+ EXPECT_EQ(a->socket.GetStreamPriority(StreamID(2)), StreamPriority(43));
+}
+
+TEST(DcSctpSocketTest, ReconnectSocketWithPendingStreamReset) {
+ // This is an issue found by fuzzing, and doesn't really make sense in WebRTC
+ // data channels as a SCTP connection is never ever closed and then
+ // reconnected. SCTP connections are closed when the peer connection is
+ // deleted, and then it doesn't do more with SCTP.
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ ConnectSockets(a, z);
+
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)}));
+
+ EXPECT_CALL(z.cb, OnAborted).Times(1);
+ a.socket.Close();
+
+ EXPECT_EQ(a.socket.state(), SocketState::kClosed);
+
+ EXPECT_CALL(a.cb, OnConnected).Times(1);
+ EXPECT_CALL(z.cb, OnConnected).Times(1);
+ a.socket.Connect();
+ ExchangeMessages(a, z);
+ a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)}));
+}
+
+TEST(DcSctpSocketTest, SmallSentMessagesWithPrioWillArriveInSpecificOrder) {
+ DcSctpOptions options = {.enable_message_interleaving = true};
+ SocketUnderTest a("A", options);
+ SocketUnderTest z("A", options);
+
+ a.socket.SetStreamPriority(StreamID(1), StreamPriority(700));
+ a.socket.SetStreamPriority(StreamID(2), StreamPriority(200));
+ a.socket.SetStreamPriority(StreamID(3), StreamPriority(100));
+
+ // Enqueue messages before connecting the socket, to ensure they aren't send
+ // as soon as Send() is called.
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(301),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(103),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+
+ ConnectSockets(a, z);
+ ExchangeMessages(a, z);
+
+ std::vector<uint32_t> received_ppids;
+ for (;;) {
+ absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+ if (!msg.has_value()) {
+ break;
+ }
+ received_ppids.push_back(*msg->ppid());
+ }
+
+ EXPECT_THAT(received_ppids, ElementsAre(101, 102, 103, 201, 301));
+}
+
+TEST(DcSctpSocketTest, LargeSentMessagesWithPrioWillArriveInSpecificOrder) {
+ DcSctpOptions options = {.enable_message_interleaving = true};
+ SocketUnderTest a("A", options);
+ SocketUnderTest z("A", options);
+
+ a.socket.SetStreamPriority(StreamID(1), StreamPriority(700));
+ a.socket.SetStreamPriority(StreamID(2), StreamPriority(200));
+ a.socket.SetStreamPriority(StreamID(3), StreamPriority(100));
+
+ // Enqueue messages before connecting the socket, to ensure they aren't send
+ // as soon as Send() is called.
+ a.socket.Send(DcSctpMessage(StreamID(3), PPID(301),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ ConnectSockets(a, z);
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201, 301));
+}
+
+TEST(DcSctpSocketTest, MessageWithHigherPrioWillInterruptLowerPrioMessage) {
+ DcSctpOptions options = {.enable_message_interleaving = true};
+ SocketUnderTest a("A", options);
+ SocketUnderTest z("Z", options);
+
+ ConnectSockets(a, z);
+
+ a.socket.SetStreamPriority(StreamID(2), StreamPriority(128));
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(201),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ // Due to a non-zero initial congestion window, the message will already start
+ // to send, but will not succeed to be sent completely before filling the
+ // congestion window or stopping due to reaching how many packets that can be
+ // sent at once (max burst). The important thing is that the entire message
+ // doesn't get sent in full.
+
+ // Now enqueue two messages; one small and one large higher priority message.
+ a.socket.SetStreamPriority(StreamID(1), StreamPriority(512));
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(101),
+ std::vector<uint8_t>(kSmallMessageSize)),
+ kSendOptions);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(102),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201));
+}
+
+TEST(DcSctpSocketTest, LifecycleEventsAreGeneratedForAckedMessages) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+ ConnectSockets(a, z);
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(101),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ {.lifecycle_id = LifecycleId(41)});
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(102),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+
+ a.socket.Send(DcSctpMessage(StreamID(2), PPID(103),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ {.lifecycle_id = LifecycleId(42)});
+
+ EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(41)));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(41)));
+ EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(42)));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(42)));
+ ExchangeMessages(a, z);
+ // In case of delayed ack.
+ AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 103));
+}
+
+TEST(DcSctpSocketTest, LifecycleEventsForFailMaxRetransmissions) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+ ConnectSockets(a, z);
+
+ std::vector<uint8_t> payload(a.options.mtu - 100);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
+ {
+ .max_retransmissions = 0,
+ .lifecycle_id = LifecycleId(1),
+ });
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload),
+ {
+ .max_retransmissions = 0,
+ .lifecycle_id = LifecycleId(2),
+ });
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload),
+ {
+ .max_retransmissions = 0,
+ .lifecycle_id = LifecycleId(3),
+ });
+
+ // First DATA
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // Second DATA (lost)
+ a.cb.ConsumeSentPacket();
+
+ EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(1)));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
+ EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(2),
+ /*maybe_delivered=*/true));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(2)));
+ EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(3)));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(3)));
+ ExchangeMessages(a, z);
+
+ // Handle delayed SACK.
+ AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
+ ExchangeMessages(a, z);
+
+ // The chunk is now NACKed. Let the RTO expire, to discard the message.
+ AdvanceTime(a, z, a.options.rto_initial);
+ ExchangeMessages(a, z);
+
+ // Handle delayed SACK.
+ AdvanceTime(a, z, a.options.delayed_ack_max_timeout);
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(51, 53));
+}
+
+TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithRetransmitLimit) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+ ConnectSockets(a, z);
+
+ // Will not be able to send it in full within the congestion window, but will
+ // need to wait for SACKs to be received for more fragments to be sent.
+ std::vector<uint8_t> payload(kLargeMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
+ {
+ .max_retransmissions = 0,
+ .lifecycle_id = LifecycleId(1),
+ });
+
+ // First DATA
+ z.socket.ReceivePacket(a.cb.ConsumeSentPacket());
+ // Second DATA (lost)
+ a.cb.ConsumeSentPacket();
+
+ EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1),
+ /*maybe_delivered=*/false));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty());
+}
+
+TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithLifetimeLimit) {
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ // Send it before the socket is connected, to prevent it from being sent too
+ // quickly. The idea is that it should be expired before even attempting to
+ // send it in full.
+ std::vector<uint8_t> payload(kSmallMessageSize);
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload),
+ {
+ .lifetime = DurationMs(100),
+ .lifecycle_id = LifecycleId(1),
+ });
+
+ AdvanceTime(a, z, DurationMs(200));
+
+ EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1),
+ /*maybe_delivered=*/false));
+ EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1)));
+ ConnectSockets(a, z);
+ ExchangeMessages(a, z);
+
+ EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty());
+}
+
+} // namespace
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.cc b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.cc
new file mode 100644
index 0000000000..9588b85b59
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.cc
@@ -0,0 +1,196 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/heartbeat_handler.h"
+
+#include <stddef.h>
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/functional/bind_front.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "net/dcsctp/packet/bounded_byte_reader.h"
+#include "net/dcsctp/packet/bounded_byte_writer.h"
+#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
+#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h"
+#include "net/dcsctp/packet/parameter/parameter.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/socket/context.h"
+#include "net/dcsctp/timer/timer.h"
+#include "rtc_base/logging.h"
+
+namespace dcsctp {
+
+// This is stored (in serialized form) as HeartbeatInfoParameter sent in
+// HeartbeatRequestChunk and received back in HeartbeatAckChunk. It should be
+// well understood that this data may be modified by the peer, so it can't
+// be trusted.
+//
+// It currently only stores a timestamp, in millisecond precision, to allow for
+// RTT measurements. If that would be manipulated by the peer, it would just
+// result in incorrect RTT measurements, which isn't an issue.
+class HeartbeatInfo {
+ public:
+ static constexpr size_t kBufferSize = sizeof(uint64_t);
+ static_assert(kBufferSize == 8, "Unexpected buffer size");
+
+ explicit HeartbeatInfo(TimeMs created_at) : created_at_(created_at) {}
+
+ std::vector<uint8_t> Serialize() {
+ uint32_t high_bits = static_cast<uint32_t>(*created_at_ >> 32);
+ uint32_t low_bits = static_cast<uint32_t>(*created_at_);
+
+ std::vector<uint8_t> data(kBufferSize);
+ BoundedByteWriter<kBufferSize> writer(data);
+ writer.Store32<0>(high_bits);
+ writer.Store32<4>(low_bits);
+ return data;
+ }
+
+ static absl::optional<HeartbeatInfo> Deserialize(
+ rtc::ArrayView<const uint8_t> data) {
+ if (data.size() != kBufferSize) {
+ RTC_LOG(LS_WARNING) << "Invalid heartbeat info: " << data.size()
+ << " bytes";
+ return absl::nullopt;
+ }
+
+ BoundedByteReader<kBufferSize> reader(data);
+ uint32_t high_bits = reader.Load32<0>();
+ uint32_t low_bits = reader.Load32<4>();
+
+ uint64_t created_at = static_cast<uint64_t>(high_bits) << 32 | low_bits;
+ return HeartbeatInfo(TimeMs(created_at));
+ }
+
+ TimeMs created_at() const { return created_at_; }
+
+ private:
+ const TimeMs created_at_;
+};
+
+HeartbeatHandler::HeartbeatHandler(absl::string_view log_prefix,
+ const DcSctpOptions& options,
+ Context* context,
+ TimerManager* timer_manager)
+ : log_prefix_(std::string(log_prefix) + "heartbeat: "),
+ ctx_(context),
+ timer_manager_(timer_manager),
+ interval_duration_(options.heartbeat_interval),
+ interval_duration_should_include_rtt_(
+ options.heartbeat_interval_include_rtt),
+ interval_timer_(timer_manager_->CreateTimer(
+ "heartbeat-interval",
+ absl::bind_front(&HeartbeatHandler::OnIntervalTimerExpiry, this),
+ TimerOptions(interval_duration_, TimerBackoffAlgorithm::kFixed))),
+ timeout_timer_(timer_manager_->CreateTimer(
+ "heartbeat-timeout",
+ absl::bind_front(&HeartbeatHandler::OnTimeoutTimerExpiry, this),
+ TimerOptions(options.rto_initial,
+ TimerBackoffAlgorithm::kExponential,
+ /*max_restarts=*/0))) {
+ // The interval timer must always be running as long as the association is up.
+ RestartTimer();
+}
+
+void HeartbeatHandler::RestartTimer() {
+ if (interval_duration_ == DurationMs(0)) {
+ // Heartbeating has been disabled.
+ return;
+ }
+
+ if (interval_duration_should_include_rtt_) {
+ // The RTT should be used, but it's not easy accessible. The RTO will
+ // suffice.
+ interval_timer_->set_duration(interval_duration_ + ctx_->current_rto());
+ } else {
+ interval_timer_->set_duration(interval_duration_);
+ }
+
+ interval_timer_->Start();
+}
+
+void HeartbeatHandler::HandleHeartbeatRequest(HeartbeatRequestChunk chunk) {
+ // https://tools.ietf.org/html/rfc4960#section-8.3
+ // "The receiver of the HEARTBEAT should immediately respond with a
+ // HEARTBEAT ACK that contains the Heartbeat Information TLV, together with
+ // any other received TLVs, copied unchanged from the received HEARTBEAT
+ // chunk."
+ ctx_->Send(ctx_->PacketBuilder().Add(
+ HeartbeatAckChunk(std::move(chunk).extract_parameters())));
+}
+
+void HeartbeatHandler::HandleHeartbeatAck(HeartbeatAckChunk chunk) {
+ timeout_timer_->Stop();
+ absl::optional<HeartbeatInfoParameter> info_param = chunk.info();
+ if (!info_param.has_value()) {
+ ctx_->callbacks().OnError(
+ ErrorKind::kParseFailed,
+ "Failed to parse HEARTBEAT-ACK; No Heartbeat Info parameter");
+ return;
+ }
+ absl::optional<HeartbeatInfo> info =
+ HeartbeatInfo::Deserialize(info_param->info());
+ if (!info.has_value()) {
+ ctx_->callbacks().OnError(ErrorKind::kParseFailed,
+ "Failed to parse HEARTBEAT-ACK; Failed to "
+ "deserialized Heartbeat info parameter");
+ return;
+ }
+
+ TimeMs now = ctx_->callbacks().TimeMillis();
+ if (info->created_at() > TimeMs(0) && info->created_at() <= now) {
+ ctx_->ObserveRTT(now - info->created_at());
+ }
+
+ // https://tools.ietf.org/html/rfc4960#section-8.1
+ // "The counter shall be reset each time ... a HEARTBEAT ACK is received from
+ // the peer endpoint."
+ ctx_->ClearTxErrorCounter();
+}
+
+absl::optional<DurationMs> HeartbeatHandler::OnIntervalTimerExpiry() {
+ if (ctx_->is_connection_established()) {
+ HeartbeatInfo info(ctx_->callbacks().TimeMillis());
+ timeout_timer_->set_duration(ctx_->current_rto());
+ timeout_timer_->Start();
+ RTC_DLOG(LS_INFO) << log_prefix_ << "Sending HEARTBEAT with timeout "
+ << *timeout_timer_->duration();
+
+ Parameters parameters = Parameters::Builder()
+ .Add(HeartbeatInfoParameter(info.Serialize()))
+ .Build();
+
+ ctx_->Send(ctx_->PacketBuilder().Add(
+ HeartbeatRequestChunk(std::move(parameters))));
+ } else {
+ RTC_DLOG(LS_VERBOSE)
+ << log_prefix_
+ << "Will not send HEARTBEAT when connection not established";
+ }
+ return absl::nullopt;
+}
+
+absl::optional<DurationMs> HeartbeatHandler::OnTimeoutTimerExpiry() {
+ // Note that the timeout timer is not restarted. It will be started again when
+ // the interval timer expires.
+ RTC_DCHECK(!timeout_timer_->is_running());
+ ctx_->IncrementTxErrorCounter("HEARTBEAT timeout");
+ return absl::nullopt;
+}
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.h b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.h
new file mode 100644
index 0000000000..14c3109534
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.h
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_
+#define NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_
+
+#include <stdint.h>
+
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/socket/context.h"
+#include "net/dcsctp/timer/timer.h"
+
+namespace dcsctp {
+
+// HeartbeatHandler handles all logic around sending heartbeats and receiving
+// the responses, as well as receiving incoming heartbeat requests.
+//
+// Heartbeats are sent on idle connections to ensure that the connection is
+// still healthy and to measure the RTT. If a number of heartbeats time out,
+// the connection will eventually be closed.
+class HeartbeatHandler {
+ public:
+ HeartbeatHandler(absl::string_view log_prefix,
+ const DcSctpOptions& options,
+ Context* context,
+ TimerManager* timer_manager);
+
+ // Called when the heartbeat interval timer should be restarted. This is
+ // generally done every time data is sent, which makes the timer expire when
+ // the connection is idle.
+ void RestartTimer();
+
+ // Called on received HeartbeatRequestChunk chunks.
+ void HandleHeartbeatRequest(HeartbeatRequestChunk chunk);
+
+ // Called on received HeartbeatRequestChunk chunks.
+ void HandleHeartbeatAck(HeartbeatAckChunk chunk);
+
+ private:
+ absl::optional<DurationMs> OnIntervalTimerExpiry();
+ absl::optional<DurationMs> OnTimeoutTimerExpiry();
+
+ const std::string log_prefix_;
+ Context* ctx_;
+ TimerManager* timer_manager_;
+ // The time for a connection to be idle before a heartbeat is sent.
+ const DurationMs interval_duration_;
+ // Adding RTT to the duration will add some jitter, which is good in
+ // production, but less good in unit tests, which is why it can be disabled.
+ const bool interval_duration_should_include_rtt_;
+ const std::unique_ptr<Timer> interval_timer_;
+ const std::unique_ptr<Timer> timeout_timer_;
+};
+} // namespace dcsctp
+
+#endif // NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_
diff --git a/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler_test.cc b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler_test.cc
new file mode 100644
index 0000000000..faa0e3da06
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler_test.cc
@@ -0,0 +1,184 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/heartbeat_handler.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "api/task_queue/task_queue_base.h"
+#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h"
+#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h"
+#include "net/dcsctp/public/types.h"
+#include "net/dcsctp/socket/mock_context.h"
+#include "net/dcsctp/testing/testing_macros.h"
+#include "rtc_base/gunit.h"
+#include "test/gmock.h"
+
+namespace dcsctp {
+namespace {
+using ::testing::ElementsAre;
+using ::testing::IsEmpty;
+using ::testing::NiceMock;
+using ::testing::Return;
+using ::testing::SizeIs;
+
+constexpr DurationMs kHeartbeatInterval = DurationMs(30'000);
+
+DcSctpOptions MakeOptions(DurationMs heartbeat_interval) {
+ DcSctpOptions options;
+ options.heartbeat_interval_include_rtt = false;
+ options.heartbeat_interval = heartbeat_interval;
+ return options;
+}
+
+class HeartbeatHandlerTestBase : public testing::Test {
+ protected:
+ explicit HeartbeatHandlerTestBase(DurationMs heartbeat_interval)
+ : options_(MakeOptions(heartbeat_interval)),
+ context_(&callbacks_),
+ timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) {
+ return callbacks_.CreateTimeout(precision);
+ }),
+ handler_("log: ", options_, &context_, &timer_manager_) {}
+
+ void AdvanceTime(DurationMs duration) {
+ callbacks_.AdvanceTime(duration);
+ for (;;) {
+ absl::optional<TimeoutID> timeout_id = callbacks_.GetNextExpiredTimeout();
+ if (!timeout_id.has_value()) {
+ break;
+ }
+ timer_manager_.HandleTimeout(*timeout_id);
+ }
+ }
+
+ const DcSctpOptions options_;
+ NiceMock<MockDcSctpSocketCallbacks> callbacks_;
+ NiceMock<MockContext> context_;
+ TimerManager timer_manager_;
+ HeartbeatHandler handler_;
+};
+
+class HeartbeatHandlerTest : public HeartbeatHandlerTestBase {
+ protected:
+ HeartbeatHandlerTest() : HeartbeatHandlerTestBase(kHeartbeatInterval) {}
+};
+
+class DisabledHeartbeatHandlerTest : public HeartbeatHandlerTestBase {
+ protected:
+ DisabledHeartbeatHandlerTest() : HeartbeatHandlerTestBase(DurationMs(0)) {}
+};
+
+TEST_F(HeartbeatHandlerTest, HasRunningHeartbeatIntervalTimer) {
+ AdvanceTime(options_.heartbeat_interval);
+
+ // Validate that a heartbeat request was sent.
+ std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload));
+ ASSERT_THAT(packet.descriptors(), SizeIs(1));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ HeartbeatRequestChunk request,
+ HeartbeatRequestChunk::Parse(packet.descriptors()[0].data));
+
+ EXPECT_TRUE(request.info().has_value());
+}
+
+TEST_F(HeartbeatHandlerTest, RepliesToHeartbeatRequests) {
+ uint8_t info_data[] = {1, 2, 3, 4, 5};
+ HeartbeatRequestChunk request(
+ Parameters::Builder().Add(HeartbeatInfoParameter(info_data)).Build());
+
+ handler_.HandleHeartbeatRequest(std::move(request));
+
+ std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload));
+ ASSERT_THAT(packet.descriptors(), SizeIs(1));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ HeartbeatAckChunk response,
+ HeartbeatAckChunk::Parse(packet.descriptors()[0].data));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ HeartbeatInfoParameter param,
+ response.parameters().get<HeartbeatInfoParameter>());
+
+ EXPECT_THAT(param.info(), ElementsAre(1, 2, 3, 4, 5));
+}
+
+TEST_F(HeartbeatHandlerTest, SendsHeartbeatRequestsOnIdleChannel) {
+ AdvanceTime(options_.heartbeat_interval);
+
+ // Grab the request, and make a response.
+ std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload));
+ ASSERT_THAT(packet.descriptors(), SizeIs(1));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ HeartbeatRequestChunk req,
+ HeartbeatRequestChunk::Parse(packet.descriptors()[0].data));
+
+ HeartbeatAckChunk ack(std::move(req).extract_parameters());
+
+ // Respond a while later. This RTT will be measured by the handler
+ constexpr DurationMs rtt(313);
+
+ EXPECT_CALL(context_, ObserveRTT(rtt)).Times(1);
+
+ callbacks_.AdvanceTime(rtt);
+ handler_.HandleHeartbeatAck(std::move(ack));
+}
+
+TEST_F(HeartbeatHandlerTest, DoesntObserveInvalidHeartbeats) {
+ AdvanceTime(options_.heartbeat_interval);
+
+ // Grab the request, and make a response.
+ std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload));
+ ASSERT_THAT(packet.descriptors(), SizeIs(1));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ HeartbeatRequestChunk req,
+ HeartbeatRequestChunk::Parse(packet.descriptors()[0].data));
+
+ HeartbeatAckChunk ack(std::move(req).extract_parameters());
+
+ EXPECT_CALL(context_, ObserveRTT).Times(0);
+
+ // Go backwards in time - which make the HEARTBEAT-ACK have an invalid
+ // timestamp in it, as it will be in the future.
+ callbacks_.AdvanceTime(DurationMs(-100));
+
+ handler_.HandleHeartbeatAck(std::move(ack));
+}
+
+TEST_F(HeartbeatHandlerTest, IncreasesErrorIfNotAckedInTime) {
+ DurationMs rto(105);
+ EXPECT_CALL(context_, current_rto).WillOnce(Return(rto));
+ AdvanceTime(options_.heartbeat_interval);
+
+ // Validate that a request was sent.
+ EXPECT_THAT(callbacks_.ConsumeSentPacket(), Not(IsEmpty()));
+
+ EXPECT_CALL(context_, IncrementTxErrorCounter).Times(1);
+ AdvanceTime(rto);
+}
+
+TEST_F(DisabledHeartbeatHandlerTest, IsReallyDisabled) {
+ AdvanceTime(options_.heartbeat_interval);
+
+ // Validate that a request was NOT sent.
+ EXPECT_THAT(callbacks_.ConsumeSentPacket(), IsEmpty());
+}
+
+} // namespace
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/mock_context.h b/third_party/libwebrtc/net/dcsctp/socket/mock_context.h
new file mode 100644
index 0000000000..88e71d1b35
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/mock_context.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_
+#define NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_
+
+#include <cstdint>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/socket/context.h"
+#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h"
+#include "test/gmock.h"
+
+namespace dcsctp {
+
+class MockContext : public Context {
+ public:
+ static constexpr TSN MyInitialTsn() { return TSN(990); }
+ static constexpr TSN PeerInitialTsn() { return TSN(10); }
+ static constexpr VerificationTag PeerVerificationTag() {
+ return VerificationTag(0x01234567);
+ }
+
+ explicit MockContext(MockDcSctpSocketCallbacks* callbacks)
+ : callbacks_(*callbacks) {
+ ON_CALL(*this, is_connection_established)
+ .WillByDefault(testing::Return(true));
+ ON_CALL(*this, my_initial_tsn)
+ .WillByDefault(testing::Return(MyInitialTsn()));
+ ON_CALL(*this, peer_initial_tsn)
+ .WillByDefault(testing::Return(PeerInitialTsn()));
+ ON_CALL(*this, callbacks).WillByDefault(testing::ReturnRef(callbacks_));
+ ON_CALL(*this, current_rto).WillByDefault(testing::Return(DurationMs(123)));
+ ON_CALL(*this, Send).WillByDefault([this](SctpPacket::Builder& builder) {
+ callbacks_.SendPacketWithStatus(builder.Build());
+ });
+ }
+
+ MOCK_METHOD(bool, is_connection_established, (), (const, override));
+ MOCK_METHOD(TSN, my_initial_tsn, (), (const, override));
+ MOCK_METHOD(TSN, peer_initial_tsn, (), (const, override));
+ MOCK_METHOD(DcSctpSocketCallbacks&, callbacks, (), (const, override));
+
+ MOCK_METHOD(void, ObserveRTT, (DurationMs rtt_ms), (override));
+ MOCK_METHOD(DurationMs, current_rto, (), (const, override));
+ MOCK_METHOD(bool,
+ IncrementTxErrorCounter,
+ (absl::string_view reason),
+ (override));
+ MOCK_METHOD(void, ClearTxErrorCounter, (), (override));
+ MOCK_METHOD(bool, HasTooManyTxErrors, (), (const, override));
+ SctpPacket::Builder PacketBuilder() const override {
+ return SctpPacket::Builder(PeerVerificationTag(), options_);
+ }
+ MOCK_METHOD(void, Send, (SctpPacket::Builder & builder), (override));
+
+ DcSctpOptions options_;
+ MockDcSctpSocketCallbacks& callbacks_;
+};
+} // namespace dcsctp
+
+#endif // NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_
diff --git a/third_party/libwebrtc/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/third_party/libwebrtc/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h
new file mode 100644
index 0000000000..8b2a772fa3
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h
@@ -0,0 +1,179 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_
+#define NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_
+
+#include <cstdint>
+#include <deque>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "api/task_queue/task_queue_base.h"
+#include "net/dcsctp/public/dcsctp_message.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/public/timeout.h"
+#include "net/dcsctp/public/types.h"
+#include "net/dcsctp/timer/fake_timeout.h"
+#include "rtc_base/logging.h"
+#include "rtc_base/random.h"
+#include "test/gmock.h"
+
+namespace dcsctp {
+
+namespace internal {
+// It can be argued if a mocked random number generator should be deterministic
+// or if it should be have as a "real" random number generator. In this
+// implementation, each instantiation of `MockDcSctpSocketCallbacks` will have
+// their `GetRandomInt` return different sequences, but each instantiation will
+// always generate the same sequence of random numbers. This to make it easier
+// to compare logs from tests, but still to let e.g. two different sockets (used
+// in the same test) get different random numbers, so that they don't start e.g.
+// on the same sequence number. While that isn't an issue in the protocol, it
+// just makes debugging harder as the two sockets would look exactly the same.
+//
+// In a real implementation of `DcSctpSocketCallbacks` the random number
+// generator backing `GetRandomInt` should be seeded externally and correctly.
+inline int GetUniqueSeed() {
+ static int seed = 0;
+ return ++seed;
+}
+} // namespace internal
+
+class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks {
+ public:
+ explicit MockDcSctpSocketCallbacks(absl::string_view name = "")
+ : log_prefix_(name.empty() ? "" : std::string(name) + ": "),
+ random_(internal::GetUniqueSeed()),
+ timeout_manager_([this]() { return now_; }) {
+ ON_CALL(*this, SendPacketWithStatus)
+ .WillByDefault([this](rtc::ArrayView<const uint8_t> data) {
+ sent_packets_.emplace_back(
+ std::vector<uint8_t>(data.begin(), data.end()));
+ return SendPacketStatus::kSuccess;
+ });
+ ON_CALL(*this, OnMessageReceived)
+ .WillByDefault([this](DcSctpMessage message) {
+ received_messages_.emplace_back(std::move(message));
+ });
+
+ ON_CALL(*this, OnError)
+ .WillByDefault([this](ErrorKind error, absl::string_view message) {
+ RTC_LOG(LS_WARNING)
+ << log_prefix_ << "Socket error: " << ToString(error) << "; "
+ << message;
+ });
+ ON_CALL(*this, OnAborted)
+ .WillByDefault([this](ErrorKind error, absl::string_view message) {
+ RTC_LOG(LS_WARNING)
+ << log_prefix_ << "Socket abort: " << ToString(error) << "; "
+ << message;
+ });
+ ON_CALL(*this, TimeMillis).WillByDefault([this]() { return now_; });
+ }
+
+ MOCK_METHOD(SendPacketStatus,
+ SendPacketWithStatus,
+ (rtc::ArrayView<const uint8_t> data),
+ (override));
+
+ std::unique_ptr<Timeout> CreateTimeout(
+ webrtc::TaskQueueBase::DelayPrecision precision) override {
+ // The fake timeout manager does not implement |precision|.
+ return timeout_manager_.CreateTimeout();
+ }
+
+ MOCK_METHOD(TimeMs, TimeMillis, (), (override));
+ uint32_t GetRandomInt(uint32_t low, uint32_t high) override {
+ return random_.Rand(low, high);
+ }
+
+ MOCK_METHOD(void, OnMessageReceived, (DcSctpMessage message), (override));
+ MOCK_METHOD(void,
+ OnError,
+ (ErrorKind error, absl::string_view message),
+ (override));
+ MOCK_METHOD(void,
+ OnAborted,
+ (ErrorKind error, absl::string_view message),
+ (override));
+ MOCK_METHOD(void, OnConnected, (), (override));
+ MOCK_METHOD(void, OnClosed, (), (override));
+ MOCK_METHOD(void, OnConnectionRestarted, (), (override));
+ MOCK_METHOD(void,
+ OnStreamsResetFailed,
+ (rtc::ArrayView<const StreamID> outgoing_streams,
+ absl::string_view reason),
+ (override));
+ MOCK_METHOD(void,
+ OnStreamsResetPerformed,
+ (rtc::ArrayView<const StreamID> outgoing_streams),
+ (override));
+ MOCK_METHOD(void,
+ OnIncomingStreamsReset,
+ (rtc::ArrayView<const StreamID> incoming_streams),
+ (override));
+ MOCK_METHOD(void, OnBufferedAmountLow, (StreamID stream_id), (override));
+ MOCK_METHOD(void, OnTotalBufferedAmountLow, (), (override));
+ MOCK_METHOD(void,
+ OnLifecycleMessageExpired,
+ (LifecycleId lifecycle_id, bool maybe_delivered),
+ (override));
+ MOCK_METHOD(void,
+ OnLifecycleMessageFullySent,
+ (LifecycleId lifecycle_id),
+ (override));
+ MOCK_METHOD(void,
+ OnLifecycleMessageDelivered,
+ (LifecycleId lifecycle_id),
+ (override));
+ MOCK_METHOD(void, OnLifecycleEnd, (LifecycleId lifecycle_id), (override));
+
+ bool HasPacket() const { return !sent_packets_.empty(); }
+
+ std::vector<uint8_t> ConsumeSentPacket() {
+ if (sent_packets_.empty()) {
+ return {};
+ }
+ std::vector<uint8_t> ret = std::move(sent_packets_.front());
+ sent_packets_.pop_front();
+ return ret;
+ }
+ absl::optional<DcSctpMessage> ConsumeReceivedMessage() {
+ if (received_messages_.empty()) {
+ return absl::nullopt;
+ }
+ DcSctpMessage ret = std::move(received_messages_.front());
+ received_messages_.pop_front();
+ return ret;
+ }
+
+ void AdvanceTime(DurationMs duration_ms) { now_ = now_ + duration_ms; }
+ void SetTime(TimeMs now) { now_ = now; }
+
+ absl::optional<TimeoutID> GetNextExpiredTimeout() {
+ return timeout_manager_.GetNextExpiredTimeout();
+ }
+
+ private:
+ const std::string log_prefix_;
+ TimeMs now_ = TimeMs(0);
+ webrtc::Random random_;
+ FakeTimeoutManager timeout_manager_;
+ std::deque<std::vector<uint8_t>> sent_packets_;
+ std::deque<DcSctpMessage> received_messages_;
+};
+} // namespace dcsctp
+
+#endif // NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_
diff --git a/third_party/libwebrtc/net/dcsctp/socket/packet_sender.cc b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.cc
new file mode 100644
index 0000000000..85392e205d
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.cc
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/packet_sender.h"
+
+#include <utility>
+#include <vector>
+
+#include "net/dcsctp/public/types.h"
+
+namespace dcsctp {
+
+PacketSender::PacketSender(DcSctpSocketCallbacks& callbacks,
+ std::function<void(rtc::ArrayView<const uint8_t>,
+ SendPacketStatus)> on_sent_packet)
+ : callbacks_(callbacks), on_sent_packet_(std::move(on_sent_packet)) {}
+
+bool PacketSender::Send(SctpPacket::Builder& builder) {
+ if (builder.empty()) {
+ return false;
+ }
+
+ std::vector<uint8_t> payload = builder.Build();
+
+ SendPacketStatus status = callbacks_.SendPacketWithStatus(payload);
+ on_sent_packet_(payload, status);
+ switch (status) {
+ case SendPacketStatus::kSuccess: {
+ return true;
+ }
+ case SendPacketStatus::kTemporaryFailure: {
+ // TODO(boivie): Queue this packet to be retried to be sent later.
+ return false;
+ }
+
+ case SendPacketStatus::kError: {
+ // Nothing that can be done.
+ return false;
+ }
+ }
+}
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/packet_sender.h b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.h
new file mode 100644
index 0000000000..7af4d3c47b
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_PACKET_SENDER_H_
+#define NET_DCSCTP_SOCKET_PACKET_SENDER_H_
+
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+
+namespace dcsctp {
+
+// The PacketSender sends packets to the network using the provided callback
+// interface. When an attempt to send a packet is made, the `on_sent_packet`
+// callback will be triggered.
+class PacketSender {
+ public:
+ PacketSender(DcSctpSocketCallbacks& callbacks,
+ std::function<void(rtc::ArrayView<const uint8_t>,
+ SendPacketStatus)> on_sent_packet);
+
+ // Sends the packet, and returns true if it was sent successfully.
+ bool Send(SctpPacket::Builder& builder);
+
+ private:
+ DcSctpSocketCallbacks& callbacks_;
+
+ // Callback that will be triggered for every send attempt, indicating the
+ // status of the operation.
+ std::function<void(rtc::ArrayView<const uint8_t>, SendPacketStatus)>
+ on_sent_packet_;
+};
+} // namespace dcsctp
+
+#endif // NET_DCSCTP_SOCKET_PACKET_SENDER_H_
diff --git a/third_party/libwebrtc/net/dcsctp/socket/packet_sender_test.cc b/third_party/libwebrtc/net/dcsctp/socket/packet_sender_test.cc
new file mode 100644
index 0000000000..079dc36a41
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/packet_sender_test.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/packet_sender.h"
+
+#include "net/dcsctp/common/internal_types.h"
+#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h"
+#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h"
+#include "rtc_base/gunit.h"
+#include "test/gmock.h"
+
+namespace dcsctp {
+namespace {
+using ::testing::_;
+
+constexpr VerificationTag kVerificationTag(123);
+
+class PacketSenderTest : public testing::Test {
+ protected:
+ PacketSenderTest() : sender_(callbacks_, on_send_fn_.AsStdFunction()) {}
+
+ SctpPacket::Builder PacketBuilder() const {
+ return SctpPacket::Builder(kVerificationTag, options_);
+ }
+
+ DcSctpOptions options_;
+ testing::NiceMock<MockDcSctpSocketCallbacks> callbacks_;
+ testing::MockFunction<void(rtc::ArrayView<const uint8_t>, SendPacketStatus)>
+ on_send_fn_;
+ PacketSender sender_;
+};
+
+TEST_F(PacketSenderTest, SendPacketCallsCallback) {
+ EXPECT_CALL(on_send_fn_, Call(_, SendPacketStatus::kSuccess));
+ EXPECT_TRUE(sender_.Send(PacketBuilder().Add(CookieAckChunk())));
+
+ EXPECT_CALL(callbacks_, SendPacketWithStatus)
+ .WillOnce(testing::Return(SendPacketStatus::kError));
+ EXPECT_CALL(on_send_fn_, Call(_, SendPacketStatus::kError));
+ EXPECT_FALSE(sender_.Send(PacketBuilder().Add(CookieAckChunk())));
+}
+
+} // namespace
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc
new file mode 100644
index 0000000000..7d04cbb0d7
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/state_cookie.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "net/dcsctp/packet/bounded_byte_reader.h"
+#include "net/dcsctp/packet/bounded_byte_writer.h"
+#include "net/dcsctp/socket/capabilities.h"
+#include "rtc_base/logging.h"
+
+namespace dcsctp {
+
+// Magic values, which the state cookie is prefixed with.
+constexpr uint32_t kMagic1 = 1684230979;
+constexpr uint32_t kMagic2 = 1414541360;
+constexpr size_t StateCookie::kCookieSize;
+
+std::vector<uint8_t> StateCookie::Serialize() {
+ std::vector<uint8_t> cookie;
+ cookie.resize(kCookieSize);
+ BoundedByteWriter<kCookieSize> buffer(cookie);
+ buffer.Store32<0>(kMagic1);
+ buffer.Store32<4>(kMagic2);
+ buffer.Store32<8>(*initiate_tag_);
+ buffer.Store32<12>(*initial_tsn_);
+ buffer.Store32<16>(a_rwnd_);
+ buffer.Store32<20>(static_cast<uint32_t>(*tie_tag_ >> 32));
+ buffer.Store32<24>(static_cast<uint32_t>(*tie_tag_));
+ buffer.Store8<28>(capabilities_.partial_reliability);
+ buffer.Store8<29>(capabilities_.message_interleaving);
+ buffer.Store8<30>(capabilities_.reconfig);
+ return cookie;
+}
+
+absl::optional<StateCookie> StateCookie::Deserialize(
+ rtc::ArrayView<const uint8_t> cookie) {
+ if (cookie.size() != kCookieSize) {
+ RTC_DLOG(LS_WARNING) << "Invalid state cookie: " << cookie.size()
+ << " bytes";
+ return absl::nullopt;
+ }
+
+ BoundedByteReader<kCookieSize> buffer(cookie);
+ uint32_t magic1 = buffer.Load32<0>();
+ uint32_t magic2 = buffer.Load32<4>();
+ if (magic1 != kMagic1 || magic2 != kMagic2) {
+ RTC_DLOG(LS_WARNING) << "Invalid state cookie; wrong magic";
+ return absl::nullopt;
+ }
+
+ VerificationTag verification_tag(buffer.Load32<8>());
+ TSN initial_tsn(buffer.Load32<12>());
+ uint32_t a_rwnd = buffer.Load32<16>();
+ uint32_t tie_tag_upper = buffer.Load32<20>();
+ uint32_t tie_tag_lower = buffer.Load32<24>();
+ TieTag tie_tag(static_cast<uint64_t>(tie_tag_upper) << 32 |
+ static_cast<uint64_t>(tie_tag_lower));
+ Capabilities capabilities;
+ capabilities.partial_reliability = buffer.Load8<28>() != 0;
+ capabilities.message_interleaving = buffer.Load8<29>() != 0;
+ capabilities.reconfig = buffer.Load8<30>() != 0;
+
+ return StateCookie(verification_tag, initial_tsn, a_rwnd, tie_tag,
+ capabilities);
+}
+
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h
new file mode 100644
index 0000000000..df4b801397
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h
@@ -0,0 +1,65 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_STATE_COOKIE_H_
+#define NET_DCSCTP_SOCKET_STATE_COOKIE_H_
+
+#include <cstdint>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "net/dcsctp/common/internal_types.h"
+#include "net/dcsctp/socket/capabilities.h"
+
+namespace dcsctp {
+
+// This is serialized as a state cookie and put in INIT_ACK. The client then
+// responds with this in COOKIE_ECHO.
+//
+// NOTE: Expect that the client will modify it to try to exploit the library.
+// Do not trust anything in it; no pointers or anything like that.
+class StateCookie {
+ public:
+ static constexpr size_t kCookieSize = 31;
+
+ StateCookie(VerificationTag initiate_tag,
+ TSN initial_tsn,
+ uint32_t a_rwnd,
+ TieTag tie_tag,
+ Capabilities capabilities)
+ : initiate_tag_(initiate_tag),
+ initial_tsn_(initial_tsn),
+ a_rwnd_(a_rwnd),
+ tie_tag_(tie_tag),
+ capabilities_(capabilities) {}
+
+ // Returns a serialized version of this cookie.
+ std::vector<uint8_t> Serialize();
+
+ // Deserializes the cookie, and returns absl::nullopt if that failed.
+ static absl::optional<StateCookie> Deserialize(
+ rtc::ArrayView<const uint8_t> cookie);
+
+ VerificationTag initiate_tag() const { return initiate_tag_; }
+ TSN initial_tsn() const { return initial_tsn_; }
+ uint32_t a_rwnd() const { return a_rwnd_; }
+ TieTag tie_tag() const { return tie_tag_; }
+ const Capabilities& capabilities() const { return capabilities_; }
+
+ private:
+ const VerificationTag initiate_tag_;
+ const TSN initial_tsn_;
+ const uint32_t a_rwnd_;
+ const TieTag tie_tag_;
+ const Capabilities capabilities_;
+};
+} // namespace dcsctp
+
+#endif // NET_DCSCTP_SOCKET_STATE_COOKIE_H_
diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc b/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc
new file mode 100644
index 0000000000..870620ba14
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/state_cookie.h"
+
+#include "net/dcsctp/testing/testing_macros.h"
+#include "rtc_base/gunit.h"
+#include "test/gmock.h"
+
+namespace dcsctp {
+namespace {
+using ::testing::SizeIs;
+
+TEST(StateCookieTest, SerializeAndDeserialize) {
+ Capabilities capabilities = {/*partial_reliability=*/true,
+ /*message_interleaving=*/false,
+ /*reconfig=*/true};
+ StateCookie cookie(VerificationTag(123), TSN(456),
+ /*a_rwnd=*/789, TieTag(101112), capabilities);
+ std::vector<uint8_t> serialized = cookie.Serialize();
+ EXPECT_THAT(serialized, SizeIs(StateCookie::kCookieSize));
+ ASSERT_HAS_VALUE_AND_ASSIGN(StateCookie deserialized,
+ StateCookie::Deserialize(serialized));
+ EXPECT_EQ(deserialized.initiate_tag(), VerificationTag(123));
+ EXPECT_EQ(deserialized.initial_tsn(), TSN(456));
+ EXPECT_EQ(deserialized.a_rwnd(), 789u);
+ EXPECT_EQ(deserialized.tie_tag(), TieTag(101112));
+ EXPECT_TRUE(deserialized.capabilities().partial_reliability);
+ EXPECT_FALSE(deserialized.capabilities().message_interleaving);
+ EXPECT_TRUE(deserialized.capabilities().reconfig);
+}
+
+TEST(StateCookieTest, ValidateMagicValue) {
+ Capabilities capabilities = {/*partial_reliability=*/true,
+ /*message_interleaving=*/false,
+ /*reconfig=*/true};
+ StateCookie cookie(VerificationTag(123), TSN(456),
+ /*a_rwnd=*/789, TieTag(101112), capabilities);
+ std::vector<uint8_t> serialized = cookie.Serialize();
+ ASSERT_THAT(serialized, SizeIs(StateCookie::kCookieSize));
+
+ absl::string_view magic(reinterpret_cast<const char*>(serialized.data()), 8);
+ EXPECT_EQ(magic, "dcSCTP00");
+}
+
+} // namespace
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.cc b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.cc
new file mode 100644
index 0000000000..9d86953f44
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.cc
@@ -0,0 +1,349 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/stream_reset_handler.h"
+
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "net/dcsctp/common/internal_types.h"
+#include "net/dcsctp/common/str_join.h"
+#include "net/dcsctp/packet/chunk/reconfig_chunk.h"
+#include "net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h"
+#include "net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h"
+#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h"
+#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h"
+#include "net/dcsctp/packet/parameter/parameter.h"
+#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h"
+#include "net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/packet/tlv_trait.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/rx/data_tracker.h"
+#include "net/dcsctp/rx/reassembly_queue.h"
+#include "net/dcsctp/socket/context.h"
+#include "net/dcsctp/timer/timer.h"
+#include "net/dcsctp/tx/retransmission_queue.h"
+#include "rtc_base/logging.h"
+
+namespace dcsctp {
+namespace {
+using ResponseResult = ReconfigurationResponseParameter::Result;
+
+bool DescriptorsAre(const std::vector<ParameterDescriptor>& c,
+ uint16_t e1,
+ uint16_t e2) {
+ return (c[0].type == e1 && c[1].type == e2) ||
+ (c[0].type == e2 && c[1].type == e1);
+}
+
+} // namespace
+
+bool StreamResetHandler::Validate(const ReConfigChunk& chunk) {
+ const Parameters& parameters = chunk.parameters();
+
+ // https://tools.ietf.org/html/rfc6525#section-3.1
+ // "Note that each RE-CONFIG chunk holds at least one parameter
+ // and at most two parameters. Only the following combinations are allowed:"
+ std::vector<ParameterDescriptor> descriptors = parameters.descriptors();
+ if (descriptors.size() == 1) {
+ if ((descriptors[0].type == OutgoingSSNResetRequestParameter::kType) ||
+ (descriptors[0].type == IncomingSSNResetRequestParameter::kType) ||
+ (descriptors[0].type == SSNTSNResetRequestParameter::kType) ||
+ (descriptors[0].type == AddOutgoingStreamsRequestParameter::kType) ||
+ (descriptors[0].type == AddIncomingStreamsRequestParameter::kType) ||
+ (descriptors[0].type == ReconfigurationResponseParameter::kType)) {
+ return true;
+ }
+ } else if (descriptors.size() == 2) {
+ if (DescriptorsAre(descriptors, OutgoingSSNResetRequestParameter::kType,
+ IncomingSSNResetRequestParameter::kType) ||
+ DescriptorsAre(descriptors, AddOutgoingStreamsRequestParameter::kType,
+ AddIncomingStreamsRequestParameter::kType) ||
+ DescriptorsAre(descriptors, ReconfigurationResponseParameter::kType,
+ OutgoingSSNResetRequestParameter::kType) ||
+ DescriptorsAre(descriptors, ReconfigurationResponseParameter::kType,
+ ReconfigurationResponseParameter::kType)) {
+ return true;
+ }
+ }
+
+ RTC_LOG(LS_WARNING) << "Invalid set of RE-CONFIG parameters";
+ return false;
+}
+
+absl::optional<std::vector<ReconfigurationResponseParameter>>
+StreamResetHandler::Process(const ReConfigChunk& chunk) {
+ if (!Validate(chunk)) {
+ return absl::nullopt;
+ }
+
+ std::vector<ReconfigurationResponseParameter> responses;
+
+ for (const ParameterDescriptor& desc : chunk.parameters().descriptors()) {
+ switch (desc.type) {
+ case OutgoingSSNResetRequestParameter::kType:
+ HandleResetOutgoing(desc, responses);
+ break;
+
+ case IncomingSSNResetRequestParameter::kType:
+ HandleResetIncoming(desc, responses);
+ break;
+
+ case ReconfigurationResponseParameter::kType:
+ HandleResponse(desc);
+ break;
+ }
+ }
+
+ return responses;
+}
+
+void StreamResetHandler::HandleReConfig(ReConfigChunk chunk) {
+ absl::optional<std::vector<ReconfigurationResponseParameter>> responses =
+ Process(chunk);
+
+ if (!responses.has_value()) {
+ ctx_->callbacks().OnError(ErrorKind::kParseFailed,
+ "Failed to parse RE-CONFIG command");
+ return;
+ }
+
+ if (!responses->empty()) {
+ SctpPacket::Builder b = ctx_->PacketBuilder();
+ Parameters::Builder params_builder;
+ for (const auto& response : *responses) {
+ params_builder.Add(response);
+ }
+ b.Add(ReConfigChunk(params_builder.Build()));
+ ctx_->Send(b);
+ }
+}
+
+bool StreamResetHandler::ValidateReqSeqNbr(
+ ReconfigRequestSN req_seq_nbr,
+ std::vector<ReconfigurationResponseParameter>& responses) {
+ if (req_seq_nbr == last_processed_req_seq_nbr_) {
+ // This has already been performed previously.
+ RTC_DLOG(LS_VERBOSE) << log_prefix_ << "req=" << *req_seq_nbr
+ << " already processed";
+ responses.push_back(ReconfigurationResponseParameter(
+ req_seq_nbr, ResponseResult::kSuccessNothingToDo));
+ return false;
+ }
+
+ if (req_seq_nbr != ReconfigRequestSN(*last_processed_req_seq_nbr_ + 1)) {
+ // Too old, too new, from wrong association etc.
+ // This is expected to happen when handing over a RTCPeerConnection from one
+ // server to another. The client will notice this and may decide to close
+ // old data channels, which may be sent to the wrong (or both) servers
+ // during a handover.
+ RTC_DLOG(LS_VERBOSE) << log_prefix_ << "req=" << *req_seq_nbr
+ << " bad seq_nbr";
+ responses.push_back(ReconfigurationResponseParameter(
+ req_seq_nbr, ResponseResult::kErrorBadSequenceNumber));
+ return false;
+ }
+
+ return true;
+}
+
+void StreamResetHandler::HandleResetOutgoing(
+ const ParameterDescriptor& descriptor,
+ std::vector<ReconfigurationResponseParameter>& responses) {
+ absl::optional<OutgoingSSNResetRequestParameter> req =
+ OutgoingSSNResetRequestParameter::Parse(descriptor.data);
+ if (!req.has_value()) {
+ ctx_->callbacks().OnError(ErrorKind::kParseFailed,
+ "Failed to parse Outgoing Reset command");
+ return;
+ }
+
+ if (ValidateReqSeqNbr(req->request_sequence_number(), responses)) {
+ ResponseResult result;
+
+ RTC_DLOG(LS_VERBOSE) << log_prefix_
+ << "Reset outgoing streams with req_seq_nbr="
+ << *req->request_sequence_number();
+
+ last_processed_req_seq_nbr_ = req->request_sequence_number();
+ result = reassembly_queue_->ResetStreams(
+ *req, data_tracker_->last_cumulative_acked_tsn());
+ if (result == ResponseResult::kSuccessPerformed) {
+ ctx_->callbacks().OnIncomingStreamsReset(req->stream_ids());
+ }
+ responses.push_back(ReconfigurationResponseParameter(
+ req->request_sequence_number(), result));
+ }
+}
+
+void StreamResetHandler::HandleResetIncoming(
+ const ParameterDescriptor& descriptor,
+ std::vector<ReconfigurationResponseParameter>& responses) {
+ absl::optional<IncomingSSNResetRequestParameter> req =
+ IncomingSSNResetRequestParameter::Parse(descriptor.data);
+ if (!req.has_value()) {
+ ctx_->callbacks().OnError(ErrorKind::kParseFailed,
+ "Failed to parse Incoming Reset command");
+ return;
+ }
+ if (ValidateReqSeqNbr(req->request_sequence_number(), responses)) {
+ responses.push_back(ReconfigurationResponseParameter(
+ req->request_sequence_number(), ResponseResult::kSuccessNothingToDo));
+ last_processed_req_seq_nbr_ = req->request_sequence_number();
+ }
+}
+
+void StreamResetHandler::HandleResponse(const ParameterDescriptor& descriptor) {
+ absl::optional<ReconfigurationResponseParameter> resp =
+ ReconfigurationResponseParameter::Parse(descriptor.data);
+ if (!resp.has_value()) {
+ ctx_->callbacks().OnError(
+ ErrorKind::kParseFailed,
+ "Failed to parse Reconfiguration Response command");
+ return;
+ }
+
+ if (current_request_.has_value() && current_request_->has_been_sent() &&
+ resp->response_sequence_number() == current_request_->req_seq_nbr()) {
+ reconfig_timer_->Stop();
+
+ switch (resp->result()) {
+ case ResponseResult::kSuccessNothingToDo:
+ case ResponseResult::kSuccessPerformed:
+ RTC_DLOG(LS_VERBOSE)
+ << log_prefix_ << "Reset stream success, req_seq_nbr="
+ << *current_request_->req_seq_nbr() << ", streams="
+ << StrJoin(current_request_->streams(), ",",
+ [](rtc::StringBuilder& sb, StreamID stream_id) {
+ sb << *stream_id;
+ });
+ ctx_->callbacks().OnStreamsResetPerformed(current_request_->streams());
+ current_request_ = absl::nullopt;
+ retransmission_queue_->CommitResetStreams();
+ break;
+ case ResponseResult::kInProgress:
+ RTC_DLOG(LS_VERBOSE)
+ << log_prefix_ << "Reset stream still pending, req_seq_nbr="
+ << *current_request_->req_seq_nbr() << ", streams="
+ << StrJoin(current_request_->streams(), ",",
+ [](rtc::StringBuilder& sb, StreamID stream_id) {
+ sb << *stream_id;
+ });
+ // Force this request to be sent again, but with new req_seq_nbr.
+ current_request_->PrepareRetransmission();
+ reconfig_timer_->set_duration(ctx_->current_rto());
+ reconfig_timer_->Start();
+ break;
+ case ResponseResult::kErrorRequestAlreadyInProgress:
+ case ResponseResult::kDenied:
+ case ResponseResult::kErrorWrongSSN:
+ case ResponseResult::kErrorBadSequenceNumber:
+ RTC_DLOG(LS_WARNING)
+ << log_prefix_ << "Reset stream error=" << ToString(resp->result())
+ << ", req_seq_nbr=" << *current_request_->req_seq_nbr()
+ << ", streams="
+ << StrJoin(current_request_->streams(), ",",
+ [](rtc::StringBuilder& sb, StreamID stream_id) {
+ sb << *stream_id;
+ });
+ ctx_->callbacks().OnStreamsResetFailed(current_request_->streams(),
+ ToString(resp->result()));
+ current_request_ = absl::nullopt;
+ retransmission_queue_->RollbackResetStreams();
+ break;
+ }
+ }
+}
+
+absl::optional<ReConfigChunk> StreamResetHandler::MakeStreamResetRequest() {
+ // Only send stream resets if there are streams to reset, and no current
+ // ongoing request (there can only be one at a time), and if the stream
+ // can be reset.
+ if (current_request_.has_value() ||
+ !retransmission_queue_->HasStreamsReadyToBeReset()) {
+ return absl::nullopt;
+ }
+
+ current_request_.emplace(TSN(*retransmission_queue_->next_tsn() - 1),
+ retransmission_queue_->GetStreamsReadyToBeReset());
+ reconfig_timer_->set_duration(ctx_->current_rto());
+ reconfig_timer_->Start();
+ return MakeReconfigChunk();
+}
+
+ReConfigChunk StreamResetHandler::MakeReconfigChunk() {
+ // The req_seq_nbr will be empty if the request has never been sent before,
+ // or if it was sent, but the sender responded "in progress", and then the
+ // req_seq_nbr will be cleared to re-send with a new number. But if the
+ // request is re-sent due to timeout (reconfig-timer expiring), the same
+ // req_seq_nbr will be used.
+ RTC_DCHECK(current_request_.has_value());
+
+ if (!current_request_->has_been_sent()) {
+ current_request_->PrepareToSend(next_outgoing_req_seq_nbr_);
+ next_outgoing_req_seq_nbr_ =
+ ReconfigRequestSN(*next_outgoing_req_seq_nbr_ + 1);
+ }
+
+ Parameters::Builder params_builder =
+ Parameters::Builder().Add(OutgoingSSNResetRequestParameter(
+ current_request_->req_seq_nbr(), current_request_->req_seq_nbr(),
+ current_request_->sender_last_assigned_tsn(),
+ current_request_->streams()));
+
+ return ReConfigChunk(params_builder.Build());
+}
+
+void StreamResetHandler::ResetStreams(
+ rtc::ArrayView<const StreamID> outgoing_streams) {
+ for (StreamID stream_id : outgoing_streams) {
+ retransmission_queue_->PrepareResetStream(stream_id);
+ }
+}
+
+absl::optional<DurationMs> StreamResetHandler::OnReconfigTimerExpiry() {
+ if (current_request_->has_been_sent()) {
+ // There is an outstanding request, which timed out while waiting for a
+ // response.
+ if (!ctx_->IncrementTxErrorCounter("RECONFIG timeout")) {
+ // Timed out. The connection will close after processing the timers.
+ return absl::nullopt;
+ }
+ } else {
+ // There is no outstanding request, but there is a prepared one. This means
+ // that the receiver has previously responded "in progress", which resulted
+ // in retrying the request (but with a new req_seq_nbr) after a while.
+ }
+
+ ctx_->Send(ctx_->PacketBuilder().Add(MakeReconfigChunk()));
+ return ctx_->current_rto();
+}
+
+HandoverReadinessStatus StreamResetHandler::GetHandoverReadiness() const {
+ HandoverReadinessStatus status;
+ if (retransmission_queue_->HasStreamsReadyToBeReset()) {
+ status.Add(HandoverUnreadinessReason::kPendingStreamReset);
+ }
+ if (current_request_.has_value()) {
+ status.Add(HandoverUnreadinessReason::kPendingStreamResetRequest);
+ }
+ return status;
+}
+
+void StreamResetHandler::AddHandoverState(DcSctpSocketHandoverState& state) {
+ state.rx.last_completed_reset_req_sn = last_processed_req_seq_nbr_.value();
+ state.tx.next_reset_req_sn = next_outgoing_req_seq_nbr_.value();
+}
+
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.h b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.h
new file mode 100644
index 0000000000..6e49665538
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.h
@@ -0,0 +1,230 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_
+#define NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/functional/bind_front.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "net/dcsctp/common/internal_types.h"
+#include "net/dcsctp/packet/chunk/reconfig_chunk.h"
+#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h"
+#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h"
+#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/rx/data_tracker.h"
+#include "net/dcsctp/rx/reassembly_queue.h"
+#include "net/dcsctp/socket/context.h"
+#include "net/dcsctp/timer/timer.h"
+#include "net/dcsctp/tx/retransmission_queue.h"
+#include "rtc_base/containers/flat_set.h"
+
+namespace dcsctp {
+
+// StreamResetHandler handles sending outgoing stream reset requests (to close
+// an SCTP stream, which translates to closing a data channel).
+//
+// It also handles incoming "outgoing stream reset requests", when the peer
+// wants to close its data channel.
+//
+// Resetting streams is an asynchronous operation where the client will request
+// a request a stream to be reset, but then it might not be performed exactly at
+// this point. First, the sender might need to discard all messages that have
+// been enqueued for this stream, or it may select to wait until all have been
+// sent. At least, it must wait for the currently sending fragmented message to
+// be fully sent, because a stream can't be reset while having received half a
+// message. In the stream reset request, the "sender's last assigned TSN" is
+// provided, which is simply the TSN for which the receiver should've received
+// all messages before this value, before the stream can be reset. Since
+// fragments can get lost or sent out-of-order, the receiver of a request may
+// not have received all the data just yet, and then it will respond to the
+// sender: "In progress". In other words, try again. The sender will then need
+// to start a timer and try the very same request again (but with a new sequence
+// number) until the receiver successfully performs the operation.
+//
+// All this can take some time, and may be driven by timers, so the client will
+// ultimately be notified using callbacks.
+//
+// In this implementation, when a stream is reset, the queued but not-yet-sent
+// messages will be discarded, but that may change in the future. RFC8831 allows
+// both behaviors.
+class StreamResetHandler {
+ public:
+ StreamResetHandler(absl::string_view log_prefix,
+ Context* context,
+ TimerManager* timer_manager,
+ DataTracker* data_tracker,
+ ReassemblyQueue* reassembly_queue,
+ RetransmissionQueue* retransmission_queue,
+ const DcSctpSocketHandoverState* handover_state = nullptr)
+ : log_prefix_(std::string(log_prefix) + "reset: "),
+ ctx_(context),
+ data_tracker_(data_tracker),
+ reassembly_queue_(reassembly_queue),
+ retransmission_queue_(retransmission_queue),
+ reconfig_timer_(timer_manager->CreateTimer(
+ "re-config",
+ absl::bind_front(&StreamResetHandler::OnReconfigTimerExpiry, this),
+ TimerOptions(DurationMs(0)))),
+ next_outgoing_req_seq_nbr_(
+ handover_state
+ ? ReconfigRequestSN(handover_state->tx.next_reset_req_sn)
+ : ReconfigRequestSN(*ctx_->my_initial_tsn())),
+ last_processed_req_seq_nbr_(
+ handover_state ? ReconfigRequestSN(
+ handover_state->rx.last_completed_reset_req_sn)
+ : ReconfigRequestSN(*ctx_->peer_initial_tsn() - 1)) {
+ }
+
+ // Initiates reset of the provided streams. While there can only be one
+ // ongoing stream reset request at any time, this method can be called at any
+ // time and also multiple times. It will enqueue requests that can't be
+ // directly fulfilled, and will asynchronously process them when any ongoing
+ // request has completed.
+ void ResetStreams(rtc::ArrayView<const StreamID> outgoing_streams);
+
+ // Creates a Reset Streams request that must be sent if returned. Will start
+ // the reconfig timer. Will return absl::nullopt if there is no need to
+ // create a request (no streams to reset) or if there already is an ongoing
+ // stream reset request that hasn't completed yet.
+ absl::optional<ReConfigChunk> MakeStreamResetRequest();
+
+ // Called when handling and incoming RE-CONFIG chunk.
+ void HandleReConfig(ReConfigChunk chunk);
+
+ HandoverReadinessStatus GetHandoverReadiness() const;
+
+ void AddHandoverState(DcSctpSocketHandoverState& state);
+
+ private:
+ // Represents a stream request operation. There can only be one ongoing at
+ // any time, and a sent request may either succeed, fail or result in the
+ // receiver signaling that it can't process it right now, and then it will be
+ // retried.
+ class CurrentRequest {
+ public:
+ CurrentRequest(TSN sender_last_assigned_tsn, std::vector<StreamID> streams)
+ : req_seq_nbr_(absl::nullopt),
+ sender_last_assigned_tsn_(sender_last_assigned_tsn),
+ streams_(std::move(streams)) {}
+
+ // Returns the current request sequence number, if this request has been
+ // sent (check `has_been_sent` first). Will return 0 if the request is just
+ // prepared (or scheduled for retransmission) but not yet sent.
+ ReconfigRequestSN req_seq_nbr() const {
+ return req_seq_nbr_.value_or(ReconfigRequestSN(0));
+ }
+
+ // The sender's last assigned TSN, from the retransmission queue. The
+ // receiver uses this to know when all data up to this TSN has been
+ // received, to know when to safely reset the stream.
+ TSN sender_last_assigned_tsn() const { return sender_last_assigned_tsn_; }
+
+ // The streams that are to be reset.
+ const std::vector<StreamID>& streams() const { return streams_; }
+
+ // If this request has been sent yet. If not, then it's either because it
+ // has only been prepared and not yet sent, or because the received couldn't
+ // apply the request, and then the exact same request will be retried, but
+ // with a new sequence number.
+ bool has_been_sent() const { return req_seq_nbr_.has_value(); }
+
+ // If the receiver can't apply the request yet (and answered "In Progress"),
+ // this will be called to prepare the request to be retransmitted at a later
+ // time.
+ void PrepareRetransmission() { req_seq_nbr_ = absl::nullopt; }
+
+ // If the request hasn't been sent yet, this assigns it a request number.
+ void PrepareToSend(ReconfigRequestSN new_req_seq_nbr) {
+ req_seq_nbr_ = new_req_seq_nbr;
+ }
+
+ private:
+ // If this is set, this request has been sent. If it's not set, the request
+ // has been prepared, but has not yet been sent. This is typically used when
+ // the peer responded "in progress" and the same request (but a different
+ // request number) must be sent again.
+ absl::optional<ReconfigRequestSN> req_seq_nbr_;
+ // The sender's (that's us) last assigned TSN, from the retransmission
+ // queue.
+ TSN sender_last_assigned_tsn_;
+ // The streams that are to be reset in this request.
+ const std::vector<StreamID> streams_;
+ };
+
+ // Called to validate an incoming RE-CONFIG chunk.
+ bool Validate(const ReConfigChunk& chunk);
+
+ // Processes a stream stream reconfiguration chunk and may either return
+ // absl::nullopt (on protocol errors), or a list of responses - either 0, 1
+ // or 2.
+ absl::optional<std::vector<ReconfigurationResponseParameter>> Process(
+ const ReConfigChunk& chunk);
+
+ // Creates the actual RE-CONFIG chunk. A request (which set `current_request`)
+ // must have been created prior.
+ ReConfigChunk MakeReconfigChunk();
+
+ // Called to validate the `req_seq_nbr`, that it's the next in sequence. If it
+ // fails to validate, and returns false, it will also add a response to
+ // `responses`.
+ bool ValidateReqSeqNbr(
+ ReconfigRequestSN req_seq_nbr,
+ std::vector<ReconfigurationResponseParameter>& responses);
+
+ // Called when this socket receives an outgoing stream reset request. It might
+ // either be performed straight away, or have to be deferred, and the result
+ // of that will be put in `responses`.
+ void HandleResetOutgoing(
+ const ParameterDescriptor& descriptor,
+ std::vector<ReconfigurationResponseParameter>& responses);
+
+ // Called when this socket receives an incoming stream reset request. This
+ // isn't really supported, but a successful response is put in `responses`.
+ void HandleResetIncoming(
+ const ParameterDescriptor& descriptor,
+ std::vector<ReconfigurationResponseParameter>& responses);
+
+ // Called when receiving a response to an outgoing stream reset request. It
+ // will either commit the stream resetting, if the operation was successful,
+ // or will schedule a retry if it was deferred. And if it failed, the
+ // operation will be rolled back.
+ void HandleResponse(const ParameterDescriptor& descriptor);
+
+ // Expiration handler for the Reconfig timer.
+ absl::optional<DurationMs> OnReconfigTimerExpiry();
+
+ const std::string log_prefix_;
+ Context* ctx_;
+ DataTracker* data_tracker_;
+ ReassemblyQueue* reassembly_queue_;
+ RetransmissionQueue* retransmission_queue_;
+ const std::unique_ptr<Timer> reconfig_timer_;
+
+ // The next sequence number for outgoing stream requests.
+ ReconfigRequestSN next_outgoing_req_seq_nbr_;
+
+ // The current stream request operation.
+ absl::optional<CurrentRequest> current_request_;
+
+ // For incoming requests - last processed request sequence number.
+ ReconfigRequestSN last_processed_req_seq_nbr_;
+};
+} // namespace dcsctp
+
+#endif // NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_
diff --git a/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler_test.cc b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler_test.cc
new file mode 100644
index 0000000000..493b4c4bf7
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler_test.cc
@@ -0,0 +1,802 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/stream_reset_handler.h"
+
+#include <array>
+#include <cstdint>
+#include <memory>
+#include <type_traits>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "api/task_queue/task_queue_base.h"
+#include "net/dcsctp/common/handover_testing.h"
+#include "net/dcsctp/common/internal_types.h"
+#include "net/dcsctp/packet/chunk/reconfig_chunk.h"
+#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h"
+#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h"
+#include "net/dcsctp/packet/parameter/parameter.h"
+#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h"
+#include "net/dcsctp/public/dcsctp_message.h"
+#include "net/dcsctp/rx/data_tracker.h"
+#include "net/dcsctp/rx/reassembly_queue.h"
+#include "net/dcsctp/socket/mock_context.h"
+#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h"
+#include "net/dcsctp/testing/data_generator.h"
+#include "net/dcsctp/testing/testing_macros.h"
+#include "net/dcsctp/timer/timer.h"
+#include "net/dcsctp/tx/mock_send_queue.h"
+#include "net/dcsctp/tx/retransmission_queue.h"
+#include "rtc_base/gunit.h"
+#include "test/gmock.h"
+
+namespace dcsctp {
+namespace {
+using ::testing::IsEmpty;
+using ::testing::NiceMock;
+using ::testing::Return;
+using ::testing::SizeIs;
+using ::testing::UnorderedElementsAre;
+using ResponseResult = ReconfigurationResponseParameter::Result;
+
+constexpr TSN kMyInitialTsn = MockContext::MyInitialTsn();
+constexpr ReconfigRequestSN kMyInitialReqSn = ReconfigRequestSN(*kMyInitialTsn);
+constexpr TSN kPeerInitialTsn = MockContext::PeerInitialTsn();
+constexpr ReconfigRequestSN kPeerInitialReqSn =
+ ReconfigRequestSN(*kPeerInitialTsn);
+constexpr uint32_t kArwnd = 131072;
+constexpr DurationMs kRto = DurationMs(250);
+
+constexpr std::array<uint8_t, 4> kShortPayload = {1, 2, 3, 4};
+
+MATCHER_P3(SctpMessageIs, stream_id, ppid, expected_payload, "") {
+ if (arg.stream_id() != stream_id) {
+ *result_listener << "the stream_id is " << *arg.stream_id();
+ return false;
+ }
+
+ if (arg.ppid() != ppid) {
+ *result_listener << "the ppid is " << *arg.ppid();
+ return false;
+ }
+
+ if (std::vector<uint8_t>(arg.payload().begin(), arg.payload().end()) !=
+ std::vector<uint8_t>(expected_payload.begin(), expected_payload.end())) {
+ *result_listener << "the payload is wrong";
+ return false;
+ }
+ return true;
+}
+
+TSN AddTo(TSN tsn, int delta) {
+ return TSN(*tsn + delta);
+}
+
+ReconfigRequestSN AddTo(ReconfigRequestSN req_sn, int delta) {
+ return ReconfigRequestSN(*req_sn + delta);
+}
+
+class StreamResetHandlerTest : public testing::Test {
+ protected:
+ StreamResetHandlerTest()
+ : ctx_(&callbacks_),
+ timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) {
+ return callbacks_.CreateTimeout(precision);
+ }),
+ delayed_ack_timer_(timer_manager_.CreateTimer(
+ "test/delayed_ack",
+ []() { return absl::nullopt; },
+ TimerOptions(DurationMs(0)))),
+ t3_rtx_timer_(timer_manager_.CreateTimer(
+ "test/t3_rtx",
+ []() { return absl::nullopt; },
+ TimerOptions(DurationMs(0)))),
+ data_tracker_(std::make_unique<DataTracker>("log: ",
+ delayed_ack_timer_.get(),
+ kPeerInitialTsn)),
+ reasm_(std::make_unique<ReassemblyQueue>("log: ",
+ kPeerInitialTsn,
+ kArwnd)),
+ retransmission_queue_(std::make_unique<RetransmissionQueue>(
+ "",
+ &callbacks_,
+ kMyInitialTsn,
+ kArwnd,
+ producer_,
+ [](DurationMs rtt_ms) {},
+ []() {},
+ *t3_rtx_timer_,
+ DcSctpOptions())),
+ handler_(
+ std::make_unique<StreamResetHandler>("log: ",
+ &ctx_,
+ &timer_manager_,
+ data_tracker_.get(),
+ reasm_.get(),
+ retransmission_queue_.get())) {
+ EXPECT_CALL(ctx_, current_rto).WillRepeatedly(Return(kRto));
+ }
+
+ void AdvanceTime(DurationMs duration) {
+ callbacks_.AdvanceTime(kRto);
+ for (;;) {
+ absl::optional<TimeoutID> timeout_id = callbacks_.GetNextExpiredTimeout();
+ if (!timeout_id.has_value()) {
+ break;
+ }
+ timer_manager_.HandleTimeout(*timeout_id);
+ }
+ }
+
+ // Handles the passed in RE-CONFIG `chunk` and returns the responses
+ // that are sent in the response RE-CONFIG.
+ std::vector<ReconfigurationResponseParameter> HandleAndCatchResponse(
+ ReConfigChunk chunk) {
+ handler_->HandleReConfig(std::move(chunk));
+
+ std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
+ if (payload.empty()) {
+ EXPECT_TRUE(false);
+ return {};
+ }
+
+ std::vector<ReconfigurationResponseParameter> responses;
+ absl::optional<SctpPacket> p = SctpPacket::Parse(payload);
+ if (!p.has_value()) {
+ EXPECT_TRUE(false);
+ return {};
+ }
+ if (p->descriptors().size() != 1) {
+ EXPECT_TRUE(false);
+ return {};
+ }
+ absl::optional<ReConfigChunk> response_chunk =
+ ReConfigChunk::Parse(p->descriptors()[0].data);
+ if (!response_chunk.has_value()) {
+ EXPECT_TRUE(false);
+ return {};
+ }
+ for (const auto& desc : response_chunk->parameters().descriptors()) {
+ if (desc.type == ReconfigurationResponseParameter::kType) {
+ absl::optional<ReconfigurationResponseParameter> response =
+ ReconfigurationResponseParameter::Parse(desc.data);
+ if (!response.has_value()) {
+ EXPECT_TRUE(false);
+ return {};
+ }
+ responses.emplace_back(*std::move(response));
+ }
+ }
+ return responses;
+ }
+
+ void PerformHandover() {
+ EXPECT_TRUE(handler_->GetHandoverReadiness().IsReady());
+ EXPECT_TRUE(data_tracker_->GetHandoverReadiness().IsReady());
+ EXPECT_TRUE(reasm_->GetHandoverReadiness().IsReady());
+ EXPECT_TRUE(retransmission_queue_->GetHandoverReadiness().IsReady());
+
+ DcSctpSocketHandoverState state;
+ handler_->AddHandoverState(state);
+ data_tracker_->AddHandoverState(state);
+ reasm_->AddHandoverState(state);
+
+ retransmission_queue_->AddHandoverState(state);
+
+ g_handover_state_transformer_for_test(&state);
+
+ data_tracker_ = std::make_unique<DataTracker>(
+ "log: ", delayed_ack_timer_.get(), kPeerInitialTsn);
+ data_tracker_->RestoreFromState(state);
+ reasm_ =
+ std::make_unique<ReassemblyQueue>("log: ", kPeerInitialTsn, kArwnd);
+ reasm_->RestoreFromState(state);
+ retransmission_queue_ = std::make_unique<RetransmissionQueue>(
+ "", &callbacks_, kMyInitialTsn, kArwnd, producer_,
+ [](DurationMs rtt_ms) {}, []() {}, *t3_rtx_timer_, DcSctpOptions(),
+ /*supports_partial_reliability=*/true,
+ /*use_message_interleaving=*/false);
+ retransmission_queue_->RestoreFromState(state);
+ handler_ = std::make_unique<StreamResetHandler>(
+ "log: ", &ctx_, &timer_manager_, data_tracker_.get(), reasm_.get(),
+ retransmission_queue_.get(), &state);
+ }
+
+ DataGenerator gen_;
+ NiceMock<MockDcSctpSocketCallbacks> callbacks_;
+ NiceMock<MockContext> ctx_;
+ NiceMock<MockSendQueue> producer_;
+ TimerManager timer_manager_;
+ std::unique_ptr<Timer> delayed_ack_timer_;
+ std::unique_ptr<Timer> t3_rtx_timer_;
+ std::unique_ptr<DataTracker> data_tracker_;
+ std::unique_ptr<ReassemblyQueue> reasm_;
+ std::unique_ptr<RetransmissionQueue> retransmission_queue_;
+ std::unique_ptr<StreamResetHandler> handler_;
+};
+
+TEST_F(StreamResetHandlerTest, ChunkWithNoParametersReturnsError) {
+ EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0);
+ EXPECT_CALL(callbacks_, OnError).Times(1);
+ handler_->HandleReConfig(ReConfigChunk(Parameters()));
+}
+
+TEST_F(StreamResetHandlerTest, ChunkWithInvalidParametersReturnsError) {
+ Parameters::Builder builder;
+ // Two OutgoingSSNResetRequestParameter in a RE-CONFIG is not valid.
+ builder.Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(1),
+ ReconfigRequestSN(10),
+ kPeerInitialTsn, {StreamID(1)}));
+ builder.Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(2),
+ ReconfigRequestSN(10),
+ kPeerInitialTsn, {StreamID(2)}));
+
+ EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0);
+ EXPECT_CALL(callbacks_, OnError).Times(1);
+ handler_->HandleReConfig(ReConfigChunk(builder.Build()));
+}
+
+TEST_F(StreamResetHandlerTest, FailToDeliverWithoutResettingStream) {
+ reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE"));
+ reasm_->Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE"));
+
+ data_tracker_->Observe(kPeerInitialTsn);
+ data_tracker_->Observe(AddTo(kPeerInitialTsn, 1));
+ EXPECT_THAT(reasm_->FlushMessages(),
+ UnorderedElementsAre(
+ SctpMessageIs(StreamID(1), PPID(53), kShortPayload),
+ SctpMessageIs(StreamID(1), PPID(53), kShortPayload)));
+
+ gen_.ResetStream();
+ reasm_->Add(AddTo(kPeerInitialTsn, 2), gen_.Ordered({1, 2, 3, 4}, "BE"));
+ EXPECT_THAT(reasm_->FlushMessages(), IsEmpty());
+}
+
+TEST_F(StreamResetHandlerTest, ResetStreamsNotDeferred) {
+ reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE"));
+ reasm_->Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE"));
+
+ data_tracker_->Observe(kPeerInitialTsn);
+ data_tracker_->Observe(AddTo(kPeerInitialTsn, 1));
+ EXPECT_THAT(reasm_->FlushMessages(),
+ UnorderedElementsAre(
+ SctpMessageIs(StreamID(1), PPID(53), kShortPayload),
+ SctpMessageIs(StreamID(1), PPID(53), kShortPayload)));
+
+ Parameters::Builder builder;
+ builder.Add(OutgoingSSNResetRequestParameter(
+ kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 1),
+ {StreamID(1)}));
+
+ std::vector<ReconfigurationResponseParameter> responses =
+ HandleAndCatchResponse(ReConfigChunk(builder.Build()));
+ EXPECT_THAT(responses, SizeIs(1));
+ EXPECT_EQ(responses[0].result(), ResponseResult::kSuccessPerformed);
+
+ gen_.ResetStream();
+ reasm_->Add(AddTo(kPeerInitialTsn, 2), gen_.Ordered({1, 2, 3, 4}, "BE"));
+ EXPECT_THAT(reasm_->FlushMessages(),
+ UnorderedElementsAre(
+ SctpMessageIs(StreamID(1), PPID(53), kShortPayload)));
+}
+
+TEST_F(StreamResetHandlerTest, ResetStreamsDeferred) {
+ DataGeneratorOptions opts;
+ opts.message_id = MID(0);
+ reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE", opts));
+
+ opts.message_id = MID(1);
+ reasm_->Add(AddTo(kPeerInitialTsn, 1),
+ gen_.Ordered({1, 2, 3, 4}, "BE", opts));
+
+ data_tracker_->Observe(kPeerInitialTsn);
+ data_tracker_->Observe(AddTo(kPeerInitialTsn, 1));
+ EXPECT_THAT(reasm_->FlushMessages(),
+ UnorderedElementsAre(
+ SctpMessageIs(StreamID(1), PPID(53), kShortPayload),
+ SctpMessageIs(StreamID(1), PPID(53), kShortPayload)));
+
+ Parameters::Builder builder;
+ builder.Add(OutgoingSSNResetRequestParameter(
+ kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 3),
+ {StreamID(1)}));
+
+ std::vector<ReconfigurationResponseParameter> responses =
+ HandleAndCatchResponse(ReConfigChunk(builder.Build()));
+ EXPECT_THAT(responses, SizeIs(1));
+ EXPECT_EQ(responses[0].result(), ResponseResult::kInProgress);
+
+ opts.message_id = MID(1);
+ opts.ppid = PPID(5);
+ reasm_->Add(AddTo(kPeerInitialTsn, 5),
+ gen_.Ordered({1, 2, 3, 4}, "BE", opts));
+ reasm_->MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 1));
+
+ opts.message_id = MID(0);
+ opts.ppid = PPID(4);
+ reasm_->Add(AddTo(kPeerInitialTsn, 4),
+ gen_.Ordered({1, 2, 3, 4}, "BE", opts));
+ reasm_->MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 1));
+
+ opts.message_id = MID(3);
+ opts.ppid = PPID(3);
+ reasm_->Add(AddTo(kPeerInitialTsn, 3),
+ gen_.Ordered({1, 2, 3, 4}, "BE", opts));
+ reasm_->MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 1));
+
+ opts.message_id = MID(2);
+ opts.ppid = PPID(2);
+ reasm_->Add(AddTo(kPeerInitialTsn, 2),
+ gen_.Ordered({1, 2, 3, 4}, "BE", opts));
+ reasm_->MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 5));
+
+ EXPECT_THAT(
+ reasm_->FlushMessages(),
+ UnorderedElementsAre(SctpMessageIs(StreamID(1), PPID(2), kShortPayload),
+ SctpMessageIs(StreamID(1), PPID(3), kShortPayload),
+ SctpMessageIs(StreamID(1), PPID(4), kShortPayload),
+ SctpMessageIs(StreamID(1), PPID(5), kShortPayload)));
+}
+
+TEST_F(StreamResetHandlerTest, SendOutgoingRequestDirectly) {
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(42)));
+ handler_->ResetStreams(std::vector<StreamID>({StreamID(42)}));
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(std::vector<StreamID>({StreamID(42)})));
+
+ absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest();
+ ASSERT_TRUE(reconfig.has_value());
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ OutgoingSSNResetRequestParameter req,
+ reconfig->parameters().get<OutgoingSSNResetRequestParameter>());
+
+ EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn);
+ EXPECT_EQ(req.sender_last_assigned_tsn(),
+ TSN(*retransmission_queue_->next_tsn() - 1));
+ EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(42)));
+}
+
+TEST_F(StreamResetHandlerTest, ResetMultipleStreamsInOneRequest) {
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(40)));
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(41)));
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))).Times(2);
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(43)));
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(44)));
+ handler_->ResetStreams(std::vector<StreamID>({StreamID(42)}));
+ handler_->ResetStreams(
+ std::vector<StreamID>({StreamID(43), StreamID(44), StreamID(41)}));
+ handler_->ResetStreams(std::vector<StreamID>({StreamID(42), StreamID(40)}));
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(
+ std::vector<StreamID>({StreamID(40), StreamID(41), StreamID(42),
+ StreamID(43), StreamID(44)})));
+ absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest();
+ ASSERT_TRUE(reconfig.has_value());
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ OutgoingSSNResetRequestParameter req,
+ reconfig->parameters().get<OutgoingSSNResetRequestParameter>());
+
+ EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn);
+ EXPECT_EQ(req.sender_last_assigned_tsn(),
+ TSN(*retransmission_queue_->next_tsn() - 1));
+ EXPECT_THAT(req.stream_ids(),
+ UnorderedElementsAre(StreamID(40), StreamID(41), StreamID(42),
+ StreamID(43), StreamID(44)));
+}
+
+TEST_F(StreamResetHandlerTest, SendOutgoingRequestDeferred) {
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(42)));
+ handler_->ResetStreams(std::vector<StreamID>({StreamID(42)}));
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset())
+ .WillOnce(Return(false))
+ .WillOnce(Return(false))
+ .WillOnce(Return(true));
+
+ EXPECT_FALSE(handler_->MakeStreamResetRequest().has_value());
+ EXPECT_FALSE(handler_->MakeStreamResetRequest().has_value());
+ EXPECT_TRUE(handler_->MakeStreamResetRequest().has_value());
+}
+
+TEST_F(StreamResetHandlerTest, SendOutgoingResettingOnPositiveResponse) {
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(42)));
+ handler_->ResetStreams(std::vector<StreamID>({StreamID(42)}));
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(std::vector<StreamID>({StreamID(42)})));
+
+ absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest();
+ ASSERT_TRUE(reconfig.has_value());
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ OutgoingSSNResetRequestParameter req,
+ reconfig->parameters().get<OutgoingSSNResetRequestParameter>());
+
+ Parameters::Builder builder;
+ builder.Add(ReconfigurationResponseParameter(
+ req.request_sequence_number(), ResponseResult::kSuccessPerformed));
+ ReConfigChunk response_reconfig(builder.Build());
+
+ EXPECT_CALL(producer_, CommitResetStreams);
+ EXPECT_CALL(producer_, RollbackResetStreams).Times(0);
+
+ // Processing a response shouldn't result in sending anything.
+ EXPECT_CALL(callbacks_, OnError).Times(0);
+ EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0);
+ handler_->HandleReConfig(std::move(response_reconfig));
+}
+
+TEST_F(StreamResetHandlerTest, SendOutgoingResetRollbackOnError) {
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(42)));
+ handler_->ResetStreams(std::vector<StreamID>({StreamID(42)}));
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(std::vector<StreamID>({StreamID(42)})));
+
+ absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest();
+ ASSERT_TRUE(reconfig.has_value());
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ OutgoingSSNResetRequestParameter req,
+ reconfig->parameters().get<OutgoingSSNResetRequestParameter>());
+
+ Parameters::Builder builder;
+ builder.Add(ReconfigurationResponseParameter(
+ req.request_sequence_number(), ResponseResult::kErrorBadSequenceNumber));
+ ReConfigChunk response_reconfig(builder.Build());
+
+ EXPECT_CALL(producer_, CommitResetStreams).Times(0);
+ EXPECT_CALL(producer_, RollbackResetStreams);
+
+ // Only requests should result in sending responses.
+ EXPECT_CALL(callbacks_, OnError).Times(0);
+ EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0);
+ handler_->HandleReConfig(std::move(response_reconfig));
+}
+
+TEST_F(StreamResetHandlerTest, SendOutgoingResetRetransmitOnInProgress) {
+ static constexpr StreamID kStreamToReset = StreamID(42);
+
+ EXPECT_CALL(producer_, PrepareResetStream(kStreamToReset));
+ handler_->ResetStreams(std::vector<StreamID>({kStreamToReset}));
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(std::vector<StreamID>({kStreamToReset})));
+
+ absl::optional<ReConfigChunk> reconfig1 = handler_->MakeStreamResetRequest();
+ ASSERT_TRUE(reconfig1.has_value());
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ OutgoingSSNResetRequestParameter req1,
+ reconfig1->parameters().get<OutgoingSSNResetRequestParameter>());
+
+ // Simulate that the peer responded "In Progress".
+ Parameters::Builder builder;
+ builder.Add(ReconfigurationResponseParameter(req1.request_sequence_number(),
+ ResponseResult::kInProgress));
+ ReConfigChunk response_reconfig(builder.Build());
+
+ EXPECT_CALL(producer_, CommitResetStreams()).Times(0);
+ EXPECT_CALL(producer_, RollbackResetStreams()).Times(0);
+
+ // Processing a response shouldn't result in sending anything.
+ EXPECT_CALL(callbacks_, OnError).Times(0);
+ EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0);
+ handler_->HandleReConfig(std::move(response_reconfig));
+
+ // Let some time pass, so that the reconfig timer expires, and retries the
+ // same request.
+ EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(1);
+ AdvanceTime(kRto);
+
+ std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket();
+ ASSERT_FALSE(payload.empty());
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload));
+ ASSERT_THAT(packet.descriptors(), SizeIs(1));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ ReConfigChunk reconfig2,
+ ReConfigChunk::Parse(packet.descriptors()[0].data));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ OutgoingSSNResetRequestParameter req2,
+ reconfig2.parameters().get<OutgoingSSNResetRequestParameter>());
+
+ EXPECT_EQ(req2.request_sequence_number(),
+ AddTo(req1.request_sequence_number(), 1));
+ EXPECT_THAT(req2.stream_ids(), UnorderedElementsAre(kStreamToReset));
+}
+
+TEST_F(StreamResetHandlerTest, ResetWhileRequestIsSentWillQueue) {
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(42)));
+ handler_->ResetStreams(std::vector<StreamID>({StreamID(42)}));
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(std::vector<StreamID>({StreamID(42)})));
+
+ absl::optional<ReConfigChunk> reconfig1 = handler_->MakeStreamResetRequest();
+ ASSERT_TRUE(reconfig1.has_value());
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ OutgoingSSNResetRequestParameter req1,
+ reconfig1->parameters().get<OutgoingSSNResetRequestParameter>());
+ EXPECT_EQ(req1.request_sequence_number(), kMyInitialReqSn);
+ EXPECT_EQ(req1.sender_last_assigned_tsn(),
+ AddTo(retransmission_queue_->next_tsn(), -1));
+ EXPECT_THAT(req1.stream_ids(), UnorderedElementsAre(StreamID(42)));
+
+ // Streams reset while the request is in-flight will be queued.
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(41)));
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(43)));
+ StreamID stream_ids[] = {StreamID(41), StreamID(43)};
+ handler_->ResetStreams(stream_ids);
+ EXPECT_EQ(handler_->MakeStreamResetRequest(), absl::nullopt);
+
+ Parameters::Builder builder;
+ builder.Add(ReconfigurationResponseParameter(
+ req1.request_sequence_number(), ResponseResult::kSuccessPerformed));
+ ReConfigChunk response_reconfig(builder.Build());
+
+ EXPECT_CALL(producer_, CommitResetStreams()).Times(1);
+ EXPECT_CALL(producer_, RollbackResetStreams()).Times(0);
+
+ // Processing a response shouldn't result in sending anything.
+ EXPECT_CALL(callbacks_, OnError).Times(0);
+ EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0);
+ handler_->HandleReConfig(std::move(response_reconfig));
+
+ // Response has been processed. A new request can be sent.
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(std::vector<StreamID>({StreamID(41), StreamID(43)})));
+
+ absl::optional<ReConfigChunk> reconfig2 = handler_->MakeStreamResetRequest();
+ ASSERT_TRUE(reconfig2.has_value());
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ OutgoingSSNResetRequestParameter req2,
+ reconfig2->parameters().get<OutgoingSSNResetRequestParameter>());
+ EXPECT_EQ(req2.request_sequence_number(), AddTo(kMyInitialReqSn, 1));
+ EXPECT_EQ(req2.sender_last_assigned_tsn(),
+ TSN(*retransmission_queue_->next_tsn() - 1));
+ EXPECT_THAT(req2.stream_ids(),
+ UnorderedElementsAre(StreamID(41), StreamID(43)));
+}
+
+TEST_F(StreamResetHandlerTest, SendIncomingResetJustReturnsNothingPerformed) {
+ Parameters::Builder builder;
+ builder.Add(
+ IncomingSSNResetRequestParameter(kPeerInitialReqSn, {StreamID(1)}));
+
+ std::vector<ReconfigurationResponseParameter> responses =
+ HandleAndCatchResponse(ReConfigChunk(builder.Build()));
+ ASSERT_THAT(responses, SizeIs(1));
+ EXPECT_THAT(responses[0].response_sequence_number(), kPeerInitialReqSn);
+ EXPECT_THAT(responses[0].result(), ResponseResult::kSuccessNothingToDo);
+}
+
+TEST_F(StreamResetHandlerTest, SendSameRequestTwiceReturnsNothingToDo) {
+ reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE"));
+ reasm_->Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE"));
+
+ data_tracker_->Observe(kPeerInitialTsn);
+ data_tracker_->Observe(AddTo(kPeerInitialTsn, 1));
+ EXPECT_THAT(reasm_->FlushMessages(),
+ UnorderedElementsAre(
+ SctpMessageIs(StreamID(1), PPID(53), kShortPayload),
+ SctpMessageIs(StreamID(1), PPID(53), kShortPayload)));
+
+ Parameters::Builder builder1;
+ builder1.Add(OutgoingSSNResetRequestParameter(
+ kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 1),
+ {StreamID(1)}));
+
+ std::vector<ReconfigurationResponseParameter> responses1 =
+ HandleAndCatchResponse(ReConfigChunk(builder1.Build()));
+ EXPECT_THAT(responses1, SizeIs(1));
+ EXPECT_EQ(responses1[0].result(), ResponseResult::kSuccessPerformed);
+
+ Parameters::Builder builder2;
+ builder2.Add(OutgoingSSNResetRequestParameter(
+ kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 1),
+ {StreamID(1)}));
+
+ std::vector<ReconfigurationResponseParameter> responses2 =
+ HandleAndCatchResponse(ReConfigChunk(builder2.Build()));
+ EXPECT_THAT(responses2, SizeIs(1));
+ EXPECT_EQ(responses2[0].result(), ResponseResult::kSuccessNothingToDo);
+}
+
+TEST_F(StreamResetHandlerTest,
+ HandoverIsAllowedOnlyWhenNoStreamIsBeingOrWillBeReset) {
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(42)));
+ handler_->ResetStreams(std::vector<StreamID>({StreamID(42)}));
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_EQ(
+ handler_->GetHandoverReadiness(),
+ HandoverReadinessStatus(HandoverUnreadinessReason::kPendingStreamReset));
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset())
+ .WillOnce(Return(true))
+ .WillOnce(Return(false));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(std::vector<StreamID>({StreamID(42)})));
+
+ ASSERT_TRUE(handler_->MakeStreamResetRequest().has_value());
+ EXPECT_EQ(handler_->GetHandoverReadiness(),
+ HandoverReadinessStatus(
+ HandoverUnreadinessReason::kPendingStreamResetRequest));
+
+ // Reset more streams while the request is in-flight.
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(41)));
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(43)));
+ StreamID stream_ids[] = {StreamID(41), StreamID(43)};
+ handler_->ResetStreams(stream_ids);
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_EQ(handler_->GetHandoverReadiness(),
+ HandoverReadinessStatus()
+ .Add(HandoverUnreadinessReason::kPendingStreamResetRequest)
+ .Add(HandoverUnreadinessReason::kPendingStreamReset));
+
+ // Processing a response to first request.
+ EXPECT_CALL(producer_, CommitResetStreams()).Times(1);
+ handler_->HandleReConfig(
+ ReConfigChunk(Parameters::Builder()
+ .Add(ReconfigurationResponseParameter(
+ kMyInitialReqSn, ResponseResult::kSuccessPerformed))
+ .Build()));
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_EQ(
+ handler_->GetHandoverReadiness(),
+ HandoverReadinessStatus(HandoverUnreadinessReason::kPendingStreamReset));
+
+ // Second request can be sent.
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset())
+ .WillOnce(Return(true))
+ .WillOnce(Return(false));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(std::vector<StreamID>({StreamID(41), StreamID(43)})));
+
+ ASSERT_TRUE(handler_->MakeStreamResetRequest().has_value());
+ EXPECT_EQ(handler_->GetHandoverReadiness(),
+ HandoverReadinessStatus(
+ HandoverUnreadinessReason::kPendingStreamResetRequest));
+
+ // Processing a response to second request.
+ EXPECT_CALL(producer_, CommitResetStreams()).Times(1);
+ handler_->HandleReConfig(ReConfigChunk(
+ Parameters::Builder()
+ .Add(ReconfigurationResponseParameter(
+ AddTo(kMyInitialReqSn, 1), ResponseResult::kSuccessPerformed))
+ .Build()));
+
+ // Seconds response has been processed. No pending resets.
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(false));
+
+ EXPECT_TRUE(handler_->GetHandoverReadiness().IsReady());
+}
+
+TEST_F(StreamResetHandlerTest, HandoverInInitialState) {
+ PerformHandover();
+
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(42)));
+ handler_->ResetStreams(std::vector<StreamID>({StreamID(42)}));
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(std::vector<StreamID>({StreamID(42)})));
+
+ absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest();
+ ASSERT_TRUE(reconfig.has_value());
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ OutgoingSSNResetRequestParameter req,
+ reconfig->parameters().get<OutgoingSSNResetRequestParameter>());
+
+ EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn);
+ EXPECT_EQ(req.sender_last_assigned_tsn(),
+ TSN(*retransmission_queue_->next_tsn() - 1));
+ EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(42)));
+}
+
+TEST_F(StreamResetHandlerTest, HandoverAfterHavingResetOneStream) {
+ // Reset one stream
+ {
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(42)));
+ handler_->ResetStreams(std::vector<StreamID>({StreamID(42)}));
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset())
+ .WillOnce(Return(true))
+ .WillOnce(Return(false));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(std::vector<StreamID>({StreamID(42)})));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(ReConfigChunk reconfig,
+ handler_->MakeStreamResetRequest());
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ OutgoingSSNResetRequestParameter req,
+ reconfig.parameters().get<OutgoingSSNResetRequestParameter>());
+ EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn);
+ EXPECT_EQ(req.sender_last_assigned_tsn(),
+ TSN(*retransmission_queue_->next_tsn() - 1));
+ EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(42)));
+
+ EXPECT_CALL(producer_, CommitResetStreams()).Times(1);
+ handler_->HandleReConfig(
+ ReConfigChunk(Parameters::Builder()
+ .Add(ReconfigurationResponseParameter(
+ req.request_sequence_number(),
+ ResponseResult::kSuccessPerformed))
+ .Build()));
+ }
+
+ PerformHandover();
+
+ // Reset another stream after handover
+ {
+ EXPECT_CALL(producer_, PrepareResetStream(StreamID(43)));
+ handler_->ResetStreams(std::vector<StreamID>({StreamID(43)}));
+
+ EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true));
+ EXPECT_CALL(producer_, GetStreamsReadyToBeReset())
+ .WillOnce(Return(std::vector<StreamID>({StreamID(43)})));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(ReConfigChunk reconfig,
+ handler_->MakeStreamResetRequest());
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ OutgoingSSNResetRequestParameter req,
+ reconfig.parameters().get<OutgoingSSNResetRequestParameter>());
+
+ EXPECT_EQ(req.request_sequence_number(),
+ ReconfigRequestSN(kMyInitialReqSn.value() + 1));
+ EXPECT_EQ(req.sender_last_assigned_tsn(),
+ TSN(*retransmission_queue_->next_tsn() - 1));
+ EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(43)));
+ }
+}
+
+TEST_F(StreamResetHandlerTest, PerformCloseAfterOneFirstFailing) {
+ // Inject a stream reset on the first expected TSN (which hasn't been seen).
+ Parameters::Builder builder;
+ builder.Add(OutgoingSSNResetRequestParameter(
+ kPeerInitialReqSn, ReconfigRequestSN(3), kPeerInitialTsn, {StreamID(1)}));
+
+ // The socket is expected to say "in progress" as that TSN hasn't been seen.
+ std::vector<ReconfigurationResponseParameter> responses =
+ HandleAndCatchResponse(ReConfigChunk(builder.Build()));
+ EXPECT_THAT(responses, SizeIs(1));
+ EXPECT_EQ(responses[0].result(), ResponseResult::kInProgress);
+
+ // Let the socket receive the TSN.
+ DataGeneratorOptions opts;
+ opts.message_id = MID(0);
+ reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE", opts));
+ reasm_->MaybeResetStreamsDeferred(kPeerInitialTsn);
+ data_tracker_->Observe(kPeerInitialTsn);
+
+ // And emulate that time has passed, and the peer retries the stream reset,
+ // but now with an incremented request sequence number.
+ Parameters::Builder builder2;
+ builder2.Add(OutgoingSSNResetRequestParameter(
+ ReconfigRequestSN(*kPeerInitialReqSn + 1), ReconfigRequestSN(3),
+ kPeerInitialTsn, {StreamID(1)}));
+
+ // This is supposed to be handled well.
+ std::vector<ReconfigurationResponseParameter> responses2 =
+ HandleAndCatchResponse(ReConfigChunk(builder2.Build()));
+ EXPECT_THAT(responses2, SizeIs(1));
+ EXPECT_EQ(responses2[0].result(), ResponseResult::kSuccessPerformed);
+}
+} // namespace
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.cc b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.cc
new file mode 100644
index 0000000000..d769e26069
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.cc
@@ -0,0 +1,314 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/socket/transmission_control_block.h"
+
+#include <algorithm>
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "net/dcsctp/packet/chunk/data_chunk.h"
+#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h"
+#include "net/dcsctp/packet/chunk/idata_chunk.h"
+#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h"
+#include "net/dcsctp/packet/chunk/reconfig_chunk.h"
+#include "net/dcsctp/packet/chunk/sack_chunk.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/rx/data_tracker.h"
+#include "net/dcsctp/rx/reassembly_queue.h"
+#include "net/dcsctp/socket/capabilities.h"
+#include "net/dcsctp/socket/stream_reset_handler.h"
+#include "net/dcsctp/timer/timer.h"
+#include "net/dcsctp/tx/retransmission_queue.h"
+#include "net/dcsctp/tx/retransmission_timeout.h"
+#include "rtc_base/logging.h"
+#include "rtc_base/strings/string_builder.h"
+
+namespace dcsctp {
+
+TransmissionControlBlock::TransmissionControlBlock(
+ TimerManager& timer_manager,
+ absl::string_view log_prefix,
+ const DcSctpOptions& options,
+ const Capabilities& capabilities,
+ DcSctpSocketCallbacks& callbacks,
+ SendQueue& send_queue,
+ VerificationTag my_verification_tag,
+ TSN my_initial_tsn,
+ VerificationTag peer_verification_tag,
+ TSN peer_initial_tsn,
+ size_t a_rwnd,
+ TieTag tie_tag,
+ PacketSender& packet_sender,
+ std::function<bool()> is_connection_established)
+ : log_prefix_(log_prefix),
+ options_(options),
+ timer_manager_(timer_manager),
+ capabilities_(capabilities),
+ callbacks_(callbacks),
+ t3_rtx_(timer_manager_.CreateTimer(
+ "t3-rtx",
+ absl::bind_front(&TransmissionControlBlock::OnRtxTimerExpiry, this),
+ TimerOptions(options.rto_initial,
+ TimerBackoffAlgorithm::kExponential,
+ /*max_restarts=*/absl::nullopt,
+ options.max_timer_backoff_duration))),
+ delayed_ack_timer_(timer_manager_.CreateTimer(
+ "delayed-ack",
+ absl::bind_front(&TransmissionControlBlock::OnDelayedAckTimerExpiry,
+ this),
+ TimerOptions(options.delayed_ack_max_timeout,
+ TimerBackoffAlgorithm::kExponential,
+ /*max_restarts=*/0,
+ /*max_backoff_duration=*/absl::nullopt,
+ webrtc::TaskQueueBase::DelayPrecision::kHigh))),
+ my_verification_tag_(my_verification_tag),
+ my_initial_tsn_(my_initial_tsn),
+ peer_verification_tag_(peer_verification_tag),
+ peer_initial_tsn_(peer_initial_tsn),
+ tie_tag_(tie_tag),
+ is_connection_established_(std::move(is_connection_established)),
+ packet_sender_(packet_sender),
+ rto_(options),
+ tx_error_counter_(log_prefix, options),
+ data_tracker_(log_prefix, delayed_ack_timer_.get(), peer_initial_tsn),
+ reassembly_queue_(log_prefix,
+ peer_initial_tsn,
+ options.max_receiver_window_buffer_size,
+ capabilities.message_interleaving),
+ retransmission_queue_(
+ log_prefix,
+ &callbacks_,
+ my_initial_tsn,
+ a_rwnd,
+ send_queue,
+ absl::bind_front(&TransmissionControlBlock::ObserveRTT, this),
+ [this]() { tx_error_counter_.Clear(); },
+ *t3_rtx_,
+ options,
+ capabilities.partial_reliability,
+ capabilities.message_interleaving),
+ stream_reset_handler_(log_prefix,
+ this,
+ &timer_manager,
+ &data_tracker_,
+ &reassembly_queue_,
+ &retransmission_queue_),
+ heartbeat_handler_(log_prefix, options, this, &timer_manager_) {
+ send_queue.EnableMessageInterleaving(capabilities.message_interleaving);
+}
+
+void TransmissionControlBlock::ObserveRTT(DurationMs rtt) {
+ DurationMs prev_rto = rto_.rto();
+ rto_.ObserveRTT(rtt);
+ RTC_DLOG(LS_VERBOSE) << log_prefix_ << "new rtt=" << *rtt
+ << ", srtt=" << *rto_.srtt() << ", rto=" << *rto_.rto()
+ << " (" << *prev_rto << ")";
+ t3_rtx_->set_duration(rto_.rto());
+
+ DurationMs delayed_ack_tmo =
+ std::min(rto_.rto() * 0.5, options_.delayed_ack_max_timeout);
+ delayed_ack_timer_->set_duration(delayed_ack_tmo);
+}
+
+absl::optional<DurationMs> TransmissionControlBlock::OnRtxTimerExpiry() {
+ TimeMs now = callbacks_.TimeMillis();
+ RTC_DLOG(LS_INFO) << log_prefix_ << "Timer " << t3_rtx_->name()
+ << " has expired";
+ if (cookie_echo_chunk_.has_value()) {
+ // In the COOKIE_ECHO state, let the T1-COOKIE timer trigger
+ // retransmissions, to avoid having two timers doing that.
+ RTC_DLOG(LS_VERBOSE) << "Not retransmitting as T1-cookie is active.";
+ } else {
+ if (IncrementTxErrorCounter("t3-rtx expired")) {
+ retransmission_queue_.HandleT3RtxTimerExpiry();
+ SendBufferedPackets(now);
+ }
+ }
+ return absl::nullopt;
+}
+
+absl::optional<DurationMs> TransmissionControlBlock::OnDelayedAckTimerExpiry() {
+ data_tracker_.HandleDelayedAckTimerExpiry();
+ MaybeSendSack();
+ return absl::nullopt;
+}
+
+void TransmissionControlBlock::MaybeSendSack() {
+ if (data_tracker_.ShouldSendAck(/*also_if_delayed=*/false)) {
+ SctpPacket::Builder builder = PacketBuilder();
+ builder.Add(
+ data_tracker_.CreateSelectiveAck(reassembly_queue_.remaining_bytes()));
+ Send(builder);
+ }
+}
+
+void TransmissionControlBlock::MaybeSendForwardTsn(SctpPacket::Builder& builder,
+ TimeMs now) {
+ if (now >= limit_forward_tsn_until_ &&
+ retransmission_queue_.ShouldSendForwardTsn(now)) {
+ if (capabilities_.message_interleaving) {
+ builder.Add(retransmission_queue_.CreateIForwardTsn());
+ } else {
+ builder.Add(retransmission_queue_.CreateForwardTsn());
+ }
+ packet_sender_.Send(builder);
+ // https://datatracker.ietf.org/doc/html/rfc3758
+ // "IMPLEMENTATION NOTE: An implementation may wish to limit the number of
+ // duplicate FORWARD TSN chunks it sends by ... waiting a full RTT before
+ // sending a duplicate FORWARD TSN."
+ // "Any delay applied to the sending of FORWARD TSN chunk SHOULD NOT exceed
+ // 200ms and MUST NOT exceed 500ms".
+ limit_forward_tsn_until_ = now + std::min(DurationMs(200), rto_.srtt());
+ }
+}
+
+void TransmissionControlBlock::MaybeSendFastRetransmit() {
+ if (!retransmission_queue_.has_data_to_be_fast_retransmitted()) {
+ return;
+ }
+
+ // https://datatracker.ietf.org/doc/html/rfc4960#section-7.2.4
+ // "Determine how many of the earliest (i.e., lowest TSN) DATA chunks marked
+ // for retransmission will fit into a single packet, subject to constraint of
+ // the path MTU of the destination transport address to which the packet is
+ // being sent. Call this value K. Retransmit those K DATA chunks in a single
+ // packet. When a Fast Retransmit is being performed, the sender SHOULD
+ // ignore the value of cwnd and SHOULD NOT delay retransmission for this
+ // single packet."
+
+ SctpPacket::Builder builder(peer_verification_tag_, options_);
+ auto chunks = retransmission_queue_.GetChunksForFastRetransmit(
+ builder.bytes_remaining());
+ for (auto& [tsn, data] : chunks) {
+ if (capabilities_.message_interleaving) {
+ builder.Add(IDataChunk(tsn, std::move(data), false));
+ } else {
+ builder.Add(DataChunk(tsn, std::move(data), false));
+ }
+ }
+ packet_sender_.Send(builder);
+}
+
+void TransmissionControlBlock::SendBufferedPackets(SctpPacket::Builder& builder,
+ TimeMs now) {
+ for (int packet_idx = 0;
+ packet_idx < options_.max_burst && retransmission_queue_.can_send_data();
+ ++packet_idx) {
+ // Only add control chunks to the first packet that is sent, if sending
+ // multiple packets in one go (as allowed by the congestion window).
+ if (packet_idx == 0) {
+ if (cookie_echo_chunk_.has_value()) {
+ // https://tools.ietf.org/html/rfc4960#section-5.1
+ // "The COOKIE ECHO chunk can be bundled with any pending outbound DATA
+ // chunks, but it MUST be the first chunk in the packet..."
+ RTC_DCHECK(builder.empty());
+ builder.Add(*cookie_echo_chunk_);
+ }
+
+ // https://tools.ietf.org/html/rfc4960#section-6
+ // "Before an endpoint transmits a DATA chunk, if any received DATA
+ // chunks have not been acknowledged (e.g., due to delayed ack), the
+ // sender should create a SACK and bundle it with the outbound DATA chunk,
+ // as long as the size of the final SCTP packet does not exceed the
+ // current MTU."
+ if (data_tracker_.ShouldSendAck(/*also_if_delayed=*/true)) {
+ builder.Add(data_tracker_.CreateSelectiveAck(
+ reassembly_queue_.remaining_bytes()));
+ }
+ MaybeSendForwardTsn(builder, now);
+ absl::optional<ReConfigChunk> reconfig =
+ stream_reset_handler_.MakeStreamResetRequest();
+ if (reconfig.has_value()) {
+ builder.Add(*reconfig);
+ }
+ }
+
+ auto chunks =
+ retransmission_queue_.GetChunksToSend(now, builder.bytes_remaining());
+ for (auto& [tsn, data] : chunks) {
+ if (capabilities_.message_interleaving) {
+ builder.Add(IDataChunk(tsn, std::move(data), false));
+ } else {
+ builder.Add(DataChunk(tsn, std::move(data), false));
+ }
+ }
+
+ if (!packet_sender_.Send(builder)) {
+ break;
+ }
+
+ if (cookie_echo_chunk_.has_value()) {
+ // https://tools.ietf.org/html/rfc4960#section-5.1
+ // "... until the COOKIE ACK is returned the sender MUST NOT send any
+ // other packets to the peer."
+ break;
+ }
+ }
+}
+
+std::string TransmissionControlBlock::ToString() const {
+ rtc::StringBuilder sb;
+
+ sb.AppendFormat(
+ "verification_tag=%08x, last_cumulative_ack=%u, capabilities=",
+ *peer_verification_tag_, *data_tracker_.last_cumulative_acked_tsn());
+
+ if (capabilities_.partial_reliability) {
+ sb << "PR,";
+ }
+ if (capabilities_.message_interleaving) {
+ sb << "IL,";
+ }
+ if (capabilities_.reconfig) {
+ sb << "Reconfig,";
+ }
+
+ return sb.Release();
+}
+
+HandoverReadinessStatus TransmissionControlBlock::GetHandoverReadiness() const {
+ HandoverReadinessStatus status;
+ status.Add(data_tracker_.GetHandoverReadiness());
+ status.Add(stream_reset_handler_.GetHandoverReadiness());
+ status.Add(reassembly_queue_.GetHandoverReadiness());
+ status.Add(retransmission_queue_.GetHandoverReadiness());
+ return status;
+}
+
+void TransmissionControlBlock::AddHandoverState(
+ DcSctpSocketHandoverState& state) {
+ state.capabilities.partial_reliability = capabilities_.partial_reliability;
+ state.capabilities.message_interleaving = capabilities_.message_interleaving;
+ state.capabilities.reconfig = capabilities_.reconfig;
+
+ state.my_verification_tag = my_verification_tag().value();
+ state.peer_verification_tag = peer_verification_tag().value();
+ state.my_initial_tsn = my_initial_tsn().value();
+ state.peer_initial_tsn = peer_initial_tsn().value();
+ state.tie_tag = tie_tag().value();
+
+ data_tracker_.AddHandoverState(state);
+ stream_reset_handler_.AddHandoverState(state);
+ reassembly_queue_.AddHandoverState(state);
+ retransmission_queue_.AddHandoverState(state);
+}
+
+void TransmissionControlBlock::RestoreFromState(
+ const DcSctpSocketHandoverState& state) {
+ data_tracker_.RestoreFromState(state);
+ retransmission_queue_.RestoreFromState(state);
+ reassembly_queue_.RestoreFromState(state);
+}
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.h b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.h
new file mode 100644
index 0000000000..8e0e9a3ec5
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.h
@@ -0,0 +1,193 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_
+#define NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_
+
+#include <cstdint>
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "absl/functional/bind_front.h"
+#include "absl/strings/string_view.h"
+#include "api/task_queue/task_queue_base.h"
+#include "net/dcsctp/common/sequence_numbers.h"
+#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_options.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/rx/data_tracker.h"
+#include "net/dcsctp/rx/reassembly_queue.h"
+#include "net/dcsctp/socket/capabilities.h"
+#include "net/dcsctp/socket/context.h"
+#include "net/dcsctp/socket/heartbeat_handler.h"
+#include "net/dcsctp/socket/packet_sender.h"
+#include "net/dcsctp/socket/stream_reset_handler.h"
+#include "net/dcsctp/timer/timer.h"
+#include "net/dcsctp/tx/retransmission_error_counter.h"
+#include "net/dcsctp/tx/retransmission_queue.h"
+#include "net/dcsctp/tx/retransmission_timeout.h"
+#include "net/dcsctp/tx/send_queue.h"
+
+namespace dcsctp {
+
+// The TransmissionControlBlock (TCB) represents an open connection to a peer,
+// and holds all the resources for that. If the connection is e.g. shutdown,
+// closed or restarted, this object will be deleted and/or replaced.
+class TransmissionControlBlock : public Context {
+ public:
+ TransmissionControlBlock(TimerManager& timer_manager,
+ absl::string_view log_prefix,
+ const DcSctpOptions& options,
+ const Capabilities& capabilities,
+ DcSctpSocketCallbacks& callbacks,
+ SendQueue& send_queue,
+ VerificationTag my_verification_tag,
+ TSN my_initial_tsn,
+ VerificationTag peer_verification_tag,
+ TSN peer_initial_tsn,
+ size_t a_rwnd,
+ TieTag tie_tag,
+ PacketSender& packet_sender,
+ std::function<bool()> is_connection_established);
+
+ // Implementation of `Context`.
+ bool is_connection_established() const override {
+ return is_connection_established_();
+ }
+ TSN my_initial_tsn() const override { return my_initial_tsn_; }
+ TSN peer_initial_tsn() const override { return peer_initial_tsn_; }
+ DcSctpSocketCallbacks& callbacks() const override { return callbacks_; }
+ void ObserveRTT(DurationMs rtt) override;
+ DurationMs current_rto() const override { return rto_.rto(); }
+ bool IncrementTxErrorCounter(absl::string_view reason) override {
+ return tx_error_counter_.Increment(reason);
+ }
+ void ClearTxErrorCounter() override { tx_error_counter_.Clear(); }
+ SctpPacket::Builder PacketBuilder() const override {
+ return SctpPacket::Builder(peer_verification_tag_, options_);
+ }
+ bool HasTooManyTxErrors() const override {
+ return tx_error_counter_.IsExhausted();
+ }
+ void Send(SctpPacket::Builder& builder) override {
+ packet_sender_.Send(builder);
+ }
+
+ // Other accessors
+ DataTracker& data_tracker() { return data_tracker_; }
+ ReassemblyQueue& reassembly_queue() { return reassembly_queue_; }
+ RetransmissionQueue& retransmission_queue() { return retransmission_queue_; }
+ StreamResetHandler& stream_reset_handler() { return stream_reset_handler_; }
+ HeartbeatHandler& heartbeat_handler() { return heartbeat_handler_; }
+ size_t cwnd() const { return retransmission_queue_.cwnd(); }
+ DurationMs current_srtt() const { return rto_.srtt(); }
+
+ // Returns this socket's verification tag, set in all packet headers.
+ VerificationTag my_verification_tag() const { return my_verification_tag_; }
+ // Returns the peer's verification tag, which should be in received packets.
+ VerificationTag peer_verification_tag() const {
+ return peer_verification_tag_;
+ }
+ // All negotiated supported capabilities.
+ const Capabilities& capabilities() const { return capabilities_; }
+ // A 64-bit tie-tag, used to e.g. detect reconnections.
+ TieTag tie_tag() const { return tie_tag_; }
+
+ // Sends a SACK, if there is a need to.
+ void MaybeSendSack();
+
+ // Sends a FORWARD-TSN, if it is needed and allowed (rate-limited).
+ void MaybeSendForwardTsn(SctpPacket::Builder& builder, TimeMs now);
+
+ // Will be set while the socket is in kCookieEcho state. In this state, there
+ // can only be a single packet outstanding, and it must contain the COOKIE
+ // ECHO chunk as the first chunk in that packet, until the COOKIE ACK has been
+ // received, which will make the socket call `ClearCookieEchoChunk`.
+ void SetCookieEchoChunk(CookieEchoChunk chunk) {
+ cookie_echo_chunk_ = std::move(chunk);
+ }
+
+ // Called when the COOKIE ACK chunk has been received, to allow further
+ // packets to be sent.
+ void ClearCookieEchoChunk() { cookie_echo_chunk_ = absl::nullopt; }
+
+ bool has_cookie_echo_chunk() const { return cookie_echo_chunk_.has_value(); }
+
+ void MaybeSendFastRetransmit();
+
+ // Fills `builder` (which may already be filled with control chunks) with
+ // other control and data chunks, and sends packets as much as can be
+ // allowed by the congestion control algorithm.
+ void SendBufferedPackets(SctpPacket::Builder& builder, TimeMs now);
+
+ // As above, but without passing in a builder. If `cookie_echo_chunk_` is
+ // present, then only one packet will be sent, with this chunk as the first
+ // chunk.
+ void SendBufferedPackets(TimeMs now) {
+ SctpPacket::Builder builder(peer_verification_tag_, options_);
+ SendBufferedPackets(builder, now);
+ }
+
+ // Returns a textual representation of this object, for logging.
+ std::string ToString() const;
+
+ HandoverReadinessStatus GetHandoverReadiness() const;
+
+ void AddHandoverState(DcSctpSocketHandoverState& state);
+ void RestoreFromState(const DcSctpSocketHandoverState& handover_state);
+
+ private:
+ // Will be called when the retransmission timer (t3-rtx) expires.
+ absl::optional<DurationMs> OnRtxTimerExpiry();
+ // Will be called when the delayed ack timer expires.
+ absl::optional<DurationMs> OnDelayedAckTimerExpiry();
+
+ const std::string log_prefix_;
+ const DcSctpOptions options_;
+ TimerManager& timer_manager_;
+ // Negotiated capabilities that both peers support.
+ const Capabilities capabilities_;
+ DcSctpSocketCallbacks& callbacks_;
+ // The data retransmission timer, called t3-rtx in SCTP.
+ const std::unique_ptr<Timer> t3_rtx_;
+ // Delayed ack timer, which triggers when acks should be sent (when delayed).
+ const std::unique_ptr<Timer> delayed_ack_timer_;
+ const VerificationTag my_verification_tag_;
+ const TSN my_initial_tsn_;
+ const VerificationTag peer_verification_tag_;
+ const TSN peer_initial_tsn_;
+ // Nonce, used to detect reconnections.
+ const TieTag tie_tag_;
+ const std::function<bool()> is_connection_established_;
+ PacketSender& packet_sender_;
+ // Rate limiting of FORWARD-TSN. Next can be sent at or after this timestamp.
+ TimeMs limit_forward_tsn_until_ = TimeMs(0);
+
+ RetransmissionTimeout rto_;
+ RetransmissionErrorCounter tx_error_counter_;
+ DataTracker data_tracker_;
+ ReassemblyQueue reassembly_queue_;
+ RetransmissionQueue retransmission_queue_;
+ StreamResetHandler stream_reset_handler_;
+ HeartbeatHandler heartbeat_handler_;
+
+ // Only valid when the socket state == State::kCookieEchoed. In this state,
+ // the socket must wait for COOKIE ACK to continue sending any packets (not
+ // including a COOKIE ECHO). So if `cookie_echo_chunk_` is present, the
+ // SendBufferedChunks will always only just send one packet, with this chunk
+ // as the first chunk in the packet.
+ absl::optional<CookieEchoChunk> cookie_echo_chunk_ = absl::nullopt;
+};
+} // namespace dcsctp
+
+#endif // NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_