/* * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "net/dcsctp/rx/traditional_reassembly_streams.h" #include #include #include #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/types/optional.h" #include "api/array_view.h" #include "net/dcsctp/common/sequence_numbers.h" #include "net/dcsctp/packet/chunk/forward_tsn_common.h" #include "net/dcsctp/packet/data.h" #include "net/dcsctp/public/dcsctp_message.h" #include "rtc_base/logging.h" namespace dcsctp { namespace { // Given a map (`chunks`) and an iterator to within that map (`iter`), this // function will return an iterator to the first chunk in that message, which // has the `is_beginning` flag set. If there are any gaps, or if the beginning // can't be found, `absl::nullopt` is returned. absl::optional::iterator> FindBeginning( const std::map& chunks, std::map::iterator iter) { UnwrappedTSN prev_tsn = iter->first; for (;;) { if (iter->second.is_beginning) { return iter; } if (iter == chunks.begin()) { return absl::nullopt; } --iter; if (iter->first.next_value() != prev_tsn) { return absl::nullopt; } prev_tsn = iter->first; } } // Given a map (`chunks`) and an iterator to within that map (`iter`), this // function will return an iterator to the chunk after the last chunk in that // message, which has the `is_end` flag set. If there are any gaps, or if the // end can't be found, `absl::nullopt` is returned. absl::optional::iterator> FindEnd( std::map& chunks, std::map::iterator iter) { UnwrappedTSN prev_tsn = iter->first; for (;;) { if (iter->second.is_end) { return ++iter; } ++iter; if (iter == chunks.end()) { return absl::nullopt; } if (iter->first != prev_tsn.next_value()) { return absl::nullopt; } prev_tsn = iter->first; } } } // namespace TraditionalReassemblyStreams::TraditionalReassemblyStreams( absl::string_view log_prefix, OnAssembledMessage on_assembled_message) : log_prefix_(log_prefix), on_assembled_message_(std::move(on_assembled_message)) {} int TraditionalReassemblyStreams::UnorderedStream::Add(UnwrappedTSN tsn, Data data) { int queued_bytes = data.size(); auto [it, inserted] = chunks_.emplace(tsn, std::move(data)); if (!inserted) { return 0; } queued_bytes -= TryToAssembleMessage(it); return queued_bytes; } size_t TraditionalReassemblyStreams::UnorderedStream::TryToAssembleMessage( ChunkMap::iterator iter) { // TODO(boivie): This method is O(N) with the number of fragments in a // message, which can be inefficient for very large values of N. This could be // optimized by e.g. only trying to assemble a message once _any_ beginning // and _any_ end has been found. absl::optional start = FindBeginning(chunks_, iter); if (!start.has_value()) { return 0; } absl::optional end = FindEnd(chunks_, iter); if (!end.has_value()) { return 0; } size_t bytes_assembled = AssembleMessage(*start, *end); chunks_.erase(*start, *end); return bytes_assembled; } size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage( const ChunkMap::iterator start, const ChunkMap::iterator end) { size_t count = std::distance(start, end); if (count == 1) { // Fast path - zero-copy const Data& data = start->second; size_t payload_size = start->second.size(); UnwrappedTSN tsns[1] = {start->first}; DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload)); parent_.on_assembled_message_(tsns, std::move(message)); return payload_size; } // Slow path - will need to concatenate the payload. std::vector tsns; std::vector payload; size_t payload_size = std::accumulate( start, end, 0, [](size_t v, const auto& p) { return v + p.second.size(); }); tsns.reserve(count); payload.reserve(payload_size); for (auto it = start; it != end; ++it) { const Data& data = it->second; tsns.push_back(it->first); payload.insert(payload.end(), data.payload.begin(), data.payload.end()); } DcSctpMessage message(start->second.stream_id, start->second.ppid, std::move(payload)); parent_.on_assembled_message_(tsns, std::move(message)); return payload_size; } size_t TraditionalReassemblyStreams::UnorderedStream::EraseTo( UnwrappedTSN tsn) { auto end_iter = chunks_.upper_bound(tsn); size_t removed_bytes = std::accumulate( chunks_.begin(), end_iter, 0, [](size_t r, const auto& p) { return r + p.second.size(); }); chunks_.erase(chunks_.begin(), end_iter); return removed_bytes; } size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessage() { if (chunks_by_ssn_.empty() || chunks_by_ssn_.begin()->first != next_ssn_) { return 0; } ChunkMap& chunks = chunks_by_ssn_.begin()->second; if (!chunks.begin()->second.is_beginning || !chunks.rbegin()->second.is_end) { return 0; } uint32_t tsn_diff = UnwrappedTSN::Difference(chunks.rbegin()->first, chunks.begin()->first); if (tsn_diff != chunks.size() - 1) { return 0; } size_t assembled_bytes = AssembleMessage(chunks.begin(), chunks.end()); chunks_by_ssn_.erase(chunks_by_ssn_.begin()); next_ssn_.Increment(); return assembled_bytes; } size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessages() { size_t assembled_bytes = 0; for (;;) { size_t assembled_bytes_this_iter = TryToAssembleMessage(); if (assembled_bytes_this_iter == 0) { break; } assembled_bytes += assembled_bytes_this_iter; } return assembled_bytes; } int TraditionalReassemblyStreams::OrderedStream::Add(UnwrappedTSN tsn, Data data) { int queued_bytes = data.size(); UnwrappedSSN ssn = ssn_unwrapper_.Unwrap(data.ssn); auto [unused, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data)); if (!inserted) { return 0; } if (ssn == next_ssn_) { queued_bytes -= TryToAssembleMessages(); } return queued_bytes; } size_t TraditionalReassemblyStreams::OrderedStream::EraseTo(SSN ssn) { UnwrappedSSN unwrapped_ssn = ssn_unwrapper_.Unwrap(ssn); auto end_iter = chunks_by_ssn_.upper_bound(unwrapped_ssn); size_t removed_bytes = std::accumulate( chunks_by_ssn_.begin(), end_iter, 0, [](size_t r1, const auto& p) { return r1 + absl::c_accumulate(p.second, 0, [](size_t r2, const auto& q) { return r2 + q.second.size(); }); }); chunks_by_ssn_.erase(chunks_by_ssn_.begin(), end_iter); if (unwrapped_ssn >= next_ssn_) { unwrapped_ssn.Increment(); next_ssn_ = unwrapped_ssn; } removed_bytes += TryToAssembleMessages(); return removed_bytes; } int TraditionalReassemblyStreams::Add(UnwrappedTSN tsn, Data data) { if (data.is_unordered) { auto it = unordered_streams_.try_emplace(data.stream_id, this).first; return it->second.Add(tsn, std::move(data)); } auto it = ordered_streams_.try_emplace(data.stream_id, this).first; return it->second.Add(tsn, std::move(data)); } size_t TraditionalReassemblyStreams::HandleForwardTsn( UnwrappedTSN new_cumulative_ack_tsn, rtc::ArrayView skipped_streams) { size_t bytes_removed = 0; // The `skipped_streams` only cover ordered messages - need to // iterate all unordered streams manually to remove those chunks. for (auto& [unused, stream] : unordered_streams_) { bytes_removed += stream.EraseTo(new_cumulative_ack_tsn); } for (const auto& skipped_stream : skipped_streams) { auto it = ordered_streams_.try_emplace(skipped_stream.stream_id, this).first; bytes_removed += it->second.EraseTo(skipped_stream.ssn); } return bytes_removed; } void TraditionalReassemblyStreams::ResetStreams( rtc::ArrayView stream_ids) { if (stream_ids.empty()) { for (auto& [stream_id, stream] : ordered_streams_) { RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Resetting implicit stream_id=" << *stream_id; stream.Reset(); } } else { for (StreamID stream_id : stream_ids) { auto it = ordered_streams_.find(stream_id); if (it != ordered_streams_.end()) { RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Resetting explicit stream_id=" << *stream_id; it->second.Reset(); } } } } HandoverReadinessStatus TraditionalReassemblyStreams::GetHandoverReadiness() const { HandoverReadinessStatus status; for (const auto& [unused, stream] : ordered_streams_) { if (stream.has_unassembled_chunks()) { status.Add(HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks); break; } } for (const auto& [unused, stream] : unordered_streams_) { if (stream.has_unassembled_chunks()) { status.Add( HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks); break; } } return status; } void TraditionalReassemblyStreams::AddHandoverState( DcSctpSocketHandoverState& state) { for (const auto& [stream_id, stream] : ordered_streams_) { DcSctpSocketHandoverState::OrderedStream state_stream; state_stream.id = stream_id.value(); state_stream.next_ssn = stream.next_ssn().value(); state.rx.ordered_streams.push_back(std::move(state_stream)); } for (const auto& [stream_id, unused] : unordered_streams_) { DcSctpSocketHandoverState::UnorderedStream state_stream; state_stream.id = stream_id.value(); state.rx.unordered_streams.push_back(std::move(state_stream)); } } void TraditionalReassemblyStreams::RestoreFromState( const DcSctpSocketHandoverState& state) { // Validate that the component is in pristine state. RTC_DCHECK(ordered_streams_.empty()); RTC_DCHECK(unordered_streams_.empty()); for (const DcSctpSocketHandoverState::OrderedStream& state_stream : state.rx.ordered_streams) { ordered_streams_.emplace( std::piecewise_construct, std::forward_as_tuple(StreamID(state_stream.id)), std::forward_as_tuple(this, SSN(state_stream.next_ssn))); } for (const DcSctpSocketHandoverState::UnorderedStream& state_stream : state.rx.unordered_streams) { unordered_streams_.emplace(std::piecewise_construct, std::forward_as_tuple(StreamID(state_stream.id)), std::forward_as_tuple(this)); } } } // namespace dcsctp