summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc')
-rw-r--r--third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc348
1 files changed, 348 insertions, 0 deletions
diff --git a/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc
new file mode 100644
index 0000000000..dce6c90131
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc
@@ -0,0 +1,348 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/rx/traditional_reassembly_streams.h"
+
+#include <stddef.h>
+
+#include <cstdint>
+#include <functional>
+#include <iterator>
+#include <map>
+#include <numeric>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/types/optional.h"
+#include "api/array_view.h"
+#include "net/dcsctp/common/sequence_numbers.h"
+#include "net/dcsctp/packet/chunk/forward_tsn_common.h"
+#include "net/dcsctp/packet/data.h"
+#include "net/dcsctp/public/dcsctp_message.h"
+#include "rtc_base/logging.h"
+
+namespace dcsctp {
+namespace {
+
+// Given a map (`chunks`) and an iterator to within that map (`iter`), this
+// function will return an iterator to the first chunk in that message, which
+// has the `is_beginning` flag set. If there are any gaps, or if the beginning
+// can't be found, `absl::nullopt` is returned.
+absl::optional<std::map<UnwrappedTSN, Data>::iterator> FindBeginning(
+ const std::map<UnwrappedTSN, Data>& chunks,
+ std::map<UnwrappedTSN, Data>::iterator iter) {
+ UnwrappedTSN prev_tsn = iter->first;
+ for (;;) {
+ if (iter->second.is_beginning) {
+ return iter;
+ }
+ if (iter == chunks.begin()) {
+ return absl::nullopt;
+ }
+ --iter;
+ if (iter->first.next_value() != prev_tsn) {
+ return absl::nullopt;
+ }
+ prev_tsn = iter->first;
+ }
+}
+
+// Given a map (`chunks`) and an iterator to within that map (`iter`), this
+// function will return an iterator to the chunk after the last chunk in that
+// message, which has the `is_end` flag set. If there are any gaps, or if the
+// end can't be found, `absl::nullopt` is returned.
+absl::optional<std::map<UnwrappedTSN, Data>::iterator> FindEnd(
+ std::map<UnwrappedTSN, Data>& chunks,
+ std::map<UnwrappedTSN, Data>::iterator iter) {
+ UnwrappedTSN prev_tsn = iter->first;
+ for (;;) {
+ if (iter->second.is_end) {
+ return ++iter;
+ }
+ ++iter;
+ if (iter == chunks.end()) {
+ return absl::nullopt;
+ }
+ if (iter->first != prev_tsn.next_value()) {
+ return absl::nullopt;
+ }
+ prev_tsn = iter->first;
+ }
+}
+} // namespace
+
+TraditionalReassemblyStreams::TraditionalReassemblyStreams(
+ absl::string_view log_prefix,
+ OnAssembledMessage on_assembled_message)
+ : log_prefix_(log_prefix),
+ on_assembled_message_(std::move(on_assembled_message)) {}
+
+int TraditionalReassemblyStreams::UnorderedStream::Add(UnwrappedTSN tsn,
+ Data data) {
+ int queued_bytes = data.size();
+ auto [it, inserted] = chunks_.emplace(tsn, std::move(data));
+ if (!inserted) {
+ return 0;
+ }
+
+ queued_bytes -= TryToAssembleMessage(it);
+
+ return queued_bytes;
+}
+
+size_t TraditionalReassemblyStreams::UnorderedStream::TryToAssembleMessage(
+ ChunkMap::iterator iter) {
+ // TODO(boivie): This method is O(N) with the number of fragments in a
+ // message, which can be inefficient for very large values of N. This could be
+ // optimized by e.g. only trying to assemble a message once _any_ beginning
+ // and _any_ end has been found.
+ absl::optional<ChunkMap::iterator> start = FindBeginning(chunks_, iter);
+ if (!start.has_value()) {
+ return 0;
+ }
+ absl::optional<ChunkMap::iterator> end = FindEnd(chunks_, iter);
+ if (!end.has_value()) {
+ return 0;
+ }
+
+ size_t bytes_assembled = AssembleMessage(*start, *end);
+ chunks_.erase(*start, *end);
+ return bytes_assembled;
+}
+
+size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage(
+ const ChunkMap::iterator start,
+ const ChunkMap::iterator end) {
+ size_t count = std::distance(start, end);
+
+ if (count == 1) {
+ // Fast path - zero-copy
+ const Data& data = start->second;
+ size_t payload_size = start->second.size();
+ UnwrappedTSN tsns[1] = {start->first};
+ DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload));
+ parent_.on_assembled_message_(tsns, std::move(message));
+ return payload_size;
+ }
+
+ // Slow path - will need to concatenate the payload.
+ std::vector<UnwrappedTSN> tsns;
+ std::vector<uint8_t> payload;
+
+ size_t payload_size = std::accumulate(
+ start, end, 0,
+ [](size_t v, const auto& p) { return v + p.second.size(); });
+
+ tsns.reserve(count);
+ payload.reserve(payload_size);
+ for (auto it = start; it != end; ++it) {
+ const Data& data = it->second;
+ tsns.push_back(it->first);
+ payload.insert(payload.end(), data.payload.begin(), data.payload.end());
+ }
+
+ DcSctpMessage message(start->second.stream_id, start->second.ppid,
+ std::move(payload));
+ parent_.on_assembled_message_(tsns, std::move(message));
+
+ return payload_size;
+}
+
+size_t TraditionalReassemblyStreams::UnorderedStream::EraseTo(
+ UnwrappedTSN tsn) {
+ auto end_iter = chunks_.upper_bound(tsn);
+ size_t removed_bytes = std::accumulate(
+ chunks_.begin(), end_iter, 0,
+ [](size_t r, const auto& p) { return r + p.second.size(); });
+
+ chunks_.erase(chunks_.begin(), end_iter);
+ return removed_bytes;
+}
+
+size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessage() {
+ if (chunks_by_ssn_.empty() || chunks_by_ssn_.begin()->first != next_ssn_) {
+ return 0;
+ }
+
+ ChunkMap& chunks = chunks_by_ssn_.begin()->second;
+
+ if (!chunks.begin()->second.is_beginning || !chunks.rbegin()->second.is_end) {
+ return 0;
+ }
+
+ uint32_t tsn_diff =
+ UnwrappedTSN::Difference(chunks.rbegin()->first, chunks.begin()->first);
+ if (tsn_diff != chunks.size() - 1) {
+ return 0;
+ }
+
+ size_t assembled_bytes = AssembleMessage(chunks.begin(), chunks.end());
+ chunks_by_ssn_.erase(chunks_by_ssn_.begin());
+ next_ssn_.Increment();
+ return assembled_bytes;
+}
+
+size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessages() {
+ size_t assembled_bytes = 0;
+
+ for (;;) {
+ size_t assembled_bytes_this_iter = TryToAssembleMessage();
+ if (assembled_bytes_this_iter == 0) {
+ break;
+ }
+ assembled_bytes += assembled_bytes_this_iter;
+ }
+ return assembled_bytes;
+}
+
+int TraditionalReassemblyStreams::OrderedStream::Add(UnwrappedTSN tsn,
+ Data data) {
+ int queued_bytes = data.size();
+
+ UnwrappedSSN ssn = ssn_unwrapper_.Unwrap(data.ssn);
+ auto [unused, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data));
+ if (!inserted) {
+ return 0;
+ }
+
+ if (ssn == next_ssn_) {
+ queued_bytes -= TryToAssembleMessages();
+ }
+
+ return queued_bytes;
+}
+
+size_t TraditionalReassemblyStreams::OrderedStream::EraseTo(SSN ssn) {
+ UnwrappedSSN unwrapped_ssn = ssn_unwrapper_.Unwrap(ssn);
+
+ auto end_iter = chunks_by_ssn_.upper_bound(unwrapped_ssn);
+ size_t removed_bytes = std::accumulate(
+ chunks_by_ssn_.begin(), end_iter, 0, [](size_t r1, const auto& p) {
+ return r1 +
+ absl::c_accumulate(p.second, 0, [](size_t r2, const auto& q) {
+ return r2 + q.second.size();
+ });
+ });
+ chunks_by_ssn_.erase(chunks_by_ssn_.begin(), end_iter);
+
+ if (unwrapped_ssn >= next_ssn_) {
+ unwrapped_ssn.Increment();
+ next_ssn_ = unwrapped_ssn;
+ }
+
+ removed_bytes += TryToAssembleMessages();
+ return removed_bytes;
+}
+
+int TraditionalReassemblyStreams::Add(UnwrappedTSN tsn, Data data) {
+ if (data.is_unordered) {
+ auto it = unordered_streams_.try_emplace(data.stream_id, this).first;
+ return it->second.Add(tsn, std::move(data));
+ }
+
+ auto it = ordered_streams_.try_emplace(data.stream_id, this).first;
+ return it->second.Add(tsn, std::move(data));
+}
+
+size_t TraditionalReassemblyStreams::HandleForwardTsn(
+ UnwrappedTSN new_cumulative_ack_tsn,
+ rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) {
+ size_t bytes_removed = 0;
+ // The `skipped_streams` only cover ordered messages - need to
+ // iterate all unordered streams manually to remove those chunks.
+ for (auto& [unused, stream] : unordered_streams_) {
+ bytes_removed += stream.EraseTo(new_cumulative_ack_tsn);
+ }
+
+ for (const auto& skipped_stream : skipped_streams) {
+ auto it =
+ ordered_streams_.try_emplace(skipped_stream.stream_id, this).first;
+ bytes_removed += it->second.EraseTo(skipped_stream.ssn);
+ }
+
+ return bytes_removed;
+}
+
+void TraditionalReassemblyStreams::ResetStreams(
+ rtc::ArrayView<const StreamID> stream_ids) {
+ if (stream_ids.empty()) {
+ for (auto& [stream_id, stream] : ordered_streams_) {
+ RTC_DLOG(LS_VERBOSE) << log_prefix_
+ << "Resetting implicit stream_id=" << *stream_id;
+ stream.Reset();
+ }
+ } else {
+ for (StreamID stream_id : stream_ids) {
+ auto it = ordered_streams_.find(stream_id);
+ if (it != ordered_streams_.end()) {
+ RTC_DLOG(LS_VERBOSE)
+ << log_prefix_ << "Resetting explicit stream_id=" << *stream_id;
+ it->second.Reset();
+ }
+ }
+ }
+}
+
+HandoverReadinessStatus TraditionalReassemblyStreams::GetHandoverReadiness()
+ const {
+ HandoverReadinessStatus status;
+ for (const auto& [unused, stream] : ordered_streams_) {
+ if (stream.has_unassembled_chunks()) {
+ status.Add(HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks);
+ break;
+ }
+ }
+ for (const auto& [unused, stream] : unordered_streams_) {
+ if (stream.has_unassembled_chunks()) {
+ status.Add(
+ HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks);
+ break;
+ }
+ }
+ return status;
+}
+
+void TraditionalReassemblyStreams::AddHandoverState(
+ DcSctpSocketHandoverState& state) {
+ for (const auto& [stream_id, stream] : ordered_streams_) {
+ DcSctpSocketHandoverState::OrderedStream state_stream;
+ state_stream.id = stream_id.value();
+ state_stream.next_ssn = stream.next_ssn().value();
+ state.rx.ordered_streams.push_back(std::move(state_stream));
+ }
+ for (const auto& [stream_id, unused] : unordered_streams_) {
+ DcSctpSocketHandoverState::UnorderedStream state_stream;
+ state_stream.id = stream_id.value();
+ state.rx.unordered_streams.push_back(std::move(state_stream));
+ }
+}
+
+void TraditionalReassemblyStreams::RestoreFromState(
+ const DcSctpSocketHandoverState& state) {
+ // Validate that the component is in pristine state.
+ RTC_DCHECK(ordered_streams_.empty());
+ RTC_DCHECK(unordered_streams_.empty());
+
+ for (const DcSctpSocketHandoverState::OrderedStream& state_stream :
+ state.rx.ordered_streams) {
+ ordered_streams_.emplace(
+ std::piecewise_construct,
+ std::forward_as_tuple(StreamID(state_stream.id)),
+ std::forward_as_tuple(this, SSN(state_stream.next_ssn)));
+ }
+ for (const DcSctpSocketHandoverState::UnorderedStream& state_stream :
+ state.rx.unordered_streams) {
+ unordered_streams_.emplace(std::piecewise_construct,
+ std::forward_as_tuple(StreamID(state_stream.id)),
+ std::forward_as_tuple(this));
+ }
+}
+
+} // namespace dcsctp