diff options
Diffstat (limited to 'third_party/libwebrtc/net/dcsctp/rx')
15 files changed, 3863 insertions, 0 deletions
diff --git a/third_party/libwebrtc/net/dcsctp/rx/BUILD.gn b/third_party/libwebrtc/net/dcsctp/rx/BUILD.gn new file mode 100644 index 0000000000..8ef60dcd5f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/BUILD.gn @@ -0,0 +1,149 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_library("data_tracker") { + deps = [ + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:socket", + "../timer", + ] + sources = [ + "data_tracker.cc", + "data_tracker.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_source_set("reassembly_streams") { + deps = [ + "../../../api:array_view", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:socket", + "../public:types", + ] + sources = [ "reassembly_streams.h" ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_library("interleaved_reassembly_streams") { + deps = [ + ":reassembly_streams", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:types", + ] + sources = [ + "interleaved_reassembly_streams.cc", + "interleaved_reassembly_streams.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} +rtc_library("traditional_reassembly_streams") { + deps = [ + ":reassembly_streams", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:types", + ] + sources = [ + "traditional_reassembly_streams.cc", + "traditional_reassembly_streams.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("reassembly_queue") { + deps = [ + ":interleaved_reassembly_streams", + ":reassembly_streams", + ":traditional_reassembly_streams", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../common:internal_types", + "../common:sequence_numbers", + "../common:str_join", + "../packet:chunk", + "../packet:data", + "../packet:parameter", + "../public:socket", + "../public:types", + ] + sources = [ + "reassembly_queue.cc", + "reassembly_queue.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_rx_unittests") { + testonly = true + + deps = [ + ":data_tracker", + ":interleaved_reassembly_streams", + ":reassembly_queue", + ":reassembly_streams", + ":traditional_reassembly_streams", + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../test:test_support", + "../common:handover_testing", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:types", + "../testing:data_generator", + "../timer", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + sources = [ + "data_tracker_test.cc", + "interleaved_reassembly_streams_test.cc", + "reassembly_queue_test.cc", + "traditional_reassembly_streams_test.cc", + ] + } +} diff --git a/third_party/libwebrtc/net/dcsctp/rx/data_tracker.cc b/third_party/libwebrtc/net/dcsctp/rx/data_tracker.cc new file mode 100644 index 0000000000..1f2e43f7f5 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/data_tracker.cc @@ -0,0 +1,386 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/data_tracker.h" + +#include <algorithm> +#include <cstdint> +#include <iterator> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/timer/timer.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +constexpr size_t DataTracker::kMaxDuplicateTsnReported; +constexpr size_t DataTracker::kMaxGapAckBlocksReported; + +bool DataTracker::AdditionalTsnBlocks::Add(UnwrappedTSN tsn) { + // Find any block to expand. It will look for any block that includes (also + // when expanded) the provided `tsn`. It will return the block that is greater + // than, or equal to `tsn`. + auto it = absl::c_lower_bound( + blocks_, tsn, [&](const TsnRange& elem, const UnwrappedTSN& t) { + return elem.last.next_value() < t; + }); + + if (it == blocks_.end()) { + // No matching block found. There is no greater than, or equal block - which + // means that this TSN is greater than any block. It can then be inserted at + // the end. + blocks_.emplace_back(tsn, tsn); + return true; + } + + if (tsn >= it->first && tsn <= it->last) { + // It's already in this block. + return false; + } + + if (it->last.next_value() == tsn) { + // This block can be expanded to the right, or merged with the next. + auto next_it = it + 1; + if (next_it != blocks_.end() && tsn.next_value() == next_it->first) { + // Expanding it would make it adjacent to next block - merge those. + it->last = next_it->last; + blocks_.erase(next_it); + return true; + } + + // Expand to the right + it->last = tsn; + return true; + } + + if (it->first == tsn.next_value()) { + // This block can be expanded to the left. Merging to the left would've been + // covered by the above "merge to the right". Both blocks (expand a + // right-most block to the left and expand a left-most block to the right) + // would match, but the left-most would be returned by std::lower_bound. + RTC_DCHECK(it == blocks_.begin() || (it - 1)->last.next_value() != tsn); + + // Expand to the left. + it->first = tsn; + return true; + } + + // Need to create a new block in the middle. + blocks_.emplace(it, tsn, tsn); + return true; +} + +void DataTracker::AdditionalTsnBlocks::EraseTo(UnwrappedTSN tsn) { + // Find the block that is greater than or equals `tsn`. + auto it = absl::c_lower_bound( + blocks_, tsn, [&](const TsnRange& elem, const UnwrappedTSN& t) { + return elem.last < t; + }); + + // The block that is found is greater or equal (or possibly ::end, when no + // block is greater or equal). All blocks before this block can be safely + // removed. the TSN might be within this block, so possibly truncate it. + bool tsn_is_within_block = it != blocks_.end() && tsn >= it->first; + blocks_.erase(blocks_.begin(), it); + + if (tsn_is_within_block) { + blocks_.front().first = tsn.next_value(); + } +} + +void DataTracker::AdditionalTsnBlocks::PopFront() { + RTC_DCHECK(!blocks_.empty()); + blocks_.erase(blocks_.begin()); +} + +bool DataTracker::IsTSNValid(TSN tsn) const { + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.PeekUnwrap(tsn); + + // Note that this method doesn't return `false` for old DATA chunks, as those + // are actually valid, and receiving those may affect the generated SACK + // response (by setting "duplicate TSNs"). + + uint32_t difference = + UnwrappedTSN::Difference(unwrapped_tsn, last_cumulative_acked_tsn_); + if (difference > kMaxAcceptedOutstandingFragments) { + return false; + } + return true; +} + +bool DataTracker::Observe(TSN tsn, + AnyDataChunk::ImmediateAckFlag immediate_ack) { + bool is_duplicate = false; + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.Unwrap(tsn); + + // IsTSNValid must be called prior to calling this method. + RTC_DCHECK( + UnwrappedTSN::Difference(unwrapped_tsn, last_cumulative_acked_tsn_) <= + kMaxAcceptedOutstandingFragments); + + // Old chunk already seen before? + if (unwrapped_tsn <= last_cumulative_acked_tsn_) { + if (duplicate_tsns_.size() < kMaxDuplicateTsnReported) { + duplicate_tsns_.insert(unwrapped_tsn.Wrap()); + } + // https://datatracker.ietf.org/doc/html/rfc4960#section-6.2 + // "When a packet arrives with duplicate DATA chunk(s) and with no new DATA + // chunk(s), the endpoint MUST immediately send a SACK with no delay. If a + // packet arrives with duplicate DATA chunk(s) bundled with new DATA chunks, + // the endpoint MAY immediately send a SACK." + UpdateAckState(AckState::kImmediate, "duplicate data"); + is_duplicate = true; + } else { + if (unwrapped_tsn == last_cumulative_acked_tsn_.next_value()) { + last_cumulative_acked_tsn_ = unwrapped_tsn; + // The cumulative acked tsn may be moved even further, if a gap was + // filled. + if (!additional_tsn_blocks_.empty() && + additional_tsn_blocks_.front().first == + last_cumulative_acked_tsn_.next_value()) { + last_cumulative_acked_tsn_ = additional_tsn_blocks_.front().last; + additional_tsn_blocks_.PopFront(); + } + } else { + bool inserted = additional_tsn_blocks_.Add(unwrapped_tsn); + if (!inserted) { + // Already seen before. + if (duplicate_tsns_.size() < kMaxDuplicateTsnReported) { + duplicate_tsns_.insert(unwrapped_tsn.Wrap()); + } + // https://datatracker.ietf.org/doc/html/rfc4960#section-6.2 + // "When a packet arrives with duplicate DATA chunk(s) and with no new + // DATA chunk(s), the endpoint MUST immediately send a SACK with no + // delay. If a packet arrives with duplicate DATA chunk(s) bundled with + // new DATA chunks, the endpoint MAY immediately send a SACK." + // No need to do this. SACKs are sent immediately on packet loss below. + is_duplicate = true; + } + } + } + + // https://tools.ietf.org/html/rfc4960#section-6.7 + // "Upon the reception of a new DATA chunk, an endpoint shall examine the + // continuity of the TSNs received. If the endpoint detects a gap in + // the received DATA chunk sequence, it SHOULD send a SACK with Gap Ack + // Blocks immediately. The data receiver continues sending a SACK after + // receipt of each SCTP packet that doesn't fill the gap." + if (!additional_tsn_blocks_.empty()) { + UpdateAckState(AckState::kImmediate, "packet loss"); + } + + // https://tools.ietf.org/html/rfc7053#section-5.2 + // "Upon receipt of an SCTP packet containing a DATA chunk with the I + // bit set, the receiver SHOULD NOT delay the sending of the corresponding + // SACK chunk, i.e., the receiver SHOULD immediately respond with the + // corresponding SACK chunk." + if (*immediate_ack) { + UpdateAckState(AckState::kImmediate, "immediate-ack bit set"); + } + + if (!seen_packet_) { + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "After the reception of the first DATA chunk in an association the + // endpoint MUST immediately respond with a SACK to acknowledge the DATA + // chunk." + seen_packet_ = true; + UpdateAckState(AckState::kImmediate, "first DATA chunk"); + } + + // https://tools.ietf.org/html/rfc4960#section-6.2 + // "Specifically, an acknowledgement SHOULD be generated for at least + // every second packet (not every second DATA chunk) received, and SHOULD be + // generated within 200 ms of the arrival of any unacknowledged DATA chunk." + if (ack_state_ == AckState::kIdle) { + UpdateAckState(AckState::kBecomingDelayed, "received DATA when idle"); + } else if (ack_state_ == AckState::kDelayed) { + UpdateAckState(AckState::kImmediate, "received DATA when already delayed"); + } + return !is_duplicate; +} + +void DataTracker::HandleForwardTsn(TSN new_cumulative_ack) { + // ForwardTSN is sent to make the receiver (this socket) "forget" about partly + // received (or not received at all) data, up until `new_cumulative_ack`. + + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.Unwrap(new_cumulative_ack); + UnwrappedTSN prev_last_cum_ack_tsn = last_cumulative_acked_tsn_; + + // Old chunk already seen before? + if (unwrapped_tsn <= last_cumulative_acked_tsn_) { + // https://tools.ietf.org/html/rfc3758#section-3.6 + // "Note, if the "New Cumulative TSN" value carried in the arrived + // FORWARD TSN chunk is found to be behind or at the current cumulative TSN + // point, the data receiver MUST treat this FORWARD TSN as out-of-date and + // MUST NOT update its Cumulative TSN. The receiver SHOULD send a SACK to + // its peer (the sender of the FORWARD TSN) since such a duplicate may + // indicate the previous SACK was lost in the network." + UpdateAckState(AckState::kImmediate, + "FORWARD_TSN new_cumulative_tsn was behind"); + return; + } + + // https://tools.ietf.org/html/rfc3758#section-3.6 + // "When a FORWARD TSN chunk arrives, the data receiver MUST first update + // its cumulative TSN point to the value carried in the FORWARD TSN chunk, and + // then MUST further advance its cumulative TSN point locally if possible, as + // shown by the following example..." + + // The `new_cumulative_ack` will become the current + // `last_cumulative_acked_tsn_`, and if there have been prior "gaps" that are + // now overlapping with the new value, remove them. + last_cumulative_acked_tsn_ = unwrapped_tsn; + additional_tsn_blocks_.EraseTo(unwrapped_tsn); + + // See if the `last_cumulative_acked_tsn_` can be moved even further: + if (!additional_tsn_blocks_.empty() && + additional_tsn_blocks_.front().first == + last_cumulative_acked_tsn_.next_value()) { + last_cumulative_acked_tsn_ = additional_tsn_blocks_.front().last; + additional_tsn_blocks_.PopFront(); + } + + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "FORWARD_TSN, cum_ack_tsn=" + << *prev_last_cum_ack_tsn.Wrap() << "->" + << *new_cumulative_ack << "->" + << *last_cumulative_acked_tsn_.Wrap(); + + // https://tools.ietf.org/html/rfc3758#section-3.6 + // "Any time a FORWARD TSN chunk arrives, for the purposes of sending a + // SACK, the receiver MUST follow the same rules as if a DATA chunk had been + // received (i.e., follow the delayed sack rules specified in ..." + if (ack_state_ == AckState::kIdle) { + UpdateAckState(AckState::kBecomingDelayed, + "received FORWARD_TSN when idle"); + } else if (ack_state_ == AckState::kDelayed) { + UpdateAckState(AckState::kImmediate, + "received FORWARD_TSN when already delayed"); + } +} + +SackChunk DataTracker::CreateSelectiveAck(size_t a_rwnd) { + // Note that in SCTP, the receiver side is allowed to discard received data + // and signal that to the sender, but only chunks that have previously been + // reported in the gap-ack-blocks. However, this implementation will never do + // that. So this SACK produced is more like a NR-SACK as explained in + // https://ieeexplore.ieee.org/document/4697037 and which there is an RFC + // draft at https://tools.ietf.org/html/draft-tuexen-tsvwg-sctp-multipath-17. + std::set<TSN> duplicate_tsns; + duplicate_tsns_.swap(duplicate_tsns); + + return SackChunk(last_cumulative_acked_tsn_.Wrap(), a_rwnd, + CreateGapAckBlocks(), std::move(duplicate_tsns)); +} + +std::vector<SackChunk::GapAckBlock> DataTracker::CreateGapAckBlocks() const { + const auto& blocks = additional_tsn_blocks_.blocks(); + std::vector<SackChunk::GapAckBlock> gap_ack_blocks; + gap_ack_blocks.reserve(std::min(blocks.size(), kMaxGapAckBlocksReported)); + for (size_t i = 0; i < blocks.size() && i < kMaxGapAckBlocksReported; ++i) { + auto start_diff = + UnwrappedTSN::Difference(blocks[i].first, last_cumulative_acked_tsn_); + auto end_diff = + UnwrappedTSN::Difference(blocks[i].last, last_cumulative_acked_tsn_); + gap_ack_blocks.emplace_back(static_cast<uint16_t>(start_diff), + static_cast<uint16_t>(end_diff)); + } + + return gap_ack_blocks; +} + +bool DataTracker::ShouldSendAck(bool also_if_delayed) { + if (ack_state_ == AckState::kImmediate || + (also_if_delayed && (ack_state_ == AckState::kBecomingDelayed || + ack_state_ == AckState::kDelayed))) { + UpdateAckState(AckState::kIdle, "sending SACK"); + return true; + } + + return false; +} + +bool DataTracker::will_increase_cum_ack_tsn(TSN tsn) const { + UnwrappedTSN unwrapped = tsn_unwrapper_.PeekUnwrap(tsn); + return unwrapped == last_cumulative_acked_tsn_.next_value(); +} + +void DataTracker::ForceImmediateSack() { + ack_state_ = AckState::kImmediate; +} + +void DataTracker::HandleDelayedAckTimerExpiry() { + UpdateAckState(AckState::kImmediate, "delayed ack timer expired"); +} + +void DataTracker::ObservePacketEnd() { + if (ack_state_ == AckState::kBecomingDelayed) { + UpdateAckState(AckState::kDelayed, "packet end"); + } +} + +void DataTracker::UpdateAckState(AckState new_state, absl::string_view reason) { + if (new_state != ack_state_) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "State changed from " + << ToString(ack_state_) << " to " + << ToString(new_state) << " due to " << reason; + if (ack_state_ == AckState::kDelayed) { + delayed_ack_timer_.Stop(); + } else if (new_state == AckState::kDelayed) { + delayed_ack_timer_.Start(); + } + ack_state_ = new_state; + } +} + +absl::string_view DataTracker::ToString(AckState ack_state) { + switch (ack_state) { + case AckState::kIdle: + return "IDLE"; + case AckState::kBecomingDelayed: + return "BECOMING_DELAYED"; + case AckState::kDelayed: + return "DELAYED"; + case AckState::kImmediate: + return "IMMEDIATE"; + } +} + +HandoverReadinessStatus DataTracker::GetHandoverReadiness() const { + HandoverReadinessStatus status; + if (!additional_tsn_blocks_.empty()) { + status.Add(HandoverUnreadinessReason::kDataTrackerTsnBlocksPending); + } + return status; +} + +void DataTracker::AddHandoverState(DcSctpSocketHandoverState& state) { + state.rx.last_cumulative_acked_tsn = last_cumulative_acked_tsn().value(); + state.rx.seen_packet = seen_packet_; +} + +void DataTracker::RestoreFromState(const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(additional_tsn_blocks_.empty()); + RTC_DCHECK(duplicate_tsns_.empty()); + RTC_DCHECK(!seen_packet_); + + seen_packet_ = state.rx.seen_packet; + last_cumulative_acked_tsn_ = + tsn_unwrapper_.Unwrap(TSN(state.rx.last_cumulative_acked_tsn)); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/data_tracker.h b/third_party/libwebrtc/net/dcsctp/rx/data_tracker.h new file mode 100644 index 0000000000..ea077a9b57 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/data_tracker.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_DATA_TRACKER_H_ +#define NET_DCSCTP_RX_DATA_TRACKER_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <cstdint> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_handover_state.h" +#include "net/dcsctp/timer/timer.h" + +namespace dcsctp { + +// Keeps track of received DATA chunks and handles all logic for _when_ to +// create SACKs and also _how_ to generate them. +// +// It only uses TSNs to track delivery and doesn't need to be aware of streams. +// +// SACKs are optimally sent every second packet on connections with no packet +// loss. When packet loss is detected, it's sent for every packet. When SACKs +// are not sent directly, a timer is used to send a SACK delayed (by RTO/2, or +// 200ms, whatever is smallest). +class DataTracker { + public: + // The maximum number of duplicate TSNs that will be reported in a SACK. + static constexpr size_t kMaxDuplicateTsnReported = 20; + // The maximum number of gap-ack-blocks that will be reported in a SACK. + static constexpr size_t kMaxGapAckBlocksReported = 20; + + // The maximum number of accepted in-flight DATA chunks. This indicates the + // maximum difference from this buffer's last cumulative ack TSN, and any + // received data. Data received beyond this limit will be dropped, which will + // force the transmitter to send data that actually increases the last + // cumulative acked TSN. + static constexpr uint32_t kMaxAcceptedOutstandingFragments = 100000; + + DataTracker(absl::string_view log_prefix, + Timer* delayed_ack_timer, + TSN peer_initial_tsn) + : log_prefix_(std::string(log_prefix) + "dtrack: "), + seen_packet_(false), + delayed_ack_timer_(*delayed_ack_timer), + last_cumulative_acked_tsn_( + tsn_unwrapper_.Unwrap(TSN(*peer_initial_tsn - 1))) {} + + // Indicates if the provided TSN is valid. If this return false, the data + // should be dropped and not added to any other buffers, which essentially + // means that there is intentional packet loss. + bool IsTSNValid(TSN tsn) const; + + // Call for every incoming data chunk. Returns `true` if `tsn` was seen for + // the first time, and `false` if it has been seen before (a duplicate `tsn`). + bool Observe(TSN tsn, + AnyDataChunk::ImmediateAckFlag immediate_ack = + AnyDataChunk::ImmediateAckFlag(false)); + // Called at the end of processing an SCTP packet. + void ObservePacketEnd(); + + // Called for incoming FORWARD-TSN/I-FORWARD-TSN chunks + void HandleForwardTsn(TSN new_cumulative_ack); + + // Indicates if a SACK should be sent. There may be other reasons to send a + // SACK, but if this function indicates so, it should be sent as soon as + // possible. Calling this function will make it clear a flag so that if it's + // called again, it will probably return false. + // + // If the delayed ack timer is running, this method will return false _unless_ + // `also_if_delayed` is set to true. Then it will return true as well. + bool ShouldSendAck(bool also_if_delayed = false); + + // Returns the last cumulative ack TSN - the last seen data chunk's TSN + // value before any packet loss was detected. + TSN last_cumulative_acked_tsn() const { + return TSN(last_cumulative_acked_tsn_.Wrap()); + } + + // Returns true if the received `tsn` would increase the cumulative ack TSN. + bool will_increase_cum_ack_tsn(TSN tsn) const; + + // Forces `ShouldSendSack` to return true. + void ForceImmediateSack(); + + // Note that this will clear `duplicates_`, so every SackChunk that is + // consumed must be sent. + SackChunk CreateSelectiveAck(size_t a_rwnd); + + void HandleDelayedAckTimerExpiry(); + + HandoverReadinessStatus GetHandoverReadiness() const; + + void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); + + private: + enum class AckState { + // No need to send an ACK. + kIdle, + + // Has received data chunks (but not yet end of packet). + kBecomingDelayed, + + // Has received data chunks and the end of a packet. Delayed ack timer is + // running and a SACK will be sent on expiry, or if DATA is sent, or after + // next packet with data. + kDelayed, + + // Send a SACK immediately after handling this packet. + kImmediate, + }; + + // Represents ranges of TSNs that have been received that are not directly + // following the last cumulative acked TSN. This information is returned to + // the sender in the "gap ack blocks" in the SACK chunk. The blocks are always + // non-overlapping and non-adjacent. + class AdditionalTsnBlocks { + public: + // Represents an inclusive range of received TSNs, i.e. [first, last]. + struct TsnRange { + TsnRange(UnwrappedTSN first, UnwrappedTSN last) + : first(first), last(last) {} + UnwrappedTSN first; + UnwrappedTSN last; + }; + + // Adds a TSN to the set. This will try to expand any existing block and + // might merge blocks to ensure that all blocks are non-adjacent. If a + // current block can't be expanded, a new block is created. + // + // The return value indicates if `tsn` was added. If false is returned, the + // `tsn` was already represented in one of the blocks. + bool Add(UnwrappedTSN tsn); + + // Erases all TSNs up to, and including `tsn`. This will remove all blocks + // that are completely below `tsn` and may truncate a block where `tsn` is + // within that block. In that case, the frontmost block's start TSN will be + // the next following tsn after `tsn`. + void EraseTo(UnwrappedTSN tsn); + + // Removes the first block. Must not be called on an empty set. + void PopFront(); + + const std::vector<TsnRange>& blocks() const { return blocks_; } + + bool empty() const { return blocks_.empty(); } + + const TsnRange& front() const { return blocks_.front(); } + + private: + // A sorted vector of non-overlapping and non-adjacent blocks. + std::vector<TsnRange> blocks_; + }; + + std::vector<SackChunk::GapAckBlock> CreateGapAckBlocks() const; + void UpdateAckState(AckState new_state, absl::string_view reason); + static absl::string_view ToString(AckState ack_state); + + const std::string log_prefix_; + // If a packet has ever been seen. + bool seen_packet_; + Timer& delayed_ack_timer_; + AckState ack_state_ = AckState::kIdle; + UnwrappedTSN::Unwrapper tsn_unwrapper_; + + // All TSNs up until (and including) this value have been seen. + UnwrappedTSN last_cumulative_acked_tsn_; + // Received TSNs that are not directly following `last_cumulative_acked_tsn_`. + AdditionalTsnBlocks additional_tsn_blocks_; + std::set<TSN> duplicate_tsns_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_DATA_TRACKER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/rx/data_tracker_test.cc b/third_party/libwebrtc/net/dcsctp/rx/data_tracker_test.cc new file mode 100644 index 0000000000..f74dd6eb0b --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/data_tracker_test.cc @@ -0,0 +1,739 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/data_tracker.h" + +#include <cstdint> +#include <initializer_list> +#include <memory> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/timer/fake_timeout.h" +#include "net/dcsctp/timer/timer.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +constexpr size_t kArwnd = 10000; +constexpr TSN kInitialTSN(11); + +class DataTrackerTest : public testing::Test { + protected: + DataTrackerTest() + : timeout_manager_([this]() { return now_; }), + timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return timeout_manager_.CreateTimeout(precision); + }), + timer_(timer_manager_.CreateTimer( + "test/delayed_ack", + []() { return absl::nullopt; }, + TimerOptions(DurationMs(0)))), + tracker_( + std::make_unique<DataTracker>("log: ", timer_.get(), kInitialTSN)) { + } + + void Observer(std::initializer_list<uint32_t> tsns, + bool expect_as_duplicate = false) { + for (const uint32_t tsn : tsns) { + if (expect_as_duplicate) { + EXPECT_FALSE( + tracker_->Observe(TSN(tsn), AnyDataChunk::ImmediateAckFlag(false))); + } else { + EXPECT_TRUE( + tracker_->Observe(TSN(tsn), AnyDataChunk::ImmediateAckFlag(false))); + } + } + } + + void HandoverTracker() { + EXPECT_TRUE(tracker_->GetHandoverReadiness().IsReady()); + DcSctpSocketHandoverState state; + tracker_->AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + tracker_ = + std::make_unique<DataTracker>("log: ", timer_.get(), kInitialTSN); + tracker_->RestoreFromState(state); + } + + TimeMs now_ = TimeMs(0); + FakeTimeoutManager timeout_manager_; + TimerManager timer_manager_; + std::unique_ptr<Timer> timer_; + std::unique_ptr<DataTracker> tracker_; +}; + +TEST_F(DataTrackerTest, Empty) { + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ObserverSingleInOrderPacket) { + Observer({11}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ObserverManyInOrderMovesCumulativeTsnAck) { + Observer({11, 12, 13}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(13)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ObserveOutOfOrderMovesCumulativeTsnAck) { + Observer({12, 13, 14, 11}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, SingleGap) { + Observer({12}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ExampleFromRFC4960Section334) { + Observer({11, 12, 14, 15, 17}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(12)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 5))); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, AckAlreadyReceivedChunk) { + Observer({11}); + SackChunk sack1 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack1.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack1.gap_ack_blocks(), IsEmpty()); + + // Receive old chunk + Observer({8}, /*expect_as_duplicate=*/true); + SackChunk sack2 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack2.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack2.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, DoubleSendRetransmittedChunk) { + Observer({11, 13, 14, 15}); + SackChunk sack1 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack1.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack1.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 4))); + + // Fill in the hole. + Observer({12, 16, 17, 18}); + SackChunk sack2 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack2.cumulative_tsn_ack(), TSN(18)); + EXPECT_THAT(sack2.gap_ack_blocks(), IsEmpty()); + + // Receive chunk 12 again. + Observer({12}, /*expect_as_duplicate=*/true); + Observer({19, 20, 21}); + SackChunk sack3 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack3.cumulative_tsn_ack(), TSN(21)); + EXPECT_THAT(sack3.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ForwardTsnSimple) { + // Messages (11, 12, 13), (14, 15) - first message expires. + Observer({11, 12, 15}); + + tracker_->HandleForwardTsn(TSN(13)); + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(13)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); +} + +TEST_F(DataTrackerTest, ForwardTsnSkipsFromGapBlock) { + // Messages (11, 12, 13), (14, 15) - first message expires. + Observer({11, 12, 14}); + + tracker_->HandleForwardTsn(TSN(13)); + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ExampleFromRFC3758) { + tracker_->HandleForwardTsn(TSN(102)); + + Observer({102}, /*expect_as_duplicate=*/true); + Observer({104, 105, 107}); + + tracker_->HandleForwardTsn(TSN(103)); + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(105)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); +} + +TEST_F(DataTrackerTest, EmptyAllAcks) { + Observer({11, 13, 14, 15}); + + tracker_->HandleForwardTsn(TSN(100)); + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(100)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, SetsArwndCorrectly) { + SackChunk sack1 = tracker_->CreateSelectiveAck(/*a_rwnd=*/100); + EXPECT_EQ(sack1.a_rwnd(), 100u); + + SackChunk sack2 = tracker_->CreateSelectiveAck(/*a_rwnd=*/101); + EXPECT_EQ(sack2.a_rwnd(), 101u); +} + +TEST_F(DataTrackerTest, WillIncreaseCumAckTsn) { + EXPECT_EQ(tracker_->last_cumulative_acked_tsn(), TSN(10)); + EXPECT_FALSE(tracker_->will_increase_cum_ack_tsn(TSN(10))); + EXPECT_TRUE(tracker_->will_increase_cum_ack_tsn(TSN(11))); + EXPECT_FALSE(tracker_->will_increase_cum_ack_tsn(TSN(12))); + + Observer({11, 12, 13, 14, 15}); + EXPECT_EQ(tracker_->last_cumulative_acked_tsn(), TSN(15)); + EXPECT_FALSE(tracker_->will_increase_cum_ack_tsn(TSN(15))); + EXPECT_TRUE(tracker_->will_increase_cum_ack_tsn(TSN(16))); + EXPECT_FALSE(tracker_->will_increase_cum_ack_tsn(TSN(17))); +} + +TEST_F(DataTrackerTest, ForceShouldSendSackImmediately) { + EXPECT_FALSE(tracker_->ShouldSendAck()); + + tracker_->ForceImmediateSack(); + + EXPECT_TRUE(tracker_->ShouldSendAck()); +} + +TEST_F(DataTrackerTest, WillAcceptValidTSNs) { + // The initial TSN is always one more than the last, which is our base. + TSN last_tsn = TSN(*kInitialTSN - 1); + int limit = static_cast<int>(DataTracker::kMaxAcceptedOutstandingFragments); + + for (int i = -limit; i <= limit; ++i) { + EXPECT_TRUE(tracker_->IsTSNValid(TSN(*last_tsn + i))); + } +} + +TEST_F(DataTrackerTest, WillNotAcceptInvalidTSNs) { + // The initial TSN is always one more than the last, which is our base. + TSN last_tsn = TSN(*kInitialTSN - 1); + + size_t limit = DataTracker::kMaxAcceptedOutstandingFragments; + EXPECT_FALSE(tracker_->IsTSNValid(TSN(*last_tsn + limit + 1))); + EXPECT_FALSE(tracker_->IsTSNValid(TSN(*last_tsn - (limit + 1)))); + EXPECT_FALSE(tracker_->IsTSNValid(TSN(*last_tsn + 0x8000000))); + EXPECT_FALSE(tracker_->IsTSNValid(TSN(*last_tsn - 0x8000000))); +} + +TEST_F(DataTrackerTest, ReportSingleDuplicateTsns) { + Observer({11, 12}); + Observer({11}, /*expect_as_duplicate=*/true); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(12)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), UnorderedElementsAre(TSN(11))); +} + +TEST_F(DataTrackerTest, ReportMultipleDuplicateTsns) { + Observer({11, 12, 13, 14}); + Observer({12, 13, 12, 13}, /*expect_as_duplicate=*/true); + Observer({15, 16}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(16)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), UnorderedElementsAre(TSN(12), TSN(13))); +} + +TEST_F(DataTrackerTest, ReportDuplicateTsnsInGapAckBlocks) { + Observer({11, /*12,*/ 13, 14}); + Observer({13, 14}, /*expect_as_duplicate=*/true); + Observer({15, 16}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 5))); + EXPECT_THAT(sack.duplicate_tsns(), UnorderedElementsAre(TSN(13), TSN(14))); +} + +TEST_F(DataTrackerTest, ClearsDuplicateTsnsAfterCreatingSack) { + Observer({11, 12, 13, 14}); + Observer({12, 13, 12, 13}, /*expect_as_duplicate=*/true); + Observer({15, 16}); + SackChunk sack1 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack1.cumulative_tsn_ack(), TSN(16)); + EXPECT_THAT(sack1.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack1.duplicate_tsns(), UnorderedElementsAre(TSN(12), TSN(13))); + + Observer({17}); + SackChunk sack2 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack2.cumulative_tsn_ack(), TSN(17)); + EXPECT_THAT(sack2.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack2.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, LimitsNumberOfDuplicatesReported) { + for (size_t i = 0; i < DataTracker::kMaxDuplicateTsnReported + 10; ++i) { + TSN tsn(11 + i); + tracker_->Observe(tsn, AnyDataChunk::ImmediateAckFlag(false)); + tracker_->Observe(tsn, AnyDataChunk::ImmediateAckFlag(false)); + } + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), + SizeIs(DataTracker::kMaxDuplicateTsnReported)); +} + +TEST_F(DataTrackerTest, LimitsNumberOfGapAckBlocksReported) { + for (size_t i = 0; i < DataTracker::kMaxGapAckBlocksReported + 10; ++i) { + TSN tsn(11 + i * 2); + tracker_->Observe(tsn, AnyDataChunk::ImmediateAckFlag(false)); + } + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack.gap_ack_blocks(), + SizeIs(DataTracker::kMaxGapAckBlocksReported)); +} + +TEST_F(DataTrackerTest, SendsSackForFirstPacketObserved) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, SendsSackEverySecondPacketWhenThereIsNoPacketLoss) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({12}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + Observer({13}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({14}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + Observer({15}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, SendsSackEveryPacketOnPacketLoss) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({13}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({14}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({15}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({16}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + // Fill the hole. + Observer({12}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + // Goes back to every second packet + Observer({17}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({18}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, SendsSackOnDuplicateDataChunks) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({11}, /*expect_as_duplicate=*/true); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({12}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + // Goes back to every second packet + Observer({13}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + // Duplicate again + Observer({12}, /*expect_as_duplicate=*/true); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, GapAckBlockAddSingleBlock) { + Observer({12}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); +} + +TEST_F(DataTrackerTest, GapAckBlockAddsAnother) { + Observer({12}); + Observer({14}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2), + SackChunk::GapAckBlock(4, 4))); +} + +TEST_F(DataTrackerTest, GapAckBlockAddsDuplicate) { + Observer({12}); + Observer({12}, /*expect_as_duplicate=*/true); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); + EXPECT_THAT(sack.duplicate_tsns(), ElementsAre(TSN(12))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToRight) { + Observer({12}); + Observer({13}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 3))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToRightWithOther) { + Observer({12}); + Observer({20}); + Observer({30}); + Observer({21}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 2), // + SackChunk::GapAckBlock(10, 11), // + SackChunk::GapAckBlock(20, 20))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToLeft) { + Observer({13}); + Observer({12}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 3))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToLeftWithOther) { + Observer({12}); + Observer({21}); + Observer({30}); + Observer({20}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 2), // + SackChunk::GapAckBlock(10, 11), // + SackChunk::GapAckBlock(20, 20))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToLRightAndMerges) { + Observer({12}); + Observer({20}); + Observer({22}); + Observer({30}); + Observer({21}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 2), // + SackChunk::GapAckBlock(10, 12), // + SackChunk::GapAckBlock(20, 20))); +} + +TEST_F(DataTrackerTest, GapAckBlockMergesManyBlocksIntoOne) { + Observer({22}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12))); + Observer({30}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(20, 20))); + Observer({24}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(20, 20))); + Observer({28}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(18, 18), // + SackChunk::GapAckBlock(20, 20))); + Observer({26}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(16, 16), // + SackChunk::GapAckBlock(18, 18), // + SackChunk::GapAckBlock(20, 20))); + Observer({29}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(16, 16), // + SackChunk::GapAckBlock(18, 20))); + Observer({23}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 14), // + SackChunk::GapAckBlock(16, 16), // + SackChunk::GapAckBlock(18, 20))); + Observer({27}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 14), // + SackChunk::GapAckBlock(16, 20))); + + Observer({25}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 20))); + Observer({20}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 10), // + SackChunk::GapAckBlock(12, 20))); + Observer({32}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 10), // + SackChunk::GapAckBlock(12, 20), // + SackChunk::GapAckBlock(22, 22))); + Observer({21}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 20), // + SackChunk::GapAckBlock(22, 22))); + Observer({31}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 22))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveBeforeCumAckTsn) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(8)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 4), // + SackChunk::GapAckBlock(10, 12), + SackChunk::GapAckBlock(20, 21))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveBeforeFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(11)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveAtBeginningOfFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(12)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveAtMiddleOfFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + tracker_->HandleForwardTsn(TSN(13)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveAtEndOfFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + tracker_->HandleForwardTsn(TSN(14)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAfterFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(18)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(18)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 4), // + SackChunk::GapAckBlock(12, 13))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightBeforeSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(19)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAtStartOfSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(20)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAtMiddleOfSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(21)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAtEndOfSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(22)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveeFarAfterAllBlocks) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(40)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(40)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, HandoverEmpty) { + HandoverTracker(); + Observer({11}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, + HandoverWhileSendingSackEverySecondPacketWhenThereIsNoPacketLoss) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + + HandoverTracker(); + + Observer({12}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + Observer({13}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({14}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + Observer({15}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, HandoverWhileSendingSackEveryPacketOnPacketLoss) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + Observer({13}); + EXPECT_EQ(tracker_->GetHandoverReadiness(), + HandoverReadinessStatus().Add( + HandoverUnreadinessReason::kDataTrackerTsnBlocksPending)); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + Observer({14}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_EQ(tracker_->GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kDataTrackerTsnBlocksPending)); + Observer({15}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + Observer({16}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + // Fill the hole. + Observer({12}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + // Goes back to every second packet + Observer({17}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + + HandoverTracker(); + + Observer({18}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.cc b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.cc new file mode 100644 index 0000000000..8b316de676 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.cc @@ -0,0 +1,272 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/interleaved_reassembly_streams.h" + +#include <stddef.h> + +#include <cstdint> +#include <functional> +#include <iterator> +#include <map> +#include <numeric> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/types.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +InterleavedReassemblyStreams::InterleavedReassemblyStreams( + absl::string_view log_prefix, + OnAssembledMessage on_assembled_message) + : log_prefix_(log_prefix), on_assembled_message_(on_assembled_message) {} + +size_t InterleavedReassemblyStreams::Stream::TryToAssembleMessage( + UnwrappedMID mid) { + std::map<UnwrappedMID, ChunkMap>::const_iterator it = + chunks_by_mid_.find(mid); + if (it == chunks_by_mid_.end()) { + RTC_DLOG(LS_VERBOSE) << parent_.log_prefix_ << "TryToAssembleMessage " + << *mid.Wrap() << " - no chunks"; + return 0; + } + const ChunkMap& chunks = it->second; + if (!chunks.begin()->second.second.is_beginning || + !chunks.rbegin()->second.second.is_end) { + RTC_DLOG(LS_VERBOSE) << parent_.log_prefix_ << "TryToAssembleMessage " + << *mid.Wrap() << "- missing beginning or end"; + return 0; + } + int64_t fsn_diff = *chunks.rbegin()->first - *chunks.begin()->first; + if (fsn_diff != (static_cast<int64_t>(chunks.size()) - 1)) { + RTC_DLOG(LS_VERBOSE) << parent_.log_prefix_ << "TryToAssembleMessage " + << *mid.Wrap() << "- not all chunks exist (have " + << chunks.size() << ", expect " << (fsn_diff + 1) + << ")"; + return 0; + } + + size_t removed_bytes = AssembleMessage(chunks); + RTC_DLOG(LS_VERBOSE) << parent_.log_prefix_ << "TryToAssembleMessage " + << *mid.Wrap() << " - succeeded and removed " + << removed_bytes; + + chunks_by_mid_.erase(mid); + return removed_bytes; +} + +size_t InterleavedReassemblyStreams::Stream::AssembleMessage( + const ChunkMap& tsn_chunks) { + size_t count = tsn_chunks.size(); + if (count == 1) { + // Fast path - zero-copy + const Data& data = tsn_chunks.begin()->second.second; + size_t payload_size = data.size(); + UnwrappedTSN tsns[1] = {tsn_chunks.begin()->second.first}; + DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload)); + parent_.on_assembled_message_(tsns, std::move(message)); + return payload_size; + } + + // Slow path - will need to concatenate the payload. + std::vector<UnwrappedTSN> tsns; + tsns.reserve(count); + + std::vector<uint8_t> payload; + size_t payload_size = absl::c_accumulate( + tsn_chunks, 0, + [](size_t v, const auto& p) { return v + p.second.second.size(); }); + payload.reserve(payload_size); + + for (auto& item : tsn_chunks) { + const UnwrappedTSN tsn = item.second.first; + const Data& data = item.second.second; + tsns.push_back(tsn); + payload.insert(payload.end(), data.payload.begin(), data.payload.end()); + } + + const Data& data = tsn_chunks.begin()->second.second; + + DcSctpMessage message(data.stream_id, data.ppid, std::move(payload)); + parent_.on_assembled_message_(tsns, std::move(message)); + return payload_size; +} + +size_t InterleavedReassemblyStreams::Stream::EraseTo(MID message_id) { + UnwrappedMID unwrapped_mid = mid_unwrapper_.Unwrap(message_id); + + size_t removed_bytes = 0; + auto it = chunks_by_mid_.begin(); + while (it != chunks_by_mid_.end() && it->first <= unwrapped_mid) { + removed_bytes += absl::c_accumulate( + it->second, 0, + [](size_t r2, const auto& q) { return r2 + q.second.second.size(); }); + it = chunks_by_mid_.erase(it); + } + + if (!stream_id_.unordered) { + // For ordered streams, erasing a message might suddenly unblock that queue + // and allow it to deliver any following received messages. + if (unwrapped_mid >= next_mid_) { + next_mid_ = unwrapped_mid.next_value(); + } + + removed_bytes += TryToAssembleMessages(); + } + + return removed_bytes; +} + +int InterleavedReassemblyStreams::Stream::Add(UnwrappedTSN tsn, Data data) { + RTC_DCHECK_EQ(*data.is_unordered, *stream_id_.unordered); + RTC_DCHECK_EQ(*data.stream_id, *stream_id_.stream_id); + int queued_bytes = data.size(); + UnwrappedMID mid = mid_unwrapper_.Unwrap(data.message_id); + FSN fsn = data.fsn; + auto [unused, inserted] = + chunks_by_mid_[mid].emplace(fsn, std::make_pair(tsn, std::move(data))); + if (!inserted) { + return 0; + } + + if (stream_id_.unordered) { + queued_bytes -= TryToAssembleMessage(mid); + } else { + if (mid == next_mid_) { + queued_bytes -= TryToAssembleMessages(); + } + } + + return queued_bytes; +} + +size_t InterleavedReassemblyStreams::Stream::TryToAssembleMessages() { + size_t removed_bytes = 0; + + for (;;) { + size_t removed_bytes_this_iter = TryToAssembleMessage(next_mid_); + if (removed_bytes_this_iter == 0) { + break; + } + + removed_bytes += removed_bytes_this_iter; + next_mid_.Increment(); + } + return removed_bytes; +} + +void InterleavedReassemblyStreams::Stream::AddHandoverState( + DcSctpSocketHandoverState& state) const { + if (stream_id_.unordered) { + DcSctpSocketHandoverState::UnorderedStream state_stream; + state_stream.id = stream_id_.stream_id.value(); + state.rx.unordered_streams.push_back(std::move(state_stream)); + } else { + DcSctpSocketHandoverState::OrderedStream state_stream; + state_stream.id = stream_id_.stream_id.value(); + state_stream.next_ssn = next_mid_.Wrap().value(); + state.rx.ordered_streams.push_back(std::move(state_stream)); + } +} + +InterleavedReassemblyStreams::Stream& +InterleavedReassemblyStreams::GetOrCreateStream(const FullStreamId& stream_id) { + auto it = streams_.find(stream_id); + if (it == streams_.end()) { + it = + streams_ + .emplace(std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple(stream_id, this)) + .first; + } + return it->second; +} + +int InterleavedReassemblyStreams::Add(UnwrappedTSN tsn, Data data) { + return GetOrCreateStream(FullStreamId(data.is_unordered, data.stream_id)) + .Add(tsn, std::move(data)); +} + +size_t InterleavedReassemblyStreams::HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) { + size_t removed_bytes = 0; + for (const auto& skipped : skipped_streams) { + removed_bytes += + GetOrCreateStream(FullStreamId(skipped.unordered, skipped.stream_id)) + .EraseTo(skipped.message_id); + } + return removed_bytes; +} + +void InterleavedReassemblyStreams::ResetStreams( + rtc::ArrayView<const StreamID> stream_ids) { + if (stream_ids.empty()) { + for (auto& entry : streams_) { + entry.second.Reset(); + } + } else { + for (StreamID stream_id : stream_ids) { + GetOrCreateStream(FullStreamId(IsUnordered(true), stream_id)).Reset(); + GetOrCreateStream(FullStreamId(IsUnordered(false), stream_id)).Reset(); + } + } +} + +HandoverReadinessStatus InterleavedReassemblyStreams::GetHandoverReadiness() + const { + HandoverReadinessStatus status; + for (const auto& [stream_id, stream] : streams_) { + if (stream.has_unassembled_chunks()) { + status.Add( + stream_id.unordered + ? HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks + : HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks); + break; + } + } + return status; +} + +void InterleavedReassemblyStreams::AddHandoverState( + DcSctpSocketHandoverState& state) { + for (const auto& [unused, stream] : streams_) { + stream.AddHandoverState(state); + } +} + +void InterleavedReassemblyStreams::RestoreFromState( + const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(streams_.empty()); + + for (const DcSctpSocketHandoverState::OrderedStream& state : + state.rx.ordered_streams) { + FullStreamId stream_id(IsUnordered(false), StreamID(state.id)); + streams_.emplace( + std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple(stream_id, this, MID(state.next_ssn))); + } + for (const DcSctpSocketHandoverState::UnorderedStream& state : + state.rx.unordered_streams) { + FullStreamId stream_id(IsUnordered(true), StreamID(state.id)); + streams_.emplace(std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple(stream_id, this)); + } +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.h b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.h new file mode 100644 index 0000000000..a7b67707e9 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.h @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_INTERLEAVED_REASSEMBLY_STREAMS_H_ +#define NET_DCSCTP_RX_INTERLEAVED_REASSEMBLY_STREAMS_H_ + +#include <cstdint> +#include <map> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/rx/reassembly_streams.h" + +namespace dcsctp { + +// Handles reassembly of incoming data when interleaved message sending is +// enabled on the association, i.e. when RFC8260 is in use. +class InterleavedReassemblyStreams : public ReassemblyStreams { + public: + InterleavedReassemblyStreams(absl::string_view log_prefix, + OnAssembledMessage on_assembled_message); + + int Add(UnwrappedTSN tsn, Data data) override; + + size_t HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) + override; + + void ResetStreams(rtc::ArrayView<const StreamID> stream_ids) override; + + HandoverReadinessStatus GetHandoverReadiness() const override; + void AddHandoverState(DcSctpSocketHandoverState& state) override; + void RestoreFromState(const DcSctpSocketHandoverState& state) override; + + private: + struct FullStreamId { + const IsUnordered unordered; + const StreamID stream_id; + + FullStreamId(IsUnordered unordered, StreamID stream_id) + : unordered(unordered), stream_id(stream_id) {} + + friend bool operator<(FullStreamId a, FullStreamId b) { + return a.unordered < b.unordered || + (!(a.unordered < b.unordered) && (a.stream_id < b.stream_id)); + } + }; + + class Stream { + public: + Stream(FullStreamId stream_id, + InterleavedReassemblyStreams* parent, + MID next_mid = MID(0)) + : stream_id_(stream_id), + parent_(*parent), + next_mid_(mid_unwrapper_.Unwrap(next_mid)) {} + int Add(UnwrappedTSN tsn, Data data); + size_t EraseTo(MID message_id); + void Reset() { + mid_unwrapper_.Reset(); + next_mid_ = mid_unwrapper_.Unwrap(MID(0)); + } + bool has_unassembled_chunks() const { return !chunks_by_mid_.empty(); } + void AddHandoverState(DcSctpSocketHandoverState& state) const; + + private: + using ChunkMap = std::map<FSN, std::pair<UnwrappedTSN, Data>>; + + // Try to assemble one message identified by `mid`. + // Returns the number of bytes assembled if a message was assembled. + size_t TryToAssembleMessage(UnwrappedMID mid); + size_t AssembleMessage(const ChunkMap& tsn_chunks); + // Try to assemble one or several messages in order from the stream. + // Returns the number of bytes assembled if one or more messages were + // assembled. + size_t TryToAssembleMessages(); + + const FullStreamId stream_id_; + InterleavedReassemblyStreams& parent_; + std::map<UnwrappedMID, ChunkMap> chunks_by_mid_; + UnwrappedMID::Unwrapper mid_unwrapper_; + UnwrappedMID next_mid_; + }; + + Stream& GetOrCreateStream(const FullStreamId& stream_id); + + const std::string log_prefix_; + + // Callback for when a message has been assembled. + const OnAssembledMessage on_assembled_message_; + + // All unordered and ordered streams, managing not-yet-assembled data. + std::map<FullStreamId, Stream> streams_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_INTERLEAVED_REASSEMBLY_STREAMS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams_test.cc b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams_test.cc new file mode 100644 index 0000000000..df4024ed60 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams_test.cc @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/interleaved_reassembly_streams.h" + +#include <cstdint> +#include <memory> +#include <utility> + +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/rx/reassembly_streams.h" +#include "net/dcsctp/testing/data_generator.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::MockFunction; +using ::testing::NiceMock; + +class InterleavedReassemblyStreamsTest : public testing::Test { + protected: + UnwrappedTSN tsn(uint32_t value) { return tsn_.Unwrap(TSN(value)); } + + InterleavedReassemblyStreamsTest() {} + DataGenerator gen_; + UnwrappedTSN::Unwrapper tsn_; +}; + +TEST_F(InterleavedReassemblyStreamsTest, + AddUnorderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Unordered({5, 6})), 2); + // Adding the end fragment should make it empty again. + EXPECT_EQ(streams.Add(tsn(4), gen_.Unordered({7}, "E")), -6); +} + +TEST_F(InterleavedReassemblyStreamsTest, + AddSimpleOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), -6); +} + +TEST_F(InterleavedReassemblyStreamsTest, + AddMoreComplexOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + Data late = gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + EXPECT_EQ(streams.Add(tsn(2), std::move(late)), -8); +} + +TEST_F(InterleavedReassemblyStreamsTest, + DeleteUnorderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Unordered({5, 6})), 2); + + IForwardTsnChunk::SkippedStream skipped[] = { + IForwardTsnChunk::SkippedStream(IsUnordered(true), StreamID(1), MID(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(3), skipped), 6u); +} + +TEST_F(InterleavedReassemblyStreamsTest, + DeleteSimpleOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + + IForwardTsnChunk::SkippedStream skipped[] = { + IForwardTsnChunk::SkippedStream(IsUnordered(false), StreamID(1), MID(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(3), skipped), 6u); +} + +TEST_F(InterleavedReassemblyStreamsTest, + DeleteManyOrderedMessagesReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + + // Expire all three messages + IForwardTsnChunk::SkippedStream skipped[] = { + IForwardTsnChunk::SkippedStream(IsUnordered(false), StreamID(1), MID(2))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(8), skipped), 8u); +} + +TEST_F(InterleavedReassemblyStreamsTest, + DeleteOrderedMessageDelivesTwoReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + + // The first ordered message expire, and the following two are delivered. + IForwardTsnChunk::SkippedStream skipped[] = { + IForwardTsnChunk::SkippedStream(IsUnordered(false), StreamID(1), MID(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(4), skipped), 8u); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.cc b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.cc new file mode 100644 index 0000000000..f72c5cb8c1 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.cc @@ -0,0 +1,312 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/reassembly_queue.h" + +#include <stddef.h> + +#include <algorithm> +#include <cstdint> +#include <memory> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/rx/interleaved_reassembly_streams.h" +#include "net/dcsctp/rx/reassembly_streams.h" +#include "net/dcsctp/rx/traditional_reassembly_streams.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace { +std::unique_ptr<ReassemblyStreams> CreateStreams( + absl::string_view log_prefix, + ReassemblyStreams::OnAssembledMessage on_assembled_message, + bool use_message_interleaving) { + if (use_message_interleaving) { + return std::make_unique<InterleavedReassemblyStreams>( + log_prefix, std::move(on_assembled_message)); + } + return std::make_unique<TraditionalReassemblyStreams>( + log_prefix, std::move(on_assembled_message)); +} +} // namespace + +ReassemblyQueue::ReassemblyQueue(absl::string_view log_prefix, + TSN peer_initial_tsn, + size_t max_size_bytes, + bool use_message_interleaving) + : log_prefix_(std::string(log_prefix) + "reasm: "), + max_size_bytes_(max_size_bytes), + watermark_bytes_(max_size_bytes * kHighWatermarkLimit), + last_assembled_tsn_watermark_( + tsn_unwrapper_.Unwrap(TSN(*peer_initial_tsn - 1))), + last_completed_reset_req_seq_nbr_(ReconfigRequestSN(0)), + streams_(CreateStreams( + log_prefix_, + [this](rtc::ArrayView<const UnwrappedTSN> tsns, + DcSctpMessage message) { + AddReassembledMessage(tsns, std::move(message)); + }, + use_message_interleaving)) {} + +void ReassemblyQueue::Add(TSN tsn, Data data) { + RTC_DCHECK(IsConsistent()); + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "added tsn=" << *tsn + << ", stream=" << *data.stream_id << ":" + << *data.message_id << ":" << *data.fsn << ", type=" + << (data.is_beginning && data.is_end ? "complete" + : data.is_beginning ? "first" + : data.is_end ? "last" + : "middle"); + + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.Unwrap(tsn); + + if (unwrapped_tsn <= last_assembled_tsn_watermark_ || + delivered_tsns_.find(unwrapped_tsn) != delivered_tsns_.end()) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Chunk has already been delivered - skipping"; + return; + } + + // If a stream reset has been received with a "sender's last assigned tsn" in + // the future, the socket is in "deferred reset processing" mode and must + // buffer chunks until it's exited. + if (deferred_reset_streams_.has_value() && + unwrapped_tsn > + tsn_unwrapper_.Unwrap( + deferred_reset_streams_->req.sender_last_assigned_tsn())) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Deferring chunk with tsn=" << *tsn + << " until cum_ack_tsn=" + << *deferred_reset_streams_->req.sender_last_assigned_tsn(); + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "In this mode, any data arriving with a TSN larger than the + // Sender's Last Assigned TSN for the affected stream(s) MUST be queued + // locally and held until the cumulative acknowledgment point reaches the + // Sender's Last Assigned TSN." + queued_bytes_ += data.size(); + deferred_reset_streams_->deferred_chunks.emplace_back( + std::make_pair(tsn, std::move(data))); + } else { + queued_bytes_ += streams_->Add(unwrapped_tsn, std::move(data)); + } + + // https://tools.ietf.org/html/rfc4960#section-6.9 + // "Note: If the data receiver runs out of buffer space while still + // waiting for more fragments to complete the reassembly of the message, it + // should dispatch part of its inbound message through a partial delivery + // API (see Section 10), freeing some of its receive buffer space so that + // the rest of the message may be received." + + // TODO(boivie): Support EOR flag and partial delivery? + RTC_DCHECK(IsConsistent()); +} + +ReconfigurationResponseParameter::Result ReassemblyQueue::ResetStreams( + const OutgoingSSNResetRequestParameter& req, + TSN cum_tsn_ack) { + RTC_DCHECK(IsConsistent()); + if (deferred_reset_streams_.has_value()) { + // In deferred mode already. + return ReconfigurationResponseParameter::Result::kInProgress; + } else if (req.request_sequence_number() <= + last_completed_reset_req_seq_nbr_) { + // Already performed at some time previously. + return ReconfigurationResponseParameter::Result::kSuccessPerformed; + } + + UnwrappedTSN sla_tsn = tsn_unwrapper_.Unwrap(req.sender_last_assigned_tsn()); + UnwrappedTSN unwrapped_cum_tsn_ack = tsn_unwrapper_.Unwrap(cum_tsn_ack); + + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "If the Sender's Last Assigned TSN is greater than the + // cumulative acknowledgment point, then the endpoint MUST enter "deferred + // reset processing"." + if (sla_tsn > unwrapped_cum_tsn_ack) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ + << "Entering deferred reset processing mode until cum_tsn_ack=" + << *req.sender_last_assigned_tsn(); + deferred_reset_streams_ = absl::make_optional<DeferredResetStreams>(req); + return ReconfigurationResponseParameter::Result::kInProgress; + } + + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "... streams MUST be reset to 0 as the next expected SSN." + streams_->ResetStreams(req.stream_ids()); + last_completed_reset_req_seq_nbr_ = req.request_sequence_number(); + RTC_DCHECK(IsConsistent()); + return ReconfigurationResponseParameter::Result::kSuccessPerformed; +} + +bool ReassemblyQueue::MaybeResetStreamsDeferred(TSN cum_ack_tsn) { + RTC_DCHECK(IsConsistent()); + if (deferred_reset_streams_.has_value()) { + UnwrappedTSN unwrapped_cum_ack_tsn = tsn_unwrapper_.Unwrap(cum_ack_tsn); + UnwrappedTSN unwrapped_sla_tsn = tsn_unwrapper_.Unwrap( + deferred_reset_streams_->req.sender_last_assigned_tsn()); + if (unwrapped_cum_ack_tsn >= unwrapped_sla_tsn) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Leaving deferred reset processing with tsn=" + << *cum_ack_tsn << ", feeding back " + << deferred_reset_streams_->deferred_chunks.size() + << " chunks"; + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "... streams MUST be reset to 0 as the next expected SSN." + streams_->ResetStreams(deferred_reset_streams_->req.stream_ids()); + std::vector<std::pair<TSN, Data>> deferred_chunks = + std::move(deferred_reset_streams_->deferred_chunks); + // The response will not be sent now, but as a reply to the retried + // request, which will come as "in progress" has been sent prior. + last_completed_reset_req_seq_nbr_ = + deferred_reset_streams_->req.request_sequence_number(); + deferred_reset_streams_ = absl::nullopt; + + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "Any queued TSNs (queued at step E2) MUST now be released and processed + // normally." + for (auto& [tsn, data] : deferred_chunks) { + queued_bytes_ -= data.size(); + Add(tsn, std::move(data)); + } + + RTC_DCHECK(IsConsistent()); + return true; + } else { + RTC_DLOG(LS_VERBOSE) << "Staying in deferred reset processing. tsn=" + << *cum_ack_tsn; + } + } + + return false; +} + +std::vector<DcSctpMessage> ReassemblyQueue::FlushMessages() { + std::vector<DcSctpMessage> ret; + reassembled_messages_.swap(ret); + return ret; +} + +void ReassemblyQueue::AddReassembledMessage( + rtc::ArrayView<const UnwrappedTSN> tsns, + DcSctpMessage message) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Assembled message from TSN=[" + << StrJoin(tsns, ",", + [](rtc::StringBuilder& sb, UnwrappedTSN tsn) { + sb << *tsn.Wrap(); + }) + << "], message; stream_id=" << *message.stream_id() + << ", ppid=" << *message.ppid() + << ", payload=" << message.payload().size() << " bytes"; + + for (const UnwrappedTSN tsn : tsns) { + if (tsn <= last_assembled_tsn_watermark_) { + // This can be provoked by a misbehaving peer by sending FORWARD-TSN with + // invalid SSNs, allowing ordered messages to stay in the queue that + // should've been discarded. + RTC_DLOG(LS_VERBOSE) + << log_prefix_ + << "Message is built from fragments already seen - skipping"; + return; + } else if (tsn == last_assembled_tsn_watermark_.next_value()) { + // Update watermark, or insert into delivered_tsns_ + last_assembled_tsn_watermark_.Increment(); + } else { + delivered_tsns_.insert(tsn); + } + } + + // With new TSNs in delivered_tsns, gaps might be filled. + MaybeMoveLastAssembledWatermarkFurther(); + + reassembled_messages_.emplace_back(std::move(message)); +} + +void ReassemblyQueue::MaybeMoveLastAssembledWatermarkFurther() { + // `delivered_tsns_` contain TSNS when there is a gap between ranges of + // assembled TSNs. `last_assembled_tsn_watermark_` should not be adjacent to + // that list, because if so, it can be moved. + while (!delivered_tsns_.empty() && + *delivered_tsns_.begin() == + last_assembled_tsn_watermark_.next_value()) { + last_assembled_tsn_watermark_.Increment(); + delivered_tsns_.erase(delivered_tsns_.begin()); + } +} + +void ReassemblyQueue::Handle(const AnyForwardTsnChunk& forward_tsn) { + RTC_DCHECK(IsConsistent()); + UnwrappedTSN tsn = tsn_unwrapper_.Unwrap(forward_tsn.new_cumulative_tsn()); + + last_assembled_tsn_watermark_ = std::max(last_assembled_tsn_watermark_, tsn); + delivered_tsns_.erase(delivered_tsns_.begin(), + delivered_tsns_.upper_bound(tsn)); + + MaybeMoveLastAssembledWatermarkFurther(); + + queued_bytes_ -= + streams_->HandleForwardTsn(tsn, forward_tsn.skipped_streams()); + RTC_DCHECK(IsConsistent()); +} + +bool ReassemblyQueue::IsConsistent() const { + // `delivered_tsns_` and `last_assembled_tsn_watermark_` mustn't overlap or be + // adjacent. + if (!delivered_tsns_.empty() && + last_assembled_tsn_watermark_.next_value() >= *delivered_tsns_.begin()) { + return false; + } + + // Allow queued_bytes_ to be larger than max_size_bytes, as it's not actively + // enforced in this class. This comparison will still trigger if queued_bytes_ + // became "negative". + return (queued_bytes_ >= 0 && queued_bytes_ <= 2 * max_size_bytes_); +} + +HandoverReadinessStatus ReassemblyQueue::GetHandoverReadiness() const { + HandoverReadinessStatus status = streams_->GetHandoverReadiness(); + if (!delivered_tsns_.empty()) { + status.Add(HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap); + } + if (deferred_reset_streams_.has_value()) { + status.Add(HandoverUnreadinessReason::kStreamResetDeferred); + } + return status; +} + +void ReassemblyQueue::AddHandoverState(DcSctpSocketHandoverState& state) { + state.rx.last_assembled_tsn = last_assembled_tsn_watermark_.Wrap().value(); + state.rx.last_completed_deferred_reset_req_sn = + last_completed_reset_req_seq_nbr_.value(); + streams_->AddHandoverState(state); +} + +void ReassemblyQueue::RestoreFromState(const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(last_completed_reset_req_seq_nbr_ == ReconfigRequestSN(0)); + + last_assembled_tsn_watermark_ = + tsn_unwrapper_.Unwrap(TSN(state.rx.last_assembled_tsn)); + last_completed_reset_req_seq_nbr_ = + ReconfigRequestSN(state.rx.last_completed_deferred_reset_req_sn); + streams_->RestoreFromState(state); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.h b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.h new file mode 100644 index 0000000000..91f30a3f69 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.h @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_REASSEMBLY_QUEUE_H_ +#define NET_DCSCTP_RX_REASSEMBLY_QUEUE_H_ + +#include <stddef.h> + +#include <cstdint> +#include <memory> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/public/dcsctp_handover_state.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/rx/reassembly_streams.h" + +namespace dcsctp { + +// Contains the received DATA chunks that haven't yet been reassembled, and +// reassembles chunks when possible. +// +// The actual assembly is handled by an implementation of the +// `ReassemblyStreams` interface. +// +// Except for reassembling fragmented messages, this class will also handle two +// less common operations; To handle the receiver-side of partial reliability +// (limited number of retransmissions or limited message lifetime) as well as +// stream resetting, which is used when a sender wishes to close a data channel. +// +// Partial reliability is handled when a FORWARD-TSN or I-FORWARD-TSN chunk is +// received, and it will simply delete any chunks matching the parameters in +// that chunk. This is mainly implemented in ReassemblyStreams. +// +// Resetting streams is handled when a RECONFIG chunks is received, with an +// "Outgoing SSN Reset Request" parameter. That parameter will contain a list of +// streams to reset, and a `sender_last_assigned_tsn`. If this TSN is not yet +// seen, the stream cannot be directly reset, and this class will respond that +// the reset is "deferred". But if this TSN provided is known, the stream can be +// immediately be reset. +// +// The ReassemblyQueue has a maximum size, as it would otherwise be an DoS +// attack vector where a peer could consume all memory of the other peer by +// sending a lot of ordered chunks, but carefully withholding an early one. It +// also has a watermark limit, which the caller can query is the number of bytes +// is above that limit. This is used by the caller to be selective in what to +// add to the reassembly queue, so that it's not exhausted. The caller is +// expected to call `is_full` prior to adding data to the queue and to act +// accordingly if the queue is full. +class ReassemblyQueue { + public: + // When the queue is filled over this fraction (of its maximum size), the + // socket should restrict incoming data to avoid filling up the queue. + static constexpr float kHighWatermarkLimit = 0.9; + + ReassemblyQueue(absl::string_view log_prefix, + TSN peer_initial_tsn, + size_t max_size_bytes, + bool use_message_interleaving = false); + + // Adds a data chunk to the queue, with a `tsn` and other parameters in + // `data`. + void Add(TSN tsn, Data data); + + // Indicates if the reassembly queue has any reassembled messages that can be + // retrieved by calling `FlushMessages`. + bool HasMessages() const { return !reassembled_messages_.empty(); } + + // Returns any reassembled messages. + std::vector<DcSctpMessage> FlushMessages(); + + // Handle a ForwardTSN chunk, when the sender has indicated that the received + // (this class) should forget about some chunks. This is used to implement + // partial reliability. + void Handle(const AnyForwardTsnChunk& forward_tsn); + + // Given the reset stream request and the current cum_tsn_ack, might either + // reset the streams directly (returns kSuccessPerformed), or at a later time, + // by entering the "deferred reset processing" mode (returns kInProgress). + ReconfigurationResponseParameter::Result ResetStreams( + const OutgoingSSNResetRequestParameter& req, + TSN cum_tsn_ack); + + // Given the current (updated) cum_tsn_ack, might leave "defererred reset + // processing" mode and reset streams. Returns true if so. + bool MaybeResetStreamsDeferred(TSN cum_ack_tsn); + + // The number of payload bytes that have been queued. Note that the actual + // memory usage is higher due to additional overhead of tracking received + // data. + size_t queued_bytes() const { return queued_bytes_; } + + // The remaining bytes until the queue has reached the watermark limit. + size_t remaining_bytes() const { return watermark_bytes_ - queued_bytes_; } + + // Indicates if the queue is full. Data should not be added to the queue when + // it's full. + bool is_full() const { return queued_bytes_ >= max_size_bytes_; } + + // Indicates if the queue is above the watermark limit, which is a certain + // percentage of its size. + bool is_above_watermark() const { return queued_bytes_ >= watermark_bytes_; } + + // Returns the watermark limit, in bytes. + size_t watermark_bytes() const { return watermark_bytes_; } + + HandoverReadinessStatus GetHandoverReadiness() const; + + void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); + + private: + bool IsConsistent() const; + void AddReassembledMessage(rtc::ArrayView<const UnwrappedTSN> tsns, + DcSctpMessage message); + void MaybeMoveLastAssembledWatermarkFurther(); + + struct DeferredResetStreams { + explicit DeferredResetStreams(OutgoingSSNResetRequestParameter req) + : req(std::move(req)) {} + OutgoingSSNResetRequestParameter req; + std::vector<std::pair<TSN, Data>> deferred_chunks; + }; + + const std::string log_prefix_; + const size_t max_size_bytes_; + const size_t watermark_bytes_; + UnwrappedTSN::Unwrapper tsn_unwrapper_; + + // Whenever a message has been assembled, either increase + // `last_assembled_tsn_watermark_` or - if there are gaps - add the message's + // TSNs into delivered_tsns_ so that messages are not re-delivered on + // duplicate chunks. + UnwrappedTSN last_assembled_tsn_watermark_; + std::set<UnwrappedTSN> delivered_tsns_; + // Messages that have been reassembled, and will be returned by + // `FlushMessages`. + std::vector<DcSctpMessage> reassembled_messages_; + + // If present, "deferred reset processing" mode is active. + absl::optional<DeferredResetStreams> deferred_reset_streams_; + + // Contains the last request sequence number of the + // OutgoingSSNResetRequestParameter that was performed. + ReconfigRequestSN last_completed_reset_req_seq_nbr_; + + // The number of "payload bytes" that are in this queue, in total. + size_t queued_bytes_ = 0; + + // The actual implementation of ReassemblyStreams. + std::unique_ptr<ReassemblyStreams> streams_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_REASSEMBLY_QUEUE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue_test.cc b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue_test.cc new file mode 100644 index 0000000000..549bc6fce1 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue_test.cc @@ -0,0 +1,509 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/reassembly_queue.h" + +#include <stddef.h> + +#include <algorithm> +#include <array> +#include <cstdint> +#include <iterator> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/testing/data_generator.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +// The default maximum size of the Reassembly Queue. +static constexpr size_t kBufferSize = 10000; + +static constexpr StreamID kStreamID(1); +static constexpr SSN kSSN(0); +static constexpr MID kMID(0); +static constexpr FSN kFSN(0); +static constexpr PPID kPPID(53); + +static constexpr std::array<uint8_t, 4> kShortPayload = {1, 2, 3, 4}; +static constexpr std::array<uint8_t, 4> kMessage2Payload = {5, 6, 7, 8}; +static constexpr std::array<uint8_t, 6> kSixBytePayload = {1, 2, 3, 4, 5, 6}; +static constexpr std::array<uint8_t, 8> kMediumPayload1 = {1, 2, 3, 4, + 5, 6, 7, 8}; +static constexpr std::array<uint8_t, 8> kMediumPayload2 = {9, 10, 11, 12, + 13, 14, 15, 16}; +static constexpr std::array<uint8_t, 16> kLongPayload = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + +MATCHER_P3(SctpMessageIs, stream_id, ppid, expected_payload, "") { + if (arg.stream_id() != stream_id) { + *result_listener << "the stream_id is " << *arg.stream_id(); + return false; + } + + if (arg.ppid() != ppid) { + *result_listener << "the ppid is " << *arg.ppid(); + return false; + } + + if (std::vector<uint8_t>(arg.payload().begin(), arg.payload().end()) != + std::vector<uint8_t>(expected_payload.begin(), expected_payload.end())) { + *result_listener << "the payload is wrong"; + return false; + } + return true; +} + +class ReassemblyQueueTest : public testing::Test { + protected: + ReassemblyQueueTest() {} + DataGenerator gen_; +}; + +TEST_F(ReassemblyQueueTest, EmptyQueue) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + EXPECT_FALSE(reasm.HasMessages()); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, SingleUnorderedChunkMessage) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, LargeUnorderedChunkAllPermutations) { + std::vector<uint32_t> tsns = {10, 11, 12, 13}; + rtc::ArrayView<const uint8_t> payload(kLongPayload); + do { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + + for (size_t i = 0; i < tsns.size(); i++) { + auto span = payload.subview((tsns[i] - 10) * 4, 4); + Data::IsBeginning is_beginning(tsns[i] == 10); + Data::IsEnd is_end(tsns[i] == 13); + + reasm.Add(TSN(tsns[i]), + Data(kStreamID, kSSN, kMID, kFSN, kPPID, + std::vector<uint8_t>(span.begin(), span.end()), + is_beginning, is_end, IsUnordered(false))); + if (i < 3) { + EXPECT_FALSE(reasm.HasMessages()); + } else { + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kLongPayload))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + } + } + } while (std::next_permutation(std::begin(tsns), std::end(tsns))); +} + +TEST_F(ReassemblyQueueTest, SingleOrderedChunkMessage) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); +} + +TEST_F(ReassemblyQueueTest, ManySmallOrderedMessages) { + std::vector<uint32_t> tsns = {10, 11, 12, 13}; + rtc::ArrayView<const uint8_t> payload(kLongPayload); + do { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + for (size_t i = 0; i < tsns.size(); i++) { + auto span = payload.subview((tsns[i] - 10) * 4, 4); + Data::IsBeginning is_beginning(true); + Data::IsEnd is_end(true); + + SSN ssn(static_cast<uint16_t>(tsns[i] - 10)); + reasm.Add(TSN(tsns[i]), + Data(kStreamID, ssn, kMID, kFSN, kPPID, + std::vector<uint8_t>(span.begin(), span.end()), + is_beginning, is_end, IsUnordered(false))); + } + EXPECT_THAT( + reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, payload.subview(0, 4)), + SctpMessageIs(kStreamID, kPPID, payload.subview(4, 4)), + SctpMessageIs(kStreamID, kPPID, payload.subview(8, 4)), + SctpMessageIs(kStreamID, kPPID, payload.subview(12, 4)))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + } while (std::next_permutation(std::begin(tsns), std::end(tsns))); +} + +TEST_F(ReassemblyQueueTest, RetransmissionInLargeOrdered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1}, "B")); + reasm.Add(TSN(12), gen_.Ordered({3})); + reasm.Add(TSN(13), gen_.Ordered({4})); + reasm.Add(TSN(14), gen_.Ordered({5})); + reasm.Add(TSN(15), gen_.Ordered({6})); + reasm.Add(TSN(16), gen_.Ordered({7})); + reasm.Add(TSN(17), gen_.Ordered({8})); + EXPECT_EQ(reasm.queued_bytes(), 7u); + + // lost and retransmitted + reasm.Add(TSN(11), gen_.Ordered({2})); + reasm.Add(TSN(18), gen_.Ordered({9})); + reasm.Add(TSN(19), gen_.Ordered({10})); + EXPECT_EQ(reasm.queued_bytes(), 10u); + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Add(TSN(20), gen_.Ordered({11, 12, 13, 14, 15, 16}, "E")); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kLongPayload))); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, ForwardTSNRemoveUnordered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1}, "B")); + reasm.Add(TSN(12), gen_.Unordered({3})); + reasm.Add(TSN(13), gen_.Unordered({4}, "E")); + + reasm.Add(TSN(14), gen_.Unordered({5}, "B")); + reasm.Add(TSN(15), gen_.Unordered({6})); + reasm.Add(TSN(17), gen_.Unordered({8}, "E")); + EXPECT_EQ(reasm.queued_bytes(), 6u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Handle(ForwardTsnChunk(TSN(13), {})); + EXPECT_EQ(reasm.queued_bytes(), 3u); + + // The lost chunk comes, but too late. + reasm.Add(TSN(11), gen_.Unordered({2})); + EXPECT_FALSE(reasm.HasMessages()); + EXPECT_EQ(reasm.queued_bytes(), 3u); + + // The second lost chunk comes, message is assembled. + reasm.Add(TSN(16), gen_.Unordered({7})); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, ForwardTSNRemoveOrdered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1}, "B")); + reasm.Add(TSN(12), gen_.Ordered({3})); + reasm.Add(TSN(13), gen_.Ordered({4}, "E")); + + reasm.Add(TSN(14), gen_.Ordered({5}, "B")); + reasm.Add(TSN(15), gen_.Ordered({6})); + reasm.Add(TSN(16), gen_.Ordered({7})); + reasm.Add(TSN(17), gen_.Ordered({8}, "E")); + EXPECT_EQ(reasm.queued_bytes(), 7u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Handle(ForwardTsnChunk( + TSN(13), {ForwardTsnChunk::SkippedStream(kStreamID, kSSN)})); + EXPECT_EQ(reasm.queued_bytes(), 0u); + + // The lost chunk comes, but too late. + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kMessage2Payload))); +} + +TEST_F(ReassemblyQueueTest, ForwardTSNRemoveALotOrdered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1}, "B")); + reasm.Add(TSN(12), gen_.Ordered({3})); + reasm.Add(TSN(13), gen_.Ordered({4}, "E")); + + reasm.Add(TSN(15), gen_.Ordered({5}, "B")); + reasm.Add(TSN(16), gen_.Ordered({6})); + reasm.Add(TSN(17), gen_.Ordered({7})); + reasm.Add(TSN(18), gen_.Ordered({8}, "E")); + EXPECT_EQ(reasm.queued_bytes(), 7u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Handle(ForwardTsnChunk( + TSN(13), {ForwardTsnChunk::SkippedStream(kStreamID, kSSN)})); + EXPECT_EQ(reasm.queued_bytes(), 0u); + + // The lost chunk comes, but too late. + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kMessage2Payload))); +} + +TEST_F(ReassemblyQueueTest, ShouldntDeliverMessagesBeforeInitialTsn) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(5), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, ShouldntRedeliverUnorderedMessages) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, ShouldntRedeliverUnorderedMessagesReallyUnordered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "B")); + EXPECT_EQ(reasm.queued_bytes(), 4u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Add(TSN(12), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 4u); + EXPECT_TRUE(reasm.HasMessages()); + + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); + reasm.Add(TSN(12), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 4u); + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, ShouldntDeliverBeforeForwardedTsn) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Handle(ForwardTsnChunk(TSN(12), {})); + + reasm.Add(TSN(12), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, NotReadyForHandoverWhenDeliveredTsnsHaveGap) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "B")); + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Add(TSN(12), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_EQ( + reasm.GetHandoverReadiness(), + HandoverReadinessStatus() + .Add(HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap) + .Add( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks)); + + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); + EXPECT_EQ( + reasm.GetHandoverReadiness(), + HandoverReadinessStatus() + .Add(HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap) + .Add( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks)); + + reasm.Handle(ForwardTsnChunk(TSN(13), {})); + EXPECT_EQ(reasm.GetHandoverReadiness(), HandoverReadinessStatus()); +} + +TEST_F(ReassemblyQueueTest, NotReadyForHandoverWhenResetStreamIsDeferred) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + DataGeneratorOptions opts; + opts.message_id = MID(0); + reasm.Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + opts.message_id = MID(1); + reasm.Add(TSN(11), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + EXPECT_THAT(reasm.FlushMessages(), SizeIs(2)); + + reasm.ResetStreams( + OutgoingSSNResetRequestParameter( + ReconfigRequestSN(10), ReconfigRequestSN(3), TSN(13), {StreamID(1)}), + TSN(11)); + EXPECT_EQ(reasm.GetHandoverReadiness(), + HandoverReadinessStatus().Add( + HandoverUnreadinessReason::kStreamResetDeferred)); + + opts.message_id = MID(3); + opts.ppid = PPID(3); + reasm.Add(TSN(13), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm.MaybeResetStreamsDeferred(TSN(11)); + + opts.message_id = MID(2); + opts.ppid = PPID(2); + reasm.Add(TSN(13), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm.MaybeResetStreamsDeferred(TSN(15)); + EXPECT_EQ(reasm.GetHandoverReadiness(), + HandoverReadinessStatus().Add( + HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap)); + + EXPECT_THAT(reasm.FlushMessages(), SizeIs(2)); + EXPECT_EQ(reasm.GetHandoverReadiness(), + HandoverReadinessStatus().Add( + HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap)); + + reasm.Handle(ForwardTsnChunk(TSN(15), {})); + EXPECT_EQ(reasm.GetHandoverReadiness(), HandoverReadinessStatus()); +} + +TEST_F(ReassemblyQueueTest, HandoverInInitialState) { + ReassemblyQueue reasm1("log: ", TSN(10), kBufferSize); + + EXPECT_EQ(reasm1.GetHandoverReadiness(), HandoverReadinessStatus()); + DcSctpSocketHandoverState state; + reasm1.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + ReassemblyQueue reasm2("log: ", TSN(100), kBufferSize, + /*use_message_interleaving=*/false); + reasm2.RestoreFromState(state); + + reasm2.Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm2.FlushMessages(), SizeIs(1)); +} + +TEST_F(ReassemblyQueueTest, HandoverAfterHavingAssembedOneMessage) { + ReassemblyQueue reasm1("log: ", TSN(10), kBufferSize); + reasm1.Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm1.FlushMessages(), SizeIs(1)); + + EXPECT_EQ(reasm1.GetHandoverReadiness(), HandoverReadinessStatus()); + DcSctpSocketHandoverState state; + reasm1.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + ReassemblyQueue reasm2("log: ", TSN(100), kBufferSize, + /*use_message_interleaving=*/false); + reasm2.RestoreFromState(state); + + reasm2.Add(TSN(11), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm2.FlushMessages(), SizeIs(1)); +} + +TEST_F(ReassemblyQueueTest, HandleInconsistentForwardTSN) { + // Found when fuzzing. + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + // Add TSN=43, SSN=7. Can't be reassembled as previous SSNs aren't known. + reasm.Add(TSN(43), Data(kStreamID, SSN(7), MID(0), FSN(0), kPPID, + std::vector<uint8_t>(10), Data::IsBeginning(true), + Data::IsEnd(true), IsUnordered(false))); + + // Invalid, as TSN=44 have to have SSN>=7, but peer says 6. + reasm.Handle(ForwardTsnChunk( + TSN(44), {ForwardTsnChunk::SkippedStream(kStreamID, SSN(6))})); + + // Don't assemble SSN=7, as that TSN is skipped. + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, SingleUnorderedChunkMessageInRfc8260) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize, + /*use_message_interleaving=*/true); + reasm.Add(TSN(10), Data(StreamID(1), SSN(0), MID(0), FSN(0), kPPID, + {1, 2, 3, 4}, Data::IsBeginning(true), + Data::IsEnd(true), IsUnordered(true))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); +} + +TEST_F(ReassemblyQueueTest, TwoInterleavedChunks) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize, + /*use_message_interleaving=*/true); + reasm.Add(TSN(10), Data(StreamID(1), SSN(0), MID(0), FSN(0), kPPID, + {1, 2, 3, 4}, Data::IsBeginning(true), + Data::IsEnd(false), IsUnordered(true))); + reasm.Add(TSN(11), Data(StreamID(2), SSN(0), MID(0), FSN(0), kPPID, + {9, 10, 11, 12}, Data::IsBeginning(true), + Data::IsEnd(false), IsUnordered(true))); + EXPECT_EQ(reasm.queued_bytes(), 8u); + reasm.Add(TSN(12), Data(StreamID(1), SSN(0), MID(0), FSN(1), kPPID, + {5, 6, 7, 8}, Data::IsBeginning(false), + Data::IsEnd(true), IsUnordered(true))); + EXPECT_EQ(reasm.queued_bytes(), 4u); + reasm.Add(TSN(13), Data(StreamID(2), SSN(0), MID(0), FSN(1), kPPID, + {13, 14, 15, 16}, Data::IsBeginning(false), + Data::IsEnd(true), IsUnordered(true))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(StreamID(1), kPPID, kMediumPayload1), + SctpMessageIs(StreamID(2), kPPID, kMediumPayload2))); +} + +TEST_F(ReassemblyQueueTest, UnorderedInterleavedMessagesAllPermutations) { + std::vector<int> indexes = {0, 1, 2, 3, 4, 5}; + TSN tsns[] = {TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), TSN(15)}; + StreamID stream_ids[] = {StreamID(1), StreamID(2), StreamID(1), + StreamID(1), StreamID(2), StreamID(2)}; + FSN fsns[] = {FSN(0), FSN(0), FSN(1), FSN(2), FSN(1), FSN(2)}; + rtc::ArrayView<const uint8_t> payload(kSixBytePayload); + do { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize, + /*use_message_interleaving=*/true); + for (int i : indexes) { + auto span = payload.subview(*fsns[i] * 2, 2); + Data::IsBeginning is_beginning(fsns[i] == FSN(0)); + Data::IsEnd is_end(fsns[i] == FSN(2)); + reasm.Add(tsns[i], Data(stream_ids[i], SSN(0), MID(0), fsns[i], kPPID, + std::vector<uint8_t>(span.begin(), span.end()), + is_beginning, is_end, IsUnordered(true))); + } + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), kPPID, kSixBytePayload), + SctpMessageIs(StreamID(2), kPPID, kSixBytePayload))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + } while (std::next_permutation(std::begin(indexes), std::end(indexes))); +} + +TEST_F(ReassemblyQueueTest, IForwardTSNRemoveALotOrdered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize, + /*use_message_interleaving=*/true); + reasm.Add(TSN(10), gen_.Ordered({1}, "B")); + gen_.Ordered({2}, ""); + reasm.Add(TSN(12), gen_.Ordered({3}, "")); + reasm.Add(TSN(13), gen_.Ordered({4}, "E")); + reasm.Add(TSN(15), gen_.Ordered({5}, "B")); + reasm.Add(TSN(16), gen_.Ordered({6}, "")); + reasm.Add(TSN(17), gen_.Ordered({7}, "")); + reasm.Add(TSN(18), gen_.Ordered({8}, "E")); + + ASSERT_FALSE(reasm.HasMessages()); + EXPECT_EQ(reasm.queued_bytes(), 7u); + + reasm.Handle( + IForwardTsnChunk(TSN(13), {IForwardTsnChunk::SkippedStream( + IsUnordered(false), kStreamID, MID(0))})); + EXPECT_EQ(reasm.queued_bytes(), 0u); + + // The lost chunk comes, but too late. + ASSERT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kMessage2Payload))); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.cc b/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.cc new file mode 100644 index 0000000000..9fd52fb15d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.cc @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/reassembly_streams.h" + +#include <cstddef> +#include <map> +#include <utility> + +namespace dcsctp { + +ReassembledMessage AssembleMessage(std::map<UnwrappedTSN, Data>::iterator start, + std::map<UnwrappedTSN, Data>::iterator end) { + size_t count = std::distance(start, end); + + if (count == 1) { + // Fast path - zero-copy + Data& data = start->second; + + return ReassembledMessage{ + .tsns = {start->first}, + .message = DcSctpMessage(data.stream_id, data.ppid, + std::move(start->second.payload)), + }; + } + + // Slow path - will need to concatenate the payload. + std::vector<UnwrappedTSN> tsns; + std::vector<uint8_t> payload; + + size_t payload_size = std::accumulate( + start, end, 0, + [](size_t v, const auto& p) { return v + p.second.size(); }); + + tsns.reserve(count); + payload.reserve(payload_size); + for (auto it = start; it != end; ++it) { + Data& data = it->second; + tsns.push_back(it->first); + payload.insert(payload.end(), data.payload.begin(), data.payload.end()); + } + + return ReassembledMessage{ + .tsns = std::move(tsns), + .message = DcSctpMessage(start->second.stream_id, start->second.ppid, + std::move(payload)), + }; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.h b/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.h new file mode 100644 index 0000000000..0ecfac0c0a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_REASSEMBLY_STREAMS_H_ +#define NET_DCSCTP_RX_REASSEMBLY_STREAMS_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <functional> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_handover_state.h" +#include "net/dcsctp/public/dcsctp_message.h" + +namespace dcsctp { + +// Implementations of this interface will be called when data is received, when +// data should be skipped/forgotten or when sequence number should be reset. +// +// As a result of these operations - mainly when data is received - the +// implementations of this interface should notify when a message has been +// assembled, by calling the provided callback of type `OnAssembledMessage`. How +// it assembles messages will depend on e.g. if a message was sent on an ordered +// or unordered stream. +// +// Implementations will - for each operation - indicate how much additional +// memory that has been used as a result of performing the operation. This is +// used to limit the maximum amount of memory used, to prevent out-of-memory +// situations. +class ReassemblyStreams { + public: + // This callback will be provided as an argument to the constructor of the + // concrete class implementing this interface and should be called when a + // message has been assembled as well as indicating from which TSNs this + // message was assembled from. + using OnAssembledMessage = + std::function<void(rtc::ArrayView<const UnwrappedTSN> tsns, + DcSctpMessage message)>; + + virtual ~ReassemblyStreams() = default; + + // Adds a data chunk to a stream as identified in `data`. + // If it was the last remaining chunk in a message, reassemble one (or + // several, in case of ordered chunks) messages. + // + // Returns the additional number of bytes added to the queue as a result of + // performing this operation. If this addition resulted in messages being + // assembled and delivered, this may be negative. + virtual int Add(UnwrappedTSN tsn, Data data) = 0; + + // Called for incoming FORWARD-TSN/I-FORWARD-TSN chunks - when the sender + // wishes the received to skip/forget about data up until the provided TSN. + // This is used to implement partial reliability, such as limiting the number + // of retransmissions or the an expiration duration. As a result of skipping + // data, this may result in the implementation being able to assemble messages + // in ordered streams. + // + // Returns the number of bytes removed from the queue as a result of + // this operation. + virtual size_t HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> + skipped_streams) = 0; + + // Called for incoming (possibly deferred) RE_CONFIG chunks asking for + // either a few streams, or all streams (when the list is empty) to be + // reset - to have their next SSN or Message ID to be zero. + virtual void ResetStreams(rtc::ArrayView<const StreamID> stream_ids) = 0; + + virtual HandoverReadinessStatus GetHandoverReadiness() const = 0; + virtual void AddHandoverState(DcSctpSocketHandoverState& state) = 0; + virtual void RestoreFromState(const DcSctpSocketHandoverState& state) = 0; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_REASSEMBLY_STREAMS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc new file mode 100644 index 0000000000..dce6c90131 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc @@ -0,0 +1,348 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/traditional_reassembly_streams.h" + +#include <stddef.h> + +#include <cstdint> +#include <functional> +#include <iterator> +#include <map> +#include <numeric> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace { + +// Given a map (`chunks`) and an iterator to within that map (`iter`), this +// function will return an iterator to the first chunk in that message, which +// has the `is_beginning` flag set. If there are any gaps, or if the beginning +// can't be found, `absl::nullopt` is returned. +absl::optional<std::map<UnwrappedTSN, Data>::iterator> FindBeginning( + const std::map<UnwrappedTSN, Data>& chunks, + std::map<UnwrappedTSN, Data>::iterator iter) { + UnwrappedTSN prev_tsn = iter->first; + for (;;) { + if (iter->second.is_beginning) { + return iter; + } + if (iter == chunks.begin()) { + return absl::nullopt; + } + --iter; + if (iter->first.next_value() != prev_tsn) { + return absl::nullopt; + } + prev_tsn = iter->first; + } +} + +// Given a map (`chunks`) and an iterator to within that map (`iter`), this +// function will return an iterator to the chunk after the last chunk in that +// message, which has the `is_end` flag set. If there are any gaps, or if the +// end can't be found, `absl::nullopt` is returned. +absl::optional<std::map<UnwrappedTSN, Data>::iterator> FindEnd( + std::map<UnwrappedTSN, Data>& chunks, + std::map<UnwrappedTSN, Data>::iterator iter) { + UnwrappedTSN prev_tsn = iter->first; + for (;;) { + if (iter->second.is_end) { + return ++iter; + } + ++iter; + if (iter == chunks.end()) { + return absl::nullopt; + } + if (iter->first != prev_tsn.next_value()) { + return absl::nullopt; + } + prev_tsn = iter->first; + } +} +} // namespace + +TraditionalReassemblyStreams::TraditionalReassemblyStreams( + absl::string_view log_prefix, + OnAssembledMessage on_assembled_message) + : log_prefix_(log_prefix), + on_assembled_message_(std::move(on_assembled_message)) {} + +int TraditionalReassemblyStreams::UnorderedStream::Add(UnwrappedTSN tsn, + Data data) { + int queued_bytes = data.size(); + auto [it, inserted] = chunks_.emplace(tsn, std::move(data)); + if (!inserted) { + return 0; + } + + queued_bytes -= TryToAssembleMessage(it); + + return queued_bytes; +} + +size_t TraditionalReassemblyStreams::UnorderedStream::TryToAssembleMessage( + ChunkMap::iterator iter) { + // TODO(boivie): This method is O(N) with the number of fragments in a + // message, which can be inefficient for very large values of N. This could be + // optimized by e.g. only trying to assemble a message once _any_ beginning + // and _any_ end has been found. + absl::optional<ChunkMap::iterator> start = FindBeginning(chunks_, iter); + if (!start.has_value()) { + return 0; + } + absl::optional<ChunkMap::iterator> end = FindEnd(chunks_, iter); + if (!end.has_value()) { + return 0; + } + + size_t bytes_assembled = AssembleMessage(*start, *end); + chunks_.erase(*start, *end); + return bytes_assembled; +} + +size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage( + const ChunkMap::iterator start, + const ChunkMap::iterator end) { + size_t count = std::distance(start, end); + + if (count == 1) { + // Fast path - zero-copy + const Data& data = start->second; + size_t payload_size = start->second.size(); + UnwrappedTSN tsns[1] = {start->first}; + DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload)); + parent_.on_assembled_message_(tsns, std::move(message)); + return payload_size; + } + + // Slow path - will need to concatenate the payload. + std::vector<UnwrappedTSN> tsns; + std::vector<uint8_t> payload; + + size_t payload_size = std::accumulate( + start, end, 0, + [](size_t v, const auto& p) { return v + p.second.size(); }); + + tsns.reserve(count); + payload.reserve(payload_size); + for (auto it = start; it != end; ++it) { + const Data& data = it->second; + tsns.push_back(it->first); + payload.insert(payload.end(), data.payload.begin(), data.payload.end()); + } + + DcSctpMessage message(start->second.stream_id, start->second.ppid, + std::move(payload)); + parent_.on_assembled_message_(tsns, std::move(message)); + + return payload_size; +} + +size_t TraditionalReassemblyStreams::UnorderedStream::EraseTo( + UnwrappedTSN tsn) { + auto end_iter = chunks_.upper_bound(tsn); + size_t removed_bytes = std::accumulate( + chunks_.begin(), end_iter, 0, + [](size_t r, const auto& p) { return r + p.second.size(); }); + + chunks_.erase(chunks_.begin(), end_iter); + return removed_bytes; +} + +size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessage() { + if (chunks_by_ssn_.empty() || chunks_by_ssn_.begin()->first != next_ssn_) { + return 0; + } + + ChunkMap& chunks = chunks_by_ssn_.begin()->second; + + if (!chunks.begin()->second.is_beginning || !chunks.rbegin()->second.is_end) { + return 0; + } + + uint32_t tsn_diff = + UnwrappedTSN::Difference(chunks.rbegin()->first, chunks.begin()->first); + if (tsn_diff != chunks.size() - 1) { + return 0; + } + + size_t assembled_bytes = AssembleMessage(chunks.begin(), chunks.end()); + chunks_by_ssn_.erase(chunks_by_ssn_.begin()); + next_ssn_.Increment(); + return assembled_bytes; +} + +size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessages() { + size_t assembled_bytes = 0; + + for (;;) { + size_t assembled_bytes_this_iter = TryToAssembleMessage(); + if (assembled_bytes_this_iter == 0) { + break; + } + assembled_bytes += assembled_bytes_this_iter; + } + return assembled_bytes; +} + +int TraditionalReassemblyStreams::OrderedStream::Add(UnwrappedTSN tsn, + Data data) { + int queued_bytes = data.size(); + + UnwrappedSSN ssn = ssn_unwrapper_.Unwrap(data.ssn); + auto [unused, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data)); + if (!inserted) { + return 0; + } + + if (ssn == next_ssn_) { + queued_bytes -= TryToAssembleMessages(); + } + + return queued_bytes; +} + +size_t TraditionalReassemblyStreams::OrderedStream::EraseTo(SSN ssn) { + UnwrappedSSN unwrapped_ssn = ssn_unwrapper_.Unwrap(ssn); + + auto end_iter = chunks_by_ssn_.upper_bound(unwrapped_ssn); + size_t removed_bytes = std::accumulate( + chunks_by_ssn_.begin(), end_iter, 0, [](size_t r1, const auto& p) { + return r1 + + absl::c_accumulate(p.second, 0, [](size_t r2, const auto& q) { + return r2 + q.second.size(); + }); + }); + chunks_by_ssn_.erase(chunks_by_ssn_.begin(), end_iter); + + if (unwrapped_ssn >= next_ssn_) { + unwrapped_ssn.Increment(); + next_ssn_ = unwrapped_ssn; + } + + removed_bytes += TryToAssembleMessages(); + return removed_bytes; +} + +int TraditionalReassemblyStreams::Add(UnwrappedTSN tsn, Data data) { + if (data.is_unordered) { + auto it = unordered_streams_.try_emplace(data.stream_id, this).first; + return it->second.Add(tsn, std::move(data)); + } + + auto it = ordered_streams_.try_emplace(data.stream_id, this).first; + return it->second.Add(tsn, std::move(data)); +} + +size_t TraditionalReassemblyStreams::HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) { + size_t bytes_removed = 0; + // The `skipped_streams` only cover ordered messages - need to + // iterate all unordered streams manually to remove those chunks. + for (auto& [unused, stream] : unordered_streams_) { + bytes_removed += stream.EraseTo(new_cumulative_ack_tsn); + } + + for (const auto& skipped_stream : skipped_streams) { + auto it = + ordered_streams_.try_emplace(skipped_stream.stream_id, this).first; + bytes_removed += it->second.EraseTo(skipped_stream.ssn); + } + + return bytes_removed; +} + +void TraditionalReassemblyStreams::ResetStreams( + rtc::ArrayView<const StreamID> stream_ids) { + if (stream_ids.empty()) { + for (auto& [stream_id, stream] : ordered_streams_) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Resetting implicit stream_id=" << *stream_id; + stream.Reset(); + } + } else { + for (StreamID stream_id : stream_ids) { + auto it = ordered_streams_.find(stream_id); + if (it != ordered_streams_.end()) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Resetting explicit stream_id=" << *stream_id; + it->second.Reset(); + } + } + } +} + +HandoverReadinessStatus TraditionalReassemblyStreams::GetHandoverReadiness() + const { + HandoverReadinessStatus status; + for (const auto& [unused, stream] : ordered_streams_) { + if (stream.has_unassembled_chunks()) { + status.Add(HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks); + break; + } + } + for (const auto& [unused, stream] : unordered_streams_) { + if (stream.has_unassembled_chunks()) { + status.Add( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks); + break; + } + } + return status; +} + +void TraditionalReassemblyStreams::AddHandoverState( + DcSctpSocketHandoverState& state) { + for (const auto& [stream_id, stream] : ordered_streams_) { + DcSctpSocketHandoverState::OrderedStream state_stream; + state_stream.id = stream_id.value(); + state_stream.next_ssn = stream.next_ssn().value(); + state.rx.ordered_streams.push_back(std::move(state_stream)); + } + for (const auto& [stream_id, unused] : unordered_streams_) { + DcSctpSocketHandoverState::UnorderedStream state_stream; + state_stream.id = stream_id.value(); + state.rx.unordered_streams.push_back(std::move(state_stream)); + } +} + +void TraditionalReassemblyStreams::RestoreFromState( + const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(ordered_streams_.empty()); + RTC_DCHECK(unordered_streams_.empty()); + + for (const DcSctpSocketHandoverState::OrderedStream& state_stream : + state.rx.ordered_streams) { + ordered_streams_.emplace( + std::piecewise_construct, + std::forward_as_tuple(StreamID(state_stream.id)), + std::forward_as_tuple(this, SSN(state_stream.next_ssn))); + } + for (const DcSctpSocketHandoverState::UnorderedStream& state_stream : + state.rx.unordered_streams) { + unordered_streams_.emplace(std::piecewise_construct, + std::forward_as_tuple(StreamID(state_stream.id)), + std::forward_as_tuple(this)); + } +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.h b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.h new file mode 100644 index 0000000000..4825afd1ba --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.h @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_TRADITIONAL_REASSEMBLY_STREAMS_H_ +#define NET_DCSCTP_RX_TRADITIONAL_REASSEMBLY_STREAMS_H_ +#include <stddef.h> +#include <stdint.h> + +#include <map> +#include <string> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/rx/reassembly_streams.h" + +namespace dcsctp { + +// Handles reassembly of incoming data when interleaved message sending +// is not enabled on the association, i.e. when RFC8260 is not in use and +// RFC4960 is to be followed. +class TraditionalReassemblyStreams : public ReassemblyStreams { + public: + TraditionalReassemblyStreams(absl::string_view log_prefix, + OnAssembledMessage on_assembled_message); + + int Add(UnwrappedTSN tsn, Data data) override; + + size_t HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) + override; + + void ResetStreams(rtc::ArrayView<const StreamID> stream_ids) override; + + HandoverReadinessStatus GetHandoverReadiness() const override; + void AddHandoverState(DcSctpSocketHandoverState& state) override; + void RestoreFromState(const DcSctpSocketHandoverState& state) override; + + private: + using ChunkMap = std::map<UnwrappedTSN, Data>; + + // Base class for `UnorderedStream` and `OrderedStream`. + class StreamBase { + protected: + explicit StreamBase(TraditionalReassemblyStreams* parent) + : parent_(*parent) {} + + size_t AssembleMessage(ChunkMap::iterator start, ChunkMap::iterator end); + TraditionalReassemblyStreams& parent_; + }; + + // Manages all received data for a specific unordered stream, and assembles + // messages when possible. + class UnorderedStream : StreamBase { + public: + explicit UnorderedStream(TraditionalReassemblyStreams* parent) + : StreamBase(parent) {} + int Add(UnwrappedTSN tsn, Data data); + // Returns the number of bytes removed from the queue. + size_t EraseTo(UnwrappedTSN tsn); + bool has_unassembled_chunks() const { return !chunks_.empty(); } + + private: + // Given an iterator to any chunk within the map, try to assemble a message + // into `reassembled_messages` containing it and - if successful - erase + // those chunks from the stream chunks map. + // + // Returns the number of bytes that were assembled. + size_t TryToAssembleMessage(ChunkMap::iterator iter); + + ChunkMap chunks_; + }; + + // Manages all received data for a specific ordered stream, and assembles + // messages when possible. + class OrderedStream : StreamBase { + public: + explicit OrderedStream(TraditionalReassemblyStreams* parent, + SSN next_ssn = SSN(0)) + : StreamBase(parent), next_ssn_(ssn_unwrapper_.Unwrap(next_ssn)) {} + int Add(UnwrappedTSN tsn, Data data); + size_t EraseTo(SSN ssn); + void Reset() { + ssn_unwrapper_.Reset(); + next_ssn_ = ssn_unwrapper_.Unwrap(SSN(0)); + } + SSN next_ssn() const { return next_ssn_.Wrap(); } + bool has_unassembled_chunks() const { return !chunks_by_ssn_.empty(); } + + private: + // Try to assemble one or several messages in order from the stream. + // Returns the number of bytes assembled if a message was assembled. + size_t TryToAssembleMessage(); + size_t TryToAssembleMessages(); + // This must be an ordered container to be able to iterate in SSN order. + std::map<UnwrappedSSN, ChunkMap> chunks_by_ssn_; + UnwrappedSSN::Unwrapper ssn_unwrapper_; + UnwrappedSSN next_ssn_; + }; + + const std::string log_prefix_; + + // Callback for when a message has been assembled. + const OnAssembledMessage on_assembled_message_; + + // All unordered and ordered streams, managing not-yet-assembled data. + std::map<StreamID, UnorderedStream> unordered_streams_; + std::map<StreamID, OrderedStream> ordered_streams_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_TRADITIONAL_REASSEMBLY_STREAMS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams_test.cc b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams_test.cc new file mode 100644 index 0000000000..341870442d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams_test.cc @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/traditional_reassembly_streams.h" + +#include <cstdint> +#include <memory> +#include <utility> + +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/rx/reassembly_streams.h" +#include "net/dcsctp/testing/data_generator.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::MockFunction; +using ::testing::NiceMock; +using ::testing::Property; + +class TraditionalReassemblyStreamsTest : public testing::Test { + protected: + UnwrappedTSN tsn(uint32_t value) { return tsn_.Unwrap(TSN(value)); } + + TraditionalReassemblyStreamsTest() {} + DataGenerator gen_; + UnwrappedTSN::Unwrapper tsn_; +}; + +TEST_F(TraditionalReassemblyStreamsTest, + AddUnorderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Unordered({5, 6})), 2); + // Adding the end fragment should make it empty again. + EXPECT_EQ(streams.Add(tsn(4), gen_.Unordered({7}, "E")), -6); +} + +TEST_F(TraditionalReassemblyStreamsTest, + AddSimpleOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), -6); +} + +TEST_F(TraditionalReassemblyStreamsTest, + AddMoreComplexOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + Data late = gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + EXPECT_EQ(streams.Add(tsn(2), std::move(late)), -8); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteUnorderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Unordered({5, 6})), 2); + + EXPECT_EQ(streams.HandleForwardTsn(tsn(3), {}), 6u); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteSimpleOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(3), skipped), 6u); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteManyOrderedMessagesReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + + // Expire all three messages + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(2))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(8), skipped), 8u); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteOrderedMessageDelivesTwoReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + + // The first ordered message expire, and the following two are delivered. + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(4), skipped), 8u); +} + +TEST_F(TraditionalReassemblyStreamsTest, NoStreamsCanBeHandedOver) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams1("", on_assembled.AsStdFunction()); + EXPECT_TRUE(streams1.GetHandoverReadiness().IsReady()); + + DcSctpSocketHandoverState state; + streams1.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction()); + streams2.RestoreFromState(state); + + EXPECT_EQ(streams2.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams2.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams2.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams2.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); +} + +TEST_F(TraditionalReassemblyStreamsTest, + OrderedStreamsCanBeHandedOverWhenNoUnassembledChunksExist) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams1("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams1.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks)); + EXPECT_EQ(streams1.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks)); + EXPECT_EQ(streams1.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks)); + + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(0))}; + EXPECT_EQ(streams1.HandleForwardTsn(tsn(3), skipped), 6u); + EXPECT_TRUE(streams1.GetHandoverReadiness().IsReady()); + + DcSctpSocketHandoverState state; + streams1.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction()); + streams2.RestoreFromState(state); + EXPECT_EQ(streams2.Add(tsn(4), gen_.Ordered({7})), 1); +} + +TEST_F(TraditionalReassemblyStreamsTest, + UnorderedStreamsCanBeHandedOverWhenNoUnassembledChunksExist) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams1("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams1.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ( + streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks)); + EXPECT_EQ(streams1.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ( + streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks)); + EXPECT_EQ(streams1.Add(tsn(3), gen_.Unordered({5, 6})), 2); + EXPECT_EQ( + streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks)); + + EXPECT_EQ(streams1.HandleForwardTsn(tsn(3), {}), 6u); + EXPECT_TRUE(streams1.GetHandoverReadiness().IsReady()); + + DcSctpSocketHandoverState state; + streams1.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction()); + streams2.RestoreFromState(state); + EXPECT_EQ(streams2.Add(tsn(4), gen_.Unordered({7})), 1); +} + +TEST_F(TraditionalReassemblyStreamsTest, CanDeleteFirstOrderedMessage) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + EXPECT_CALL(on_assembled, + Call(ElementsAre(tsn(2)), + Property(&DcSctpMessage::payload, ElementsAre(2, 3, 4)))); + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + // Not received, SID=1. TSN=1, SSN=0 + gen_.Ordered({1}, "BE"); + // And deleted (SID=1, TSN=1, SSN=0) + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(1), skipped), 0u); + + // Receive SID=1, TSN=2, SSN=1 + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4}, "BE")), 0); +} + +} // namespace +} // namespace dcsctp |