From d8bbc7858622b6d9c278469aab701ca0b609cddf Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 15 May 2024 05:35:49 +0200 Subject: Merging upstream version 126.0. Signed-off-by: Daniel Baumann --- third_party/libwebrtc/net/dcsctp/socket/BUILD.gn | 2 +- .../net/dcsctp/socket/callback_deferrer.cc | 119 +++++++------- .../net/dcsctp/socket/callback_deferrer.h | 17 +- .../libwebrtc/net/dcsctp/socket/dcsctp_socket.cc | 130 ++++++++-------- .../libwebrtc/net/dcsctp/socket/dcsctp_socket.h | 10 +- .../net/dcsctp/socket/dcsctp_socket_test.cc | 172 +++++++++++++++++++++ .../libwebrtc/net/dcsctp/socket/state_cookie.cc | 52 ++++--- .../libwebrtc/net/dcsctp/socket/state_cookie.h | 28 ++-- .../net/dcsctp/socket/state_cookie_test.cc | 14 +- 9 files changed, 384 insertions(+), 160 deletions(-) (limited to 'third_party/libwebrtc/net/dcsctp/socket') diff --git a/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn b/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn index 04f61e5b72..406593e23b 100644 --- a/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn +++ b/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn @@ -140,7 +140,6 @@ rtc_library("dcsctp_socket") { "../../../api:make_ref_counted", "../../../api:refcountedbase", "../../../api:scoped_refptr", - "../../../api:sequence_checker", "../../../api/task_queue:task_queue", "../../../rtc_base:checks", "../../../rtc_base:logging", @@ -178,6 +177,7 @@ rtc_library("dcsctp_socket") { "//third_party/abseil-cpp/absl/memory", "//third_party/abseil-cpp/absl/strings", "//third_party/abseil-cpp/absl/types:optional", + "//third_party/abseil-cpp/absl/types:variant", ] } diff --git a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc index 0a24020167..549a592b8d 100644 --- a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc +++ b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc @@ -12,31 +12,6 @@ #include "api/make_ref_counted.h" namespace dcsctp { -namespace { -// A wrapper around the move-only DcSctpMessage, to let it be captured in a -// lambda. -class MessageDeliverer { - public: - explicit MessageDeliverer(DcSctpMessage&& message) - : state_(rtc::make_ref_counted(std::move(message))) {} - - void Deliver(DcSctpSocketCallbacks& c) { - // Really ensure that it's only called once. - RTC_DCHECK(!state_->has_delivered); - state_->has_delivered = true; - c.OnMessageReceived(std::move(state_->message)); - } - - private: - struct State : public webrtc::RefCountInterface { - explicit State(DcSctpMessage&& m) - : has_delivered(false), message(std::move(m)) {} - bool has_delivered; - DcSctpMessage message; - }; - rtc::scoped_refptr state_; -}; -} // namespace void CallbackDeferrer::Prepare() { RTC_DCHECK(!prepared_); @@ -48,12 +23,16 @@ void CallbackDeferrer::TriggerDeferred() { // callback, and that might result in adding new callbacks to this instance, // and the vector can't be modified while iterated on. RTC_DCHECK(prepared_); - std::vector> deferred; - deferred.swap(deferred_); prepared_ = false; - - for (auto& cb : deferred) { - cb(underlying_); + if (deferred_.empty()) { + return; + } + std::vector> deferred; + // Reserve a small buffer to prevent too much reallocation on growth. + deferred.reserve(8); + deferred.swap(deferred_); + for (auto& [cb, data] : deferred) { + cb(std::move(data), underlying_); } } @@ -84,40 +63,57 @@ uint32_t CallbackDeferrer::GetRandomInt(uint32_t low, uint32_t high) { void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) { RTC_DCHECK(prepared_); deferred_.emplace_back( - [deliverer = MessageDeliverer(std::move(message))]( - DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); }); + +[](CallbackData data, DcSctpSocketCallbacks& cb) { + return cb.OnMessageReceived(absl::get(std::move(data))); + }, + std::move(message)); } void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) { RTC_DCHECK(prepared_); deferred_.emplace_back( - [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { - cb.OnError(error, message); - }); + +[](CallbackData data, DcSctpSocketCallbacks& cb) { + Error error = absl::get(std::move(data)); + return cb.OnError(error.error, error.message); + }, + Error{error, std::string(message)}); } void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) { RTC_DCHECK(prepared_); deferred_.emplace_back( - [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { - cb.OnAborted(error, message); - }); + +[](CallbackData data, DcSctpSocketCallbacks& cb) { + Error error = absl::get(std::move(data)); + return cb.OnAborted(error.error, error.message); + }, + Error{error, std::string(message)}); } void CallbackDeferrer::OnConnected() { RTC_DCHECK(prepared_); - deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); }); + deferred_.emplace_back( + +[](CallbackData data, DcSctpSocketCallbacks& cb) { + return cb.OnConnected(); + }, + absl::monostate{}); } void CallbackDeferrer::OnClosed() { RTC_DCHECK(prepared_); - deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); }); + deferred_.emplace_back( + +[](CallbackData data, DcSctpSocketCallbacks& cb) { + return cb.OnClosed(); + }, + absl::monostate{}); } void CallbackDeferrer::OnConnectionRestarted() { RTC_DCHECK(prepared_); deferred_.emplace_back( - [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); }); + +[](CallbackData data, DcSctpSocketCallbacks& cb) { + return cb.OnConnectionRestarted(); + }, + absl::monostate{}); } void CallbackDeferrer::OnStreamsResetFailed( @@ -125,42 +121,53 @@ void CallbackDeferrer::OnStreamsResetFailed( absl::string_view reason) { RTC_DCHECK(prepared_); deferred_.emplace_back( - [streams = std::vector(outgoing_streams.begin(), - outgoing_streams.end()), - reason = std::string(reason)](DcSctpSocketCallbacks& cb) { - cb.OnStreamsResetFailed(streams, reason); - }); + +[](CallbackData data, DcSctpSocketCallbacks& cb) { + StreamReset stream_reset = absl::get(std::move(data)); + return cb.OnStreamsResetFailed(stream_reset.streams, + stream_reset.message); + }, + StreamReset{{outgoing_streams.begin(), outgoing_streams.end()}, + std::string(reason)}); } void CallbackDeferrer::OnStreamsResetPerformed( rtc::ArrayView outgoing_streams) { RTC_DCHECK(prepared_); deferred_.emplace_back( - [streams = std::vector(outgoing_streams.begin(), - outgoing_streams.end())]( - DcSctpSocketCallbacks& cb) { cb.OnStreamsResetPerformed(streams); }); + +[](CallbackData data, DcSctpSocketCallbacks& cb) { + StreamReset stream_reset = absl::get(std::move(data)); + return cb.OnStreamsResetPerformed(stream_reset.streams); + }, + StreamReset{{outgoing_streams.begin(), outgoing_streams.end()}}); } void CallbackDeferrer::OnIncomingStreamsReset( rtc::ArrayView incoming_streams) { RTC_DCHECK(prepared_); deferred_.emplace_back( - [streams = std::vector(incoming_streams.begin(), - incoming_streams.end())]( - DcSctpSocketCallbacks& cb) { cb.OnIncomingStreamsReset(streams); }); + +[](CallbackData data, DcSctpSocketCallbacks& cb) { + StreamReset stream_reset = absl::get(std::move(data)); + return cb.OnIncomingStreamsReset(stream_reset.streams); + }, + StreamReset{{incoming_streams.begin(), incoming_streams.end()}}); } void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) { RTC_DCHECK(prepared_); - deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) { - cb.OnBufferedAmountLow(stream_id); - }); + deferred_.emplace_back( + +[](CallbackData data, DcSctpSocketCallbacks& cb) { + return cb.OnBufferedAmountLow(absl::get(std::move(data))); + }, + stream_id); } void CallbackDeferrer::OnTotalBufferedAmountLow() { RTC_DCHECK(prepared_); deferred_.emplace_back( - [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); }); + +[](CallbackData data, DcSctpSocketCallbacks& cb) { + return cb.OnTotalBufferedAmountLow(); + }, + absl::monostate{}); } void CallbackDeferrer::OnLifecycleMessageExpired(LifecycleId lifecycle_id, diff --git a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h index 6659e87155..9d9fbcef06 100644 --- a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h +++ b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h @@ -18,6 +18,7 @@ #include #include "absl/strings/string_view.h" +#include "absl/types/variant.h" #include "api/array_view.h" #include "api/ref_counted_base.h" #include "api/scoped_refptr.h" @@ -89,12 +90,26 @@ class CallbackDeferrer : public DcSctpSocketCallbacks { void OnLifecycleEnd(LifecycleId lifecycle_id) override; private: + struct Error { + ErrorKind error; + std::string message; + }; + struct StreamReset { + std::vector streams; + std::string message; + }; + // Use a pre-sized variant for storage to avoid double heap allocation. This + // variant can hold all cases of stored data. + using CallbackData = absl:: + variant; + using Callback = void (*)(CallbackData, DcSctpSocketCallbacks&); + void Prepare(); void TriggerDeferred(); DcSctpSocketCallbacks& underlying_; bool prepared_ = false; - std::vector> deferred_; + std::vector> deferred_; }; } // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc index f0f9590943..98cd34a111 100644 --- a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc +++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc @@ -296,20 +296,14 @@ void DcSctpSocket::SendInit() { packet_sender_.Send(b, /*write_checksum=*/true); } -void DcSctpSocket::MakeConnectionParameters() { - VerificationTag new_verification_tag( - callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag)); - TSN initial_tsn(callbacks_.GetRandomInt(kMinInitialTsn, kMaxInitialTsn)); - connect_params_.initial_tsn = initial_tsn; - connect_params_.verification_tag = new_verification_tag; -} - void DcSctpSocket::Connect() { - RTC_DCHECK_RUN_ON(&thread_checker_); CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); if (state_ == State::kClosed) { - MakeConnectionParameters(); + connect_params_.initial_tsn = + TSN(callbacks_.GetRandomInt(kMinInitialTsn, kMaxInitialTsn)); + connect_params_.verification_tag = VerificationTag( + callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag)); RTC_DLOG(LS_INFO) << log_prefix() << rtc::StringFormat( @@ -348,7 +342,6 @@ void DcSctpSocket::CreateTransmissionControlBlock( } void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { - RTC_DCHECK_RUN_ON(&thread_checker_); CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); if (state_ != State::kClosed) { @@ -391,7 +384,6 @@ void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { } void DcSctpSocket::Shutdown() { - RTC_DCHECK_RUN_ON(&thread_checker_); CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); if (tcb_ != nullptr) { @@ -420,7 +412,6 @@ void DcSctpSocket::Shutdown() { } void DcSctpSocket::Close() { - RTC_DCHECK_RUN_ON(&thread_checker_); CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); if (state_ != State::kClosed) { @@ -468,20 +459,51 @@ void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) { void DcSctpSocket::SetStreamPriority(StreamID stream_id, StreamPriority priority) { - RTC_DCHECK_RUN_ON(&thread_checker_); send_queue_.SetStreamPriority(stream_id, priority); } StreamPriority DcSctpSocket::GetStreamPriority(StreamID stream_id) const { - RTC_DCHECK_RUN_ON(&thread_checker_); return send_queue_.GetStreamPriority(stream_id); } SendStatus DcSctpSocket::Send(DcSctpMessage message, const SendOptions& send_options) { - RTC_DCHECK_RUN_ON(&thread_checker_); CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); - LifecycleId lifecycle_id = send_options.lifecycle_id; + SendStatus send_status = InternalSend(message, send_options); + if (send_status != SendStatus::kSuccess) + return send_status; + Timestamp now = callbacks_.Now(); + ++metrics_.tx_messages_count; + send_queue_.Add(now, std::move(message), send_options); + if (tcb_ != nullptr) + tcb_->SendBufferedPackets(now); + RTC_DCHECK(IsConsistent()); + return SendStatus::kSuccess; +} +std::vector DcSctpSocket::SendMany( + rtc::ArrayView messages, + const SendOptions& send_options) { + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + Timestamp now = callbacks_.Now(); + std::vector send_statuses; + send_statuses.reserve(messages.size()); + for (DcSctpMessage& message : messages) { + SendStatus send_status = InternalSend(message, send_options); + send_statuses.push_back(send_status); + if (send_status != SendStatus::kSuccess) + continue; + ++metrics_.tx_messages_count; + send_queue_.Add(now, std::move(message), send_options); + } + if (tcb_ != nullptr) + tcb_->SendBufferedPackets(now); + RTC_DCHECK(IsConsistent()); + return send_statuses; +} + +SendStatus DcSctpSocket::InternalSend(const DcSctpMessage& message, + const SendOptions& send_options) { + LifecycleId lifecycle_id = send_options.lifecycle_id; if (message.payload().empty()) { if (lifecycle_id.IsSet()) { callbacks_.OnLifecycleEnd(lifecycle_id); @@ -519,21 +541,11 @@ SendStatus DcSctpSocket::Send(DcSctpMessage message, "Unable to send message as the send queue is full"); return SendStatus::kErrorResourceExhaustion; } - - Timestamp now = callbacks_.Now(); - ++metrics_.tx_messages_count; - send_queue_.Add(now, std::move(message), send_options); - if (tcb_ != nullptr) { - tcb_->SendBufferedPackets(now); - } - - RTC_DCHECK(IsConsistent()); return SendStatus::kSuccess; } ResetStreamsStatus DcSctpSocket::ResetStreams( rtc::ArrayView outgoing_streams) { - RTC_DCHECK_RUN_ON(&thread_checker_); CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); if (tcb_ == nullptr) { @@ -555,7 +567,6 @@ ResetStreamsStatus DcSctpSocket::ResetStreams( } SocketState DcSctpSocket::state() const { - RTC_DCHECK_RUN_ON(&thread_checker_); switch (state_) { case State::kClosed: return SocketState::kClosed; @@ -573,29 +584,23 @@ SocketState DcSctpSocket::state() const { } void DcSctpSocket::SetMaxMessageSize(size_t max_message_size) { - RTC_DCHECK_RUN_ON(&thread_checker_); options_.max_message_size = max_message_size; } size_t DcSctpSocket::buffered_amount(StreamID stream_id) const { - RTC_DCHECK_RUN_ON(&thread_checker_); return send_queue_.buffered_amount(stream_id); } size_t DcSctpSocket::buffered_amount_low_threshold(StreamID stream_id) const { - RTC_DCHECK_RUN_ON(&thread_checker_); return send_queue_.buffered_amount_low_threshold(stream_id); } void DcSctpSocket::SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) { - RTC_DCHECK_RUN_ON(&thread_checker_); send_queue_.SetBufferedAmountLowThreshold(stream_id, bytes); } absl::optional DcSctpSocket::GetMetrics() const { - RTC_DCHECK_RUN_ON(&thread_checker_); - if (tcb_ == nullptr) { return absl::nullopt; } @@ -750,7 +755,6 @@ bool DcSctpSocket::ValidatePacket(const SctpPacket& packet) { } void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) { - RTC_DCHECK_RUN_ON(&thread_checker_); CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); timer_manager_.HandleTimeout(timeout_id); @@ -764,7 +768,6 @@ void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) { } void DcSctpSocket::ReceivePacket(rtc::ArrayView data) { - RTC_DCHECK_RUN_ON(&thread_checker_); CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); ++metrics_.rx_packets_count; @@ -1153,11 +1156,16 @@ void DcSctpSocket::HandleInit(const CommonHeader& header, } TieTag tie_tag(0); + VerificationTag my_verification_tag; + TSN my_initial_tsn; if (state_ == State::kClosed) { RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received Init in closed state (normal)"; - MakeConnectionParameters(); + my_verification_tag = VerificationTag( + callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag)); + my_initial_tsn = + TSN(callbacks_.GetRandomInt(kMinInitialTsn, kMaxInitialTsn)); } else if (state_ == State::kCookieWait || state_ == State::kCookieEchoed) { // https://tools.ietf.org/html/rfc4960#section-5.2.1 // "This usually indicates an initialization collision, i.e., each @@ -1170,6 +1178,8 @@ void DcSctpSocket::HandleInit(const CommonHeader& header, // endpoint) was sent." RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received Init indicating simultaneous connections"; + my_verification_tag = connect_params_.verification_tag; + my_initial_tsn = connect_params_.initial_tsn; } else { RTC_DCHECK(tcb_ != nullptr); // https://tools.ietf.org/html/rfc4960#section-5.2.2 @@ -1184,17 +1194,16 @@ void DcSctpSocket::HandleInit(const CommonHeader& header, << "Received Init indicating restarted connection"; // Create a new verification tag - different from the previous one. for (int tries = 0; tries < 10; ++tries) { - connect_params_.verification_tag = VerificationTag( + my_verification_tag = VerificationTag( callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag)); - if (connect_params_.verification_tag != tcb_->my_verification_tag()) { + if (my_verification_tag != tcb_->my_verification_tag()) { break; } } // Make the initial TSN make a large jump, so that there is no overlap // with the old and new association. - connect_params_.initial_tsn = - TSN(*tcb_->retransmission_queue().next_tsn() + 1000000); + my_initial_tsn = TSN(*tcb_->retransmission_queue().next_tsn() + 1000000); tie_tag = tcb_->tie_tag(); } @@ -1204,8 +1213,8 @@ void DcSctpSocket::HandleInit(const CommonHeader& header, "Proceeding with connection. my_verification_tag=%08x, " "my_initial_tsn=%u, peer_verification_tag=%08x, " "peer_initial_tsn=%u", - *connect_params_.verification_tag, *connect_params_.initial_tsn, - *chunk->initiate_tag(), *chunk->initial_tsn()); + *my_verification_tag, *my_initial_tsn, *chunk->initiate_tag(), + *chunk->initial_tsn()); Capabilities capabilities = ComputeCapabilities(options_, chunk->nbr_outbound_streams(), @@ -1214,16 +1223,17 @@ void DcSctpSocket::HandleInit(const CommonHeader& header, SctpPacket::Builder b(chunk->initiate_tag(), options_); Parameters::Builder params_builder = Parameters::Builder().Add(StateCookieParameter( - StateCookie(chunk->initiate_tag(), chunk->initial_tsn(), - chunk->a_rwnd(), tie_tag, capabilities) + StateCookie(chunk->initiate_tag(), my_verification_tag, + chunk->initial_tsn(), my_initial_tsn, chunk->a_rwnd(), + tie_tag, capabilities) .Serialize())); AddCapabilityParameters(options_, params_builder); - InitAckChunk init_ack(/*initiate_tag=*/connect_params_.verification_tag, + InitAckChunk init_ack(/*initiate_tag=*/my_verification_tag, options_.max_receiver_window_buffer_size, options_.announced_maximum_outgoing_streams, options_.announced_maximum_incoming_streams, - connect_params_.initial_tsn, params_builder.Build()); + my_initial_tsn, params_builder.Build()); b.Add(init_ack); // If the peer has signaled that it supports zero checksum, INIT-ACK can then // have its checksum as zero. @@ -1309,13 +1319,13 @@ void DcSctpSocket::HandleCookieEcho( return; } } else { - if (header.verification_tag != connect_params_.verification_tag) { + if (header.verification_tag != cookie->my_tag()) { callbacks_.OnError( ErrorKind::kParseFailed, rtc::StringFormat( "Received CookieEcho with invalid verification tag: %08x, " "expected %08x", - *header.verification_tag, *connect_params_.verification_tag)); + *header.verification_tag, *cookie->my_tag())); return; } } @@ -1340,10 +1350,10 @@ void DcSctpSocket::HandleCookieEcho( // send queue is already re-configured, and shouldn't be reset. send_queue_.Reset(); - CreateTransmissionControlBlock( - cookie->capabilities(), connect_params_.verification_tag, - connect_params_.initial_tsn, cookie->initiate_tag(), - cookie->initial_tsn(), cookie->a_rwnd(), MakeTieTag(callbacks_)); + CreateTransmissionControlBlock(cookie->capabilities(), cookie->my_tag(), + cookie->my_initial_tsn(), cookie->peer_tag(), + cookie->peer_initial_tsn(), cookie->a_rwnd(), + MakeTieTag(callbacks_)); } SctpPacket::Builder b = tcb_->PacketBuilder(); @@ -1363,13 +1373,13 @@ bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header, << *tcb_->my_verification_tag() << ", peer_tag=" << *header.verification_tag << ", tcb_tag=" << *tcb_->peer_verification_tag() - << ", cookie_tag=" << *cookie.initiate_tag() + << ", peer_tag=" << *cookie.peer_tag() << ", local_tie_tag=" << *tcb_->tie_tag() << ", peer_tie_tag=" << *cookie.tie_tag(); // https://tools.ietf.org/html/rfc4960#section-5.2.4 // "Handle a COOKIE ECHO when a TCB Exists" if (header.verification_tag != tcb_->my_verification_tag() && - tcb_->peer_verification_tag() != cookie.initiate_tag() && + tcb_->peer_verification_tag() != cookie.peer_tag() && cookie.tie_tag() == tcb_->tie_tag()) { // "A) In this case, the peer may have restarted." if (state_ == State::kShutdownAckSent) { @@ -1377,7 +1387,7 @@ bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header, // that the peer has restarted ... it MUST NOT set up a new association // but instead resend the SHUTDOWN ACK and send an ERROR chunk with a // "Cookie Received While Shutting Down" error cause to its peer." - SctpPacket::Builder b(cookie.initiate_tag(), options_); + SctpPacket::Builder b(cookie.peer_tag(), options_); b.Add(ShutdownAckChunk()); b.Add(ErrorChunk(Parameters::Builder() .Add(CookieReceivedWhileShuttingDownCause()) @@ -1394,7 +1404,7 @@ bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header, tcb_ = nullptr; callbacks_.OnConnectionRestarted(); } else if (header.verification_tag == tcb_->my_verification_tag() && - tcb_->peer_verification_tag() != cookie.initiate_tag()) { + tcb_->peer_verification_tag() != cookie.peer_tag()) { // TODO(boivie): Handle the peer_tag == 0? // "B) In this case, both sides may be attempting to start an // association at about the same time, but the peer endpoint started its @@ -1404,7 +1414,7 @@ bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header, << "Received COOKIE-ECHO indicating simultaneous connections"; tcb_ = nullptr; } else if (header.verification_tag != tcb_->my_verification_tag() && - tcb_->peer_verification_tag() == cookie.initiate_tag() && + tcb_->peer_verification_tag() == cookie.peer_tag() && cookie.tie_tag() == TieTag(0)) { // "C) In this case, the local endpoint's cookie has arrived late. // Before it arrived, the local endpoint sent an INIT and received an @@ -1417,7 +1427,7 @@ bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header, << "Received COOKIE-ECHO indicating a late COOKIE-ECHO. Discarding"; return false; } else if (header.verification_tag == tcb_->my_verification_tag() && - tcb_->peer_verification_tag() == cookie.initiate_tag()) { + tcb_->peer_verification_tag() == cookie.peer_tag()) { // "D) When both local and remote tags match, the endpoint should enter // the ESTABLISHED state, if it is in the COOKIE-ECHOED state. It // should stop any cookie timer that may be running and send a COOKIE @@ -1761,7 +1771,6 @@ void DcSctpSocket::SendShutdownAck() { } HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const { - RTC_DCHECK_RUN_ON(&thread_checker_); HandoverReadinessStatus status; if (state_ != State::kClosed && state_ != State::kEstablished) { status.Add(HandoverUnreadinessReason::kWrongConnectionState); @@ -1775,7 +1784,6 @@ HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const { absl::optional DcSctpSocket::GetHandoverStateAndClose() { - RTC_DCHECK_RUN_ON(&thread_checker_); CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); if (!GetHandoverReadiness().IsReady()) { diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h index deb6ee23e7..c65571a923 100644 --- a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h +++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h @@ -14,10 +14,10 @@ #include #include #include +#include #include "absl/strings/string_view.h" #include "api/array_view.h" -#include "api/sequence_checker.h" #include "net/dcsctp/packet/chunk/abort_chunk.h" #include "net/dcsctp/packet/chunk/chunk.h" #include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" @@ -91,6 +91,8 @@ class DcSctpSocket : public DcSctpSocketInterface { void Close() override; SendStatus Send(DcSctpMessage message, const SendOptions& send_options) override; + std::vector SendMany(rtc::ArrayView messages, + const SendOptions& send_options) override; ResetStreamsStatus ResetStreams( rtc::ArrayView outgoing_streams) override; SocketState state() const override; @@ -148,8 +150,6 @@ class DcSctpSocket : public DcSctpSocketInterface { // Changes the socket state, given a `reason` (for debugging/logging). void SetState(State state, absl::string_view reason); - // Fills in `connect_params` with random verification tag and initial TSN. - void MakeConnectionParameters(); // Closes the association. Note that the TCB will not be valid past this call. void InternalClose(ErrorKind error, absl::string_view message); // Closes the association, because of too many retransmission errors. @@ -167,6 +167,9 @@ class DcSctpSocket : public DcSctpSocketInterface { void MaybeSendShutdownOnPacketReceived(const SctpPacket& packet); // If there are streams pending to be reset, send a request to reset them. void MaybeSendResetStreamsRequest(); + // Performs internal processing shared between Send and SendMany. + SendStatus InternalSend(const DcSctpMessage& message, + const SendOptions& send_options); // Sends a INIT chunk. void SendInit(); // Sends a SHUTDOWN chunk. @@ -267,7 +270,6 @@ class DcSctpSocket : public DcSctpSocketInterface { const std::string log_prefix_; const std::unique_ptr packet_observer_; - RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker thread_checker_; Metrics metrics_; DcSctpOptions options_; diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc index dc76b80a37..413516bae0 100644 --- a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc +++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc @@ -66,6 +66,7 @@ namespace { using ::testing::_; using ::testing::AllOf; using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::IsEmpty; @@ -1561,6 +1562,33 @@ TEST(DcSctpSocketTest, SetMaxMessageSize) { EXPECT_EQ(a.socket.options().max_message_size, 42u); } +TEST_P(DcSctpSocketParametrizedTest, SendManyMessages) { + SocketUnderTest a("A"); + auto z = std::make_unique("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + static constexpr int kIterations = 100; + std::vector messages; + std::vector statuses; + for (int i = 0; i < kIterations; ++i) { + messages.push_back(DcSctpMessage(StreamID(1), PPID(53), {1, 2})); + statuses.push_back(SendStatus::kSuccess); + } + EXPECT_THAT(a.socket.SendMany(messages, {}), ElementsAreArray(statuses)); + + ExchangeMessages(a, *z); + + for (int i = 0; i < kIterations; ++i) { + EXPECT_TRUE(z->cb.ConsumeReceivedMessage().has_value()); + } + + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) { SocketUnderTest a("A"); auto z = std::make_unique("Z"); @@ -3061,5 +3089,149 @@ TEST(DcSctpSocketTest, HandlesForwardTsnOutOfOrderWithStreamResetting) { testing::Optional(Property(&DcSctpMessage::ppid, PPID(53)))); } +TEST(DcSctpSocketTest, ResentInitHasSameParameters) { + // If an INIT chunk has to be resent (due to INIT_ACK not received in time), + // the resent INIT must have the same properties as the original one. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + auto packet_1 = a.cb.ConsumeSentPacket(); + + // Times out, INIT is re-sent. + AdvanceTime(a, z, a.options.t1_init_timeout.ToTimeDelta()); + auto packet_2 = a.cb.ConsumeSentPacket(); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet_1, + SctpPacket::Parse(packet_1, z.options)); + ASSERT_HAS_VALUE_AND_ASSIGN( + InitChunk init_chunk_1, + InitChunk::Parse(init_packet_1.descriptors()[0].data)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet_2, + SctpPacket::Parse(packet_2, z.options)); + ASSERT_HAS_VALUE_AND_ASSIGN( + InitChunk init_chunk_2, + InitChunk::Parse(init_packet_2.descriptors()[0].data)); + + EXPECT_EQ(init_chunk_1.initial_tsn(), init_chunk_2.initial_tsn()); + EXPECT_EQ(init_chunk_1.initiate_tag(), init_chunk_2.initiate_tag()); +} + +TEST(DcSctpSocketTest, ResentInitAckHasDifferentParameters) { + // For every INIT, an INIT_ACK is produced. Verify that the socket doesn't + // maintain any state by ensuring that two created INIT_ACKs for the same + // received INIT are different. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + auto packet_1 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet_1, HasChunks(ElementsAre(IsChunkType(InitChunk::kType)))); + + z.socket.ReceivePacket(packet_1); + auto packet_2 = z.cb.ConsumeSentPacket(); + z.socket.ReceivePacket(packet_1); + auto packet_3 = z.cb.ConsumeSentPacket(); + + EXPECT_THAT(packet_2, + HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType)))); + EXPECT_THAT(packet_3, + HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType)))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_ack_packet_1, + SctpPacket::Parse(packet_2, z.options)); + ASSERT_HAS_VALUE_AND_ASSIGN( + InitAckChunk init_ack_chunk_1, + InitAckChunk::Parse(init_ack_packet_1.descriptors()[0].data)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_ack_packet_2, + SctpPacket::Parse(packet_3, z.options)); + ASSERT_HAS_VALUE_AND_ASSIGN( + InitAckChunk init_ack_chunk_2, + InitAckChunk::Parse(init_ack_packet_2.descriptors()[0].data)); + + EXPECT_NE(init_ack_chunk_1.initiate_tag(), init_ack_chunk_2.initiate_tag()); + EXPECT_NE(init_ack_chunk_1.initial_tsn(), init_ack_chunk_2.initial_tsn()); +} + +TEST(DcSctpSocketResendInitTest, ConnectionCanContinueFromFirstInitAck) { + // If an INIT chunk has to be resent (due to INIT_ACK not received in time), + // another INIT will be sent, and if both INITs were actually received, both + // will be responded to by an INIT_ACK. While these two INIT_ACKs may have + // different parameters, the connection must be able to finish with the cookie + // (as replied to using COOKIE_ECHO) from either INIT_ACK. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector(kLargeMessageSize)), + kSendOptions); + a.socket.Connect(); + auto init_1 = a.cb.ConsumeSentPacket(); + + // Times out, INIT is re-sent. + AdvanceTime(a, z, a.options.t1_init_timeout.ToTimeDelta()); + auto init_2 = a.cb.ConsumeSentPacket(); + + EXPECT_THAT(init_1, HasChunks(ElementsAre(IsChunkType(InitChunk::kType)))); + EXPECT_THAT(init_2, HasChunks(ElementsAre(IsChunkType(InitChunk::kType)))); + + z.socket.ReceivePacket(init_1); + z.socket.ReceivePacket(init_2); + auto init_ack_1 = z.cb.ConsumeSentPacket(); + auto init_ack_2 = z.cb.ConsumeSentPacket(); + EXPECT_THAT(init_ack_1, + HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType)))); + EXPECT_THAT(init_ack_2, + HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType)))); + + a.socket.ReceivePacket(init_ack_1); + // Then let the rest continue. + ExchangeMessages(a, z); + + absl::optional msg = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), SizeIs(kLargeMessageSize)); +} + +TEST(DcSctpSocketResendInitTest, ConnectionCanContinueFromSecondInitAck) { + // Just as above, but discarding the first INIT_ACK. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector(kLargeMessageSize)), + kSendOptions); + a.socket.Connect(); + auto init_1 = a.cb.ConsumeSentPacket(); + + // Times out, INIT is re-sent. + AdvanceTime(a, z, a.options.t1_init_timeout.ToTimeDelta()); + auto init_2 = a.cb.ConsumeSentPacket(); + + EXPECT_THAT(init_1, HasChunks(ElementsAre(IsChunkType(InitChunk::kType)))); + EXPECT_THAT(init_2, HasChunks(ElementsAre(IsChunkType(InitChunk::kType)))); + + z.socket.ReceivePacket(init_1); + z.socket.ReceivePacket(init_2); + auto init_ack_1 = z.cb.ConsumeSentPacket(); + auto init_ack_2 = z.cb.ConsumeSentPacket(); + EXPECT_THAT(init_ack_1, + HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType)))); + EXPECT_THAT(init_ack_2, + HasChunks(ElementsAre(IsChunkType(InitAckChunk::kType)))); + + a.socket.ReceivePacket(init_ack_2); + // Then let the rest continue. + ExchangeMessages(a, z); + + absl::optional msg = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), SizeIs(kLargeMessageSize)); +} + } // namespace } // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc index 624d783a3b..c5ed1d8620 100644 --- a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc +++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc @@ -32,17 +32,19 @@ std::vector StateCookie::Serialize() { BoundedByteWriter buffer(cookie); buffer.Store32<0>(kMagic1); buffer.Store32<4>(kMagic2); - buffer.Store32<8>(*initiate_tag_); - buffer.Store32<12>(*initial_tsn_); - buffer.Store32<16>(a_rwnd_); - buffer.Store32<20>(static_cast(*tie_tag_ >> 32)); - buffer.Store32<24>(static_cast(*tie_tag_)); - buffer.Store8<28>(capabilities_.partial_reliability); - buffer.Store8<29>(capabilities_.message_interleaving); - buffer.Store8<30>(capabilities_.reconfig); - buffer.Store16<32>(capabilities_.negotiated_maximum_incoming_streams); - buffer.Store16<34>(capabilities_.negotiated_maximum_outgoing_streams); - buffer.Store8<36>(capabilities_.zero_checksum); + buffer.Store32<8>(*peer_tag_); + buffer.Store32<12>(*my_tag_); + buffer.Store32<16>(*peer_initial_tsn_); + buffer.Store32<20>(*my_initial_tsn_); + buffer.Store32<24>(a_rwnd_); + buffer.Store32<28>(static_cast(*tie_tag_ >> 32)); + buffer.Store32<32>(static_cast(*tie_tag_)); + buffer.Store8<36>(capabilities_.partial_reliability); + buffer.Store8<37>(capabilities_.message_interleaving); + buffer.Store8<38>(capabilities_.reconfig); + buffer.Store16<40>(capabilities_.negotiated_maximum_incoming_streams); + buffer.Store16<42>(capabilities_.negotiated_maximum_outgoing_streams); + buffer.Store8<44>(capabilities_.zero_checksum); return cookie; } @@ -62,23 +64,25 @@ absl::optional StateCookie::Deserialize( return absl::nullopt; } - VerificationTag verification_tag(buffer.Load32<8>()); - TSN initial_tsn(buffer.Load32<12>()); - uint32_t a_rwnd = buffer.Load32<16>(); - uint32_t tie_tag_upper = buffer.Load32<20>(); - uint32_t tie_tag_lower = buffer.Load32<24>(); + VerificationTag peer_tag(buffer.Load32<8>()); + VerificationTag my_tag(buffer.Load32<12>()); + TSN peer_initial_tsn(buffer.Load32<16>()); + TSN my_initial_tsn(buffer.Load32<20>()); + uint32_t a_rwnd = buffer.Load32<24>(); + uint32_t tie_tag_upper = buffer.Load32<28>(); + uint32_t tie_tag_lower = buffer.Load32<32>(); TieTag tie_tag(static_cast(tie_tag_upper) << 32 | static_cast(tie_tag_lower)); Capabilities capabilities; - capabilities.partial_reliability = buffer.Load8<28>() != 0; - capabilities.message_interleaving = buffer.Load8<29>() != 0; - capabilities.reconfig = buffer.Load8<30>() != 0; - capabilities.negotiated_maximum_incoming_streams = buffer.Load16<32>(); - capabilities.negotiated_maximum_outgoing_streams = buffer.Load16<34>(); - capabilities.zero_checksum = buffer.Load8<36>() != 0; + capabilities.partial_reliability = buffer.Load8<36>() != 0; + capabilities.message_interleaving = buffer.Load8<37>() != 0; + capabilities.reconfig = buffer.Load8<38>() != 0; + capabilities.negotiated_maximum_incoming_streams = buffer.Load16<40>(); + capabilities.negotiated_maximum_outgoing_streams = buffer.Load16<42>(); + capabilities.zero_checksum = buffer.Load8<44>() != 0; - return StateCookie(verification_tag, initial_tsn, a_rwnd, tie_tag, - capabilities); + return StateCookie(peer_tag, my_tag, peer_initial_tsn, my_initial_tsn, a_rwnd, + tie_tag, capabilities); } } // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h index 34cd6d3690..b94eedafd4 100644 --- a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h +++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h @@ -27,15 +27,19 @@ namespace dcsctp { // Do not trust anything in it; no pointers or anything like that. class StateCookie { public: - static constexpr size_t kCookieSize = 37; + static constexpr size_t kCookieSize = 45; - StateCookie(VerificationTag initiate_tag, - TSN initial_tsn, + StateCookie(VerificationTag peer_tag, + VerificationTag my_tag, + TSN peer_initial_tsn, + TSN my_initial_tsn, uint32_t a_rwnd, TieTag tie_tag, Capabilities capabilities) - : initiate_tag_(initiate_tag), - initial_tsn_(initial_tsn), + : peer_tag_(peer_tag), + my_tag_(my_tag), + peer_initial_tsn_(peer_initial_tsn), + my_initial_tsn_(my_initial_tsn), a_rwnd_(a_rwnd), tie_tag_(tie_tag), capabilities_(capabilities) {} @@ -47,15 +51,21 @@ class StateCookie { static absl::optional Deserialize( rtc::ArrayView cookie); - VerificationTag initiate_tag() const { return initiate_tag_; } - TSN initial_tsn() const { return initial_tsn_; } + VerificationTag peer_tag() const { return peer_tag_; } + VerificationTag my_tag() const { return my_tag_; } + TSN peer_initial_tsn() const { return peer_initial_tsn_; } + TSN my_initial_tsn() const { return my_initial_tsn_; } uint32_t a_rwnd() const { return a_rwnd_; } TieTag tie_tag() const { return tie_tag_; } const Capabilities& capabilities() const { return capabilities_; } private: - const VerificationTag initiate_tag_; - const TSN initial_tsn_; + // Also called "Tag_A" in RFC4960. + const VerificationTag peer_tag_; + // Also called "Tag_Z" in RFC4960. + const VerificationTag my_tag_; + const TSN peer_initial_tsn_; + const TSN my_initial_tsn_; const uint32_t a_rwnd_; const TieTag tie_tag_; const Capabilities capabilities_; diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc b/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc index 19be71a1ca..806ea2024b 100644 --- a/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc +++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc @@ -24,14 +24,18 @@ TEST(StateCookieTest, SerializeAndDeserialize) { .zero_checksum = true, .negotiated_maximum_incoming_streams = 123, .negotiated_maximum_outgoing_streams = 234}; - StateCookie cookie(VerificationTag(123), TSN(456), + StateCookie cookie(/*peer_tag=*/VerificationTag(123), + /*my_tag=*/VerificationTag(321), + /*peer_initial_tsn=*/TSN(456), /*my_initial_tsn=*/TSN(654), /*a_rwnd=*/789, TieTag(101112), capabilities); std::vector serialized = cookie.Serialize(); EXPECT_THAT(serialized, SizeIs(StateCookie::kCookieSize)); ASSERT_HAS_VALUE_AND_ASSIGN(StateCookie deserialized, StateCookie::Deserialize(serialized)); - EXPECT_EQ(deserialized.initiate_tag(), VerificationTag(123)); - EXPECT_EQ(deserialized.initial_tsn(), TSN(456)); + EXPECT_EQ(deserialized.peer_tag(), VerificationTag(123)); + EXPECT_EQ(deserialized.my_tag(), VerificationTag(321)); + EXPECT_EQ(deserialized.peer_initial_tsn(), TSN(456)); + EXPECT_EQ(deserialized.my_initial_tsn(), TSN(654)); EXPECT_EQ(deserialized.a_rwnd(), 789u); EXPECT_EQ(deserialized.tie_tag(), TieTag(101112)); EXPECT_TRUE(deserialized.capabilities().partial_reliability); @@ -48,7 +52,9 @@ TEST(StateCookieTest, ValidateMagicValue) { Capabilities capabilities = {.partial_reliability = true, .message_interleaving = false, .reconfig = true}; - StateCookie cookie(VerificationTag(123), TSN(456), + StateCookie cookie(/*peer_tag=*/VerificationTag(123), + /*my_tag=*/VerificationTag(321), + /*peer_initial_tsn=*/TSN(456), /*my_initial_tsn=*/TSN(654), /*a_rwnd=*/789, TieTag(101112), capabilities); std::vector serialized = cookie.Serialize(); ASSERT_THAT(serialized, SizeIs(StateCookie::kCookieSize)); -- cgit v1.2.3