diff options
Diffstat (limited to 'third_party/libwebrtc/net/dcsctp/tx')
21 files changed, 7683 insertions, 0 deletions
diff --git a/third_party/libwebrtc/net/dcsctp/tx/BUILD.gn b/third_party/libwebrtc/net/dcsctp/tx/BUILD.gn new file mode 100644 index 0000000000..43fd41639e --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/BUILD.gn @@ -0,0 +1,209 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("send_queue") { + deps = [ + "../../../api:array_view", + "../common:internal_types", + "../packet:chunk", + "../packet:data", + "../public:socket", + "../public:types", + ] + sources = [ "send_queue.h" ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] +} + +rtc_library("rr_send_queue") { + deps = [ + ":send_queue", + ":stream_scheduler", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base/containers:flat_map", + "../common:str_join", + "../packet:data", + "../public:socket", + "../public:types", + ] + sources = [ + "rr_send_queue.cc", + "rr_send_queue.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("stream_scheduler") { + deps = [ + ":send_queue", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:strong_alias", + "../../../rtc_base/containers:flat_set", + "../common:str_join", + "../packet:chunk", + "../packet:data", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + ] + sources = [ + "stream_scheduler.cc", + "stream_scheduler.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("retransmission_error_counter") { + deps = [ + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../public:types", + ] + sources = [ + "retransmission_error_counter.cc", + "retransmission_error_counter.h", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_library("retransmission_timeout") { + deps = [ + "../../../rtc_base:checks", + "../public:types", + ] + sources = [ + "retransmission_timeout.cc", + "retransmission_timeout.h", + ] +} + +rtc_library("outstanding_data") { + deps = [ + ":retransmission_timeout", + ":send_queue", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../common:math", + "../common:sequence_numbers", + "../common:str_join", + "../packet:chunk", + "../packet:data", + "../public:socket", + "../public:types", + "../timer", + ] + sources = [ + "outstanding_data.cc", + "outstanding_data.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("retransmission_queue") { + deps = [ + ":outstanding_data", + ":retransmission_timeout", + ":send_queue", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:math", + "../common:sequence_numbers", + "../common:str_join", + "../packet:chunk", + "../packet:data", + "../public:socket", + "../public:types", + "../timer", + ] + sources = [ + "retransmission_queue.cc", + "retransmission_queue.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +if (rtc_include_tests) { + rtc_source_set("mock_send_queue") { + testonly = true + deps = [ + ":send_queue", + "../../../api:array_view", + "../../../test:test_support", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + sources = [ "mock_send_queue.h" ] + } + + rtc_library("dcsctp_tx_unittests") { + testonly = true + + deps = [ + ":mock_send_queue", + ":outstanding_data", + ":retransmission_error_counter", + ":retransmission_queue", + ":retransmission_timeout", + ":rr_send_queue", + ":send_queue", + ":stream_scheduler", + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../test:test_support", + "../common:handover_testing", + "../common:math", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + "../socket:mock_callbacks", + "../socket:mock_callbacks", + "../testing:data_generator", + "../testing:testing_macros", + "../timer", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + sources = [ + "outstanding_data_test.cc", + "retransmission_error_counter_test.cc", + "retransmission_queue_test.cc", + "retransmission_timeout_test.cc", + "rr_send_queue_test.cc", + "stream_scheduler_test.cc", + ] + } +} diff --git a/third_party/libwebrtc/net/dcsctp/tx/mock_send_queue.h b/third_party/libwebrtc/net/dcsctp/tx/mock_send_queue.h new file mode 100644 index 0000000000..0c8f5d141d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/mock_send_queue.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_MOCK_SEND_QUEUE_H_ +#define NET_DCSCTP_TX_MOCK_SEND_QUEUE_H_ + +#include <cstdint> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/tx/send_queue.h" +#include "test/gmock.h" + +namespace dcsctp { + +class MockSendQueue : public SendQueue { + public: + MockSendQueue() { + ON_CALL(*this, Produce).WillByDefault([](TimeMs now, size_t max_size) { + return absl::nullopt; + }); + } + + MOCK_METHOD(absl::optional<SendQueue::DataToSend>, + Produce, + (TimeMs now, size_t max_size), + (override)); + MOCK_METHOD(bool, + Discard, + (IsUnordered unordered, StreamID stream_id, MID message_id), + (override)); + MOCK_METHOD(void, PrepareResetStream, (StreamID stream_id), (override)); + MOCK_METHOD(bool, HasStreamsReadyToBeReset, (), (const, override)); + MOCK_METHOD(std::vector<StreamID>, GetStreamsReadyToBeReset, (), (override)); + MOCK_METHOD(void, CommitResetStreams, (), (override)); + MOCK_METHOD(void, RollbackResetStreams, (), (override)); + MOCK_METHOD(void, Reset, (), (override)); + MOCK_METHOD(size_t, buffered_amount, (StreamID stream_id), (const, override)); + MOCK_METHOD(size_t, total_buffered_amount, (), (const, override)); + MOCK_METHOD(size_t, + buffered_amount_low_threshold, + (StreamID stream_id), + (const, override)); + MOCK_METHOD(void, + SetBufferedAmountLowThreshold, + (StreamID stream_id, size_t bytes), + (override)); + MOCK_METHOD(void, EnableMessageInterleaving, (bool enabled), (override)); +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_MOCK_SEND_QUEUE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.cc b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.cc new file mode 100644 index 0000000000..4f1e863056 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.cc @@ -0,0 +1,543 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/outstanding_data.h" + +#include <algorithm> +#include <set> +#include <utility> +#include <vector> + +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/public/types.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +// The number of times a packet must be NACKed before it's retransmitted. +// See https://tools.ietf.org/html/rfc4960#section-7.2.4 +constexpr uint8_t kNumberOfNacksForRetransmission = 3; + +// Returns how large a chunk will be, serialized, carrying the data +size_t OutstandingData::GetSerializedChunkSize(const Data& data) const { + return RoundUpTo4(data_chunk_header_size_ + data.size()); +} + +void OutstandingData::Item::Ack() { + if (lifecycle_ != Lifecycle::kAbandoned) { + lifecycle_ = Lifecycle::kActive; + } + ack_state_ = AckState::kAcked; +} + +OutstandingData::Item::NackAction OutstandingData::Item::Nack( + bool retransmit_now) { + ack_state_ = AckState::kNacked; + ++nack_count_; + if (!should_be_retransmitted() && !is_abandoned() && + (retransmit_now || nack_count_ >= kNumberOfNacksForRetransmission)) { + // Nacked enough times - it's considered lost. + if (num_retransmissions_ < *max_retransmissions_) { + lifecycle_ = Lifecycle::kToBeRetransmitted; + return NackAction::kRetransmit; + } + Abandon(); + return NackAction::kAbandon; + } + return NackAction::kNothing; +} + +void OutstandingData::Item::MarkAsRetransmitted() { + lifecycle_ = Lifecycle::kActive; + ack_state_ = AckState::kUnacked; + + nack_count_ = 0; + ++num_retransmissions_; +} + +void OutstandingData::Item::Abandon() { + lifecycle_ = Lifecycle::kAbandoned; +} + +bool OutstandingData::Item::has_expired(TimeMs now) const { + return expires_at_ <= now; +} + +bool OutstandingData::IsConsistent() const { + size_t actual_outstanding_bytes = 0; + size_t actual_outstanding_items = 0; + + std::set<UnwrappedTSN> combined_to_be_retransmitted; + combined_to_be_retransmitted.insert(to_be_retransmitted_.begin(), + to_be_retransmitted_.end()); + combined_to_be_retransmitted.insert(to_be_fast_retransmitted_.begin(), + to_be_fast_retransmitted_.end()); + + std::set<UnwrappedTSN> actual_combined_to_be_retransmitted; + for (const auto& [tsn, item] : outstanding_data_) { + if (item.is_outstanding()) { + actual_outstanding_bytes += GetSerializedChunkSize(item.data()); + ++actual_outstanding_items; + } + + if (item.should_be_retransmitted()) { + actual_combined_to_be_retransmitted.insert(tsn); + } + } + + if (outstanding_data_.empty() && + next_tsn_ != last_cumulative_tsn_ack_.next_value()) { + return false; + } + + return actual_outstanding_bytes == outstanding_bytes_ && + actual_outstanding_items == outstanding_items_ && + actual_combined_to_be_retransmitted == combined_to_be_retransmitted; +} + +void OutstandingData::AckChunk(AckInfo& ack_info, + std::map<UnwrappedTSN, Item>::iterator iter) { + if (!iter->second.is_acked()) { + size_t serialized_size = GetSerializedChunkSize(iter->second.data()); + ack_info.bytes_acked += serialized_size; + if (iter->second.is_outstanding()) { + outstanding_bytes_ -= serialized_size; + --outstanding_items_; + } + if (iter->second.should_be_retransmitted()) { + RTC_DCHECK(to_be_fast_retransmitted_.find(iter->first) == + to_be_fast_retransmitted_.end()); + to_be_retransmitted_.erase(iter->first); + } + iter->second.Ack(); + ack_info.highest_tsn_acked = + std::max(ack_info.highest_tsn_acked, iter->first); + } +} + +OutstandingData::AckInfo OutstandingData::HandleSack( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + bool is_in_fast_recovery) { + OutstandingData::AckInfo ack_info(cumulative_tsn_ack); + // Erase all items up to cumulative_tsn_ack. + RemoveAcked(cumulative_tsn_ack, ack_info); + + // ACK packets reported in the gap ack blocks + AckGapBlocks(cumulative_tsn_ack, gap_ack_blocks, ack_info); + + // NACK and possibly mark for retransmit chunks that weren't acked. + NackBetweenAckBlocks(cumulative_tsn_ack, gap_ack_blocks, is_in_fast_recovery, + ack_info); + + RTC_DCHECK(IsConsistent()); + return ack_info; +} + +void OutstandingData::RemoveAcked(UnwrappedTSN cumulative_tsn_ack, + AckInfo& ack_info) { + auto first_unacked = outstanding_data_.upper_bound(cumulative_tsn_ack); + + for (auto iter = outstanding_data_.begin(); iter != first_unacked; ++iter) { + AckChunk(ack_info, iter); + if (iter->second.lifecycle_id().IsSet()) { + RTC_DCHECK(iter->second.data().is_end); + if (iter->second.is_abandoned()) { + ack_info.abandoned_lifecycle_ids.push_back(iter->second.lifecycle_id()); + } else { + ack_info.acked_lifecycle_ids.push_back(iter->second.lifecycle_id()); + } + } + } + + outstanding_data_.erase(outstanding_data_.begin(), first_unacked); + last_cumulative_tsn_ack_ = cumulative_tsn_ack; +} + +void OutstandingData::AckGapBlocks( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + AckInfo& ack_info) { + // Mark all non-gaps as ACKED (but they can't be removed) as (from RFC) + // "SCTP considers the information carried in the Gap Ack Blocks in the + // SACK chunk as advisory.". Note that when NR-SACK is supported, this can be + // handled differently. + + for (auto& block : gap_ack_blocks) { + auto start = outstanding_data_.lower_bound( + UnwrappedTSN::AddTo(cumulative_tsn_ack, block.start)); + auto end = outstanding_data_.upper_bound( + UnwrappedTSN::AddTo(cumulative_tsn_ack, block.end)); + for (auto iter = start; iter != end; ++iter) { + AckChunk(ack_info, iter); + } + } +} + +void OutstandingData::NackBetweenAckBlocks( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + bool is_in_fast_recovery, + OutstandingData::AckInfo& ack_info) { + // Mark everything between the blocks as NACKED/TO_BE_RETRANSMITTED. + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "Mark the DATA chunk(s) with three miss indications for retransmission." + // "For each incoming SACK, miss indications are incremented only for + // missing TSNs prior to the highest TSN newly acknowledged in the SACK." + // + // What this means is that only when there is a increasing stream of data + // received and there are new packets seen (since last time), packets that are + // in-flight and between gaps should be nacked. This means that SCTP relies on + // the T3-RTX-timer to re-send packets otherwise. + UnwrappedTSN max_tsn_to_nack = ack_info.highest_tsn_acked; + if (is_in_fast_recovery && cumulative_tsn_ack > last_cumulative_tsn_ack_) { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "If an endpoint is in Fast Recovery and a SACK arrives that advances + // the Cumulative TSN Ack Point, the miss indications are incremented for + // all TSNs reported missing in the SACK." + max_tsn_to_nack = UnwrappedTSN::AddTo( + cumulative_tsn_ack, + gap_ack_blocks.empty() ? 0 : gap_ack_blocks.rbegin()->end); + } + + UnwrappedTSN prev_block_last_acked = cumulative_tsn_ack; + for (auto& block : gap_ack_blocks) { + UnwrappedTSN cur_block_first_acked = + UnwrappedTSN::AddTo(cumulative_tsn_ack, block.start); + for (auto iter = outstanding_data_.upper_bound(prev_block_last_acked); + iter != outstanding_data_.lower_bound(cur_block_first_acked); ++iter) { + if (iter->first <= max_tsn_to_nack) { + ack_info.has_packet_loss |= + NackItem(iter->first, iter->second, /*retransmit_now=*/false, + /*do_fast_retransmit=*/!is_in_fast_recovery); + } + } + prev_block_last_acked = UnwrappedTSN::AddTo(cumulative_tsn_ack, block.end); + } + + // Note that packets are not NACKED which are above the highest gap-ack-block + // (or above the cumulative ack TSN if no gap-ack-blocks) as only packets + // up until the highest_tsn_acked (see above) should be considered when + // NACKing. +} + +bool OutstandingData::NackItem(UnwrappedTSN tsn, + Item& item, + bool retransmit_now, + bool do_fast_retransmit) { + if (item.is_outstanding()) { + outstanding_bytes_ -= GetSerializedChunkSize(item.data()); + --outstanding_items_; + } + + switch (item.Nack(retransmit_now)) { + case Item::NackAction::kNothing: + return false; + case Item::NackAction::kRetransmit: + if (do_fast_retransmit) { + to_be_fast_retransmitted_.insert(tsn); + } else { + to_be_retransmitted_.insert(tsn); + } + RTC_DLOG(LS_VERBOSE) << *tsn.Wrap() << " marked for retransmission"; + break; + case Item::NackAction::kAbandon: + AbandonAllFor(item); + break; + } + return true; +} + +void OutstandingData::AbandonAllFor(const Item& item) { + // Erase all remaining chunks from the producer, if any. + if (discard_from_send_queue_(item.data().is_unordered, item.data().stream_id, + item.data().message_id)) { + // There were remaining chunks to be produced for this message. Since the + // receiver may have already received all chunks (up till now) for this + // message, we can't just FORWARD-TSN to the last fragment in this + // (abandoned) message and start sending a new message, as the receiver will + // then see a new message before the end of the previous one was seen (or + // skipped over). So create a new fragment, representing the end, that the + // received will never see as it is abandoned immediately and used as cum + // TSN in the sent FORWARD-TSN. + UnwrappedTSN tsn = next_tsn_; + next_tsn_.Increment(); + Data message_end(item.data().stream_id, item.data().ssn, + item.data().message_id, item.data().fsn, item.data().ppid, + std::vector<uint8_t>(), Data::IsBeginning(false), + Data::IsEnd(true), item.data().is_unordered); + Item& added_item = + outstanding_data_ + .emplace(std::piecewise_construct, std::forward_as_tuple(tsn), + std::forward_as_tuple(std::move(message_end), TimeMs(0), + MaxRetransmits::NoLimit(), + TimeMs::InfiniteFuture(), + LifecycleId::NotSet())) + .first->second; + // The added chunk shouldn't be included in `outstanding_bytes`, so set it + // as acked. + added_item.Ack(); + RTC_DLOG(LS_VERBOSE) << "Adding unsent end placeholder for message at tsn=" + << *tsn.Wrap(); + } + + for (auto& [tsn, other] : outstanding_data_) { + if (!other.is_abandoned() && + other.data().stream_id == item.data().stream_id && + other.data().is_unordered == item.data().is_unordered && + other.data().message_id == item.data().message_id) { + RTC_DLOG(LS_VERBOSE) << "Marking chunk " << *tsn.Wrap() + << " as abandoned"; + if (other.should_be_retransmitted()) { + to_be_fast_retransmitted_.erase(tsn); + to_be_retransmitted_.erase(tsn); + } + other.Abandon(); + } + } +} + +std::vector<std::pair<TSN, Data>> OutstandingData::ExtractChunksThatCanFit( + std::set<UnwrappedTSN>& chunks, + size_t max_size) { + std::vector<std::pair<TSN, Data>> result; + + for (auto it = chunks.begin(); it != chunks.end();) { + UnwrappedTSN tsn = *it; + auto elem = outstanding_data_.find(tsn); + RTC_DCHECK(elem != outstanding_data_.end()); + Item& item = elem->second; + RTC_DCHECK(item.should_be_retransmitted()); + RTC_DCHECK(!item.is_outstanding()); + RTC_DCHECK(!item.is_abandoned()); + RTC_DCHECK(!item.is_acked()); + + size_t serialized_size = GetSerializedChunkSize(item.data()); + if (serialized_size <= max_size) { + item.MarkAsRetransmitted(); + result.emplace_back(tsn.Wrap(), item.data().Clone()); + max_size -= serialized_size; + outstanding_bytes_ += serialized_size; + ++outstanding_items_; + it = chunks.erase(it); + } else { + ++it; + } + // No point in continuing if the packet is full. + if (max_size <= data_chunk_header_size_) { + break; + } + } + return result; +} + +std::vector<std::pair<TSN, Data>> +OutstandingData::GetChunksToBeFastRetransmitted(size_t max_size) { + std::vector<std::pair<TSN, Data>> result = + ExtractChunksThatCanFit(to_be_fast_retransmitted_, max_size); + + // https://datatracker.ietf.org/doc/html/rfc4960#section-7.2.4 + // "Those TSNs marked for retransmission due to the Fast-Retransmit algorithm + // that did not fit in the sent datagram carrying K other TSNs are also marked + // as ineligible for a subsequent Fast Retransmit. However, as they are + // marked for retransmission they will be retransmitted later on as soon as + // cwnd allows." + if (!to_be_fast_retransmitted_.empty()) { + to_be_retransmitted_.insert(to_be_fast_retransmitted_.begin(), + to_be_fast_retransmitted_.end()); + to_be_fast_retransmitted_.clear(); + } + + RTC_DCHECK(IsConsistent()); + return result; +} + +std::vector<std::pair<TSN, Data>> OutstandingData::GetChunksToBeRetransmitted( + size_t max_size) { + // Chunks scheduled for fast retransmission must be sent first. + RTC_DCHECK(to_be_fast_retransmitted_.empty()); + return ExtractChunksThatCanFit(to_be_retransmitted_, max_size); +} + +void OutstandingData::ExpireOutstandingChunks(TimeMs now) { + for (const auto& [tsn, item] : outstanding_data_) { + // Chunks that are nacked can be expired. Care should be taken not to expire + // unacked (in-flight) chunks as they might have been received, but the SACK + // is either delayed or in-flight and may be received later. + if (item.is_abandoned()) { + // Already abandoned. + } else if (item.is_nacked() && item.has_expired(now)) { + RTC_DLOG(LS_VERBOSE) << "Marking nacked chunk " << *tsn.Wrap() + << " and message " << *item.data().message_id + << " as expired"; + AbandonAllFor(item); + } else { + // A non-expired chunk. No need to iterate any further. + break; + } + } + RTC_DCHECK(IsConsistent()); +} + +UnwrappedTSN OutstandingData::highest_outstanding_tsn() const { + return outstanding_data_.empty() ? last_cumulative_tsn_ack_ + : outstanding_data_.rbegin()->first; +} + +absl::optional<UnwrappedTSN> OutstandingData::Insert( + const Data& data, + TimeMs time_sent, + MaxRetransmits max_retransmissions, + TimeMs expires_at, + LifecycleId lifecycle_id) { + UnwrappedTSN tsn = next_tsn_; + next_tsn_.Increment(); + + // All chunks are always padded to be even divisible by 4. + size_t chunk_size = GetSerializedChunkSize(data); + outstanding_bytes_ += chunk_size; + ++outstanding_items_; + auto it = outstanding_data_ + .emplace(std::piecewise_construct, std::forward_as_tuple(tsn), + std::forward_as_tuple(data.Clone(), time_sent, + max_retransmissions, expires_at, + lifecycle_id)) + .first; + + if (it->second.has_expired(time_sent)) { + // No need to send it - it was expired when it was in the send + // queue. + RTC_DLOG(LS_VERBOSE) << "Marking freshly produced chunk " + << *it->first.Wrap() << " and message " + << *it->second.data().message_id << " as expired"; + AbandonAllFor(it->second); + RTC_DCHECK(IsConsistent()); + return absl::nullopt; + } + + RTC_DCHECK(IsConsistent()); + return tsn; +} + +void OutstandingData::NackAll() { + for (auto& [tsn, item] : outstanding_data_) { + if (!item.is_acked()) { + NackItem(tsn, item, /*retransmit_now=*/true, + /*do_fast_retransmit=*/false); + } + } + RTC_DCHECK(IsConsistent()); +} + +absl::optional<DurationMs> OutstandingData::MeasureRTT(TimeMs now, + UnwrappedTSN tsn) const { + auto it = outstanding_data_.find(tsn); + if (it != outstanding_data_.end() && !it->second.has_been_retransmitted()) { + // https://tools.ietf.org/html/rfc4960#section-6.3.1 + // "Karn's algorithm: RTT measurements MUST NOT be made using + // packets that were retransmitted (and thus for which it is ambiguous + // whether the reply was for the first instance of the chunk or for a + // later instance)" + return now - it->second.time_sent(); + } + return absl::nullopt; +} + +std::vector<std::pair<TSN, OutstandingData::State>> +OutstandingData::GetChunkStatesForTesting() const { + std::vector<std::pair<TSN, State>> states; + states.emplace_back(last_cumulative_tsn_ack_.Wrap(), State::kAcked); + for (const auto& [tsn, item] : outstanding_data_) { + State state; + if (item.is_abandoned()) { + state = State::kAbandoned; + } else if (item.should_be_retransmitted()) { + state = State::kToBeRetransmitted; + } else if (item.is_acked()) { + state = State::kAcked; + } else if (item.is_outstanding()) { + state = State::kInFlight; + } else { + state = State::kNacked; + } + + states.emplace_back(tsn.Wrap(), state); + } + return states; +} + +bool OutstandingData::ShouldSendForwardTsn() const { + if (!outstanding_data_.empty()) { + auto it = outstanding_data_.begin(); + return it->first == last_cumulative_tsn_ack_.next_value() && + it->second.is_abandoned(); + } + return false; +} + +ForwardTsnChunk OutstandingData::CreateForwardTsn() const { + std::map<StreamID, SSN> skipped_per_ordered_stream; + UnwrappedTSN new_cumulative_ack = last_cumulative_tsn_ack_; + + for (const auto& [tsn, item] : outstanding_data_) { + if ((tsn != new_cumulative_ack.next_value()) || !item.is_abandoned()) { + break; + } + new_cumulative_ack = tsn; + if (!item.data().is_unordered && + item.data().ssn > skipped_per_ordered_stream[item.data().stream_id]) { + skipped_per_ordered_stream[item.data().stream_id] = item.data().ssn; + } + } + + std::vector<ForwardTsnChunk::SkippedStream> skipped_streams; + skipped_streams.reserve(skipped_per_ordered_stream.size()); + for (const auto& [stream_id, ssn] : skipped_per_ordered_stream) { + skipped_streams.emplace_back(stream_id, ssn); + } + return ForwardTsnChunk(new_cumulative_ack.Wrap(), std::move(skipped_streams)); +} + +IForwardTsnChunk OutstandingData::CreateIForwardTsn() const { + std::map<std::pair<IsUnordered, StreamID>, MID> skipped_per_stream; + UnwrappedTSN new_cumulative_ack = last_cumulative_tsn_ack_; + + for (const auto& [tsn, item] : outstanding_data_) { + if ((tsn != new_cumulative_ack.next_value()) || !item.is_abandoned()) { + break; + } + new_cumulative_ack = tsn; + std::pair<IsUnordered, StreamID> stream_id = + std::make_pair(item.data().is_unordered, item.data().stream_id); + + if (item.data().message_id > skipped_per_stream[stream_id]) { + skipped_per_stream[stream_id] = item.data().message_id; + } + } + + std::vector<IForwardTsnChunk::SkippedStream> skipped_streams; + skipped_streams.reserve(skipped_per_stream.size()); + for (const auto& [stream, message_id] : skipped_per_stream) { + skipped_streams.emplace_back(stream.first, stream.second, message_id); + } + + return IForwardTsnChunk(new_cumulative_ack.Wrap(), + std::move(skipped_streams)); +} + +void OutstandingData::ResetSequenceNumbers(UnwrappedTSN next_tsn, + UnwrappedTSN last_cumulative_tsn) { + RTC_DCHECK(outstanding_data_.empty()); + RTC_DCHECK(next_tsn_ == last_cumulative_tsn_ack_.next_value()); + RTC_DCHECK(next_tsn == last_cumulative_tsn.next_value()); + next_tsn_ = next_tsn; + last_cumulative_tsn_ack_ = last_cumulative_tsn; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.h b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.h new file mode 100644 index 0000000000..6b4b7121fb --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.h @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_OUTSTANDING_DATA_H_ +#define NET_DCSCTP_TX_OUTSTANDING_DATA_H_ + +#include <map> +#include <set> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// This class keeps track of outstanding data chunks (sent, not yet acked) and +// handles acking, nacking, rescheduling and abandoning. +class OutstandingData { + public: + // State for DATA chunks (message fragments) in the queue - used in tests. + enum class State { + // The chunk has been sent but not received yet (from the sender's point of + // view, as no SACK has been received yet that reference this chunk). + kInFlight, + // A SACK has been received which explicitly marked this chunk as missing - + // it's now NACKED and may be retransmitted if NACKED enough times. + kNacked, + // A chunk that will be retransmitted when possible. + kToBeRetransmitted, + // A SACK has been received which explicitly marked this chunk as received. + kAcked, + // A chunk whose message has expired or has been retransmitted too many + // times (RFC3758). It will not be retransmitted anymore. + kAbandoned, + }; + + // Contains variables scoped to a processing of an incoming SACK. + struct AckInfo { + explicit AckInfo(UnwrappedTSN cumulative_tsn_ack) + : highest_tsn_acked(cumulative_tsn_ack) {} + + // Bytes acked by increasing cumulative_tsn_ack and gap_ack_blocks. + size_t bytes_acked = 0; + + // Indicates if this SACK indicates that packet loss has occurred. Just + // because a packet is missing in the SACK doesn't necessarily mean that + // there is packet loss as that packet might be in-flight and received + // out-of-order. But when it has been reported missing consecutive times, it + // will eventually be considered "lost" and this will be set. + bool has_packet_loss = false; + + // Highest TSN Newly Acknowledged, an SCTP variable. + UnwrappedTSN highest_tsn_acked; + + // The set of lifecycle IDs that were acked using cumulative_tsn_ack. + std::vector<LifecycleId> acked_lifecycle_ids; + // The set of lifecycle IDs that were acked, but had been abandoned. + std::vector<LifecycleId> abandoned_lifecycle_ids; + }; + + OutstandingData( + size_t data_chunk_header_size, + UnwrappedTSN next_tsn, + UnwrappedTSN last_cumulative_tsn_ack, + std::function<bool(IsUnordered, StreamID, MID)> discard_from_send_queue) + : data_chunk_header_size_(data_chunk_header_size), + next_tsn_(next_tsn), + last_cumulative_tsn_ack_(last_cumulative_tsn_ack), + discard_from_send_queue_(std::move(discard_from_send_queue)) {} + + AckInfo HandleSack( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + bool is_in_fast_recovery); + + // Returns as many of the chunks that are eligible for fast retransmissions + // and that would fit in a single packet of `max_size`. The eligible chunks + // that didn't fit will be marked for (normal) retransmission and will not be + // returned if this method is called again. + std::vector<std::pair<TSN, Data>> GetChunksToBeFastRetransmitted( + size_t max_size); + + // Given `max_size` of space left in a packet, which chunks can be added to + // it? + std::vector<std::pair<TSN, Data>> GetChunksToBeRetransmitted(size_t max_size); + + size_t outstanding_bytes() const { return outstanding_bytes_; } + + // Returns the number of DATA chunks that are in-flight. + size_t outstanding_items() const { return outstanding_items_; } + + // Given the current time `now_ms`, expire and abandon outstanding (sent at + // least once) chunks that have a limited lifetime. + void ExpireOutstandingChunks(TimeMs now); + + bool empty() const { return outstanding_data_.empty(); } + + bool has_data_to_be_fast_retransmitted() const { + return !to_be_fast_retransmitted_.empty(); + } + + bool has_data_to_be_retransmitted() const { + return !to_be_retransmitted_.empty() || !to_be_fast_retransmitted_.empty(); + } + + UnwrappedTSN last_cumulative_tsn_ack() const { + return last_cumulative_tsn_ack_; + } + + UnwrappedTSN next_tsn() const { return next_tsn_; } + + UnwrappedTSN highest_outstanding_tsn() const; + + // Schedules `data` to be sent, with the provided partial reliability + // parameters. Returns the TSN if the item was actually added and scheduled to + // be sent, and absl::nullopt if it shouldn't be sent. + absl::optional<UnwrappedTSN> Insert( + const Data& data, + TimeMs time_sent, + MaxRetransmits max_retransmissions = MaxRetransmits::NoLimit(), + TimeMs expires_at = TimeMs::InfiniteFuture(), + LifecycleId lifecycle_id = LifecycleId::NotSet()); + + // Nacks all outstanding data. + void NackAll(); + + // Creates a FORWARD-TSN chunk. + ForwardTsnChunk CreateForwardTsn() const; + + // Creates an I-FORWARD-TSN chunk. + IForwardTsnChunk CreateIForwardTsn() const; + + // Given the current time and a TSN, it returns the measured RTT between when + // the chunk was sent and now. It takes into acccount Karn's algorithm, so if + // the chunk has ever been retransmitted, it will return absl::nullopt. + absl::optional<DurationMs> MeasureRTT(TimeMs now, UnwrappedTSN tsn) const; + + // Returns the internal state of all queued chunks. This is only used in + // unit-tests. + std::vector<std::pair<TSN, State>> GetChunkStatesForTesting() const; + + // Returns true if the next chunk that is not acked by the peer has been + // abandoned, which means that a FORWARD-TSN should be sent. + bool ShouldSendForwardTsn() const; + + // Sets the next TSN to be used. This is used in handover. + void ResetSequenceNumbers(UnwrappedTSN next_tsn, + UnwrappedTSN last_cumulative_tsn); + + private: + // A fragmented message's DATA chunk while in the retransmission queue, and + // its associated metadata. + class Item { + public: + enum class NackAction { + kNothing, + kRetransmit, + kAbandon, + }; + + Item(Data data, + TimeMs time_sent, + MaxRetransmits max_retransmissions, + TimeMs expires_at, + LifecycleId lifecycle_id) + : time_sent_(time_sent), + max_retransmissions_(max_retransmissions), + expires_at_(expires_at), + lifecycle_id_(lifecycle_id), + data_(std::move(data)) {} + + Item(const Item&) = delete; + Item& operator=(const Item&) = delete; + + TimeMs time_sent() const { return time_sent_; } + + const Data& data() const { return data_; } + + // Acks an item. + void Ack(); + + // Nacks an item. If it has been nacked enough times, or if `retransmit_now` + // is set, it might be marked for retransmission. If the item has reached + // its max retransmission value, it will instead be abandoned. The action + // performed is indicated as return value. + NackAction Nack(bool retransmit_now); + + // Prepares the item to be retransmitted. Sets it as outstanding and + // clears all nack counters. + void MarkAsRetransmitted(); + + // Marks this item as abandoned. + void Abandon(); + + bool is_outstanding() const { return ack_state_ == AckState::kUnacked; } + bool is_acked() const { return ack_state_ == AckState::kAcked; } + bool is_nacked() const { return ack_state_ == AckState::kNacked; } + bool is_abandoned() const { return lifecycle_ == Lifecycle::kAbandoned; } + + // Indicates if this chunk should be retransmitted. + bool should_be_retransmitted() const { + return lifecycle_ == Lifecycle::kToBeRetransmitted; + } + // Indicates if this chunk has ever been retransmitted. + bool has_been_retransmitted() const { return num_retransmissions_ > 0; } + + // Given the current time, and the current state of this DATA chunk, it will + // indicate if it has expired (SCTP Partial Reliability Extension). + bool has_expired(TimeMs now) const; + + LifecycleId lifecycle_id() const { return lifecycle_id_; } + + private: + enum class Lifecycle : uint8_t { + // The chunk is alive (sent, received, etc) + kActive, + // The chunk is scheduled to be retransmitted, and will then transition to + // become active. + kToBeRetransmitted, + // The chunk has been abandoned. This is a terminal state. + kAbandoned + }; + enum class AckState : uint8_t { + // The chunk is in-flight. + kUnacked, + // The chunk has been received and acknowledged. + kAcked, + // The chunk has been nacked and is possibly lost. + kNacked + }; + + // NOTE: This data structure has been optimized for size, by ordering fields + // to avoid unnecessary padding. + + // When the packet was sent, and placed in this queue. + const TimeMs time_sent_; + // If the message was sent with a maximum number of retransmissions, this is + // set to that number. The value zero (0) means that it will never be + // retransmitted. + const MaxRetransmits max_retransmissions_; + + // Indicates the life cycle status of this chunk. + Lifecycle lifecycle_ = Lifecycle::kActive; + // Indicates the presence of this chunk, if it's in flight (Unacked), has + // been received (Acked) or is possibly lost (Nacked). + AckState ack_state_ = AckState::kUnacked; + + // The number of times the DATA chunk has been nacked (by having received a + // SACK which doesn't include it). Will be cleared on retransmissions. + uint8_t nack_count_ = 0; + // The number of times the DATA chunk has been retransmitted. + uint16_t num_retransmissions_ = 0; + + // At this exact millisecond, the item is considered expired. If the message + // is not to be expired, this is set to the infinite future. + const TimeMs expires_at_; + + // An optional lifecycle id, which may only be set for the last fragment. + const LifecycleId lifecycle_id_; + + // The actual data to send/retransmit. + const Data data_; + }; + + // Returns how large a chunk will be, serialized, carrying the data + size_t GetSerializedChunkSize(const Data& data) const; + + // Given a `cumulative_tsn_ack` from an incoming SACK, will remove those items + // in the retransmission queue up until this value and will update `ack_info` + // by setting `bytes_acked_by_cumulative_tsn_ack`. + void RemoveAcked(UnwrappedTSN cumulative_tsn_ack, AckInfo& ack_info); + + // Will mark the chunks covered by the `gap_ack_blocks` from an incoming SACK + // as "acked" and update `ack_info` by adding new TSNs to `added_tsns`. + void AckGapBlocks(UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + AckInfo& ack_info); + + // Mark chunks reported as "missing", as "nacked" or "to be retransmitted" + // depending how many times this has happened. Only packets up until + // `ack_info.highest_tsn_acked` (highest TSN newly acknowledged) are + // nacked/retransmitted. The method will set `ack_info.has_packet_loss`. + void NackBetweenAckBlocks( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + bool is_in_fast_recovery, + OutstandingData::AckInfo& ack_info); + + // Process the acknowledgement of the chunk referenced by `iter` and updates + // state in `ack_info` and the object's state. + void AckChunk(AckInfo& ack_info, std::map<UnwrappedTSN, Item>::iterator iter); + + // Helper method to process an incoming nack of an item and perform the + // correct operations given the action indicated when nacking an item (e.g. + // retransmitting or abandoning). The return value indicate if an action was + // performed, meaning that packet loss was detected and acted upon. If + // `do_fast_retransmit` is set and if the item has been nacked sufficiently + // many times so that it should be retransmitted, this will schedule it to be + // "fast retransmitted". This is only done just before going into fast + // recovery. + bool NackItem(UnwrappedTSN tsn, + Item& item, + bool retransmit_now, + bool do_fast_retransmit); + + // Given that a message fragment, `item` has been abandoned, abandon all other + // fragments that share the same message - both never-before-sent fragments + // that are still in the SendQueue and outstanding chunks. + void AbandonAllFor(const OutstandingData::Item& item); + + std::vector<std::pair<TSN, Data>> ExtractChunksThatCanFit( + std::set<UnwrappedTSN>& chunks, + size_t max_size); + + bool IsConsistent() const; + + // The size of the data chunk (DATA/I-DATA) header that is used. + const size_t data_chunk_header_size_; + // Next TSN to used. + UnwrappedTSN next_tsn_; + // The last cumulative TSN ack number. + UnwrappedTSN last_cumulative_tsn_ack_; + // Callback when to discard items from the send queue. + std::function<bool(IsUnordered, StreamID, MID)> discard_from_send_queue_; + + std::map<UnwrappedTSN, Item> outstanding_data_; + // The number of bytes that are in-flight (sent but not yet acked or nacked). + size_t outstanding_bytes_ = 0; + // The number of DATA chunks that are in-flight (sent but not yet acked or + // nacked). + size_t outstanding_items_ = 0; + // Data chunks that are eligible for fast retransmission. + std::set<UnwrappedTSN> to_be_fast_retransmitted_; + // Data chunks that are to be retransmitted. + std::set<UnwrappedTSN> to_be_retransmitted_; +}; +} // namespace dcsctp +#endif // NET_DCSCTP_TX_OUTSTANDING_DATA_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/outstanding_data_test.cc b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data_test.cc new file mode 100644 index 0000000000..cdca40cfef --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data_test.cc @@ -0,0 +1,591 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/outstanding_data.h" + +#include <vector> + +#include "absl/types/optional.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/testing/data_generator.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::MockFunction; +using State = ::dcsctp::OutstandingData::State; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::Return; +using ::testing::StrictMock; + +constexpr TimeMs kNow(42); + +class OutstandingDataTest : public testing::Test { + protected: + OutstandingDataTest() + : gen_(MID(42)), + buf_(DataChunk::kHeaderSize, + unwrapper_.Unwrap(TSN(10)), + unwrapper_.Unwrap(TSN(9)), + on_discard_.AsStdFunction()) {} + + UnwrappedTSN::Unwrapper unwrapper_; + DataGenerator gen_; + StrictMock<MockFunction<bool(IsUnordered, StreamID, MID)>> on_discard_; + OutstandingData buf_; +}; + +TEST_F(OutstandingDataTest, HasInitialState) { + EXPECT_TRUE(buf_.empty()); + EXPECT_EQ(buf_.outstanding_bytes(), 0u); + EXPECT_EQ(buf_.outstanding_items(), 0u); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(9)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(10)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(9)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked))); + EXPECT_FALSE(buf_.ShouldSendForwardTsn()); +} + +TEST_F(OutstandingDataTest, InsertChunk) { + ASSERT_HAS_VALUE_AND_ASSIGN(UnwrappedTSN tsn, + buf_.Insert(gen_.Ordered({1}, "BE"), kNow)); + + EXPECT_EQ(tsn.Wrap(), TSN(10)); + + EXPECT_EQ(buf_.outstanding_bytes(), DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(buf_.outstanding_items(), 1u); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(9)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(11)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(10)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), + Pair(TSN(10), State::kInFlight))); +} + +TEST_F(OutstandingDataTest, AcksSingleChunk) { + buf_.Insert(gen_.Ordered({1}, "BE"), kNow); + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(10)), {}, false); + + EXPECT_EQ(ack.bytes_acked, DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(ack.highest_tsn_acked.Wrap(), TSN(10)); + EXPECT_FALSE(ack.has_packet_loss); + + EXPECT_EQ(buf_.outstanding_bytes(), 0u); + EXPECT_EQ(buf_.outstanding_items(), 0u); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(10)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(11)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(10)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked))); +} + +TEST_F(OutstandingDataTest, AcksPreviousChunkDoesntUpdate) { + buf_.Insert(gen_.Ordered({1}, "BE"), kNow); + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), {}, false); + + EXPECT_EQ(buf_.outstanding_bytes(), DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(buf_.outstanding_items(), 1u); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(9)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(11)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(10)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), + Pair(TSN(10), State::kInFlight))); +} + +TEST_F(OutstandingDataTest, AcksAndNacksWithGapAckBlocks) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow); + buf_.Insert(gen_.Ordered({1}, "E"), kNow); + + std::vector<SackChunk::GapAckBlock> gab = {SackChunk::GapAckBlock(2, 2)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab, false); + EXPECT_EQ(ack.bytes_acked, DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(ack.highest_tsn_acked.Wrap(), TSN(11)); + EXPECT_FALSE(ack.has_packet_loss); + + EXPECT_EQ(buf_.outstanding_bytes(), 0u); + EXPECT_EQ(buf_.outstanding_items(), 0u); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(9)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(12)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(11)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked))); +} + +TEST_F(OutstandingDataTest, NacksThreeTimesWithSameTsnDoesntRetransmit) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow); + buf_.Insert(gen_.Ordered({1}, "E"), kNow); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked))); +} + +TEST_F(OutstandingDataTest, NacksThreeTimesResultsInRetransmission) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, "E"), kNow); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(2, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(2, 4)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_EQ(ack.bytes_acked, DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(ack.highest_tsn_acked.Wrap(), TSN(13)); + EXPECT_TRUE(ack.has_packet_loss); + + EXPECT_TRUE(buf_.has_data_to_be_retransmitted()); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked))); + + EXPECT_THAT(buf_.GetChunksToBeFastRetransmitted(1000), + ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(buf_.GetChunksToBeRetransmitted(1000), IsEmpty()); +} + +TEST_F(OutstandingDataTest, NacksThreeTimesResultsInAbandoning) { + static constexpr MaxRetransmits kMaxRetransmissions(0); + buf_.Insert(gen_.Ordered({1}, "B"), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, "E"), kNow, kMaxRetransmissions); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(2, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(2, 4)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_EQ(ack.bytes_acked, DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(ack.highest_tsn_acked.Wrap(), TSN(13)); + EXPECT_TRUE(ack.has_packet_loss); + + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(14)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAbandoned))); +} + +TEST_F(OutstandingDataTest, NacksThreeTimesResultsInAbandoningWithPlaceholder) { + static constexpr MaxRetransmits kMaxRetransmissions(0); + buf_.Insert(gen_.Ordered({1}, "B"), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(2, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(true)); + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(2, 4)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_EQ(ack.bytes_acked, DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(ack.highest_tsn_acked.Wrap(), TSN(13)); + EXPECT_TRUE(ack.has_packet_loss); + + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(15)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAbandoned), // + Pair(TSN(14), State::kAbandoned))); +} + +TEST_F(OutstandingDataTest, ExpiresChunkBeforeItIsInserted) { + static constexpr TimeMs kExpiresAt = kNow + DurationMs(1); + EXPECT_TRUE(buf_.Insert(gen_.Ordered({1}, "B"), kNow, + MaxRetransmits::NoLimit(), kExpiresAt) + .has_value()); + EXPECT_TRUE(buf_.Insert(gen_.Ordered({1}, ""), kNow + DurationMs(0), + MaxRetransmits::NoLimit(), kExpiresAt) + .has_value()); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + EXPECT_FALSE(buf_.Insert(gen_.Ordered({1}, "E"), kNow + DurationMs(1), + MaxRetransmits::NoLimit(), kExpiresAt) + .has_value()); + + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(9)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(13)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(12)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), + Pair(TSN(12), State::kAbandoned))); +} + +TEST_F(OutstandingDataTest, CanGenerateForwardTsn) { + static constexpr MaxRetransmits kMaxRetransmissions(0); + buf_.Insert(gen_.Ordered({1}, "B"), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, "E"), kNow, kMaxRetransmissions); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + buf_.NackAll(); + + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), + Pair(TSN(12), State::kAbandoned))); + + EXPECT_TRUE(buf_.ShouldSendForwardTsn()); + ForwardTsnChunk chunk = buf_.CreateForwardTsn(); + EXPECT_EQ(chunk.new_cumulative_tsn(), TSN(12)); +} + +TEST_F(OutstandingDataTest, AckWithGapBlocksFromRFC4960Section334) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, "E"), kNow); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + testing::ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight), // + Pair(TSN(14), State::kInFlight), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight))); + + std::vector<SackChunk::GapAckBlock> gab = {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 5)}; + buf_.HandleSack(unwrapper_.Unwrap(TSN(12)), gab, false); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kNacked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kNacked), // + Pair(TSN(17), State::kAcked))); +} + +TEST_F(OutstandingDataTest, MeasureRTT) { + buf_.Insert(gen_.Ordered({1}, "BE"), kNow); + buf_.Insert(gen_.Ordered({1}, "BE"), kNow + DurationMs(1)); + buf_.Insert(gen_.Ordered({1}, "BE"), kNow + DurationMs(2)); + + static constexpr DurationMs kDuration(123); + ASSERT_HAS_VALUE_AND_ASSIGN( + DurationMs duration, + buf_.MeasureRTT(kNow + kDuration, unwrapper_.Unwrap(TSN(11)))); + + EXPECT_EQ(duration, kDuration - DurationMs(1)); +} + +TEST_F(OutstandingDataTest, MustRetransmitBeforeGettingNackedAgain) { + // This test case verifies that a chunk that has been nacked, and scheduled to + // be retransmitted, doesn't get nacked again until it has been actually sent + // on the wire. + + static constexpr MaxRetransmits kOneRetransmission(1); + for (int tsn = 10; tsn <= 20; ++tsn) { + buf_.Insert(gen_.Ordered({1}, tsn == 10 ? "B" + : tsn == 20 ? "E" + : ""), + kNow, kOneRetransmission); + } + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(2, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(2, 4)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_TRUE(ack.has_packet_loss); + EXPECT_TRUE(buf_.has_data_to_be_retransmitted()); + + // Don't call GetChunksToBeRetransmitted yet - simulate that the congestion + // window doesn't allow it to be retransmitted yet. It does however get more + // SACKs indicating packet loss. + + std::vector<SackChunk::GapAckBlock> gab4 = {SackChunk::GapAckBlock(2, 5)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab4, false).has_packet_loss); + EXPECT_TRUE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab5 = {SackChunk::GapAckBlock(2, 6)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab5, false).has_packet_loss); + EXPECT_TRUE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab6 = {SackChunk::GapAckBlock(2, 7)}; + OutstandingData::AckInfo ack2 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab6, false); + + EXPECT_FALSE(ack2.has_packet_loss); + EXPECT_TRUE(buf_.has_data_to_be_retransmitted()); + + // Now it's retransmitted. + EXPECT_THAT(buf_.GetChunksToBeFastRetransmitted(1000), + ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(buf_.GetChunksToBeRetransmitted(1000), IsEmpty()); + + // And obviously lost, as it will get NACKed and abandoned. + std::vector<SackChunk::GapAckBlock> gab7 = {SackChunk::GapAckBlock(2, 8)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab7, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab8 = {SackChunk::GapAckBlock(2, 9)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab8, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + + std::vector<SackChunk::GapAckBlock> gab9 = {SackChunk::GapAckBlock(2, 10)}; + OutstandingData::AckInfo ack3 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab9, false); + + EXPECT_TRUE(ack3.has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); +} + +TEST_F(OutstandingDataTest, CanAbandonChunksMarkedForFastRetransmit) { + // This test is a bit convoluted, and can't really happen with a well behaving + // client, but this was found by fuzzers. This test will verify that a message + // that was both marked as "to be fast retransmitted" and "abandoned" at the + // same time doesn't cause any consistency issues. + + // Add chunks 10-14, but chunk 11 has zero retransmissions. When chunk 10 and + // 11 are NACKed three times, chunk 10 will be marked for retransmission, but + // chunk 11 will be abandoned, which also abandons chunk 10, as it's part of + // the same message. + buf_.Insert(gen_.Ordered({1}, "B"), kNow); // 10 + buf_.Insert(gen_.Ordered({1}, ""), kNow, MaxRetransmits(0)); // 11 + buf_.Insert(gen_.Ordered({1}, ""), kNow); // 12 + buf_.Insert(gen_.Ordered({1}, ""), kNow); // 13 + buf_.Insert(gen_.Ordered({1}, "E"), kNow); // 14 + + // ACK 9, 12 + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(3, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + // ACK 9, 12, 13 + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(3, 4)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + + // ACK 9, 12, 13, 14 + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(3, 5)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_TRUE(ack.has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_THAT(buf_.GetChunksToBeFastRetransmitted(1000), IsEmpty()); + EXPECT_THAT(buf_.GetChunksToBeRetransmitted(1000), IsEmpty()); +} + +TEST_F(OutstandingDataTest, LifecyleReturnsAckedItemsInAckInfo) { + buf_.Insert(gen_.Ordered({1}, "BE"), kNow, MaxRetransmits::NoLimit(), + TimeMs::InfiniteFuture(), LifecycleId(42)); + buf_.Insert(gen_.Ordered({1}, "BE"), kNow, MaxRetransmits::NoLimit(), + TimeMs::InfiniteFuture(), LifecycleId(43)); + buf_.Insert(gen_.Ordered({1}, "BE"), kNow, MaxRetransmits::NoLimit(), + TimeMs::InfiniteFuture(), LifecycleId(44)); + + OutstandingData::AckInfo ack1 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(11)), {}, false); + + EXPECT_THAT(ack1.acked_lifecycle_ids, + ElementsAre(LifecycleId(42), LifecycleId(43))); + + OutstandingData::AckInfo ack2 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(12)), {}, false); + + EXPECT_THAT(ack2.acked_lifecycle_ids, ElementsAre(LifecycleId(44))); +} + +TEST_F(OutstandingDataTest, LifecycleReturnsAbandonedNackedThreeTimes) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, ""), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, ""), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, "E"), kNow, MaxRetransmits(0), + TimeMs::InfiniteFuture(), LifecycleId(42)); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(2, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(2, 4)}; + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + OutstandingData::AckInfo ack1 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_TRUE(ack1.has_packet_loss); + EXPECT_THAT(ack1.abandoned_lifecycle_ids, IsEmpty()); + + // This will generate a FORWARD-TSN, which is acked + EXPECT_TRUE(buf_.ShouldSendForwardTsn()); + ForwardTsnChunk chunk = buf_.CreateForwardTsn(); + EXPECT_EQ(chunk.new_cumulative_tsn(), TSN(13)); + + OutstandingData::AckInfo ack2 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(13)), {}, false); + EXPECT_FALSE(ack2.has_packet_loss); + EXPECT_THAT(ack2.abandoned_lifecycle_ids, ElementsAre(LifecycleId(42))); +} + +TEST_F(OutstandingDataTest, LifecycleReturnsAbandonedAfterT3rtxExpired) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, ""), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, ""), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, "E"), kNow, MaxRetransmits(0), + TimeMs::InfiniteFuture(), LifecycleId(42)); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + testing::ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight))); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 4)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + testing::ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked))); + + // T3-rtx triggered. + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + buf_.NackAll(); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + testing::ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAbandoned))); + + // This will generate a FORWARD-TSN, which is acked + EXPECT_TRUE(buf_.ShouldSendForwardTsn()); + ForwardTsnChunk chunk = buf_.CreateForwardTsn(); + EXPECT_EQ(chunk.new_cumulative_tsn(), TSN(13)); + + OutstandingData::AckInfo ack2 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(13)), {}, false); + EXPECT_FALSE(ack2.has_packet_loss); + EXPECT_THAT(ack2.abandoned_lifecycle_ids, ElementsAre(LifecycleId(42))); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.cc new file mode 100644 index 0000000000..44b20ba2c2 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_error_counter.h" + +#include "absl/strings/string_view.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +bool RetransmissionErrorCounter::Increment(absl::string_view reason) { + ++counter_; + if (limit_.has_value() && counter_ > limit_.value()) { + RTC_DLOG(LS_INFO) << log_prefix_ << reason + << ", too many retransmissions, counter=" << counter_; + return false; + } + + RTC_DLOG(LS_VERBOSE) << log_prefix_ << reason << ", new counter=" << counter_ + << ", max=" << limit_.value_or(-1); + return true; +} + +void RetransmissionErrorCounter::Clear() { + if (counter_ > 0) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "recovered from counter=" << counter_; + counter_ = 0; + } +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.h b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.h new file mode 100644 index 0000000000..18af3d3c4f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RETRANSMISSION_ERROR_COUNTER_H_ +#define NET_DCSCTP_TX_RETRANSMISSION_ERROR_COUNTER_H_ + +#include <functional> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "net/dcsctp/public/dcsctp_options.h" + +namespace dcsctp { + +// The RetransmissionErrorCounter is a simple counter with a limit, and when +// the limit is exceeded, the counter is exhausted and the connection will +// be closed. It's incremented on retransmission errors, such as the T3-RTX +// timer expiring, but also missing heartbeats and stream reset requests. +class RetransmissionErrorCounter { + public: + RetransmissionErrorCounter(absl::string_view log_prefix, + const DcSctpOptions& options) + : log_prefix_(std::string(log_prefix) + "rtx-errors: "), + limit_(options.max_retransmissions) {} + + // Increments the retransmission timer. If the maximum error count has been + // reached, `false` will be returned. + bool Increment(absl::string_view reason); + bool IsExhausted() const { return limit_.has_value() && counter_ > *limit_; } + + // Clears the retransmission errors. + void Clear(); + + // Returns its current value + int value() const { return counter_; } + + private: + const std::string log_prefix_; + const absl::optional<int> limit_; + int counter_ = 0; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RETRANSMISSION_ERROR_COUNTER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter_test.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter_test.cc new file mode 100644 index 0000000000..67bbc0bec5 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter_test.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_error_counter.h" + +#include "net/dcsctp/public/dcsctp_options.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(RetransmissionErrorCounterTest, HasInitialValue) { + DcSctpOptions options; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_EQ(counter.value(), 0); +} + +TEST(RetransmissionErrorCounterTest, ReturnsFalseAtMaximumValue) { + DcSctpOptions options; + options.max_retransmissions = 5; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_TRUE(counter.Increment("test")); // 2 + EXPECT_TRUE(counter.Increment("test")); // 3 + EXPECT_TRUE(counter.Increment("test")); // 4 + EXPECT_TRUE(counter.Increment("test")); // 5 + EXPECT_FALSE(counter.Increment("test")); // Too many retransmissions +} + +TEST(RetransmissionErrorCounterTest, CanHandleZeroRetransmission) { + DcSctpOptions options; + options.max_retransmissions = 0; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_FALSE(counter.Increment("test")); // One is too many. +} + +TEST(RetransmissionErrorCounterTest, IsExhaustedAtMaximum) { + DcSctpOptions options; + options.max_retransmissions = 3; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_TRUE(counter.Increment("test")); // 2 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_TRUE(counter.Increment("test")); // 3 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_FALSE(counter.Increment("test")); // Too many retransmissions + EXPECT_TRUE(counter.IsExhausted()); + EXPECT_FALSE(counter.Increment("test")); // One after too many + EXPECT_TRUE(counter.IsExhausted()); +} + +TEST(RetransmissionErrorCounterTest, ClearingCounter) { + DcSctpOptions options; + options.max_retransmissions = 3; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_TRUE(counter.Increment("test")); // 2 + counter.Clear(); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_TRUE(counter.Increment("test")); // 2 + EXPECT_TRUE(counter.Increment("test")); // 3 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_FALSE(counter.Increment("test")); // Too many retransmissions + EXPECT_TRUE(counter.IsExhausted()); +} + +TEST(RetransmissionErrorCounterTest, CanBeLimitless) { + DcSctpOptions options; + options.max_retransmissions = absl::nullopt; + RetransmissionErrorCounter counter("log: ", options); + for (int i = 0; i < 100; ++i) { + EXPECT_TRUE(counter.Increment("test")); + EXPECT_FALSE(counter.IsExhausted()); + } +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.cc new file mode 100644 index 0000000000..36e2a859ba --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.cc @@ -0,0 +1,611 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_queue.h" + +#include <algorithm> +#include <cstdint> +#include <functional> +#include <iterator> +#include <map> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/outstanding_data.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { +namespace { + +// Allow sending only slightly less than an MTU, to account for headers. +constexpr float kMinBytesRequiredToSendFactor = 0.9; +} // namespace + +RetransmissionQueue::RetransmissionQueue( + absl::string_view log_prefix, + DcSctpSocketCallbacks* callbacks, + TSN my_initial_tsn, + size_t a_rwnd, + SendQueue& send_queue, + std::function<void(DurationMs rtt)> on_new_rtt, + std::function<void()> on_clear_retransmission_counter, + Timer& t3_rtx, + const DcSctpOptions& options, + bool supports_partial_reliability, + bool use_message_interleaving) + : callbacks_(*callbacks), + options_(options), + min_bytes_required_to_send_(options.mtu * kMinBytesRequiredToSendFactor), + partial_reliability_(supports_partial_reliability), + log_prefix_(std::string(log_prefix) + "tx: "), + data_chunk_header_size_(use_message_interleaving + ? IDataChunk::kHeaderSize + : DataChunk::kHeaderSize), + on_new_rtt_(std::move(on_new_rtt)), + on_clear_retransmission_counter_( + std::move(on_clear_retransmission_counter)), + t3_rtx_(t3_rtx), + cwnd_(options_.cwnd_mtus_initial * options_.mtu), + rwnd_(a_rwnd), + // https://tools.ietf.org/html/rfc4960#section-7.2.1 + // "The initial value of ssthresh MAY be arbitrarily high (for + // example, implementations MAY use the size of the receiver advertised + // window)."" + ssthresh_(rwnd_), + partial_bytes_acked_(0), + send_queue_(send_queue), + outstanding_data_( + data_chunk_header_size_, + tsn_unwrapper_.Unwrap(my_initial_tsn), + tsn_unwrapper_.Unwrap(TSN(*my_initial_tsn - 1)), + [this](IsUnordered unordered, StreamID stream_id, MID message_id) { + return send_queue_.Discard(unordered, stream_id, message_id); + }) {} + +bool RetransmissionQueue::IsConsistent() const { + return true; +} + +// Returns how large a chunk will be, serialized, carrying the data +size_t RetransmissionQueue::GetSerializedChunkSize(const Data& data) const { + return RoundUpTo4(data_chunk_header_size_ + data.size()); +} + +void RetransmissionQueue::MaybeExitFastRecovery( + UnwrappedTSN cumulative_tsn_ack) { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "When a SACK acknowledges all TSNs up to and including this [fast + // recovery] exit point, Fast Recovery is exited." + if (fast_recovery_exit_tsn_.has_value() && + cumulative_tsn_ack >= *fast_recovery_exit_tsn_) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "exit_point=" << *fast_recovery_exit_tsn_->Wrap() + << " reached - exiting fast recovery"; + fast_recovery_exit_tsn_ = absl::nullopt; + } +} + +void RetransmissionQueue::HandleIncreasedCumulativeTsnAck( + size_t outstanding_bytes, + size_t total_bytes_acked) { + // Allow some margin for classifying as fully utilized, due to e.g. that too + // small packets (less than kMinimumFragmentedPayload) are not sent + + // overhead. + bool is_fully_utilized = outstanding_bytes + options_.mtu >= cwnd_; + size_t old_cwnd = cwnd_; + if (phase() == CongestionAlgorithmPhase::kSlowStart) { + if (is_fully_utilized && !is_in_fast_recovery()) { + // https://tools.ietf.org/html/rfc4960#section-7.2.1 + // "Only when these three conditions are met can the cwnd be + // increased; otherwise, the cwnd MUST not be increased. If these + // conditions are met, then cwnd MUST be increased by, at most, the + // lesser of 1) the total size of the previously outstanding DATA + // chunk(s) acknowledged, and 2) the destination's path MTU." + cwnd_ += std::min(total_bytes_acked, options_.mtu); + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "SS increase cwnd=" << cwnd_ + << " (" << old_cwnd << ")"; + } + } else if (phase() == CongestionAlgorithmPhase::kCongestionAvoidance) { + // https://tools.ietf.org/html/rfc4960#section-7.2.2 + // "Whenever cwnd is greater than ssthresh, upon each SACK arrival + // that advances the Cumulative TSN Ack Point, increase + // partial_bytes_acked by the total number of bytes of all new chunks + // acknowledged in that SACK including chunks acknowledged by the new + // Cumulative TSN Ack and by Gap Ack Blocks." + size_t old_pba = partial_bytes_acked_; + partial_bytes_acked_ += total_bytes_acked; + + if (partial_bytes_acked_ >= cwnd_ && is_fully_utilized) { + // https://tools.ietf.org/html/rfc4960#section-7.2.2 + // "When partial_bytes_acked is equal to or greater than cwnd and + // before the arrival of the SACK the sender had cwnd or more bytes of + // data outstanding (i.e., before arrival of the SACK, flightsize was + // greater than or equal to cwnd), increase cwnd by MTU, and reset + // partial_bytes_acked to (partial_bytes_acked - cwnd)." + + // Errata: https://datatracker.ietf.org/doc/html/rfc8540#section-3.12 + partial_bytes_acked_ -= cwnd_; + cwnd_ += options_.mtu; + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "CA increase cwnd=" << cwnd_ + << " (" << old_cwnd << ") ssthresh=" << ssthresh_ + << ", pba=" << partial_bytes_acked_ << " (" + << old_pba << ")"; + } else { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "CA unchanged cwnd=" << cwnd_ + << " (" << old_cwnd << ") ssthresh=" << ssthresh_ + << ", pba=" << partial_bytes_acked_ << " (" + << old_pba << ")"; + } + } +} + +void RetransmissionQueue::HandlePacketLoss(UnwrappedTSN highest_tsn_acked) { + if (!is_in_fast_recovery()) { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "If not in Fast Recovery, adjust the ssthresh and cwnd of the + // destination address(es) to which the missing DATA chunks were last + // sent, according to the formula described in Section 7.2.3." + size_t old_cwnd = cwnd_; + size_t old_pba = partial_bytes_acked_; + ssthresh_ = std::max(cwnd_ / 2, options_.cwnd_mtus_min * options_.mtu); + cwnd_ = ssthresh_; + partial_bytes_acked_ = 0; + + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "packet loss detected (not fast recovery). cwnd=" + << cwnd_ << " (" << old_cwnd + << "), ssthresh=" << ssthresh_ + << ", pba=" << partial_bytes_acked_ << " (" << old_pba + << ")"; + + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "If not in Fast Recovery, enter Fast Recovery and mark the highest + // outstanding TSN as the Fast Recovery exit point." + fast_recovery_exit_tsn_ = outstanding_data_.highest_outstanding_tsn(); + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "fast recovery initiated with exit_point=" + << *fast_recovery_exit_tsn_->Wrap(); + } else { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "While in Fast Recovery, the ssthresh and cwnd SHOULD NOT change for + // any destinations due to a subsequent Fast Recovery event (i.e., one + // SHOULD NOT reduce the cwnd further due to a subsequent Fast Retransmit)." + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "packet loss detected (fast recovery). No changes."; + } +} + +void RetransmissionQueue::UpdateReceiverWindow(uint32_t a_rwnd) { + rwnd_ = outstanding_data_.outstanding_bytes() >= a_rwnd + ? 0 + : a_rwnd - outstanding_data_.outstanding_bytes(); +} + +void RetransmissionQueue::StartT3RtxTimerIfOutstandingData() { + // Note: Can't use `outstanding_bytes()` as that one doesn't count chunks to + // be retransmitted. + if (outstanding_data_.empty()) { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Whenever all outstanding data sent to an address have been + // acknowledged, turn off the T3-rtx timer of that address. + // Note: Already stopped in `StopT3RtxTimerOnIncreasedCumulativeTsnAck`." + } else { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Whenever a SACK is received that acknowledges the DATA chunk + // with the earliest outstanding TSN for that address, restart the T3-rtx + // timer for that address with its current RTO (if there is still + // outstanding data on that address)." + // "Whenever a SACK is received missing a TSN that was previously + // acknowledged via a Gap Ack Block, start the T3-rtx for the destination + // address to which the DATA chunk was originally transmitted if it is not + // already running." + if (!t3_rtx_.is_running()) { + t3_rtx_.Start(); + } + } +} + +bool RetransmissionQueue::IsSackValid(const SackChunk& sack) const { + // https://tools.ietf.org/html/rfc4960#section-6.2.1 + // "If Cumulative TSN Ack is less than the Cumulative TSN Ack Point, + // then drop the SACK. Since Cumulative TSN Ack is monotonically increasing, + // a SACK whose Cumulative TSN Ack is less than the Cumulative TSN Ack Point + // indicates an out-of- order SACK." + // + // Note: Important not to drop SACKs with identical TSN to that previously + // received, as the gap ack blocks or dup tsn fields may have changed. + UnwrappedTSN cumulative_tsn_ack = + tsn_unwrapper_.PeekUnwrap(sack.cumulative_tsn_ack()); + if (cumulative_tsn_ack < outstanding_data_.last_cumulative_tsn_ack()) { + // https://tools.ietf.org/html/rfc4960#section-6.2.1 + // "If Cumulative TSN Ack is less than the Cumulative TSN Ack Point, + // then drop the SACK. Since Cumulative TSN Ack is monotonically + // increasing, a SACK whose Cumulative TSN Ack is less than the Cumulative + // TSN Ack Point indicates an out-of- order SACK." + return false; + } else if (cumulative_tsn_ack > outstanding_data_.highest_outstanding_tsn()) { + return false; + } + return true; +} + +bool RetransmissionQueue::HandleSack(TimeMs now, const SackChunk& sack) { + if (!IsSackValid(sack)) { + return false; + } + + UnwrappedTSN old_last_cumulative_tsn_ack = + outstanding_data_.last_cumulative_tsn_ack(); + size_t old_outstanding_bytes = outstanding_data_.outstanding_bytes(); + size_t old_rwnd = rwnd_; + UnwrappedTSN cumulative_tsn_ack = + tsn_unwrapper_.Unwrap(sack.cumulative_tsn_ack()); + + if (sack.gap_ack_blocks().empty()) { + UpdateRTT(now, cumulative_tsn_ack); + } + + // Exit fast recovery before continuing processing, in case it needs to go + // into fast recovery again due to new reported packet loss. + MaybeExitFastRecovery(cumulative_tsn_ack); + + OutstandingData::AckInfo ack_info = outstanding_data_.HandleSack( + cumulative_tsn_ack, sack.gap_ack_blocks(), is_in_fast_recovery()); + + // Add lifecycle events for delivered messages. + for (LifecycleId lifecycle_id : ack_info.acked_lifecycle_ids) { + RTC_DLOG(LS_VERBOSE) << "Triggering OnLifecycleMessageDelivered(" + << lifecycle_id.value() << ")"; + callbacks_.OnLifecycleMessageDelivered(lifecycle_id); + callbacks_.OnLifecycleEnd(lifecycle_id); + } + for (LifecycleId lifecycle_id : ack_info.abandoned_lifecycle_ids) { + RTC_DLOG(LS_VERBOSE) << "Triggering OnLifecycleMessageExpired(" + << lifecycle_id.value() << ", true)"; + callbacks_.OnLifecycleMessageExpired(lifecycle_id, + /*maybe_delivered=*/true); + callbacks_.OnLifecycleEnd(lifecycle_id); + } + + // Update of outstanding_data_ is now done. Congestion control remains. + UpdateReceiverWindow(sack.a_rwnd()); + + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Received SACK, cum_tsn_ack=" + << *cumulative_tsn_ack.Wrap() << " (" + << *old_last_cumulative_tsn_ack.Wrap() + << "), outstanding_bytes=" + << outstanding_data_.outstanding_bytes() << " (" + << old_outstanding_bytes << "), rwnd=" << rwnd_ << " (" + << old_rwnd << ")"; + + if (cumulative_tsn_ack > old_last_cumulative_tsn_ack) { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Whenever a SACK is received that acknowledges the DATA chunk + // with the earliest outstanding TSN for that address, restart the T3-rtx + // timer for that address with its current RTO (if there is still + // outstanding data on that address)." + // Note: It may be started again in a bit further down. + t3_rtx_.Stop(); + + HandleIncreasedCumulativeTsnAck(old_outstanding_bytes, + ack_info.bytes_acked); + } + + if (ack_info.has_packet_loss) { + HandlePacketLoss(ack_info.highest_tsn_acked); + } + + // https://tools.ietf.org/html/rfc4960#section-8.2 + // "When an outstanding TSN is acknowledged [...] the endpoint shall clear + // the error counter ..." + if (ack_info.bytes_acked > 0) { + on_clear_retransmission_counter_(); + } + + StartT3RtxTimerIfOutstandingData(); + RTC_DCHECK(IsConsistent()); + return true; +} + +void RetransmissionQueue::UpdateRTT(TimeMs now, + UnwrappedTSN cumulative_tsn_ack) { + // RTT updating is flawed in SCTP, as explained in e.g. Pedersen J, Griwodz C, + // Halvorsen P (2006) Considerations of SCTP retransmission delays for thin + // streams. + // Due to delayed acknowledgement, the SACK may be sent much later which + // increases the calculated RTT. + // TODO(boivie): Consider occasionally sending DATA chunks with I-bit set and + // use only those packets for measurement. + + absl::optional<DurationMs> rtt = + outstanding_data_.MeasureRTT(now, cumulative_tsn_ack); + + if (rtt.has_value()) { + on_new_rtt_(*rtt); + } +} + +void RetransmissionQueue::HandleT3RtxTimerExpiry() { + size_t old_cwnd = cwnd_; + size_t old_outstanding_bytes = outstanding_bytes(); + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "For the destination address for which the timer expires, adjust + // its ssthresh with rules defined in Section 7.2.3 and set the cwnd <- MTU." + ssthresh_ = std::max(cwnd_ / 2, 4 * options_.mtu); + cwnd_ = 1 * options_.mtu; + // Errata: https://datatracker.ietf.org/doc/html/rfc8540#section-3.11 + partial_bytes_acked_ = 0; + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "For the destination address for which the timer expires, set RTO + // <- RTO * 2 ("back off the timer"). The maximum value discussed in rule C7 + // above (RTO.max) may be used to provide an upper bound to this doubling + // operation." + + // Already done by the Timer implementation. + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "Determine how many of the earliest (i.e., lowest TSN) outstanding + // DATA chunks for the address for which the T3-rtx has expired will fit into + // a single packet" + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "Note: Any DATA chunks that were sent to the address for which the + // T3-rtx timer expired but did not fit in one MTU (rule E3 above) should be + // marked for retransmission and sent as soon as cwnd allows (normally, when a + // SACK arrives)." + outstanding_data_.NackAll(); + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "Start the retransmission timer T3-rtx on the destination address + // to which the retransmission is sent, if rule R1 above indicates to do so." + + // Already done by the Timer implementation. + + RTC_DLOG(LS_INFO) << log_prefix_ << "t3-rtx expired. new cwnd=" << cwnd_ + << " (" << old_cwnd << "), ssthresh=" << ssthresh_ + << ", outstanding_bytes " << outstanding_bytes() << " (" + << old_outstanding_bytes << ")"; + RTC_DCHECK(IsConsistent()); +} + +std::vector<std::pair<TSN, Data>> +RetransmissionQueue::GetChunksForFastRetransmit(size_t bytes_in_packet) { + RTC_DCHECK(outstanding_data_.has_data_to_be_fast_retransmitted()); + RTC_DCHECK(IsDivisibleBy4(bytes_in_packet)); + std::vector<std::pair<TSN, Data>> to_be_sent; + size_t old_outstanding_bytes = outstanding_bytes(); + + to_be_sent = + outstanding_data_.GetChunksToBeFastRetransmitted(bytes_in_packet); + RTC_DCHECK(!to_be_sent.empty()); + + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "4) Restart the T3-rtx timer only if ... the endpoint is retransmitting + // the first outstanding DATA chunk sent to that address." + if (to_be_sent[0].first == + outstanding_data_.last_cumulative_tsn_ack().next_value().Wrap()) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ + << "First outstanding DATA to be retransmitted - restarting T3-RTX"; + t3_rtx_.Stop(); + } + + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Every time a DATA chunk is sent to any address (including a + // retransmission), if the T3-rtx timer of that address is not running, + // start it running so that it will expire after the RTO of that address." + if (!t3_rtx_.is_running()) { + t3_rtx_.Start(); + } + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Fast-retransmitting TSN " + << StrJoin(to_be_sent, ",", + [&](rtc::StringBuilder& sb, + const std::pair<TSN, Data>& c) { + sb << *c.first; + }) + << " - " + << absl::c_accumulate( + to_be_sent, 0, + [&](size_t r, const std::pair<TSN, Data>& d) { + return r + GetSerializedChunkSize(d.second); + }) + << " bytes. outstanding_bytes=" << outstanding_bytes() + << " (" << old_outstanding_bytes << ")"; + + RTC_DCHECK(IsConsistent()); + return to_be_sent; +} + +std::vector<std::pair<TSN, Data>> RetransmissionQueue::GetChunksToSend( + TimeMs now, + size_t bytes_remaining_in_packet) { + // Chunks are always padded to even divisible by four. + RTC_DCHECK(IsDivisibleBy4(bytes_remaining_in_packet)); + + std::vector<std::pair<TSN, Data>> to_be_sent; + size_t old_outstanding_bytes = outstanding_bytes(); + size_t old_rwnd = rwnd_; + + // Calculate the bandwidth budget (how many bytes that is + // allowed to be sent), and fill that up first with chunks that are + // scheduled to be retransmitted. If there is still budget, send new chunks + // (which will have their TSN assigned here.) + size_t max_bytes = + RoundDownTo4(std::min(max_bytes_to_send(), bytes_remaining_in_packet)); + + to_be_sent = outstanding_data_.GetChunksToBeRetransmitted(max_bytes); + max_bytes -= absl::c_accumulate(to_be_sent, 0, + [&](size_t r, const std::pair<TSN, Data>& d) { + return r + GetSerializedChunkSize(d.second); + }); + + while (max_bytes > data_chunk_header_size_) { + RTC_DCHECK(IsDivisibleBy4(max_bytes)); + absl::optional<SendQueue::DataToSend> chunk_opt = + send_queue_.Produce(now, max_bytes - data_chunk_header_size_); + if (!chunk_opt.has_value()) { + break; + } + + size_t chunk_size = GetSerializedChunkSize(chunk_opt->data); + max_bytes -= chunk_size; + rwnd_ -= chunk_size; + + absl::optional<UnwrappedTSN> tsn = outstanding_data_.Insert( + chunk_opt->data, now, + partial_reliability_ ? chunk_opt->max_retransmissions + : MaxRetransmits::NoLimit(), + partial_reliability_ ? chunk_opt->expires_at : TimeMs::InfiniteFuture(), + chunk_opt->lifecycle_id); + + if (tsn.has_value()) { + if (chunk_opt->lifecycle_id.IsSet()) { + RTC_DCHECK(chunk_opt->data.is_end); + callbacks_.OnLifecycleMessageFullySent(chunk_opt->lifecycle_id); + } + to_be_sent.emplace_back(tsn->Wrap(), std::move(chunk_opt->data)); + } + } + + if (!to_be_sent.empty()) { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Every time a DATA chunk is sent to any address (including a + // retransmission), if the T3-rtx timer of that address is not running, + // start it running so that it will expire after the RTO of that address." + if (!t3_rtx_.is_running()) { + t3_rtx_.Start(); + } + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Sending TSN " + << StrJoin(to_be_sent, ",", + [&](rtc::StringBuilder& sb, + const std::pair<TSN, Data>& c) { + sb << *c.first; + }) + << " - " + << absl::c_accumulate( + to_be_sent, 0, + [&](size_t r, const std::pair<TSN, Data>& d) { + return r + GetSerializedChunkSize(d.second); + }) + << " bytes. outstanding_bytes=" << outstanding_bytes() + << " (" << old_outstanding_bytes << "), cwnd=" << cwnd_ + << ", rwnd=" << rwnd_ << " (" << old_rwnd << ")"; + } + RTC_DCHECK(IsConsistent()); + return to_be_sent; +} + +bool RetransmissionQueue::can_send_data() const { + return cwnd_ < options_.avoid_fragmentation_cwnd_mtus * options_.mtu || + max_bytes_to_send() >= min_bytes_required_to_send_; +} + +bool RetransmissionQueue::ShouldSendForwardTsn(TimeMs now) { + if (!partial_reliability_) { + return false; + } + outstanding_data_.ExpireOutstandingChunks(now); + bool ret = outstanding_data_.ShouldSendForwardTsn(); + RTC_DCHECK(IsConsistent()); + return ret; +} + +size_t RetransmissionQueue::max_bytes_to_send() const { + size_t left = outstanding_bytes() >= cwnd_ ? 0 : cwnd_ - outstanding_bytes(); + + if (outstanding_bytes() == 0) { + // https://datatracker.ietf.org/doc/html/rfc4960#section-6.1 + // ... However, regardless of the value of rwnd (including if it is 0), the + // data sender can always have one DATA chunk in flight to the receiver if + // allowed by cwnd (see rule B, below). + return left; + } + + return std::min(rwnd(), left); +} + +void RetransmissionQueue::PrepareResetStream(StreamID stream_id) { + // TODO(boivie): These calls are now only affecting the send queue. The + // packet buffer can also change behavior - for example draining the chunk + // producer and eagerly assign TSNs so that an "Outgoing SSN Reset Request" + // can be sent quickly, with a known `sender_last_assigned_tsn`. + send_queue_.PrepareResetStream(stream_id); +} +bool RetransmissionQueue::HasStreamsReadyToBeReset() const { + return send_queue_.HasStreamsReadyToBeReset(); +} +void RetransmissionQueue::CommitResetStreams() { + send_queue_.CommitResetStreams(); +} +void RetransmissionQueue::RollbackResetStreams() { + send_queue_.RollbackResetStreams(); +} + +HandoverReadinessStatus RetransmissionQueue::GetHandoverReadiness() const { + HandoverReadinessStatus status; + if (!outstanding_data_.empty()) { + status.Add(HandoverUnreadinessReason::kRetransmissionQueueOutstandingData); + } + if (fast_recovery_exit_tsn_.has_value()) { + status.Add(HandoverUnreadinessReason::kRetransmissionQueueFastRecovery); + } + if (outstanding_data_.has_data_to_be_retransmitted()) { + status.Add(HandoverUnreadinessReason::kRetransmissionQueueNotEmpty); + } + return status; +} + +void RetransmissionQueue::AddHandoverState(DcSctpSocketHandoverState& state) { + state.tx.next_tsn = next_tsn().value(); + state.tx.rwnd = rwnd_; + state.tx.cwnd = cwnd_; + state.tx.ssthresh = ssthresh_; + state.tx.partial_bytes_acked = partial_bytes_acked_; +} + +void RetransmissionQueue::RestoreFromState( + const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(outstanding_data_.empty()); + RTC_DCHECK(!t3_rtx_.is_running()); + RTC_DCHECK(partial_bytes_acked_ == 0); + + cwnd_ = state.tx.cwnd; + rwnd_ = state.tx.rwnd; + ssthresh_ = state.tx.ssthresh; + partial_bytes_acked_ = state.tx.partial_bytes_acked; + + outstanding_data_.ResetSequenceNumbers( + tsn_unwrapper_.Unwrap(TSN(state.tx.next_tsn)), + tsn_unwrapper_.Unwrap(TSN(state.tx.next_tsn - 1))); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.h b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.h new file mode 100644 index 0000000000..830c0b346d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.h @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RETRANSMISSION_QUEUE_H_ +#define NET_DCSCTP_TX_RETRANSMISSION_QUEUE_H_ + +#include <cstdint> +#include <functional> +#include <map> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_handover_state.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/outstanding_data.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "net/dcsctp/tx/send_queue.h" + +namespace dcsctp { + +// The RetransmissionQueue manages all DATA/I-DATA chunks that are in-flight and +// schedules them to be retransmitted if necessary. Chunks are retransmitted +// when they have been lost for a number of consecutive SACKs, or when the +// retransmission timer, `t3_rtx` expires. +// +// As congestion control is tightly connected with the state of transmitted +// packets, that's also managed here to limit the amount of data that is +// in-flight (sent, but not yet acknowledged). +class RetransmissionQueue { + public: + static constexpr size_t kMinimumFragmentedPayload = 10; + using State = OutstandingData::State; + // Creates a RetransmissionQueue which will send data using `my_initial_tsn` + // (or a value from `DcSctpSocketHandoverState` if given) as the first TSN + // to use for sent fragments. It will poll data from `send_queue`. When SACKs + // are received, it will estimate the RTT, and call `on_new_rtt`. When an + // outstanding chunk has been ACKed, it will call + // `on_clear_retransmission_counter` and will also use `t3_rtx`, which is the + // SCTP retransmission timer to manage retransmissions. + RetransmissionQueue(absl::string_view log_prefix, + DcSctpSocketCallbacks* callbacks, + TSN my_initial_tsn, + size_t a_rwnd, + SendQueue& send_queue, + std::function<void(DurationMs rtt)> on_new_rtt, + std::function<void()> on_clear_retransmission_counter, + Timer& t3_rtx, + const DcSctpOptions& options, + bool supports_partial_reliability = true, + bool use_message_interleaving = false); + + // Handles a received SACK. Returns true if the `sack` was processed and + // false if it was discarded due to received out-of-order and not relevant. + bool HandleSack(TimeMs now, const SackChunk& sack); + + // Handles an expired retransmission timer. + void HandleT3RtxTimerExpiry(); + + bool has_data_to_be_fast_retransmitted() const { + return outstanding_data_.has_data_to_be_fast_retransmitted(); + } + + // Returns a list of chunks to "fast retransmit" that would fit in one SCTP + // packet with `bytes_in_packet` bytes available. The current value + // of `cwnd` is ignored. + std::vector<std::pair<TSN, Data>> GetChunksForFastRetransmit( + size_t bytes_in_packet); + + // Returns a list of chunks to send that would fit in one SCTP packet with + // `bytes_remaining_in_packet` bytes available. This may be further limited by + // the congestion control windows. Note that `ShouldSendForwardTSN` must be + // called prior to this method, to abandon expired chunks, as this method will + // not expire any chunks. + std::vector<std::pair<TSN, Data>> GetChunksToSend( + TimeMs now, + size_t bytes_remaining_in_packet); + + // Returns the internal state of all queued chunks. This is only used in + // unit-tests. + std::vector<std::pair<TSN, OutstandingData::State>> GetChunkStatesForTesting() + const { + return outstanding_data_.GetChunkStatesForTesting(); + } + + // Returns the next TSN that will be allocated for sent DATA chunks. + TSN next_tsn() const { return outstanding_data_.next_tsn().Wrap(); } + + // Returns the size of the congestion window, in bytes. This is the number of + // bytes that may be in-flight. + size_t cwnd() const { return cwnd_; } + + // Overrides the current congestion window size. + void set_cwnd(size_t cwnd) { cwnd_ = cwnd; } + + // Returns the current receiver window size. + size_t rwnd() const { return rwnd_; } + + // Returns the number of bytes of packets that are in-flight. + size_t outstanding_bytes() const { + return outstanding_data_.outstanding_bytes(); + } + + // Returns the number of DATA chunks that are in-flight. + size_t outstanding_items() const { + return outstanding_data_.outstanding_items(); + } + + // Indicates if the congestion control algorithm allows data to be sent. + bool can_send_data() const; + + // Given the current time `now`, it will evaluate if there are chunks that + // have expired and that need to be discarded. It returns true if a + // FORWARD-TSN should be sent. + bool ShouldSendForwardTsn(TimeMs now); + + // Creates a FORWARD-TSN chunk. + ForwardTsnChunk CreateForwardTsn() const { + return outstanding_data_.CreateForwardTsn(); + } + + // Creates an I-FORWARD-TSN chunk. + IForwardTsnChunk CreateIForwardTsn() const { + return outstanding_data_.CreateIForwardTsn(); + } + + // See the SendQueue for a longer description of these methods related + // to stream resetting. + void PrepareResetStream(StreamID stream_id); + bool HasStreamsReadyToBeReset() const; + std::vector<StreamID> GetStreamsReadyToBeReset() const { + return send_queue_.GetStreamsReadyToBeReset(); + } + void CommitResetStreams(); + void RollbackResetStreams(); + + HandoverReadinessStatus GetHandoverReadiness() const; + + void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); + + private: + enum class CongestionAlgorithmPhase { + kSlowStart, + kCongestionAvoidance, + }; + + bool IsConsistent() const; + + // Returns how large a chunk will be, serialized, carrying the data + size_t GetSerializedChunkSize(const Data& data) const; + + // Indicates if the congestion control algorithm is in "fast recovery". + bool is_in_fast_recovery() const { + return fast_recovery_exit_tsn_.has_value(); + } + + // Indicates if the provided SACK is valid given what has previously been + // received. If it returns false, the SACK is most likely a duplicate of + // something already seen, so this returning false doesn't necessarily mean + // that the SACK is illegal. + bool IsSackValid(const SackChunk& sack) const; + + // When a SACK chunk is received, this method will be called which _may_ call + // into the `RetransmissionTimeout` to update the RTO. + void UpdateRTT(TimeMs now, UnwrappedTSN cumulative_tsn_ack); + + // If the congestion control is in "fast recovery mode", this may be exited + // now. + void MaybeExitFastRecovery(UnwrappedTSN cumulative_tsn_ack); + + // If chunks have been ACKed, stop the retransmission timer. + void StopT3RtxTimerOnIncreasedCumulativeTsnAck( + UnwrappedTSN cumulative_tsn_ack); + + // Update the congestion control algorithm given as the cumulative ack TSN + // value has increased, as reported in an incoming SACK chunk. + void HandleIncreasedCumulativeTsnAck(size_t outstanding_bytes, + size_t total_bytes_acked); + // Update the congestion control algorithm, given as packet loss has been + // detected, as reported in an incoming SACK chunk. + void HandlePacketLoss(UnwrappedTSN highest_tsn_acked); + // Update the view of the receiver window size. + void UpdateReceiverWindow(uint32_t a_rwnd); + // If there is data sent and not ACKED, ensure that the retransmission timer + // is running. + void StartT3RtxTimerIfOutstandingData(); + + // Returns the current congestion control algorithm phase. + CongestionAlgorithmPhase phase() const { + return (cwnd_ <= ssthresh_) + ? CongestionAlgorithmPhase::kSlowStart + : CongestionAlgorithmPhase::kCongestionAvoidance; + } + + // Returns the number of bytes that may be sent in a single packet according + // to the congestion control algorithm. + size_t max_bytes_to_send() const; + + DcSctpSocketCallbacks& callbacks_; + const DcSctpOptions options_; + // The minimum bytes required to be available in the congestion window to + // allow packets to be sent - to avoid sending too small packets. + const size_t min_bytes_required_to_send_; + // If the peer supports RFC3758 - SCTP Partial Reliability Extension. + const bool partial_reliability_; + const std::string log_prefix_; + // The size of the data chunk (DATA/I-DATA) header that is used. + const size_t data_chunk_header_size_; + // Called when a new RTT measurement has been done + const std::function<void(DurationMs rtt)> on_new_rtt_; + // Called when a SACK has been seen that cleared the retransmission counter. + const std::function<void()> on_clear_retransmission_counter_; + // The retransmission counter. + Timer& t3_rtx_; + // Unwraps TSNs + UnwrappedTSN::Unwrapper tsn_unwrapper_; + + // Congestion Window. Number of bytes that may be in-flight (sent, not acked). + size_t cwnd_; + // Receive Window. Number of bytes available in the receiver's RX buffer. + size_t rwnd_; + // Slow Start Threshold. See RFC4960. + size_t ssthresh_; + // Partial Bytes Acked. See RFC4960. + size_t partial_bytes_acked_; + // If set, fast recovery is enabled until this TSN has been cumulative + // acked. + absl::optional<UnwrappedTSN> fast_recovery_exit_tsn_ = absl::nullopt; + + // The send queue. + SendQueue& send_queue_; + // All the outstanding data chunks that are in-flight and that have not been + // cumulative acked. Note that it also contains chunks that have been acked in + // gap ack blocks. + OutstandingData outstanding_data_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RETRANSMISSION_QUEUE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue_test.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue_test.cc new file mode 100644 index 0000000000..e62c030bfa --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue_test.cc @@ -0,0 +1,1593 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_queue.h" + +#include <cstddef> +#include <cstdint> +#include <functional> +#include <memory> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "net/dcsctp/testing/data_generator.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "net/dcsctp/timer/fake_timeout.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/mock_send_queue.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::MockFunction; +using State = ::dcsctp::RetransmissionQueue::State; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::NiceMock; +using ::testing::Pair; +using ::testing::Return; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +constexpr uint32_t kArwnd = 100000; +constexpr uint32_t kMaxMtu = 1191; + +DcSctpOptions MakeOptions() { + DcSctpOptions options; + options.mtu = kMaxMtu; + return options; +} + +class RetransmissionQueueTest : public testing::Test { + protected: + RetransmissionQueueTest() + : options_(MakeOptions()), + gen_(MID(42)), + timeout_manager_([this]() { return now_; }), + timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return timeout_manager_.CreateTimeout(precision); + }), + timer_(timer_manager_.CreateTimer( + "test/t3_rtx", + []() { return absl::nullopt; }, + TimerOptions(options_.rto_initial))) {} + + std::function<SendQueue::DataToSend(TimeMs, size_t)> CreateChunk() { + return [this](TimeMs now, size_t max_size) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }; + } + + std::vector<TSN> GetTSNsForFastRetransmit(RetransmissionQueue& queue) { + std::vector<TSN> tsns; + for (const auto& elem : queue.GetChunksForFastRetransmit(10000)) { + tsns.push_back(elem.first); + } + return tsns; + } + + std::vector<TSN> GetSentPacketTSNs(RetransmissionQueue& queue) { + std::vector<TSN> tsns; + for (const auto& elem : queue.GetChunksToSend(now_, 10000)) { + tsns.push_back(elem.first); + } + return tsns; + } + + RetransmissionQueue CreateQueue(bool supports_partial_reliability = true, + bool use_message_interleaving = false) { + return RetransmissionQueue( + "", &callbacks_, TSN(10), kArwnd, producer_, on_rtt_.AsStdFunction(), + on_clear_retransmission_counter_.AsStdFunction(), *timer_, options_, + supports_partial_reliability, use_message_interleaving); + } + + std::unique_ptr<RetransmissionQueue> CreateQueueByHandover( + RetransmissionQueue& queue) { + EXPECT_EQ(queue.GetHandoverReadiness(), HandoverReadinessStatus()); + DcSctpSocketHandoverState state; + queue.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + auto queue2 = std::make_unique<RetransmissionQueue>( + "", &callbacks_, TSN(10), kArwnd, producer_, on_rtt_.AsStdFunction(), + on_clear_retransmission_counter_.AsStdFunction(), *timer_, options_, + /*supports_partial_reliability=*/true, + /*use_message_interleaving=*/false); + queue2->RestoreFromState(state); + return queue2; + } + + MockDcSctpSocketCallbacks callbacks_; + DcSctpOptions options_; + DataGenerator gen_; + TimeMs now_ = TimeMs(0); + FakeTimeoutManager timeout_manager_; + TimerManager timer_manager_; + NiceMock<MockFunction<void(DurationMs rtt_ms)>> on_rtt_; + NiceMock<MockFunction<void()>> on_clear_retransmission_counter_; + NiceMock<MockSendQueue> producer_; + std::unique_ptr<Timer> timer_; +}; + +TEST_F(RetransmissionQueueTest, InitialAckedPrevTsn) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, SendOneChunk) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(10))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, SendOneChunkAndAck) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(10))); + + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, SendThreeChunksAndAckTwo) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12))); + + queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, AckWithGapBlocksFromRFC4960Section334) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 5)}, + {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kNacked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kNacked), // + Pair(TSN(17), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, ResendPacketsWhenNackedThreeTimes) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + // Send more chunks, but leave some as gaps to force retransmission after + // three NACKs. + + // Send 18 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(18))); + + // Ack 12, 14-15, 17-18 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 6)}, + {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kNacked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kNacked), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked))); + + // Send 19 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(19))); + + // Ack 12, 14-15, 17-19 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 7)}, + {})); + + // Send 20 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(20))); + + // Ack 12, 14-15, 17-20 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 8)}, + {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kToBeRetransmitted), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kToBeRetransmitted), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kAcked), // + Pair(TSN(20), State::kAcked))); + + // This will trigger "fast retransmit" mode and only chunks 13 and 16 will be + // resent right now. The send queue will not even be queried. + EXPECT_CALL(producer_, Produce).Times(0); + + EXPECT_THAT(GetTSNsForFastRetransmit(queue), + testing::ElementsAre(TSN(13), TSN(16))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kInFlight), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kAcked), // + Pair(TSN(20), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, RestartsT3RtxOnRetransmitFirstOutstandingTSN) { + // Verifies that if fast retransmit is retransmitting the first outstanding + // TSN, it will also restart T3-RTX. + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + static constexpr TimeMs kStartTime(100000); + now_ = kStartTime; + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12))); + + // Ack 10, 12, after 100ms. + now_ += DurationMs(100); + queue.HandleSack( + now_, SackChunk(TSN(10), kArwnd, {SackChunk::GapAckBlock(2, 2)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kNacked), // + Pair(TSN(12), State::kAcked))); + + // Send 13 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(13))); + + // Ack 10, 12-13, after 100ms. + now_ += DurationMs(100); + queue.HandleSack( + now_, SackChunk(TSN(10), kArwnd, {SackChunk::GapAckBlock(2, 3)}, {})); + + // Send 14 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(14))); + + // Ack 10, 12-14, after 100 ms. + now_ += DurationMs(100); + queue.HandleSack( + now_, SackChunk(TSN(10), kArwnd, {SackChunk::GapAckBlock(2, 4)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kToBeRetransmitted), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked))); + + // This will trigger "fast retransmit" mode and only chunks 13 and 16 will be + // resent right now. The send queue will not even be queried. + EXPECT_CALL(producer_, Produce).Times(0); + + EXPECT_THAT(GetTSNsForFastRetransmit(queue), testing::ElementsAre(TSN(11))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked))); + + // Verify that the timer was really restarted when fast-retransmitting. The + // timeout is `options_.rto_initial`, so advance the time just before that. + now_ += options_.rto_initial - DurationMs(1); + EXPECT_FALSE(timeout_manager_.GetNextExpiredTimeout().has_value()); + + // And ensure it really is running. + now_ += DurationMs(1); + ASSERT_HAS_VALUE_AND_ASSIGN(TimeoutID timeout, + timeout_manager_.GetNextExpiredTimeout()); + // An expired timeout has to be handled (asserts validate this). + timer_manager_.HandleTimeout(timeout); +} + +TEST_F(RetransmissionQueueTest, CanOnlyProduceTwoPacketsButWantsToSendThree) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }) + .WillOnce([this](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, RetransmitsOnT3Expiry) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + // Will force chunks to be retransmitted + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + + std::vector<std::pair<TSN, Data>> chunks_to_rtx = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_rtx, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, LimitedRetransmissionOnlyWithRfc3758Support) { + RetransmissionQueue queue = + CreateQueue(/*supports_partial_reliability=*/false); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + // Will force chunks to be retransmitted + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); +} // namespace dcsctp + +TEST_F(RetransmissionQueueTest, LimitsRetransmissionsAsUdp) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + // Will force chunks to be retransmitted + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(1); + + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); + + std::vector<std::pair<TSN, Data>> chunks_to_rtx = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_rtx, testing::IsEmpty()); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); +} + +TEST_F(RetransmissionQueueTest, LimitsRetransmissionsToThreeSends) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = MaxRetransmits(3); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + + // Retransmission 1 + queue.HandleT3RtxTimerExpiry(); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), SizeIs(1)); + + // Retransmission 2 + queue.HandleT3RtxTimerExpiry(); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), SizeIs(1)); + + // Retransmission 3 + queue.HandleT3RtxTimerExpiry(); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), SizeIs(1)); + + // Retransmission 4 - not allowed. + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(1); + queue.HandleT3RtxTimerExpiry(); + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), IsEmpty()); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); +} + +TEST_F(RetransmissionQueueTest, RetransmitsWhenSendBufferIsFullT3Expiry) { + RetransmissionQueue queue = CreateQueue(); + static constexpr size_t kCwnd = 1200; + queue.set_cwnd(kCwnd); + EXPECT_EQ(queue.cwnd(), kCwnd); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); + + std::vector<uint8_t> payload(1000); + EXPECT_CALL(producer_, Produce) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1500); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + EXPECT_EQ(queue.outstanding_bytes(), payload.size() + DataChunk::kHeaderSize); + EXPECT_EQ(queue.outstanding_items(), 1u); + + // Will force chunks to be retransmitted + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); + + std::vector<std::pair<TSN, Data>> chunks_to_rtx = + queue.GetChunksToSend(now_, 1500); + EXPECT_THAT(chunks_to_rtx, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + EXPECT_EQ(queue.outstanding_bytes(), payload.size() + DataChunk::kHeaderSize); + EXPECT_EQ(queue.outstanding_items(), 1u); +} + +TEST_F(RetransmissionQueueTest, ProducesValidForwardTsn) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + // Send and ack first chunk (TSN 10) + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight))); + + // Chunk 10 is acked, but the remaining are lost + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(true)); + + queue.HandleT3RtxTimerExpiry(); + + // NOTE: The TSN=13 represents the end fragment. + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + ForwardTsnChunk forward_tsn = queue.CreateForwardTsn(); + EXPECT_EQ(forward_tsn.new_cumulative_tsn(), TSN(13)); + EXPECT_THAT(forward_tsn.skipped_streams(), + UnorderedElementsAre( + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(42)))); +} + +TEST_F(RetransmissionQueueTest, ProducesValidForwardTsnWhenFullySent) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "E")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + // Send and ack first chunk (TSN 10) + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight))); + + // Chunk 10 is acked, but the remaining are lost + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + ForwardTsnChunk forward_tsn = queue.CreateForwardTsn(); + EXPECT_EQ(forward_tsn.new_cumulative_tsn(), TSN(12)); + EXPECT_THAT(forward_tsn.skipped_streams(), + UnorderedElementsAre( + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(42)))); +} + +TEST_F(RetransmissionQueueTest, ProducesValidIForwardTsn) { + RetransmissionQueue queue = CreateQueue(/*use_message_interleaving=*/true); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(1); + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B", opts)); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(2); + SendQueue::DataToSend dts(gen_.Unordered({1, 2, 3, 4}, "B", opts)); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(3); + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "B", opts)); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(4); + SendQueue::DataToSend dts(gen_.Ordered({13, 14, 15, 16}, "B", opts)); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _), Pair(TSN(13), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight))); + + // Chunk 13 is acked, but the remaining are lost + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(4, 4)}, {})); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kNacked), // + Pair(TSN(12), State::kNacked), // + Pair(TSN(13), State::kAcked))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(true)); + EXPECT_CALL(producer_, Discard(IsUnordered(true), StreamID(2), MID(42))) + .WillOnce(Return(true)); + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(3), MID(42))) + .WillOnce(Return(true)); + + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAcked), + // Representing end fragments of stream 1-3 + Pair(TSN(14), State::kAbandoned), // + Pair(TSN(15), State::kAbandoned), // + Pair(TSN(16), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + IForwardTsnChunk forward_tsn1 = queue.CreateIForwardTsn(); + EXPECT_EQ(forward_tsn1.new_cumulative_tsn(), TSN(12)); + EXPECT_THAT( + forward_tsn1.skipped_streams(), + UnorderedElementsAre(IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(1), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(true), StreamID(2), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(3), MID(42)))); + + // When TSN 13 is acked, the placeholder "end fragments" must be skipped as + // well. + + // A receiver is more likely to ack TSN 13, but do it incrementally. + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, {}, {})); + + EXPECT_CALL(producer_, Discard).Times(0); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + queue.HandleSack(now_, SackChunk(TSN(13), kArwnd, {}, {})); + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAbandoned), // + Pair(TSN(15), State::kAbandoned), // + Pair(TSN(16), State::kAbandoned))); + + IForwardTsnChunk forward_tsn2 = queue.CreateIForwardTsn(); + EXPECT_EQ(forward_tsn2.new_cumulative_tsn(), TSN(16)); + EXPECT_THAT( + forward_tsn2.skipped_streams(), + UnorderedElementsAre(IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(1), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(true), StreamID(2), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(3), MID(42)))); +} + +TEST_F(RetransmissionQueueTest, MeasureRTT) { + RetransmissionQueue queue = CreateQueue(/*use_message_interleaving=*/true); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + + now_ = now_ + DurationMs(123); + + EXPECT_CALL(on_rtt_, Call(DurationMs(123))).Times(1); + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); +} + +TEST_F(RetransmissionQueueTest, ValidateCumTsnAtRest) { + RetransmissionQueue queue = CreateQueue(/*use_message_interleaving=*/true); + + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(8), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(9), kArwnd, {}, {}))); + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {}))); +} + +TEST_F(RetransmissionQueueTest, ValidateCumTsnAckOnInflightData) { + RetransmissionQueue queue = CreateQueue(); + + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(8), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(9), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(13), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(14), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(15), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(16), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(17), kArwnd, {}, {}))); + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(18), kArwnd, {}, {}))); +} + +TEST_F(RetransmissionQueueTest, HandleGapAckBlocksMatchingNoInflightData) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + // Ack 9, 20-25. This is an invalid SACK, but should still be handled. + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(11, 16)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight), // + Pair(TSN(14), State::kInFlight), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, HandleInvalidGapAckBlocks) { + RetransmissionQueue queue = CreateQueue(); + + // Nothing produced - nothing in retransmission queue + + // Ack 9, 12-13 + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(3, 4)}, {})); + + // Gap ack blocks are just ignore. + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, GapAckBlocksDoNotMoveCumTsnAck) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + // Ack 9, 10-14. This is actually an invalid ACK as the first gap can't be + // adjacent to the cum-tsn-ack, but it's not strictly forbidden. However, the + // cum-tsn-ack should not move, as the gap-ack-blocks are just advisory. + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(1, 5)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, StaysWithinAvailableSize) { + RetransmissionQueue queue = CreateQueue(); + + // See SctpPacketTest::ReturnsCorrectSpaceAvailableToStayWithinMTU for the + // magic numbers in this test. + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t size) { + EXPECT_EQ(size, 1176 - DataChunk::kHeaderSize); + + std::vector<uint8_t> payload(183); + return SendQueue::DataToSend(gen_.Ordered(payload, "BE")); + }) + .WillOnce([this](TimeMs, size_t size) { + EXPECT_EQ(size, 976 - DataChunk::kHeaderSize); + + std::vector<uint8_t> payload(957); + return SendQueue::DataToSend(gen_.Ordered(payload, "BE")); + }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1188 - 12); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _))); +} + +TEST_F(RetransmissionQueueTest, AccountsNackedAbandonedChunksAsNotOutstanding) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + // Send and ack first chunk (TSN 10) + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight))); + EXPECT_EQ(queue.outstanding_bytes(), (16 + 4) * 3u); + EXPECT_EQ(queue.outstanding_items(), 3u); + + // Mark the message as lost. + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(1); + queue.HandleT3RtxTimerExpiry(); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned))); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); + + // Now ACK those, one at a time. + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); + + queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {})); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); + + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, {}, {})); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); +} + +TEST_F(RetransmissionQueueTest, ExpireFromSendQueueWhenPartiallySent) { + RetransmissionQueue queue = CreateQueue(); + DataGeneratorOptions options; + options.stream_id = StreamID(17); + options.message_id = MID(42); + TimeMs test_start = now_; + EXPECT_CALL(producer_, Produce) + .WillOnce([&](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B", options)); + dts.expires_at = TimeMs(test_start + DurationMs(10)); + return dts; + }) + .WillOnce([&](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "", options)); + dts.expires_at = TimeMs(test_start + DurationMs(10)); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 24); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(17), MID(42))) + .WillOnce(Return(true)); + now_ += DurationMs(100); + + EXPECT_THAT(queue.GetChunksToSend(now_, 24), IsEmpty()); + + EXPECT_THAT( + queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // Initial TSN + Pair(TSN(10), State::kAbandoned), // Produced + Pair(TSN(11), State::kAbandoned), // Produced and expired + Pair(TSN(12), State::kAbandoned))); // Placeholder end +} + +TEST_F(RetransmissionQueueTest, LimitsRetransmissionsOnlyWhenNackedThreeTimes) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _), Pair(TSN(13), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 2)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 3)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 4)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); +} + +TEST_F(RetransmissionQueueTest, AbandonsRtxLimit2WhenNackedNineTimes) { + // This is a fairly long test. + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = MaxRetransmits(2); + return dts; + }) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, + ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), Pair(TSN(12), _), + Pair(TSN(13), _), Pair(TSN(14), _), Pair(TSN(15), _), + Pair(TSN(16), _), Pair(TSN(17), _), Pair(TSN(18), _), + Pair(TSN(19), _))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight), // + Pair(TSN(14), State::kInFlight), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight), // + Pair(TSN(18), State::kInFlight), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + + // Ack TSN [11 to 13] - three nacks for TSN(10), which will retransmit it. + for (int tsn = 11; tsn <= 13; ++tsn) { + queue.HandleSack( + now_, + SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, (tsn - 9))}, {})); + } + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kInFlight), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight), // + Pair(TSN(18), State::kInFlight), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_THAT(queue.GetChunksForFastRetransmit(1000), + ElementsAre(Pair(TSN(10), _))); + + // Ack TSN [14 to 16] - three more nacks - second and last retransmission. + for (int tsn = 14; tsn <= 16; ++tsn) { + queue.HandleSack( + now_, + SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, (tsn - 9))}, {})); + } + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kAcked), // + Pair(TSN(17), State::kInFlight), // + Pair(TSN(18), State::kInFlight), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), ElementsAre(Pair(TSN(10), _))); + + // Ack TSN [17 to 18] + for (int tsn = 17; tsn <= 18; ++tsn) { + queue.HandleSack( + now_, + SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, (tsn - 9))}, {})); + } + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kAcked), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + // Ack TSN 19 - three more nacks for TSN 10, no more retransmissions. + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 10)}, {})); + + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), IsEmpty()); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kAcked), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kAcked))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); +} + +TEST_F(RetransmissionQueueTest, CwndRecoversWhenAcking) { + RetransmissionQueue queue = CreateQueue(); + static constexpr size_t kCwnd = 1200; + queue.set_cwnd(kCwnd); + EXPECT_EQ(queue.cwnd(), kCwnd); + + std::vector<uint8_t> payload(1000); + EXPECT_CALL(producer_, Produce) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1500); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + size_t serialized_size = payload.size() + DataChunk::kHeaderSize; + EXPECT_EQ(queue.outstanding_bytes(), serialized_size); + + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_EQ(queue.cwnd(), kCwnd + serialized_size); +} + +// Verifies that it doesn't produce tiny packets, when getting close to +// the full congestion window. +TEST_F(RetransmissionQueueTest, OnlySendsLargePacketsOnLargeCongestionWindow) { + RetransmissionQueue queue = CreateQueue(); + size_t intial_cwnd = options_.avoid_fragmentation_cwnd_mtus * options_.mtu; + queue.set_cwnd(intial_cwnd); + EXPECT_EQ(queue.cwnd(), intial_cwnd); + + // Fill the congestion window almost - leaving 500 bytes. + size_t chunk_size = intial_cwnd - 500; + EXPECT_CALL(producer_, Produce) + .WillOnce([chunk_size, this](TimeMs, size_t) { + return SendQueue::DataToSend( + gen_.Ordered(std::vector<uint8_t>(chunk_size), "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_TRUE(queue.can_send_data()); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 10000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + + // To little space left - will not send more. + EXPECT_FALSE(queue.can_send_data()); + + // But when the first chunk is acked, it will continue. + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_TRUE(queue.can_send_data()); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.cwnd(), intial_cwnd + kMaxMtu); +} + +TEST_F(RetransmissionQueueTest, AllowsSmallFragmentsOnSmallCongestionWindow) { + RetransmissionQueue queue = CreateQueue(); + size_t intial_cwnd = + options_.avoid_fragmentation_cwnd_mtus * options_.mtu - 1; + queue.set_cwnd(intial_cwnd); + EXPECT_EQ(queue.cwnd(), intial_cwnd); + + // Fill the congestion window almost - leaving 500 bytes. + size_t chunk_size = intial_cwnd - 500; + EXPECT_CALL(producer_, Produce) + .WillOnce([chunk_size, this](TimeMs, size_t) { + return SendQueue::DataToSend( + gen_.Ordered(std::vector<uint8_t>(chunk_size), "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_TRUE(queue.can_send_data()); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 10000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + + // With congestion window under limit, allow small packets to be created. + EXPECT_TRUE(queue.can_send_data()); +} + +TEST_F(RetransmissionQueueTest, ReadyForHandoverWhenHasNoOutstandingData) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1)); + EXPECT_EQ( + queue.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kRetransmissionQueueOutstandingData)); + + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + EXPECT_EQ(queue.GetHandoverReadiness(), HandoverReadinessStatus()); +} + +TEST_F(RetransmissionQueueTest, ReadyForHandoverWhenNothingToRetransmit) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(8)); + EXPECT_EQ( + queue.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kRetransmissionQueueOutstandingData)); + + // Send more chunks, but leave some chunks unacked to force retransmission + // after three NACKs. + + // Send 18 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1)); + + // Ack 12, 14-15, 17-18 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 6)}, + {})); + + // Send 19 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1)); + + // Ack 12, 14-15, 17-19 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 7)}, + {})); + + // Send 20 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1)); + + // Ack 12, 14-15, 17-20 + // This will trigger "fast retransmit" mode and only chunks 13 and 16 will be + // resent right now. The send queue will not even be queried. + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 8)}, + {})); + EXPECT_EQ( + queue.GetHandoverReadiness(), + HandoverReadinessStatus() + .Add(HandoverUnreadinessReason::kRetransmissionQueueOutstandingData) + .Add(HandoverUnreadinessReason::kRetransmissionQueueFastRecovery) + .Add(HandoverUnreadinessReason::kRetransmissionQueueNotEmpty)); + + // Send "fast retransmit" mode chunks + EXPECT_CALL(producer_, Produce).Times(0); + EXPECT_THAT(GetTSNsForFastRetransmit(queue), SizeIs(2)); + EXPECT_EQ( + queue.GetHandoverReadiness(), + HandoverReadinessStatus() + .Add(HandoverUnreadinessReason::kRetransmissionQueueOutstandingData) + .Add(HandoverUnreadinessReason::kRetransmissionQueueFastRecovery)); + + // Ack 20 to confirm the retransmission + queue.HandleSack(now_, SackChunk(TSN(20), kArwnd, {}, {})); + EXPECT_EQ(queue.GetHandoverReadiness(), HandoverReadinessStatus()); +} + +TEST_F(RetransmissionQueueTest, HandoverTest) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(2)); + queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {})); + + std::unique_ptr<RetransmissionQueue> handedover_queue = + CreateQueueByHandover(queue); + + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(*handedover_queue), + testing::ElementsAre(TSN(12), TSN(13), TSN(14))); + + handedover_queue->HandleSack(now_, SackChunk(TSN(13), kArwnd, {}, {})); + EXPECT_THAT(handedover_queue->GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, CanAlwaysSendOnePacket) { + RetransmissionQueue queue = CreateQueue(); + + // A large payload - enough to not fit two DATA in same packet. + size_t mtu = RoundDownTo4(options_.mtu); + std::vector<uint8_t> payload(mtu - 100); + + EXPECT_CALL(producer_, Produce) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "B")); + }) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "")); + }) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "")); + }) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "")); + }) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "E")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + // Produce all chunks and put them in the retransmission queue. + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 5 * mtu); + EXPECT_THAT(chunks_to_send, + ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), Pair(TSN(12), _), + Pair(TSN(13), _), Pair(TSN(14), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), + Pair(TSN(13), State::kInFlight), + Pair(TSN(14), State::kInFlight))); + + // Ack 12, and report an empty receiver window (the peer obviously has a + // tiny receive window). + queue.HandleSack( + now_, SackChunk(TSN(9), /*rwnd=*/0, {SackChunk::GapAckBlock(3, 3)}, {})); + + // Force TSN 10 to be retransmitted. + queue.HandleT3RtxTimerExpiry(); + + // Even if the receiver window is empty, it will allow TSN 10 to be sent. + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), ElementsAre(Pair(TSN(10), _))); + + // But not more than that, as there now is outstanding data. + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), IsEmpty()); + + // Don't ack any new data, and still have receiver window zero. + queue.HandleSack( + now_, SackChunk(TSN(9), /*rwnd=*/0, {SackChunk::GapAckBlock(3, 3)}, {})); + + // There is in-flight data, so new data should not be allowed to be send since + // the receiver window is full. + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), IsEmpty()); + + // Ack that packet (no more in-flight data), but still report an empty + // receiver window. + queue.HandleSack( + now_, SackChunk(TSN(10), /*rwnd=*/0, {SackChunk::GapAckBlock(2, 2)}, {})); + + // Then TSN 11 can be sent, as there is no in-flight data. + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), ElementsAre(Pair(TSN(11), _))); + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), IsEmpty()); + + // Ack and recover the receiver window + queue.HandleSack(now_, SackChunk(TSN(12), /*rwnd=*/5 * mtu, {}, {})); + + // That will unblock sending remaining chunks. + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), ElementsAre(Pair(TSN(13), _))); + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), ElementsAre(Pair(TSN(14), _))); + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), IsEmpty()); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.cc new file mode 100644 index 0000000000..7d8fb9761c --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.cc @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_timeout.h" + +#include <algorithm> +#include <cstdint> + +#include "net/dcsctp/public/dcsctp_options.h" + +namespace dcsctp { + +RetransmissionTimeout::RetransmissionTimeout(const DcSctpOptions& options) + : min_rto_(*options.rto_min), + max_rto_(*options.rto_max), + max_rtt_(*options.rtt_max), + min_rtt_variance_(*options.min_rtt_variance), + scaled_srtt_(*options.rto_initial << kRttShift), + rto_(*options.rto_initial) {} + +void RetransmissionTimeout::ObserveRTT(DurationMs measured_rtt) { + const int32_t rtt = *measured_rtt; + + // Unrealistic values will be skipped. If a wrongly measured (or otherwise + // corrupt) value was processed, it could change the state in a way that would + // take a very long time to recover. + if (rtt < 0 || rtt > max_rtt_) { + return; + } + + // From https://tools.ietf.org/html/rfc4960#section-6.3.1, but avoiding + // floating point math by implementing algorithm from "V. Jacobson: Congestion + // avoidance and control", but adapted for SCTP. + if (first_measurement_) { + scaled_srtt_ = rtt << kRttShift; + scaled_rtt_var_ = (rtt / 2) << kRttVarShift; + first_measurement_ = false; + } else { + int32_t rtt_diff = rtt - (scaled_srtt_ >> kRttShift); + scaled_srtt_ += rtt_diff; + if (rtt_diff < 0) { + rtt_diff = -rtt_diff; + } + rtt_diff -= (scaled_rtt_var_ >> kRttVarShift); + scaled_rtt_var_ += rtt_diff; + } + + if (scaled_rtt_var_ < min_rtt_variance_) { + scaled_rtt_var_ = min_rtt_variance_; + } + + rto_ = (scaled_srtt_ >> kRttShift) + scaled_rtt_var_; + + // Clamp RTO between min and max. + rto_ = std::min(std::max(rto_, min_rto_), max_rto_); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.h b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.h new file mode 100644 index 0000000000..01530cb3b5 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RETRANSMISSION_TIMEOUT_H_ +#define NET_DCSCTP_TX_RETRANSMISSION_TIMEOUT_H_ + +#include <cstdint> +#include <functional> + +#include "net/dcsctp/public/dcsctp_options.h" + +namespace dcsctp { + +// Manages updating of the Retransmission Timeout (RTO) SCTP variable, which is +// used directly as the base timeout for T3-RTX and for other timers, such as +// delayed ack. +// +// When a round-trip-time (RTT) is calculated (outside this class), `Observe` +// is called, which calculates the retransmission timeout (RTO) value. The RTO +// value will become larger if the RTT is high and/or the RTT values are varying +// a lot, which is an indicator of a bad connection. +class RetransmissionTimeout { + public: + static constexpr int kRttShift = 3; + static constexpr int kRttVarShift = 2; + explicit RetransmissionTimeout(const DcSctpOptions& options); + + // To be called when a RTT has been measured, to update the RTO value. + void ObserveRTT(DurationMs measured_rtt); + + // Returns the Retransmission Timeout (RTO) value, in milliseconds. + DurationMs rto() const { return DurationMs(rto_); } + + // Returns the smoothed RTT value, in milliseconds. + DurationMs srtt() const { return DurationMs(scaled_srtt_ >> kRttShift); } + + private: + const int32_t min_rto_; + const int32_t max_rto_; + const int32_t max_rtt_; + const int32_t min_rtt_variance_; + // If this is the first measurement + bool first_measurement_ = true; + // Smoothed Round-Trip Time, shifted by kRttShift + int32_t scaled_srtt_; + // Round-Trip Time Variation, shifted by kRttVarShift + int32_t scaled_rtt_var_ = 0; + // Retransmission Timeout + int32_t rto_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RETRANSMISSION_TIMEOUT_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout_test.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout_test.cc new file mode 100644 index 0000000000..b901995e97 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout_test.cc @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_timeout.h" + +#include "net/dcsctp/public/dcsctp_options.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +constexpr DurationMs kMaxRtt = DurationMs(8'000); +constexpr DurationMs kInitialRto = DurationMs(200); +constexpr DurationMs kMaxRto = DurationMs(800); +constexpr DurationMs kMinRto = DurationMs(120); +constexpr DurationMs kMinRttVariance = DurationMs(220); + +DcSctpOptions MakeOptions() { + DcSctpOptions options; + options.rtt_max = kMaxRtt; + options.rto_initial = kInitialRto; + options.rto_max = kMaxRto; + options.rto_min = kMinRto; + options.min_rtt_variance = kMinRttVariance; + return options; +} + +TEST(RetransmissionTimeoutTest, HasValidInitialRto) { + RetransmissionTimeout rto_(MakeOptions()); + EXPECT_EQ(rto_.rto(), kInitialRto); +} + +TEST(RetransmissionTimeoutTest, HasValidInitialSrtt) { + RetransmissionTimeout rto_(MakeOptions()); + EXPECT_EQ(rto_.srtt(), kInitialRto); +} + +TEST(RetransmissionTimeoutTest, NegativeValuesDoNotAffectRTO) { + RetransmissionTimeout rto_(MakeOptions()); + // Initial negative value + rto_.ObserveRTT(DurationMs(-10)); + EXPECT_EQ(rto_.rto(), kInitialRto); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + // Subsequent negative value + rto_.ObserveRTT(DurationMs(-10)); + EXPECT_EQ(*rto_.rto(), 372); +} + +TEST(RetransmissionTimeoutTest, TooLargeValuesDoNotAffectRTO) { + RetransmissionTimeout rto_(MakeOptions()); + // Initial too large value + rto_.ObserveRTT(kMaxRtt + DurationMs(100)); + EXPECT_EQ(rto_.rto(), kInitialRto); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + // Subsequent too large value + rto_.ObserveRTT(kMaxRtt + DurationMs(100)); + EXPECT_EQ(*rto_.rto(), 372); +} + +TEST(RetransmissionTimeoutTest, WillNeverGoBelowMinimumRto) { + RetransmissionTimeout rto_(MakeOptions()); + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(DurationMs(1)); + } + EXPECT_GE(rto_.rto(), kMinRto); +} + +TEST(RetransmissionTimeoutTest, WillNeverGoAboveMaximumRto) { + RetransmissionTimeout rto_(MakeOptions()); + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(kMaxRtt - DurationMs(1)); + // Adding jitter, which would make it RTO be well above RTT. + rto_.ObserveRTT(kMaxRtt - DurationMs(100)); + } + EXPECT_LE(rto_.rto(), kMaxRto); +} + +TEST(RetransmissionTimeoutTest, CalculatesRtoForStableRtt) { + RetransmissionTimeout rto_(MakeOptions()); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + rto_.ObserveRTT(DurationMs(128)); + EXPECT_EQ(*rto_.rto(), 344); + rto_.ObserveRTT(DurationMs(123)); + EXPECT_EQ(*rto_.rto(), 344); + rto_.ObserveRTT(DurationMs(125)); + EXPECT_EQ(*rto_.rto(), 344); + rto_.ObserveRTT(DurationMs(127)); + EXPECT_EQ(*rto_.rto(), 344); +} + +TEST(RetransmissionTimeoutTest, CalculatesRtoForUnstableRtt) { + RetransmissionTimeout rto_(MakeOptions()); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + rto_.ObserveRTT(DurationMs(402)); + EXPECT_EQ(*rto_.rto(), 622); + rto_.ObserveRTT(DurationMs(728)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(89)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(126)); + EXPECT_EQ(*rto_.rto(), 800); +} + +TEST(RetransmissionTimeoutTest, WillStabilizeAfterAWhile) { + RetransmissionTimeout rto_(MakeOptions()); + rto_.ObserveRTT(DurationMs(124)); + rto_.ObserveRTT(DurationMs(402)); + rto_.ObserveRTT(DurationMs(728)); + rto_.ObserveRTT(DurationMs(89)); + rto_.ObserveRTT(DurationMs(126)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(122)); + EXPECT_EQ(*rto_.rto(), 710); + rto_.ObserveRTT(DurationMs(123)); + EXPECT_EQ(*rto_.rto(), 631); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 562); + rto_.ObserveRTT(DurationMs(122)); + EXPECT_EQ(*rto_.rto(), 505); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 454); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 410); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 367); +} + +TEST(RetransmissionTimeoutTest, WillAlwaysStayAboveRTT) { + // In simulations, it's quite common to have a very stable RTT, and having an + // RTO at the same value will cause issues as expiry timers will be scheduled + // to be expire exactly when a packet is supposed to arrive. The RTO must be + // larger than the RTT. In non-simulated environments, this is a non-issue as + // any jitter will increase the RTO. + RetransmissionTimeout rto_(MakeOptions()); + + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(DurationMs(124)); + } + EXPECT_EQ(*rto_.rto(), 344); +} + +TEST(RetransmissionTimeoutTest, CanSpecifySmallerMinimumRttVariance) { + DcSctpOptions options = MakeOptions(); + options.min_rtt_variance = kMinRttVariance - DurationMs(100); + RetransmissionTimeout rto_(options); + + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(DurationMs(124)); + } + EXPECT_EQ(*rto_.rto(), 244); +} + +TEST(RetransmissionTimeoutTest, CanSpecifyLargerMinimumRttVariance) { + DcSctpOptions options = MakeOptions(); + options.min_rtt_variance = kMinRttVariance + DurationMs(100); + RetransmissionTimeout rto_(options); + + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(DurationMs(124)); + } + EXPECT_EQ(*rto_.rto(), 444); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.cc b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.cc new file mode 100644 index 0000000000..b1812f0f8a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.cc @@ -0,0 +1,542 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/rr_send_queue.h" + +#include <cstdint> +#include <deque> +#include <limits> +#include <map> +#include <set> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +RRSendQueue::RRSendQueue(absl::string_view log_prefix, + DcSctpSocketCallbacks* callbacks, + size_t buffer_size, + size_t mtu, + StreamPriority default_priority, + size_t total_buffered_amount_low_threshold) + : log_prefix_(std::string(log_prefix) + "fcfs: "), + callbacks_(*callbacks), + buffer_size_(buffer_size), + default_priority_(default_priority), + scheduler_(mtu), + total_buffered_amount_( + [this]() { callbacks_.OnTotalBufferedAmountLow(); }) { + total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold); +} + +size_t RRSendQueue::OutgoingStream::bytes_to_send_in_next_message() const { + if (pause_state_ == PauseState::kPaused || + pause_state_ == PauseState::kResetting) { + // The stream has paused (and there is no partially sent message). + return 0; + } + + if (items_.empty()) { + return 0; + } + + return items_.front().remaining_size; +} + +void RRSendQueue::OutgoingStream::AddHandoverState( + DcSctpSocketHandoverState::OutgoingStream& state) const { + state.next_ssn = next_ssn_.value(); + state.next_ordered_mid = next_ordered_mid_.value(); + state.next_unordered_mid = next_unordered_mid_.value(); + state.priority = *scheduler_stream_->priority(); +} + +bool RRSendQueue::IsConsistent() const { + std::set<StreamID> expected_active_streams; + std::set<StreamID> actual_active_streams = + scheduler_.ActiveStreamsForTesting(); + + size_t total_buffered_amount = 0; + for (const auto& [stream_id, stream] : streams_) { + total_buffered_amount += stream.buffered_amount().value(); + if (stream.bytes_to_send_in_next_message() > 0) { + expected_active_streams.emplace(stream_id); + } + } + if (expected_active_streams != actual_active_streams) { + auto fn = [&](rtc::StringBuilder& sb, const auto& p) { sb << *p; }; + RTC_DLOG(LS_ERROR) << "Active streams mismatch, is=[" + << StrJoin(actual_active_streams, ",", fn) + << "], expected=[" + << StrJoin(expected_active_streams, ",", fn) << "]"; + return false; + } + + return total_buffered_amount == total_buffered_amount_.value(); +} + +bool RRSendQueue::OutgoingStream::IsConsistent() const { + size_t bytes = 0; + for (const auto& item : items_) { + bytes += item.remaining_size; + } + return bytes == buffered_amount_.value(); +} + +void RRSendQueue::ThresholdWatcher::Decrease(size_t bytes) { + RTC_DCHECK(bytes <= value_); + size_t old_value = value_; + value_ -= bytes; + + if (old_value > low_threshold_ && value_ <= low_threshold_) { + on_threshold_reached_(); + } +} + +void RRSendQueue::ThresholdWatcher::SetLowThreshold(size_t low_threshold) { + // Betting on https://github.com/w3c/webrtc-pc/issues/2654 being accepted. + if (low_threshold_ < value_ && low_threshold >= value_) { + on_threshold_reached_(); + } + low_threshold_ = low_threshold; +} + +void RRSendQueue::OutgoingStream::Add(DcSctpMessage message, + MessageAttributes attributes) { + bool was_active = bytes_to_send_in_next_message() > 0; + buffered_amount_.Increase(message.payload().size()); + parent_.total_buffered_amount_.Increase(message.payload().size()); + items_.emplace_back(std::move(message), std::move(attributes)); + + if (!was_active) { + scheduler_stream_->MaybeMakeActive(); + } + + RTC_DCHECK(IsConsistent()); +} + +absl::optional<SendQueue::DataToSend> RRSendQueue::OutgoingStream::Produce( + TimeMs now, + size_t max_size) { + RTC_DCHECK(pause_state_ != PauseState::kPaused && + pause_state_ != PauseState::kResetting); + + while (!items_.empty()) { + Item& item = items_.front(); + DcSctpMessage& message = item.message; + + // Allocate Message ID and SSN when the first fragment is sent. + if (!item.message_id.has_value()) { + // Oops, this entire message has already expired. Try the next one. + if (item.attributes.expires_at <= now) { + HandleMessageExpired(item); + items_.pop_front(); + continue; + } + + MID& mid = + item.attributes.unordered ? next_unordered_mid_ : next_ordered_mid_; + item.message_id = mid; + mid = MID(*mid + 1); + } + if (!item.attributes.unordered && !item.ssn.has_value()) { + item.ssn = next_ssn_; + next_ssn_ = SSN(*next_ssn_ + 1); + } + + // Grab the next `max_size` fragment from this message and calculate flags. + rtc::ArrayView<const uint8_t> chunk_payload = + item.message.payload().subview(item.remaining_offset, max_size); + rtc::ArrayView<const uint8_t> message_payload = message.payload(); + Data::IsBeginning is_beginning(chunk_payload.data() == + message_payload.data()); + Data::IsEnd is_end((chunk_payload.data() + chunk_payload.size()) == + (message_payload.data() + message_payload.size())); + + StreamID stream_id = message.stream_id(); + PPID ppid = message.ppid(); + + // Zero-copy the payload if the message fits in a single chunk. + std::vector<uint8_t> payload = + is_beginning && is_end + ? std::move(message).ReleasePayload() + : std::vector<uint8_t>(chunk_payload.begin(), chunk_payload.end()); + + FSN fsn(item.current_fsn); + item.current_fsn = FSN(*item.current_fsn + 1); + buffered_amount_.Decrease(payload.size()); + parent_.total_buffered_amount_.Decrease(payload.size()); + + SendQueue::DataToSend chunk(Data(stream_id, item.ssn.value_or(SSN(0)), + item.message_id.value(), fsn, ppid, + std::move(payload), is_beginning, is_end, + item.attributes.unordered)); + chunk.max_retransmissions = item.attributes.max_retransmissions; + chunk.expires_at = item.attributes.expires_at; + chunk.lifecycle_id = + is_end ? item.attributes.lifecycle_id : LifecycleId::NotSet(); + + if (is_end) { + // The entire message has been sent, and its last data copied to `chunk`, + // so it can safely be discarded. + items_.pop_front(); + + if (pause_state_ == PauseState::kPending) { + RTC_DLOG(LS_VERBOSE) << "Pause state on " << *stream_id + << " is moving from pending to paused"; + pause_state_ = PauseState::kPaused; + } + } else { + item.remaining_offset += chunk_payload.size(); + item.remaining_size -= chunk_payload.size(); + RTC_DCHECK(item.remaining_offset + item.remaining_size == + item.message.payload().size()); + RTC_DCHECK(item.remaining_size > 0); + } + RTC_DCHECK(IsConsistent()); + return chunk; + } + RTC_DCHECK(IsConsistent()); + return absl::nullopt; +} + +void RRSendQueue::OutgoingStream::HandleMessageExpired( + OutgoingStream::Item& item) { + buffered_amount_.Decrease(item.remaining_size); + parent_.total_buffered_amount_.Decrease(item.remaining_size); + if (item.attributes.lifecycle_id.IsSet()) { + RTC_DLOG(LS_VERBOSE) << "Triggering OnLifecycleMessageExpired(" + << item.attributes.lifecycle_id.value() << ", false)"; + + parent_.callbacks_.OnLifecycleMessageExpired(item.attributes.lifecycle_id, + /*maybe_delivered=*/false); + parent_.callbacks_.OnLifecycleEnd(item.attributes.lifecycle_id); + } +} + +bool RRSendQueue::OutgoingStream::Discard(IsUnordered unordered, + MID message_id) { + bool result = false; + if (!items_.empty()) { + Item& item = items_.front(); + if (item.attributes.unordered == unordered && item.message_id.has_value() && + *item.message_id == message_id) { + HandleMessageExpired(item); + items_.pop_front(); + + // Only partially sent messages are discarded, so if a message was + // discarded, then it was the currently sent message. + scheduler_stream_->ForceReschedule(); + + if (pause_state_ == PauseState::kPending) { + pause_state_ = PauseState::kPaused; + scheduler_stream_->MakeInactive(); + } else if (bytes_to_send_in_next_message() == 0) { + scheduler_stream_->MakeInactive(); + } + + // As the item still existed, it had unsent data. + result = true; + } + } + RTC_DCHECK(IsConsistent()); + return result; +} + +void RRSendQueue::OutgoingStream::Pause() { + if (pause_state_ != PauseState::kNotPaused) { + // Already in progress. + return; + } + + bool had_pending_items = !items_.empty(); + + // https://datatracker.ietf.org/doc/html/rfc8831#section-6.7 + // "Closing of a data channel MUST be signaled by resetting the corresponding + // outgoing streams [RFC6525]. This means that if one side decides to close + // the data channel, it resets the corresponding outgoing stream." + // ... "[RFC6525] also guarantees that all the messages are delivered (or + // abandoned) before the stream is reset." + + // A stream is paused when it's about to be reset. In this implementation, + // it will throw away all non-partially send messages - they will be abandoned + // as noted above. This is subject to change. It will however not discard any + // partially sent messages - only whole messages. Partially delivered messages + // (at the time of receiving a Stream Reset command) will always deliver all + // the fragments before actually resetting the stream. + for (auto it = items_.begin(); it != items_.end();) { + if (it->remaining_offset == 0) { + HandleMessageExpired(*it); + it = items_.erase(it); + } else { + ++it; + } + } + + pause_state_ = (items_.empty() || items_.front().remaining_offset == 0) + ? PauseState::kPaused + : PauseState::kPending; + + if (had_pending_items && pause_state_ == PauseState::kPaused) { + RTC_DLOG(LS_VERBOSE) << "Stream " << *stream_id() + << " was previously active, but is now paused."; + scheduler_stream_->MakeInactive(); + } + + RTC_DCHECK(IsConsistent()); +} + +void RRSendQueue::OutgoingStream::Resume() { + RTC_DCHECK(pause_state_ == PauseState::kResetting); + pause_state_ = PauseState::kNotPaused; + scheduler_stream_->MaybeMakeActive(); + RTC_DCHECK(IsConsistent()); +} + +void RRSendQueue::OutgoingStream::Reset() { + // This can be called both when an outgoing stream reset has been responded + // to, or when the entire SendQueue is reset due to detecting the peer having + // restarted. The stream may be in any state at this time. + PauseState old_pause_state = pause_state_; + pause_state_ = PauseState::kNotPaused; + next_ordered_mid_ = MID(0); + next_unordered_mid_ = MID(0); + next_ssn_ = SSN(0); + if (!items_.empty()) { + // If this message has been partially sent, reset it so that it will be + // re-sent. + auto& item = items_.front(); + buffered_amount_.Increase(item.message.payload().size() - + item.remaining_size); + parent_.total_buffered_amount_.Increase(item.message.payload().size() - + item.remaining_size); + item.remaining_offset = 0; + item.remaining_size = item.message.payload().size(); + item.message_id = absl::nullopt; + item.ssn = absl::nullopt; + item.current_fsn = FSN(0); + if (old_pause_state == PauseState::kPaused || + old_pause_state == PauseState::kResetting) { + scheduler_stream_->MaybeMakeActive(); + } + } + RTC_DCHECK(IsConsistent()); +} + +bool RRSendQueue::OutgoingStream::has_partially_sent_message() const { + if (items_.empty()) { + return false; + } + return items_.front().message_id.has_value(); +} + +void RRSendQueue::Add(TimeMs now, + DcSctpMessage message, + const SendOptions& send_options) { + RTC_DCHECK(!message.payload().empty()); + // Any limited lifetime should start counting from now - when the message + // has been added to the queue. + + // `expires_at` is the time when it expires. Which is slightly larger than the + // message's lifetime, as the message is alive during its entire lifetime + // (which may be zero). + MessageAttributes attributes = { + .unordered = send_options.unordered, + .max_retransmissions = + send_options.max_retransmissions.has_value() + ? MaxRetransmits(send_options.max_retransmissions.value()) + : MaxRetransmits::NoLimit(), + .expires_at = send_options.lifetime.has_value() + ? now + *send_options.lifetime + DurationMs(1) + : TimeMs::InfiniteFuture(), + .lifecycle_id = send_options.lifecycle_id, + }; + GetOrCreateStreamInfo(message.stream_id()) + .Add(std::move(message), std::move(attributes)); + RTC_DCHECK(IsConsistent()); +} + +bool RRSendQueue::IsFull() const { + return total_buffered_amount() >= buffer_size_; +} + +bool RRSendQueue::IsEmpty() const { + return total_buffered_amount() == 0; +} + +absl::optional<SendQueue::DataToSend> RRSendQueue::Produce(TimeMs now, + size_t max_size) { + return scheduler_.Produce(now, max_size); +} + +bool RRSendQueue::Discard(IsUnordered unordered, + StreamID stream_id, + MID message_id) { + bool has_discarded = + GetOrCreateStreamInfo(stream_id).Discard(unordered, message_id); + + RTC_DCHECK(IsConsistent()); + return has_discarded; +} + +void RRSendQueue::PrepareResetStream(StreamID stream_id) { + GetOrCreateStreamInfo(stream_id).Pause(); + RTC_DCHECK(IsConsistent()); +} + +bool RRSendQueue::HasStreamsReadyToBeReset() const { + for (auto& [unused, stream] : streams_) { + if (stream.IsReadyToBeReset()) { + return true; + } + } + return false; +} +std::vector<StreamID> RRSendQueue::GetStreamsReadyToBeReset() { + RTC_DCHECK(absl::c_count_if(streams_, [](const auto& p) { + return p.second.IsResetting(); + }) == 0); + std::vector<StreamID> ready; + for (auto& [stream_id, stream] : streams_) { + if (stream.IsReadyToBeReset()) { + stream.SetAsResetting(); + ready.push_back(stream_id); + } + } + return ready; +} + +void RRSendQueue::CommitResetStreams() { + RTC_DCHECK(absl::c_count_if(streams_, [](const auto& p) { + return p.second.IsResetting(); + }) > 0); + for (auto& [unused, stream] : streams_) { + if (stream.IsResetting()) { + stream.Reset(); + } + } + RTC_DCHECK(IsConsistent()); +} + +void RRSendQueue::RollbackResetStreams() { + RTC_DCHECK(absl::c_count_if(streams_, [](const auto& p) { + return p.second.IsResetting(); + }) > 0); + for (auto& [unused, stream] : streams_) { + if (stream.IsResetting()) { + stream.Resume(); + } + } + RTC_DCHECK(IsConsistent()); +} + +void RRSendQueue::Reset() { + // Recalculate buffered amount, as partially sent messages may have been put + // fully back in the queue. + for (auto& [unused, stream] : streams_) { + stream.Reset(); + } + scheduler_.ForceReschedule(); +} + +size_t RRSendQueue::buffered_amount(StreamID stream_id) const { + auto it = streams_.find(stream_id); + if (it == streams_.end()) { + return 0; + } + return it->second.buffered_amount().value(); +} + +size_t RRSendQueue::buffered_amount_low_threshold(StreamID stream_id) const { + auto it = streams_.find(stream_id); + if (it == streams_.end()) { + return 0; + } + return it->second.buffered_amount().low_threshold(); +} + +void RRSendQueue::SetBufferedAmountLowThreshold(StreamID stream_id, + size_t bytes) { + GetOrCreateStreamInfo(stream_id).buffered_amount().SetLowThreshold(bytes); +} + +RRSendQueue::OutgoingStream& RRSendQueue::GetOrCreateStreamInfo( + StreamID stream_id) { + auto it = streams_.find(stream_id); + if (it != streams_.end()) { + return it->second; + } + + return streams_ + .emplace( + std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple(this, &scheduler_, stream_id, default_priority_, + [this, stream_id]() { + callbacks_.OnBufferedAmountLow(stream_id); + })) + .first->second; +} + +void RRSendQueue::SetStreamPriority(StreamID stream_id, + StreamPriority priority) { + OutgoingStream& stream = GetOrCreateStreamInfo(stream_id); + + stream.SetPriority(priority); + RTC_DCHECK(IsConsistent()); +} + +StreamPriority RRSendQueue::GetStreamPriority(StreamID stream_id) const { + auto stream_it = streams_.find(stream_id); + if (stream_it == streams_.end()) { + return default_priority_; + } + return stream_it->second.priority(); +} + +HandoverReadinessStatus RRSendQueue::GetHandoverReadiness() const { + HandoverReadinessStatus status; + if (!IsEmpty()) { + status.Add(HandoverUnreadinessReason::kSendQueueNotEmpty); + } + return status; +} + +void RRSendQueue::AddHandoverState(DcSctpSocketHandoverState& state) { + for (const auto& [stream_id, stream] : streams_) { + DcSctpSocketHandoverState::OutgoingStream state_stream; + state_stream.id = stream_id.value(); + stream.AddHandoverState(state_stream); + state.tx.streams.push_back(std::move(state_stream)); + } +} + +void RRSendQueue::RestoreFromState(const DcSctpSocketHandoverState& state) { + for (const DcSctpSocketHandoverState::OutgoingStream& state_stream : + state.tx.streams) { + StreamID stream_id(state_stream.id); + streams_.emplace( + std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple( + this, &scheduler_, stream_id, StreamPriority(state_stream.priority), + [this, stream_id]() { callbacks_.OnBufferedAmountLow(stream_id); }, + &state_stream)); + } +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.h b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.h new file mode 100644 index 0000000000..e9b8cd2081 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.h @@ -0,0 +1,282 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RR_SEND_QUEUE_H_ +#define NET_DCSCTP_TX_RR_SEND_QUEUE_H_ + +#include <cstdint> +#include <deque> +#include <map> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/tx/send_queue.h" +#include "net/dcsctp/tx/stream_scheduler.h" + +namespace dcsctp { + +// The Round Robin SendQueue holds all messages that the client wants to send, +// but that haven't yet been split into chunks and fully sent on the wire. +// +// As defined in https://datatracker.ietf.org/doc/html/rfc8260#section-3.2, +// it will cycle to send messages from different streams. It will send all +// fragments from one message before continuing with a different message on +// possibly a different stream, until support for message interleaving has been +// implemented. +// +// As messages can be (requested to be) sent before the connection is properly +// established, this send queue is always present - even for closed connections. +// +// The send queue may trigger callbacks: +// * `OnBufferedAmountLow`, `OnTotalBufferedAmountLow` +// These will be triggered as defined in their documentation. +// * `OnLifecycleMessageExpired(/*maybe_delivered=*/false)`, `OnLifecycleEnd` +// These will be triggered when messages have been expired, abandoned or +// discarded from the send queue. If a message is fully produced, meaning +// that the last fragment has been produced, the responsibility to send +// lifecycle events is then transferred to the retransmission queue, which +// is the one asking to produce the message. +class RRSendQueue : public SendQueue { + public: + RRSendQueue(absl::string_view log_prefix, + DcSctpSocketCallbacks* callbacks, + size_t buffer_size, + size_t mtu, + StreamPriority default_priority, + size_t total_buffered_amount_low_threshold); + + // Indicates if the buffer is full. Note that it's up to the caller to ensure + // that the buffer is not full prior to adding new items to it. + bool IsFull() const; + // Indicates if the buffer is empty. + bool IsEmpty() const; + + // Adds the message to be sent using the `send_options` provided. The current + // time should be in `now`. Note that it's the responsibility of the caller to + // ensure that the buffer is not full (by calling `IsFull`) before adding + // messages to it. + void Add(TimeMs now, + DcSctpMessage message, + const SendOptions& send_options = {}); + + // Implementation of `SendQueue`. + absl::optional<DataToSend> Produce(TimeMs now, size_t max_size) override; + bool Discard(IsUnordered unordered, + StreamID stream_id, + MID message_id) override; + void PrepareResetStream(StreamID streams) override; + bool HasStreamsReadyToBeReset() const override; + std::vector<StreamID> GetStreamsReadyToBeReset() override; + void CommitResetStreams() override; + void RollbackResetStreams() override; + void Reset() override; + size_t buffered_amount(StreamID stream_id) const override; + size_t total_buffered_amount() const override { + return total_buffered_amount_.value(); + } + size_t buffered_amount_low_threshold(StreamID stream_id) const override; + void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; + void EnableMessageInterleaving(bool enabled) override { + scheduler_.EnableMessageInterleaving(enabled); + } + + void SetStreamPriority(StreamID stream_id, StreamPriority priority); + StreamPriority GetStreamPriority(StreamID stream_id) const; + HandoverReadinessStatus GetHandoverReadiness() const; + void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); + + private: + struct MessageAttributes { + IsUnordered unordered; + MaxRetransmits max_retransmissions; + TimeMs expires_at; + LifecycleId lifecycle_id; + }; + + // Represents a value and a "low threshold" that when the value reaches or + // goes under the "low threshold", will trigger `on_threshold_reached` + // callback. + class ThresholdWatcher { + public: + explicit ThresholdWatcher(std::function<void()> on_threshold_reached) + : on_threshold_reached_(std::move(on_threshold_reached)) {} + // Increases the value. + void Increase(size_t bytes) { value_ += bytes; } + // Decreases the value and triggers `on_threshold_reached` if it's at or + // below `low_threshold()`. + void Decrease(size_t bytes); + + size_t value() const { return value_; } + size_t low_threshold() const { return low_threshold_; } + void SetLowThreshold(size_t low_threshold); + + private: + const std::function<void()> on_threshold_reached_; + size_t value_ = 0; + size_t low_threshold_ = 0; + }; + + // Per-stream information. + class OutgoingStream : public StreamScheduler::StreamProducer { + public: + OutgoingStream( + RRSendQueue* parent, + StreamScheduler* scheduler, + StreamID stream_id, + StreamPriority priority, + std::function<void()> on_buffered_amount_low, + const DcSctpSocketHandoverState::OutgoingStream* state = nullptr) + : parent_(*parent), + scheduler_stream_(scheduler->CreateStream(this, stream_id, priority)), + next_unordered_mid_(MID(state ? state->next_unordered_mid : 0)), + next_ordered_mid_(MID(state ? state->next_ordered_mid : 0)), + next_ssn_(SSN(state ? state->next_ssn : 0)), + buffered_amount_(std::move(on_buffered_amount_low)) {} + + StreamID stream_id() const { return scheduler_stream_->stream_id(); } + + // Enqueues a message to this stream. + void Add(DcSctpMessage message, MessageAttributes attributes); + + // Implementing `StreamScheduler::StreamProducer`. + absl::optional<SendQueue::DataToSend> Produce(TimeMs now, + size_t max_size) override; + size_t bytes_to_send_in_next_message() const override; + + const ThresholdWatcher& buffered_amount() const { return buffered_amount_; } + ThresholdWatcher& buffered_amount() { return buffered_amount_; } + + // Discards a partially sent message, see `SendQueue::Discard`. + bool Discard(IsUnordered unordered, MID message_id); + + // Pauses this stream, which is used before resetting it. + void Pause(); + + // Resumes a paused stream. + void Resume(); + + bool IsReadyToBeReset() const { + return pause_state_ == PauseState::kPaused; + } + + bool IsResetting() const { return pause_state_ == PauseState::kResetting; } + + void SetAsResetting() { + RTC_DCHECK(pause_state_ == PauseState::kPaused); + pause_state_ = PauseState::kResetting; + } + + // Resets this stream, meaning MIDs and SSNs are set to zero. + void Reset(); + + // Indicates if this stream has a partially sent message in it. + bool has_partially_sent_message() const; + + StreamPriority priority() const { return scheduler_stream_->priority(); } + void SetPriority(StreamPriority priority) { + scheduler_stream_->SetPriority(priority); + } + + void AddHandoverState( + DcSctpSocketHandoverState::OutgoingStream& state) const; + + private: + // Streams are paused before they can be reset. To reset a stream, the + // socket sends an outgoing stream reset command with the TSN of the last + // fragment of the last message, so that receivers and senders can agree on + // when it stopped. And if the send queue is in the middle of sending a + // message, and without fragments not yet sent and without TSNs allocated to + // them, it will keep sending data until that message has ended. + enum class PauseState { + // The stream is not paused, and not scheduled to be reset. + kNotPaused, + // The stream has requested to be reset/paused but is still producing + // fragments of a message that hasn't ended yet. When it does, it will + // transition to the `kPaused` state. + kPending, + // The stream is fully paused and can be reset. + kPaused, + // The stream has been added to an outgoing stream reset request and a + // response from the peer hasn't been received yet. + kResetting, + }; + + // An enqueued message and metadata. + struct Item { + explicit Item(DcSctpMessage msg, MessageAttributes attributes) + : message(std::move(msg)), + attributes(std::move(attributes)), + remaining_offset(0), + remaining_size(message.payload().size()) {} + DcSctpMessage message; + MessageAttributes attributes; + // The remaining payload (offset and size) to be sent, when it has been + // fragmented. + size_t remaining_offset; + size_t remaining_size; + // If set, an allocated Message ID and SSN. Will be allocated when the + // first fragment is sent. + absl::optional<MID> message_id = absl::nullopt; + absl::optional<SSN> ssn = absl::nullopt; + // The current Fragment Sequence Number, incremented for each fragment. + FSN current_fsn = FSN(0); + }; + + bool IsConsistent() const; + void HandleMessageExpired(OutgoingStream::Item& item); + + RRSendQueue& parent_; + + const std::unique_ptr<StreamScheduler::Stream> scheduler_stream_; + + PauseState pause_state_ = PauseState::kNotPaused; + // MIDs are different for unordered and ordered messages sent on a stream. + MID next_unordered_mid_; + MID next_ordered_mid_; + + SSN next_ssn_; + // Enqueued messages, and metadata. + std::deque<Item> items_; + + // The current amount of buffered data. + ThresholdWatcher buffered_amount_; + }; + + bool IsConsistent() const; + OutgoingStream& GetOrCreateStreamInfo(StreamID stream_id); + absl::optional<DataToSend> Produce( + std::map<StreamID, OutgoingStream>::iterator it, + TimeMs now, + size_t max_size); + + const std::string log_prefix_; + DcSctpSocketCallbacks& callbacks_; + const size_t buffer_size_; + const StreamPriority default_priority_; + StreamScheduler scheduler_; + + // The total amount of buffer data, for all streams. + ThresholdWatcher total_buffered_amount_; + + // All streams, and messages added to those. + std::map<StreamID, OutgoingStream> streams_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RR_SEND_QUEUE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue_test.cc b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue_test.cc new file mode 100644 index 0000000000..95416b193a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue_test.cc @@ -0,0 +1,866 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/rr_send_queue.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +constexpr TimeMs kNow = TimeMs(0); +constexpr StreamID kStreamID(1); +constexpr PPID kPPID(53); +constexpr size_t kMaxQueueSize = 1000; +constexpr StreamPriority kDefaultPriority(10); +constexpr size_t kBufferedAmountLowThreshold = 500; +constexpr size_t kOneFragmentPacketSize = 100; +constexpr size_t kTwoFragmentPacketSize = 101; +constexpr size_t kMtu = 1100; + +class RRSendQueueTest : public testing::Test { + protected: + RRSendQueueTest() + : buf_("log: ", + &callbacks_, + kMaxQueueSize, + kMtu, + kDefaultPriority, + kBufferedAmountLowThreshold) {} + + testing::NiceMock<MockDcSctpSocketCallbacks> callbacks_; + const DcSctpOptions options_; + RRSendQueue buf_; +}; + +TEST_F(RRSendQueueTest, EmptyBuffer) { + EXPECT_TRUE(buf_.IsEmpty()); + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); + EXPECT_FALSE(buf_.IsFull()); +} + +TEST_F(RRSendQueueTest, AddAndGetSingleChunk) { + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, {1, 2, 4, 5, 6})); + + EXPECT_FALSE(buf_.IsEmpty()); + EXPECT_FALSE(buf_.IsFull()); + absl::optional<SendQueue::DataToSend> chunk_opt = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_opt.has_value()); + EXPECT_TRUE(chunk_opt->data.is_beginning); + EXPECT_TRUE(chunk_opt->data.is_end); +} + +TEST_F(RRSendQueueTest, CarveOutBeginningMiddleAndEnd) { + std::vector<uint8_t> payload(60); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_beg = + buf_.Produce(kNow, /*max_size=*/20); + ASSERT_TRUE(chunk_beg.has_value()); + EXPECT_TRUE(chunk_beg->data.is_beginning); + EXPECT_FALSE(chunk_beg->data.is_end); + + absl::optional<SendQueue::DataToSend> chunk_mid = + buf_.Produce(kNow, /*max_size=*/20); + ASSERT_TRUE(chunk_mid.has_value()); + EXPECT_FALSE(chunk_mid->data.is_beginning); + EXPECT_FALSE(chunk_mid->data.is_end); + + absl::optional<SendQueue::DataToSend> chunk_end = + buf_.Produce(kNow, /*max_size=*/20); + ASSERT_TRUE(chunk_end.has_value()); + EXPECT_FALSE(chunk_end->data.is_beginning); + EXPECT_TRUE(chunk_end->data.is_end); + + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); +} + +TEST_F(RRSendQueueTest, GetChunksFromTwoMessages) { + std::vector<uint8_t> payload(60); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(3), PPID(54), payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(chunk_one->data.ppid, kPPID); + EXPECT_TRUE(chunk_one->data.is_beginning); + EXPECT_TRUE(chunk_one->data.is_end); + + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_two->data.ppid, PPID(54)); + EXPECT_TRUE(chunk_two->data.is_beginning); + EXPECT_TRUE(chunk_two->data.is_end); +} + +TEST_F(RRSendQueueTest, BufferBecomesFullAndEmptied) { + std::vector<uint8_t> payload(600); + EXPECT_FALSE(buf_.IsFull()); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_FALSE(buf_.IsFull()); + buf_.Add(kNow, DcSctpMessage(StreamID(3), PPID(54), payload)); + EXPECT_TRUE(buf_.IsFull()); + // However, it's still possible to add messages. It's a soft limit, and it + // might be necessary to forcefully add messages due to e.g. external + // fragmentation. + buf_.Add(kNow, DcSctpMessage(StreamID(5), PPID(55), payload)); + EXPECT_TRUE(buf_.IsFull()); + + absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 1000); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(chunk_one->data.ppid, kPPID); + + EXPECT_TRUE(buf_.IsFull()); + + absl::optional<SendQueue::DataToSend> chunk_two = buf_.Produce(kNow, 1000); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_two->data.ppid, PPID(54)); + + EXPECT_FALSE(buf_.IsFull()); + EXPECT_FALSE(buf_.IsEmpty()); + + absl::optional<SendQueue::DataToSend> chunk_three = buf_.Produce(kNow, 1000); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.stream_id, StreamID(5)); + EXPECT_EQ(chunk_three->data.ppid, PPID(55)); + + EXPECT_FALSE(buf_.IsFull()); + EXPECT_TRUE(buf_.IsEmpty()); +} + +TEST_F(RRSendQueueTest, DefaultsToOrderedSend) { + std::vector<uint8_t> payload(20); + + // Default is ordered + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_FALSE(chunk_one->data.is_unordered); + + // Explicitly unordered. + SendOptions opts; + opts.unordered = IsUnordered(true); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload), opts); + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_TRUE(chunk_two->data.is_unordered); +} + +TEST_F(RRSendQueueTest, ProduceWithLifetimeExpiry) { + std::vector<uint8_t> payload(20); + + // Default is no expiry + TimeMs now = kNow; + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload)); + now += DurationMs(1000000); + ASSERT_TRUE(buf_.Produce(now, kOneFragmentPacketSize)); + + SendOptions expires_2_seconds; + expires_2_seconds.lifetime = DurationMs(2000); + + // Add and consume within lifetime + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + now += DurationMs(2000); + ASSERT_TRUE(buf_.Produce(now, kOneFragmentPacketSize)); + + // Add and consume just outside lifetime + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + now += DurationMs(2001); + ASSERT_FALSE(buf_.Produce(now, kOneFragmentPacketSize)); + + // A long time after expiry + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + now += DurationMs(1000000); + ASSERT_FALSE(buf_.Produce(now, kOneFragmentPacketSize)); + + // Expire one message, but produce the second that is not expired. + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + + SendOptions expires_4_seconds; + expires_4_seconds.lifetime = DurationMs(4000); + + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_4_seconds); + now += DurationMs(2001); + + ASSERT_TRUE(buf_.Produce(now, kOneFragmentPacketSize)); + ASSERT_FALSE(buf_.Produce(now, kOneFragmentPacketSize)); +} + +TEST_F(RRSendQueueTest, DiscardPartialPackets) { + std::vector<uint8_t> payload(120); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(2), PPID(54), payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_FALSE(chunk_one->data.is_end); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + buf_.Discard(IsUnordered(false), chunk_one->data.stream_id, + chunk_one->data.message_id); + + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_FALSE(chunk_two->data.is_end); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(2)); + + absl::optional<SendQueue::DataToSend> chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_TRUE(chunk_three->data.is_end); + EXPECT_EQ(chunk_three->data.stream_id, StreamID(2)); + ASSERT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize)); + + // Calling it again shouldn't cause issues. + buf_.Discard(IsUnordered(false), chunk_one->data.stream_id, + chunk_one->data.message_id); + ASSERT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize)); +} + +TEST_F(RRSendQueueTest, PrepareResetStreamsDiscardsStream) { + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, {1, 2, 3})); + buf_.Add(kNow, DcSctpMessage(StreamID(2), PPID(54), {1, 2, 3, 4, 5})); + EXPECT_EQ(buf_.total_buffered_amount(), 8u); + + buf_.PrepareResetStream(StreamID(1)); + EXPECT_EQ(buf_.total_buffered_amount(), 5u); + + EXPECT_THAT(buf_.GetStreamsReadyToBeReset(), + UnorderedElementsAre(StreamID(1))); + buf_.CommitResetStreams(); + buf_.PrepareResetStream(StreamID(2)); + EXPECT_EQ(buf_.total_buffered_amount(), 0u); +} + +TEST_F(RRSendQueueTest, PrepareResetStreamsNotPartialPackets) { + std::vector<uint8_t> payload(120); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 50); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 2 * payload.size() - 50); + + buf_.PrepareResetStream(StreamID(1)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size() - 50); +} + +TEST_F(RRSendQueueTest, EnqueuedItemsArePausedDuringStreamReset) { + std::vector<uint8_t> payload(50); + + buf_.PrepareResetStream(StreamID(1)); + EXPECT_EQ(buf_.total_buffered_amount(), 0u); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); + + EXPECT_TRUE(buf_.HasStreamsReadyToBeReset()); + EXPECT_THAT(buf_.GetStreamsReadyToBeReset(), + UnorderedElementsAre(StreamID(1))); + + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); + + buf_.CommitResetStreams(); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 50); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 0u); +} + +TEST_F(RRSendQueueTest, PausedStreamsStillSendPartialMessagesUntilEnd) { + constexpr size_t kPayloadSize = 100; + constexpr size_t kFragmentSize = 50; + std::vector<uint8_t> payload(kPayloadSize); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kFragmentSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 2 * kPayloadSize - kFragmentSize); + + // This will stop the second message from being sent. + buf_.PrepareResetStream(StreamID(1)); + EXPECT_EQ(buf_.total_buffered_amount(), 1 * kPayloadSize - kFragmentSize); + + // Should still produce fragments until end of message. + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kFragmentSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 0ul); + + // But shouldn't produce any more messages as the stream is paused. + EXPECT_FALSE(buf_.Produce(kNow, kFragmentSize).has_value()); +} + +TEST_F(RRSendQueueTest, CommittingResetsSSN) { + std::vector<uint8_t> payload(50); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.ssn, SSN(0)); + + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.ssn, SSN(1)); + + buf_.PrepareResetStream(StreamID(1)); + + // Buffered + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + EXPECT_TRUE(buf_.HasStreamsReadyToBeReset()); + EXPECT_THAT(buf_.GetStreamsReadyToBeReset(), + UnorderedElementsAre(StreamID(1))); + buf_.CommitResetStreams(); + + absl::optional<SendQueue::DataToSend> chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.ssn, SSN(0)); +} + +TEST_F(RRSendQueueTest, CommittingResetsSSNForPausedStreamsOnly) { + std::vector<uint8_t> payload(50); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, StreamID(1)); + EXPECT_EQ(chunk_one->data.ssn, SSN(0)); + + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_two->data.ssn, SSN(0)); + + buf_.PrepareResetStream(StreamID(3)); + + // Send two more messages - SID 3 will buffer, SID 1 will send. + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, payload)); + + EXPECT_TRUE(buf_.HasStreamsReadyToBeReset()); + EXPECT_THAT(buf_.GetStreamsReadyToBeReset(), + UnorderedElementsAre(StreamID(3))); + + buf_.CommitResetStreams(); + + absl::optional<SendQueue::DataToSend> chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.stream_id, StreamID(1)); + EXPECT_EQ(chunk_three->data.ssn, SSN(1)); + + absl::optional<SendQueue::DataToSend> chunk_four = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_four.has_value()); + EXPECT_EQ(chunk_four->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_four->data.ssn, SSN(0)); +} + +TEST_F(RRSendQueueTest, RollBackResumesSSN) { + std::vector<uint8_t> payload(50); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.ssn, SSN(0)); + + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.ssn, SSN(1)); + + buf_.PrepareResetStream(StreamID(1)); + + // Buffered + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + EXPECT_TRUE(buf_.HasStreamsReadyToBeReset()); + EXPECT_THAT(buf_.GetStreamsReadyToBeReset(), + UnorderedElementsAre(StreamID(1))); + buf_.RollbackResetStreams(); + + absl::optional<SendQueue::DataToSend> chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.ssn, SSN(2)); +} + +TEST_F(RRSendQueueTest, ReturnsFragmentsForOneMessageBeforeMovingToNext) { + std::vector<uint8_t> payload(200); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, payload)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(2)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(2)); +} + +TEST_F(RRSendQueueTest, ReturnsAlsoSmallFragmentsBeforeMovingToNext) { + std::vector<uint8_t> payload(kTwoFragmentPacketSize); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, payload)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(kOneFragmentPacketSize)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, + SizeIs(kTwoFragmentPacketSize - kOneFragmentPacketSize)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk4.data.payload, + SizeIs(kTwoFragmentPacketSize - kOneFragmentPacketSize)); +} + +TEST_F(RRSendQueueTest, WillCycleInRoundRobinFashionBetweenStreams) { + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1))); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(2))); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(3))); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(4))); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, std::vector<uint8_t>(5))); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, std::vector<uint8_t>(6))); + buf_.Add(kNow, DcSctpMessage(StreamID(4), kPPID, std::vector<uint8_t>(7))); + buf_.Add(kNow, DcSctpMessage(StreamID(4), kPPID, std::vector<uint8_t>(8))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk2.data.payload, SizeIs(3)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(3)); + EXPECT_THAT(chunk3.data.payload, SizeIs(5)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(4)); + EXPECT_THAT(chunk4.data.payload, SizeIs(7)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk5, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk5.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk5.data.payload, SizeIs(2)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk6, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk6.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk6.data.payload, SizeIs(4)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk7, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk7.data.stream_id, StreamID(3)); + EXPECT_THAT(chunk7.data.payload, SizeIs(6)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk8, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk8.data.stream_id, StreamID(4)); + EXPECT_THAT(chunk8.data.payload, SizeIs(8)); +} + +TEST_F(RRSendQueueTest, DoesntTriggerOnBufferedAmountLowWhenSetToZero) { + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 0u); +} + +TEST_F(RRSendQueueTest, TriggersOnBufferedAmountAtZeroLowWhenSent) { + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 1u); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 0u); +} + +TEST_F(RRSendQueueTest, WillRetriggerOnBufferedAmountLowIfAddingMore) { + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1))); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 1u); + + // Should now trigger again, as buffer_amount went above the threshold. + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(1)); +} + +TEST_F(RRSendQueueTest, OnlyTriggersWhenTransitioningFromAboveToBelowOrEqual) { + buf_.SetBufferedAmountLowThreshold(StreamID(1), 1000); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(10))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 10u); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(10)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 0u); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(20))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 20u); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(20)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 0u); +} + +TEST_F(RRSendQueueTest, WillTriggerOnBufferedAmountLowSetAboveZero) { + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + + buf_.SetBufferedAmountLowThreshold(StreamID(1), 700); + + std::vector<uint8_t> payload(1000); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 900u); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 800u); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 700u); + + // Doesn't trigger when reducing even further. + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 600u); +} + +TEST_F(RRSendQueueTest, WillRetriggerOnBufferedAmountLowSetAboveZero) { + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + + buf_.SetBufferedAmountLowThreshold(StreamID(1), 700); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1000))); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, 400)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(400)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 600u); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(200))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 800u); + + // Will trigger again, as it went above the limit. + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, 200)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(200)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 600u); +} + +TEST_F(RRSendQueueTest, TriggersOnBufferedAmountLowOnThresholdChanged) { + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(100))); + + // Modifying the threshold, still under buffered_amount, should not trigger. + buf_.SetBufferedAmountLowThreshold(StreamID(1), 50); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 99); + + // When the threshold reaches buffered_amount, it will trigger. + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 100); + + // But not when it's set low again. + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 50); + + // But it will trigger when it overshoots. + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 150); + + // But not when it's set low again. + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 0); +} + +TEST_F(RRSendQueueTest, + OnTotalBufferedAmountLowDoesNotTriggerOnBufferFillingUp) { + EXPECT_CALL(callbacks_, OnTotalBufferedAmountLow).Times(0); + std::vector<uint8_t> payload(kBufferedAmountLowThreshold - 1); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + // Will not trigger if going above but never below. + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, + std::vector<uint8_t>(kOneFragmentPacketSize))); +} + +TEST_F(RRSendQueueTest, TriggersOnTotalBufferedAmountLowWhenCrossing) { + EXPECT_CALL(callbacks_, OnTotalBufferedAmountLow).Times(0); + std::vector<uint8_t> payload(kBufferedAmountLowThreshold); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + // Reaches it. + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, std::vector<uint8_t>(1))); + + // Drain it a bit - will trigger. + EXPECT_CALL(callbacks_, OnTotalBufferedAmountLow).Times(1); + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); +} + +TEST_F(RRSendQueueTest, WillStayInAStreamAsLongAsThatMessageIsSending) { + buf_.Add(kNow, DcSctpMessage(StreamID(5), kPPID, std::vector<uint8_t>(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(5)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + + // Next, it should pick a different stream. + + buf_.Add(kNow, + DcSctpMessage(StreamID(1), kPPID, + std::vector<uint8_t>(kOneFragmentPacketSize * 2))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(kOneFragmentPacketSize)); + + // It should still stay on the Stream1 now, even if might be tempted to switch + // to this stream, as it's the stream following 5. + buf_.Add(kNow, DcSctpMessage(StreamID(6), kPPID, std::vector<uint8_t>(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + + // After stream id 1 is complete, it's time to do stream 6. + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(6)); + EXPECT_THAT(chunk4.data.payload, SizeIs(1)); + + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); +} + +TEST_F(RRSendQueueTest, StreamsHaveInitialPriority) { + EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), kDefaultPriority); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40))); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), kDefaultPriority); +} + +TEST_F(RRSendQueueTest, CanChangeStreamPriority) { + buf_.SetStreamPriority(StreamID(1), StreamPriority(42)); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), StreamPriority(42)); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40))); + buf_.SetStreamPriority(StreamID(2), StreamPriority(42)); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), StreamPriority(42)); +} + +TEST_F(RRSendQueueTest, WillHandoverPriority) { + buf_.SetStreamPriority(StreamID(1), StreamPriority(42)); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40))); + buf_.SetStreamPriority(StreamID(2), StreamPriority(42)); + + DcSctpSocketHandoverState state; + buf_.AddHandoverState(state); + + RRSendQueue q2("log: ", &callbacks_, kMaxQueueSize, kMtu, kDefaultPriority, + kBufferedAmountLowThreshold); + q2.RestoreFromState(state); + EXPECT_EQ(q2.GetStreamPriority(StreamID(1)), StreamPriority(42)); + EXPECT_EQ(q2.GetStreamPriority(StreamID(2)), StreamPriority(42)); +} + +TEST_F(RRSendQueueTest, WillSendMessagesByPrio) { + buf_.EnableMessageInterleaving(true); + buf_.SetStreamPriority(StreamID(1), StreamPriority(10)); + buf_.SetStreamPriority(StreamID(2), StreamPriority(20)); + buf_.SetStreamPriority(StreamID(3), StreamPriority(30)); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(40))); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(20))); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, std::vector<uint8_t>(10))); + std::vector<uint16_t> expected_streams = {3, 2, 2, 1, 1, 1, 1}; + + for (uint16_t stream_num : expected_streams) { + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk, + buf_.Produce(kNow, 10)); + EXPECT_EQ(chunk.data.stream_id, StreamID(stream_num)); + } + EXPECT_FALSE(buf_.Produce(kNow, 1).has_value()); +} + +TEST_F(RRSendQueueTest, WillSendLifecycleExpireWhenExpiredInSendQueue) { + std::vector<uint8_t> payload(kOneFragmentPacketSize); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, payload), + SendOptions{.lifetime = DurationMs(1000), + .lifecycle_id = LifecycleId(1)}); + + EXPECT_CALL(callbacks_, OnLifecycleMessageExpired(LifecycleId(1), + /*maybe_delivered=*/false)); + EXPECT_CALL(callbacks_, OnLifecycleEnd(LifecycleId(1))); + EXPECT_FALSE(buf_.Produce(kNow + DurationMs(1001), kOneFragmentPacketSize) + .has_value()); +} + +TEST_F(RRSendQueueTest, WillSendLifecycleExpireWhenDiscardingDuringPause) { + std::vector<uint8_t> payload(120); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload), + SendOptions{.lifecycle_id = LifecycleId(1)}); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload), + SendOptions{.lifecycle_id = LifecycleId(2)}); + + absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 50); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 2 * payload.size() - 50); + + EXPECT_CALL(callbacks_, OnLifecycleMessageExpired(LifecycleId(2), + /*maybe_delivered=*/false)); + EXPECT_CALL(callbacks_, OnLifecycleEnd(LifecycleId(2))); + buf_.PrepareResetStream(StreamID(1)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size() - 50); +} + +TEST_F(RRSendQueueTest, WillSendLifecycleExpireWhenDiscardingExplicitly) { + std::vector<uint8_t> payload(kOneFragmentPacketSize + 20); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload), + SendOptions{.lifecycle_id = LifecycleId(1)}); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_FALSE(chunk_one->data.is_end); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_CALL(callbacks_, OnLifecycleMessageExpired(LifecycleId(1), + /*maybe_delivered=*/false)); + EXPECT_CALL(callbacks_, OnLifecycleEnd(LifecycleId(1))); + buf_.Discard(IsUnordered(false), chunk_one->data.stream_id, + chunk_one->data.message_id); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/send_queue.h b/third_party/libwebrtc/net/dcsctp/tx/send_queue.h new file mode 100644 index 0000000000..0b96e9041a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/send_queue.h @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_SEND_QUEUE_H_ +#define NET_DCSCTP_TX_SEND_QUEUE_H_ + +#include <cstdint> +#include <limits> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +class SendQueue { + public: + // Container for a data chunk that is produced by the SendQueue + struct DataToSend { + explicit DataToSend(Data data) : data(std::move(data)) {} + // The data to send, including all parameters. + Data data; + + // Partial reliability - RFC3758 + MaxRetransmits max_retransmissions = MaxRetransmits::NoLimit(); + TimeMs expires_at = TimeMs::InfiniteFuture(); + + // Lifecycle - set for the last fragment, and `LifecycleId::NotSet()` for + // all other fragments. + LifecycleId lifecycle_id = LifecycleId::NotSet(); + }; + + virtual ~SendQueue() = default; + + // TODO(boivie): This interface is obviously missing an "Add" function, but + // that is postponed a bit until the story around how to model message + // prioritization, which is important for any advanced stream scheduler, is + // further clarified. + + // Produce a chunk to be sent. + // + // `max_size` refers to how many payload bytes that may be produced, not + // including any headers. + virtual absl::optional<DataToSend> Produce(TimeMs now, size_t max_size) = 0; + + // Discards a partially sent message identified by the parameters `unordered`, + // `stream_id` and `message_id`. The `message_id` comes from the returned + // information when having called `Produce`. A partially sent message means + // that it has had at least one fragment of it returned when `Produce` was + // called prior to calling this method). + // + // This is used when a message has been found to be expired (by the partial + // reliability extension), and the retransmission queue will signal the + // receiver that any partially received message fragments should be skipped. + // This means that any remaining fragments in the Send Queue must be removed + // as well so that they are not sent. + // + // This function returns true if this message had unsent fragments still in + // the queue that were discarded, and false if there were no such fragments. + virtual bool Discard(IsUnordered unordered, + StreamID stream_id, + MID message_id) = 0; + + // Prepares the stream to be reset. This is used to close a WebRTC data + // channel and will be signaled to the other side. + // + // Concretely, it discards all whole (not partly sent) messages in the given + // stream and pauses that stream so that future added messages aren't + // produced until `ResumeStreams` is called. + // + // TODO(boivie): Investigate if it really should discard any message at all. + // RFC8831 only mentions that "[RFC6525] also guarantees that all the messages + // are delivered (or abandoned) before the stream is reset." + // + // This method can be called multiple times to add more streams to be + // reset, and paused while they are resetting. This is the first part of the + // two-phase commit protocol to reset streams, where the caller completes the + // procedure by either calling `CommitResetStreams` or `RollbackResetStreams`. + virtual void PrepareResetStream(StreamID stream_id) = 0; + + // Indicates if there are any streams that are ready to be reset. + virtual bool HasStreamsReadyToBeReset() const = 0; + + // Returns a list of streams that are ready to be included in an outgoing + // stream reset request. Any streams that are returned here must be included + // in an outgoing stream reset request, and there must not be concurrent + // requests. Before calling this method again, you must have called + virtual std::vector<StreamID> GetStreamsReadyToBeReset() = 0; + + // Called to commit to reset the streams returned by + // `GetStreamsReadyToBeReset`. It will reset the stream sequence numbers + // (SSNs) and message identifiers (MIDs) and resume the paused streams. + virtual void CommitResetStreams() = 0; + + // Called to abort the resetting of streams returned by + // `GetStreamsReadyToBeReset`. Will resume the paused streams without + // resetting the stream sequence numbers (SSNs) or message identifiers (MIDs). + // Note that the non-partial messages that were discarded when calling + // `PrepareResetStreams` will not be recovered, to better match the intention + // from the sender to "close the channel". + virtual void RollbackResetStreams() = 0; + + // Resets all message identifier counters (MID, SSN) and makes all partially + // messages be ready to be re-sent in full. This is used when the peer has + // been detected to have restarted and is used to try to minimize the amount + // of data loss. However, data loss cannot be completely guaranteed when a + // peer restarts. + virtual void Reset() = 0; + + // Returns the amount of buffered data. This doesn't include packets that are + // e.g. inflight. + virtual size_t buffered_amount(StreamID stream_id) const = 0; + + // Returns the total amount of buffer data, for all streams. + virtual size_t total_buffered_amount() const = 0; + + // Returns the limit for the `OnBufferedAmountLow` event. Default value is 0. + virtual size_t buffered_amount_low_threshold(StreamID stream_id) const = 0; + + // Sets a limit for the `OnBufferedAmountLow` event. + virtual void SetBufferedAmountLowThreshold(StreamID stream_id, + size_t bytes) = 0; + + // Configures the send queue to support interleaved message sending as + // described in RFC8260. Every send queue starts with this value set as + // disabled, but can later change it when the capabilities of the connection + // have been negotiated. This affects the behavior of the `Produce` method. + virtual void EnableMessageInterleaving(bool enabled) = 0; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_SEND_QUEUE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.cc b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.cc new file mode 100644 index 0000000000..d1560a75e4 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.cc @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/stream_scheduler.h" + +#include <algorithm> + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +void StreamScheduler::Stream::SetPriority(StreamPriority priority) { + priority_ = priority; + inverse_weight_ = InverseWeight(priority); +} + +absl::optional<SendQueue::DataToSend> StreamScheduler::Produce( + TimeMs now, + size_t max_size) { + // For non-interleaved streams, avoid rescheduling while still sending a + // message as it needs to be sent in full. For interleaved messaging, + // reschedule for every I-DATA chunk sent. + bool rescheduling = + enable_message_interleaving_ || !currently_sending_a_message_; + + RTC_LOG(LS_VERBOSE) << "Producing data, rescheduling=" << rescheduling + << ", active=" + << StrJoin(active_streams_, ", ", + [&](rtc::StringBuilder& sb, const auto& p) { + sb << *p->stream_id() << "@" + << *p->next_finish_time(); + }); + + RTC_DCHECK(rescheduling || current_stream_ != nullptr); + + absl::optional<SendQueue::DataToSend> data; + while (!data.has_value() && !active_streams_.empty()) { + if (rescheduling) { + auto it = active_streams_.begin(); + current_stream_ = *it; + RTC_DLOG(LS_VERBOSE) << "Rescheduling to stream " + << *current_stream_->stream_id(); + + active_streams_.erase(it); + current_stream_->ForceMarkInactive(); + } else { + RTC_DLOG(LS_VERBOSE) << "Producing from previous stream: " + << *current_stream_->stream_id(); + RTC_DCHECK(absl::c_any_of(active_streams_, [this](const auto* p) { + return p == current_stream_; + })); + } + + data = current_stream_->Produce(now, max_size); + } + + if (!data.has_value()) { + RTC_DLOG(LS_VERBOSE) + << "There is no stream with data; Can't produce any data."; + RTC_DCHECK(IsConsistent()); + + return absl::nullopt; + } + + RTC_DCHECK(data->data.stream_id == current_stream_->stream_id()); + + RTC_DLOG(LS_VERBOSE) << "Producing DATA, type=" + << (data->data.is_unordered ? "unordered" : "ordered") + << "::" + << (*data->data.is_beginning && *data->data.is_end + ? "complete" + : *data->data.is_beginning ? "first" + : *data->data.is_end ? "last" + : "middle") + << ", stream_id=" << *current_stream_->stream_id() + << ", ppid=" << *data->data.ppid + << ", length=" << data->data.payload.size(); + + currently_sending_a_message_ = !*data->data.is_end; + virtual_time_ = current_stream_->current_time(); + + // One side-effect of rescheduling is that the new stream will not be present + // in `active_streams`. + size_t bytes_to_send_next = current_stream_->bytes_to_send_in_next_message(); + if (rescheduling && bytes_to_send_next > 0) { + current_stream_->MakeActive(bytes_to_send_next); + } else if (!rescheduling && bytes_to_send_next == 0) { + current_stream_->MakeInactive(); + } + + RTC_DCHECK(IsConsistent()); + return data; +} + +StreamScheduler::VirtualTime StreamScheduler::Stream::CalculateFinishTime( + size_t bytes_to_send_next) const { + if (parent_.enable_message_interleaving_) { + // Perform weighted fair queuing scheduling. + return VirtualTime(*current_virtual_time_ + + bytes_to_send_next * *inverse_weight_); + } + + // Perform round-robin scheduling by letting the stream have its next virtual + // finish time in the future. It doesn't matter how far into the future, just + // any positive number so that any other stream that has the same virtual + // finish time as this stream gets to produce their data before revisiting + // this stream. + return VirtualTime(*current_virtual_time_ + 1); +} + +absl::optional<SendQueue::DataToSend> StreamScheduler::Stream::Produce( + TimeMs now, + size_t max_size) { + absl::optional<SendQueue::DataToSend> data = producer_.Produce(now, max_size); + + if (data.has_value()) { + VirtualTime new_current = CalculateFinishTime(data->data.payload.size()); + RTC_DLOG(LS_VERBOSE) << "Virtual time changed: " << *current_virtual_time_ + << " -> " << *new_current; + current_virtual_time_ = new_current; + } + + return data; +} + +bool StreamScheduler::IsConsistent() const { + for (Stream* stream : active_streams_) { + if (stream->next_finish_time_ == VirtualTime::Zero()) { + RTC_DLOG(LS_VERBOSE) << "Stream " << *stream->stream_id() + << " is active, but has no next-finish-time"; + return false; + } + } + return true; +} + +void StreamScheduler::Stream::MaybeMakeActive() { + RTC_DLOG(LS_VERBOSE) << "MaybeMakeActive(" << *stream_id() << ")"; + RTC_DCHECK(next_finish_time_ == VirtualTime::Zero()); + size_t bytes_to_send_next = bytes_to_send_in_next_message(); + if (bytes_to_send_next == 0) { + return; + } + + MakeActive(bytes_to_send_next); +} + +void StreamScheduler::Stream::MakeActive(size_t bytes_to_send_next) { + current_virtual_time_ = parent_.virtual_time_; + RTC_DCHECK_GT(bytes_to_send_next, 0); + VirtualTime next_finish_time = CalculateFinishTime( + std::min(bytes_to_send_next, parent_.max_payload_bytes_)); + RTC_DCHECK_GT(*next_finish_time, 0); + RTC_DLOG(LS_VERBOSE) << "Making stream " << *stream_id() + << " active, expiring at " << *next_finish_time; + RTC_DCHECK(next_finish_time_ == VirtualTime::Zero()); + next_finish_time_ = next_finish_time; + RTC_DCHECK(!absl::c_any_of(parent_.active_streams_, + [this](const auto* p) { return p == this; })); + parent_.active_streams_.emplace(this); +} + +void StreamScheduler::Stream::ForceMarkInactive() { + RTC_DLOG(LS_VERBOSE) << "Making stream " << *stream_id() << " inactive"; + RTC_DCHECK(next_finish_time_ != VirtualTime::Zero()); + next_finish_time_ = VirtualTime::Zero(); +} + +void StreamScheduler::Stream::MakeInactive() { + ForceMarkInactive(); + webrtc::EraseIf(parent_.active_streams_, + [&](const auto* s) { return s == this; }); +} + +std::set<StreamID> StreamScheduler::ActiveStreamsForTesting() const { + std::set<StreamID> stream_ids; + for (const auto& stream : active_streams_) { + stream_ids.insert(stream->stream_id()); + } + return stream_ids; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.h b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.h new file mode 100644 index 0000000000..9c523edbfc --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.h @@ -0,0 +1,222 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_STREAM_SCHEDULER_H_ +#define NET_DCSCTP_TX_STREAM_SCHEDULER_H_ + +#include <algorithm> +#include <cstdint> +#include <deque> +#include <map> +#include <memory> +#include <queue> +#include <set> +#include <string> +#include <utility> + +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/containers/flat_set.h" +#include "rtc_base/strong_alias.h" + +namespace dcsctp { + +// A parameterized stream scheduler. Currently, it implements the round robin +// scheduling algorithm using virtual finish time. It is to be used as a part of +// a send queue and will track all active streams (streams that have any data +// that can be sent). +// +// The stream scheduler works with the concept of associating active streams +// with a "virtual finish time", which is the time when a stream is allowed to +// produce data. Streams are ordered by their virtual finish time, and the +// "current virtual time" will advance to the next following virtual finish time +// whenever a chunk is to be produced. +// +// When message interleaving is enabled, the WFQ - Weighted Fair Queueing - +// scheduling algorithm will be used. And when it's not, round-robin scheduling +// will be used instead. +// +// In the round robin scheduling algorithm, a stream's virtual finish time will +// just increment by one (1) after having produced a chunk, which results in a +// round-robin scheduling. +// +// In WFQ scheduling algorithm, a stream's virtual finish time will be defined +// as the number of bytes in the next fragment to be sent, multiplied by the +// inverse of the stream's priority, meaning that a high priority - or a smaller +// fragment - results in a closer virtual finish time, compared to a stream with +// either a lower priority or a larger fragment to be sent. +class StreamScheduler { + private: + class VirtualTime : public webrtc::StrongAlias<class VirtualTimeTag, double> { + public: + constexpr explicit VirtualTime(const UnderlyingType& v) + : webrtc::StrongAlias<class VirtualTimeTag, double>(v) {} + + static constexpr VirtualTime Zero() { return VirtualTime(0); } + }; + class InverseWeight + : public webrtc::StrongAlias<class InverseWeightTag, double> { + public: + constexpr explicit InverseWeight(StreamPriority priority) + : webrtc::StrongAlias<class InverseWeightTag, double>( + 1.0 / std::max(static_cast<double>(*priority), 0.000001)) {} + }; + + public: + class StreamProducer { + public: + virtual ~StreamProducer() = default; + + // Produces a fragment of data to send. The current wall time is specified + // as `now` and should be used to skip chunks with expired limited lifetime. + // The parameter `max_size` specifies the maximum amount of actual payload + // that may be returned. If these constraints prevents the stream from + // sending some data, `absl::nullopt` should be returned. + virtual absl::optional<SendQueue::DataToSend> Produce(TimeMs now, + size_t max_size) = 0; + + // Returns the number of payload bytes that is scheduled to be sent in the + // next enqueued message, or zero if there are no enqueued messages or if + // the stream has been actively paused. + virtual size_t bytes_to_send_in_next_message() const = 0; + }; + + class Stream { + public: + StreamID stream_id() const { return stream_id_; } + + StreamPriority priority() const { return priority_; } + void SetPriority(StreamPriority priority); + + // Will activate the stream _if_ it has any data to send. That is, if the + // callback to `bytes_to_send_in_next_message` returns non-zero. If the + // callback returns zero, the stream will not be made active. + void MaybeMakeActive(); + + // Will remove the stream from the list of active streams, and will not try + // to produce data from it. To make it active again, call `MaybeMakeActive`. + void MakeInactive(); + + // Make the scheduler move to another message, or another stream. This is + // used to abort the scheduler from continuing producing fragments for the + // current message in case it's deleted. + void ForceReschedule() { parent_.ForceReschedule(); } + + private: + friend class StreamScheduler; + + Stream(StreamScheduler* parent, + StreamProducer* producer, + StreamID stream_id, + StreamPriority priority) + : parent_(*parent), + producer_(*producer), + stream_id_(stream_id), + priority_(priority), + inverse_weight_(priority) {} + + // Produces a message from this stream. This will only be called on streams + // that have data. + absl::optional<SendQueue::DataToSend> Produce(TimeMs now, size_t max_size); + + void MakeActive(size_t bytes_to_send_next); + void ForceMarkInactive(); + + VirtualTime current_time() const { return current_virtual_time_; } + VirtualTime next_finish_time() const { return next_finish_time_; } + size_t bytes_to_send_in_next_message() const { + return producer_.bytes_to_send_in_next_message(); + } + + VirtualTime CalculateFinishTime(size_t bytes_to_send_next) const; + + StreamScheduler& parent_; + StreamProducer& producer_; + const StreamID stream_id_; + StreamPriority priority_; + InverseWeight inverse_weight_; + // This outgoing stream's "current" virtual_time. + VirtualTime current_virtual_time_ = VirtualTime::Zero(); + VirtualTime next_finish_time_ = VirtualTime::Zero(); + }; + + // The `mtu` parameter represents the maximum SCTP packet size, which should + // be the same as `DcSctpOptions::mtu`. + explicit StreamScheduler(size_t mtu) + : max_payload_bytes_(mtu - SctpPacket::kHeaderSize - + IDataChunk::kHeaderSize) {} + + std::unique_ptr<Stream> CreateStream(StreamProducer* producer, + StreamID stream_id, + StreamPriority priority) { + return absl::WrapUnique(new Stream(this, producer, stream_id, priority)); + } + + void EnableMessageInterleaving(bool enabled) { + enable_message_interleaving_ = enabled; + } + + // Makes the scheduler stop producing message from the current stream and + // re-evaluates which stream to produce from. + void ForceReschedule() { currently_sending_a_message_ = false; } + + // Produces a fragment of data to send. The current wall time is specified as + // `now` and will be used to skip chunks with expired limited lifetime. The + // parameter `max_size` specifies the maximum amount of actual payload that + // may be returned. If no data can be produced, `absl::nullopt` is returned. + absl::optional<SendQueue::DataToSend> Produce(TimeMs now, size_t max_size); + + std::set<StreamID> ActiveStreamsForTesting() const; + + private: + struct ActiveStreamComparator { + // Ordered by virtual finish time (primary), stream-id (secondary). + bool operator()(Stream* a, Stream* b) const { + VirtualTime a_vft = a->next_finish_time(); + VirtualTime b_vft = b->next_finish_time(); + if (a_vft == b_vft) { + return a->stream_id() < b->stream_id(); + } + return a_vft < b_vft; + } + }; + + bool IsConsistent() const; + + const size_t max_payload_bytes_; + + // The current virtual time, as defined in the WFQ algorithm. + VirtualTime virtual_time_ = VirtualTime::Zero(); + + // The current stream to send chunks from. + Stream* current_stream_ = nullptr; + + bool enable_message_interleaving_ = false; + + // Indicates if the streams is currently sending a message, and should then + // - if message interleaving is not enabled - continue sending from this + // stream until that message has been sent in full. + bool currently_sending_a_message_ = false; + + // The currently active streams, ordered by virtual finish time. + webrtc::flat_set<Stream*, ActiveStreamComparator> active_streams_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_STREAM_SCHEDULER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler_test.cc b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler_test.cc new file mode 100644 index 0000000000..58f0bc4690 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler_test.cc @@ -0,0 +1,740 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/stream_scheduler.h" + +#include <vector> + +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/types.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::Return; +using ::testing::StrictMock; + +constexpr size_t kMtu = 1000; +constexpr size_t kPayloadSize = 4; + +MATCHER_P(HasDataWithMid, mid, "") { + if (!arg.has_value()) { + *result_listener << "There was no produced data"; + return false; + } + + if (arg->data.message_id != mid) { + *result_listener << "the produced data had mid " << *arg->data.message_id + << " and not the expected " << *mid; + return false; + } + + return true; +} + +std::function<absl::optional<SendQueue::DataToSend>(TimeMs, size_t)> +CreateChunk(StreamID sid, MID mid, size_t payload_size = kPayloadSize) { + return [sid, mid, payload_size](TimeMs now, size_t max_size) { + return SendQueue::DataToSend(Data( + sid, SSN(0), mid, FSN(0), PPID(42), std::vector<uint8_t>(payload_size), + Data::IsBeginning(true), Data::IsEnd(true), IsUnordered(true))); + }; +} + +std::map<StreamID, size_t> GetPacketCounts(StreamScheduler& scheduler, + size_t packets_to_generate) { + std::map<StreamID, size_t> packet_counts; + for (size_t i = 0; i < packets_to_generate; ++i) { + absl::optional<SendQueue::DataToSend> data = + scheduler.Produce(TimeMs(0), kMtu); + if (data.has_value()) { + ++packet_counts[data->data.stream_id]; + } + } + return packet_counts; +} + +class MockStreamProducer : public StreamScheduler::StreamProducer { + public: + MOCK_METHOD(absl::optional<SendQueue::DataToSend>, + Produce, + (TimeMs, size_t), + (override)); + MOCK_METHOD(size_t, bytes_to_send_in_next_message, (), (const, override)); +}; + +class TestStream { + public: + TestStream(StreamScheduler& scheduler, + StreamID stream_id, + StreamPriority priority, + size_t packet_size = kPayloadSize) { + EXPECT_CALL(producer_, Produce) + .WillRepeatedly(CreateChunk(stream_id, MID(0), packet_size)); + EXPECT_CALL(producer_, bytes_to_send_in_next_message) + .WillRepeatedly(Return(packet_size)); + stream_ = scheduler.CreateStream(&producer_, stream_id, priority); + stream_->MaybeMakeActive(); + } + + StreamScheduler::Stream& stream() { return *stream_; } + + private: + StrictMock<MockStreamProducer> producer_; + std::unique_ptr<StreamScheduler::Stream> stream_; +}; + +// A scheduler without active streams doesn't produce data. +TEST(StreamSchedulerTest, HasNoActiveStreams) { + StreamScheduler scheduler(kMtu); + + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Stream properties can be set and retrieved +TEST(StreamSchedulerTest, CanSetAndGetStreamProperties) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer; + auto stream = + scheduler.CreateStream(&producer, StreamID(1), StreamPriority(2)); + + EXPECT_EQ(stream->stream_id(), StreamID(1)); + EXPECT_EQ(stream->priority(), StreamPriority(2)); + + stream->SetPriority(StreamPriority(0)); + EXPECT_EQ(stream->priority(), StreamPriority(0)); +} + +// A scheduler with a single stream produced packets from it. +TEST(StreamSchedulerTest, CanProduceFromSingleStream) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer; + EXPECT_CALL(producer, Produce).WillOnce(CreateChunk(StreamID(1), MID(0))); + EXPECT_CALL(producer, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(0)); + auto stream = + scheduler.CreateStream(&producer, StreamID(1), StreamPriority(2)); + stream->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(0))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Switches between two streams after every packet. +TEST(StreamSchedulerTest, WillRoundRobinBetweenStreams) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce(CreateChunk(StreamID(1), MID(101))) + .WillOnce(CreateChunk(StreamID(1), MID(102))); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> producer2; + EXPECT_CALL(producer2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(200))) + .WillOnce(CreateChunk(StreamID(2), MID(201))) + .WillOnce(CreateChunk(StreamID(2), MID(202))); + EXPECT_CALL(producer2, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&producer2, StreamID(2), StreamPriority(2)); + stream2->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Switches between two streams after every packet, but keeps producing from the +// same stream when a packet contains of multiple fragments. +TEST(StreamSchedulerTest, WillRoundRobinOnlyWhenFinishedProducingChunk) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce([](...) { + return SendQueue::DataToSend( + Data(StreamID(1), SSN(0), MID(101), FSN(0), PPID(42), + std::vector<uint8_t>(4), Data::IsBeginning(true), + Data::IsEnd(false), IsUnordered(true))); + }) + .WillOnce([](...) { + return SendQueue::DataToSend( + Data(StreamID(1), SSN(0), MID(101), FSN(0), PPID(42), + std::vector<uint8_t>(4), Data::IsBeginning(false), + Data::IsEnd(false), IsUnordered(true))); + }) + .WillOnce([](...) { + return SendQueue::DataToSend( + Data(StreamID(1), SSN(0), MID(101), FSN(0), PPID(42), + std::vector<uint8_t>(4), Data::IsBeginning(false), + Data::IsEnd(true), IsUnordered(true))); + }) + .WillOnce(CreateChunk(StreamID(1), MID(102))); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> producer2; + EXPECT_CALL(producer2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(200))) + .WillOnce(CreateChunk(StreamID(2), MID(201))) + .WillOnce(CreateChunk(StreamID(2), MID(202))); + EXPECT_CALL(producer2, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&producer2, StreamID(2), StreamPriority(2)); + stream2->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Deactivates a stream before it has finished producing all packets. +TEST(StreamSchedulerTest, StreamsCanBeMadeInactive) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce(CreateChunk(StreamID(1), MID(101))); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)); // hints that there is a MID(2) coming. + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + + // ... but the stream is made inactive before it can be produced. + stream1->MakeInactive(); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Resumes a paused stream - makes a stream active after inactivating it. +TEST(StreamSchedulerTest, SingleStreamCanBeResumed) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer1; + // Callbacks are setup so that they hint that there is a MID(2) coming... + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce(CreateChunk(StreamID(1), MID(101))) + .WillOnce(CreateChunk(StreamID(1), MID(102))); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) // When making active again + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + + stream1->MakeInactive(); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); + stream1->MaybeMakeActive(); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Iterates between streams, where one is suddenly paused and later resumed. +TEST(StreamSchedulerTest, WillRoundRobinWithPausedStream) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce(CreateChunk(StreamID(1), MID(101))) + .WillOnce(CreateChunk(StreamID(1), MID(102))); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> producer2; + EXPECT_CALL(producer2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(200))) + .WillOnce(CreateChunk(StreamID(2), MID(201))) + .WillOnce(CreateChunk(StreamID(2), MID(202))); + EXPECT_CALL(producer2, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&producer2, StreamID(2), StreamPriority(2)); + stream2->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + stream1->MakeInactive(); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + stream1->MaybeMakeActive(); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Verifies that packet counts are evenly distributed in round robin scheduling. +TEST(StreamSchedulerTest, WillDistributeRoundRobinPacketsEvenlyTwoStreams) { + StreamScheduler scheduler(kMtu); + TestStream stream1(scheduler, StreamID(1), StreamPriority(1)); + TestStream stream2(scheduler, StreamID(2), StreamPriority(1)); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 10); + EXPECT_EQ(packet_counts[StreamID(1)], 5U); + EXPECT_EQ(packet_counts[StreamID(2)], 5U); +} + +// Verifies that packet counts are evenly distributed among active streams, +// where a stream is suddenly made inactive, two are added, and then the paused +// stream is resumed. +TEST(StreamSchedulerTest, WillDistributeEvenlyWithPausedAndAddedStreams) { + StreamScheduler scheduler(kMtu); + TestStream stream1(scheduler, StreamID(1), StreamPriority(1)); + TestStream stream2(scheduler, StreamID(2), StreamPriority(1)); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 10); + EXPECT_EQ(packet_counts[StreamID(1)], 5U); + EXPECT_EQ(packet_counts[StreamID(2)], 5U); + + stream2.stream().MakeInactive(); + + TestStream stream3(scheduler, StreamID(3), StreamPriority(1)); + TestStream stream4(scheduler, StreamID(4), StreamPriority(1)); + + std::map<StreamID, size_t> counts2 = GetPacketCounts(scheduler, 15); + EXPECT_EQ(counts2[StreamID(1)], 5U); + EXPECT_EQ(counts2[StreamID(2)], 0U); + EXPECT_EQ(counts2[StreamID(3)], 5U); + EXPECT_EQ(counts2[StreamID(4)], 5U); + + stream2.stream().MaybeMakeActive(); + + std::map<StreamID, size_t> counts3 = GetPacketCounts(scheduler, 20); + EXPECT_EQ(counts3[StreamID(1)], 5U); + EXPECT_EQ(counts3[StreamID(2)], 5U); + EXPECT_EQ(counts3[StreamID(3)], 5U); + EXPECT_EQ(counts3[StreamID(4)], 5U); +} + +// Degrades to fair queuing with streams having identical priority. +TEST(StreamSchedulerTest, WillDoFairQueuingWithSamePriority) { + StreamScheduler scheduler(kMtu); + scheduler.EnableMessageInterleaving(true); + + constexpr size_t kSmallPacket = 30; + constexpr size_t kLargePacket = 70; + + StrictMock<MockStreamProducer> callback1; + EXPECT_CALL(callback1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100), kSmallPacket)) + .WillOnce(CreateChunk(StreamID(1), MID(101), kSmallPacket)) + .WillOnce(CreateChunk(StreamID(1), MID(102), kSmallPacket)); + EXPECT_CALL(callback1, bytes_to_send_in_next_message) + .WillOnce(Return(kSmallPacket)) // When making active + .WillOnce(Return(kSmallPacket)) + .WillOnce(Return(kSmallPacket)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&callback1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> callback2; + EXPECT_CALL(callback2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(200), kLargePacket)) + .WillOnce(CreateChunk(StreamID(2), MID(201), kLargePacket)) + .WillOnce(CreateChunk(StreamID(2), MID(202), kLargePacket)); + EXPECT_CALL(callback2, bytes_to_send_in_next_message) + .WillOnce(Return(kLargePacket)) // When making active + .WillOnce(Return(kLargePacket)) + .WillOnce(Return(kLargePacket)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&callback2, StreamID(2), StreamPriority(2)); + stream2->MaybeMakeActive(); + + // t = 30 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + // t = 60 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + // t = 70 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + // t = 90 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + // t = 140 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + // t = 210 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Will do weighted fair queuing with three streams having different priority. +TEST(StreamSchedulerTest, WillDoWeightedFairQueuingSameSizeDifferentPriority) { + StreamScheduler scheduler(kMtu); + scheduler.EnableMessageInterleaving(true); + + StrictMock<MockStreamProducer> callback1; + EXPECT_CALL(callback1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce(CreateChunk(StreamID(1), MID(101))) + .WillOnce(CreateChunk(StreamID(1), MID(102))); + EXPECT_CALL(callback1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + // Priority 125 -> allowed to produce every 1000/125 ~= 80 time units. + auto stream1 = + scheduler.CreateStream(&callback1, StreamID(1), StreamPriority(125)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> callback2; + EXPECT_CALL(callback2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(200))) + .WillOnce(CreateChunk(StreamID(2), MID(201))) + .WillOnce(CreateChunk(StreamID(2), MID(202))); + EXPECT_CALL(callback2, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + // Priority 200 -> allowed to produce every 1000/200 ~= 50 time units. + auto stream2 = + scheduler.CreateStream(&callback2, StreamID(2), StreamPriority(200)); + stream2->MaybeMakeActive(); + + StrictMock<MockStreamProducer> callback3; + EXPECT_CALL(callback3, Produce) + .WillOnce(CreateChunk(StreamID(3), MID(300))) + .WillOnce(CreateChunk(StreamID(3), MID(301))) + .WillOnce(CreateChunk(StreamID(3), MID(302))); + EXPECT_CALL(callback3, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + // Priority 500 -> allowed to produce every 1000/500 ~= 20 time units. + auto stream3 = + scheduler.CreateStream(&callback3, StreamID(3), StreamPriority(500)); + stream3->MaybeMakeActive(); + + // t ~= 20 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(300))); + // t ~= 40 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(301))); + // t ~= 50 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + // t ~= 60 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(302))); + // t ~= 80 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + // t ~= 100 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + // t ~= 150 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + // t ~= 160 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + // t ~= 240 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Will do weighted fair queuing with three streams having different priority +// and sending different payload sizes. +TEST(StreamSchedulerTest, WillDoWeightedFairQueuingDifferentSizeAndPriority) { + StreamScheduler scheduler(kMtu); + scheduler.EnableMessageInterleaving(true); + + constexpr size_t kSmallPacket = 20; + constexpr size_t kMediumPacket = 50; + constexpr size_t kLargePacket = 70; + + // Stream with priority = 125 -> inverse weight ~=80 + StrictMock<MockStreamProducer> callback1; + EXPECT_CALL(callback1, Produce) + // virtual finish time ~ 0 + 50 * 80 = 4000 + .WillOnce(CreateChunk(StreamID(1), MID(100), kMediumPacket)) + // virtual finish time ~ 4000 + 20 * 80 = 5600 + .WillOnce(CreateChunk(StreamID(1), MID(101), kSmallPacket)) + // virtual finish time ~ 5600 + 70 * 80 = 11200 + .WillOnce(CreateChunk(StreamID(1), MID(102), kLargePacket)); + EXPECT_CALL(callback1, bytes_to_send_in_next_message) + .WillOnce(Return(kMediumPacket)) // When making active + .WillOnce(Return(kSmallPacket)) + .WillOnce(Return(kLargePacket)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&callback1, StreamID(1), StreamPriority(125)); + stream1->MaybeMakeActive(); + + // Stream with priority = 200 -> inverse weight ~=50 + StrictMock<MockStreamProducer> callback2; + EXPECT_CALL(callback2, Produce) + // virtual finish time ~ 0 + 50 * 50 = 2500 + .WillOnce(CreateChunk(StreamID(2), MID(200), kMediumPacket)) + // virtual finish time ~ 2500 + 70 * 50 = 6000 + .WillOnce(CreateChunk(StreamID(2), MID(201), kLargePacket)) + // virtual finish time ~ 6000 + 20 * 50 = 7000 + .WillOnce(CreateChunk(StreamID(2), MID(202), kSmallPacket)); + EXPECT_CALL(callback2, bytes_to_send_in_next_message) + .WillOnce(Return(kMediumPacket)) // When making active + .WillOnce(Return(kLargePacket)) + .WillOnce(Return(kSmallPacket)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&callback2, StreamID(2), StreamPriority(200)); + stream2->MaybeMakeActive(); + + // Stream with priority = 500 -> inverse weight ~=20 + StrictMock<MockStreamProducer> callback3; + EXPECT_CALL(callback3, Produce) + // virtual finish time ~ 0 + 20 * 20 = 400 + .WillOnce(CreateChunk(StreamID(3), MID(300), kSmallPacket)) + // virtual finish time ~ 400 + 50 * 20 = 1400 + .WillOnce(CreateChunk(StreamID(3), MID(301), kMediumPacket)) + // virtual finish time ~ 1400 + 70 * 20 = 2800 + .WillOnce(CreateChunk(StreamID(3), MID(302), kLargePacket)); + EXPECT_CALL(callback3, bytes_to_send_in_next_message) + .WillOnce(Return(kSmallPacket)) // When making active + .WillOnce(Return(kMediumPacket)) + .WillOnce(Return(kLargePacket)) + .WillOnce(Return(0)); + auto stream3 = + scheduler.CreateStream(&callback3, StreamID(3), StreamPriority(500)); + stream3->MaybeMakeActive(); + + // t ~= 400 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(300))); + // t ~= 1400 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(301))); + // t ~= 2500 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + // t ~= 2800 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(302))); + // t ~= 4000 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + // t ~= 5600 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + // t ~= 6000 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + // t ~= 7000 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + // t ~= 11200 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} +TEST(StreamSchedulerTest, WillDistributeWFQPacketsInTwoStreamsByPriority) { + // A simple test with two streams of different priority, but sending packets + // of identical size. Verifies that the ratio of sent packets represent their + // priority. + StreamScheduler scheduler(kMtu); + scheduler.EnableMessageInterleaving(true); + + TestStream stream1(scheduler, StreamID(1), StreamPriority(100), kPayloadSize); + TestStream stream2(scheduler, StreamID(2), StreamPriority(200), kPayloadSize); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 15); + EXPECT_EQ(packet_counts[StreamID(1)], 5U); + EXPECT_EQ(packet_counts[StreamID(2)], 10U); +} + +TEST(StreamSchedulerTest, WillDistributeWFQPacketsInFourStreamsByPriority) { + // Same as `WillDistributeWFQPacketsInTwoStreamsByPriority` but with more + // streams. + StreamScheduler scheduler(kMtu); + scheduler.EnableMessageInterleaving(true); + + TestStream stream1(scheduler, StreamID(1), StreamPriority(100), kPayloadSize); + TestStream stream2(scheduler, StreamID(2), StreamPriority(200), kPayloadSize); + TestStream stream3(scheduler, StreamID(3), StreamPriority(300), kPayloadSize); + TestStream stream4(scheduler, StreamID(4), StreamPriority(400), kPayloadSize); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 50); + EXPECT_EQ(packet_counts[StreamID(1)], 5U); + EXPECT_EQ(packet_counts[StreamID(2)], 10U); + EXPECT_EQ(packet_counts[StreamID(3)], 15U); + EXPECT_EQ(packet_counts[StreamID(4)], 20U); +} + +TEST(StreamSchedulerTest, WillDistributeFromTwoStreamsFairly) { + // A simple test with two streams of different priority, but sending packets + // of different size. Verifies that the ratio of total packet payload + // represent their priority. + // In this example, + // * stream1 has priority 100 and sends packets of size 8 + // * stream2 has priority 400 and sends packets of size 4 + // With round robin, stream1 would get twice as many payload bytes on the wire + // as stream2, but with WFQ and a 4x priority increase, stream2 should 4x as + // many payload bytes on the wire. That translates to stream2 getting 8x as + // many packets on the wire as they are half as large. + StreamScheduler scheduler(kMtu); + // Enable WFQ scheduler. + scheduler.EnableMessageInterleaving(true); + + TestStream stream1(scheduler, StreamID(1), StreamPriority(100), + /*packet_size=*/8); + TestStream stream2(scheduler, StreamID(2), StreamPriority(400), + /*packet_size=*/4); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 90); + EXPECT_EQ(packet_counts[StreamID(1)], 10U); + EXPECT_EQ(packet_counts[StreamID(2)], 80U); +} + +TEST(StreamSchedulerTest, WillDistributeFromFourStreamsFairly) { + // Same as `WillDistributeWeightedFairFromTwoStreamsFairly` but more + // complicated. + StreamScheduler scheduler(kMtu); + // Enable WFQ scheduler. + scheduler.EnableMessageInterleaving(true); + + TestStream stream1(scheduler, StreamID(1), StreamPriority(100), + /*packet_size=*/10); + TestStream stream2(scheduler, StreamID(2), StreamPriority(200), + /*packet_size=*/10); + TestStream stream3(scheduler, StreamID(3), StreamPriority(200), + /*packet_size=*/20); + TestStream stream4(scheduler, StreamID(4), StreamPriority(400), + /*packet_size=*/30); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 80); + // 15 packets * 10 bytes = 150 bytes at priority 100. + EXPECT_EQ(packet_counts[StreamID(1)], 15U); + // 30 packets * 10 bytes = 300 bytes at priority 200. + EXPECT_EQ(packet_counts[StreamID(2)], 30U); + // 15 packets * 20 bytes = 300 bytes at priority 200. + EXPECT_EQ(packet_counts[StreamID(3)], 15U); + // 20 packets * 30 bytes = 600 bytes at priority 400. + EXPECT_EQ(packet_counts[StreamID(4)], 20U); +} + +// Sending large messages with small MTU will fragment the messages and produce +// a first fragment not larger than the MTU, and will then not first send from +// the stream with the smallest message, as their first fragment will be equally +// small for both streams. See `LargeMessageWithLargeMtu` for the same test, but +// with a larger MTU. +TEST(StreamSchedulerTest, SendLargeMessageWithSmallMtu) { + StreamScheduler scheduler(100 + SctpPacket::kHeaderSize + + IDataChunk::kHeaderSize); + scheduler.EnableMessageInterleaving(true); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(0), 100)) + .WillOnce(CreateChunk(StreamID(1), MID(0), 100)); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(200)) // When making active + .WillOnce(Return(100)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(1)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> producer2; + EXPECT_CALL(producer2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(1), 100)) + .WillOnce(CreateChunk(StreamID(2), MID(1), 50)); + EXPECT_CALL(producer2, bytes_to_send_in_next_message) + .WillOnce(Return(150)) // When making active + .WillOnce(Return(50)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&producer2, StreamID(2), StreamPriority(1)); + stream2->MaybeMakeActive(); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(0))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(1))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(1))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(0))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Sending large messages with large MTU will not fragment messages and will +// send the message first from the stream that has the smallest message. +TEST(StreamSchedulerTest, SendLargeMessageWithLargeMtu) { + StreamScheduler scheduler(200 + SctpPacket::kHeaderSize + + IDataChunk::kHeaderSize); + scheduler.EnableMessageInterleaving(true); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(0), 200)); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(200)) // When making active + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(1)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> producer2; + EXPECT_CALL(producer2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(1), 150)); + EXPECT_CALL(producer2, bytes_to_send_in_next_message) + .WillOnce(Return(150)) // When making active + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&producer2, StreamID(2), StreamPriority(1)); + stream2->MaybeMakeActive(); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(1))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(0))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +} // namespace +} // namespace dcsctp |