summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/net/dcsctp
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-15 03:35:49 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-15 03:35:49 +0000
commitd8bbc7858622b6d9c278469aab701ca0b609cddf (patch)
treeeff41dc61d9f714852212739e6b3738b82a2af87 /third_party/libwebrtc/net/dcsctp
parentReleasing progress-linux version 125.0.3-1~progress7.99u1. (diff)
downloadfirefox-d8bbc7858622b6d9c278469aab701ca0b609cddf.tar.xz
firefox-d8bbc7858622b6d9c278469aab701ca0b609cddf.zip
Merging upstream version 126.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/libwebrtc/net/dcsctp')
-rw-r--r--third_party/libwebrtc/net/dcsctp/public/dcsctp_socket.h11
-rw-r--r--third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket.h8
-rw-r--r--third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc57
-rw-r--r--third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.h6
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/BUILD.gn2
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc119
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h17
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc130
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h10
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc172
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc52
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/state_cookie.h28
-rw-r--r--third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc14
-rw-r--r--third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.cc5
14 files changed, 456 insertions, 175 deletions
diff --git a/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket.h b/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket.h
index d0a81eaeb2..9989ae8d43 100644
--- a/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket.h
+++ b/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket.h
@@ -13,6 +13,7 @@
#include <cstdint>
#include <memory>
#include <utility>
+#include <vector>
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
@@ -577,6 +578,16 @@ class DcSctpSocketInterface {
virtual SendStatus Send(DcSctpMessage message,
const SendOptions& send_options) = 0;
+ // Sends the messages `messages` using the provided send options.
+ // Sending a message is an asynchronous operation, and the `OnError` callback
+ // may be invoked to indicate any errors in sending the message.
+ //
+ // This has identical semantics to Send, except that it may coalesce many
+ // messages into a single SCTP packet if they would fit.
+ virtual std::vector<SendStatus> SendMany(
+ rtc::ArrayView<DcSctpMessage> messages,
+ const SendOptions& send_options) = 0;
+
// Resetting streams is an asynchronous operation and the results will
// be notified using `DcSctpSocketCallbacks::OnStreamsResetDone()` on success
// and `DcSctpSocketCallbacks::OnStreamsResetFailed()` on failure. Note that
diff --git a/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket.h b/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket.h
index 0fd572bd94..c71c3ae16f 100644
--- a/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket.h
+++ b/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket.h
@@ -10,6 +10,8 @@
#ifndef NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_H_
#define NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_H_
+#include <vector>
+
#include "net/dcsctp/public/dcsctp_socket.h"
#include "test/gmock.h"
@@ -56,6 +58,12 @@ class MockDcSctpSocket : public DcSctpSocketInterface {
(DcSctpMessage message, const SendOptions& send_options),
(override));
+ MOCK_METHOD(std::vector<SendStatus>,
+ SendMany,
+ (rtc::ArrayView<DcSctpMessage> messages,
+ const SendOptions& send_options),
+ (override));
+
MOCK_METHOD(ResetStreamsStatus,
ResetStreams,
(rtc::ArrayView<const StreamID> outgoing_streams),
diff --git a/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc
index dce6c90131..c94691f0db 100644
--- a/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc
+++ b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc
@@ -86,6 +86,11 @@ TraditionalReassemblyStreams::TraditionalReassemblyStreams(
int TraditionalReassemblyStreams::UnorderedStream::Add(UnwrappedTSN tsn,
Data data) {
+ if (data.is_beginning && data.is_end) {
+ // Fastpath for already assembled chunks.
+ AssembleMessage(tsn, std::move(data));
+ return 0;
+ }
int queued_bytes = data.size();
auto [it, inserted] = chunks_.emplace(tsn, std::move(data));
if (!inserted) {
@@ -124,12 +129,7 @@ size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage(
if (count == 1) {
// Fast path - zero-copy
- const Data& data = start->second;
- size_t payload_size = start->second.size();
- UnwrappedTSN tsns[1] = {start->first};
- DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload));
- parent_.on_assembled_message_(tsns, std::move(message));
- return payload_size;
+ return AssembleMessage(start->first, std::move(start->second));
}
// Slow path - will need to concatenate the payload.
@@ -155,6 +155,17 @@ size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage(
return payload_size;
}
+size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage(
+ UnwrappedTSN tsn,
+ Data data) {
+ // Fast path - zero-copy
+ size_t payload_size = data.size();
+ UnwrappedTSN tsns[1] = {tsn};
+ DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload));
+ parent_.on_assembled_message_(tsns, std::move(message));
+ return payload_size;
+}
+
size_t TraditionalReassemblyStreams::UnorderedStream::EraseTo(
UnwrappedTSN tsn) {
auto end_iter = chunks_.upper_bound(tsn);
@@ -202,20 +213,40 @@ size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessages() {
return assembled_bytes;
}
+size_t
+TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessagesFastpath(
+ UnwrappedSSN ssn,
+ UnwrappedTSN tsn,
+ Data data) {
+ RTC_DCHECK(ssn == next_ssn_);
+ size_t assembled_bytes = 0;
+ if (data.is_beginning && data.is_end) {
+ assembled_bytes += AssembleMessage(tsn, std::move(data));
+ next_ssn_.Increment();
+ } else {
+ size_t queued_bytes = data.size();
+ auto [iter, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data));
+ if (!inserted) {
+ // Not actually assembled, but deduplicated meaning queued size doesn't
+ // include this message.
+ return queued_bytes;
+ }
+ }
+ return assembled_bytes + TryToAssembleMessages();
+}
+
int TraditionalReassemblyStreams::OrderedStream::Add(UnwrappedTSN tsn,
Data data) {
int queued_bytes = data.size();
-
UnwrappedSSN ssn = ssn_unwrapper_.Unwrap(data.ssn);
- auto [unused, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data));
+ if (ssn == next_ssn_) {
+ return queued_bytes -
+ TryToAssembleMessagesFastpath(ssn, tsn, std::move(data));
+ }
+ auto [iter, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data));
if (!inserted) {
return 0;
}
-
- if (ssn == next_ssn_) {
- queued_bytes -= TryToAssembleMessages();
- }
-
return queued_bytes;
}
diff --git a/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.h b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.h
index d355c599ae..9214a9bc9a 100644
--- a/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.h
+++ b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.h
@@ -55,6 +55,7 @@ class TraditionalReassemblyStreams : public ReassemblyStreams {
: parent_(*parent) {}
size_t AssembleMessage(ChunkMap::iterator start, ChunkMap::iterator end);
+ size_t AssembleMessage(UnwrappedTSN tsn, Data data);
TraditionalReassemblyStreams& parent_;
};
@@ -101,6 +102,11 @@ class TraditionalReassemblyStreams : public ReassemblyStreams {
// Returns the number of bytes assembled if a message was assembled.
size_t TryToAssembleMessage();
size_t TryToAssembleMessages();
+ // Same as above but when inserting the first complete message avoid
+ // insertion into the map.
+ size_t TryToAssembleMessagesFastpath(UnwrappedSSN ssn,
+ UnwrappedTSN tsn,
+ Data data);
// This must be an ordered container to be able to iterate in SSN order.
std::map<UnwrappedSSN, ChunkMap> chunks_by_ssn_;
UnwrappedSSN::Unwrapper ssn_unwrapper_;
diff --git a/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn b/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn
index 04f61e5b72..406593e23b 100644
--- a/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn
+++ b/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn
@@ -140,7 +140,6 @@ rtc_library("dcsctp_socket") {
"../../../api:make_ref_counted",
"../../../api:refcountedbase",
"../../../api:scoped_refptr",
- "../../../api:sequence_checker",
"../../../api/task_queue:task_queue",
"../../../rtc_base:checks",
"../../../rtc_base:logging",
@@ -178,6 +177,7 @@ rtc_library("dcsctp_socket") {
"//third_party/abseil-cpp/absl/memory",
"//third_party/abseil-cpp/absl/strings",
"//third_party/abseil-cpp/absl/types:optional",
+ "//third_party/abseil-cpp/absl/types:variant",
]
}
diff --git a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc
index 0a24020167..549a592b8d 100644
--- a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc
+++ b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc
@@ -12,31 +12,6 @@
#include "api/make_ref_counted.h"
namespace dcsctp {
-namespace {
-// A wrapper around the move-only DcSctpMessage, to let it be captured in a
-// lambda.
-class MessageDeliverer {
- public:
- explicit MessageDeliverer(DcSctpMessage&& message)
- : state_(rtc::make_ref_counted<State>(std::move(message))) {}
-
- void Deliver(DcSctpSocketCallbacks& c) {
- // Really ensure that it's only called once.
- RTC_DCHECK(!state_->has_delivered);
- state_->has_delivered = true;
- c.OnMessageReceived(std::move(state_->message));
- }
-
- private:
- struct State : public webrtc::RefCountInterface {
- explicit State(DcSctpMessage&& m)
- : has_delivered(false), message(std::move(m)) {}
- bool has_delivered;
- DcSctpMessage message;
- };
- rtc::scoped_refptr<State> state_;
-};
-} // namespace
void CallbackDeferrer::Prepare() {
RTC_DCHECK(!prepared_);
@@ -48,12 +23,16 @@ void CallbackDeferrer::TriggerDeferred() {
// callback, and that might result in adding new callbacks to this instance,
// and the vector can't be modified while iterated on.
RTC_DCHECK(prepared_);
- std::vector<std::function<void(DcSctpSocketCallbacks & cb)>> deferred;
- deferred.swap(deferred_);
prepared_ = false;
-
- for (auto& cb : deferred) {
- cb(underlying_);
+ if (deferred_.empty()) {
+ return;
+ }
+ std::vector<std::pair<Callback, CallbackData>> deferred;
+ // Reserve a small buffer to prevent too much reallocation on growth.
+ deferred.reserve(8);
+ deferred.swap(deferred_);
+ for (auto& [cb, data] : deferred) {
+ cb(std::move(data), underlying_);
}
}
@@ -84,40 +63,57 @@ uint32_t CallbackDeferrer::GetRandomInt(uint32_t low, uint32_t high) {
void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
- [deliverer = MessageDeliverer(std::move(message))](
- DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); });
+ +[](CallbackData data, DcSctpSocketCallbacks& cb) {
+ return cb.OnMessageReceived(absl::get<DcSctpMessage>(std::move(data)));
+ },
+ std::move(message));
}
void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
- [error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
- cb.OnError(error, message);
- });
+ +[](CallbackData data, DcSctpSocketCallbacks& cb) {
+ Error error = absl::get<Error>(std::move(data));
+ return cb.OnError(error.error, error.message);
+ },
+ Error{error, std::string(message)});
}
void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
- [error, message = std::string(message)](DcSctpSocketCallbacks& cb) {
- cb.OnAborted(error, message);
- });
+ +[](CallbackData data, DcSctpSocketCallbacks& cb) {
+ Error error = absl::get<Error>(std::move(data));
+ return cb.OnAborted(error.error, error.message);
+ },
+ Error{error, std::string(message)});
}
void CallbackDeferrer::OnConnected() {
RTC_DCHECK(prepared_);
- deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); });
+ deferred_.emplace_back(
+ +[](CallbackData data, DcSctpSocketCallbacks& cb) {
+ return cb.OnConnected();
+ },
+ absl::monostate{});
}
void CallbackDeferrer::OnClosed() {
RTC_DCHECK(prepared_);
- deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); });
+ deferred_.emplace_back(
+ +[](CallbackData data, DcSctpSocketCallbacks& cb) {
+ return cb.OnClosed();
+ },
+ absl::monostate{});
}
void CallbackDeferrer::OnConnectionRestarted() {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
- [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); });
+ +[](CallbackData data, DcSctpSocketCallbacks& cb) {
+ return cb.OnConnectionRestarted();
+ },
+ absl::monostate{});
}
void CallbackDeferrer::OnStreamsResetFailed(
@@ -125,42 +121,53 @@ void CallbackDeferrer::OnStreamsResetFailed(
absl::string_view reason) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
- [streams = std::vector<StreamID>(outgoing_streams.begin(),
- outgoing_streams.end()),
- reason = std::string(reason)](DcSctpSocketCallbacks& cb) {
- cb.OnStreamsResetFailed(streams, reason);
- });
+ +[](CallbackData data, DcSctpSocketCallbacks& cb) {
+ StreamReset stream_reset = absl::get<StreamReset>(std::move(data));
+ return cb.OnStreamsResetFailed(stream_reset.streams,
+ stream_reset.message);
+ },
+ StreamReset{{outgoing_streams.begin(), outgoing_streams.end()},
+ std::string(reason)});
}
void CallbackDeferrer::OnStreamsResetPerformed(
rtc::ArrayView<const StreamID> outgoing_streams) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
- [streams = std::vector<StreamID>(outgoing_streams.begin(),
- outgoing_streams.end())](
- DcSctpSocketCallbacks& cb) { cb.OnStreamsResetPerformed(streams); });
+ +[](CallbackData data, DcSctpSocketCallbacks& cb) {
+ StreamReset stream_reset = absl::get<StreamReset>(std::move(data));
+ return cb.OnStreamsResetPerformed(stream_reset.streams);
+ },
+ StreamReset{{outgoing_streams.begin(), outgoing_streams.end()}});
}
void CallbackDeferrer::OnIncomingStreamsReset(
rtc::ArrayView<const StreamID> incoming_streams) {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
- [streams = std::vector<StreamID>(incoming_streams.begin(),
- incoming_streams.end())](
- DcSctpSocketCallbacks& cb) { cb.OnIncomingStreamsReset(streams); });
+ +[](CallbackData data, DcSctpSocketCallbacks& cb) {
+ StreamReset stream_reset = absl::get<StreamReset>(std::move(data));
+ return cb.OnIncomingStreamsReset(stream_reset.streams);
+ },
+ StreamReset{{incoming_streams.begin(), incoming_streams.end()}});
}
void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) {
RTC_DCHECK(prepared_);
- deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) {
- cb.OnBufferedAmountLow(stream_id);
- });
+ deferred_.emplace_back(
+ +[](CallbackData data, DcSctpSocketCallbacks& cb) {
+ return cb.OnBufferedAmountLow(absl::get<StreamID>(std::move(data)));
+ },
+ stream_id);
}
void CallbackDeferrer::OnTotalBufferedAmountLow() {
RTC_DCHECK(prepared_);
deferred_.emplace_back(
- [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); });
+ +[](CallbackData data, DcSctpSocketCallbacks& cb) {
+ return cb.OnTotalBufferedAmountLow();
+ },
+ absl::monostate{});
}
void CallbackDeferrer::OnLifecycleMessageExpired(LifecycleId lifecycle_id,
diff --git a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h
index 6659e87155..9d9fbcef06 100644
--- a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h
+++ b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h
@@ -18,6 +18,7 @@
#include <vector>
#include "absl/strings/string_view.h"
+#include "absl/types/variant.h"
#include "api/array_view.h"
#include "api/ref_counted_base.h"
#include "api/scoped_refptr.h"
@@ -89,12 +90,26 @@ class CallbackDeferrer : public DcSctpSocketCallbacks {
void OnLifecycleEnd(LifecycleId lifecycle_id) override;
private:
+ struct Error {
+ ErrorKind error;
+ std::string message;
+ };
+ struct StreamReset {
+ std::vector<StreamID> streams;
+ std::string message;
+ };
+ // Use a pre-sized variant for storage to avoid double heap allocation. This
+ // variant can hold all cases of stored data.
+ using CallbackData = absl::
+ variant<absl::monostate, DcSctpMessage, Error, StreamReset, StreamID>;
+ using Callback = void (*)(CallbackData, DcSctpSocketCallbacks&);
+
void Prepare();
void TriggerDeferred();
DcSctpSocketCallbacks& underlying_;
bool prepared_ = false;
- std::vector<std::function<void(DcSctpSocketCallbacks& cb)>> deferred_;
+ std::vector<std::pair<Callback, CallbackData>> deferred_;
};
} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc
index f0f9590943..98cd34a111 100644
--- a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc
+++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc
@@ -296,20 +296,14 @@ void DcSctpSocket::SendInit() {
packet_sender_.Send(b, /*write_checksum=*/true);
}
-void DcSctpSocket::MakeConnectionParameters() {
- VerificationTag new_verification_tag(
- callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag));
- TSN initial_tsn(callbacks_.GetRandomInt(kMinInitialTsn, kMaxInitialTsn));
- connect_params_.initial_tsn = initial_tsn;
- connect_params_.verification_tag = new_verification_tag;
-}
-
void DcSctpSocket::Connect() {
- RTC_DCHECK_RUN_ON(&thread_checker_);
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (state_ == State::kClosed) {
- MakeConnectionParameters();
+ connect_params_.initial_tsn =
+ TSN(callbacks_.GetRandomInt(kMinInitialTsn, kMaxInitialTsn));
+ connect_params_.verification_tag = VerificationTag(
+ callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag));
RTC_DLOG(LS_INFO)
<< log_prefix()
<< rtc::StringFormat(
@@ -348,7 +342,6 @@ void DcSctpSocket::CreateTransmissionControlBlock(
}
void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) {
- RTC_DCHECK_RUN_ON(&thread_checker_);
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (state_ != State::kClosed) {
@@ -391,7 +384,6 @@ void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) {
}
void DcSctpSocket::Shutdown() {
- RTC_DCHECK_RUN_ON(&thread_checker_);
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (tcb_ != nullptr) {
@@ -420,7 +412,6 @@ void DcSctpSocket::Shutdown() {
}
void DcSctpSocket::Close() {
- RTC_DCHECK_RUN_ON(&thread_checker_);
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (state_ != State::kClosed) {
@@ -468,20 +459,51 @@ void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) {
void DcSctpSocket::SetStreamPriority(StreamID stream_id,
StreamPriority priority) {
- RTC_DCHECK_RUN_ON(&thread_checker_);
send_queue_.SetStreamPriority(stream_id, priority);
}
StreamPriority DcSctpSocket::GetStreamPriority(StreamID stream_id) const {
- RTC_DCHECK_RUN_ON(&thread_checker_);
return send_queue_.GetStreamPriority(stream_id);
}
SendStatus DcSctpSocket::Send(DcSctpMessage message,
const SendOptions& send_options) {
- RTC_DCHECK_RUN_ON(&thread_checker_);
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
- LifecycleId lifecycle_id = send_options.lifecycle_id;
+ SendStatus send_status = InternalSend(message, send_options);
+ if (send_status != SendStatus::kSuccess)
+ return send_status;
+ Timestamp now = callbacks_.Now();
+ ++metrics_.tx_messages_count;
+ send_queue_.Add(now, std::move(message), send_options);
+ if (tcb_ != nullptr)
+ tcb_->SendBufferedPackets(now);
+ RTC_DCHECK(IsConsistent());
+ return SendStatus::kSuccess;
+}
+std::vector<SendStatus> DcSctpSocket::SendMany(
+ rtc::ArrayView<DcSctpMessage> messages,
+ const SendOptions& send_options) {
+ CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
+ Timestamp now = callbacks_.Now();
+ std::vector<SendStatus> send_statuses;
+ send_statuses.reserve(messages.size());
+ for (DcSctpMessage& message : messages) {
+ SendStatus send_status = InternalSend(message, send_options);
+ send_statuses.push_back(send_status);
+ if (send_status != SendStatus::kSuccess)
+ continue;
+ ++metrics_.tx_messages_count;
+ send_queue_.Add(now, std::move(message), send_options);
+ }
+ if (tcb_ != nullptr)
+ tcb_->SendBufferedPackets(now);
+ RTC_DCHECK(IsConsistent());
+ return send_statuses;
+}
+
+SendStatus DcSctpSocket::InternalSend(const DcSctpMessage& message,
+ const SendOptions& send_options) {
+ LifecycleId lifecycle_id = send_options.lifecycle_id;
if (message.payload().empty()) {
if (lifecycle_id.IsSet()) {
callbacks_.OnLifecycleEnd(lifecycle_id);
@@ -519,21 +541,11 @@ SendStatus DcSctpSocket::Send(DcSctpMessage message,
"Unable to send message as the send queue is full");
return SendStatus::kErrorResourceExhaustion;
}
-
- Timestamp now = callbacks_.Now();
- ++metrics_.tx_messages_count;
- send_queue_.Add(now, std::move(message), send_options);
- if (tcb_ != nullptr) {
- tcb_->SendBufferedPackets(now);
- }
-
- RTC_DCHECK(IsConsistent());
return SendStatus::kSuccess;
}
ResetStreamsStatus DcSctpSocket::ResetStreams(
rtc::ArrayView<const StreamID> outgoing_streams) {
- RTC_DCHECK_RUN_ON(&thread_checker_);
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (tcb_ == nullptr) {
@@ -555,7 +567,6 @@ ResetStreamsStatus DcSctpSocket::ResetStreams(
}
SocketState DcSctpSocket::state() const {
- RTC_DCHECK_RUN_ON(&thread_checker_);
switch (state_) {
case State::kClosed:
return SocketState::kClosed;
@@ -573,29 +584,23 @@ SocketState DcSctpSocket::state() const {
}
void DcSctpSocket::SetMaxMessageSize(size_t max_message_size) {
- RTC_DCHECK_RUN_ON(&thread_checker_);
options_.max_message_size = max_message_size;
}
size_t DcSctpSocket::buffered_amount(StreamID stream_id) const {
- RTC_DCHECK_RUN_ON(&thread_checker_);
return send_queue_.buffered_amount(stream_id);
}
size_t DcSctpSocket::buffered_amount_low_threshold(StreamID stream_id) const {
- RTC_DCHECK_RUN_ON(&thread_checker_);
return send_queue_.buffered_amount_low_threshold(stream_id);
}
void DcSctpSocket::SetBufferedAmountLowThreshold(StreamID stream_id,
size_t bytes) {
- RTC_DCHECK_RUN_ON(&thread_checker_);
send_queue_.SetBufferedAmountLowThreshold(stream_id, bytes);
}
absl::optional<Metrics> DcSctpSocket::GetMetrics() const {
- RTC_DCHECK_RUN_ON(&thread_checker_);
-
if (tcb_ == nullptr) {
return absl::nullopt;
}
@@ -750,7 +755,6 @@ bool DcSctpSocket::ValidatePacket(const SctpPacket& packet) {
}
void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) {
- RTC_DCHECK_RUN_ON(&thread_checker_);
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
timer_manager_.HandleTimeout(timeout_id);
@@ -764,7 +768,6 @@ void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) {
}
void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) {
- RTC_DCHECK_RUN_ON(&thread_checker_);
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
++metrics_.rx_packets_count;
@@ -1153,11 +1156,16 @@ void DcSctpSocket::HandleInit(const CommonHeader& header,
}
TieTag tie_tag(0);
+ VerificationTag my_verification_tag;
+ TSN my_initial_tsn;
if (state_ == State::kClosed) {
RTC_DLOG(LS_VERBOSE) << log_prefix()
<< "Received Init in closed state (normal)";
- MakeConnectionParameters();
+ my_verification_tag = VerificationTag(
+ callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag));
+ my_initial_tsn =
+ TSN(callbacks_.GetRandomInt(kMinInitialTsn, kMaxInitialTsn));
} else if (state_ == State::kCookieWait || state_ == State::kCookieEchoed) {
// https://tools.ietf.org/html/rfc4960#section-5.2.1
// "This usually indicates an initialization collision, i.e., each
@@ -1170,6 +1178,8 @@ void DcSctpSocket::HandleInit(const CommonHeader& header,
// endpoint) was sent."
RTC_DLOG(LS_VERBOSE) << log_prefix()
<< "Received Init indicating simultaneous connections";
+ my_verification_tag = connect_params_.verification_tag;
+ my_initial_tsn = connect_params_.initial_tsn;
} else {
RTC_DCHECK(tcb_ != nullptr);
// https://tools.ietf.org/html/rfc4960#section-5.2.2
@@ -1184,17 +1194,16 @@ void DcSctpSocket::HandleInit(const CommonHeader& header,
<< "Received Init indicating restarted connection";
// Create a new verification tag - different from the previous one.
for (int tries = 0; tries < 10; ++tries) {
- connect_params_.verification_tag = VerificationTag(
+ my_verification_tag = VerificationTag(
callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag));
- if (connect_params_.verification_tag != tcb_->my_verification_tag()) {
+ if (my_verification_tag != tcb_->my_verification_tag()) {
break;
}
}
// Make the initial TSN make a large jump, so that there is no overlap
// with the old and new association.
- connect_params_.initial_tsn =
- TSN(*tcb_->retransmission_queue().next_tsn() + 1000000);
+ my_initial_tsn = TSN(*tcb_->retransmission_queue().next_tsn() + 1000000);
tie_tag = tcb_->tie_tag();
}
@@ -1204,8 +1213,8 @@ void DcSctpSocket::HandleInit(const CommonHeader& header,
"Proceeding with connection. my_verification_tag=%08x, "
"my_initial_tsn=%u, peer_verification_tag=%08x, "
"peer_initial_tsn=%u",
- *connect_params_.verification_tag, *connect_params_.initial_tsn,
- *chunk->initiate_tag(), *chunk->initial_tsn());
+ *my_verification_tag, *my_initial_tsn, *chunk->initiate_tag(),
+ *chunk->initial_tsn());
Capabilities capabilities =
ComputeCapabilities(options_, chunk->nbr_outbound_streams(),
@@ -1214,16 +1223,17 @@ void DcSctpSocket::HandleInit(const CommonHeader& header,
SctpPacket::Builder b(chunk->initiate_tag(), options_);
Parameters::Builder params_builder =
Parameters::Builder().Add(StateCookieParameter(
- StateCookie(chunk->initiate_tag(), chunk->initial_tsn(),
- chunk->a_rwnd(), tie_tag, capabilities)
+ StateCookie(chunk->initiate_tag(), my_verification_tag,
+ chunk->initial_tsn(), my_initial_tsn, chunk->a_rwnd(),
+ tie_tag, capabilities)
.Serialize()));
AddCapabilityParameters(options_, params_builder);
- InitAckChunk init_ack(/*initiate_tag=*/connect_params_.verification_tag,
+ InitAckChunk init_ack(/*initiate_tag=*/my_verification_tag,
options_.max_receiver_window_buffer_size,
options_.announced_maximum_outgoing_streams,
options_.announced_maximum_incoming_streams,
- connect_params_.initial_tsn, params_builder.Build());
+ my_initial_tsn, params_builder.Build());
b.Add(init_ack);
// If the peer has signaled that it supports zero checksum, INIT-ACK can then
// have its checksum as zero.
@@ -1309,13 +1319,13 @@ void DcSctpSocket::HandleCookieEcho(
return;
}
} else {
- if (header.verification_tag != connect_params_.verification_tag) {
+ if (header.verification_tag != cookie->my_tag()) {
callbacks_.OnError(
ErrorKind::kParseFailed,
rtc::StringFormat(
"Received CookieEcho with invalid verification tag: %08x, "
"expected %08x",
- *header.verification_tag, *connect_params_.verification_tag));
+ *header.verification_tag, *cookie->my_tag()));
return;
}
}
@@ -1340,10 +1350,10 @@ void DcSctpSocket::HandleCookieEcho(
// send queue is already re-configured, and shouldn't be reset.
send_queue_.Reset();
- CreateTransmissionControlBlock(
- cookie->capabilities(), connect_params_.verification_tag,
- connect_params_.initial_tsn, cookie->initiate_tag(),
- cookie->initial_tsn(), cookie->a_rwnd(), MakeTieTag(callbacks_));
+ CreateTransmissionControlBlock(cookie->capabilities(), cookie->my_tag(),
+ cookie->my_initial_tsn(), cookie->peer_tag(),
+ cookie->peer_initial_tsn(), cookie->a_rwnd(),
+ MakeTieTag(callbacks_));
}
SctpPacket::Builder b = tcb_->PacketBuilder();
@@ -1363,13 +1373,13 @@ bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header,
<< *tcb_->my_verification_tag()
<< ", peer_tag=" << *header.verification_tag
<< ", tcb_tag=" << *tcb_->peer_verification_tag()
- << ", cookie_tag=" << *cookie.initiate_tag()
+ << ", peer_tag=" << *cookie.peer_tag()
<< ", local_tie_tag=" << *tcb_->tie_tag()
<< ", peer_tie_tag=" << *cookie.tie_tag();
// https://tools.ietf.org/html/rfc4960#section-5.2.4
// "Handle a COOKIE ECHO when a TCB Exists"
if (header.verification_tag != tcb_->my_verification_tag() &&
- tcb_->peer_verification_tag() != cookie.initiate_tag() &&
+ tcb_->peer_verification_tag() != cookie.peer_tag() &&
cookie.tie_tag() == tcb_->tie_tag()) {
// "A) In this case, the peer may have restarted."
if (state_ == State::kShutdownAckSent) {
@@ -1377,7 +1387,7 @@ bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header,
// that the peer has restarted ... it MUST NOT set up a new association
// but instead resend the SHUTDOWN ACK and send an ERROR chunk with a
// "Cookie Received While Shutting Down" error cause to its peer."
- SctpPacket::Builder b(cookie.initiate_tag(), options_);
+ SctpPacket::Builder b(cookie.peer_tag(), options_);
b.Add(ShutdownAckChunk());
b.Add(ErrorChunk(Parameters::Builder()
.Add(CookieReceivedWhileShuttingDownCause())
@@ -1394,7 +1404,7 @@ bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header,
tcb_ = nullptr;
callbacks_.OnConnectionRestarted();
} else if (header.verification_tag == tcb_->my_verification_tag() &&
- tcb_->peer_verification_tag() != cookie.initiate_tag()) {
+ tcb_->peer_verification_tag() != cookie.peer_tag()) {
// TODO(boivie): Handle the peer_tag == 0?
// "B) In this case, both sides may be attempting to start an
// association at about the same time, but the peer endpoint started its
@@ -1404,7 +1414,7 @@ bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header,
<< "Received COOKIE-ECHO indicating simultaneous connections";
tcb_ = nullptr;
} else if (header.verification_tag != tcb_->my_verification_tag() &&
- tcb_->peer_verification_tag() == cookie.initiate_tag() &&
+ tcb_->peer_verification_tag() == cookie.peer_tag() &&
cookie.tie_tag() == TieTag(0)) {
// "C) In this case, the local endpoint's cookie has arrived late.
// Before it arrived, the local endpoint sent an INIT and received an
@@ -1417,7 +1427,7 @@ bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header,
<< "Received COOKIE-ECHO indicating a late COOKIE-ECHO. Discarding";
return false;
} else if (header.verification_tag == tcb_->my_verification_tag() &&
- tcb_->peer_verification_tag() == cookie.initiate_tag()) {
+ tcb_->peer_verification_tag() == cookie.peer_tag()) {
// "D) When both local and remote tags match, the endpoint should enter
// the ESTABLISHED state, if it is in the COOKIE-ECHOED state. It
// should stop any cookie timer that may be running and send a COOKIE
@@ -1761,7 +1771,6 @@ void DcSctpSocket::SendShutdownAck() {
}
HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const {
- RTC_DCHECK_RUN_ON(&thread_checker_);
HandoverReadinessStatus status;
if (state_ != State::kClosed && state_ != State::kEstablished) {
status.Add(HandoverUnreadinessReason::kWrongConnectionState);
@@ -1775,7 +1784,6 @@ HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const {
absl::optional<DcSctpSocketHandoverState>
DcSctpSocket::GetHandoverStateAndClose() {
- RTC_DCHECK_RUN_ON(&thread_checker_);
CallbackDeferrer::ScopedDeferrer deferrer(callbacks_);
if (!GetHandoverReadiness().IsReady()) {
diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h
index deb6ee23e7..c65571a923 100644
--- a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h
+++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h
@@ -14,10 +14,10 @@
#include <memory>
#include <string>
#include <utility>
+#include <vector>
#include "absl/strings/string_view.h"
#include "api/array_view.h"
-#include "api/sequence_checker.h"
#include "net/dcsctp/packet/chunk/abort_chunk.h"
#include "net/dcsctp/packet/chunk/chunk.h"
#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h"
@@ -91,6 +91,8 @@ class DcSctpSocket : public DcSctpSocketInterface {
void Close() override;
SendStatus Send(DcSctpMessage message,
const SendOptions& send_options) override;
+ std::vector<SendStatus> SendMany(rtc::ArrayView<DcSctpMessage> messages,
+ const SendOptions& send_options) override;
ResetStreamsStatus ResetStreams(
rtc::ArrayView<const StreamID> outgoing_streams) override;
SocketState state() const override;
@@ -148,8 +150,6 @@ class DcSctpSocket : public DcSctpSocketInterface {
// Changes the socket state, given a `reason` (for debugging/logging).
void SetState(State state, absl::string_view reason);
- // Fills in `connect_params` with random verification tag and initial TSN.
- void MakeConnectionParameters();
// Closes the association. Note that the TCB will not be valid past this call.
void InternalClose(ErrorKind error, absl::string_view message);
// Closes the association, because of too many retransmission errors.
@@ -167,6 +167,9 @@ class DcSctpSocket : public DcSctpSocketInterface {
void MaybeSendShutdownOnPacketReceived(const SctpPacket& packet);
// If there are streams pending to be reset, send a request to reset them.
void MaybeSendResetStreamsRequest();
+ // Performs internal processing shared between Send and SendMany.
+ SendStatus InternalSend(const DcSctpMessage& message,
+ const SendOptions& send_options);
// Sends a INIT chunk.
void SendInit();
// Sends a SHUTDOWN chunk.
@@ -267,7 +270,6 @@ class DcSctpSocket : public DcSctpSocketInterface {
const std::string log_prefix_;
const std::unique_ptr<PacketObserver> packet_observer_;
- RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker thread_checker_;
Metrics metrics_;
DcSctpOptions options_;
diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc
index dc76b80a37..413516bae0 100644
--- a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc
+++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc
@@ -66,6 +66,7 @@ namespace {
using ::testing::_;
using ::testing::AllOf;
using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
using ::testing::Eq;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
@@ -1561,6 +1562,33 @@ TEST(DcSctpSocketTest, SetMaxMessageSize) {
EXPECT_EQ(a.socket.options().max_message_size, 42u);
}
+TEST_P(DcSctpSocketParametrizedTest, SendManyMessages) {
+ SocketUnderTest a("A");
+ auto z = std::make_unique<SocketUnderTest>("Z");
+
+ ConnectSockets(a, *z);
+ z = MaybeHandoverSocket(std::move(z));
+
+ static constexpr int kIterations = 100;
+ std::vector<DcSctpMessage> messages;
+ std::vector<SendStatus> statuses;
+ for (int i = 0; i < kIterations; ++i) {
+ messages.push_back(DcSctpMessage(StreamID(1), PPID(53), {1, 2}));
+ statuses.push_back(SendStatus::kSuccess);
+ }
+ EXPECT_THAT(a.socket.SendMany(messages, {}), ElementsAreArray(statuses));
+
+ ExchangeMessages(a, *z);
+
+ for (int i = 0; i < kIterations; ++i) {
+ EXPECT_TRUE(z->cb.ConsumeReceivedMessage().has_value());
+ }
+
+ EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value());
+
+ MaybeHandoverSocketAndSendMessage(a, std::move(z));
+}
+
TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) {
SocketUnderTest a("A");
auto z = std::make_unique<SocketUnderTest>("Z");
@@ -3061,5 +3089,149 @@ TEST(DcSctpSocketTest, HandlesForwardTsnOutOfOrderWithStreamResetting) {
testing::Optional(Property(&DcSctpMessage::ppid, PPID(53))));
}
+TEST(DcSctpSocketTest, ResentInitHasSameParameters) {
+ // If an INIT chunk has to be resent (due to INIT_ACK not received in time),
+ // the resent INIT must have the same properties as the original one.
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+ auto packet_1 = a.cb.ConsumeSentPacket();
+
+ // Times out, INIT is re-sent.
+ AdvanceTime(a, z, a.options.t1_init_timeout.ToTimeDelta());
+ auto packet_2 = a.cb.ConsumeSentPacket();
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet_1,
+ SctpPacket::Parse(packet_1, z.options));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ InitChunk init_chunk_1,
+ InitChunk::Parse(init_packet_1.descriptors()[0].data));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet_2,
+ SctpPacket::Parse(packet_2, z.options));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ InitChunk init_chunk_2,
+ InitChunk::Parse(init_packet_2.descriptors()[0].data));
+
+ EXPECT_EQ(init_chunk_1.initial_tsn(), init_chunk_2.initial_tsn());
+ EXPECT_EQ(init_chunk_1.initiate_tag(), init_chunk_2.initiate_tag());
+}
+
+TEST(DcSctpSocketTest, ResentInitAckHasDifferentParameters) {
+ // For every INIT, an INIT_ACK is produced. Verify that the socket doesn't
+ // maintain any state by ensuring that two created INIT_ACKs for the same
+ // received INIT are different.
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Connect();
+ auto packet_1 = a.cb.ConsumeSentPacket();
+ EXPECT_THAT(packet_1, HasChunks(ElementsAre(IsChunkType(InitChunk::kType))));
+
+ z.socket.ReceivePacket(packet_1);
+ auto packet_2 = z.cb.ConsumeSentPacket();
+ z.socket.ReceivePacket(packet_1);
+ auto packet_3 = z.cb.ConsumeSentPacket();
+
+ EXPECT_THAT(packet_2,
+ HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType))));
+ EXPECT_THAT(packet_3,
+ HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType))));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_ack_packet_1,
+ SctpPacket::Parse(packet_2, z.options));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ InitAckChunk init_ack_chunk_1,
+ InitAckChunk::Parse(init_ack_packet_1.descriptors()[0].data));
+
+ ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_ack_packet_2,
+ SctpPacket::Parse(packet_3, z.options));
+ ASSERT_HAS_VALUE_AND_ASSIGN(
+ InitAckChunk init_ack_chunk_2,
+ InitAckChunk::Parse(init_ack_packet_2.descriptors()[0].data));
+
+ EXPECT_NE(init_ack_chunk_1.initiate_tag(), init_ack_chunk_2.initiate_tag());
+ EXPECT_NE(init_ack_chunk_1.initial_tsn(), init_ack_chunk_2.initial_tsn());
+}
+
+TEST(DcSctpSocketResendInitTest, ConnectionCanContinueFromFirstInitAck) {
+ // If an INIT chunk has to be resent (due to INIT_ACK not received in time),
+ // another INIT will be sent, and if both INITs were actually received, both
+ // will be responded to by an INIT_ACK. While these two INIT_ACKs may have
+ // different parameters, the connection must be able to finish with the cookie
+ // (as replied to using COOKIE_ECHO) from either INIT_ACK.
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ a.socket.Connect();
+ auto init_1 = a.cb.ConsumeSentPacket();
+
+ // Times out, INIT is re-sent.
+ AdvanceTime(a, z, a.options.t1_init_timeout.ToTimeDelta());
+ auto init_2 = a.cb.ConsumeSentPacket();
+
+ EXPECT_THAT(init_1, HasChunks(ElementsAre(IsChunkType(InitChunk::kType))));
+ EXPECT_THAT(init_2, HasChunks(ElementsAre(IsChunkType(InitChunk::kType))));
+
+ z.socket.ReceivePacket(init_1);
+ z.socket.ReceivePacket(init_2);
+ auto init_ack_1 = z.cb.ConsumeSentPacket();
+ auto init_ack_2 = z.cb.ConsumeSentPacket();
+ EXPECT_THAT(init_ack_1,
+ HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType))));
+ EXPECT_THAT(init_ack_2,
+ HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType))));
+
+ a.socket.ReceivePacket(init_ack_1);
+ // Then let the rest continue.
+ ExchangeMessages(a, z);
+
+ absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), SizeIs(kLargeMessageSize));
+}
+
+TEST(DcSctpSocketResendInitTest, ConnectionCanContinueFromSecondInitAck) {
+ // Just as above, but discarding the first INIT_ACK.
+ SocketUnderTest a("A");
+ SocketUnderTest z("Z");
+
+ a.socket.Send(DcSctpMessage(StreamID(1), PPID(53),
+ std::vector<uint8_t>(kLargeMessageSize)),
+ kSendOptions);
+ a.socket.Connect();
+ auto init_1 = a.cb.ConsumeSentPacket();
+
+ // Times out, INIT is re-sent.
+ AdvanceTime(a, z, a.options.t1_init_timeout.ToTimeDelta());
+ auto init_2 = a.cb.ConsumeSentPacket();
+
+ EXPECT_THAT(init_1, HasChunks(ElementsAre(IsChunkType(InitChunk::kType))));
+ EXPECT_THAT(init_2, HasChunks(ElementsAre(IsChunkType(InitChunk::kType))));
+
+ z.socket.ReceivePacket(init_1);
+ z.socket.ReceivePacket(init_2);
+ auto init_ack_1 = z.cb.ConsumeSentPacket();
+ auto init_ack_2 = z.cb.ConsumeSentPacket();
+ EXPECT_THAT(init_ack_1,
+ HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType))));
+ EXPECT_THAT(init_ack_2,
+ HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType))));
+
+ a.socket.ReceivePacket(init_ack_2);
+ // Then let the rest continue.
+ ExchangeMessages(a, z);
+
+ absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage();
+ ASSERT_TRUE(msg.has_value());
+ EXPECT_EQ(msg->stream_id(), StreamID(1));
+ EXPECT_THAT(msg->payload(), SizeIs(kLargeMessageSize));
+}
+
} // namespace
} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc
index 624d783a3b..c5ed1d8620 100644
--- a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc
+++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc
@@ -32,17 +32,19 @@ std::vector<uint8_t> StateCookie::Serialize() {
BoundedByteWriter<kCookieSize> buffer(cookie);
buffer.Store32<0>(kMagic1);
buffer.Store32<4>(kMagic2);
- buffer.Store32<8>(*initiate_tag_);
- buffer.Store32<12>(*initial_tsn_);
- buffer.Store32<16>(a_rwnd_);
- buffer.Store32<20>(static_cast<uint32_t>(*tie_tag_ >> 32));
- buffer.Store32<24>(static_cast<uint32_t>(*tie_tag_));
- buffer.Store8<28>(capabilities_.partial_reliability);
- buffer.Store8<29>(capabilities_.message_interleaving);
- buffer.Store8<30>(capabilities_.reconfig);
- buffer.Store16<32>(capabilities_.negotiated_maximum_incoming_streams);
- buffer.Store16<34>(capabilities_.negotiated_maximum_outgoing_streams);
- buffer.Store8<36>(capabilities_.zero_checksum);
+ buffer.Store32<8>(*peer_tag_);
+ buffer.Store32<12>(*my_tag_);
+ buffer.Store32<16>(*peer_initial_tsn_);
+ buffer.Store32<20>(*my_initial_tsn_);
+ buffer.Store32<24>(a_rwnd_);
+ buffer.Store32<28>(static_cast<uint32_t>(*tie_tag_ >> 32));
+ buffer.Store32<32>(static_cast<uint32_t>(*tie_tag_));
+ buffer.Store8<36>(capabilities_.partial_reliability);
+ buffer.Store8<37>(capabilities_.message_interleaving);
+ buffer.Store8<38>(capabilities_.reconfig);
+ buffer.Store16<40>(capabilities_.negotiated_maximum_incoming_streams);
+ buffer.Store16<42>(capabilities_.negotiated_maximum_outgoing_streams);
+ buffer.Store8<44>(capabilities_.zero_checksum);
return cookie;
}
@@ -62,23 +64,25 @@ absl::optional<StateCookie> StateCookie::Deserialize(
return absl::nullopt;
}
- VerificationTag verification_tag(buffer.Load32<8>());
- TSN initial_tsn(buffer.Load32<12>());
- uint32_t a_rwnd = buffer.Load32<16>();
- uint32_t tie_tag_upper = buffer.Load32<20>();
- uint32_t tie_tag_lower = buffer.Load32<24>();
+ VerificationTag peer_tag(buffer.Load32<8>());
+ VerificationTag my_tag(buffer.Load32<12>());
+ TSN peer_initial_tsn(buffer.Load32<16>());
+ TSN my_initial_tsn(buffer.Load32<20>());
+ uint32_t a_rwnd = buffer.Load32<24>();
+ uint32_t tie_tag_upper = buffer.Load32<28>();
+ uint32_t tie_tag_lower = buffer.Load32<32>();
TieTag tie_tag(static_cast<uint64_t>(tie_tag_upper) << 32 |
static_cast<uint64_t>(tie_tag_lower));
Capabilities capabilities;
- capabilities.partial_reliability = buffer.Load8<28>() != 0;
- capabilities.message_interleaving = buffer.Load8<29>() != 0;
- capabilities.reconfig = buffer.Load8<30>() != 0;
- capabilities.negotiated_maximum_incoming_streams = buffer.Load16<32>();
- capabilities.negotiated_maximum_outgoing_streams = buffer.Load16<34>();
- capabilities.zero_checksum = buffer.Load8<36>() != 0;
+ capabilities.partial_reliability = buffer.Load8<36>() != 0;
+ capabilities.message_interleaving = buffer.Load8<37>() != 0;
+ capabilities.reconfig = buffer.Load8<38>() != 0;
+ capabilities.negotiated_maximum_incoming_streams = buffer.Load16<40>();
+ capabilities.negotiated_maximum_outgoing_streams = buffer.Load16<42>();
+ capabilities.zero_checksum = buffer.Load8<44>() != 0;
- return StateCookie(verification_tag, initial_tsn, a_rwnd, tie_tag,
- capabilities);
+ return StateCookie(peer_tag, my_tag, peer_initial_tsn, my_initial_tsn, a_rwnd,
+ tie_tag, capabilities);
}
} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h
index 34cd6d3690..b94eedafd4 100644
--- a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h
+++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h
@@ -27,15 +27,19 @@ namespace dcsctp {
// Do not trust anything in it; no pointers or anything like that.
class StateCookie {
public:
- static constexpr size_t kCookieSize = 37;
+ static constexpr size_t kCookieSize = 45;
- StateCookie(VerificationTag initiate_tag,
- TSN initial_tsn,
+ StateCookie(VerificationTag peer_tag,
+ VerificationTag my_tag,
+ TSN peer_initial_tsn,
+ TSN my_initial_tsn,
uint32_t a_rwnd,
TieTag tie_tag,
Capabilities capabilities)
- : initiate_tag_(initiate_tag),
- initial_tsn_(initial_tsn),
+ : peer_tag_(peer_tag),
+ my_tag_(my_tag),
+ peer_initial_tsn_(peer_initial_tsn),
+ my_initial_tsn_(my_initial_tsn),
a_rwnd_(a_rwnd),
tie_tag_(tie_tag),
capabilities_(capabilities) {}
@@ -47,15 +51,21 @@ class StateCookie {
static absl::optional<StateCookie> Deserialize(
rtc::ArrayView<const uint8_t> cookie);
- VerificationTag initiate_tag() const { return initiate_tag_; }
- TSN initial_tsn() const { return initial_tsn_; }
+ VerificationTag peer_tag() const { return peer_tag_; }
+ VerificationTag my_tag() const { return my_tag_; }
+ TSN peer_initial_tsn() const { return peer_initial_tsn_; }
+ TSN my_initial_tsn() const { return my_initial_tsn_; }
uint32_t a_rwnd() const { return a_rwnd_; }
TieTag tie_tag() const { return tie_tag_; }
const Capabilities& capabilities() const { return capabilities_; }
private:
- const VerificationTag initiate_tag_;
- const TSN initial_tsn_;
+ // Also called "Tag_A" in RFC4960.
+ const VerificationTag peer_tag_;
+ // Also called "Tag_Z" in RFC4960.
+ const VerificationTag my_tag_;
+ const TSN peer_initial_tsn_;
+ const TSN my_initial_tsn_;
const uint32_t a_rwnd_;
const TieTag tie_tag_;
const Capabilities capabilities_;
diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc b/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc
index 19be71a1ca..806ea2024b 100644
--- a/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc
+++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc
@@ -24,14 +24,18 @@ TEST(StateCookieTest, SerializeAndDeserialize) {
.zero_checksum = true,
.negotiated_maximum_incoming_streams = 123,
.negotiated_maximum_outgoing_streams = 234};
- StateCookie cookie(VerificationTag(123), TSN(456),
+ StateCookie cookie(/*peer_tag=*/VerificationTag(123),
+ /*my_tag=*/VerificationTag(321),
+ /*peer_initial_tsn=*/TSN(456), /*my_initial_tsn=*/TSN(654),
/*a_rwnd=*/789, TieTag(101112), capabilities);
std::vector<uint8_t> serialized = cookie.Serialize();
EXPECT_THAT(serialized, SizeIs(StateCookie::kCookieSize));
ASSERT_HAS_VALUE_AND_ASSIGN(StateCookie deserialized,
StateCookie::Deserialize(serialized));
- EXPECT_EQ(deserialized.initiate_tag(), VerificationTag(123));
- EXPECT_EQ(deserialized.initial_tsn(), TSN(456));
+ EXPECT_EQ(deserialized.peer_tag(), VerificationTag(123));
+ EXPECT_EQ(deserialized.my_tag(), VerificationTag(321));
+ EXPECT_EQ(deserialized.peer_initial_tsn(), TSN(456));
+ EXPECT_EQ(deserialized.my_initial_tsn(), TSN(654));
EXPECT_EQ(deserialized.a_rwnd(), 789u);
EXPECT_EQ(deserialized.tie_tag(), TieTag(101112));
EXPECT_TRUE(deserialized.capabilities().partial_reliability);
@@ -48,7 +52,9 @@ TEST(StateCookieTest, ValidateMagicValue) {
Capabilities capabilities = {.partial_reliability = true,
.message_interleaving = false,
.reconfig = true};
- StateCookie cookie(VerificationTag(123), TSN(456),
+ StateCookie cookie(/*peer_tag=*/VerificationTag(123),
+ /*my_tag=*/VerificationTag(321),
+ /*peer_initial_tsn=*/TSN(456), /*my_initial_tsn=*/TSN(654),
/*a_rwnd=*/789, TieTag(101112), capabilities);
std::vector<uint8_t> serialized = cookie.Serialize();
ASSERT_THAT(serialized, SizeIs(StateCookie::kCookieSize));
diff --git a/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.cc b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.cc
index 7cbead296c..3e682fdca6 100644
--- a/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.cc
+++ b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.cc
@@ -373,8 +373,9 @@ void RRSendQueue::Add(Timestamp now,
: Timestamp::PlusInfinity(),
.lifecycle_id = send_options.lifecycle_id,
};
- GetOrCreateStreamInfo(message.stream_id())
- .Add(std::move(message), std::move(attributes));
+ StreamID stream_id = message.stream_id();
+ GetOrCreateStreamInfo(stream_id).Add(std::move(message),
+ std::move(attributes));
RTC_DCHECK(IsConsistent());
}