diff options
Diffstat (limited to 'third_party/libwebrtc/net/dcsctp/socket')
27 files changed, 9580 insertions, 0 deletions
diff --git a/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn b/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn new file mode 100644 index 0000000000..681ddd47e9 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn @@ -0,0 +1,281 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("context") { + sources = [ "context.h" ] + deps = [ + "../common:internal_types", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_library("heartbeat_handler") { + deps = [ + ":context", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../packet:bounded_io", + "../packet:chunk", + "../packet:parameter", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + "../timer", + ] + sources = [ + "heartbeat_handler.cc", + "heartbeat_handler.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/functional:bind_front", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("stream_reset_handler") { + deps = [ + ":context", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base/containers:flat_set", + "../common:internal_types", + "../common:str_join", + "../packet:chunk", + "../packet:parameter", + "../packet:sctp_packet", + "../packet:tlv_trait", + "../public:socket", + "../public:types", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../timer", + "../tx:retransmission_queue", + ] + sources = [ + "stream_reset_handler.cc", + "stream_reset_handler.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/functional:bind_front", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("packet_sender") { + deps = [ + "../packet:sctp_packet", + "../public:socket", + "../public:types", + "../timer", + ] + sources = [ + "packet_sender.cc", + "packet_sender.h", + ] + absl_deps = [] +} + +rtc_library("transmission_control_block") { + deps = [ + ":context", + ":heartbeat_handler", + ":packet_sender", + ":stream_reset_handler", + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../timer", + "../tx:retransmission_error_counter", + "../tx:retransmission_queue", + "../tx:retransmission_timeout", + "../tx:send_queue", + ] + sources = [ + "capabilities.h", + "transmission_control_block.cc", + "transmission_control_block.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/functional:bind_front", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("dcsctp_socket") { + deps = [ + ":context", + ":heartbeat_handler", + ":packet_sender", + ":stream_reset_handler", + ":transmission_control_block", + "../../../api:array_view", + "../../../api:make_ref_counted", + "../../../api:refcountedbase", + "../../../api:scoped_refptr", + "../../../api:sequence_checker", + "../../../api/task_queue:task_queue", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:internal_types", + "../packet:bounded_io", + "../packet:chunk", + "../packet:chunk_validators", + "../packet:data", + "../packet:error_cause", + "../packet:parameter", + "../packet:sctp_packet", + "../packet:tlv_trait", + "../public:socket", + "../public:types", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../timer", + "../tx:retransmission_error_counter", + "../tx:retransmission_queue", + "../tx:retransmission_timeout", + "../tx:rr_send_queue", + "../tx:send_queue", + ] + sources = [ + "callback_deferrer.cc", + "callback_deferrer.h", + "dcsctp_socket.cc", + "dcsctp_socket.h", + "state_cookie.cc", + "state_cookie.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/functional:bind_front", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +if (rtc_include_tests) { + rtc_source_set("mock_callbacks") { + testonly = true + sources = [ "mock_dcsctp_socket_callbacks.h" ] + deps = [ + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base:logging", + "../../../rtc_base:random", + "../../../test:test_support", + "../public:socket", + "../public:types", + "../timer", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + } + + rtc_source_set("mock_context") { + testonly = true + sources = [ "mock_context.h" ] + deps = [ + ":context", + ":mock_callbacks", + "../../../test:test_support", + "../common:internal_types", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + } + + rtc_library("dcsctp_socket_unittests") { + testonly = true + + deps = [ + ":dcsctp_socket", + ":heartbeat_handler", + ":mock_callbacks", + ":mock_context", + ":packet_sender", + ":stream_reset_handler", + ":transmission_control_block", + "../../../api:array_view", + "../../../api:create_network_emulation_manager", + "../../../api:network_emulation_manager_api", + "../../../api/task_queue", + "../../../api/task_queue:pending_task_safety_flag", + "../../../api/units:time_delta", + "../../../call:simulated_network", + "../../../rtc_base:checks", + "../../../rtc_base:copy_on_write_buffer", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:logging", + "../../../rtc_base:rtc_base_tests_utils", + "../../../rtc_base:socket_address", + "../../../rtc_base:stringutils", + "../../../rtc_base:timeutils", + "../../../test:test_support", + "../common:handover_testing", + "../common:internal_types", + "../common:math", + "../packet:chunk", + "../packet:error_cause", + "../packet:parameter", + "../packet:sctp_packet", + "../packet:tlv_trait", + "../public:socket", + "../public:types", + "../public:utils", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../testing:data_generator", + "../testing:testing_macros", + "../timer", + "../timer:task_queue_timeout", + "../tx:mock_send_queue", + "../tx:retransmission_queue", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + sources = [ + "dcsctp_socket_network_test.cc", + "dcsctp_socket_test.cc", + "heartbeat_handler_test.cc", + "packet_sender_test.cc", + "state_cookie_test.cc", + "stream_reset_handler_test.cc", + "transmission_control_block_test.cc", + ] + } +} diff --git a/third_party/libwebrtc/net/dcsctp/socket/DEPS b/third_party/libwebrtc/net/dcsctp/socket/DEPS new file mode 100644 index 0000000000..d4966290e3 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/DEPS @@ -0,0 +1,5 @@ +specific_include_rules = { + "dcsctp_socket_network_test.cc": [ + "+call", + ] +} diff --git a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc new file mode 100644 index 0000000000..123526e782 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/callback_deferrer.h" + +#include "api/make_ref_counted.h" + +namespace dcsctp { +namespace { +// A wrapper around the move-only DcSctpMessage, to let it be captured in a +// lambda. +class MessageDeliverer { + public: + explicit MessageDeliverer(DcSctpMessage&& message) + : state_(rtc::make_ref_counted<State>(std::move(message))) {} + + void Deliver(DcSctpSocketCallbacks& c) { + // Really ensure that it's only called once. + RTC_DCHECK(!state_->has_delivered); + state_->has_delivered = true; + c.OnMessageReceived(std::move(state_->message)); + } + + private: + struct State : public rtc::RefCountInterface { + explicit State(DcSctpMessage&& m) + : has_delivered(false), message(std::move(m)) {} + bool has_delivered; + DcSctpMessage message; + }; + rtc::scoped_refptr<State> state_; +}; +} // namespace + +void CallbackDeferrer::Prepare() { + RTC_DCHECK(!prepared_); + prepared_ = true; +} + +void CallbackDeferrer::TriggerDeferred() { + // Need to swap here. The client may call into the library from within a + // callback, and that might result in adding new callbacks to this instance, + // and the vector can't be modified while iterated on. + RTC_DCHECK(prepared_); + std::vector<std::function<void(DcSctpSocketCallbacks & cb)>> deferred; + deferred.swap(deferred_); + prepared_ = false; + + for (auto& cb : deferred) { + cb(underlying_); + } +} + +SendPacketStatus CallbackDeferrer::SendPacketWithStatus( + rtc::ArrayView<const uint8_t> data) { + // Will not be deferred - call directly. + return underlying_.SendPacketWithStatus(data); +} + +std::unique_ptr<Timeout> CallbackDeferrer::CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision) { + // Will not be deferred - call directly. + return underlying_.CreateTimeout(precision); +} + +TimeMs CallbackDeferrer::TimeMillis() { + // Will not be deferred - call directly. + return underlying_.TimeMillis(); +} + +uint32_t CallbackDeferrer::GetRandomInt(uint32_t low, uint32_t high) { + // Will not be deferred - call directly. + return underlying_.GetRandomInt(low, high); +} + +void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [deliverer = MessageDeliverer(std::move(message))]( + DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); }); +} + +void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { + cb.OnError(error, message); + }); +} + +void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { + cb.OnAborted(error, message); + }); +} + +void CallbackDeferrer::OnConnected() { + RTC_DCHECK(prepared_); + deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); }); +} + +void CallbackDeferrer::OnClosed() { + RTC_DCHECK(prepared_); + deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); }); +} + +void CallbackDeferrer::OnConnectionRestarted() { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); }); +} + +void CallbackDeferrer::OnStreamsResetFailed( + rtc::ArrayView<const StreamID> outgoing_streams, + absl::string_view reason) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [streams = std::vector<StreamID>(outgoing_streams.begin(), + outgoing_streams.end()), + reason = std::string(reason)](DcSctpSocketCallbacks& cb) { + cb.OnStreamsResetFailed(streams, reason); + }); +} + +void CallbackDeferrer::OnStreamsResetPerformed( + rtc::ArrayView<const StreamID> outgoing_streams) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [streams = std::vector<StreamID>(outgoing_streams.begin(), + outgoing_streams.end())]( + DcSctpSocketCallbacks& cb) { cb.OnStreamsResetPerformed(streams); }); +} + +void CallbackDeferrer::OnIncomingStreamsReset( + rtc::ArrayView<const StreamID> incoming_streams) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [streams = std::vector<StreamID>(incoming_streams.begin(), + incoming_streams.end())]( + DcSctpSocketCallbacks& cb) { cb.OnIncomingStreamsReset(streams); }); +} + +void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) { + RTC_DCHECK(prepared_); + deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) { + cb.OnBufferedAmountLow(stream_id); + }); +} + +void CallbackDeferrer::OnTotalBufferedAmountLow() { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); }); +} + +void CallbackDeferrer::OnLifecycleMessageExpired(LifecycleId lifecycle_id, + bool maybe_delivered) { + // Will not be deferred - call directly. + underlying_.OnLifecycleMessageExpired(lifecycle_id, maybe_delivered); +} +void CallbackDeferrer::OnLifecycleMessageFullySent(LifecycleId lifecycle_id) { + // Will not be deferred - call directly. + underlying_.OnLifecycleMessageFullySent(lifecycle_id); +} +void CallbackDeferrer::OnLifecycleMessageDelivered(LifecycleId lifecycle_id) { + // Will not be deferred - call directly. + underlying_.OnLifecycleMessageDelivered(lifecycle_id); +} +void CallbackDeferrer::OnLifecycleEnd(LifecycleId lifecycle_id) { + // Will not be deferred - call directly. + underlying_.OnLifecycleEnd(lifecycle_id); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h new file mode 100644 index 0000000000..1c35dda6cf --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_ +#define NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_ + +#include <cstdint> +#include <functional> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "api/ref_counted_base.h" +#include "api/scoped_refptr.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" + +namespace dcsctp { +// Defers callbacks until they can be safely triggered. +// +// There are a lot of callbacks from the dcSCTP library to the client, +// such as when messages are received or streams are closed. When the client +// receives these callbacks, the client is expected to be able to call into the +// library - from within the callback. For example, sending a reply message when +// a certain SCTP message has been received, or to reconnect when the connection +// was closed for any reason. This means that the dcSCTP library must always be +// in a consistent and stable state when these callbacks are delivered, and to +// ensure that's the case, callbacks are not immediately delivered from where +// they originate, but instead queued (deferred) by this class. At the end of +// any public API method that may result in callbacks, they are triggered and +// then delivered. +// +// There are a number of exceptions, which is clearly annotated in the API. +class CallbackDeferrer : public DcSctpSocketCallbacks { + public: + class ScopedDeferrer { + public: + explicit ScopedDeferrer(CallbackDeferrer& callback_deferrer) + : callback_deferrer_(callback_deferrer) { + callback_deferrer_.Prepare(); + } + + ~ScopedDeferrer() { callback_deferrer_.TriggerDeferred(); } + + private: + CallbackDeferrer& callback_deferrer_; + }; + + explicit CallbackDeferrer(DcSctpSocketCallbacks& underlying) + : underlying_(underlying) {} + + // Implementation of DcSctpSocketCallbacks + SendPacketStatus SendPacketWithStatus( + rtc::ArrayView<const uint8_t> data) override; + std::unique_ptr<Timeout> CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision) override; + TimeMs TimeMillis() override; + uint32_t GetRandomInt(uint32_t low, uint32_t high) override; + void OnMessageReceived(DcSctpMessage message) override; + void OnError(ErrorKind error, absl::string_view message) override; + void OnAborted(ErrorKind error, absl::string_view message) override; + void OnConnected() override; + void OnClosed() override; + void OnConnectionRestarted() override; + void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams, + absl::string_view reason) override; + void OnStreamsResetPerformed( + rtc::ArrayView<const StreamID> outgoing_streams) override; + void OnIncomingStreamsReset( + rtc::ArrayView<const StreamID> incoming_streams) override; + void OnBufferedAmountLow(StreamID stream_id) override; + void OnTotalBufferedAmountLow() override; + + void OnLifecycleMessageExpired(LifecycleId lifecycle_id, + bool maybe_delivered) override; + void OnLifecycleMessageFullySent(LifecycleId lifecycle_id) override; + void OnLifecycleMessageDelivered(LifecycleId lifecycle_id) override; + void OnLifecycleEnd(LifecycleId lifecycle_id) override; + + private: + void Prepare(); + void TriggerDeferred(); + + DcSctpSocketCallbacks& underlying_; + bool prepared_ = false; + std::vector<std::function<void(DcSctpSocketCallbacks& cb)>> deferred_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/capabilities.h b/third_party/libwebrtc/net/dcsctp/socket/capabilities.h new file mode 100644 index 0000000000..286509a40a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/capabilities.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_CAPABILITIES_H_ +#define NET_DCSCTP_SOCKET_CAPABILITIES_H_ + +#include <cstdint> +namespace dcsctp { +// Indicates what the association supports, meaning that both parties +// support it and that feature can be used. +struct Capabilities { + // RFC3758 Partial Reliability Extension + bool partial_reliability = false; + // RFC8260 Stream Schedulers and User Message Interleaving + bool message_interleaving = false; + // RFC6525 Stream Reconfiguration + bool reconfig = false; + // https://datatracker.ietf.org/doc/draft-tuexen-tsvwg-sctp-zero-checksum/ + bool zero_checksum = false; + // Negotiated maximum incoming and outgoing stream count. + uint16_t negotiated_maximum_incoming_streams = 0; + uint16_t negotiated_maximum_outgoing_streams = 0; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_CAPABILITIES_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/context.h b/third_party/libwebrtc/net/dcsctp/socket/context.h new file mode 100644 index 0000000000..eca5b9e4fb --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/context.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_CONTEXT_H_ +#define NET_DCSCTP_SOCKET_CONTEXT_H_ + +#include <cstdint> + +#include "absl/strings/string_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// A set of helper methods used by handlers to e.g. send packets. +// +// Implemented by the TransmissionControlBlock. +class Context { + public: + virtual ~Context() = default; + + // Indicates if a connection has been established. + virtual bool is_connection_established() const = 0; + + // Returns this side's initial TSN value. + virtual TSN my_initial_tsn() const = 0; + + // Returns the peer's initial TSN value. + virtual TSN peer_initial_tsn() const = 0; + + // Returns the socket callbacks. + virtual DcSctpSocketCallbacks& callbacks() const = 0; + + // Observes a measured RTT value, in milliseconds. + virtual void ObserveRTT(DurationMs rtt_ms) = 0; + + // Returns the current Retransmission Timeout (rto) value, in milliseconds. + virtual DurationMs current_rto() const = 0; + + // Increments the transmission error counter, given a human readable reason. + virtual bool IncrementTxErrorCounter(absl::string_view reason) = 0; + + // Clears the transmission error counter. + virtual void ClearTxErrorCounter() = 0; + + // Returns true if there have been too many retransmission errors. + virtual bool HasTooManyTxErrors() const = 0; + + // Returns a PacketBuilder, filled in with the correct verification tag. + virtual SctpPacket::Builder PacketBuilder() const = 0; + + // Builds the packet from `builder` and sends it. + virtual void Send(SctpPacket::Builder& builder) = 0; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_CONTEXT_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc new file mode 100644 index 0000000000..32bcdaaacf --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc @@ -0,0 +1,1797 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/dcsctp_socket.h" + +#include <algorithm> +#include <cstdint> +#include <limits> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/functional/bind_front.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/packet/chunk/abort_chunk.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" +#include "net/dcsctp/packet/chunk_validators.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/no_user_data_cause.h" +#include "net/dcsctp/packet/error_cause/out_of_resource_error_cause.h" +#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" +#include "net/dcsctp/packet/parameter/zero_checksum_acceptable_chunk_parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/callback_deferrer.h" +#include "net/dcsctp/socket/capabilities.h" +#include "net/dcsctp/socket/heartbeat_handler.h" +#include "net/dcsctp/socket/state_cookie.h" +#include "net/dcsctp/socket/stream_reset_handler.h" +#include "net/dcsctp/socket/transmission_control_block.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" +#include "rtc_base/strings/string_format.h" + +namespace dcsctp { +namespace { + +// https://tools.ietf.org/html/rfc4960#section-5.1 +constexpr uint32_t kMinVerificationTag = 1; +constexpr uint32_t kMaxVerificationTag = std::numeric_limits<uint32_t>::max(); + +// https://tools.ietf.org/html/rfc4960#section-3.3.2 +constexpr uint32_t kMinInitialTsn = 0; +constexpr uint32_t kMaxInitialTsn = std::numeric_limits<uint32_t>::max(); + +Capabilities ComputeCapabilities(const DcSctpOptions& options, + uint16_t peer_nbr_outbound_streams, + uint16_t peer_nbr_inbound_streams, + const Parameters& parameters) { + Capabilities capabilities; + absl::optional<SupportedExtensionsParameter> supported_extensions = + parameters.get<SupportedExtensionsParameter>(); + + if (options.enable_partial_reliability) { + capabilities.partial_reliability = + parameters.get<ForwardTsnSupportedParameter>().has_value(); + if (supported_extensions.has_value()) { + capabilities.partial_reliability |= + supported_extensions->supports(ForwardTsnChunk::kType); + } + } + + if (options.enable_message_interleaving && supported_extensions.has_value()) { + capabilities.message_interleaving = + supported_extensions->supports(IDataChunk::kType) && + supported_extensions->supports(IForwardTsnChunk::kType); + } + if (supported_extensions.has_value() && + supported_extensions->supports(ReConfigChunk::kType)) { + capabilities.reconfig = true; + } + + if (options.enable_zero_checksum && + parameters.get<ZeroChecksumAcceptableChunkParameter>().has_value()) { + capabilities.zero_checksum = true; + } + + capabilities.negotiated_maximum_incoming_streams = std::min( + options.announced_maximum_incoming_streams, peer_nbr_outbound_streams); + capabilities.negotiated_maximum_outgoing_streams = std::min( + options.announced_maximum_outgoing_streams, peer_nbr_inbound_streams); + + return capabilities; +} + +void AddCapabilityParameters(const DcSctpOptions& options, + Parameters::Builder& builder) { + std::vector<uint8_t> chunk_types = {ReConfigChunk::kType}; + + if (options.enable_partial_reliability) { + builder.Add(ForwardTsnSupportedParameter()); + chunk_types.push_back(ForwardTsnChunk::kType); + } + if (options.enable_message_interleaving) { + chunk_types.push_back(IDataChunk::kType); + chunk_types.push_back(IForwardTsnChunk::kType); + } + if (options.enable_zero_checksum) { + builder.Add(ZeroChecksumAcceptableChunkParameter()); + } + builder.Add(SupportedExtensionsParameter(std::move(chunk_types))); +} + +TieTag MakeTieTag(DcSctpSocketCallbacks& cb) { + uint32_t tie_tag_upper = + cb.GetRandomInt(0, std::numeric_limits<uint32_t>::max()); + uint32_t tie_tag_lower = + cb.GetRandomInt(1, std::numeric_limits<uint32_t>::max()); + return TieTag(static_cast<uint64_t>(tie_tag_upper) << 32 | + static_cast<uint64_t>(tie_tag_lower)); +} + +SctpImplementation DeterminePeerImplementation( + rtc::ArrayView<const uint8_t> cookie) { + if (cookie.size() > 8) { + absl::string_view magic(reinterpret_cast<const char*>(cookie.data()), 8); + if (magic == "dcSCTP00") { + return SctpImplementation::kDcsctp; + } + if (magic == "KAME-BSD") { + return SctpImplementation::kUsrSctp; + } + } + return SctpImplementation::kOther; +} +} // namespace + +DcSctpSocket::DcSctpSocket(absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr<PacketObserver> packet_observer, + const DcSctpOptions& options) + : log_prefix_(std::string(log_prefix) + ": "), + packet_observer_(std::move(packet_observer)), + options_(options), + callbacks_(callbacks), + timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return callbacks_.CreateTimeout(precision); + }), + t1_init_(timer_manager_.CreateTimer( + "t1-init", + absl::bind_front(&DcSctpSocket::OnInitTimerExpiry, this), + TimerOptions(options.t1_init_timeout, + TimerBackoffAlgorithm::kExponential, + options.max_init_retransmits))), + t1_cookie_(timer_manager_.CreateTimer( + "t1-cookie", + absl::bind_front(&DcSctpSocket::OnCookieTimerExpiry, this), + TimerOptions(options.t1_cookie_timeout, + TimerBackoffAlgorithm::kExponential, + options.max_init_retransmits))), + t2_shutdown_(timer_manager_.CreateTimer( + "t2-shutdown", + absl::bind_front(&DcSctpSocket::OnShutdownTimerExpiry, this), + TimerOptions(options.t2_shutdown_timeout, + TimerBackoffAlgorithm::kExponential, + options.max_retransmissions))), + packet_sender_(callbacks_, + absl::bind_front(&DcSctpSocket::OnSentPacket, this)), + send_queue_(log_prefix_, + &callbacks_, + options_.max_send_buffer_size, + options_.mtu, + options_.default_stream_priority, + options_.total_buffered_amount_low_threshold) {} + +std::string DcSctpSocket::log_prefix() const { + return log_prefix_ + "[" + std::string(ToString(state_)) + "] "; +} + +bool DcSctpSocket::IsConsistent() const { + if (tcb_ != nullptr && tcb_->reassembly_queue().HasMessages()) { + return false; + } + switch (state_) { + case State::kClosed: + return (tcb_ == nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kCookieWait: + return (tcb_ == nullptr && t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kCookieEchoed: + return (tcb_ != nullptr && !t1_init_->is_running() && + t1_cookie_->is_running() && !t2_shutdown_->is_running() && + tcb_->has_cookie_echo_chunk()); + case State::kEstablished: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kShutdownPending: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kShutdownSent: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && t2_shutdown_->is_running()); + case State::kShutdownReceived: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kShutdownAckSent: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && t2_shutdown_->is_running()); + } +} + +constexpr absl::string_view DcSctpSocket::ToString(DcSctpSocket::State state) { + switch (state) { + case DcSctpSocket::State::kClosed: + return "CLOSED"; + case DcSctpSocket::State::kCookieWait: + return "COOKIE_WAIT"; + case DcSctpSocket::State::kCookieEchoed: + return "COOKIE_ECHOED"; + case DcSctpSocket::State::kEstablished: + return "ESTABLISHED"; + case DcSctpSocket::State::kShutdownPending: + return "SHUTDOWN_PENDING"; + case DcSctpSocket::State::kShutdownSent: + return "SHUTDOWN_SENT"; + case DcSctpSocket::State::kShutdownReceived: + return "SHUTDOWN_RECEIVED"; + case DcSctpSocket::State::kShutdownAckSent: + return "SHUTDOWN_ACK_SENT"; + } +} + +void DcSctpSocket::SetState(State state, absl::string_view reason) { + if (state_ != state) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Socket state changed from " + << ToString(state_) << " to " << ToString(state) + << " due to " << reason; + state_ = state; + } +} + +void DcSctpSocket::SendInit() { + Parameters::Builder params_builder; + AddCapabilityParameters(options_, params_builder); + InitChunk init(/*initiate_tag=*/connect_params_.verification_tag, + /*a_rwnd=*/options_.max_receiver_window_buffer_size, + options_.announced_maximum_outgoing_streams, + options_.announced_maximum_incoming_streams, + connect_params_.initial_tsn, params_builder.Build()); + SctpPacket::Builder b(VerificationTag(0), options_); + b.Add(init); + // https://www.ietf.org/archive/id/draft-tuexen-tsvwg-sctp-zero-checksum-01.html#section-4.2 + // "When an end point sends a packet containing an INIT chunk, it MUST include + // a correct CRC32c checksum in the packet containing the INIT chunk." + packet_sender_.Send(b, /*write_checksum=*/true); +} + +void DcSctpSocket::MakeConnectionParameters() { + VerificationTag new_verification_tag( + callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag)); + TSN initial_tsn(callbacks_.GetRandomInt(kMinInitialTsn, kMaxInitialTsn)); + connect_params_.initial_tsn = initial_tsn; + connect_params_.verification_tag = new_verification_tag; +} + +void DcSctpSocket::Connect() { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (state_ == State::kClosed) { + MakeConnectionParameters(); + RTC_DLOG(LS_INFO) + << log_prefix() + << rtc::StringFormat( + "Connecting. my_verification_tag=%08x, my_initial_tsn=%u", + *connect_params_.verification_tag, *connect_params_.initial_tsn); + SendInit(); + t1_init_->Start(); + SetState(State::kCookieWait, "Connect called"); + } else { + RTC_DLOG(LS_WARNING) << log_prefix() + << "Called Connect on a socket that is not closed"; + } + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::CreateTransmissionControlBlock( + const Capabilities& capabilities, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag) { + metrics_.uses_message_interleaving = capabilities.message_interleaving; + metrics_.uses_zero_checksum = capabilities.zero_checksum; + metrics_.negotiated_maximum_incoming_streams = + capabilities.negotiated_maximum_incoming_streams; + metrics_.negotiated_maximum_outgoing_streams = + capabilities.negotiated_maximum_outgoing_streams; + tcb_ = std::make_unique<TransmissionControlBlock>( + timer_manager_, log_prefix_, options_, capabilities, callbacks_, + send_queue_, my_verification_tag, my_initial_tsn, peer_verification_tag, + peer_initial_tsn, a_rwnd, tie_tag, packet_sender_, + [this]() { return state_ == State::kEstablished; }); + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Created TCB: " << tcb_->ToString(); +} + +void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (state_ != State::kClosed) { + callbacks_.OnError(ErrorKind::kUnsupportedOperation, + "Only closed socket can be restored from state"); + } else { + if (state.socket_state == + DcSctpSocketHandoverState::SocketState::kConnected) { + VerificationTag my_verification_tag = + VerificationTag(state.my_verification_tag); + connect_params_.verification_tag = my_verification_tag; + + Capabilities capabilities; + capabilities.partial_reliability = state.capabilities.partial_reliability; + capabilities.message_interleaving = + state.capabilities.message_interleaving; + capabilities.reconfig = state.capabilities.reconfig; + capabilities.zero_checksum = state.capabilities.zero_checksum; + capabilities.negotiated_maximum_incoming_streams = + state.capabilities.negotiated_maximum_incoming_streams; + capabilities.negotiated_maximum_outgoing_streams = + state.capabilities.negotiated_maximum_outgoing_streams; + + send_queue_.RestoreFromState(state); + + CreateTransmissionControlBlock( + capabilities, my_verification_tag, TSN(state.my_initial_tsn), + VerificationTag(state.peer_verification_tag), + TSN(state.peer_initial_tsn), static_cast<size_t>(0), + TieTag(state.tie_tag)); + + tcb_->RestoreFromState(state); + + SetState(State::kEstablished, "restored from handover state"); + callbacks_.OnConnected(); + } + } + + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::Shutdown() { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (tcb_ != nullptr) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon receipt of the SHUTDOWN primitive from its upper layer, the + // endpoint enters the SHUTDOWN-PENDING state and remains there until all + // outstanding data has been acknowledged by its peer." + + // TODO(webrtc:12739): Remove this check, as it just hides the problem that + // the socket can transition from ShutdownSent to ShutdownPending, or + // ShutdownAckSent to ShutdownPending which is illegal. + if (state_ != State::kShutdownSent && state_ != State::kShutdownAckSent) { + SetState(State::kShutdownPending, "Shutdown called"); + t1_init_->Stop(); + t1_cookie_->Stop(); + MaybeSendShutdownOrAck(); + } + } else { + // Connection closed before even starting to connect, or during the initial + // connection phase. There is no outstanding data, so the socket can just + // be closed (stopping any connection timers, if any), as this is the + // client's intention, by calling Shutdown. + InternalClose(ErrorKind::kNoError, ""); + } + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::Close() { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (state_ != State::kClosed) { + if (tcb_ != nullptr) { + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(AbortChunk(/*filled_in_verification_tag=*/true, + Parameters::Builder() + .Add(UserInitiatedAbortCause("Close called")) + .Build())); + packet_sender_.Send(b); + } + InternalClose(ErrorKind::kNoError, ""); + } else { + RTC_DLOG(LS_INFO) << log_prefix() << "Called Close on a closed socket"; + } + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::CloseConnectionBecauseOfTooManyTransmissionErrors() { + packet_sender_.Send(tcb_->PacketBuilder().Add(AbortChunk( + true, Parameters::Builder() + .Add(UserInitiatedAbortCause("Too many retransmissions")) + .Build()))); + InternalClose(ErrorKind::kTooManyRetries, "Too many retransmissions"); +} + +void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) { + if (state_ != State::kClosed) { + t1_init_->Stop(); + t1_cookie_->Stop(); + t2_shutdown_->Stop(); + tcb_ = nullptr; + + if (error == ErrorKind::kNoError) { + callbacks_.OnClosed(); + } else { + callbacks_.OnAborted(error, message); + } + SetState(State::kClosed, message); + } + // This method's purpose is to abort/close and make it consistent by ensuring + // that e.g. all timers really are stopped. + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::SetStreamPriority(StreamID stream_id, + StreamPriority priority) { + RTC_DCHECK_RUN_ON(&thread_checker_); + send_queue_.SetStreamPriority(stream_id, priority); +} +StreamPriority DcSctpSocket::GetStreamPriority(StreamID stream_id) const { + RTC_DCHECK_RUN_ON(&thread_checker_); + return send_queue_.GetStreamPriority(stream_id); +} + +SendStatus DcSctpSocket::Send(DcSctpMessage message, + const SendOptions& send_options) { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + LifecycleId lifecycle_id = send_options.lifecycle_id; + + if (message.payload().empty()) { + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Unable to send empty message"); + return SendStatus::kErrorMessageEmpty; + } + if (message.payload().size() > options_.max_message_size) { + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Unable to send too large message"); + return SendStatus::kErrorMessageTooLarge; + } + if (state_ == State::kShutdownPending || state_ == State::kShutdownSent || + state_ == State::kShutdownReceived || state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "An endpoint should reject any new data request from its upper layer + // if it is in the SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED, or + // SHUTDOWN-ACK-SENT state." + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } + callbacks_.OnError(ErrorKind::kWrongSequence, + "Unable to send message as the socket is shutting down"); + return SendStatus::kErrorShuttingDown; + } + if (send_queue_.IsFull()) { + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } + callbacks_.OnError(ErrorKind::kResourceExhaustion, + "Unable to send message as the send queue is full"); + return SendStatus::kErrorResourceExhaustion; + } + + TimeMs now = callbacks_.TimeMillis(); + ++metrics_.tx_messages_count; + send_queue_.Add(now, std::move(message), send_options); + if (tcb_ != nullptr) { + tcb_->SendBufferedPackets(now); + } + + RTC_DCHECK(IsConsistent()); + return SendStatus::kSuccess; +} + +ResetStreamsStatus DcSctpSocket::ResetStreams( + rtc::ArrayView<const StreamID> outgoing_streams) { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (tcb_ == nullptr) { + callbacks_.OnError(ErrorKind::kWrongSequence, + "Can't reset streams as the socket is not connected"); + return ResetStreamsStatus::kNotConnected; + } + if (!tcb_->capabilities().reconfig) { + callbacks_.OnError(ErrorKind::kUnsupportedOperation, + "Can't reset streams as the peer doesn't support it"); + return ResetStreamsStatus::kNotSupported; + } + + tcb_->stream_reset_handler().ResetStreams(outgoing_streams); + MaybeSendResetStreamsRequest(); + + RTC_DCHECK(IsConsistent()); + return ResetStreamsStatus::kPerformed; +} + +SocketState DcSctpSocket::state() const { + RTC_DCHECK_RUN_ON(&thread_checker_); + switch (state_) { + case State::kClosed: + return SocketState::kClosed; + case State::kCookieWait: + case State::kCookieEchoed: + return SocketState::kConnecting; + case State::kEstablished: + return SocketState::kConnected; + case State::kShutdownPending: + case State::kShutdownSent: + case State::kShutdownReceived: + case State::kShutdownAckSent: + return SocketState::kShuttingDown; + } +} + +void DcSctpSocket::SetMaxMessageSize(size_t max_message_size) { + RTC_DCHECK_RUN_ON(&thread_checker_); + options_.max_message_size = max_message_size; +} + +size_t DcSctpSocket::buffered_amount(StreamID stream_id) const { + RTC_DCHECK_RUN_ON(&thread_checker_); + return send_queue_.buffered_amount(stream_id); +} + +size_t DcSctpSocket::buffered_amount_low_threshold(StreamID stream_id) const { + RTC_DCHECK_RUN_ON(&thread_checker_); + return send_queue_.buffered_amount_low_threshold(stream_id); +} + +void DcSctpSocket::SetBufferedAmountLowThreshold(StreamID stream_id, + size_t bytes) { + RTC_DCHECK_RUN_ON(&thread_checker_); + send_queue_.SetBufferedAmountLowThreshold(stream_id, bytes); +} + +absl::optional<Metrics> DcSctpSocket::GetMetrics() const { + RTC_DCHECK_RUN_ON(&thread_checker_); + + if (tcb_ == nullptr) { + return absl::nullopt; + } + + Metrics metrics = metrics_; + metrics.cwnd_bytes = tcb_->cwnd(); + metrics.srtt_ms = tcb_->current_srtt().value(); + size_t packet_payload_size = + options_.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize; + metrics.unack_data_count = + tcb_->retransmission_queue().outstanding_items() + + (send_queue_.total_buffered_amount() + packet_payload_size - 1) / + packet_payload_size; + metrics.peer_rwnd_bytes = tcb_->retransmission_queue().rwnd(); + metrics.negotiated_maximum_incoming_streams = + tcb_->capabilities().negotiated_maximum_incoming_streams; + metrics.negotiated_maximum_incoming_streams = + tcb_->capabilities().negotiated_maximum_incoming_streams; + metrics.rtx_packets_count = tcb_->retransmission_queue().rtx_packets_count(); + metrics.rtx_bytes_count = tcb_->retransmission_queue().rtx_bytes_count(); + + return metrics; +} + +void DcSctpSocket::MaybeSendShutdownOnPacketReceived(const SctpPacket& packet) { + if (state_ == State::kShutdownSent) { + bool has_data_chunk = + std::find_if(packet.descriptors().begin(), packet.descriptors().end(), + [](const SctpPacket::ChunkDescriptor& descriptor) { + return descriptor.type == DataChunk::kType; + }) != packet.descriptors().end(); + if (has_data_chunk) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "While in the SHUTDOWN-SENT state, the SHUTDOWN sender MUST immediately + // respond to each received packet containing one or more DATA chunks with + // a SHUTDOWN chunk and restart the T2-shutdown timer."" + SendShutdown(); + t2_shutdown_->set_duration(tcb_->current_rto()); + t2_shutdown_->Start(); + } + } +} + +void DcSctpSocket::MaybeSendResetStreamsRequest() { + absl::optional<ReConfigChunk> reconfig = + tcb_->stream_reset_handler().MakeStreamResetRequest(); + if (reconfig.has_value()) { + SctpPacket::Builder builder = tcb_->PacketBuilder(); + builder.Add(*reconfig); + packet_sender_.Send(builder); + } +} + +bool DcSctpSocket::ValidatePacket(const SctpPacket& packet) { + const CommonHeader& header = packet.common_header(); + VerificationTag my_verification_tag = + tcb_ != nullptr ? tcb_->my_verification_tag() : VerificationTag(0); + + if (header.verification_tag == VerificationTag(0)) { + if (packet.descriptors().size() == 1 && + packet.descriptors()[0].type == InitChunk::kType) { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "When an endpoint receives an SCTP packet with the Verification Tag + // set to 0, it should verify that the packet contains only an INIT chunk. + // Otherwise, the receiver MUST silently discard the packet."" + return true; + } + callbacks_.OnError( + ErrorKind::kParseFailed, + "Only a single INIT chunk can be present in packets sent on " + "verification_tag = 0"); + return false; + } + + if (packet.descriptors().size() == 1 && + packet.descriptors()[0].type == AbortChunk::kType) { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "The receiver of an ABORT MUST accept the packet if the Verification + // Tag field of the packet matches its own tag and the T bit is not set OR + // if it is set to its peer's tag and the T bit is set in the Chunk Flags. + // Otherwise, the receiver MUST silently discard the packet and take no + // further action." + bool t_bit = (packet.descriptors()[0].flags & 0x01) != 0; + if (t_bit && tcb_ == nullptr) { + // Can't verify the tag - assume it's okey. + return true; + } + if ((!t_bit && header.verification_tag == my_verification_tag) || + (t_bit && header.verification_tag == tcb_->peer_verification_tag())) { + return true; + } + callbacks_.OnError(ErrorKind::kParseFailed, + "ABORT chunk verification tag was wrong"); + return false; + } + + if (packet.descriptors()[0].type == InitAckChunk::kType) { + if (header.verification_tag == connect_params_.verification_tag) { + return true; + } + callbacks_.OnError( + ErrorKind::kParseFailed, + rtc::StringFormat( + "Packet has invalid verification tag: %08x, expected %08x", + *header.verification_tag, *connect_params_.verification_tag)); + return false; + } + + if (packet.descriptors()[0].type == CookieEchoChunk::kType) { + // Handled in chunk handler (due to RFC 4960, section 5.2.4). + return true; + } + + if (packet.descriptors().size() == 1 && + packet.descriptors()[0].type == ShutdownCompleteChunk::kType) { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "The receiver of a SHUTDOWN COMPLETE shall accept the packet if the + // Verification Tag field of the packet matches its own tag and the T bit is + // not set OR if it is set to its peer's tag and the T bit is set in the + // Chunk Flags. Otherwise, the receiver MUST silently discard the packet + // and take no further action." + bool t_bit = (packet.descriptors()[0].flags & 0x01) != 0; + if (t_bit && tcb_ == nullptr) { + // Can't verify the tag - assume it's okey. + return true; + } + if ((!t_bit && header.verification_tag == my_verification_tag) || + (t_bit && header.verification_tag == tcb_->peer_verification_tag())) { + return true; + } + callbacks_.OnError(ErrorKind::kParseFailed, + "SHUTDOWN_COMPLETE chunk verification tag was wrong"); + return false; + } + + // https://tools.ietf.org/html/rfc4960#section-8.5 + // "When receiving an SCTP packet, the endpoint MUST ensure that the value + // in the Verification Tag field of the received SCTP packet matches its own + // tag. If the received Verification Tag value does not match the receiver's + // own tag value, the receiver shall silently discard the packet and shall not + // process it any further..." + if (header.verification_tag == my_verification_tag) { + return true; + } + + callbacks_.OnError( + ErrorKind::kParseFailed, + rtc::StringFormat( + "Packet has invalid verification tag: %08x, expected %08x", + *header.verification_tag, *my_verification_tag)); + return false; +} + +void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + timer_manager_.HandleTimeout(timeout_id); + + if (tcb_ != nullptr && tcb_->HasTooManyTxErrors()) { + // Tearing down the TCB has to be done outside the handlers. + CloseConnectionBecauseOfTooManyTransmissionErrors(); + } + + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + ++metrics_.rx_packets_count; + + if (packet_observer_ != nullptr) { + packet_observer_->OnReceivedPacket(callbacks_.TimeMillis(), data); + } + + absl::optional<SctpPacket> packet = SctpPacket::Parse(data, options_); + if (!packet.has_value()) { + // https://tools.ietf.org/html/rfc4960#section-6.8 + // "The default procedure for handling invalid SCTP packets is to + // silently discard them." + callbacks_.OnError(ErrorKind::kParseFailed, + "Failed to parse received SCTP packet"); + RTC_DCHECK(IsConsistent()); + return; + } + + if (RTC_DLOG_IS_ON) { + for (const auto& descriptor : packet->descriptors()) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received " + << DebugConvertChunkToString(descriptor.data); + } + } + + if (!ValidatePacket(*packet)) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Packet failed verification tag check - dropping"; + RTC_DCHECK(IsConsistent()); + return; + } + + MaybeSendShutdownOnPacketReceived(*packet); + + for (const auto& descriptor : packet->descriptors()) { + if (!Dispatch(packet->common_header(), descriptor)) { + break; + } + } + + if (tcb_ != nullptr) { + tcb_->data_tracker().ObservePacketEnd(); + tcb_->MaybeSendSack(); + } + + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::DebugPrintOutgoing(rtc::ArrayView<const uint8_t> payload) { + auto packet = SctpPacket::Parse(payload, options_); + RTC_DCHECK(packet.has_value()); + + for (const auto& desc : packet->descriptors()) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Sent " + << DebugConvertChunkToString(desc.data); + } +} + +bool DcSctpSocket::Dispatch(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + switch (descriptor.type) { + case DataChunk::kType: + HandleData(header, descriptor); + break; + case InitChunk::kType: + HandleInit(header, descriptor); + break; + case InitAckChunk::kType: + HandleInitAck(header, descriptor); + break; + case SackChunk::kType: + HandleSack(header, descriptor); + break; + case HeartbeatRequestChunk::kType: + HandleHeartbeatRequest(header, descriptor); + break; + case HeartbeatAckChunk::kType: + HandleHeartbeatAck(header, descriptor); + break; + case AbortChunk::kType: + HandleAbort(header, descriptor); + break; + case ErrorChunk::kType: + HandleError(header, descriptor); + break; + case CookieEchoChunk::kType: + HandleCookieEcho(header, descriptor); + break; + case CookieAckChunk::kType: + HandleCookieAck(header, descriptor); + break; + case ShutdownChunk::kType: + HandleShutdown(header, descriptor); + break; + case ShutdownAckChunk::kType: + HandleShutdownAck(header, descriptor); + break; + case ShutdownCompleteChunk::kType: + HandleShutdownComplete(header, descriptor); + break; + case ReConfigChunk::kType: + HandleReconfig(header, descriptor); + break; + case ForwardTsnChunk::kType: + HandleForwardTsn(header, descriptor); + break; + case IDataChunk::kType: + HandleIData(header, descriptor); + break; + case IForwardTsnChunk::kType: + HandleIForwardTsn(header, descriptor); + break; + default: + return HandleUnrecognizedChunk(descriptor); + } + return true; +} + +bool DcSctpSocket::HandleUnrecognizedChunk( + const SctpPacket::ChunkDescriptor& descriptor) { + bool report_as_error = (descriptor.type & 0x40) != 0; + bool continue_processing = (descriptor.type & 0x80) != 0; + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received unknown chunk: " + << static_cast<int>(descriptor.type); + if (report_as_error) { + rtc::StringBuilder sb; + sb << "Received unknown chunk of type: " + << static_cast<int>(descriptor.type) << " with report-error bit set"; + callbacks_.OnError(ErrorKind::kParseFailed, sb.str()); + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Unknown chunk, with type indicating it should be reported."; + + // https://tools.ietf.org/html/rfc4960#section-3.2 + // "... report in an ERROR chunk using the 'Unrecognized Chunk Type' + // cause." + if (tcb_ != nullptr) { + // Need TCB - this chunk must be sent with a correct verification tag. + packet_sender_.Send(tcb_->PacketBuilder().Add( + ErrorChunk(Parameters::Builder() + .Add(UnrecognizedChunkTypeCause(std::vector<uint8_t>( + descriptor.data.begin(), descriptor.data.end()))) + .Build()))); + } + } + if (!continue_processing) { + // https://tools.ietf.org/html/rfc4960#section-3.2 + // "Stop processing this SCTP packet and discard it, do not process any + // further chunks within it." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Unknown chunk, with type indicating not to " + "process any further chunks"; + } + + return continue_processing; +} + +absl::optional<DurationMs> DcSctpSocket::OnInitTimerExpiry() { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t1_init_->name() + << " has expired: " << t1_init_->expiration_count() + << "/" << t1_init_->options().max_restarts.value_or(-1); + RTC_DCHECK(state_ == State::kCookieWait); + + if (t1_init_->is_running()) { + SendInit(); + } else { + InternalClose(ErrorKind::kTooManyRetries, "No INIT_ACK received"); + } + RTC_DCHECK(IsConsistent()); + return absl::nullopt; +} + +absl::optional<DurationMs> DcSctpSocket::OnCookieTimerExpiry() { + // https://tools.ietf.org/html/rfc4960#section-4 + // "If the T1-cookie timer expires, the endpoint MUST retransmit COOKIE + // ECHO and restart the T1-cookie timer without changing state. This MUST + // be repeated up to 'Max.Init.Retransmits' times. After that, the endpoint + // MUST abort the initialization process and report the error to the SCTP + // user." + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t1_cookie_->name() + << " has expired: " << t1_cookie_->expiration_count() + << "/" + << t1_cookie_->options().max_restarts.value_or(-1); + + RTC_DCHECK(state_ == State::kCookieEchoed); + + if (t1_cookie_->is_running()) { + tcb_->SendBufferedPackets(callbacks_.TimeMillis()); + } else { + InternalClose(ErrorKind::kTooManyRetries, "No COOKIE_ACK received"); + } + + RTC_DCHECK(IsConsistent()); + return absl::nullopt; +} + +absl::optional<DurationMs> DcSctpSocket::OnShutdownTimerExpiry() { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t2_shutdown_->name() + << " has expired: " << t2_shutdown_->expiration_count() + << "/" + << t2_shutdown_->options().max_restarts.value_or(-1); + + if (!t2_shutdown_->is_running()) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "An endpoint should limit the number of retransmissions of the SHUTDOWN + // chunk to the protocol parameter 'Association.Max.Retrans'. If this + // threshold is exceeded, the endpoint should destroy the TCB..." + + packet_sender_.Send(tcb_->PacketBuilder().Add( + AbortChunk(true, Parameters::Builder() + .Add(UserInitiatedAbortCause( + "Too many retransmissions of SHUTDOWN")) + .Build()))); + + InternalClose(ErrorKind::kTooManyRetries, "No SHUTDOWN_ACK received"); + RTC_DCHECK(IsConsistent()); + return absl::nullopt; + } + + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If the timer expires, the endpoint must resend the SHUTDOWN with the + // updated last sequential TSN received from its peer." + SendShutdown(); + RTC_DCHECK(IsConsistent()); + return tcb_->current_rto(); +} + +void DcSctpSocket::OnSentPacket(rtc::ArrayView<const uint8_t> packet, + SendPacketStatus status) { + // The packet observer is invoked even if the packet was failed to be sent, to + // indicate an attempt was made. + if (packet_observer_ != nullptr) { + packet_observer_->OnSentPacket(callbacks_.TimeMillis(), packet); + } + + if (status == SendPacketStatus::kSuccess) { + if (RTC_DLOG_IS_ON) { + DebugPrintOutgoing(packet); + } + + // The heartbeat interval timer is restarted for every sent packet, to + // fire when the outgoing channel is inactive. + if (tcb_ != nullptr) { + tcb_->heartbeat_handler().RestartTimer(); + } + + ++metrics_.tx_packets_count; + } +} + +bool DcSctpSocket::ValidateHasTCB() { + if (tcb_ != nullptr) { + return true; + } + + callbacks_.OnError( + ErrorKind::kNotConnected, + "Received unexpected commands on socket that is not connected"); + return false; +} + +void DcSctpSocket::ReportFailedToParseChunk(int chunk_type) { + rtc::StringBuilder sb; + sb << "Failed to parse chunk of type: " << chunk_type; + callbacks_.OnError(ErrorKind::kParseFailed, sb.str()); +} + +void DcSctpSocket::HandleData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<DataChunk> chunk = DataChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleDataCommon(*chunk); + } +} + +void DcSctpSocket::HandleIData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<IDataChunk> chunk = IDataChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleDataCommon(*chunk); + } +} + +void DcSctpSocket::HandleDataCommon(AnyDataChunk& chunk) { + TSN tsn = chunk.tsn(); + AnyDataChunk::ImmediateAckFlag immediate_ack = chunk.options().immediate_ack; + Data data = std::move(chunk).extract(); + + if (data.payload.empty()) { + // Empty DATA chunks are illegal. + packet_sender_.Send(tcb_->PacketBuilder().Add( + ErrorChunk(Parameters::Builder().Add(NoUserDataCause(tsn)).Build()))); + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Received DATA chunk with no user data"); + return; + } + + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Handle DATA, queue_size=" + << tcb_->reassembly_queue().queued_bytes() + << ", water_mark=" + << tcb_->reassembly_queue().watermark_bytes() + << ", full=" << tcb_->reassembly_queue().is_full() + << ", above=" + << tcb_->reassembly_queue().is_above_watermark(); + + if (tcb_->reassembly_queue().is_full()) { + // If the reassembly queue is full, there is nothing that can be done. The + // specification only allows dropping gap-ack-blocks, and that's not + // likely to help as the socket has been trying to fill gaps since the + // watermark was reached. + packet_sender_.Send(tcb_->PacketBuilder().Add(AbortChunk( + true, Parameters::Builder().Add(OutOfResourceErrorCause()).Build()))); + InternalClose(ErrorKind::kResourceExhaustion, + "Reassembly Queue is exhausted"); + return; + } + + if (tcb_->reassembly_queue().is_above_watermark()) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Is above high watermark"; + // If the reassembly queue is above its high watermark, only accept data + // chunks that increase its cumulative ack tsn in an attempt to fill gaps + // to deliver messages. + if (!tcb_->data_tracker().will_increase_cum_ack_tsn(tsn)) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Rejected data because of exceeding watermark"; + tcb_->data_tracker().ForceImmediateSack(); + return; + } + } + + if (!tcb_->data_tracker().IsTSNValid(tsn)) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Rejected data because of failing TSN validity"; + return; + } + + if (tcb_->data_tracker().Observe(tsn, immediate_ack)) { + tcb_->reassembly_queue().Add(tsn, std::move(data)); + MaybeDeliverMessages(); + } +} + +void DcSctpSocket::HandleInit(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<InitChunk> chunk = InitChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + if (chunk->initiate_tag() == VerificationTag(0) || + chunk->nbr_outbound_streams() == 0 || chunk->nbr_inbound_streams() == 0) { + // https://tools.ietf.org/html/rfc4960#section-3.3.2 + // "If the value of the Initiate Tag in a received INIT chunk is found + // to be 0, the receiver MUST treat it as an error and close the + // association by transmitting an ABORT." + + // "A receiver of an INIT with the OS value set to 0 SHOULD abort the + // association." + + // "A receiver of an INIT with the MIS value of 0 SHOULD abort the + // association." + + packet_sender_.Send( + SctpPacket::Builder(VerificationTag(0), options_) + .Add(AbortChunk( + /*filled_in_verification_tag=*/false, + Parameters::Builder() + .Add(ProtocolViolationCause("INIT malformed")) + .Build()))); + InternalClose(ErrorKind::kProtocolViolation, "Received invalid INIT"); + return; + } + + if (state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If an endpoint is in the SHUTDOWN-ACK-SENT state and receives an + // INIT chunk (e.g., if the SHUTDOWN COMPLETE was lost) with source and + // destination transport addresses (either in the IP addresses or in the + // INIT chunk) that belong to this association, it should discard the INIT + // chunk and retransmit the SHUTDOWN ACK chunk." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init indicating lost ShutdownComplete"; + SendShutdownAck(); + return; + } + + TieTag tie_tag(0); + if (state_ == State::kClosed) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init in closed state (normal)"; + + MakeConnectionParameters(); + } else if (state_ == State::kCookieWait || state_ == State::kCookieEchoed) { + // https://tools.ietf.org/html/rfc4960#section-5.2.1 + // "This usually indicates an initialization collision, i.e., each + // endpoint is attempting, at about the same time, to establish an + // association with the other endpoint. Upon receipt of an INIT in the + // COOKIE-WAIT state, an endpoint MUST respond with an INIT ACK using the + // same parameters it sent in its original INIT chunk (including its + // Initiate Tag, unchanged). When responding, the endpoint MUST send the + // INIT ACK back to the same address that the original INIT (sent by this + // endpoint) was sent." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init indicating simultaneous connections"; + } else { + RTC_DCHECK(tcb_ != nullptr); + // https://tools.ietf.org/html/rfc4960#section-5.2.2 + // "The outbound SCTP packet containing this INIT ACK MUST carry a + // Verification Tag value equal to the Initiate Tag found in the + // unexpected INIT. And the INIT ACK MUST contain a new Initiate Tag + // (randomly generated; see Section 5.3.1). Other parameters for the + // endpoint SHOULD be copied from the existing parameters of the + // association (e.g., number of outbound streams) into the INIT ACK and + // cookie." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init indicating restarted connection"; + // Create a new verification tag - different from the previous one. + for (int tries = 0; tries < 10; ++tries) { + connect_params_.verification_tag = VerificationTag( + callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag)); + if (connect_params_.verification_tag != tcb_->my_verification_tag()) { + break; + } + } + + // Make the initial TSN make a large jump, so that there is no overlap + // with the old and new association. + connect_params_.initial_tsn = + TSN(*tcb_->retransmission_queue().next_tsn() + 1000000); + tie_tag = tcb_->tie_tag(); + } + + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << rtc::StringFormat( + "Proceeding with connection. my_verification_tag=%08x, " + "my_initial_tsn=%u, peer_verification_tag=%08x, " + "peer_initial_tsn=%u", + *connect_params_.verification_tag, *connect_params_.initial_tsn, + *chunk->initiate_tag(), *chunk->initial_tsn()); + + Capabilities capabilities = + ComputeCapabilities(options_, chunk->nbr_outbound_streams(), + chunk->nbr_inbound_streams(), chunk->parameters()); + + SctpPacket::Builder b(chunk->initiate_tag(), options_); + Parameters::Builder params_builder = + Parameters::Builder().Add(StateCookieParameter( + StateCookie(chunk->initiate_tag(), chunk->initial_tsn(), + chunk->a_rwnd(), tie_tag, capabilities) + .Serialize())); + AddCapabilityParameters(options_, params_builder); + + InitAckChunk init_ack(/*initiate_tag=*/connect_params_.verification_tag, + options_.max_receiver_window_buffer_size, + options_.announced_maximum_outgoing_streams, + options_.announced_maximum_incoming_streams, + connect_params_.initial_tsn, params_builder.Build()); + b.Add(init_ack); + // If the peer has signaled that it supports zero checksum, INIT-ACK can then + // have its checksum as zero. + packet_sender_.Send(b, /*write_checksum=*/!capabilities.zero_checksum); +} + +void DcSctpSocket::HandleInitAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<InitAckChunk> chunk = InitAckChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + if (state_ != State::kCookieWait) { + // https://tools.ietf.org/html/rfc4960#section-5.2.3 + // "If an INIT ACK is received by an endpoint in any state other than + // the COOKIE-WAIT state, the endpoint should discard the INIT ACK chunk." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received INIT_ACK in unexpected state"; + return; + } + + auto cookie = chunk->parameters().get<StateCookieParameter>(); + if (!cookie.has_value()) { + packet_sender_.Send( + SctpPacket::Builder(connect_params_.verification_tag, options_) + .Add(AbortChunk( + /*filled_in_verification_tag=*/false, + Parameters::Builder() + .Add(ProtocolViolationCause("INIT-ACK malformed")) + .Build()))); + InternalClose(ErrorKind::kProtocolViolation, + "InitAck chunk doesn't contain a cookie"); + return; + } + Capabilities capabilities = + ComputeCapabilities(options_, chunk->nbr_outbound_streams(), + chunk->nbr_inbound_streams(), chunk->parameters()); + t1_init_->Stop(); + + metrics_.peer_implementation = DeterminePeerImplementation(cookie->data()); + + // If the connection is re-established (peer restarted, but re-used old + // connection), make sure that all message identifiers are reset and any + // partly sent message is re-sent in full. The same is true when the socket + // is closed and later re-opened, which never happens in WebRTC, but is a + // valid operation on the SCTP level. Note that in case of handover, the + // send queue is already re-configured, and shouldn't be reset. + send_queue_.Reset(); + + CreateTransmissionControlBlock(capabilities, connect_params_.verification_tag, + connect_params_.initial_tsn, + chunk->initiate_tag(), chunk->initial_tsn(), + chunk->a_rwnd(), MakeTieTag(callbacks_)); + + SetState(State::kCookieEchoed, "INIT_ACK received"); + + // The connection isn't fully established just yet. + tcb_->SetCookieEchoChunk(CookieEchoChunk(cookie->data())); + tcb_->SendBufferedPackets(callbacks_.TimeMillis()); + t1_cookie_->Start(); +} + +void DcSctpSocket::HandleCookieEcho( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<CookieEchoChunk> chunk = + CookieEchoChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + absl::optional<StateCookie> cookie = + StateCookie::Deserialize(chunk->cookie()); + if (!cookie.has_value()) { + callbacks_.OnError(ErrorKind::kParseFailed, "Failed to parse state cookie"); + return; + } + + if (tcb_ != nullptr) { + if (!HandleCookieEchoWithTCB(header, *cookie)) { + return; + } + } else { + if (header.verification_tag != connect_params_.verification_tag) { + callbacks_.OnError( + ErrorKind::kParseFailed, + rtc::StringFormat( + "Received CookieEcho with invalid verification tag: %08x, " + "expected %08x", + *header.verification_tag, *connect_params_.verification_tag)); + return; + } + } + + // The init timer can be running on simultaneous connections. + t1_init_->Stop(); + t1_cookie_->Stop(); + if (state_ != State::kEstablished) { + if (tcb_ != nullptr) { + tcb_->ClearCookieEchoChunk(); + } + SetState(State::kEstablished, "COOKIE_ECHO received"); + callbacks_.OnConnected(); + } + + if (tcb_ == nullptr) { + // If the connection is re-established (peer restarted, but re-used old + // connection), make sure that all message identifiers are reset and any + // partly sent message is re-sent in full. The same is true when the socket + // is closed and later re-opened, which never happens in WebRTC, but is a + // valid operation on the SCTP level. Note that in case of handover, the + // send queue is already re-configured, and shouldn't be reset. + send_queue_.Reset(); + + CreateTransmissionControlBlock( + cookie->capabilities(), connect_params_.verification_tag, + connect_params_.initial_tsn, cookie->initiate_tag(), + cookie->initial_tsn(), cookie->a_rwnd(), MakeTieTag(callbacks_)); + } + + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(CookieAckChunk()); + + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "A COOKIE ACK chunk may be bundled with any pending DATA chunks (and/or + // SACK chunks), but the COOKIE ACK chunk MUST be the first chunk in the + // packet." + tcb_->SendBufferedPackets(b, callbacks_.TimeMillis()); +} + +bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header, + const StateCookie& cookie) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Handling CookieEchoChunk with TCB. local_tag=" + << *tcb_->my_verification_tag() + << ", peer_tag=" << *header.verification_tag + << ", tcb_tag=" << *tcb_->peer_verification_tag() + << ", cookie_tag=" << *cookie.initiate_tag() + << ", local_tie_tag=" << *tcb_->tie_tag() + << ", peer_tie_tag=" << *cookie.tie_tag(); + // https://tools.ietf.org/html/rfc4960#section-5.2.4 + // "Handle a COOKIE ECHO when a TCB Exists" + if (header.verification_tag != tcb_->my_verification_tag() && + tcb_->peer_verification_tag() != cookie.initiate_tag() && + cookie.tie_tag() == tcb_->tie_tag()) { + // "A) In this case, the peer may have restarted." + if (state_ == State::kShutdownAckSent) { + // "If the endpoint is in the SHUTDOWN-ACK-SENT state and recognizes + // that the peer has restarted ... it MUST NOT set up a new association + // but instead resend the SHUTDOWN ACK and send an ERROR chunk with a + // "Cookie Received While Shutting Down" error cause to its peer." + SctpPacket::Builder b(cookie.initiate_tag(), options_); + b.Add(ShutdownAckChunk()); + b.Add(ErrorChunk(Parameters::Builder() + .Add(CookieReceivedWhileShuttingDownCause()) + .Build())); + packet_sender_.Send(b); + callbacks_.OnError(ErrorKind::kWrongSequence, + "Received COOKIE-ECHO while shutting down"); + return false; + } + + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received COOKIE-ECHO indicating a restarted peer"; + + tcb_ = nullptr; + callbacks_.OnConnectionRestarted(); + } else if (header.verification_tag == tcb_->my_verification_tag() && + tcb_->peer_verification_tag() != cookie.initiate_tag()) { + // TODO(boivie): Handle the peer_tag == 0? + // "B) In this case, both sides may be attempting to start an + // association at about the same time, but the peer endpoint started its + // INIT after responding to the local endpoint's INIT." + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Received COOKIE-ECHO indicating simultaneous connections"; + tcb_ = nullptr; + } else if (header.verification_tag != tcb_->my_verification_tag() && + tcb_->peer_verification_tag() == cookie.initiate_tag() && + cookie.tie_tag() == TieTag(0)) { + // "C) In this case, the local endpoint's cookie has arrived late. + // Before it arrived, the local endpoint sent an INIT and received an + // INIT ACK and finally sent a COOKIE ECHO with the peer's same tag but + // a new tag of its own. The cookie should be silently discarded. The + // endpoint SHOULD NOT change states and should leave any timers + // running." + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Received COOKIE-ECHO indicating a late COOKIE-ECHO. Discarding"; + return false; + } else if (header.verification_tag == tcb_->my_verification_tag() && + tcb_->peer_verification_tag() == cookie.initiate_tag()) { + // "D) When both local and remote tags match, the endpoint should enter + // the ESTABLISHED state, if it is in the COOKIE-ECHOED state. It + // should stop any cookie timer that may be running and send a COOKIE + // ACK." + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Received duplicate COOKIE-ECHO, probably because of peer not " + "receiving COOKIE-ACK and retransmitting COOKIE-ECHO. Continuing."; + } + return true; +} + +void DcSctpSocket::HandleCookieAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<CookieAckChunk> chunk = CookieAckChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + if (state_ != State::kCookieEchoed) { + // https://tools.ietf.org/html/rfc4960#section-5.2.5 + // "At any state other than COOKIE-ECHOED, an endpoint should silently + // discard a received COOKIE ACK chunk." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received COOKIE_ACK not in COOKIE_ECHOED state"; + return; + } + + // RFC 4960, Errata ID: 4400 + t1_cookie_->Stop(); + tcb_->ClearCookieEchoChunk(); + SetState(State::kEstablished, "COOKIE_ACK received"); + tcb_->SendBufferedPackets(callbacks_.TimeMillis()); + callbacks_.OnConnected(); +} + +void DcSctpSocket::MaybeDeliverMessages() { + for (auto& message : tcb_->reassembly_queue().FlushMessages()) { + ++metrics_.rx_messages_count; + callbacks_.OnMessageReceived(std::move(message)); + } +} + +void DcSctpSocket::HandleSack(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<SackChunk> chunk = SackChunk::Parse(descriptor.data); + + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + TimeMs now = callbacks_.TimeMillis(); + SackChunk sack = ChunkValidators::Clean(*std::move(chunk)); + + if (tcb_->retransmission_queue().HandleSack(now, sack)) { + MaybeSendShutdownOrAck(); + // Receiving an ACK may make the socket go into fast recovery mode. + // https://datatracker.ietf.org/doc/html/rfc4960#section-7.2.4 + // "Determine how many of the earliest (i.e., lowest TSN) DATA chunks + // marked for retransmission will fit into a single packet, subject to + // constraint of the path MTU of the destination transport address to + // which the packet is being sent. Call this value K. Retransmit those K + // DATA chunks in a single packet. When a Fast Retransmit is being + // performed, the sender SHOULD ignore the value of cwnd and SHOULD NOT + // delay retransmission for this single packet." + tcb_->MaybeSendFastRetransmit(); + + // Receiving an ACK will decrease outstanding bytes (maybe now below + // cwnd?) or indicate packet loss that may result in sending FORWARD-TSN. + tcb_->SendBufferedPackets(now); + } else { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Dropping out-of-order SACK with TSN " + << *sack.cumulative_tsn_ack(); + } + } +} + +void DcSctpSocket::HandleHeartbeatRequest( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<HeartbeatRequestChunk> chunk = + HeartbeatRequestChunk::Parse(descriptor.data); + + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + tcb_->heartbeat_handler().HandleHeartbeatRequest(*std::move(chunk)); + } +} + +void DcSctpSocket::HandleHeartbeatAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<HeartbeatAckChunk> chunk = + HeartbeatAckChunk::Parse(descriptor.data); + + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + tcb_->heartbeat_handler().HandleHeartbeatAck(*std::move(chunk)); + } +} + +void DcSctpSocket::HandleAbort(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<AbortChunk> chunk = AbortChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk)) { + std::string error_string = ErrorCausesToString(chunk->error_causes()); + if (tcb_ == nullptr) { + // https://tools.ietf.org/html/rfc4960#section-3.3.7 + // "If an endpoint receives an ABORT with a format error or no TCB is + // found, it MUST silently discard it." + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received ABORT (" << error_string + << ") on a connection with no TCB. Ignoring"; + return; + } + + RTC_DLOG(LS_WARNING) << log_prefix() << "Received ABORT (" << error_string + << ") - closing connection."; + InternalClose(ErrorKind::kPeerReported, error_string); + } +} + +void DcSctpSocket::HandleError(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<ErrorChunk> chunk = ErrorChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk)) { + std::string error_string = ErrorCausesToString(chunk->error_causes()); + if (tcb_ == nullptr) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received ERROR (" << error_string + << ") on a connection with no TCB. Ignoring"; + return; + } + + RTC_DLOG(LS_WARNING) << log_prefix() << "Received ERROR: " << error_string; + callbacks_.OnError(ErrorKind::kPeerReported, + "Peer reported error: " + error_string); + } +} + +void DcSctpSocket::HandleReconfig( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + TimeMs now = callbacks_.TimeMillis(); + absl::optional<ReConfigChunk> chunk = ReConfigChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + tcb_->stream_reset_handler().HandleReConfig(*std::move(chunk)); + // Handling this response may result in outgoing stream resets finishing + // (either successfully or with failure). If there still are pending streams + // that were waiting for this request to finish, continue resetting them. + MaybeSendResetStreamsRequest(); + + // If a response was processed, pending to-be-reset streams may now have + // become unpaused. Try to send more DATA chunks. + tcb_->SendBufferedPackets(now); + + // If it leaves "deferred reset processing", there may be chunks to deliver + // that were queued while waiting for the stream to reset. + MaybeDeliverMessages(); + } +} + +void DcSctpSocket::HandleShutdown( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + if (!ValidateParseSuccess(ShutdownChunk::Parse(descriptor.data))) { + return; + } + + if (state_ == State::kClosed) { + return; + } else if (state_ == State::kCookieWait || state_ == State::kCookieEchoed) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If a SHUTDOWN is received in the COOKIE-WAIT or COOKIE ECHOED state, + // the SHUTDOWN chunk SHOULD be silently discarded." + } else if (state_ == State::kShutdownSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If an endpoint is in the SHUTDOWN-SENT state and receives a + // SHUTDOWN chunk from its peer, the endpoint shall respond immediately + // with a SHUTDOWN ACK to its peer, and move into the SHUTDOWN-ACK-SENT + // state restarting its T2-shutdown timer." + SendShutdownAck(); + SetState(State::kShutdownAckSent, "SHUTDOWN received"); + } else if (state_ == State::kShutdownAckSent) { + // TODO(webrtc:12739): This condition should be removed and handled by the + // next (state_ != State::kShutdownReceived). + return; + } else if (state_ != State::kShutdownReceived) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received SHUTDOWN - shutting down the socket"; + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon reception of the SHUTDOWN, the peer endpoint shall enter the + // SHUTDOWN-RECEIVED state, stop accepting new data from its SCTP user, + // and verify, by checking the Cumulative TSN Ack field of the chunk, that + // all its outstanding DATA chunks have been received by the SHUTDOWN + // sender." + SetState(State::kShutdownReceived, "SHUTDOWN received"); + MaybeSendShutdownOrAck(); + } +} + +void DcSctpSocket::HandleShutdownAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + if (!ValidateParseSuccess(ShutdownAckChunk::Parse(descriptor.data))) { + return; + } + + if (state_ == State::kShutdownSent || state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon the receipt of the SHUTDOWN ACK, the SHUTDOWN sender shall stop + // the T2-shutdown timer, send a SHUTDOWN COMPLETE chunk to its peer, and + // remove all record of the association." + + // "If an endpoint is in the SHUTDOWN-ACK-SENT state and receives a + // SHUTDOWN ACK, it shall stop the T2-shutdown timer, send a SHUTDOWN + // COMPLETE chunk to its peer, and remove all record of the association." + + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(ShutdownCompleteChunk(/*tag_reflected=*/false)); + packet_sender_.Send(b); + InternalClose(ErrorKind::kNoError, ""); + } else { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "If the receiver is in COOKIE-ECHOED or COOKIE-WAIT state + // the procedures in Section 8.4 SHOULD be followed; in other words, it + // should be treated as an Out Of The Blue packet." + + // https://tools.ietf.org/html/rfc4960#section-8.4 + // "If the packet contains a SHUTDOWN ACK chunk, the receiver + // should respond to the sender of the OOTB packet with a SHUTDOWN + // COMPLETE. When sending the SHUTDOWN COMPLETE, the receiver of the OOTB + // packet must fill in the Verification Tag field of the outbound packet + // with the Verification Tag received in the SHUTDOWN ACK and set the T + // bit in the Chunk Flags to indicate that the Verification Tag is + // reflected." + + SctpPacket::Builder b(header.verification_tag, options_); + b.Add(ShutdownCompleteChunk(/*tag_reflected=*/true)); + packet_sender_.Send(b); + } +} + +void DcSctpSocket::HandleShutdownComplete( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + if (!ValidateParseSuccess(ShutdownCompleteChunk::Parse(descriptor.data))) { + return; + } + + if (state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon reception of the SHUTDOWN COMPLETE chunk, the endpoint will + // verify that it is in the SHUTDOWN-ACK-SENT state; if it is not, the + // chunk should be discarded. If the endpoint is in the SHUTDOWN-ACK-SENT + // state, the endpoint should stop the T2-shutdown timer and remove all + // knowledge of the association (and thus the association enters the + // CLOSED state)." + InternalClose(ErrorKind::kNoError, ""); + } +} + +void DcSctpSocket::HandleForwardTsn( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<ForwardTsnChunk> chunk = + ForwardTsnChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleForwardTsnCommon(*chunk); + } +} + +void DcSctpSocket::HandleIForwardTsn( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<IForwardTsnChunk> chunk = + IForwardTsnChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleForwardTsnCommon(*chunk); + } +} + +void DcSctpSocket::HandleForwardTsnCommon(const AnyForwardTsnChunk& chunk) { + if (!tcb_->capabilities().partial_reliability) { + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(AbortChunk(/*filled_in_verification_tag=*/true, + Parameters::Builder() + .Add(ProtocolViolationCause( + "I-FORWARD-TSN received, but not indicated " + "during connection establishment")) + .Build())); + packet_sender_.Send(b); + + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Received a FORWARD_TSN without announced peer support"); + return; + } + if (tcb_->data_tracker().HandleForwardTsn(chunk.new_cumulative_tsn())) { + tcb_->reassembly_queue().HandleForwardTsn(chunk.new_cumulative_tsn(), + chunk.skipped_streams()); + } + + // A forward TSN - for ordered streams - may allow messages to be delivered. + MaybeDeliverMessages(); +} + +void DcSctpSocket::MaybeSendShutdownOrAck() { + if (tcb_->retransmission_queue().outstanding_bytes() != 0) { + return; + } + + if (state_ == State::kShutdownPending) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Once all its outstanding data has been acknowledged, the endpoint + // shall send a SHUTDOWN chunk to its peer including in the Cumulative TSN + // Ack field the last sequential TSN it has received from the peer. It + // shall then start the T2-shutdown timer and enter the SHUTDOWN-SENT + // state."" + + SendShutdown(); + t2_shutdown_->set_duration(tcb_->current_rto()); + t2_shutdown_->Start(); + SetState(State::kShutdownSent, "No more outstanding data"); + } else if (state_ == State::kShutdownReceived) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If the receiver of the SHUTDOWN has no more outstanding DATA + // chunks, the SHUTDOWN receiver MUST send a SHUTDOWN ACK and start a + // T2-shutdown timer of its own, entering the SHUTDOWN-ACK-SENT state. If + // the timer expires, the endpoint must resend the SHUTDOWN ACK." + + SendShutdownAck(); + SetState(State::kShutdownAckSent, "No more outstanding data"); + } +} + +void DcSctpSocket::SendShutdown() { + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(ShutdownChunk(tcb_->data_tracker().last_cumulative_acked_tsn())); + packet_sender_.Send(b); +} + +void DcSctpSocket::SendShutdownAck() { + packet_sender_.Send(tcb_->PacketBuilder().Add(ShutdownAckChunk())); + t2_shutdown_->set_duration(tcb_->current_rto()); + t2_shutdown_->Start(); +} + +HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const { + RTC_DCHECK_RUN_ON(&thread_checker_); + HandoverReadinessStatus status; + if (state_ != State::kClosed && state_ != State::kEstablished) { + status.Add(HandoverUnreadinessReason::kWrongConnectionState); + } + status.Add(send_queue_.GetHandoverReadiness()); + if (tcb_) { + status.Add(tcb_->GetHandoverReadiness()); + } + return status; +} + +absl::optional<DcSctpSocketHandoverState> +DcSctpSocket::GetHandoverStateAndClose() { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (!GetHandoverReadiness().IsReady()) { + return absl::nullopt; + } + + DcSctpSocketHandoverState state; + + if (state_ == State::kClosed) { + state.socket_state = DcSctpSocketHandoverState::SocketState::kClosed; + } else if (state_ == State::kEstablished) { + state.socket_state = DcSctpSocketHandoverState::SocketState::kConnected; + tcb_->AddHandoverState(state); + send_queue_.AddHandoverState(state); + InternalClose(ErrorKind::kNoError, "handover"); + } + + return std::move(state); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h new file mode 100644 index 0000000000..f91eb3ead4 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h @@ -0,0 +1,299 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_ +#define NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "api/sequence_checker.h" +#include "net/dcsctp/packet/chunk/abort_chunk.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/callback_deferrer.h" +#include "net/dcsctp/socket/packet_sender.h" +#include "net/dcsctp/socket/state_cookie.h" +#include "net/dcsctp/socket/transmission_control_block.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_error_counter.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "net/dcsctp/tx/rr_send_queue.h" + +namespace dcsctp { + +// DcSctpSocket represents a single SCTP socket, to be used over DTLS. +// +// Every dcSCTP is completely isolated from any other socket. +// +// This class manages all packet and chunk dispatching and mainly handles the +// connection sequences (connect, close, shutdown, etc) as well as managing +// the Transmission Control Block (tcb). +// +// This class is thread-compatible. +class DcSctpSocket : public DcSctpSocketInterface { + public: + // Instantiates a DcSctpSocket, which interacts with the world through the + // `callbacks` interface and is configured using `options`. + // + // For debugging, `log_prefix` will prefix all debug logs, and a + // `packet_observer` can be attached to e.g. dump sent and received packets. + DcSctpSocket(absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr<PacketObserver> packet_observer, + const DcSctpOptions& options); + + DcSctpSocket(const DcSctpSocket&) = delete; + DcSctpSocket& operator=(const DcSctpSocket&) = delete; + + // Implementation of `DcSctpSocketInterface`. + void ReceivePacket(rtc::ArrayView<const uint8_t> data) override; + void HandleTimeout(TimeoutID timeout_id) override; + void Connect() override; + void RestoreFromState(const DcSctpSocketHandoverState& state) override; + void Shutdown() override; + void Close() override; + SendStatus Send(DcSctpMessage message, + const SendOptions& send_options) override; + ResetStreamsStatus ResetStreams( + rtc::ArrayView<const StreamID> outgoing_streams) override; + SocketState state() const override; + const DcSctpOptions& options() const override { return options_; } + void SetMaxMessageSize(size_t max_message_size) override; + void SetStreamPriority(StreamID stream_id, StreamPriority priority) override; + StreamPriority GetStreamPriority(StreamID stream_id) const override; + size_t buffered_amount(StreamID stream_id) const override; + size_t buffered_amount_low_threshold(StreamID stream_id) const override; + void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; + absl::optional<Metrics> GetMetrics() const override; + HandoverReadinessStatus GetHandoverReadiness() const override; + absl::optional<DcSctpSocketHandoverState> GetHandoverStateAndClose() override; + SctpImplementation peer_implementation() const override { + return metrics_.peer_implementation; + } + // Returns this socket's verification tag, or zero if not yet connected. + VerificationTag verification_tag() const { + return tcb_ != nullptr ? tcb_->my_verification_tag() : VerificationTag(0); + } + + private: + // Parameter proposals valid during the connect phase. + struct ConnectParameters { + TSN initial_tsn = TSN(0); + VerificationTag verification_tag = VerificationTag(0); + }; + + // Detailed state (separate from SocketState, which is the public state). + enum class State { + kClosed, + kCookieWait, + // TCB valid in these: + kCookieEchoed, + kEstablished, + kShutdownPending, + kShutdownSent, + kShutdownReceived, + kShutdownAckSent, + }; + + // Returns the log prefix used for debug logging. + std::string log_prefix() const; + + bool IsConsistent() const; + static constexpr absl::string_view ToString(DcSctpSocket::State state); + + void CreateTransmissionControlBlock(const Capabilities& capabilities, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag); + + // Changes the socket state, given a `reason` (for debugging/logging). + void SetState(State state, absl::string_view reason); + // Fills in `connect_params` with random verification tag and initial TSN. + void MakeConnectionParameters(); + // Closes the association. Note that the TCB will not be valid past this call. + void InternalClose(ErrorKind error, absl::string_view message); + // Closes the association, because of too many retransmission errors. + void CloseConnectionBecauseOfTooManyTransmissionErrors(); + // Timer expiration handlers + absl::optional<DurationMs> OnInitTimerExpiry(); + absl::optional<DurationMs> OnCookieTimerExpiry(); + absl::optional<DurationMs> OnShutdownTimerExpiry(); + void OnSentPacket(rtc::ArrayView<const uint8_t> packet, + SendPacketStatus status); + // Sends SHUTDOWN or SHUTDOWN-ACK if the socket is shutting down and if all + // outstanding data has been acknowledged. + void MaybeSendShutdownOrAck(); + // If the socket is shutting down, responds SHUTDOWN to any incoming DATA. + void MaybeSendShutdownOnPacketReceived(const SctpPacket& packet); + // If there are streams pending to be reset, send a request to reset them. + void MaybeSendResetStreamsRequest(); + // Sends a INIT chunk. + void SendInit(); + // Sends a SHUTDOWN chunk. + void SendShutdown(); + // Sends a SHUTDOWN-ACK chunk. + void SendShutdownAck(); + // Validates the SCTP packet, as a whole - not the validity of individual + // chunks within it, as that's done in the different chunk handlers. + bool ValidatePacket(const SctpPacket& packet); + // Parses `payload`, which is a serialized packet that is just going to be + // sent and prints all chunks. + void DebugPrintOutgoing(rtc::ArrayView<const uint8_t> payload); + // Called whenever data has been received, or the cumulative acknowledgment + // TSN has moved, that may result in delivering messages. + void MaybeDeliverMessages(); + // Returns true if there is a TCB, and false otherwise (and reports an error). + bool ValidateHasTCB(); + + // Returns true if the parsing of a chunk of type `T` succeeded. If it didn't, + // it reports an error and returns false. + template <class T> + bool ValidateParseSuccess(const absl::optional<T>& c) { + if (c.has_value()) { + return true; + } + + ReportFailedToParseChunk(T::kType); + return false; + } + + // Reports failing to have parsed a chunk with the provided `chunk_type`. + void ReportFailedToParseChunk(int chunk_type); + // Called when unknown chunks are received. May report an error. + bool HandleUnrecognizedChunk(const SctpPacket::ChunkDescriptor& descriptor); + + // Will dispatch more specific chunk handlers. + bool Dispatch(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming DATA chunks. + void HandleData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming I-DATA chunks. + void HandleIData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Common handler for DATA and I-DATA chunks. + void HandleDataCommon(AnyDataChunk& chunk); + // Handles incoming INIT chunks. + void HandleInit(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming INIT-ACK chunks. + void HandleInitAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming SACK chunks. + void HandleSack(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming HEARTBEAT chunks. + void HandleHeartbeatRequest(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming HEARTBEAT-ACK chunks. + void HandleHeartbeatAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming ABORT chunks. + void HandleAbort(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming ERROR chunks. + void HandleError(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming COOKIE-ECHO chunks. + void HandleCookieEcho(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles receiving COOKIE-ECHO when there already is a TCB. The return value + // indicates if the processing should continue. + bool HandleCookieEchoWithTCB(const CommonHeader& header, + const StateCookie& cookie); + // Handles incoming COOKIE-ACK chunks. + void HandleCookieAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming SHUTDOWN chunks. + void HandleShutdown(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming SHUTDOWN-ACK chunks. + void HandleShutdownAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming FORWARD-TSN chunks. + void HandleForwardTsn(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming I-FORWARD-TSN chunks. + void HandleIForwardTsn(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming RE-CONFIG chunks. + void HandleReconfig(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Common handled for FORWARD-TSN/I-FORWARD-TSN. + void HandleForwardTsnCommon(const AnyForwardTsnChunk& chunk); + // Handles incoming SHUTDOWN-COMPLETE chunks + void HandleShutdownComplete(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + + const std::string log_prefix_; + const std::unique_ptr<PacketObserver> packet_observer_; + RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker thread_checker_; + Metrics metrics_; + DcSctpOptions options_; + + // Enqueues callbacks and dispatches them just before returning to the caller. + CallbackDeferrer callbacks_; + + TimerManager timer_manager_; + const std::unique_ptr<Timer> t1_init_; + const std::unique_ptr<Timer> t1_cookie_; + const std::unique_ptr<Timer> t2_shutdown_; + + // Packets that failed to be sent, but should be retried. + PacketSender packet_sender_; + + // The actual SendQueue implementation. As data can be sent on a socket before + // the connection is established, this component is not in the TCB. + RRSendQueue send_queue_; + + // Contains verification tag and initial TSN between having sent the INIT + // until the connection is established (there is no TCB at this point). + ConnectParameters connect_params_; + // The socket state. + State state_ = State::kClosed; + // If the connection is established, contains a transmission control block. + std::unique_ptr<TransmissionControlBlock> tcb_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_network_test.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_network_test.cc new file mode 100644 index 0000000000..f097bfa095 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_network_test.cc @@ -0,0 +1,518 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include <cstdint> +#include <deque> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/pending_task_safety_flag.h" +#include "api/task_queue/task_queue_base.h" +#include "api/test/create_network_emulation_manager.h" +#include "api/test/network_emulation_manager.h" +#include "api/units/time_delta.h" +#include "call/simulated_network.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/socket/dcsctp_socket.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "net/dcsctp/timer/task_queue_timeout.h" +#include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/gunit.h" +#include "rtc_base/logging.h" +#include "rtc_base/socket_address.h" +#include "rtc_base/strings/string_format.h" +#include "rtc_base/time_utils.h" +#include "test/gmock.h" + +#if !defined(WEBRTC_ANDROID) && defined(NDEBUG) && \ + !defined(THREAD_SANITIZER) && !defined(MEMORY_SANITIZER) +#define DCSCTP_NDEBUG_TEST(t) t +#else +// In debug mode, and when MSAN or TSAN sanitizers are enabled, these tests are +// too expensive to run due to extensive consistency checks that iterate on all +// outstanding chunks. Same with low-end Android devices, which have +// difficulties with these tests. +#define DCSCTP_NDEBUG_TEST(t) DISABLED_##t +#endif + +namespace dcsctp { +namespace { +using ::testing::AllOf; +using ::testing::Ge; +using ::testing::Le; +using ::testing::SizeIs; + +constexpr StreamID kStreamId(1); +constexpr PPID kPpid(53); +constexpr size_t kSmallPayloadSize = 10; +constexpr size_t kLargePayloadSize = 10000; +constexpr size_t kHugePayloadSize = 262144; +constexpr size_t kBufferedAmountLowThreshold = kLargePayloadSize * 2; +constexpr webrtc::TimeDelta kPrintBandwidthDuration = + webrtc::TimeDelta::Seconds(1); +constexpr webrtc::TimeDelta kBenchmarkRuntime(webrtc::TimeDelta::Seconds(10)); +constexpr webrtc::TimeDelta kAWhile(webrtc::TimeDelta::Seconds(1)); + +inline int GetUniqueSeed() { + static int seed = 0; + return ++seed; +} + +DcSctpOptions MakeOptionsForTest() { + DcSctpOptions options; + + // Throughput numbers are affected by the MTU. Ensure it's constant. + options.mtu = 1200; + + // By disabling the heartbeat interval, there will no timers at all running + // when the socket is idle, which makes it easy to just continue the test + // until there are no more scheduled tasks. Note that it _will_ run for longer + // than necessary as timers aren't cancelled when they are stopped (as that's + // not supported), but it's still simulated time and passes quickly. + options.heartbeat_interval = DurationMs(0); + return options; +} + +// When doing throughput tests, knowing what each actor should do. +enum class ActorMode { + kAtRest, + kThroughputSender, + kThroughputReceiver, + kLimitedRetransmissionSender, +}; + +// An abstraction around EmulatedEndpoint, representing a bound socket that +// will send its packet to a given destination. +class BoundSocket : public webrtc::EmulatedNetworkReceiverInterface { + public: + void Bind(webrtc::EmulatedEndpoint* endpoint) { + endpoint_ = endpoint; + uint16_t port = endpoint->BindReceiver(0, this).value(); + source_address_ = + rtc::SocketAddress(endpoint_->GetPeerLocalAddress(), port); + } + + void SetDestination(const BoundSocket& socket) { + dest_address_ = socket.source_address_; + } + + void SetReceiver(std::function<void(rtc::CopyOnWriteBuffer)> receiver) { + receiver_ = std::move(receiver); + } + + void SendPacket(rtc::ArrayView<const uint8_t> data) { + endpoint_->SendPacket(source_address_, dest_address_, + rtc::CopyOnWriteBuffer(data.data(), data.size())); + } + + private: + // Implementation of `webrtc::EmulatedNetworkReceiverInterface`. + void OnPacketReceived(webrtc::EmulatedIpPacket packet) override { + receiver_(std::move(packet.data)); + } + + std::function<void(rtc::CopyOnWriteBuffer)> receiver_; + webrtc::EmulatedEndpoint* endpoint_ = nullptr; + rtc::SocketAddress source_address_; + rtc::SocketAddress dest_address_; +}; + +// Sends at a constant rate but with random packet sizes. +class SctpActor : public DcSctpSocketCallbacks { + public: + SctpActor(absl::string_view name, + BoundSocket& emulated_socket, + const DcSctpOptions& sctp_options) + : log_prefix_(std::string(name) + ": "), + thread_(rtc::Thread::Current()), + emulated_socket_(emulated_socket), + timeout_factory_( + *thread_, + [this]() { return TimeMillis(); }, + [this](dcsctp::TimeoutID timeout_id) { + sctp_socket_.HandleTimeout(timeout_id); + }), + random_(GetUniqueSeed()), + sctp_socket_(name, *this, nullptr, sctp_options), + last_bandwidth_printout_(TimeMs(TimeMillis())) { + emulated_socket.SetReceiver([this](rtc::CopyOnWriteBuffer buf) { + // The receiver will be executed on the NetworkEmulation task queue, but + // the dcSCTP socket is owned by `thread_` and is not thread-safe. + thread_->PostTask([this, buf] { this->sctp_socket_.ReceivePacket(buf); }); + }); + } + + void PrintBandwidth() { + TimeMs now = TimeMillis(); + DurationMs duration = now - last_bandwidth_printout_; + + double bitrate_mbps = + static_cast<double>(received_bytes_ * 8) / *duration / 1000; + RTC_LOG(LS_INFO) << log_prefix() + << rtc::StringFormat("Received %0.2f Mbps", bitrate_mbps); + + received_bitrate_mbps_.push_back(bitrate_mbps); + received_bytes_ = 0; + last_bandwidth_printout_ = now; + // Print again in a second. + if (mode_ == ActorMode::kThroughputReceiver) { + thread_->PostDelayedTask( + SafeTask(safety_.flag(), [this] { PrintBandwidth(); }), + kPrintBandwidthDuration); + } + } + + void SendPacket(rtc::ArrayView<const uint8_t> data) override { + emulated_socket_.SendPacket(data); + } + + std::unique_ptr<Timeout> CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision) override { + return timeout_factory_.CreateTimeout(precision); + } + + TimeMs TimeMillis() override { return TimeMs(rtc::TimeMillis()); } + + uint32_t GetRandomInt(uint32_t low, uint32_t high) override { + return random_.Rand(low, high); + } + + void OnMessageReceived(DcSctpMessage message) override { + received_bytes_ += message.payload().size(); + last_received_message_ = std::move(message); + } + + void OnError(ErrorKind error, absl::string_view message) override { + RTC_LOG(LS_WARNING) << log_prefix() << "Socket error: " << ToString(error) + << "; " << message; + } + + void OnAborted(ErrorKind error, absl::string_view message) override { + RTC_LOG(LS_ERROR) << log_prefix() << "Socket abort: " << ToString(error) + << "; " << message; + } + + void OnConnected() override {} + + void OnClosed() override {} + + void OnConnectionRestarted() override {} + + void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams, + absl::string_view reason) override {} + + void OnStreamsResetPerformed( + rtc::ArrayView<const StreamID> outgoing_streams) override {} + + void OnIncomingStreamsReset( + rtc::ArrayView<const StreamID> incoming_streams) override {} + + void NotifyOutgoingMessageBufferEmpty() override {} + + void OnBufferedAmountLow(StreamID stream_id) override { + if (mode_ == ActorMode::kThroughputSender) { + std::vector<uint8_t> payload(kHugePayloadSize); + sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)), + SendOptions()); + + } else if (mode_ == ActorMode::kLimitedRetransmissionSender) { + while (sctp_socket_.buffered_amount(kStreamId) < + kBufferedAmountLowThreshold * 2) { + SendOptions send_options; + send_options.max_retransmissions = 0; + sctp_socket_.Send( + DcSctpMessage(kStreamId, kPpid, + std::vector<uint8_t>(kLargePayloadSize)), + send_options); + + send_options.max_retransmissions = absl::nullopt; + sctp_socket_.Send( + DcSctpMessage(kStreamId, kPpid, + std::vector<uint8_t>(kSmallPayloadSize)), + send_options); + } + } + } + + absl::optional<DcSctpMessage> ConsumeReceivedMessage() { + if (!last_received_message_.has_value()) { + return absl::nullopt; + } + DcSctpMessage ret = *std::move(last_received_message_); + last_received_message_ = absl::nullopt; + return ret; + } + + DcSctpSocket& sctp_socket() { return sctp_socket_; } + + void SetActorMode(ActorMode mode) { + mode_ = mode; + if (mode_ == ActorMode::kThroughputSender) { + sctp_socket_.SetBufferedAmountLowThreshold(kStreamId, + kBufferedAmountLowThreshold); + std::vector<uint8_t> payload(kHugePayloadSize); + sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)), + SendOptions()); + + } else if (mode_ == ActorMode::kLimitedRetransmissionSender) { + sctp_socket_.SetBufferedAmountLowThreshold(kStreamId, + kBufferedAmountLowThreshold); + std::vector<uint8_t> payload(kHugePayloadSize); + sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)), + SendOptions()); + + } else if (mode == ActorMode::kThroughputReceiver) { + thread_->PostDelayedTask( + SafeTask(safety_.flag(), [this] { PrintBandwidth(); }), + kPrintBandwidthDuration); + } + } + + // Returns the average bitrate, stripping the first `remove_first_n` that + // represent the time it took to ramp up the congestion control algorithm. + double avg_received_bitrate_mbps(size_t remove_first_n = 3) const { + std::vector<double> bitrates = received_bitrate_mbps_; + bitrates.erase(bitrates.begin(), bitrates.begin() + remove_first_n); + + double sum = 0; + for (double bitrate : bitrates) { + sum += bitrate; + } + + return sum / bitrates.size(); + } + + private: + std::string log_prefix() const { + rtc::StringBuilder sb; + sb << log_prefix_; + sb << rtc::TimeMillis(); + sb << ": "; + return sb.Release(); + } + + ActorMode mode_ = ActorMode::kAtRest; + const std::string log_prefix_; + rtc::Thread* thread_; + BoundSocket& emulated_socket_; + TaskQueueTimeoutFactory timeout_factory_; + webrtc::Random random_; + DcSctpSocket sctp_socket_; + size_t received_bytes_ = 0; + absl::optional<DcSctpMessage> last_received_message_; + TimeMs last_bandwidth_printout_; + // Per-second received bitrates, in Mbps + std::vector<double> received_bitrate_mbps_; + webrtc::ScopedTaskSafety safety_; +}; + +class DcSctpSocketNetworkTest : public testing::Test { + protected: + DcSctpSocketNetworkTest() + : options_(MakeOptionsForTest()), + emulation_(webrtc::CreateNetworkEmulationManager( + webrtc::TimeMode::kSimulated)) {} + + void MakeNetwork(const webrtc::BuiltInNetworkBehaviorConfig& config) { + webrtc::EmulatedEndpoint* endpoint_a = + emulation_->CreateEndpoint(webrtc::EmulatedEndpointConfig()); + webrtc::EmulatedEndpoint* endpoint_z = + emulation_->CreateEndpoint(webrtc::EmulatedEndpointConfig()); + + webrtc::EmulatedNetworkNode* node1 = emulation_->CreateEmulatedNode(config); + webrtc::EmulatedNetworkNode* node2 = emulation_->CreateEmulatedNode(config); + + emulation_->CreateRoute(endpoint_a, {node1}, endpoint_z); + emulation_->CreateRoute(endpoint_z, {node2}, endpoint_a); + + emulated_socket_a_.Bind(endpoint_a); + emulated_socket_z_.Bind(endpoint_z); + + emulated_socket_a_.SetDestination(emulated_socket_z_); + emulated_socket_z_.SetDestination(emulated_socket_a_); + } + + void Sleep(webrtc::TimeDelta duration) { + // Sleep in one-millisecond increments, to let timers expire when expected. + for (int i = 0; i < duration.ms(); ++i) { + emulation_->time_controller()->AdvanceTime(webrtc::TimeDelta::Millis(1)); + } + } + + DcSctpOptions options_; + std::unique_ptr<webrtc::NetworkEmulationManager> emulation_; + BoundSocket emulated_socket_a_; + BoundSocket emulated_socket_z_; +}; + +TEST_F(DcSctpSocketNetworkTest, CanConnectAndShutdown) { + webrtc::BuiltInNetworkBehaviorConfig pipe_config; + MakeNetwork(pipe_config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + EXPECT_THAT(sender.sctp_socket().state(), SocketState::kClosed); + + sender.sctp_socket().Connect(); + Sleep(kAWhile); + EXPECT_THAT(sender.sctp_socket().state(), SocketState::kConnected); + + sender.sctp_socket().Shutdown(); + Sleep(kAWhile); + EXPECT_THAT(sender.sctp_socket().state(), SocketState::kClosed); +} + +TEST_F(DcSctpSocketNetworkTest, CanSendLargeMessage) { + webrtc::BuiltInNetworkBehaviorConfig pipe_config; + pipe_config.queue_delay_ms = 30; + MakeNetwork(pipe_config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + sender.sctp_socket().Connect(); + + constexpr size_t kPayloadSize = 100 * 1024; + + std::vector<uint8_t> payload(kPayloadSize); + sender.sctp_socket().Send(DcSctpMessage(kStreamId, kPpid, payload), + SendOptions()); + + Sleep(kAWhile); + + ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage message, + receiver.ConsumeReceivedMessage()); + + EXPECT_THAT(message.payload(), SizeIs(kPayloadSize)); + + sender.sctp_socket().Shutdown(); + Sleep(kAWhile); +} + +TEST_F(DcSctpSocketNetworkTest, CanSendMessagesReliablyWithLowBandwidth) { + webrtc::BuiltInNetworkBehaviorConfig pipe_config; + pipe_config.queue_delay_ms = 30; + pipe_config.link_capacity_kbps = 1000; + MakeNetwork(pipe_config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + sender.sctp_socket().Connect(); + + sender.SetActorMode(ActorMode::kThroughputSender); + receiver.SetActorMode(ActorMode::kThroughputReceiver); + + Sleep(kBenchmarkRuntime); + sender.SetActorMode(ActorMode::kAtRest); + receiver.SetActorMode(ActorMode::kAtRest); + + Sleep(kAWhile); + + sender.sctp_socket().Shutdown(); + + Sleep(kAWhile); + + // Verify that the bitrates are in the range of 0.5-1.0 Mbps. + double bitrate = receiver.avg_received_bitrate_mbps(); + EXPECT_THAT(bitrate, AllOf(Ge(0.5), Le(1.0))); +} + +TEST_F(DcSctpSocketNetworkTest, + DCSCTP_NDEBUG_TEST(CanSendMessagesReliablyWithMediumBandwidth)) { + webrtc::BuiltInNetworkBehaviorConfig pipe_config; + pipe_config.queue_delay_ms = 30; + pipe_config.link_capacity_kbps = 18000; + MakeNetwork(pipe_config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + sender.sctp_socket().Connect(); + + sender.SetActorMode(ActorMode::kThroughputSender); + receiver.SetActorMode(ActorMode::kThroughputReceiver); + + Sleep(kBenchmarkRuntime); + sender.SetActorMode(ActorMode::kAtRest); + receiver.SetActorMode(ActorMode::kAtRest); + + Sleep(kAWhile); + + sender.sctp_socket().Shutdown(); + + Sleep(kAWhile); + + // Verify that the bitrates are in the range of 16-18 Mbps. + double bitrate = receiver.avg_received_bitrate_mbps(); + EXPECT_THAT(bitrate, AllOf(Ge(16), Le(18))); +} + +TEST_F(DcSctpSocketNetworkTest, CanSendMessagesReliablyWithMuchPacketLoss) { + webrtc::BuiltInNetworkBehaviorConfig config; + config.queue_delay_ms = 30; + config.loss_percent = 1; + MakeNetwork(config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + sender.sctp_socket().Connect(); + + sender.SetActorMode(ActorMode::kThroughputSender); + receiver.SetActorMode(ActorMode::kThroughputReceiver); + + Sleep(kBenchmarkRuntime); + sender.SetActorMode(ActorMode::kAtRest); + receiver.SetActorMode(ActorMode::kAtRest); + + Sleep(kAWhile); + + sender.sctp_socket().Shutdown(); + + Sleep(kAWhile); + + // TCP calculator gives: 1200 MTU, 60ms RTT and 1% packet loss -> 1.6Mbps. + // This test is doing slightly better (doesn't have any additional header + // overhead etc). Verify that the bitrates are in the range of 1.5-2.5 Mbps. + double bitrate = receiver.avg_received_bitrate_mbps(); + EXPECT_THAT(bitrate, AllOf(Ge(1.5), Le(2.5))); +} + +TEST_F(DcSctpSocketNetworkTest, DCSCTP_NDEBUG_TEST(HasHighBandwidth)) { + webrtc::BuiltInNetworkBehaviorConfig pipe_config; + pipe_config.queue_delay_ms = 30; + MakeNetwork(pipe_config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + sender.sctp_socket().Connect(); + + sender.SetActorMode(ActorMode::kThroughputSender); + receiver.SetActorMode(ActorMode::kThroughputReceiver); + + Sleep(kBenchmarkRuntime); + + sender.SetActorMode(ActorMode::kAtRest); + receiver.SetActorMode(ActorMode::kAtRest); + Sleep(kAWhile); + + sender.sctp_socket().Shutdown(); + Sleep(kAWhile); + + // Verify that the bitrate is in the range of 540-640 Mbps + double bitrate = receiver.avg_received_bitrate_mbps(); + EXPECT_THAT(bitrate, AllOf(Ge(520), Le(640))); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc new file mode 100644 index 0000000000..13202846ac --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc @@ -0,0 +1,3058 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/dcsctp_socket.h" + +#include <algorithm> +#include <cstdint> +#include <deque> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/chunk/abort_chunk.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/text_pcap_packet_observer.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +ABSL_FLAG(bool, dcsctp_capture_packets, false, "Print packet capture."); + +namespace dcsctp { +namespace { +using ::testing::_; +using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::Property; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +constexpr SendOptions kSendOptions; +constexpr size_t kLargeMessageSize = DcSctpOptions::kMaxSafeMTUSize * 20; +constexpr size_t kSmallMessageSize = 10; +constexpr int kMaxBurstPackets = 4; +constexpr DcSctpOptions kDefaultOptions; + +MATCHER_P(HasChunks, chunks, "") { + absl::optional<SctpPacket> packet = SctpPacket::Parse(arg, kDefaultOptions); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + return ExplainMatchResult(chunks, packet->descriptors(), result_listener); +} + +MATCHER_P(IsChunkType, chunk_type, "") { + return ExplainMatchResult(chunk_type, arg.type, result_listener); +} + +MATCHER_P(IsDataChunk, properties, "") { + if (arg.type != DataChunk::kType) { + *result_listener << "the chunk is not a data chunk"; + return false; + } + + absl::optional<DataChunk> chunk = DataChunk::Parse(arg.data); + if (!chunk.has_value()) { + *result_listener << "The chunk didn't parse as a data chunk"; + return false; + } + + return ExplainMatchResult(properties, *chunk, result_listener); +} + +MATCHER_P(IsSack, properties, "") { + if (arg.type != SackChunk::kType) { + *result_listener << "the chunk is not a sack chunk"; + return false; + } + + absl::optional<SackChunk> chunk = SackChunk::Parse(arg.data); + if (!chunk.has_value()) { + *result_listener << "The chunk didn't parse as a sack chunk"; + return false; + } + + return ExplainMatchResult(properties, *chunk, result_listener); +} + +MATCHER_P(IsReConfig, properties, "") { + if (arg.type != ReConfigChunk::kType) { + *result_listener << "the chunk is not a re-config chunk"; + return false; + } + + absl::optional<ReConfigChunk> chunk = ReConfigChunk::Parse(arg.data); + if (!chunk.has_value()) { + *result_listener << "The chunk didn't parse as a re-config chunk"; + return false; + } + + return ExplainMatchResult(properties, *chunk, result_listener); +} + +MATCHER_P(IsHeartbeatAck, properties, "") { + if (arg.type != HeartbeatAckChunk::kType) { + *result_listener << "the chunk is not a HeartbeatAckChunk"; + return false; + } + + absl::optional<HeartbeatAckChunk> chunk = HeartbeatAckChunk::Parse(arg.data); + if (!chunk.has_value()) { + *result_listener << "The chunk didn't parse as a HeartbeatAckChunk"; + return false; + } + + return ExplainMatchResult(properties, *chunk, result_listener); +} + +MATCHER_P(IsHeartbeatRequest, properties, "") { + if (arg.type != HeartbeatRequestChunk::kType) { + *result_listener << "the chunk is not a HeartbeatRequestChunk"; + return false; + } + + absl::optional<HeartbeatRequestChunk> chunk = + HeartbeatRequestChunk::Parse(arg.data); + if (!chunk.has_value()) { + *result_listener << "The chunk didn't parse as a HeartbeatRequestChunk"; + return false; + } + + return ExplainMatchResult(properties, *chunk, result_listener); +} + +MATCHER_P(HasParameters, parameters, "") { + return ExplainMatchResult(parameters, arg.parameters().descriptors(), + result_listener); +} + +MATCHER_P(IsOutgoingResetRequest, properties, "") { + if (arg.type != OutgoingSSNResetRequestParameter::kType) { + *result_listener + << "the parameter is not an OutgoingSSNResetRequestParameter"; + return false; + } + + absl::optional<OutgoingSSNResetRequestParameter> parameter = + OutgoingSSNResetRequestParameter::Parse(arg.data); + if (!parameter.has_value()) { + *result_listener + << "The parameter didn't parse as an OutgoingSSNResetRequestParameter"; + return false; + } + + return ExplainMatchResult(properties, *parameter, result_listener); +} + +MATCHER_P(IsReconfigurationResponse, properties, "") { + if (arg.type != ReconfigurationResponseParameter::kType) { + *result_listener + << "the parameter is not an ReconfigurationResponseParameter"; + return false; + } + + absl::optional<ReconfigurationResponseParameter> parameter = + ReconfigurationResponseParameter::Parse(arg.data); + if (!parameter.has_value()) { + *result_listener + << "The parameter didn't parse as an ReconfigurationResponseParameter"; + return false; + } + + return ExplainMatchResult(properties, *parameter, result_listener); +} + +TSN AddTo(TSN tsn, int delta) { + return TSN(*tsn + delta); +} + +DcSctpOptions FixupOptions(DcSctpOptions options = {}) { + DcSctpOptions fixup = options; + // To make the interval more predictable in tests. + fixup.heartbeat_interval_include_rtt = false; + fixup.max_burst = kMaxBurstPackets; + return fixup; +} + +std::unique_ptr<PacketObserver> GetPacketObserver(absl::string_view name) { + if (absl::GetFlag(FLAGS_dcsctp_capture_packets)) { + return std::make_unique<TextPcapPacketObserver>(name); + } + return nullptr; +} + +struct SocketUnderTest { + explicit SocketUnderTest(absl::string_view name, + const DcSctpOptions& opts = kDefaultOptions) + : options(FixupOptions(opts)), + cb(name), + socket(name, cb, GetPacketObserver(name), options) {} + + const DcSctpOptions options; + testing::NiceMock<MockDcSctpSocketCallbacks> cb; + DcSctpSocket socket; +}; + +void ExchangeMessages(SocketUnderTest& a, SocketUnderTest& z) { + bool delivered_packet = false; + do { + delivered_packet = false; + std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket(); + if (!packet_from_a.empty()) { + delivered_packet = true; + z.socket.ReceivePacket(std::move(packet_from_a)); + } + std::vector<uint8_t> packet_from_z = z.cb.ConsumeSentPacket(); + if (!packet_from_z.empty()) { + delivered_packet = true; + a.socket.ReceivePacket(std::move(packet_from_z)); + } + } while (delivered_packet); +} + +void RunTimers(SocketUnderTest& s) { + for (;;) { + absl::optional<TimeoutID> timeout_id = s.cb.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + s.socket.HandleTimeout(*timeout_id); + } +} + +void AdvanceTime(SocketUnderTest& a, SocketUnderTest& z, DurationMs duration) { + a.cb.AdvanceTime(duration); + z.cb.AdvanceTime(duration); + + RunTimers(a); + RunTimers(z); +} + +// Exchanges messages between `a` and `z`, advancing time until there are no +// more pending timers, or until `max_timeout` is reached. +void ExchangeMessagesAndAdvanceTime( + SocketUnderTest& a, + SocketUnderTest& z, + DurationMs max_timeout = DurationMs(10000)) { + TimeMs time_started = a.cb.TimeMillis(); + while (a.cb.TimeMillis() - time_started < max_timeout) { + ExchangeMessages(a, z); + + DurationMs time_to_next_timeout = + std::min(a.cb.GetTimeToNextTimeout(), z.cb.GetTimeToNextTimeout()); + if (time_to_next_timeout == DurationMs::InfiniteDuration()) { + // No more pending timer. + return; + } + AdvanceTime(a, z, time_to_next_timeout); + } +} + +// Calls Connect() on `sock_a_` and make the connection established. +void ConnectSockets(SocketUnderTest& a, SocketUnderTest& z) { + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + + a.socket.Connect(); + // Z reads INIT, INIT_ACK, COOKIE_ECHO, COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +std::unique_ptr<SocketUnderTest> HandoverSocket( + std::unique_ptr<SocketUnderTest> sut) { + EXPECT_EQ(sut->socket.GetHandoverReadiness(), HandoverReadinessStatus()); + + bool is_closed = sut->socket.state() == SocketState::kClosed; + if (!is_closed) { + EXPECT_CALL(sut->cb, OnClosed).Times(1); + } + absl::optional<DcSctpSocketHandoverState> handover_state = + sut->socket.GetHandoverStateAndClose(); + EXPECT_TRUE(handover_state.has_value()); + g_handover_state_transformer_for_test(&*handover_state); + + auto handover_socket = std::make_unique<SocketUnderTest>("H", sut->options); + if (!is_closed) { + EXPECT_CALL(handover_socket->cb, OnConnected).Times(1); + } + handover_socket->socket.RestoreFromState(*handover_state); + return handover_socket; +} + +std::vector<uint32_t> GetReceivedMessagePpids(SocketUnderTest& z) { + std::vector<uint32_t> ppids; + for (;;) { + absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage(); + if (!msg.has_value()) { + break; + } + ppids.push_back(*msg->ppid()); + } + return ppids; +} + +// Test parameter that controls whether to perform handovers during the test. A +// test can have multiple points where it conditionally hands over socket Z. +// Either socket Z will be handed over at all those points or handed over never. +enum class HandoverMode { + kNoHandover, + kPerformHandovers, +}; + +class DcSctpSocketParametrizedTest + : public ::testing::Test, + public ::testing::WithParamInterface<HandoverMode> { + protected: + // Trigger handover for `sut` depending on the current test param. + std::unique_ptr<SocketUnderTest> MaybeHandoverSocket( + std::unique_ptr<SocketUnderTest> sut) { + if (GetParam() == HandoverMode::kPerformHandovers) { + return HandoverSocket(std::move(sut)); + } + return sut; + } + + // Trigger handover for socket Z depending on the current test param. + // Then checks message passing to verify the handed over socket is functional. + void MaybeHandoverSocketAndSendMessage(SocketUnderTest& a, + std::unique_ptr<SocketUnderTest> z) { + if (GetParam() == HandoverMode::kPerformHandovers) { + z = HandoverSocket(std::move(z)); + } + + ExchangeMessages(a, *z); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + ExchangeMessages(a, *z); + + absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + } +}; + +INSTANTIATE_TEST_SUITE_P(Handovers, + DcSctpSocketParametrizedTest, + testing::Values(HandoverMode::kNoHandover, + HandoverMode::kPerformHandovers), + [](const auto& test_info) { + return test_info.param == + HandoverMode::kPerformHandovers + ? "WithHandovers" + : "NoHandover"; + }); + +TEST(DcSctpSocketTest, EstablishConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + + a.socket.Connect(); + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, EstablishConnectionWithSetupCollision) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + a.socket.Connect(); + z.socket.Connect(); + + ExchangeMessages(a, z); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, ShuttingDownWhileEstablishingConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(0); + EXPECT_CALL(z.cb, OnConnected).Times(1); + a.socket.Connect(); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Drop COOKIE_ACK, just to more easily verify shutdown protocol. + z.cb.ConsumeSentPacket(); + + // As Socket A has received INIT_ACK, it has a TCB and is connected, while + // Socket Z needs to receive COOKIE_ECHO to get there. Socket A still has + // timers running at this point. + EXPECT_EQ(a.socket.state(), SocketState::kConnecting); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); + + // Socket A is now shut down, which should make it stop those timers. + a.socket.Shutdown(); + + EXPECT_CALL(a.cb, OnClosed).Times(1); + EXPECT_CALL(z.cb, OnClosed).Times(1); + + // Z reads SHUTDOWN, produces SHUTDOWN_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // Z reads SHUTDOWN_COMPLETE. + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + EXPECT_TRUE(a.cb.ConsumeSentPacket().empty()); + EXPECT_TRUE(z.cb.ConsumeSentPacket().empty()); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); + EXPECT_EQ(z.socket.state(), SocketState::kClosed); +} + +TEST(DcSctpSocketTest, EstablishSimultaneousConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + a.socket.Connect(); + + // INIT isn't received by Z, as it wasn't ready yet. + a.cb.ConsumeSentPacket(); + + z.socket.Connect(); + + // A reads INIT, produces INIT_ACK + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // Z reads INIT_ACK, sends COOKIE_ECHO + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + // A reads COOKIE_ECHO - establishes connection. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + + // Proceed with the remaining packets. + ExchangeMessages(a, z); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, EstablishConnectionLostCookieAck) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + + a.socket.Connect(); + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // COOKIE_ACK is lost. + z.cb.ConsumeSentPacket(); + + EXPECT_EQ(a.socket.state(), SocketState::kConnecting); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); + + // This will make A re-send the COOKIE_ECHO + AdvanceTime(a, z, DurationMs(a.options.t1_cookie_timeout)); + + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, ResendInitAndEstablishConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + // INIT is never received by Z. + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsChunkType(InitChunk::kType)))); + + AdvanceTime(a, z, a.options.t1_init_timeout); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, ResendingInitTooManyTimesAborts) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + + // INIT is never received by Z. + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsChunkType(InitChunk::kType)))); + + for (int i = 0; i < *a.options.max_init_retransmits; ++i) { + AdvanceTime(a, z, a.options.t1_init_timeout * (1 << i)); + + // INIT is resent + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsChunkType(InitChunk::kType)))); + } + + // Another timeout, after the max init retransmits. + EXPECT_CALL(a.cb, OnAborted).Times(1); + AdvanceTime( + a, z, a.options.t1_init_timeout * (1 << *a.options.max_init_retransmits)); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); +} + +TEST(DcSctpSocketTest, ResendCookieEchoAndEstablishConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // COOKIE_ECHO is never received by Z. + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsChunkType(CookieEchoChunk::kType)))); + + AdvanceTime(a, z, a.options.t1_init_timeout); + + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, ResendingCookieEchoTooManyTimesAborts) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // COOKIE_ECHO is never received by Z. + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsChunkType(CookieEchoChunk::kType)))); + + for (int i = 0; i < *a.options.max_init_retransmits; ++i) { + AdvanceTime(a, z, a.options.t1_cookie_timeout * (1 << i)); + + // COOKIE_ECHO is resent + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsChunkType(CookieEchoChunk::kType)))); + } + + // Another timeout, after the max init retransmits. + EXPECT_CALL(a.cb, OnAborted).Times(1); + AdvanceTime( + a, z, + a.options.t1_cookie_timeout * (1 << *a.options.max_init_retransmits)); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); +} + +TEST(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + a.socket.Connect(); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // COOKIE_ECHO is never received by Z. + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsChunkType(CookieEchoChunk::kType), + IsDataChunk(_)))); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + // There are DATA chunks in the sent packet (that was lost), which means that + // the T3-RTX timer is running, but as the socket is in kCookieEcho state, it + // will be T1-COOKIE that drives retransmissions, so when the T3-RTX expires, + // nothing should be retransmitted. + ASSERT_TRUE(a.options.rto_initial < a.options.t1_cookie_timeout); + AdvanceTime(a, z, a.options.rto_initial); + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + // When T1-COOKIE expires, both the COOKIE-ECHO and DATA should be present. + AdvanceTime(a, z, a.options.t1_cookie_timeout - a.options.rto_initial); + + // And this COOKIE-ECHO and DATA is also lost - never received by Z. + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsChunkType(CookieEchoChunk::kType), + IsDataChunk(_)))); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + // COOKIE_ECHO has exponential backoff. + AdvanceTime(a, z, a.options.t1_cookie_timeout * 2); + + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); + + ExchangeMessages(a, z); + EXPECT_THAT(z.cb.ConsumeReceivedMessage()->payload(), + SizeIs(kLargeMessageSize)); +} + +TEST_P(DcSctpSocketParametrizedTest, ShutdownConnection) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + RTC_LOG(LS_INFO) << "Shutting down"; + + EXPECT_CALL(z->cb, OnClosed).Times(1); + a.socket.Shutdown(); + // Z reads SHUTDOWN, produces SHUTDOWN_ACK + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + // Z reads SHUTDOWN_COMPLETE. + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); + EXPECT_EQ(z->socket.state(), SocketState::kClosed); + + z = MaybeHandoverSocket(std::move(z)); + EXPECT_EQ(z->socket.state(), SocketState::kClosed); +} + +TEST(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + a.socket.Shutdown(); + // Drop first SHUTDOWN packet. + a.cb.ConsumeSentPacket(); + + EXPECT_EQ(a.socket.state(), SocketState::kShuttingDown); + + for (int i = 0; i < *a.options.max_retransmissions; ++i) { + AdvanceTime(a, z, DurationMs(a.options.rto_initial * (1 << i))); + + // Dropping every shutdown chunk. + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsChunkType(ShutdownChunk::kType)))); + EXPECT_TRUE(a.cb.ConsumeSentPacket().empty()); + } + // The last expiry, makes it abort the connection. + EXPECT_CALL(a.cb, OnAborted).Times(1); + AdvanceTime(a, z, + a.options.rto_initial * (1 << *a.options.max_retransmissions)); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsChunkType(AbortChunk::kType)))); + EXPECT_TRUE(a.cb.ConsumeSentPacket().empty()); +} + +TEST(DcSctpSocketTest, EstablishConnectionWhileSendingData) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); + + absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); +} + +TEST(DcSctpSocketTest, SendMessageAfterEstablished) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); +} + +TEST_P(DcSctpSocketParametrizedTest, TimeoutResendsPacket) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + a.cb.ConsumeSentPacket(); + + RTC_LOG(LS_INFO) << "Advancing time"; + AdvanceTime(a, *z, a.options.rto_initial); + + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, SendALotOfBytesMissedSecondPacket) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + std::vector<uint8_t> payload(kLargeMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + // First DATA + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Second DATA (lost) + a.cb.ConsumeSentPacket(); + + // Retransmit and handle the rest + ExchangeMessages(a, *z); + + absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload)); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + // Inject a HEARTBEAT chunk + SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions()); + uint8_t info[] = {1, 2, 3, 4}; + Parameters::Builder params_builder; + params_builder.Add(HeartbeatInfoParameter(info)); + b.Add(HeartbeatRequestChunk(params_builder.Build())); + a.socket.ReceivePacket(b.Build()); + + // HEARTBEAT_ACK is sent as a reply. Capture it. + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsHeartbeatAck( + Property(&HeartbeatAckChunk::info, + Optional(Property(&HeartbeatInfoParameter::info, + ElementsAre(1, 2, 3, 4)))))))); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + AdvanceTime(a, *z, a.options.heartbeat_interval); + + std::vector<uint8_t> packet = a.cb.ConsumeSentPacket(); + // The info is a single 64-bit number. + EXPECT_THAT( + packet, + HasChunks(ElementsAre(IsHeartbeatRequest(Property( + &HeartbeatRequestChunk::info, + Optional(Property(&HeartbeatInfoParameter::info, SizeIs(8)))))))); + + // Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back. + z->socket.ReceivePacket(packet); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + CloseConnectionAfterTooManyLostHeartbeats) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(z->cb, OnClosed).Times(1); + EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty()); + // Force-close socket Z so that it doesn't interfere from now on. + z->socket.Close(); + + DurationMs time_to_next_hearbeat = a.options.heartbeat_interval; + + for (int i = 0; i < *a.options.max_retransmissions; ++i) { + RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending..."; + AdvanceTime(a, *z, time_to_next_hearbeat); + + // Dropping every heartbeat. + ASSERT_HAS_VALUE_AND_ASSIGN( + SctpPacket hb_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket(), z->options)); + EXPECT_EQ(hb_packet.descriptors()[0].type, HeartbeatRequestChunk::kType); + + RTC_LOG(LS_INFO) << "Letting the heartbeat expire."; + AdvanceTime(a, *z, DurationMs(1000)); + + time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000); + } + + RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending..."; + AdvanceTime(a, *z, time_to_next_hearbeat); + + // Last heartbeat + EXPECT_THAT(a.cb.ConsumeSentPacket(), Not(IsEmpty())); + + EXPECT_CALL(a.cb, OnAborted).Times(1); + // Should suffice as exceeding RTO + AdvanceTime(a, *z, DurationMs(1000)); + + z = MaybeHandoverSocket(std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty()); + EXPECT_CALL(z->cb, OnClosed).Times(1); + // Force-close socket Z so that it doesn't interfere from now on. + z->socket.Close(); + + DurationMs time_to_next_hearbeat = a.options.heartbeat_interval; + + for (int i = 0; i < *a.options.max_retransmissions; ++i) { + AdvanceTime(a, *z, time_to_next_hearbeat); + + // Dropping every heartbeat. + a.cb.ConsumeSentPacket(); + + RTC_LOG(LS_INFO) << "Letting the heartbeat expire."; + AdvanceTime(a, *z, DurationMs(1000)); + + time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000); + } + + RTC_LOG(LS_INFO) << "Getting the last heartbeat - and acking it"; + AdvanceTime(a, *z, time_to_next_hearbeat); + + std::vector<uint8_t> hb_packet_raw = a.cb.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet, + SctpPacket::Parse(hb_packet_raw, z->options)); + ASSERT_THAT(hb_packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk hb, + HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data)); + + SctpPacket::Builder b(a.socket.verification_tag(), a.options); + b.Add(HeartbeatAckChunk(std::move(hb).extract_parameters())); + a.socket.ReceivePacket(b.Build()); + + // Should suffice as exceeding RTO - which will not fire. + EXPECT_CALL(a.cb, OnAborted).Times(0); + AdvanceTime(a, *z, DurationMs(1000)); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + // Verify that we get new heartbeats again. + RTC_LOG(LS_INFO) << "Expecting a new heartbeat"; + AdvanceTime(a, *z, time_to_next_hearbeat); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsHeartbeatRequest(_)))); +} + +TEST_P(DcSctpSocketParametrizedTest, ResetStream) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {}); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + + // Handle SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Reset the outgoing stream. This will directly send a RE-CONFIG. + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + + // Receiving the packet will trigger a callback, indicating that A has + // reset its stream. It will also send a RE-CONFIG with a response. + EXPECT_CALL(z->cb, OnIncomingStreamsReset).Times(1); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + // Receiving a response will trigger a callback. Streams are now reset. + EXPECT_CALL(a.cb, OnStreamsResetPerformed).Times(1); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillMakeChunksStartAtZeroSsn) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + std::vector<uint8_t> payload(a.options.mtu - 100); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto packet1 = a.cb.ConsumeSentPacket(); + EXPECT_THAT( + packet1, + HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0)))))); + z->socket.ReceivePacket(packet1); + + auto packet2 = a.cb.ConsumeSentPacket(); + EXPECT_THAT( + packet2, + HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(1)))))); + z->socket.ReceivePacket(packet2); + + // Handle SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->stream_id(), StreamID(1)); + + absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->stream_id(), StreamID(1)); + + // Reset the outgoing stream. This will directly send a RE-CONFIG. + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + // RE-CONFIG, req + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // RE-CONFIG, resp + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto packet3 = a.cb.ConsumeSentPacket(); + EXPECT_THAT( + packet3, + HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0)))))); + z->socket.ReceivePacket(packet3); + + auto packet4 = a.cb.ConsumeSentPacket(); + EXPECT_THAT( + packet4, + HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(1)))))); + z->socket.ReceivePacket(packet4); + + // Handle SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + ResetStreamWillOnlyResetTheRequestedStreams) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + std::vector<uint8_t> payload(a.options.mtu - 100); + + // Send two ordered messages on SID 1 + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto packet1 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet1, HasChunks(ElementsAre(IsDataChunk( + AllOf(Property(&DataChunk::stream_id, StreamID(1)), + Property(&DataChunk::ssn, SSN(0))))))); + z->socket.ReceivePacket(packet1); + + auto packet2 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet2, HasChunks(ElementsAre(IsDataChunk( + AllOf(Property(&DataChunk::stream_id, StreamID(1)), + Property(&DataChunk::ssn, SSN(1))))))); + z->socket.ReceivePacket(packet2); + + // Handle SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Do the same, for SID 3 + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + auto packet3 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet3, HasChunks(ElementsAre(IsDataChunk( + AllOf(Property(&DataChunk::stream_id, StreamID(3)), + Property(&DataChunk::ssn, SSN(0))))))); + z->socket.ReceivePacket(packet3); + auto packet4 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet4, HasChunks(ElementsAre(IsDataChunk( + AllOf(Property(&DataChunk::stream_id, StreamID(3)), + Property(&DataChunk::ssn, SSN(1))))))); + z->socket.ReceivePacket(packet4); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Receive all messages. + absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->stream_id(), StreamID(1)); + + absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->stream_id(), StreamID(1)); + + absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg3.has_value()); + EXPECT_EQ(msg3->stream_id(), StreamID(3)); + + absl::optional<DcSctpMessage> msg4 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg4.has_value()); + EXPECT_EQ(msg4->stream_id(), StreamID(3)); + + // Reset SID 1. This will directly send a RE-CONFIG. + a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)})); + // RE-CONFIG, req + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // RE-CONFIG, resp + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Send a message on SID 1 and 3 - SID 1 should not be reset, but 3 should. + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + + auto packet5 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet5, + HasChunks(ElementsAre(IsDataChunk( + AllOf(Property(&DataChunk::stream_id, StreamID(1)), + Property(&DataChunk::ssn, SSN(2))))))); // Unchanged. + z->socket.ReceivePacket(packet5); + + auto packet6 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet6, HasChunks(ElementsAre(IsDataChunk(AllOf( + Property(&DataChunk::stream_id, StreamID(3)), + Property(&DataChunk::ssn, SSN(0))))))); // Reset + z->socket.ReceivePacket(packet6); + + // Handle SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, OnePeerReconnects) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(1); + // Let's be evil here - reconnect while a fragmented packet was about to be + // sent. The receiving side should get it in full. + std::vector<uint8_t> payload(kLargeMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + // First DATA + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + // Create a new association, z2 - and don't use z anymore. + SocketUnderTest z2("Z2"); + z2.socket.Connect(); + + // Retransmit and handle the rest. As there will be some chunks in-flight that + // have the wrong verification tag, those will yield errors. + ExchangeMessages(a, z2); + + absl::optional<DcSctpMessage> msg = z2.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload)); +} + +TEST_P(DcSctpSocketParametrizedTest, SendMessageWithLimitedRtx) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + SendOptions send_options; + send_options.max_retransmissions = 0; + std::vector<uint8_t> payload(a.options.mtu - 100); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options); + + // First DATA + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Second DATA (lost) + a.cb.ConsumeSentPacket(); + // Third DATA + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + // Handle SACK for first DATA + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Handle delayed SACK for third DATA + AdvanceTime(a, *z, a.options.delayed_ack_max_timeout); + + // Handle SACK for second DATA + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Now the missing data chunk will be marked as nacked, but it might still be + // in-flight and the reported gap could be due to out-of-order delivery. So + // the RetransmissionQueue will not mark it as "to be retransmitted" until + // after the t3-rtx timer has expired. + AdvanceTime(a, *z, a.options.rto_initial); + + // The chunk will be marked as retransmitted, and then as abandoned, which + // will trigger a FORWARD-TSN to be sent. + + // FORWARD-TSN (third) + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + // Which will trigger a SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->ppid(), PPID(51)); + + absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->ppid(), PPID(53)); + + absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage(); + EXPECT_FALSE(msg3.has_value()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, SendManyFragmentedMessagesWithLimitedRtx) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + SendOptions send_options; + send_options.unordered = IsUnordered(true); + send_options.max_retransmissions = 0; + std::vector<uint8_t> payload(a.options.mtu * 2 - 100 /* margin */); + // Sending first message + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + // Sending second message + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); + // Sending third message + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options); + // Sending fourth message + a.socket.Send(DcSctpMessage(StreamID(1), PPID(54), payload), send_options); + + // First DATA, first fragment + std::vector<uint8_t> packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasChunks(ElementsAre( + IsDataChunk(Property(&DataChunk::ppid, PPID(51)))))); + z->socket.ReceivePacket(std::move(packet)); + + // First DATA, second fragment (lost) + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasChunks(ElementsAre( + IsDataChunk(Property(&DataChunk::ppid, PPID(51)))))); + + // Second DATA, first fragment + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasChunks(ElementsAre( + IsDataChunk(Property(&DataChunk::ppid, PPID(52)))))); + z->socket.ReceivePacket(std::move(packet)); + + // Second DATA, second fragment (lost) + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasChunks(ElementsAre(IsDataChunk( + AllOf(Property(&DataChunk::ppid, PPID(52)), + Property(&DataChunk::ssn, SSN(0))))))); + + // Third DATA, first fragment + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasChunks(ElementsAre(IsDataChunk( + AllOf(Property(&DataChunk::ppid, PPID(53)), + Property(&DataChunk::ssn, SSN(0))))))); + z->socket.ReceivePacket(std::move(packet)); + + // Third DATA, second fragment (lost) + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasChunks(ElementsAre(IsDataChunk( + AllOf(Property(&DataChunk::ppid, PPID(53)), + Property(&DataChunk::ssn, SSN(0))))))); + + // Fourth DATA, first fragment + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasChunks(ElementsAre(IsDataChunk( + AllOf(Property(&DataChunk::ppid, PPID(54)), + Property(&DataChunk::ssn, SSN(0))))))); + z->socket.ReceivePacket(std::move(packet)); + + // Fourth DATA, second fragment + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasChunks(ElementsAre(IsDataChunk( + AllOf(Property(&DataChunk::ppid, PPID(54)), + Property(&DataChunk::ssn, SSN(0))))))); + z->socket.ReceivePacket(std::move(packet)); + + ExchangeMessages(a, *z); + + // Let the RTX timer expire, and exchange FORWARD-TSN/SACKs + AdvanceTime(a, *z, a.options.rto_initial); + + ExchangeMessages(a, *z); + + absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->ppid(), PPID(54)); + + ASSERT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +struct FakeChunkConfig : ChunkConfig { + static constexpr int kType = 0x49; + static constexpr size_t kHeaderSize = 4; + static constexpr int kVariableLengthAlignment = 0; +}; + +class FakeChunk : public Chunk, public TLVTrait<FakeChunkConfig> { + public: + FakeChunk() {} + + FakeChunk(FakeChunk&& other) = default; + FakeChunk& operator=(FakeChunk&& other) = default; + + void SerializeTo(std::vector<uint8_t>& out) const override { + AllocateTLV(out); + } + std::string ToString() const override { return "FAKE"; } +}; + +TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + // Inject a FAKE chunk + SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions()); + b.Add(FakeChunk()); + a.socket.ReceivePacket(b.Build()); + + // ERROR is sent as a reply. Capture it. + ASSERT_HAS_VALUE_AND_ASSIGN( + SctpPacket reply_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket(), z->options)); + ASSERT_THAT(reply_packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + ErrorChunk error, ErrorChunk::Parse(reply_packet.descriptors()[0].data)); + ASSERT_HAS_VALUE_AND_ASSIGN( + UnrecognizedChunkTypeCause cause, + error.error_causes().get<UnrecognizedChunkTypeCause>()); + EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04)); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, ReceivingErrorChunkReportsAsCallback) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + // Inject a ERROR chunk + SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions()); + b.Add( + ErrorChunk(Parameters::Builder() + .Add(UnrecognizedChunkTypeCause({0x49, 0x00, 0x00, 0x04})) + .Build())); + + EXPECT_CALL(a.cb, OnError(ErrorKind::kPeerReported, + HasSubstr("Unrecognized Chunk Type"))); + a.socket.ReceivePacket(b.Build()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) { + SocketUnderTest a("A"); + + constexpr size_t kReceiveWindowBufferSize = 2000; + SocketUnderTest z( + "Z", {.mtu = 3000, + .max_receiver_window_buffer_size = kReceiveWindowBufferSize}); + + EXPECT_CALL(z.cb, OnClosed).Times(0); + EXPECT_CALL(z.cb, OnAborted).Times(0); + + a.socket.Connect(); + std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(init_data, z.options)); + ASSERT_HAS_VALUE_AND_ASSIGN( + InitChunk init_chunk, + InitChunk::Parse(init_packet.descriptors()[0].data)); + z.socket.ReceivePacket(init_data); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // Fill up Z2 to the high watermark limit. + constexpr size_t kWatermarkLimit = + kReceiveWindowBufferSize * ReassemblyQueue::kHighWatermarkLimit; + constexpr size_t kRemainingSize = kReceiveWindowBufferSize - kWatermarkLimit; + + TSN tsn = init_chunk.initial_tsn(); + AnyDataChunk::Options opts; + opts.is_beginning = Data::IsBeginning(true); + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(kWatermarkLimit + 1), opts)) + .Build()); + + // First DATA will always trigger a SACK. It's not interesting. + EXPECT_THAT(z.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsSack( + AllOf(Property(&SackChunk::cumulative_tsn_ack, tsn), + Property(&SackChunk::gap_ack_blocks, IsEmpty())))))); + + // This DATA should be accepted - it's advancing cum ack tsn. + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 1), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(1), + /*options=*/{})) + .Build()); + + // The receiver might have moved into delayed ack mode. + AdvanceTime(a, z, z.options.rto_initial); + + EXPECT_THAT(z.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsSack( + AllOf(Property(&SackChunk::cumulative_tsn_ack, AddTo(tsn, 1)), + Property(&SackChunk::gap_ack_blocks, IsEmpty())))))); + + // This DATA will not be accepted - it's not advancing cum ack tsn. + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(1), + /*options=*/{})) + .Build()); + + // Sack will be sent in IMMEDIATE mode when this is happening. + EXPECT_THAT(z.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsSack( + AllOf(Property(&SackChunk::cumulative_tsn_ack, AddTo(tsn, 1)), + Property(&SackChunk::gap_ack_blocks, IsEmpty())))))); + + // This DATA will not be accepted either. + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 4), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(1), + /*options=*/{})) + .Build()); + + // Sack will be sent in IMMEDIATE mode when this is happening. + EXPECT_THAT(z.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsSack( + AllOf(Property(&SackChunk::cumulative_tsn_ack, AddTo(tsn, 1)), + Property(&SackChunk::gap_ack_blocks, IsEmpty())))))); + + // This DATA should be accepted, and it fills the reassembly queue. + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 2), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(kRemainingSize), + /*options=*/{})) + .Build()); + + // The receiver might have moved into delayed ack mode. + AdvanceTime(a, z, z.options.rto_initial); + + EXPECT_THAT(z.cb.ConsumeSentPacket(), + HasChunks(ElementsAre(IsSack( + AllOf(Property(&SackChunk::cumulative_tsn_ack, AddTo(tsn, 2)), + Property(&SackChunk::gap_ack_blocks, IsEmpty())))))); + + EXPECT_CALL(z.cb, OnAborted(ErrorKind::kResourceExhaustion, _)); + EXPECT_CALL(z.cb, OnClosed).Times(0); + + // This DATA will make the connection close. It's too full now. + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(kSmallMessageSize), + /*options=*/{})) + .Build()); +} + +TEST(DcSctpSocketTest, SetMaxMessageSize) { + SocketUnderTest a("A"); + + a.socket.SetMaxMessageSize(42u); + EXPECT_EQ(a.socket.options().max_message_size, 42u); +} + +TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + // Mock that the time always goes forward. + TimeMs now(0); + EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + + // Queue a few small messages with low lifetime, both ordered and unordered, + // and validate that all are delivered. + static constexpr int kIterations = 100; + for (int i = 0; i < kIterations; ++i) { + SendOptions send_options; + send_options.unordered = IsUnordered((i % 2) == 0); + send_options.lifetime = DurationMs(i % 3); // 0, 1, 2 ms + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options); + } + + ExchangeMessages(a, *z); + + for (int i = 0; i < kIterations; ++i) { + EXPECT_TRUE(z->cb.ConsumeReceivedMessage().has_value()); + } + + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); + + // Validate that the sockets really make the time move forward. + EXPECT_GE(*now, kIterations * 2); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + DiscardsMessagesWithLowLifetimeIfMustBuffer) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + SendOptions lifetime_0; + lifetime_0.unordered = IsUnordered(true); + lifetime_0.lifetime = DurationMs(0); + + SendOptions lifetime_1; + lifetime_1.unordered = IsUnordered(true); + lifetime_1.lifetime = DurationMs(1); + + // Mock that the time always goes forward. + TimeMs now(0); + EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + + // Fill up the send buffer with a large message. + std::vector<uint8_t> payload(kLargeMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + // And queue a few small messages with lifetime=0 or 1 ms - can't be sent. + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), lifetime_0); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {4, 5, 6}), lifetime_1); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {7, 8, 9}), lifetime_0); + + // Handle all that was sent until congestion window got full. + for (;;) { + std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket(); + if (packet_from_a.empty()) { + break; + } + z->socket.ReceivePacket(std::move(packet_from_a)); + } + + // Shouldn't be enough to send that large message. + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); + + // Exchange the rest of the messages, with the time ever increasing. + ExchangeMessages(a, *z); + + // The large message should be delivered. It was sent reliably. + ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage m1, z->cb.ConsumeReceivedMessage()); + EXPECT_EQ(m1.stream_id(), StreamID(1)); + EXPECT_THAT(m1.payload(), SizeIs(kLargeMessageSize)); + + // But none of the smaller messages. + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, HasReasonableBufferedAmountValues) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + // Sending a small message will directly send it as a single packet, so + // nothing is left in the queue. + EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + // Sending a message will directly start sending a few packets, so the + // buffered amount is not the full message size. + EXPECT_GT(a.socket.buffered_amount(StreamID(1)), 0u); + EXPECT_LT(a.socket.buffered_amount(StreamID(1)), kLargeMessageSize); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) { + SocketUnderTest a("A"); + EXPECT_EQ(a.socket.buffered_amount_low_threshold(StreamID(1)), 0u); +} + +TEST_P(DcSctpSocketParametrizedTest, + TriggersOnBufferedAmountLowWithDefaultValueZero) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + EXPECT_CALL(a.cb, OnBufferedAmountLow).WillRepeatedly(testing::Return()); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + DoesntTriggerOnBufferedAmountLowIfBelowThreshold) { + static constexpr size_t kMessageSize = 1000; + static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10; + + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + a.socket.SetBufferedAmountLowThreshold(StreamID(1), + kBufferedAmountLowThreshold); + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(0); + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountMultipleTimes) { + static constexpr size_t kMessageSize = 1000; + static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2; + + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + a.socket.SetBufferedAmountLowThreshold(StreamID(1), + kBufferedAmountLowThreshold); + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(3); + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(2))).Times(2); + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + a.socket.Send( + DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + a.socket.Send( + DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) { + static constexpr size_t kMessageSize = 1000; + static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5; + + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + a.socket.SetBufferedAmountLowThreshold(StreamID(1), + kBufferedAmountLowThreshold); + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + + // Add a few messages to fill up the congestion window. When that is full, + // messages will start to be fully buffered. + while (a.socket.buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) { + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kMessageSize)), + kSendOptions); + } + size_t initial_buffered = a.socket.buffered_amount(StreamID(1)); + ASSERT_GT(initial_buffered, kBufferedAmountLowThreshold); + + // Start ACKing packets, which will empty the send queue, and trigger the + // callback. + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(1); + ExchangeMessages(a, *z); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + DoesntTriggerOnTotalBufferAmountLowWhenBelow) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + ExchangeMessages(a, *z); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + TriggersOnTotalBufferAmountLowWhenCrossingThreshold) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0); + + // Fill up the send queue completely. + for (;;) { + if (a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions) == SendStatus::kErrorResourceExhaustion) { + break; + } + } + + EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(1); + ExchangeMessages(a, *z); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST(DcSctpSocketTest, InitialMetricsAreUnset) { + SocketUnderTest a("A"); + + EXPECT_FALSE(a.socket.GetMetrics().has_value()); +} + +TEST(DcSctpSocketTest, MessageInterleavingMetricsAreSet) { + std::vector<std::pair<bool, bool>> combinations = { + {false, false}, {false, true}, {true, false}, {true, true}}; + for (const auto& [a_enable, z_enable] : combinations) { + DcSctpOptions a_options = {.enable_message_interleaving = a_enable}; + DcSctpOptions z_options = {.enable_message_interleaving = z_enable}; + + SocketUnderTest a("A", a_options); + SocketUnderTest z("Z", z_options); + ConnectSockets(a, z); + + EXPECT_EQ(a.socket.GetMetrics()->uses_message_interleaving, + a_enable && z_enable); + } +} + +TEST(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + const size_t initial_a_rwnd = a.options.max_receiver_window_buffer_size * + ReassemblyQueue::kHighWatermarkLimit; + + EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 2u); + EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 2u); + EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 0u); + EXPECT_EQ(a.socket.GetMetrics()->cwnd_bytes, + a.options.cwnd_mtus_initial * a.options.mtu); + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u); + + EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 2u); + EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 0u); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u); + + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK + EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd); + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u); + + EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value()); + + EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 3u); + EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 3u); + EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 1u); + + EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 3u); + EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 1u); + + // Send one more (large - fragmented), and receive the delayed SACK. + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(a.options.mtu * 2 + 1)), + kSendOptions); + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 3u); + + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA + + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u); + EXPECT_GT(a.socket.GetMetrics()->peer_rwnd_bytes, 0u); + EXPECT_LT(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd); + + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA + + EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value()); + + EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 6u); + EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 4u); + EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 2u); + + EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 6u); + EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 2u); + + // Delayed sack + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); + + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u); + EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 5u); + EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd); +} + +TEST(DcSctpSocketTest, RetransmissionMetricsAreSetForFastRetransmit) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + ConnectSockets(a, z); + + // Enough to trigger fast retransmit of the missing second packet. + std::vector<uint8_t> payload(DcSctpOptions::kMaxSafeMTUSize * 5); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + // Receive first packet, drop second, receive and retransmit the remaining. + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.cb.ConsumeSentPacket(); + ExchangeMessages(a, z); + + EXPECT_EQ(a.socket.GetMetrics()->rtx_packets_count, 1u); + size_t expected_data_size = + RoundDownTo4(DcSctpOptions::kMaxSafeMTUSize - SctpPacket::kHeaderSize); + EXPECT_EQ(a.socket.GetMetrics()->rtx_bytes_count, expected_data_size); +} + +TEST(DcSctpSocketTest, RetransmissionMetricsAreSetForNormalRetransmit) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + ConnectSockets(a, z); + + std::vector<uint8_t> payload(kSmallMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + a.cb.ConsumeSentPacket(); + AdvanceTime(a, z, a.options.rto_initial); + ExchangeMessages(a, z); + + EXPECT_EQ(a.socket.GetMetrics()->rtx_packets_count, 1u); + size_t expected_data_size = + RoundUpTo4(kSmallMessageSize + DataChunk::kHeaderSize); + EXPECT_EQ(a.socket.GetMetrics()->rtx_bytes_count, expected_data_size); +} + +TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + size_t payload_bytes = + a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize; + + size_t expected_sent_packets = a.options.cwnd_mtus_initial; + + size_t expected_queued_bytes = + kLargeMessageSize - expected_sent_packets * payload_bytes; + + size_t expected_queued_packets = expected_queued_bytes / payload_bytes; + + // Due to alignment, padding etc, it's hard to calculate the exact number, but + // it should be in this range. + EXPECT_GE(a.socket.GetMetrics()->unack_data_count, + expected_sent_packets + expected_queued_packets); + + EXPECT_LE(a.socket.GetMetrics()->unack_data_count, + expected_sent_packets + expected_queued_packets + 2); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, DoesntSendMoreThanMaxBurstPackets) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + for (int i = 0; i < kMaxBurstPackets; ++i) { + std::vector<uint8_t> packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, Not(IsEmpty())); + z->socket.ReceivePacket(std::move(packet)); // DATA + } + + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + ExchangeMessages(a, *z); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + // A really large message, to ensure that the congestion window is often full. + constexpr size_t kMessageSize = 100000; + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + + bool delivered_packet = false; + std::vector<size_t> data_packet_sizes; + do { + delivered_packet = false; + std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket(); + if (!packet_from_a.empty()) { + data_packet_sizes.push_back(packet_from_a.size()); + delivered_packet = true; + z->socket.ReceivePacket(std::move(packet_from_a)); + } + std::vector<uint8_t> packet_from_z = z->cb.ConsumeSentPacket(); + if (!packet_from_z.empty()) { + delivered_packet = true; + a.socket.ReceivePacket(std::move(packet_from_z)); + } + } while (delivered_packet); + + size_t packet_payload_bytes = + a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize; + // +1 accounts for padding, and rounding up. + size_t expected_packets = + (kMessageSize + packet_payload_bytes - 1) / packet_payload_bytes + 1; + EXPECT_THAT(data_packet_sizes, SizeIs(expected_packets)); + + // Remove the last size - it will be the remainder. But all other sizes should + // be large. + data_packet_sizes.pop_back(); + + for (size_t size : data_packet_sizes) { + // The 4 is for padding/alignment. + EXPECT_GE(size, a.options.mtu - 4); + } + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST(DcSctpSocketTest, SendMessagesAfterHandover) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + + // Send message before handover to move socket to a not initial state + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + z->cb.ConsumeReceivedMessage(); + + z = HandoverSocket(std::move(z)); + + absl::optional<DcSctpMessage> msg; + + RTC_LOG(LS_INFO) << "Sending A #1"; + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAre(3, 4)); + + RTC_LOG(LS_INFO) << "Sending A #2"; + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(2)); + EXPECT_THAT(msg->payload(), testing::ElementsAre(5, 6)); + + RTC_LOG(LS_INFO) << "Sending Z #1"; + + z->socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // ack + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // data + + msg = a.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAre(1, 2, 3)); +} + +TEST(DcSctpSocketTest, CanDetectDcsctpImplementation) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp); + + // As A initiated the connection establishment, Z will not receive enough + // information to know about A's implementation + EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kUnknown); +} + +TEST(DcSctpSocketTest, BothCanDetectDcsctpImplementation) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + a.socket.Connect(); + z.socket.Connect(); + + ExchangeMessages(a, z); + + EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp); + EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kDcsctp); +} + +TEST_P(DcSctpSocketParametrizedTest, CanLoseFirstOrderedMessage) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + SendOptions send_options; + send_options.unordered = IsUnordered(false); + send_options.max_retransmissions = 0; + std::vector<uint8_t> payload(a.options.mtu - 100); + + // Send a first message (SID=1, SSN=0) + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + + // First DATA is lost, and retransmission timer will delete it. + a.cb.ConsumeSentPacket(); + AdvanceTime(a, *z, a.options.rto_initial); + ExchangeMessages(a, *z); + + // Send a second message (SID=0, SSN=1). + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); + ExchangeMessages(a, *z); + + // The Z socket should receive the second message, but not the first. + absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->ppid(), PPID(52)); + + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST(DcSctpSocketTest, ReceiveBothUnorderedAndOrderedWithSameTSN) { + /* This issue was found by fuzzing. */ + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(init_data, z.options)); + ASSERT_HAS_VALUE_AND_ASSIGN( + InitChunk init_chunk, + InitChunk::Parse(init_packet.descriptors()[0].data)); + z.socket.ReceivePacket(init_data); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // Receive a short unordered message with tsn=INITIAL_TSN+1 + TSN tsn = init_chunk.initial_tsn(); + AnyDataChunk::Options opts; + opts.is_beginning = Data::IsBeginning(true); + opts.is_end = Data::IsEnd(true); + opts.is_unordered = IsUnordered(true); + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(10), opts)) + .Build()); + + // Now receive a longer _ordered_ message with [INITIAL_TSN, INITIAL_TSN+1]. + // This isn't allowed as it reuses TSN=53 with different properties, but it + // shouldn't cause any issues. + opts.is_unordered = IsUnordered(false); + opts.is_end = Data::IsEnd(false); + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(10), opts)) + .Build()); + + opts.is_beginning = Data::IsBeginning(false); + opts.is_end = Data::IsEnd(true); + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(10), opts)) + .Build()); +} + +TEST(DcSctpSocketTest, CloseTwoStreamsAtTheSameTime) { + // Reported as https://crbug.com/1312009. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(2)))).Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(2)))).Times(1); + + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + + ExchangeMessages(a, z); + + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)})); + + ExchangeMessages(a, z); +} + +TEST(DcSctpSocketTest, CloseThreeStreamsAtTheSameTime) { + // Similar to CloseTwoStreamsAtTheSameTime, but ensuring that the two + // remaining streams are reset at the same time in the second request. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(z.cb, OnIncomingStreamsReset( + UnorderedElementsAre(StreamID(2), StreamID(3)))) + .Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed( + UnorderedElementsAre(StreamID(2), StreamID(3)))) + .Times(1); + + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), kSendOptions); + + ExchangeMessages(a, z); + + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)})); + a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)})); + + ExchangeMessages(a, z); +} + +TEST(DcSctpSocketTest, CloseStreamsWithPendingRequest) { + // Checks that stream reset requests are properly paused when they can't be + // immediately reset - i.e. when there is already an ongoing stream reset + // request (and there can only be a single one in-flight). + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(z.cb, OnIncomingStreamsReset( + UnorderedElementsAre(StreamID(2), StreamID(3)))) + .Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed( + UnorderedElementsAre(StreamID(2), StreamID(3)))) + .Times(1); + + ConnectSockets(a, z); + + SendOptions send_options = {.unordered = IsUnordered(false)}; + + // Send a few ordered messages + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options); + + ExchangeMessages(a, z); + + // Receive these messages + absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->stream_id(), StreamID(1)); + absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->stream_id(), StreamID(2)); + absl::optional<DcSctpMessage> msg3 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg3.has_value()); + EXPECT_EQ(msg3->stream_id(), StreamID(3)); + + // Reset the streams - not all at once. + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + + std::vector<uint8_t> packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasChunks(ElementsAre(IsReConfig(HasParameters( + ElementsAre(IsOutgoingResetRequest(Property( + &OutgoingSSNResetRequestParameter::stream_ids, + ElementsAre(StreamID(1)))))))))); + z.socket.ReceivePacket(std::move(packet)); + + // Sending more reset requests while this one is ongoing. + + a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)})); + a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)})); + + ExchangeMessages(a, z); + + // Send a few more ordered messages + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options); + + ExchangeMessages(a, z); + + // Receive these messages + absl::optional<DcSctpMessage> msg4 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg4.has_value()); + EXPECT_EQ(msg4->stream_id(), StreamID(1)); + absl::optional<DcSctpMessage> msg5 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg5.has_value()); + EXPECT_EQ(msg5->stream_id(), StreamID(2)); + absl::optional<DcSctpMessage> msg6 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg6.has_value()); + EXPECT_EQ(msg6->stream_id(), StreamID(3)); +} + +TEST(DcSctpSocketTest, StreamsHaveInitialPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + SocketUnderTest a("A", options); + + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), + options.default_stream_priority); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), + options.default_stream_priority); +} + +TEST(DcSctpSocketTest, CanChangeStreamPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + SocketUnderTest a("A", options); + + a.socket.SetStreamPriority(StreamID(1), StreamPriority(43)); + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), StreamPriority(43)); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + + a.socket.SetStreamPriority(StreamID(2), StreamPriority(43)); + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), StreamPriority(43)); +} + +TEST_P(DcSctpSocketParametrizedTest, WillHandoverPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + auto a = std::make_unique<SocketUnderTest>("A", options); + SocketUnderTest z("Z"); + + ConnectSockets(*a, z); + + a->socket.SetStreamPriority(StreamID(1), StreamPriority(43)); + a->socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + a->socket.SetStreamPriority(StreamID(2), StreamPriority(43)); + + ExchangeMessages(*a, z); + + a = MaybeHandoverSocket(std::move(a)); + + EXPECT_EQ(a->socket.GetStreamPriority(StreamID(1)), StreamPriority(43)); + EXPECT_EQ(a->socket.GetStreamPriority(StreamID(2)), StreamPriority(43)); +} + +TEST(DcSctpSocketTest, ReconnectSocketWithPendingStreamReset) { + // This is an issue found by fuzzing, and doesn't really make sense in WebRTC + // data channels as a SCTP connection is never ever closed and then + // reconnected. SCTP connections are closed when the peer connection is + // deleted, and then it doesn't do more with SCTP. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + + EXPECT_CALL(z.cb, OnAborted).Times(1); + a.socket.Close(); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + a.socket.Connect(); + ExchangeMessages(a, z); + a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)})); +} + +TEST(DcSctpSocketTest, SmallSentMessagesWithPrioWillArriveInSpecificOrder) { + DcSctpOptions options = {.enable_message_interleaving = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("A", options); + + a.socket.SetStreamPriority(StreamID(1), StreamPriority(700)); + a.socket.SetStreamPriority(StreamID(2), StreamPriority(200)); + a.socket.SetStreamPriority(StreamID(3), StreamPriority(100)); + + // Enqueue messages before connecting the socket, to ensure they aren't send + // as soon as Send() is called. + a.socket.Send(DcSctpMessage(StreamID(3), PPID(301), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(101), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(201), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(102), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(103), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + + ConnectSockets(a, z); + ExchangeMessages(a, z); + + std::vector<uint32_t> received_ppids; + for (;;) { + absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage(); + if (!msg.has_value()) { + break; + } + received_ppids.push_back(*msg->ppid()); + } + + EXPECT_THAT(received_ppids, ElementsAre(101, 102, 103, 201, 301)); +} + +TEST(DcSctpSocketTest, LargeSentMessagesWithPrioWillArriveInSpecificOrder) { + DcSctpOptions options = {.enable_message_interleaving = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("A", options); + + a.socket.SetStreamPriority(StreamID(1), StreamPriority(700)); + a.socket.SetStreamPriority(StreamID(2), StreamPriority(200)); + a.socket.SetStreamPriority(StreamID(3), StreamPriority(100)); + + // Enqueue messages before connecting the socket, to ensure they aren't send + // as soon as Send() is called. + a.socket.Send(DcSctpMessage(StreamID(3), PPID(301), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(101), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(201), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(102), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + ConnectSockets(a, z); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201, 301)); +} + +TEST(DcSctpSocketTest, MessageWithHigherPrioWillInterruptLowerPrioMessage) { + DcSctpOptions options = {.enable_message_interleaving = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("Z", options); + + ConnectSockets(a, z); + + a.socket.SetStreamPriority(StreamID(2), StreamPriority(128)); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(201), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + // Due to a non-zero initial congestion window, the message will already start + // to send, but will not succeed to be sent completely before filling the + // congestion window or stopping due to reaching how many packets that can be + // sent at once (max burst). The important thing is that the entire message + // doesn't get sent in full. + + // Now enqueue two messages; one small and one large higher priority message. + a.socket.SetStreamPriority(StreamID(1), StreamPriority(512)); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(101), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(102), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201)); +} + +TEST(DcSctpSocketTest, LifecycleEventsAreGeneratedForAckedMessages) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(101), + std::vector<uint8_t>(kLargeMessageSize)), + {.lifecycle_id = LifecycleId(41)}); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(102), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(103), + std::vector<uint8_t>(kLargeMessageSize)), + {.lifecycle_id = LifecycleId(42)}); + + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(41))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(41))); + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(42))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(42))); + ExchangeMessages(a, z); + // In case of delayed ack. + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 103)); +} + +TEST(DcSctpSocketTest, LifecycleEventsForFailMaxRetransmissions) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + ConnectSockets(a, z); + + std::vector<uint8_t> payload(a.options.mtu - 100); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(1), + }); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(2), + }); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(3), + }); + + // First DATA + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Second DATA (lost) + a.cb.ConsumeSentPacket(); + + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(1))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1))); + EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(2), + /*maybe_delivered=*/true)); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(2))); + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(3))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(3))); + ExchangeMessages(a, z); + + // Handle delayed SACK. + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); + ExchangeMessages(a, z); + + // The chunk is now NACKed. Let the RTO expire, to discard the message. + AdvanceTime(a, z, a.options.rto_initial); + ExchangeMessages(a, z); + + // Handle delayed SACK. + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(51, 53)); +} + +TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithRetransmitLimit) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + ConnectSockets(a, z); + + // Will not be able to send it in full within the congestion window, but will + // need to wait for SACKs to be received for more fragments to be sent. + std::vector<uint8_t> payload(kLargeMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(1), + }); + + // First DATA + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Second DATA (lost) + a.cb.ConsumeSentPacket(); + + EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1), + /*maybe_delivered=*/false)); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1))); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty()); +} + +TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithLifetimeLimit) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + // Send it before the socket is connected, to prevent it from being sent too + // quickly. The idea is that it should be expired before even attempting to + // send it in full. + std::vector<uint8_t> payload(kSmallMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .lifetime = DurationMs(100), + .lifecycle_id = LifecycleId(1), + }); + + AdvanceTime(a, z, DurationMs(200)); + + EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1), + /*maybe_delivered=*/false)); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1))); + ConnectSockets(a, z); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty()); +} + +TEST_P(DcSctpSocketParametrizedTest, ExposesTheNumberOfNegotiatedStreams) { + DcSctpOptions options_a = { + .announced_maximum_incoming_streams = 12, + .announced_maximum_outgoing_streams = 45, + }; + SocketUnderTest a("A", options_a); + + DcSctpOptions options_z = { + .announced_maximum_incoming_streams = 23, + .announced_maximum_outgoing_streams = 34, + }; + auto z = std::make_unique<SocketUnderTest>("Z", options_z); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + ASSERT_HAS_VALUE_AND_ASSIGN(Metrics metrics_a, a.socket.GetMetrics()); + EXPECT_EQ(metrics_a.negotiated_maximum_incoming_streams, 12); + EXPECT_EQ(metrics_a.negotiated_maximum_outgoing_streams, 23); + + ASSERT_HAS_VALUE_AND_ASSIGN(Metrics metrics_z, z->socket.GetMetrics()); + EXPECT_EQ(metrics_z.negotiated_maximum_incoming_streams, 23); + EXPECT_EQ(metrics_z.negotiated_maximum_outgoing_streams, 12); +} + +TEST(DcSctpSocketTest, ResetStreamsDeferred) { + // Guaranteed to be fragmented into two fragments. + constexpr size_t kTwoFragmentsSize = DcSctpOptions::kMaxSafeMTUSize + 100; + + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kTwoFragmentsSize)), + {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(54), + std::vector<uint8_t>(kSmallMessageSize)), + {}); + + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + + auto data1 = a.cb.ConsumeSentPacket(); + auto data2 = a.cb.ConsumeSentPacket(); + auto data3 = a.cb.ConsumeSentPacket(); + auto reconfig = a.cb.ConsumeSentPacket(); + + EXPECT_THAT( + data1, + HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0)))))); + EXPECT_THAT( + data2, + HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0)))))); + EXPECT_THAT( + data3, + HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(1)))))); + EXPECT_THAT(reconfig, HasChunks(ElementsAre(IsReConfig(HasParameters( + ElementsAre(IsOutgoingResetRequest(Property( + &OutgoingSSNResetRequestParameter::stream_ids, + ElementsAre(StreamID(1)))))))))); + + // Receive them slightly out of order to make stream resetting deferred. + z.socket.ReceivePacket(reconfig); + + z.socket.ReceivePacket(data1); + z.socket.ReceivePacket(data2); + z.socket.ReceivePacket(data3); + + absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->stream_id(), StreamID(1)); + EXPECT_EQ(msg1->ppid(), PPID(53)); + EXPECT_EQ(msg1->payload().size(), kTwoFragmentsSize); + + absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->stream_id(), StreamID(1)); + EXPECT_EQ(msg2->ppid(), PPID(54)); + EXPECT_EQ(msg2->payload().size(), kSmallMessageSize); + + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))); + ExchangeMessages(a, z); + + // Z sent "in progress", which will make A buffer packets until it's sure + // that the reconfiguration has been applied. A will retry - wait for that. + AdvanceTime(a, z, a.options.rto_initial); + + auto reconfig2 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(reconfig2, HasChunks(ElementsAre(IsReConfig(HasParameters( + ElementsAre(IsOutgoingResetRequest(Property( + &OutgoingSSNResetRequestParameter::stream_ids, + ElementsAre(StreamID(1)))))))))); + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))); + z.socket.ReceivePacket(reconfig2); + + auto reconfig3 = z.cb.ConsumeSentPacket(); + EXPECT_THAT(reconfig3, HasChunks(ElementsAre(IsReConfig(HasParameters( + ElementsAre(IsReconfigurationResponse(Property( + &ReconfigurationResponseParameter::result, + ReconfigurationResponseParameter::Result:: + kSuccessPerformed)))))))); + a.socket.ReceivePacket(reconfig3); + + EXPECT_THAT( + data1, + HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0)))))); + EXPECT_THAT( + data2, + HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(0)))))); + EXPECT_THAT( + data3, + HasChunks(ElementsAre(IsDataChunk(Property(&DataChunk::ssn, SSN(1)))))); + EXPECT_THAT(reconfig, HasChunks(ElementsAre(IsReConfig(HasParameters( + ElementsAre(IsOutgoingResetRequest(Property( + &OutgoingSSNResetRequestParameter::stream_ids, + ElementsAre(StreamID(1)))))))))); + + // Send a new message after the stream has been reset. + a.socket.Send(DcSctpMessage(StreamID(1), PPID(55), + std::vector<uint8_t>(kSmallMessageSize)), + {}); + ExchangeMessages(a, z); + + absl::optional<DcSctpMessage> msg3 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg3.has_value()); + EXPECT_EQ(msg3->stream_id(), StreamID(1)); + EXPECT_EQ(msg3->ppid(), PPID(55)); + EXPECT_EQ(msg3->payload().size(), kSmallMessageSize); +} + +TEST(DcSctpSocketTest, ResetStreamsWithPausedSenderResumesWhenPerformed) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), + std::vector<uint8_t>(kSmallMessageSize)), + {}); + + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + + // Will be queued, as the stream has an outstanding reset operation. + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), + std::vector<uint8_t>(kSmallMessageSize)), + {}); + + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))); + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))); + ExchangeMessages(a, z); + + absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->stream_id(), StreamID(1)); + EXPECT_EQ(msg1->ppid(), PPID(51)); + EXPECT_EQ(msg1->payload().size(), kSmallMessageSize); + + absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->stream_id(), StreamID(1)); + EXPECT_EQ(msg2->ppid(), PPID(52)); + EXPECT_EQ(msg2->payload().size(), kSmallMessageSize); +} + +TEST_P(DcSctpSocketParametrizedTest, ZeroChecksumMetricsAreSet) { + std::vector<std::pair<bool, bool>> combinations = { + {false, false}, {false, true}, {true, false}, {true, true}}; + for (const auto& [a_enable, z_enable] : combinations) { + DcSctpOptions a_options = {.enable_zero_checksum = a_enable}; + DcSctpOptions z_options = {.enable_zero_checksum = z_enable}; + + SocketUnderTest a("A", a_options); + auto z = std::make_unique<SocketUnderTest>("Z", z_options); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_EQ(a.socket.GetMetrics()->uses_zero_checksum, a_enable && z_enable); + EXPECT_EQ(z->socket.GetMetrics()->uses_zero_checksum, a_enable && z_enable); + } +} + +TEST(DcSctpSocketTest, AlwaysSendsInitWithNonZeroChecksum) { + DcSctpOptions options = {.enable_zero_checksum = true}; + SocketUnderTest a("A", options); + + a.socket.Connect(); + std::vector<uint8_t> data = a.cb.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(data, options)); + EXPECT_THAT(packet.descriptors(), + ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type, + InitChunk::kType))); + EXPECT_THAT(packet.common_header().checksum, Not(Eq(0u))); +} + +TEST(DcSctpSocketTest, MaySendInitAckWithZeroChecksum) { + DcSctpOptions options = {.enable_zero_checksum = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("Z", options); + + a.socket.Connect(); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // INIT + + std::vector<uint8_t> data = z.cb.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(data, options)); + EXPECT_THAT(packet.descriptors(), + ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type, + InitAckChunk::kType))); + EXPECT_THAT(packet.common_header().checksum, 0u); +} + +TEST(DcSctpSocketTest, AlwaysSendsCookieEchoWithNonZeroChecksum) { + DcSctpOptions options = {.enable_zero_checksum = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("Z", options); + + a.socket.Connect(); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // INIT + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // INIT-ACK + + std::vector<uint8_t> data = a.cb.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(data, options)); + EXPECT_THAT(packet.descriptors(), + ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type, + CookieEchoChunk::kType))); + EXPECT_THAT(packet.common_header().checksum, Not(Eq(0u))); +} + +TEST(DcSctpSocketTest, SendsCookieAckWithZeroChecksum) { + DcSctpOptions options = {.enable_zero_checksum = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("Z", options); + + a.socket.Connect(); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // INIT + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // INIT-ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // COOKIE-ECHO + + std::vector<uint8_t> data = z.cb.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(data, options)); + EXPECT_THAT(packet.descriptors(), + ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type, + CookieAckChunk::kType))); + EXPECT_THAT(packet.common_header().checksum, 0u); +} + +TEST_P(DcSctpSocketParametrizedTest, SendsDataWithZeroChecksum) { + DcSctpOptions options = {.enable_zero_checksum = true}; + SocketUnderTest a("A", options); + auto z = std::make_unique<SocketUnderTest>("Z", options); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + std::vector<uint8_t> payload(a.options.mtu - 100); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + std::vector<uint8_t> data = a.cb.ConsumeSentPacket(); + z->socket.ReceivePacket(data); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(data, options)); + EXPECT_THAT(packet.descriptors(), + ElementsAre(testing::Field(&SctpPacket::ChunkDescriptor::type, + DataChunk::kType))); + EXPECT_THAT(packet.common_header().checksum, 0u); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, AllPacketsAfterConnectHaveZeroChecksum) { + DcSctpOptions options = {.enable_zero_checksum = true}; + SocketUnderTest a("A", options); + auto z = std::make_unique<SocketUnderTest>("Z", options); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + // Send large messages in both directions, and verify that they arrive and + // that every packet has zero checksum. + std::vector<uint8_t> payload(kLargeMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + z->socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + for (;;) { + if (auto data = a.cb.ConsumeSentPacket(); !data.empty()) { + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(data, options)); + EXPECT_THAT(packet.common_header().checksum, 0u); + z->socket.ReceivePacket(std::move(data)); + + } else if (auto data = z->cb.ConsumeSentPacket(); !data.empty()) { + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(data, options)); + EXPECT_THAT(packet.common_header().checksum, 0u); + a.socket.ReceivePacket(std::move(data)); + + } else { + break; + } + } + + absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_THAT(msg1->payload(), SizeIs(kLargeMessageSize)); + + absl::optional<DcSctpMessage> msg2 = a.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_THAT(msg2->payload(), SizeIs(kLargeMessageSize)); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST(DcSctpSocketTest, HandlesForwardTsnOutOfOrderWithStreamResetting) { + // This test ensures that receiving FORWARD-TSN and RECONFIG out of order is + // handled correctly. + SocketUnderTest a("A", {.heartbeat_interval = DurationMs(0)}); + SocketUnderTest z("Z", {.heartbeat_interval = DurationMs(0)}); + + ConnectSockets(a, z); + std::vector<uint8_t> payload(kSmallMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .max_retransmissions = 0, + }); + + // Packet is lost. + EXPECT_THAT(a.cb.ConsumeSentPacket(), + HasChunks(ElementsAre( + IsDataChunk(AllOf(Property(&DataChunk::ssn, SSN(0)), + Property(&DataChunk::ppid, PPID(51))))))); + AdvanceTime(a, z, a.options.rto_initial); + + auto fwd_tsn_packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(fwd_tsn_packet, + HasChunks(ElementsAre(IsChunkType(ForwardTsnChunk::kType)))); + // Reset stream 1 + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + auto reconfig_packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(reconfig_packet, + HasChunks(ElementsAre(IsChunkType(ReConfigChunk::kType)))); + + // These two packets are received in the wrong order. + z.socket.ReceivePacket(reconfig_packet); + z.socket.ReceivePacket(fwd_tsn_packet); + ExchangeMessagesAndAdvanceTime(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto data_packet_2 = a.cb.ConsumeSentPacket(); + auto data_packet_3 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(data_packet_2, HasChunks(ElementsAre(IsDataChunk(AllOf( + Property(&DataChunk::ssn, SSN(0)), + Property(&DataChunk::ppid, PPID(52))))))); + EXPECT_THAT(data_packet_3, HasChunks(ElementsAre(IsDataChunk(AllOf( + Property(&DataChunk::ssn, SSN(1)), + Property(&DataChunk::ppid, PPID(53))))))); + + z.socket.ReceivePacket(data_packet_2); + z.socket.ReceivePacket(data_packet_3); + ASSERT_THAT(z.cb.ConsumeReceivedMessage(), + testing::Optional(Property(&DcSctpMessage::ppid, PPID(52)))); + ASSERT_THAT(z.cb.ConsumeReceivedMessage(), + testing::Optional(Property(&DcSctpMessage::ppid, PPID(53)))); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.cc b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.cc new file mode 100644 index 0000000000..902dff962f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.cc @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/heartbeat_handler.h" + +#include <stddef.h> + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/functional/bind_front.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +// This is stored (in serialized form) as HeartbeatInfoParameter sent in +// HeartbeatRequestChunk and received back in HeartbeatAckChunk. It should be +// well understood that this data may be modified by the peer, so it can't +// be trusted. +// +// It currently only stores a timestamp, in millisecond precision, to allow for +// RTT measurements. If that would be manipulated by the peer, it would just +// result in incorrect RTT measurements, which isn't an issue. +class HeartbeatInfo { + public: + static constexpr size_t kBufferSize = sizeof(uint64_t); + static_assert(kBufferSize == 8, "Unexpected buffer size"); + + explicit HeartbeatInfo(TimeMs created_at) : created_at_(created_at) {} + + std::vector<uint8_t> Serialize() { + uint32_t high_bits = static_cast<uint32_t>(*created_at_ >> 32); + uint32_t low_bits = static_cast<uint32_t>(*created_at_); + + std::vector<uint8_t> data(kBufferSize); + BoundedByteWriter<kBufferSize> writer(data); + writer.Store32<0>(high_bits); + writer.Store32<4>(low_bits); + return data; + } + + static absl::optional<HeartbeatInfo> Deserialize( + rtc::ArrayView<const uint8_t> data) { + if (data.size() != kBufferSize) { + RTC_LOG(LS_WARNING) << "Invalid heartbeat info: " << data.size() + << " bytes"; + return absl::nullopt; + } + + BoundedByteReader<kBufferSize> reader(data); + uint32_t high_bits = reader.Load32<0>(); + uint32_t low_bits = reader.Load32<4>(); + + uint64_t created_at = static_cast<uint64_t>(high_bits) << 32 | low_bits; + return HeartbeatInfo(TimeMs(created_at)); + } + + TimeMs created_at() const { return created_at_; } + + private: + const TimeMs created_at_; +}; + +HeartbeatHandler::HeartbeatHandler(absl::string_view log_prefix, + const DcSctpOptions& options, + Context* context, + TimerManager* timer_manager) + : log_prefix_(log_prefix), + ctx_(context), + timer_manager_(timer_manager), + interval_duration_(options.heartbeat_interval), + interval_duration_should_include_rtt_( + options.heartbeat_interval_include_rtt), + interval_timer_(timer_manager_->CreateTimer( + "heartbeat-interval", + absl::bind_front(&HeartbeatHandler::OnIntervalTimerExpiry, this), + TimerOptions(interval_duration_, TimerBackoffAlgorithm::kFixed))), + timeout_timer_(timer_manager_->CreateTimer( + "heartbeat-timeout", + absl::bind_front(&HeartbeatHandler::OnTimeoutTimerExpiry, this), + TimerOptions(options.rto_initial, + TimerBackoffAlgorithm::kExponential, + /*max_restarts=*/0))) { + // The interval timer must always be running as long as the association is up. + RestartTimer(); +} + +void HeartbeatHandler::RestartTimer() { + if (interval_duration_ == DurationMs(0)) { + // Heartbeating has been disabled. + return; + } + + if (interval_duration_should_include_rtt_) { + // The RTT should be used, but it's not easy accessible. The RTO will + // suffice. + interval_timer_->set_duration(interval_duration_ + ctx_->current_rto()); + } else { + interval_timer_->set_duration(interval_duration_); + } + + interval_timer_->Start(); +} + +void HeartbeatHandler::HandleHeartbeatRequest(HeartbeatRequestChunk chunk) { + // https://tools.ietf.org/html/rfc4960#section-8.3 + // "The receiver of the HEARTBEAT should immediately respond with a + // HEARTBEAT ACK that contains the Heartbeat Information TLV, together with + // any other received TLVs, copied unchanged from the received HEARTBEAT + // chunk." + ctx_->Send(ctx_->PacketBuilder().Add( + HeartbeatAckChunk(std::move(chunk).extract_parameters()))); +} + +void HeartbeatHandler::HandleHeartbeatAck(HeartbeatAckChunk chunk) { + timeout_timer_->Stop(); + absl::optional<HeartbeatInfoParameter> info_param = chunk.info(); + if (!info_param.has_value()) { + ctx_->callbacks().OnError( + ErrorKind::kParseFailed, + "Failed to parse HEARTBEAT-ACK; No Heartbeat Info parameter"); + return; + } + absl::optional<HeartbeatInfo> info = + HeartbeatInfo::Deserialize(info_param->info()); + if (!info.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse HEARTBEAT-ACK; Failed to " + "deserialized Heartbeat info parameter"); + return; + } + + TimeMs now = ctx_->callbacks().TimeMillis(); + if (info->created_at() > TimeMs(0) && info->created_at() <= now) { + ctx_->ObserveRTT(now - info->created_at()); + } + + // https://tools.ietf.org/html/rfc4960#section-8.1 + // "The counter shall be reset each time ... a HEARTBEAT ACK is received from + // the peer endpoint." + ctx_->ClearTxErrorCounter(); +} + +absl::optional<DurationMs> HeartbeatHandler::OnIntervalTimerExpiry() { + if (ctx_->is_connection_established()) { + HeartbeatInfo info(ctx_->callbacks().TimeMillis()); + timeout_timer_->set_duration(ctx_->current_rto()); + timeout_timer_->Start(); + RTC_DLOG(LS_INFO) << log_prefix_ << "Sending HEARTBEAT with timeout " + << *timeout_timer_->duration(); + + Parameters parameters = Parameters::Builder() + .Add(HeartbeatInfoParameter(info.Serialize())) + .Build(); + + ctx_->Send(ctx_->PacketBuilder().Add( + HeartbeatRequestChunk(std::move(parameters)))); + } else { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ + << "Will not send HEARTBEAT when connection not established"; + } + return absl::nullopt; +} + +absl::optional<DurationMs> HeartbeatHandler::OnTimeoutTimerExpiry() { + // Note that the timeout timer is not restarted. It will be started again when + // the interval timer expires. + RTC_DCHECK(!timeout_timer_->is_running()); + ctx_->IncrementTxErrorCounter("HEARTBEAT timeout"); + return absl::nullopt; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.h b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.h new file mode 100644 index 0000000000..318b02955b --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_ +#define NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_ + +#include <stdint.h> + +#include <memory> +#include <string> + +#include "absl/strings/string_view.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" + +namespace dcsctp { + +// HeartbeatHandler handles all logic around sending heartbeats and receiving +// the responses, as well as receiving incoming heartbeat requests. +// +// Heartbeats are sent on idle connections to ensure that the connection is +// still healthy and to measure the RTT. If a number of heartbeats time out, +// the connection will eventually be closed. +class HeartbeatHandler { + public: + HeartbeatHandler(absl::string_view log_prefix, + const DcSctpOptions& options, + Context* context, + TimerManager* timer_manager); + + // Called when the heartbeat interval timer should be restarted. This is + // generally done every time data is sent, which makes the timer expire when + // the connection is idle. + void RestartTimer(); + + // Called on received HeartbeatRequestChunk chunks. + void HandleHeartbeatRequest(HeartbeatRequestChunk chunk); + + // Called on received HeartbeatRequestChunk chunks. + void HandleHeartbeatAck(HeartbeatAckChunk chunk); + + private: + absl::optional<DurationMs> OnIntervalTimerExpiry(); + absl::optional<DurationMs> OnTimeoutTimerExpiry(); + + const absl::string_view log_prefix_; + Context* ctx_; + TimerManager* timer_manager_; + // The time for a connection to be idle before a heartbeat is sent. + const DurationMs interval_duration_; + // Adding RTT to the duration will add some jitter, which is good in + // production, but less good in unit tests, which is why it can be disabled. + const bool interval_duration_should_include_rtt_; + const std::unique_ptr<Timer> interval_timer_; + const std::unique_ptr<Timer> timeout_timer_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler_test.cc b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler_test.cc new file mode 100644 index 0000000000..d573192440 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler_test.cc @@ -0,0 +1,189 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/heartbeat_handler.h" + +#include <memory> +#include <utility> +#include <vector> + +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/socket/mock_context.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SizeIs; + +constexpr DurationMs kHeartbeatInterval = DurationMs(30'000); + +DcSctpOptions MakeOptions(DurationMs heartbeat_interval) { + DcSctpOptions options; + options.heartbeat_interval_include_rtt = false; + options.heartbeat_interval = heartbeat_interval; + options.enable_zero_checksum = false; + return options; +} + +class HeartbeatHandlerTestBase : public testing::Test { + protected: + explicit HeartbeatHandlerTestBase(DurationMs heartbeat_interval) + : options_(MakeOptions(heartbeat_interval)), + context_(&callbacks_), + timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return callbacks_.CreateTimeout(precision); + }), + handler_("log: ", options_, &context_, &timer_manager_) {} + + void AdvanceTime(DurationMs duration) { + callbacks_.AdvanceTime(duration); + for (;;) { + absl::optional<TimeoutID> timeout_id = callbacks_.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + timer_manager_.HandleTimeout(*timeout_id); + } + } + + const DcSctpOptions options_; + NiceMock<MockDcSctpSocketCallbacks> callbacks_; + NiceMock<MockContext> context_; + TimerManager timer_manager_; + HeartbeatHandler handler_; +}; + +class HeartbeatHandlerTest : public HeartbeatHandlerTestBase { + protected: + HeartbeatHandlerTest() : HeartbeatHandlerTestBase(kHeartbeatInterval) {} +}; + +class DisabledHeartbeatHandlerTest : public HeartbeatHandlerTestBase { + protected: + DisabledHeartbeatHandlerTest() : HeartbeatHandlerTestBase(DurationMs(0)) {} +}; + +TEST_F(HeartbeatHandlerTest, HasRunningHeartbeatIntervalTimer) { + AdvanceTime(options_.heartbeat_interval); + + // Validate that a heartbeat request was sent. + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(payload, options_)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk request, + HeartbeatRequestChunk::Parse(packet.descriptors()[0].data)); + + EXPECT_TRUE(request.info().has_value()); +} + +TEST_F(HeartbeatHandlerTest, RepliesToHeartbeatRequests) { + uint8_t info_data[] = {1, 2, 3, 4, 5}; + HeartbeatRequestChunk request( + Parameters::Builder().Add(HeartbeatInfoParameter(info_data)).Build()); + + handler_.HandleHeartbeatRequest(std::move(request)); + + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(payload, options_)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatAckChunk response, + HeartbeatAckChunk::Parse(packet.descriptors()[0].data)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatInfoParameter param, + response.parameters().get<HeartbeatInfoParameter>()); + + EXPECT_THAT(param.info(), ElementsAre(1, 2, 3, 4, 5)); +} + +TEST_F(HeartbeatHandlerTest, SendsHeartbeatRequestsOnIdleChannel) { + AdvanceTime(options_.heartbeat_interval); + + // Grab the request, and make a response. + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(payload, options_)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk req, + HeartbeatRequestChunk::Parse(packet.descriptors()[0].data)); + + HeartbeatAckChunk ack(std::move(req).extract_parameters()); + + // Respond a while later. This RTT will be measured by the handler + constexpr DurationMs rtt(313); + + EXPECT_CALL(context_, ObserveRTT(rtt)).Times(1); + + callbacks_.AdvanceTime(rtt); + handler_.HandleHeartbeatAck(std::move(ack)); +} + +TEST_F(HeartbeatHandlerTest, DoesntObserveInvalidHeartbeats) { + AdvanceTime(options_.heartbeat_interval); + + // Grab the request, and make a response. + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(payload, options_)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk req, + HeartbeatRequestChunk::Parse(packet.descriptors()[0].data)); + + HeartbeatAckChunk ack(std::move(req).extract_parameters()); + + EXPECT_CALL(context_, ObserveRTT).Times(0); + + // Go backwards in time - which make the HEARTBEAT-ACK have an invalid + // timestamp in it, as it will be in the future. + callbacks_.AdvanceTime(DurationMs(-100)); + + handler_.HandleHeartbeatAck(std::move(ack)); +} + +TEST_F(HeartbeatHandlerTest, IncreasesErrorIfNotAckedInTime) { + DurationMs rto(105); + EXPECT_CALL(context_, current_rto).WillOnce(Return(rto)); + AdvanceTime(options_.heartbeat_interval); + + // Validate that a request was sent. + EXPECT_THAT(callbacks_.ConsumeSentPacket(), Not(IsEmpty())); + + EXPECT_CALL(context_, IncrementTxErrorCounter).Times(1); + AdvanceTime(rto); +} + +TEST_F(DisabledHeartbeatHandlerTest, IsReallyDisabled) { + AdvanceTime(options_.heartbeat_interval); + + // Validate that a request was NOT sent. + EXPECT_THAT(callbacks_.ConsumeSentPacket(), IsEmpty()); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/mock_context.h b/third_party/libwebrtc/net/dcsctp/socket/mock_context.h new file mode 100644 index 0000000000..88e71d1b35 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/mock_context.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_ +#define NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_ + +#include <cstdint> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "test/gmock.h" + +namespace dcsctp { + +class MockContext : public Context { + public: + static constexpr TSN MyInitialTsn() { return TSN(990); } + static constexpr TSN PeerInitialTsn() { return TSN(10); } + static constexpr VerificationTag PeerVerificationTag() { + return VerificationTag(0x01234567); + } + + explicit MockContext(MockDcSctpSocketCallbacks* callbacks) + : callbacks_(*callbacks) { + ON_CALL(*this, is_connection_established) + .WillByDefault(testing::Return(true)); + ON_CALL(*this, my_initial_tsn) + .WillByDefault(testing::Return(MyInitialTsn())); + ON_CALL(*this, peer_initial_tsn) + .WillByDefault(testing::Return(PeerInitialTsn())); + ON_CALL(*this, callbacks).WillByDefault(testing::ReturnRef(callbacks_)); + ON_CALL(*this, current_rto).WillByDefault(testing::Return(DurationMs(123))); + ON_CALL(*this, Send).WillByDefault([this](SctpPacket::Builder& builder) { + callbacks_.SendPacketWithStatus(builder.Build()); + }); + } + + MOCK_METHOD(bool, is_connection_established, (), (const, override)); + MOCK_METHOD(TSN, my_initial_tsn, (), (const, override)); + MOCK_METHOD(TSN, peer_initial_tsn, (), (const, override)); + MOCK_METHOD(DcSctpSocketCallbacks&, callbacks, (), (const, override)); + + MOCK_METHOD(void, ObserveRTT, (DurationMs rtt_ms), (override)); + MOCK_METHOD(DurationMs, current_rto, (), (const, override)); + MOCK_METHOD(bool, + IncrementTxErrorCounter, + (absl::string_view reason), + (override)); + MOCK_METHOD(void, ClearTxErrorCounter, (), (override)); + MOCK_METHOD(bool, HasTooManyTxErrors, (), (const, override)); + SctpPacket::Builder PacketBuilder() const override { + return SctpPacket::Builder(PeerVerificationTag(), options_); + } + MOCK_METHOD(void, Send, (SctpPacket::Builder & builder), (override)); + + DcSctpOptions options_; + MockDcSctpSocketCallbacks& callbacks_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/third_party/libwebrtc/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h new file mode 100644 index 0000000000..150c1b9fa5 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_ +#define NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_ + +#include <cstdint> +#include <deque> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/timeout.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/timer/fake_timeout.h" +#include "rtc_base/logging.h" +#include "rtc_base/random.h" +#include "test/gmock.h" + +namespace dcsctp { + +namespace internal { +// It can be argued if a mocked random number generator should be deterministic +// or if it should be have as a "real" random number generator. In this +// implementation, each instantiation of `MockDcSctpSocketCallbacks` will have +// their `GetRandomInt` return different sequences, but each instantiation will +// always generate the same sequence of random numbers. This to make it easier +// to compare logs from tests, but still to let e.g. two different sockets (used +// in the same test) get different random numbers, so that they don't start e.g. +// on the same sequence number. While that isn't an issue in the protocol, it +// just makes debugging harder as the two sockets would look exactly the same. +// +// In a real implementation of `DcSctpSocketCallbacks` the random number +// generator backing `GetRandomInt` should be seeded externally and correctly. +inline int GetUniqueSeed() { + static int seed = 0; + return ++seed; +} +} // namespace internal + +class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks { + public: + explicit MockDcSctpSocketCallbacks(absl::string_view name = "") + : log_prefix_(name.empty() ? "" : std::string(name) + ": "), + random_(internal::GetUniqueSeed()), + timeout_manager_([this]() { return now_; }) { + ON_CALL(*this, SendPacketWithStatus) + .WillByDefault([this](rtc::ArrayView<const uint8_t> data) { + sent_packets_.emplace_back( + std::vector<uint8_t>(data.begin(), data.end())); + return SendPacketStatus::kSuccess; + }); + ON_CALL(*this, OnMessageReceived) + .WillByDefault([this](DcSctpMessage message) { + received_messages_.emplace_back(std::move(message)); + }); + + ON_CALL(*this, OnError) + .WillByDefault([this](ErrorKind error, absl::string_view message) { + RTC_LOG(LS_WARNING) + << log_prefix_ << "Socket error: " << ToString(error) << "; " + << message; + }); + ON_CALL(*this, OnAborted) + .WillByDefault([this](ErrorKind error, absl::string_view message) { + RTC_LOG(LS_WARNING) + << log_prefix_ << "Socket abort: " << ToString(error) << "; " + << message; + }); + ON_CALL(*this, TimeMillis).WillByDefault([this]() { return now_; }); + } + + MOCK_METHOD(SendPacketStatus, + SendPacketWithStatus, + (rtc::ArrayView<const uint8_t> data), + (override)); + + std::unique_ptr<Timeout> CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision) override { + // The fake timeout manager does not implement |precision|. + return timeout_manager_.CreateTimeout(); + } + + MOCK_METHOD(TimeMs, TimeMillis, (), (override)); + uint32_t GetRandomInt(uint32_t low, uint32_t high) override { + return random_.Rand(low, high); + } + + MOCK_METHOD(void, OnMessageReceived, (DcSctpMessage message), (override)); + MOCK_METHOD(void, + OnError, + (ErrorKind error, absl::string_view message), + (override)); + MOCK_METHOD(void, + OnAborted, + (ErrorKind error, absl::string_view message), + (override)); + MOCK_METHOD(void, OnConnected, (), (override)); + MOCK_METHOD(void, OnClosed, (), (override)); + MOCK_METHOD(void, OnConnectionRestarted, (), (override)); + MOCK_METHOD(void, + OnStreamsResetFailed, + (rtc::ArrayView<const StreamID> outgoing_streams, + absl::string_view reason), + (override)); + MOCK_METHOD(void, + OnStreamsResetPerformed, + (rtc::ArrayView<const StreamID> outgoing_streams), + (override)); + MOCK_METHOD(void, + OnIncomingStreamsReset, + (rtc::ArrayView<const StreamID> incoming_streams), + (override)); + MOCK_METHOD(void, OnBufferedAmountLow, (StreamID stream_id), (override)); + MOCK_METHOD(void, OnTotalBufferedAmountLow, (), (override)); + MOCK_METHOD(void, + OnLifecycleMessageExpired, + (LifecycleId lifecycle_id, bool maybe_delivered), + (override)); + MOCK_METHOD(void, + OnLifecycleMessageFullySent, + (LifecycleId lifecycle_id), + (override)); + MOCK_METHOD(void, + OnLifecycleMessageDelivered, + (LifecycleId lifecycle_id), + (override)); + MOCK_METHOD(void, OnLifecycleEnd, (LifecycleId lifecycle_id), (override)); + + bool HasPacket() const { return !sent_packets_.empty(); } + + std::vector<uint8_t> ConsumeSentPacket() { + if (sent_packets_.empty()) { + return {}; + } + std::vector<uint8_t> ret = std::move(sent_packets_.front()); + sent_packets_.pop_front(); + return ret; + } + absl::optional<DcSctpMessage> ConsumeReceivedMessage() { + if (received_messages_.empty()) { + return absl::nullopt; + } + DcSctpMessage ret = std::move(received_messages_.front()); + received_messages_.pop_front(); + return ret; + } + + void AdvanceTime(DurationMs duration_ms) { now_ = now_ + duration_ms; } + void SetTime(TimeMs now) { now_ = now; } + + absl::optional<TimeoutID> GetNextExpiredTimeout() { + return timeout_manager_.GetNextExpiredTimeout(); + } + + DurationMs GetTimeToNextTimeout() const { + return timeout_manager_.GetTimeToNextTimeout(); + } + + private: + const std::string log_prefix_; + TimeMs now_ = TimeMs(0); + webrtc::Random random_; + FakeTimeoutManager timeout_manager_; + std::deque<std::vector<uint8_t>> sent_packets_; + std::deque<DcSctpMessage> received_messages_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/packet_sender.cc b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.cc new file mode 100644 index 0000000000..f0134eea9b --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.cc @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/packet_sender.h" + +#include <utility> +#include <vector> + +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +PacketSender::PacketSender(DcSctpSocketCallbacks& callbacks, + std::function<void(rtc::ArrayView<const uint8_t>, + SendPacketStatus)> on_sent_packet) + : callbacks_(callbacks), on_sent_packet_(std::move(on_sent_packet)) {} + +bool PacketSender::Send(SctpPacket::Builder& builder, bool write_checksum) { + if (builder.empty()) { + return false; + } + + std::vector<uint8_t> payload = builder.Build(write_checksum); + + SendPacketStatus status = callbacks_.SendPacketWithStatus(payload); + on_sent_packet_(payload, status); + switch (status) { + case SendPacketStatus::kSuccess: { + return true; + } + case SendPacketStatus::kTemporaryFailure: { + // TODO(boivie): Queue this packet to be retried to be sent later. + return false; + } + + case SendPacketStatus::kError: { + // Nothing that can be done. + return false; + } + } +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/packet_sender.h b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.h new file mode 100644 index 0000000000..395c2efcba --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_PACKET_SENDER_H_ +#define NET_DCSCTP_SOCKET_PACKET_SENDER_H_ + +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_socket.h" + +namespace dcsctp { + +// The PacketSender sends packets to the network using the provided callback +// interface. When an attempt to send a packet is made, the `on_sent_packet` +// callback will be triggered. +class PacketSender { + public: + PacketSender(DcSctpSocketCallbacks& callbacks, + std::function<void(rtc::ArrayView<const uint8_t>, + SendPacketStatus)> on_sent_packet); + + // Sends the packet, and returns true if it was sent successfully. + bool Send(SctpPacket::Builder& builder, bool write_checksum = true); + + private: + DcSctpSocketCallbacks& callbacks_; + + // Callback that will be triggered for every send attempt, indicating the + // status of the operation. + std::function<void(rtc::ArrayView<const uint8_t>, SendPacketStatus)> + on_sent_packet_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_PACKET_SENDER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/packet_sender_test.cc b/third_party/libwebrtc/net/dcsctp/socket/packet_sender_test.cc new file mode 100644 index 0000000000..079dc36a41 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/packet_sender_test.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/packet_sender.h" + +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::_; + +constexpr VerificationTag kVerificationTag(123); + +class PacketSenderTest : public testing::Test { + protected: + PacketSenderTest() : sender_(callbacks_, on_send_fn_.AsStdFunction()) {} + + SctpPacket::Builder PacketBuilder() const { + return SctpPacket::Builder(kVerificationTag, options_); + } + + DcSctpOptions options_; + testing::NiceMock<MockDcSctpSocketCallbacks> callbacks_; + testing::MockFunction<void(rtc::ArrayView<const uint8_t>, SendPacketStatus)> + on_send_fn_; + PacketSender sender_; +}; + +TEST_F(PacketSenderTest, SendPacketCallsCallback) { + EXPECT_CALL(on_send_fn_, Call(_, SendPacketStatus::kSuccess)); + EXPECT_TRUE(sender_.Send(PacketBuilder().Add(CookieAckChunk()))); + + EXPECT_CALL(callbacks_, SendPacketWithStatus) + .WillOnce(testing::Return(SendPacketStatus::kError)); + EXPECT_CALL(on_send_fn_, Call(_, SendPacketStatus::kError)); + EXPECT_FALSE(sender_.Send(PacketBuilder().Add(CookieAckChunk()))); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc new file mode 100644 index 0000000000..624d783a3b --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/state_cookie.h" + +#include <cstdint> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/socket/capabilities.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +// Magic values, which the state cookie is prefixed with. +constexpr uint32_t kMagic1 = 1684230979; +constexpr uint32_t kMagic2 = 1414541360; +constexpr size_t StateCookie::kCookieSize; + +std::vector<uint8_t> StateCookie::Serialize() { + std::vector<uint8_t> cookie; + cookie.resize(kCookieSize); + BoundedByteWriter<kCookieSize> buffer(cookie); + buffer.Store32<0>(kMagic1); + buffer.Store32<4>(kMagic2); + buffer.Store32<8>(*initiate_tag_); + buffer.Store32<12>(*initial_tsn_); + buffer.Store32<16>(a_rwnd_); + buffer.Store32<20>(static_cast<uint32_t>(*tie_tag_ >> 32)); + buffer.Store32<24>(static_cast<uint32_t>(*tie_tag_)); + buffer.Store8<28>(capabilities_.partial_reliability); + buffer.Store8<29>(capabilities_.message_interleaving); + buffer.Store8<30>(capabilities_.reconfig); + buffer.Store16<32>(capabilities_.negotiated_maximum_incoming_streams); + buffer.Store16<34>(capabilities_.negotiated_maximum_outgoing_streams); + buffer.Store8<36>(capabilities_.zero_checksum); + return cookie; +} + +absl::optional<StateCookie> StateCookie::Deserialize( + rtc::ArrayView<const uint8_t> cookie) { + if (cookie.size() != kCookieSize) { + RTC_DLOG(LS_WARNING) << "Invalid state cookie: " << cookie.size() + << " bytes"; + return absl::nullopt; + } + + BoundedByteReader<kCookieSize> buffer(cookie); + uint32_t magic1 = buffer.Load32<0>(); + uint32_t magic2 = buffer.Load32<4>(); + if (magic1 != kMagic1 || magic2 != kMagic2) { + RTC_DLOG(LS_WARNING) << "Invalid state cookie; wrong magic"; + return absl::nullopt; + } + + VerificationTag verification_tag(buffer.Load32<8>()); + TSN initial_tsn(buffer.Load32<12>()); + uint32_t a_rwnd = buffer.Load32<16>(); + uint32_t tie_tag_upper = buffer.Load32<20>(); + uint32_t tie_tag_lower = buffer.Load32<24>(); + TieTag tie_tag(static_cast<uint64_t>(tie_tag_upper) << 32 | + static_cast<uint64_t>(tie_tag_lower)); + Capabilities capabilities; + capabilities.partial_reliability = buffer.Load8<28>() != 0; + capabilities.message_interleaving = buffer.Load8<29>() != 0; + capabilities.reconfig = buffer.Load8<30>() != 0; + capabilities.negotiated_maximum_incoming_streams = buffer.Load16<32>(); + capabilities.negotiated_maximum_outgoing_streams = buffer.Load16<34>(); + capabilities.zero_checksum = buffer.Load8<36>() != 0; + + return StateCookie(verification_tag, initial_tsn, a_rwnd, tie_tag, + capabilities); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h new file mode 100644 index 0000000000..34cd6d3690 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_STATE_COOKIE_H_ +#define NET_DCSCTP_SOCKET_STATE_COOKIE_H_ + +#include <cstdint> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/socket/capabilities.h" + +namespace dcsctp { + +// This is serialized as a state cookie and put in INIT_ACK. The client then +// responds with this in COOKIE_ECHO. +// +// NOTE: Expect that the client will modify it to try to exploit the library. +// Do not trust anything in it; no pointers or anything like that. +class StateCookie { + public: + static constexpr size_t kCookieSize = 37; + + StateCookie(VerificationTag initiate_tag, + TSN initial_tsn, + uint32_t a_rwnd, + TieTag tie_tag, + Capabilities capabilities) + : initiate_tag_(initiate_tag), + initial_tsn_(initial_tsn), + a_rwnd_(a_rwnd), + tie_tag_(tie_tag), + capabilities_(capabilities) {} + + // Returns a serialized version of this cookie. + std::vector<uint8_t> Serialize(); + + // Deserializes the cookie, and returns absl::nullopt if that failed. + static absl::optional<StateCookie> Deserialize( + rtc::ArrayView<const uint8_t> cookie); + + VerificationTag initiate_tag() const { return initiate_tag_; } + TSN initial_tsn() const { return initial_tsn_; } + uint32_t a_rwnd() const { return a_rwnd_; } + TieTag tie_tag() const { return tie_tag_; } + const Capabilities& capabilities() const { return capabilities_; } + + private: + const VerificationTag initiate_tag_; + const TSN initial_tsn_; + const uint32_t a_rwnd_; + const TieTag tie_tag_; + const Capabilities capabilities_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_STATE_COOKIE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc b/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc new file mode 100644 index 0000000000..19be71a1ca --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/state_cookie.h" + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; + +TEST(StateCookieTest, SerializeAndDeserialize) { + Capabilities capabilities = {.partial_reliability = true, + .message_interleaving = false, + .reconfig = true, + .zero_checksum = true, + .negotiated_maximum_incoming_streams = 123, + .negotiated_maximum_outgoing_streams = 234}; + StateCookie cookie(VerificationTag(123), TSN(456), + /*a_rwnd=*/789, TieTag(101112), capabilities); + std::vector<uint8_t> serialized = cookie.Serialize(); + EXPECT_THAT(serialized, SizeIs(StateCookie::kCookieSize)); + ASSERT_HAS_VALUE_AND_ASSIGN(StateCookie deserialized, + StateCookie::Deserialize(serialized)); + EXPECT_EQ(deserialized.initiate_tag(), VerificationTag(123)); + EXPECT_EQ(deserialized.initial_tsn(), TSN(456)); + EXPECT_EQ(deserialized.a_rwnd(), 789u); + EXPECT_EQ(deserialized.tie_tag(), TieTag(101112)); + EXPECT_TRUE(deserialized.capabilities().partial_reliability); + EXPECT_FALSE(deserialized.capabilities().message_interleaving); + EXPECT_TRUE(deserialized.capabilities().reconfig); + EXPECT_TRUE(deserialized.capabilities().zero_checksum); + EXPECT_EQ(deserialized.capabilities().negotiated_maximum_incoming_streams, + 123); + EXPECT_EQ(deserialized.capabilities().negotiated_maximum_outgoing_streams, + 234); +} + +TEST(StateCookieTest, ValidateMagicValue) { + Capabilities capabilities = {.partial_reliability = true, + .message_interleaving = false, + .reconfig = true}; + StateCookie cookie(VerificationTag(123), TSN(456), + /*a_rwnd=*/789, TieTag(101112), capabilities); + std::vector<uint8_t> serialized = cookie.Serialize(); + ASSERT_THAT(serialized, SizeIs(StateCookie::kCookieSize)); + + absl::string_view magic(reinterpret_cast<const char*>(serialized.data()), 8); + EXPECT_EQ(magic, "dcSCTP00"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.cc b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.cc new file mode 100644 index 0000000000..2094309afe --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.cc @@ -0,0 +1,385 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/stream_reset_handler.h" + +#include <cstdint> +#include <memory> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h" +#include "net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace { +using ResponseResult = ReconfigurationResponseParameter::Result; + +bool DescriptorsAre(const std::vector<ParameterDescriptor>& c, + uint16_t e1, + uint16_t e2) { + return (c[0].type == e1 && c[1].type == e2) || + (c[0].type == e2 && c[1].type == e1); +} + +} // namespace + +bool StreamResetHandler::Validate(const ReConfigChunk& chunk) { + const Parameters& parameters = chunk.parameters(); + + // https://tools.ietf.org/html/rfc6525#section-3.1 + // "Note that each RE-CONFIG chunk holds at least one parameter + // and at most two parameters. Only the following combinations are allowed:" + std::vector<ParameterDescriptor> descriptors = parameters.descriptors(); + if (descriptors.size() == 1) { + if ((descriptors[0].type == OutgoingSSNResetRequestParameter::kType) || + (descriptors[0].type == IncomingSSNResetRequestParameter::kType) || + (descriptors[0].type == SSNTSNResetRequestParameter::kType) || + (descriptors[0].type == AddOutgoingStreamsRequestParameter::kType) || + (descriptors[0].type == AddIncomingStreamsRequestParameter::kType) || + (descriptors[0].type == ReconfigurationResponseParameter::kType)) { + return true; + } + } else if (descriptors.size() == 2) { + if (DescriptorsAre(descriptors, OutgoingSSNResetRequestParameter::kType, + IncomingSSNResetRequestParameter::kType) || + DescriptorsAre(descriptors, AddOutgoingStreamsRequestParameter::kType, + AddIncomingStreamsRequestParameter::kType) || + DescriptorsAre(descriptors, ReconfigurationResponseParameter::kType, + OutgoingSSNResetRequestParameter::kType) || + DescriptorsAre(descriptors, ReconfigurationResponseParameter::kType, + ReconfigurationResponseParameter::kType)) { + return true; + } + } + + RTC_LOG(LS_WARNING) << "Invalid set of RE-CONFIG parameters"; + return false; +} + +absl::optional<std::vector<ReconfigurationResponseParameter>> +StreamResetHandler::Process(const ReConfigChunk& chunk) { + if (!Validate(chunk)) { + return absl::nullopt; + } + + std::vector<ReconfigurationResponseParameter> responses; + + for (const ParameterDescriptor& desc : chunk.parameters().descriptors()) { + switch (desc.type) { + case OutgoingSSNResetRequestParameter::kType: + HandleResetOutgoing(desc, responses); + break; + + case IncomingSSNResetRequestParameter::kType: + HandleResetIncoming(desc, responses); + break; + + case ReconfigurationResponseParameter::kType: + HandleResponse(desc); + break; + } + } + + return responses; +} + +void StreamResetHandler::HandleReConfig(ReConfigChunk chunk) { + absl::optional<std::vector<ReconfigurationResponseParameter>> responses = + Process(chunk); + + if (!responses.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse RE-CONFIG command"); + return; + } + + if (!responses->empty()) { + SctpPacket::Builder b = ctx_->PacketBuilder(); + Parameters::Builder params_builder; + for (const auto& response : *responses) { + params_builder.Add(response); + } + b.Add(ReConfigChunk(params_builder.Build())); + ctx_->Send(b); + } +} + +bool StreamResetHandler::ValidateReqSeqNbr( + UnwrappedReconfigRequestSn req_seq_nbr, + std::vector<ReconfigurationResponseParameter>& responses) { + if (req_seq_nbr == last_processed_req_seq_nbr_) { + // https://www.rfc-editor.org/rfc/rfc6525.html#section-5.2.1 "If the + // received RE-CONFIG chunk contains at least one request and based on the + // analysis of the Re-configuration Request Sequence Numbers this is the + // last received RE-CONFIG chunk (i.e., a retransmission), the same + // RE-CONFIG chunk MUST to be sent back in response, as it was earlier." + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "req=" << *req_seq_nbr + << " already processed, returning result=" + << ToString(last_processed_req_result_); + responses.push_back(ReconfigurationResponseParameter( + req_seq_nbr.Wrap(), last_processed_req_result_)); + return false; + } + + if (req_seq_nbr != last_processed_req_seq_nbr_.next_value()) { + // Too old, too new, from wrong association etc. + // This is expected to happen when handing over a RTCPeerConnection from one + // server to another. The client will notice this and may decide to close + // old data channels, which may be sent to the wrong (or both) servers + // during a handover. + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "req=" << *req_seq_nbr + << " bad seq_nbr"; + responses.push_back(ReconfigurationResponseParameter( + req_seq_nbr.Wrap(), ResponseResult::kErrorBadSequenceNumber)); + return false; + } + + return true; +} + +void StreamResetHandler::HandleResetOutgoing( + const ParameterDescriptor& descriptor, + std::vector<ReconfigurationResponseParameter>& responses) { + absl::optional<OutgoingSSNResetRequestParameter> req = + OutgoingSSNResetRequestParameter::Parse(descriptor.data); + if (!req.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse Outgoing Reset command"); + return; + } + + UnwrappedReconfigRequestSn request_sn = + incoming_reconfig_request_sn_unwrapper_.Unwrap( + req->request_sequence_number()); + + if (ValidateReqSeqNbr(request_sn, responses)) { + last_processed_req_seq_nbr_ = request_sn; + if (data_tracker_->IsLaterThanCumulativeAckedTsn( + req->sender_last_assigned_tsn())) { + // https://datatracker.ietf.org/doc/html/rfc6525#section-5.2.2 + // E2) "If the Sender's Last Assigned TSN is greater than the cumulative + // acknowledgment point, then the endpoint MUST enter 'deferred reset + // processing'." + reassembly_queue_->EnterDeferredReset(req->sender_last_assigned_tsn(), + req->stream_ids()); + // "If the endpoint enters 'deferred reset processing', it MUST put a + // Re-configuration Response Parameter into a RE-CONFIG chunk indicating + // 'In progress' and MUST send the RE-CONFIG chunk. + last_processed_req_result_ = ResponseResult::kInProgress; + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Reset outgoing; Sender last_assigned=" + << *req->sender_last_assigned_tsn() + << " - not yet reached -> InProgress"; + } else { + // https://datatracker.ietf.org/doc/html/rfc6525#section-5.2.2 + // E3) If no stream numbers are listed in the parameter, then all incoming + // streams MUST be reset to 0 as the next expected SSN. If specific stream + // numbers are listed, then only these specific streams MUST be reset to + // 0, and all other non-listed SSNs remain unchanged. E4: Any queued TSNs + // (queued at step E2) MUST now be released and processed normally. + reassembly_queue_->ResetStreamsAndLeaveDeferredReset(req->stream_ids()); + ctx_->callbacks().OnIncomingStreamsReset(req->stream_ids()); + last_processed_req_result_ = ResponseResult::kSuccessPerformed; + + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Reset outgoing; Sender last_assigned=" + << *req->sender_last_assigned_tsn() + << " - reached -> SuccessPerformed"; + } + responses.push_back(ReconfigurationResponseParameter( + req->request_sequence_number(), last_processed_req_result_)); + } +} + +void StreamResetHandler::HandleResetIncoming( + const ParameterDescriptor& descriptor, + std::vector<ReconfigurationResponseParameter>& responses) { + absl::optional<IncomingSSNResetRequestParameter> req = + IncomingSSNResetRequestParameter::Parse(descriptor.data); + if (!req.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse Incoming Reset command"); + return; + } + + UnwrappedReconfigRequestSn request_sn = + incoming_reconfig_request_sn_unwrapper_.Unwrap( + req->request_sequence_number()); + + if (ValidateReqSeqNbr(request_sn, responses)) { + responses.push_back(ReconfigurationResponseParameter( + req->request_sequence_number(), ResponseResult::kSuccessNothingToDo)); + last_processed_req_seq_nbr_ = request_sn; + } +} + +void StreamResetHandler::HandleResponse(const ParameterDescriptor& descriptor) { + absl::optional<ReconfigurationResponseParameter> resp = + ReconfigurationResponseParameter::Parse(descriptor.data); + if (!resp.has_value()) { + ctx_->callbacks().OnError( + ErrorKind::kParseFailed, + "Failed to parse Reconfiguration Response command"); + return; + } + + if (current_request_.has_value() && current_request_->has_been_sent() && + resp->response_sequence_number() == current_request_->req_seq_nbr()) { + reconfig_timer_->Stop(); + + switch (resp->result()) { + case ResponseResult::kSuccessNothingToDo: + case ResponseResult::kSuccessPerformed: + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Reset stream success, req_seq_nbr=" + << *current_request_->req_seq_nbr() << ", streams=" + << StrJoin(current_request_->streams(), ",", + [](rtc::StringBuilder& sb, StreamID stream_id) { + sb << *stream_id; + }); + ctx_->callbacks().OnStreamsResetPerformed(current_request_->streams()); + current_request_ = absl::nullopt; + retransmission_queue_->CommitResetStreams(); + break; + case ResponseResult::kInProgress: + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Reset stream still pending, req_seq_nbr=" + << *current_request_->req_seq_nbr() << ", streams=" + << StrJoin(current_request_->streams(), ",", + [](rtc::StringBuilder& sb, StreamID stream_id) { + sb << *stream_id; + }); + // Force this request to be sent again, but with new req_seq_nbr. + current_request_->PrepareRetransmission(); + reconfig_timer_->set_duration(ctx_->current_rto()); + reconfig_timer_->Start(); + break; + case ResponseResult::kErrorRequestAlreadyInProgress: + case ResponseResult::kDenied: + case ResponseResult::kErrorWrongSSN: + case ResponseResult::kErrorBadSequenceNumber: + RTC_DLOG(LS_WARNING) + << log_prefix_ << "Reset stream error=" << ToString(resp->result()) + << ", req_seq_nbr=" << *current_request_->req_seq_nbr() + << ", streams=" + << StrJoin(current_request_->streams(), ",", + [](rtc::StringBuilder& sb, StreamID stream_id) { + sb << *stream_id; + }); + ctx_->callbacks().OnStreamsResetFailed(current_request_->streams(), + ToString(resp->result())); + current_request_ = absl::nullopt; + retransmission_queue_->RollbackResetStreams(); + break; + } + } +} + +absl::optional<ReConfigChunk> StreamResetHandler::MakeStreamResetRequest() { + // Only send stream resets if there are streams to reset, and no current + // ongoing request (there can only be one at a time), and if the stream + // can be reset. + if (current_request_.has_value() || + !retransmission_queue_->HasStreamsReadyToBeReset()) { + return absl::nullopt; + } + + current_request_.emplace(retransmission_queue_->last_assigned_tsn(), + retransmission_queue_->BeginResetStreams()); + reconfig_timer_->set_duration(ctx_->current_rto()); + reconfig_timer_->Start(); + return MakeReconfigChunk(); +} + +ReConfigChunk StreamResetHandler::MakeReconfigChunk() { + // The req_seq_nbr will be empty if the request has never been sent before, + // or if it was sent, but the sender responded "in progress", and then the + // req_seq_nbr will be cleared to re-send with a new number. But if the + // request is re-sent due to timeout (reconfig-timer expiring), the same + // req_seq_nbr will be used. + RTC_DCHECK(current_request_.has_value()); + + if (!current_request_->has_been_sent()) { + current_request_->PrepareToSend(next_outgoing_req_seq_nbr_); + next_outgoing_req_seq_nbr_ = + ReconfigRequestSN(*next_outgoing_req_seq_nbr_ + 1); + } + + Parameters::Builder params_builder = + Parameters::Builder().Add(OutgoingSSNResetRequestParameter( + current_request_->req_seq_nbr(), current_request_->req_seq_nbr(), + current_request_->sender_last_assigned_tsn(), + current_request_->streams())); + + return ReConfigChunk(params_builder.Build()); +} + +void StreamResetHandler::ResetStreams( + rtc::ArrayView<const StreamID> outgoing_streams) { + for (StreamID stream_id : outgoing_streams) { + retransmission_queue_->PrepareResetStream(stream_id); + } +} + +absl::optional<DurationMs> StreamResetHandler::OnReconfigTimerExpiry() { + if (current_request_->has_been_sent()) { + // There is an outstanding request, which timed out while waiting for a + // response. + if (!ctx_->IncrementTxErrorCounter("RECONFIG timeout")) { + // Timed out. The connection will close after processing the timers. + return absl::nullopt; + } + } else { + // There is no outstanding request, but there is a prepared one. This means + // that the receiver has previously responded "in progress", which resulted + // in retrying the request (but with a new req_seq_nbr) after a while. + } + + ctx_->Send(ctx_->PacketBuilder().Add(MakeReconfigChunk())); + return ctx_->current_rto(); +} + +HandoverReadinessStatus StreamResetHandler::GetHandoverReadiness() const { + HandoverReadinessStatus status; + if (retransmission_queue_->HasStreamsReadyToBeReset()) { + status.Add(HandoverUnreadinessReason::kPendingStreamReset); + } + if (current_request_.has_value()) { + status.Add(HandoverUnreadinessReason::kPendingStreamResetRequest); + } + return status; +} + +void StreamResetHandler::AddHandoverState(DcSctpSocketHandoverState& state) { + state.rx.last_completed_reset_req_sn = + last_processed_req_seq_nbr_.Wrap().value(); + state.tx.next_reset_req_sn = next_outgoing_req_seq_nbr_.value(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.h b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.h new file mode 100644 index 0000000000..c335130175 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.h @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_ +#define NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/functional/bind_front.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "rtc_base/containers/flat_set.h" + +namespace dcsctp { + +// StreamResetHandler handles sending outgoing stream reset requests (to close +// an SCTP stream, which translates to closing a data channel). +// +// It also handles incoming "outgoing stream reset requests", when the peer +// wants to close its data channel. +// +// Resetting streams is an asynchronous operation where the client will request +// a request a stream to be reset, but then it might not be performed exactly at +// this point. First, the sender might need to discard all messages that have +// been enqueued for this stream, or it may select to wait until all have been +// sent. At least, it must wait for the currently sending fragmented message to +// be fully sent, because a stream can't be reset while having received half a +// message. In the stream reset request, the "sender's last assigned TSN" is +// provided, which is simply the TSN for which the receiver should've received +// all messages before this value, before the stream can be reset. Since +// fragments can get lost or sent out-of-order, the receiver of a request may +// not have received all the data just yet, and then it will respond to the +// sender: "In progress". In other words, try again. The sender will then need +// to start a timer and try the very same request again (but with a new sequence +// number) until the receiver successfully performs the operation. +// +// All this can take some time, and may be driven by timers, so the client will +// ultimately be notified using callbacks. +// +// In this implementation, when a stream is reset, the queued but not-yet-sent +// messages will be discarded, but that may change in the future. RFC8831 allows +// both behaviors. +class StreamResetHandler { + public: + StreamResetHandler(absl::string_view log_prefix, + Context* context, + TimerManager* timer_manager, + DataTracker* data_tracker, + ReassemblyQueue* reassembly_queue, + RetransmissionQueue* retransmission_queue, + const DcSctpSocketHandoverState* handover_state = nullptr) + : log_prefix_(log_prefix), + ctx_(context), + data_tracker_(data_tracker), + reassembly_queue_(reassembly_queue), + retransmission_queue_(retransmission_queue), + reconfig_timer_(timer_manager->CreateTimer( + "re-config", + absl::bind_front(&StreamResetHandler::OnReconfigTimerExpiry, this), + TimerOptions(DurationMs(0)))), + next_outgoing_req_seq_nbr_( + handover_state + ? ReconfigRequestSN(handover_state->tx.next_reset_req_sn) + : ReconfigRequestSN(*ctx_->my_initial_tsn())), + last_processed_req_seq_nbr_( + incoming_reconfig_request_sn_unwrapper_.Unwrap( + handover_state + ? ReconfigRequestSN( + handover_state->rx.last_completed_reset_req_sn) + : ReconfigRequestSN(*ctx_->peer_initial_tsn() - 1))), + last_processed_req_result_( + ReconfigurationResponseParameter::Result::kSuccessNothingToDo) {} + + // Initiates reset of the provided streams. While there can only be one + // ongoing stream reset request at any time, this method can be called at any + // time and also multiple times. It will enqueue requests that can't be + // directly fulfilled, and will asynchronously process them when any ongoing + // request has completed. + void ResetStreams(rtc::ArrayView<const StreamID> outgoing_streams); + + // Creates a Reset Streams request that must be sent if returned. Will start + // the reconfig timer. Will return absl::nullopt if there is no need to + // create a request (no streams to reset) or if there already is an ongoing + // stream reset request that hasn't completed yet. + absl::optional<ReConfigChunk> MakeStreamResetRequest(); + + // Called when handling and incoming RE-CONFIG chunk. + void HandleReConfig(ReConfigChunk chunk); + + HandoverReadinessStatus GetHandoverReadiness() const; + + void AddHandoverState(DcSctpSocketHandoverState& state); + + private: + using UnwrappedReconfigRequestSn = UnwrappedSequenceNumber<ReconfigRequestSN>; + // Represents a stream request operation. There can only be one ongoing at + // any time, and a sent request may either succeed, fail or result in the + // receiver signaling that it can't process it right now, and then it will be + // retried. + class CurrentRequest { + public: + CurrentRequest(TSN sender_last_assigned_tsn, std::vector<StreamID> streams) + : req_seq_nbr_(absl::nullopt), + sender_last_assigned_tsn_(sender_last_assigned_tsn), + streams_(std::move(streams)) {} + + // Returns the current request sequence number, if this request has been + // sent (check `has_been_sent` first). Will return 0 if the request is just + // prepared (or scheduled for retransmission) but not yet sent. + ReconfigRequestSN req_seq_nbr() const { + return req_seq_nbr_.value_or(ReconfigRequestSN(0)); + } + + // The sender's last assigned TSN, from the retransmission queue. The + // receiver uses this to know when all data up to this TSN has been + // received, to know when to safely reset the stream. + TSN sender_last_assigned_tsn() const { return sender_last_assigned_tsn_; } + + // The streams that are to be reset. + const std::vector<StreamID>& streams() const { return streams_; } + + // If this request has been sent yet. If not, then it's either because it + // has only been prepared and not yet sent, or because the received couldn't + // apply the request, and then the exact same request will be retried, but + // with a new sequence number. + bool has_been_sent() const { return req_seq_nbr_.has_value(); } + + // If the receiver can't apply the request yet (and answered "In Progress"), + // this will be called to prepare the request to be retransmitted at a later + // time. + void PrepareRetransmission() { req_seq_nbr_ = absl::nullopt; } + + // If the request hasn't been sent yet, this assigns it a request number. + void PrepareToSend(ReconfigRequestSN new_req_seq_nbr) { + req_seq_nbr_ = new_req_seq_nbr; + } + + private: + // If this is set, this request has been sent. If it's not set, the request + // has been prepared, but has not yet been sent. This is typically used when + // the peer responded "in progress" and the same request (but a different + // request number) must be sent again. + absl::optional<ReconfigRequestSN> req_seq_nbr_; + // The sender's (that's us) last assigned TSN, from the retransmission + // queue. + TSN sender_last_assigned_tsn_; + // The streams that are to be reset in this request. + const std::vector<StreamID> streams_; + }; + + // Called to validate an incoming RE-CONFIG chunk. + bool Validate(const ReConfigChunk& chunk); + + // Processes a stream stream reconfiguration chunk and may either return + // absl::nullopt (on protocol errors), or a list of responses - either 0, 1 + // or 2. + absl::optional<std::vector<ReconfigurationResponseParameter>> Process( + const ReConfigChunk& chunk); + + // Creates the actual RE-CONFIG chunk. A request (which set `current_request`) + // must have been created prior. + ReConfigChunk MakeReconfigChunk(); + + // Called to validate the `req_seq_nbr`, that it's the next in sequence. If it + // fails to validate, and returns false, it will also add a response to + // `responses`. + bool ValidateReqSeqNbr( + UnwrappedReconfigRequestSn req_seq_nbr, + std::vector<ReconfigurationResponseParameter>& responses); + + // Called when this socket receives an outgoing stream reset request. It might + // either be performed straight away, or have to be deferred, and the result + // of that will be put in `responses`. + void HandleResetOutgoing( + const ParameterDescriptor& descriptor, + std::vector<ReconfigurationResponseParameter>& responses); + + // Called when this socket receives an incoming stream reset request. This + // isn't really supported, but a successful response is put in `responses`. + void HandleResetIncoming( + const ParameterDescriptor& descriptor, + std::vector<ReconfigurationResponseParameter>& responses); + + // Called when receiving a response to an outgoing stream reset request. It + // will either commit the stream resetting, if the operation was successful, + // or will schedule a retry if it was deferred. And if it failed, the + // operation will be rolled back. + void HandleResponse(const ParameterDescriptor& descriptor); + + // Expiration handler for the Reconfig timer. + absl::optional<DurationMs> OnReconfigTimerExpiry(); + + const absl::string_view log_prefix_; + Context* ctx_; + DataTracker* data_tracker_; + ReassemblyQueue* reassembly_queue_; + RetransmissionQueue* retransmission_queue_; + UnwrappedReconfigRequestSn::Unwrapper incoming_reconfig_request_sn_unwrapper_; + const std::unique_ptr<Timer> reconfig_timer_; + + // The next sequence number for outgoing stream requests. + ReconfigRequestSN next_outgoing_req_seq_nbr_; + + // The current stream request operation. + absl::optional<CurrentRequest> current_request_; + + // For incoming requests - last processed request sequence number. + UnwrappedReconfigRequestSn last_processed_req_seq_nbr_; + // The result from last processed incoming request + ReconfigurationResponseParameter::Result last_processed_req_result_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler_test.cc b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler_test.cc new file mode 100644 index 0000000000..091d717f8a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler_test.cc @@ -0,0 +1,916 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/stream_reset_handler.h" + +#include <array> +#include <cstdint> +#include <memory> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/mock_context.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "net/dcsctp/testing/data_generator.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/mock_send_queue.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::IsEmpty; +using ::testing::NiceMock; +using ::testing::Property; +using ::testing::Return; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; +using ResponseResult = ReconfigurationResponseParameter::Result; +using SkippedStream = AnyForwardTsnChunk::SkippedStream; + +constexpr TSN kMyInitialTsn = MockContext::MyInitialTsn(); +constexpr ReconfigRequestSN kMyInitialReqSn = ReconfigRequestSN(*kMyInitialTsn); +constexpr TSN kPeerInitialTsn = MockContext::PeerInitialTsn(); +constexpr ReconfigRequestSN kPeerInitialReqSn = + ReconfigRequestSN(*kPeerInitialTsn); +constexpr uint32_t kArwnd = 131072; +constexpr DurationMs kRto = DurationMs(250); + +constexpr std::array<uint8_t, 4> kShortPayload = {1, 2, 3, 4}; + +MATCHER_P3(SctpMessageIs, stream_id, ppid, expected_payload, "") { + if (arg.stream_id() != stream_id) { + *result_listener << "the stream_id is " << *arg.stream_id(); + return false; + } + + if (arg.ppid() != ppid) { + *result_listener << "the ppid is " << *arg.ppid(); + return false; + } + + if (std::vector<uint8_t>(arg.payload().begin(), arg.payload().end()) != + std::vector<uint8_t>(expected_payload.begin(), expected_payload.end())) { + *result_listener << "the payload is wrong"; + return false; + } + return true; +} + +TSN AddTo(TSN tsn, int delta) { + return TSN(*tsn + delta); +} + +ReconfigRequestSN AddTo(ReconfigRequestSN req_sn, int delta) { + return ReconfigRequestSN(*req_sn + delta); +} + +class StreamResetHandlerTest : public testing::Test { + protected: + StreamResetHandlerTest() + : ctx_(&callbacks_), + timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return callbacks_.CreateTimeout(precision); + }), + delayed_ack_timer_(timer_manager_.CreateTimer( + "test/delayed_ack", + []() { return absl::nullopt; }, + TimerOptions(DurationMs(0)))), + t3_rtx_timer_(timer_manager_.CreateTimer( + "test/t3_rtx", + []() { return absl::nullopt; }, + TimerOptions(DurationMs(0)))), + data_tracker_(std::make_unique<DataTracker>("log: ", + delayed_ack_timer_.get(), + kPeerInitialTsn)), + reasm_(std::make_unique<ReassemblyQueue>("log: ", + kPeerInitialTsn, + kArwnd)), + retransmission_queue_(std::make_unique<RetransmissionQueue>( + "", + &callbacks_, + kMyInitialTsn, + kArwnd, + producer_, + [](DurationMs rtt_ms) {}, + []() {}, + *t3_rtx_timer_, + DcSctpOptions())), + handler_( + std::make_unique<StreamResetHandler>("log: ", + &ctx_, + &timer_manager_, + data_tracker_.get(), + reasm_.get(), + retransmission_queue_.get())) { + EXPECT_CALL(ctx_, current_rto).WillRepeatedly(Return(kRto)); + } + + void AdvanceTime(DurationMs duration) { + callbacks_.AdvanceTime(kRto); + for (;;) { + absl::optional<TimeoutID> timeout_id = callbacks_.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + timer_manager_.HandleTimeout(*timeout_id); + } + } + + // Handles the passed in RE-CONFIG `chunk` and returns the responses + // that are sent in the response RE-CONFIG. + std::vector<ReconfigurationResponseParameter> HandleAndCatchResponse( + ReConfigChunk chunk) { + handler_->HandleReConfig(std::move(chunk)); + + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + if (payload.empty()) { + EXPECT_TRUE(false); + return {}; + } + + std::vector<ReconfigurationResponseParameter> responses; + absl::optional<SctpPacket> p = SctpPacket::Parse(payload, DcSctpOptions()); + if (!p.has_value()) { + EXPECT_TRUE(false); + return {}; + } + if (p->descriptors().size() != 1) { + EXPECT_TRUE(false); + return {}; + } + absl::optional<ReConfigChunk> response_chunk = + ReConfigChunk::Parse(p->descriptors()[0].data); + if (!response_chunk.has_value()) { + EXPECT_TRUE(false); + return {}; + } + for (const auto& desc : response_chunk->parameters().descriptors()) { + if (desc.type == ReconfigurationResponseParameter::kType) { + absl::optional<ReconfigurationResponseParameter> response = + ReconfigurationResponseParameter::Parse(desc.data); + if (!response.has_value()) { + EXPECT_TRUE(false); + return {}; + } + responses.emplace_back(*std::move(response)); + } + } + return responses; + } + + void PerformHandover() { + EXPECT_TRUE(handler_->GetHandoverReadiness().IsReady()); + EXPECT_TRUE(data_tracker_->GetHandoverReadiness().IsReady()); + EXPECT_TRUE(reasm_->GetHandoverReadiness().IsReady()); + EXPECT_TRUE(retransmission_queue_->GetHandoverReadiness().IsReady()); + + DcSctpSocketHandoverState state; + handler_->AddHandoverState(state); + data_tracker_->AddHandoverState(state); + reasm_->AddHandoverState(state); + + retransmission_queue_->AddHandoverState(state); + + g_handover_state_transformer_for_test(&state); + + data_tracker_ = std::make_unique<DataTracker>( + "log: ", delayed_ack_timer_.get(), kPeerInitialTsn); + data_tracker_->RestoreFromState(state); + reasm_ = + std::make_unique<ReassemblyQueue>("log: ", kPeerInitialTsn, kArwnd); + reasm_->RestoreFromState(state); + retransmission_queue_ = std::make_unique<RetransmissionQueue>( + "", &callbacks_, kMyInitialTsn, kArwnd, producer_, + [](DurationMs rtt_ms) {}, []() {}, *t3_rtx_timer_, DcSctpOptions(), + /*supports_partial_reliability=*/true, + /*use_message_interleaving=*/false); + retransmission_queue_->RestoreFromState(state); + handler_ = std::make_unique<StreamResetHandler>( + "log: ", &ctx_, &timer_manager_, data_tracker_.get(), reasm_.get(), + retransmission_queue_.get(), &state); + } + + DataGenerator gen_; + NiceMock<MockDcSctpSocketCallbacks> callbacks_; + NiceMock<MockContext> ctx_; + NiceMock<MockSendQueue> producer_; + TimerManager timer_manager_; + std::unique_ptr<Timer> delayed_ack_timer_; + std::unique_ptr<Timer> t3_rtx_timer_; + std::unique_ptr<DataTracker> data_tracker_; + std::unique_ptr<ReassemblyQueue> reasm_; + std::unique_ptr<RetransmissionQueue> retransmission_queue_; + std::unique_ptr<StreamResetHandler> handler_; +}; + +TEST_F(StreamResetHandlerTest, ChunkWithNoParametersReturnsError) { + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + EXPECT_CALL(callbacks_, OnError).Times(1); + handler_->HandleReConfig(ReConfigChunk(Parameters())); +} + +TEST_F(StreamResetHandlerTest, ChunkWithInvalidParametersReturnsError) { + Parameters::Builder builder; + // Two OutgoingSSNResetRequestParameter in a RE-CONFIG is not valid. + builder.Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(1), + ReconfigRequestSN(10), + kPeerInitialTsn, {StreamID(1)})); + builder.Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(2), + ReconfigRequestSN(10), + kPeerInitialTsn, {StreamID(2)})); + + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + EXPECT_CALL(callbacks_, OnError).Times(1); + handler_->HandleReConfig(ReConfigChunk(builder.Build())); +} + +TEST_F(StreamResetHandlerTest, FailToDeliverWithoutResettingStream) { + reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE")); + reasm_->Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE")); + + data_tracker_->Observe(kPeerInitialTsn); + data_tracker_->Observe(AddTo(kPeerInitialTsn, 1)); + EXPECT_THAT(reasm_->FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload), + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); + + gen_.ResetStream(); + reasm_->Add(AddTo(kPeerInitialTsn, 2), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm_->FlushMessages(), IsEmpty()); +} + +TEST_F(StreamResetHandlerTest, ResetStreamsNotDeferred) { + reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE")); + reasm_->Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE")); + + data_tracker_->Observe(kPeerInitialTsn); + data_tracker_->Observe(AddTo(kPeerInitialTsn, 1)); + EXPECT_THAT(reasm_->FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload), + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); + + Parameters::Builder builder; + builder.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 1), + {StreamID(1)})); + + std::vector<ReconfigurationResponseParameter> responses = + HandleAndCatchResponse(ReConfigChunk(builder.Build())); + EXPECT_THAT(responses, SizeIs(1)); + EXPECT_EQ(responses[0].result(), ResponseResult::kSuccessPerformed); + + gen_.ResetStream(); + reasm_->Add(AddTo(kPeerInitialTsn, 2), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm_->FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); +} + +TEST_F(StreamResetHandlerTest, ResetStreamsDeferred) { + constexpr StreamID kStreamId = StreamID(1); + data_tracker_->Observe(TSN(10)); + reasm_->Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE", {.mid = MID(0)})); + + data_tracker_->Observe(TSN(11)); + reasm_->Add(TSN(11), gen_.Ordered({1, 2, 3, 4}, "BE", {.mid = MID(1)})); + + EXPECT_THAT( + reasm_->FlushMessages(), + UnorderedElementsAre(SctpMessageIs(kStreamId, PPID(53), kShortPayload), + SctpMessageIs(kStreamId, PPID(53), kShortPayload))); + + Parameters::Builder builder; + builder.Add(OutgoingSSNResetRequestParameter( + ReconfigRequestSN(10), ReconfigRequestSN(3), TSN(13), {kStreamId})); + EXPECT_THAT(HandleAndCatchResponse(ReConfigChunk(builder.Build())), + ElementsAre(Property(&ReconfigurationResponseParameter::result, + ResponseResult::kInProgress))); + + data_tracker_->Observe(TSN(15)); + reasm_->Add(TSN(15), gen_.Ordered({1, 2, 3, 4}, "BE", + {.mid = MID(1), .ppid = PPID(5)})); + + data_tracker_->Observe(TSN(14)); + reasm_->Add(TSN(14), gen_.Ordered({1, 2, 3, 4}, "BE", + {.mid = MID(0), .ppid = PPID(4)})); + + data_tracker_->Observe(TSN(13)); + reasm_->Add(TSN(13), gen_.Ordered({1, 2, 3, 4}, "BE", + {.mid = MID(3), .ppid = PPID(3)})); + + data_tracker_->Observe(TSN(12)); + reasm_->Add(TSN(12), gen_.Ordered({1, 2, 3, 4}, "BE", + {.mid = MID(2), .ppid = PPID(2)})); + + builder.Add(OutgoingSSNResetRequestParameter( + ReconfigRequestSN(11), ReconfigRequestSN(4), TSN(13), {kStreamId})); + EXPECT_THAT(HandleAndCatchResponse(ReConfigChunk(builder.Build())), + ElementsAre(Property(&ReconfigurationResponseParameter::result, + ResponseResult::kSuccessPerformed))); + + EXPECT_THAT( + reasm_->FlushMessages(), + UnorderedElementsAre(SctpMessageIs(kStreamId, PPID(2), kShortPayload), + SctpMessageIs(kStreamId, PPID(3), kShortPayload), + SctpMessageIs(kStreamId, PPID(4), kShortPayload), + SctpMessageIs(kStreamId, PPID(5), kShortPayload))); +} + +TEST_F(StreamResetHandlerTest, ResetStreamsDeferredOnlySelectedStreams) { + // This test verifies the receiving behavior of receiving messages on + // streams 1, 2 and 3, and receiving a reset request on stream 1, 2, causing + // deferred reset processing. + + // Reset stream 1,2 with "last assigned TSN=12" + Parameters::Builder builder; + builder.Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(10), + ReconfigRequestSN(3), TSN(12), + {StreamID(1), StreamID(2)})); + EXPECT_THAT(HandleAndCatchResponse(ReConfigChunk(builder.Build())), + ElementsAre(Property(&ReconfigurationResponseParameter::result, + ResponseResult::kInProgress))); + + // TSN 10, SID 1 - before TSN 12 -> deliver + data_tracker_->Observe(TSN(10)); + reasm_->Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE", + {.stream_id = StreamID(1), + .mid = MID(0), + .ppid = PPID(1001)})); + + // TSN 11, SID 2 - before TSN 12 -> deliver + data_tracker_->Observe(TSN(11)); + reasm_->Add(TSN(11), gen_.Ordered({1, 2, 3, 4}, "BE", + {.stream_id = StreamID(2), + .mid = MID(0), + .ppid = PPID(1002)})); + + // TSN 12, SID 3 - at TSN 12 -> deliver + data_tracker_->Observe(TSN(12)); + reasm_->Add(TSN(12), gen_.Ordered({1, 2, 3, 4}, "BE", + {.stream_id = StreamID(3), + .mid = MID(0), + .ppid = PPID(1003)})); + + // TSN 13, SID 1 - after TSN 12 and SID=1 -> defer + data_tracker_->Observe(TSN(13)); + reasm_->Add(TSN(13), gen_.Ordered({1, 2, 3, 4}, "BE", + {.stream_id = StreamID(1), + .mid = MID(0), + .ppid = PPID(1004)})); + + // TSN 14, SID 2 - after TSN 12 and SID=2 -> defer + data_tracker_->Observe(TSN(14)); + reasm_->Add(TSN(14), gen_.Ordered({1, 2, 3, 4}, "BE", + {.stream_id = StreamID(2), + .mid = MID(0), + .ppid = PPID(1005)})); + + // TSN 15, SID 3 - after TSN 12, but SID 3 is not reset -> deliver + data_tracker_->Observe(TSN(15)); + reasm_->Add(TSN(15), gen_.Ordered({1, 2, 3, 4}, "BE", + {.stream_id = StreamID(3), + .mid = MID(1), + .ppid = PPID(1006)})); + + EXPECT_THAT(reasm_->FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(1001), kShortPayload), + SctpMessageIs(StreamID(2), PPID(1002), kShortPayload), + SctpMessageIs(StreamID(3), PPID(1003), kShortPayload), + SctpMessageIs(StreamID(3), PPID(1006), kShortPayload))); + + builder.Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(11), + ReconfigRequestSN(3), TSN(13), + {StreamID(1), StreamID(2)})); + EXPECT_THAT(HandleAndCatchResponse(ReConfigChunk(builder.Build())), + ElementsAre(Property(&ReconfigurationResponseParameter::result, + ResponseResult::kSuccessPerformed))); + + EXPECT_THAT(reasm_->FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(1004), kShortPayload), + SctpMessageIs(StreamID(2), PPID(1005), kShortPayload))); +} + +TEST_F(StreamResetHandlerTest, ResetStreamsDefersForwardTsn) { + // This test verifies that FORWARD-TSNs are deferred if they want to move + // the cumulative ack TSN point past sender's last assigned TSN. + static constexpr StreamID kStreamId = StreamID(42); + + // Simulate sender sends: + // * TSN 10 (SSN=0, BE, lost), + // * TSN 11 (SSN=1, BE, lost), + // * TSN 12 (SSN=2, BE, lost) + // * RESET THE STREAM + // * TSN 13 (SSN=0, B, received) + // * TSN 14 (SSN=0, E, lost), + // * TSN 15 (SSN=1, BE, received) + Parameters::Builder builder; + builder.Add(OutgoingSSNResetRequestParameter( + ReconfigRequestSN(10), ReconfigRequestSN(3), TSN(12), {kStreamId})); + EXPECT_THAT(HandleAndCatchResponse(ReConfigChunk(builder.Build())), + ElementsAre(Property(&ReconfigurationResponseParameter::result, + ResponseResult::kInProgress))); + + // TSN 13, B, after TSN=12 -> defer + data_tracker_->Observe(TSN(13)); + reasm_->Add(TSN(13), + gen_.Ordered( + {1, 2, 3, 4}, "B", + {.stream_id = kStreamId, .mid = MID(0), .ppid = PPID(1004)})); + + // TSN 15, BE, after TSN=12 -> defer + data_tracker_->Observe(TSN(15)); + reasm_->Add(TSN(15), + gen_.Ordered( + {1, 2, 3, 4}, "BE", + {.stream_id = kStreamId, .mid = MID(1), .ppid = PPID(1005)})); + + // Time passes, sender decides to send FORWARD-TSN up to the RESET. + data_tracker_->HandleForwardTsn(TSN(12)); + reasm_->HandleForwardTsn( + TSN(12), std::vector<SkippedStream>({SkippedStream(kStreamId, SSN(2))})); + + // The receiver sends a SACK in response to that. The stream hasn't been + // reset yet, but the sender now decides that TSN=13-14 is to be skipped. + // As this has a TSN 14, after TSN=12 -> defer it. + data_tracker_->HandleForwardTsn(TSN(14)); + reasm_->HandleForwardTsn( + TSN(14), std::vector<SkippedStream>({SkippedStream(kStreamId, SSN(0))})); + + // Reset the stream -> deferred TSNs should be re-added. + builder.Add(OutgoingSSNResetRequestParameter( + ReconfigRequestSN(11), ReconfigRequestSN(3), TSN(12), {kStreamId})); + EXPECT_THAT(HandleAndCatchResponse(ReConfigChunk(builder.Build())), + ElementsAre(Property(&ReconfigurationResponseParameter::result, + ResponseResult::kSuccessPerformed))); + + EXPECT_THAT(reasm_->FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(kStreamId, PPID(1005), kShortPayload))); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingRequestDirectly) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get<OutgoingSSNResetRequestParameter>()); + + EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(42))); +} + +TEST_F(StreamResetHandlerTest, ResetMultipleStreamsInOneRequest) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(40))); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(41))); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))).Times(2); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(43))); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(44))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + handler_->ResetStreams( + std::vector<StreamID>({StreamID(43), StreamID(44), StreamID(41)})); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42), StreamID(40)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return( + std::vector<StreamID>({StreamID(40), StreamID(41), StreamID(42), + StreamID(43), StreamID(44)}))); + absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get<OutgoingSSNResetRequestParameter>()); + + EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), + UnorderedElementsAre(StreamID(40), StreamID(41), StreamID(42), + StreamID(43), StreamID(44))); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingRequestDeferred) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()) + .WillOnce(Return(false)) + .WillOnce(Return(false)) + .WillOnce(Return(true)); + + EXPECT_FALSE(handler_->MakeStreamResetRequest().has_value()); + EXPECT_FALSE(handler_->MakeStreamResetRequest().has_value()); + EXPECT_TRUE(handler_->MakeStreamResetRequest().has_value()); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingResettingOnPositiveResponse) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get<OutgoingSSNResetRequestParameter>()); + + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter( + req.request_sequence_number(), ResponseResult::kSuccessPerformed)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams); + EXPECT_CALL(producer_, RollbackResetStreams).Times(0); + + // Processing a response shouldn't result in sending anything. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + handler_->HandleReConfig(std::move(response_reconfig)); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingResetRollbackOnError) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get<OutgoingSSNResetRequestParameter>()); + + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter( + req.request_sequence_number(), ResponseResult::kErrorBadSequenceNumber)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams).Times(0); + EXPECT_CALL(producer_, RollbackResetStreams); + + // Only requests should result in sending responses. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + handler_->HandleReConfig(std::move(response_reconfig)); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingResetRetransmitOnInProgress) { + static constexpr StreamID kStreamToReset = StreamID(42); + + EXPECT_CALL(producer_, PrepareResetStream(kStreamToReset)); + handler_->ResetStreams(std::vector<StreamID>({kStreamToReset})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({kStreamToReset}))); + + absl::optional<ReConfigChunk> reconfig1 = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig1.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req1, + reconfig1->parameters().get<OutgoingSSNResetRequestParameter>()); + + // Simulate that the peer responded "In Progress". + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter(req1.request_sequence_number(), + ResponseResult::kInProgress)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams()).Times(0); + EXPECT_CALL(producer_, RollbackResetStreams()).Times(0); + + // Processing a response shouldn't result in sending anything. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + handler_->HandleReConfig(std::move(response_reconfig)); + + // Let some time pass, so that the reconfig timer expires, and retries the + // same request. + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(1); + AdvanceTime(kRto); + + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + ASSERT_FALSE(payload.empty()); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(payload, DcSctpOptions())); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + ReConfigChunk reconfig2, + ReConfigChunk::Parse(packet.descriptors()[0].data)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req2, + reconfig2.parameters().get<OutgoingSSNResetRequestParameter>()); + + EXPECT_EQ(req2.request_sequence_number(), + AddTo(req1.request_sequence_number(), 1)); + EXPECT_THAT(req2.stream_ids(), UnorderedElementsAre(kStreamToReset)); +} + +TEST_F(StreamResetHandlerTest, ResetWhileRequestIsSentWillQueue) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + absl::optional<ReConfigChunk> reconfig1 = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig1.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req1, + reconfig1->parameters().get<OutgoingSSNResetRequestParameter>()); + EXPECT_EQ(req1.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req1.sender_last_assigned_tsn(), + AddTo(retransmission_queue_->next_tsn(), -1)); + EXPECT_THAT(req1.stream_ids(), UnorderedElementsAre(StreamID(42))); + + // Streams reset while the request is in-flight will be queued. + EXPECT_CALL(producer_, PrepareResetStream(StreamID(41))); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(43))); + StreamID stream_ids[] = {StreamID(41), StreamID(43)}; + handler_->ResetStreams(stream_ids); + EXPECT_EQ(handler_->MakeStreamResetRequest(), absl::nullopt); + + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter( + req1.request_sequence_number(), ResponseResult::kSuccessPerformed)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams()).Times(1); + EXPECT_CALL(producer_, RollbackResetStreams()).Times(0); + + // Processing a response shouldn't result in sending anything. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + handler_->HandleReConfig(std::move(response_reconfig)); + + // Response has been processed. A new request can be sent. + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(41), StreamID(43)}))); + + absl::optional<ReConfigChunk> reconfig2 = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig2.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req2, + reconfig2->parameters().get<OutgoingSSNResetRequestParameter>()); + EXPECT_EQ(req2.request_sequence_number(), AddTo(kMyInitialReqSn, 1)); + EXPECT_EQ(req2.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req2.stream_ids(), + UnorderedElementsAre(StreamID(41), StreamID(43))); +} + +TEST_F(StreamResetHandlerTest, SendIncomingResetJustReturnsNothingPerformed) { + Parameters::Builder builder; + builder.Add( + IncomingSSNResetRequestParameter(kPeerInitialReqSn, {StreamID(1)})); + + std::vector<ReconfigurationResponseParameter> responses = + HandleAndCatchResponse(ReConfigChunk(builder.Build())); + ASSERT_THAT(responses, SizeIs(1)); + EXPECT_THAT(responses[0].response_sequence_number(), kPeerInitialReqSn); + EXPECT_THAT(responses[0].result(), ResponseResult::kSuccessNothingToDo); +} + +TEST_F(StreamResetHandlerTest, SendSameRequestTwiceIsIdempotent) { + // Simulate that receiving the same chunk twice (due to network issues, + // or retransmissions, causing a RECONFIG to be re-received) is idempotent. + for (int i = 0; i < 2; ++i) { + Parameters::Builder builder; + builder.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 1), + {StreamID(1)})); + + std::vector<ReconfigurationResponseParameter> responses1 = + HandleAndCatchResponse(ReConfigChunk(builder.Build())); + EXPECT_THAT(responses1, SizeIs(1)); + EXPECT_EQ(responses1[0].result(), ResponseResult::kInProgress); + } +} + +TEST_F(StreamResetHandlerTest, + HandoverIsAllowedOnlyWhenNoStreamIsBeingOrWillBeReset) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_EQ( + handler_->GetHandoverReadiness(), + HandoverReadinessStatus(HandoverUnreadinessReason::kPendingStreamReset)); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + ASSERT_TRUE(handler_->MakeStreamResetRequest().has_value()); + EXPECT_EQ(handler_->GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kPendingStreamResetRequest)); + + // Reset more streams while the request is in-flight. + EXPECT_CALL(producer_, PrepareResetStream(StreamID(41))); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(43))); + StreamID stream_ids[] = {StreamID(41), StreamID(43)}; + handler_->ResetStreams(stream_ids); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_EQ(handler_->GetHandoverReadiness(), + HandoverReadinessStatus() + .Add(HandoverUnreadinessReason::kPendingStreamResetRequest) + .Add(HandoverUnreadinessReason::kPendingStreamReset)); + + // Processing a response to first request. + EXPECT_CALL(producer_, CommitResetStreams()).Times(1); + handler_->HandleReConfig( + ReConfigChunk(Parameters::Builder() + .Add(ReconfigurationResponseParameter( + kMyInitialReqSn, ResponseResult::kSuccessPerformed)) + .Build())); + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_EQ( + handler_->GetHandoverReadiness(), + HandoverReadinessStatus(HandoverUnreadinessReason::kPendingStreamReset)); + + // Second request can be sent. + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(41), StreamID(43)}))); + + ASSERT_TRUE(handler_->MakeStreamResetRequest().has_value()); + EXPECT_EQ(handler_->GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kPendingStreamResetRequest)); + + // Processing a response to second request. + EXPECT_CALL(producer_, CommitResetStreams()).Times(1); + handler_->HandleReConfig(ReConfigChunk( + Parameters::Builder() + .Add(ReconfigurationResponseParameter( + AddTo(kMyInitialReqSn, 1), ResponseResult::kSuccessPerformed)) + .Build())); + + // Seconds response has been processed. No pending resets. + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(false)); + + EXPECT_TRUE(handler_->GetHandoverReadiness().IsReady()); +} + +TEST_F(StreamResetHandlerTest, HandoverInInitialState) { + PerformHandover(); + + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get<OutgoingSSNResetRequestParameter>()); + + EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(42))); +} + +TEST_F(StreamResetHandlerTest, HandoverAfterHavingResetOneStream) { + // Reset one stream + { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + ASSERT_HAS_VALUE_AND_ASSIGN(ReConfigChunk reconfig, + handler_->MakeStreamResetRequest()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig.parameters().get<OutgoingSSNResetRequestParameter>()); + EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(42))); + + EXPECT_CALL(producer_, CommitResetStreams()).Times(1); + handler_->HandleReConfig( + ReConfigChunk(Parameters::Builder() + .Add(ReconfigurationResponseParameter( + req.request_sequence_number(), + ResponseResult::kSuccessPerformed)) + .Build())); + } + + PerformHandover(); + + // Reset another stream after handover + { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(43))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(43)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(43)}))); + + ASSERT_HAS_VALUE_AND_ASSIGN(ReConfigChunk reconfig, + handler_->MakeStreamResetRequest()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig.parameters().get<OutgoingSSNResetRequestParameter>()); + + EXPECT_EQ(req.request_sequence_number(), + ReconfigRequestSN(kMyInitialReqSn.value() + 1)); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(43))); + } +} + +TEST_F(StreamResetHandlerTest, PerformCloseAfterOneFirstFailing) { + // Inject a stream reset on the first expected TSN (which hasn't been seen). + Parameters::Builder builder; + builder.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), kPeerInitialTsn, {StreamID(1)})); + + // The socket is expected to say "in progress" as that TSN hasn't been seen. + std::vector<ReconfigurationResponseParameter> responses = + HandleAndCatchResponse(ReConfigChunk(builder.Build())); + EXPECT_THAT(responses, SizeIs(1)); + EXPECT_EQ(responses[0].result(), ResponseResult::kInProgress); + + // Let the socket receive the TSN. + DataGeneratorOptions opts; + opts.mid = MID(0); + reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + data_tracker_->Observe(kPeerInitialTsn); + + // And emulate that time has passed, and the peer retries the stream reset, + // but now with an incremented request sequence number. + Parameters::Builder builder2; + builder2.Add(OutgoingSSNResetRequestParameter( + ReconfigRequestSN(*kPeerInitialReqSn + 1), ReconfigRequestSN(3), + kPeerInitialTsn, {StreamID(1)})); + + // This is supposed to be handled well. + std::vector<ReconfigurationResponseParameter> responses2 = + HandleAndCatchResponse(ReConfigChunk(builder2.Build())); + EXPECT_THAT(responses2, SizeIs(1)); + EXPECT_EQ(responses2[0].result(), ResponseResult::kSuccessPerformed); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.cc b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.cc new file mode 100644 index 0000000000..0621b48e80 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.cc @@ -0,0 +1,333 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/transmission_control_block.h" + +#include <algorithm> +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/capabilities.h" +#include "net/dcsctp/socket/stream_reset_handler.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +TransmissionControlBlock::TransmissionControlBlock( + TimerManager& timer_manager, + absl::string_view log_prefix, + const DcSctpOptions& options, + const Capabilities& capabilities, + DcSctpSocketCallbacks& callbacks, + SendQueue& send_queue, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag, + PacketSender& packet_sender, + std::function<bool()> is_connection_established) + : log_prefix_(log_prefix), + options_(options), + timer_manager_(timer_manager), + capabilities_(capabilities), + callbacks_(callbacks), + t3_rtx_(timer_manager_.CreateTimer( + "t3-rtx", + absl::bind_front(&TransmissionControlBlock::OnRtxTimerExpiry, this), + TimerOptions(options.rto_initial, + TimerBackoffAlgorithm::kExponential, + /*max_restarts=*/absl::nullopt, + options.max_timer_backoff_duration.has_value() + ? *options.max_timer_backoff_duration + : DurationMs::InfiniteDuration()))), + delayed_ack_timer_(timer_manager_.CreateTimer( + "delayed-ack", + absl::bind_front(&TransmissionControlBlock::OnDelayedAckTimerExpiry, + this), + TimerOptions(options.delayed_ack_max_timeout, + TimerBackoffAlgorithm::kExponential, + /*max_restarts=*/0, + /*max_backoff_duration=*/DurationMs::InfiniteDuration(), + webrtc::TaskQueueBase::DelayPrecision::kHigh))), + my_verification_tag_(my_verification_tag), + my_initial_tsn_(my_initial_tsn), + peer_verification_tag_(peer_verification_tag), + peer_initial_tsn_(peer_initial_tsn), + tie_tag_(tie_tag), + is_connection_established_(std::move(is_connection_established)), + packet_sender_(packet_sender), + rto_(options), + tx_error_counter_(log_prefix, options), + data_tracker_(log_prefix, delayed_ack_timer_.get(), peer_initial_tsn), + reassembly_queue_(log_prefix, + peer_initial_tsn, + options.max_receiver_window_buffer_size, + capabilities.message_interleaving), + retransmission_queue_( + log_prefix, + &callbacks_, + my_initial_tsn, + a_rwnd, + send_queue, + absl::bind_front(&TransmissionControlBlock::ObserveRTT, this), + [this]() { tx_error_counter_.Clear(); }, + *t3_rtx_, + options, + capabilities.partial_reliability, + capabilities.message_interleaving), + stream_reset_handler_(log_prefix, + this, + &timer_manager, + &data_tracker_, + &reassembly_queue_, + &retransmission_queue_), + heartbeat_handler_(log_prefix, options, this, &timer_manager_) { + send_queue.EnableMessageInterleaving(capabilities.message_interleaving); +} + +void TransmissionControlBlock::ObserveRTT(DurationMs rtt) { + DurationMs prev_rto = rto_.rto(); + rto_.ObserveRTT(rtt); + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "new rtt=" << *rtt + << ", srtt=" << *rto_.srtt() << ", rto=" << *rto_.rto() + << " (" << *prev_rto << ")"; + t3_rtx_->set_duration(rto_.rto()); + + DurationMs delayed_ack_tmo = + std::min(rto_.rto() * 0.5, options_.delayed_ack_max_timeout); + delayed_ack_timer_->set_duration(delayed_ack_tmo); +} + +absl::optional<DurationMs> TransmissionControlBlock::OnRtxTimerExpiry() { + TimeMs now = callbacks_.TimeMillis(); + RTC_DLOG(LS_INFO) << log_prefix_ << "Timer " << t3_rtx_->name() + << " has expired"; + if (cookie_echo_chunk_.has_value()) { + // In the COOKIE_ECHO state, let the T1-COOKIE timer trigger + // retransmissions, to avoid having two timers doing that. + RTC_DLOG(LS_VERBOSE) << "Not retransmitting as T1-cookie is active."; + } else { + if (IncrementTxErrorCounter("t3-rtx expired")) { + retransmission_queue_.HandleT3RtxTimerExpiry(); + SendBufferedPackets(now); + } + } + return absl::nullopt; +} + +absl::optional<DurationMs> TransmissionControlBlock::OnDelayedAckTimerExpiry() { + data_tracker_.HandleDelayedAckTimerExpiry(); + MaybeSendSack(); + return absl::nullopt; +} + +void TransmissionControlBlock::MaybeSendSack() { + if (data_tracker_.ShouldSendAck(/*also_if_delayed=*/false)) { + SctpPacket::Builder builder = PacketBuilder(); + builder.Add( + data_tracker_.CreateSelectiveAck(reassembly_queue_.remaining_bytes())); + Send(builder); + } +} + +void TransmissionControlBlock::MaybeSendForwardTsn(SctpPacket::Builder& builder, + TimeMs now) { + if (now >= limit_forward_tsn_until_ && + retransmission_queue_.ShouldSendForwardTsn(now)) { + if (capabilities_.message_interleaving) { + builder.Add(retransmission_queue_.CreateIForwardTsn()); + } else { + builder.Add(retransmission_queue_.CreateForwardTsn()); + } + Send(builder); + // https://datatracker.ietf.org/doc/html/rfc3758 + // "IMPLEMENTATION NOTE: An implementation may wish to limit the number of + // duplicate FORWARD TSN chunks it sends by ... waiting a full RTT before + // sending a duplicate FORWARD TSN." + // "Any delay applied to the sending of FORWARD TSN chunk SHOULD NOT exceed + // 200ms and MUST NOT exceed 500ms". + limit_forward_tsn_until_ = now + std::min(DurationMs(200), rto_.srtt()); + } +} + +void TransmissionControlBlock::MaybeSendFastRetransmit() { + if (!retransmission_queue_.has_data_to_be_fast_retransmitted()) { + return; + } + + // https://datatracker.ietf.org/doc/html/rfc4960#section-7.2.4 + // "Determine how many of the earliest (i.e., lowest TSN) DATA chunks marked + // for retransmission will fit into a single packet, subject to constraint of + // the path MTU of the destination transport address to which the packet is + // being sent. Call this value K. Retransmit those K DATA chunks in a single + // packet. When a Fast Retransmit is being performed, the sender SHOULD + // ignore the value of cwnd and SHOULD NOT delay retransmission for this + // single packet." + + SctpPacket::Builder builder(peer_verification_tag_, options_); + auto chunks = retransmission_queue_.GetChunksForFastRetransmit( + builder.bytes_remaining()); + for (auto& [tsn, data] : chunks) { + if (capabilities_.message_interleaving) { + builder.Add(IDataChunk(tsn, std::move(data), false)); + } else { + builder.Add(DataChunk(tsn, std::move(data), false)); + } + } + Send(builder); +} + +void TransmissionControlBlock::SendBufferedPackets(SctpPacket::Builder& builder, + TimeMs now) { + for (int packet_idx = 0; + packet_idx < options_.max_burst && retransmission_queue_.can_send_data(); + ++packet_idx) { + // Only add control chunks to the first packet that is sent, if sending + // multiple packets in one go (as allowed by the congestion window). + if (packet_idx == 0) { + if (cookie_echo_chunk_.has_value()) { + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "The COOKIE ECHO chunk can be bundled with any pending outbound DATA + // chunks, but it MUST be the first chunk in the packet..." + RTC_DCHECK(builder.empty()); + builder.Add(*cookie_echo_chunk_); + } + + // https://tools.ietf.org/html/rfc4960#section-6 + // "Before an endpoint transmits a DATA chunk, if any received DATA + // chunks have not been acknowledged (e.g., due to delayed ack), the + // sender should create a SACK and bundle it with the outbound DATA chunk, + // as long as the size of the final SCTP packet does not exceed the + // current MTU." + if (data_tracker_.ShouldSendAck(/*also_if_delayed=*/true)) { + builder.Add(data_tracker_.CreateSelectiveAck( + reassembly_queue_.remaining_bytes())); + } + MaybeSendForwardTsn(builder, now); + absl::optional<ReConfigChunk> reconfig = + stream_reset_handler_.MakeStreamResetRequest(); + if (reconfig.has_value()) { + builder.Add(*reconfig); + } + } + + auto chunks = + retransmission_queue_.GetChunksToSend(now, builder.bytes_remaining()); + for (auto& [tsn, data] : chunks) { + if (capabilities_.message_interleaving) { + builder.Add(IDataChunk(tsn, std::move(data), false)); + } else { + builder.Add(DataChunk(tsn, std::move(data), false)); + } + } + + // https://www.ietf.org/archive/id/draft-tuexen-tsvwg-sctp-zero-checksum-02.html#section-4.2 + // "When an end point sends a packet containing a COOKIE ECHO chunk, it MUST + // include a correct CRC32c checksum in the packet containing the COOKIE + // ECHO chunk." + bool write_checksum = + !capabilities_.zero_checksum || cookie_echo_chunk_.has_value(); + if (!packet_sender_.Send(builder, write_checksum)) { + break; + } + + if (cookie_echo_chunk_.has_value()) { + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "... until the COOKIE ACK is returned the sender MUST NOT send any + // other packets to the peer." + break; + } + } +} + +std::string TransmissionControlBlock::ToString() const { + rtc::StringBuilder sb; + + sb.AppendFormat( + "verification_tag=%08x, last_cumulative_ack=%u, capabilities=", + *peer_verification_tag_, *data_tracker_.last_cumulative_acked_tsn()); + + if (capabilities_.partial_reliability) { + sb << "PR,"; + } + if (capabilities_.message_interleaving) { + sb << "IL,"; + } + if (capabilities_.reconfig) { + sb << "Reconfig,"; + } + if (capabilities_.zero_checksum) { + sb << "ZeroChecksum,"; + } + sb << " max_in=" << capabilities_.negotiated_maximum_incoming_streams; + sb << " max_out=" << capabilities_.negotiated_maximum_outgoing_streams; + + return sb.Release(); +} + +HandoverReadinessStatus TransmissionControlBlock::GetHandoverReadiness() const { + HandoverReadinessStatus status; + status.Add(data_tracker_.GetHandoverReadiness()); + status.Add(stream_reset_handler_.GetHandoverReadiness()); + status.Add(reassembly_queue_.GetHandoverReadiness()); + status.Add(retransmission_queue_.GetHandoverReadiness()); + return status; +} + +void TransmissionControlBlock::AddHandoverState( + DcSctpSocketHandoverState& state) { + state.capabilities.partial_reliability = capabilities_.partial_reliability; + state.capabilities.message_interleaving = capabilities_.message_interleaving; + state.capabilities.reconfig = capabilities_.reconfig; + state.capabilities.zero_checksum = capabilities_.zero_checksum; + state.capabilities.negotiated_maximum_incoming_streams = + capabilities_.negotiated_maximum_incoming_streams; + state.capabilities.negotiated_maximum_outgoing_streams = + capabilities_.negotiated_maximum_outgoing_streams; + + state.my_verification_tag = my_verification_tag().value(); + state.peer_verification_tag = peer_verification_tag().value(); + state.my_initial_tsn = my_initial_tsn().value(); + state.peer_initial_tsn = peer_initial_tsn().value(); + state.tie_tag = tie_tag().value(); + + data_tracker_.AddHandoverState(state); + stream_reset_handler_.AddHandoverState(state); + reassembly_queue_.AddHandoverState(state); + retransmission_queue_.AddHandoverState(state); +} + +void TransmissionControlBlock::RestoreFromState( + const DcSctpSocketHandoverState& state) { + data_tracker_.RestoreFromState(state); + retransmission_queue_.RestoreFromState(state); + reassembly_queue_.RestoreFromState(state); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.h b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.h new file mode 100644 index 0000000000..46a39d5a7b --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.h @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_ +#define NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_ + +#include <cstdint> +#include <functional> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/functional/bind_front.h" +#include "absl/strings/string_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/capabilities.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/socket/heartbeat_handler.h" +#include "net/dcsctp/socket/packet_sender.h" +#include "net/dcsctp/socket/stream_reset_handler.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_error_counter.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "net/dcsctp/tx/send_queue.h" + +namespace dcsctp { + +// The TransmissionControlBlock (TCB) represents an open connection to a peer, +// and holds all the resources for that. If the connection is e.g. shutdown, +// closed or restarted, this object will be deleted and/or replaced. +class TransmissionControlBlock : public Context { + public: + TransmissionControlBlock(TimerManager& timer_manager, + absl::string_view log_prefix, + const DcSctpOptions& options, + const Capabilities& capabilities, + DcSctpSocketCallbacks& callbacks, + SendQueue& send_queue, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag, + PacketSender& packet_sender, + std::function<bool()> is_connection_established); + + // Implementation of `Context`. + bool is_connection_established() const override { + return is_connection_established_(); + } + TSN my_initial_tsn() const override { return my_initial_tsn_; } + TSN peer_initial_tsn() const override { return peer_initial_tsn_; } + DcSctpSocketCallbacks& callbacks() const override { return callbacks_; } + void ObserveRTT(DurationMs rtt) override; + DurationMs current_rto() const override { return rto_.rto(); } + bool IncrementTxErrorCounter(absl::string_view reason) override { + return tx_error_counter_.Increment(reason); + } + void ClearTxErrorCounter() override { tx_error_counter_.Clear(); } + SctpPacket::Builder PacketBuilder() const override { + return SctpPacket::Builder(peer_verification_tag_, options_); + } + bool HasTooManyTxErrors() const override { + return tx_error_counter_.IsExhausted(); + } + void Send(SctpPacket::Builder& builder) override { + packet_sender_.Send(builder, + /*write_checksum=*/!capabilities_.zero_checksum); + } + + // Other accessors + DataTracker& data_tracker() { return data_tracker_; } + ReassemblyQueue& reassembly_queue() { return reassembly_queue_; } + RetransmissionQueue& retransmission_queue() { return retransmission_queue_; } + StreamResetHandler& stream_reset_handler() { return stream_reset_handler_; } + HeartbeatHandler& heartbeat_handler() { return heartbeat_handler_; } + size_t cwnd() const { return retransmission_queue_.cwnd(); } + DurationMs current_srtt() const { return rto_.srtt(); } + + // Returns this socket's verification tag, set in all packet headers. + VerificationTag my_verification_tag() const { return my_verification_tag_; } + // Returns the peer's verification tag, which should be in received packets. + VerificationTag peer_verification_tag() const { + return peer_verification_tag_; + } + // All negotiated supported capabilities. + const Capabilities& capabilities() const { return capabilities_; } + // A 64-bit tie-tag, used to e.g. detect reconnections. + TieTag tie_tag() const { return tie_tag_; } + + // Sends a SACK, if there is a need to. + void MaybeSendSack(); + + // Sends a FORWARD-TSN, if it is needed and allowed (rate-limited). + void MaybeSendForwardTsn(SctpPacket::Builder& builder, TimeMs now); + + // Will be set while the socket is in kCookieEcho state. In this state, there + // can only be a single packet outstanding, and it must contain the COOKIE + // ECHO chunk as the first chunk in that packet, until the COOKIE ACK has been + // received, which will make the socket call `ClearCookieEchoChunk`. + void SetCookieEchoChunk(CookieEchoChunk chunk) { + cookie_echo_chunk_ = std::move(chunk); + } + + // Called when the COOKIE ACK chunk has been received, to allow further + // packets to be sent. + void ClearCookieEchoChunk() { cookie_echo_chunk_ = absl::nullopt; } + + bool has_cookie_echo_chunk() const { return cookie_echo_chunk_.has_value(); } + + void MaybeSendFastRetransmit(); + + // Fills `builder` (which may already be filled with control chunks) with + // other control and data chunks, and sends packets as much as can be + // allowed by the congestion control algorithm. + void SendBufferedPackets(SctpPacket::Builder& builder, TimeMs now); + + // As above, but without passing in a builder. If `cookie_echo_chunk_` is + // present, then only one packet will be sent, with this chunk as the first + // chunk. + void SendBufferedPackets(TimeMs now) { + SctpPacket::Builder builder(peer_verification_tag_, options_); + SendBufferedPackets(builder, now); + } + + // Returns a textual representation of this object, for logging. + std::string ToString() const; + + HandoverReadinessStatus GetHandoverReadiness() const; + + void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& handover_state); + + private: + // Will be called when the retransmission timer (t3-rtx) expires. + absl::optional<DurationMs> OnRtxTimerExpiry(); + // Will be called when the delayed ack timer expires. + absl::optional<DurationMs> OnDelayedAckTimerExpiry(); + + const absl::string_view log_prefix_; + const DcSctpOptions options_; + TimerManager& timer_manager_; + // Negotiated capabilities that both peers support. + const Capabilities capabilities_; + DcSctpSocketCallbacks& callbacks_; + // The data retransmission timer, called t3-rtx in SCTP. + const std::unique_ptr<Timer> t3_rtx_; + // Delayed ack timer, which triggers when acks should be sent (when delayed). + const std::unique_ptr<Timer> delayed_ack_timer_; + const VerificationTag my_verification_tag_; + const TSN my_initial_tsn_; + const VerificationTag peer_verification_tag_; + const TSN peer_initial_tsn_; + // Nonce, used to detect reconnections. + const TieTag tie_tag_; + const std::function<bool()> is_connection_established_; + PacketSender& packet_sender_; + // Rate limiting of FORWARD-TSN. Next can be sent at or after this timestamp. + TimeMs limit_forward_tsn_until_ = TimeMs(0); + + RetransmissionTimeout rto_; + RetransmissionErrorCounter tx_error_counter_; + DataTracker data_tracker_; + ReassemblyQueue reassembly_queue_; + RetransmissionQueue retransmission_queue_; + StreamResetHandler stream_reset_handler_; + HeartbeatHandler heartbeat_handler_; + + // Only valid when the socket state == State::kCookieEchoed. In this state, + // the socket must wait for COOKIE ACK to continue sending any packets (not + // including a COOKIE ECHO). So if `cookie_echo_chunk_` is present, the + // SendBufferedChunks will always only just send one packet, with this chunk + // as the first chunk in the packet. + absl::optional<CookieEchoChunk> cookie_echo_chunk_ = absl::nullopt; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block_test.cc b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block_test.cc new file mode 100644 index 0000000000..6106fbb309 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block_test.cc @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2023 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/transmission_control_block.h" + +#include <array> +#include <cstdint> +#include <memory> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/capabilities.h" +#include "net/dcsctp/socket/mock_context.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "net/dcsctp/testing/data_generator.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/mock_send_queue.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::Return; +using ::testing::StrictMock; + +constexpr VerificationTag kMyVerificationTag = VerificationTag(123); +constexpr VerificationTag kPeerVerificationTag = VerificationTag(456); +constexpr TSN kMyInitialTsn = TSN(10); +constexpr TSN kPeerInitialTsn = TSN(1000); +constexpr size_t kArwnd = 65536; +constexpr TieTag kTieTag = TieTag(12345678); + +class TransmissionControlBlockTest : public testing::Test { + protected: + TransmissionControlBlockTest() + : sender_(callbacks_, on_send_fn_.AsStdFunction()), + timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return callbacks_.CreateTimeout(precision); + }) {} + + DcSctpOptions options_; + Capabilities capabilities_; + StrictMock<MockDcSctpSocketCallbacks> callbacks_; + StrictMock<MockSendQueue> send_queue_; + testing::MockFunction<void(rtc::ArrayView<const uint8_t>, SendPacketStatus)> + on_send_fn_; + testing::MockFunction<bool()> on_connection_established; + PacketSender sender_; + TimerManager timer_manager_; +}; + +TEST_F(TransmissionControlBlockTest, LogsBasicInfoInToString) { + EXPECT_CALL(send_queue_, EnableMessageInterleaving); + + capabilities_.negotiated_maximum_incoming_streams = 1000; + capabilities_.negotiated_maximum_outgoing_streams = 2000; + TransmissionControlBlock tcb( + timer_manager_, "log: ", options_, capabilities_, callbacks_, send_queue_, + kMyVerificationTag, kMyInitialTsn, kPeerVerificationTag, kPeerInitialTsn, + kArwnd, kTieTag, sender_, on_connection_established.AsStdFunction()); + + EXPECT_EQ(tcb.ToString(), + "verification_tag=000001c8, last_cumulative_ack=999, capabilities= " + "max_in=1000 max_out=2000"); +} + +TEST_F(TransmissionControlBlockTest, LogsAllCapabilitiesInToSring) { + EXPECT_CALL(send_queue_, EnableMessageInterleaving); + + capabilities_.negotiated_maximum_incoming_streams = 1000; + capabilities_.negotiated_maximum_outgoing_streams = 2000; + capabilities_.message_interleaving = true; + capabilities_.partial_reliability = true; + capabilities_.zero_checksum = true; + capabilities_.reconfig = true; + + TransmissionControlBlock tcb( + timer_manager_, "log: ", options_, capabilities_, callbacks_, send_queue_, + kMyVerificationTag, kMyInitialTsn, kPeerVerificationTag, kPeerInitialTsn, + kArwnd, kTieTag, sender_, on_connection_established.AsStdFunction()); + + EXPECT_EQ( + tcb.ToString(), + "verification_tag=000001c8, last_cumulative_ack=999, " + "capabilities=PR,IL,Reconfig,ZeroChecksum, max_in=1000 max_out=2000"); +} + +TEST_F(TransmissionControlBlockTest, IsInitiallyHandoverReady) { + EXPECT_CALL(send_queue_, EnableMessageInterleaving); + EXPECT_CALL(send_queue_, HasStreamsReadyToBeReset).WillOnce(Return(false)); + + TransmissionControlBlock tcb( + timer_manager_, "log: ", options_, capabilities_, callbacks_, send_queue_, + kMyVerificationTag, kMyInitialTsn, kPeerVerificationTag, kPeerInitialTsn, + kArwnd, kTieTag, sender_, on_connection_established.AsStdFunction()); + + EXPECT_TRUE(tcb.GetHandoverReadiness().IsReady()); +} +} // namespace +} // namespace dcsctp |