diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
commit | 43a97878ce14b72f0981164f87f2e35e14151312 (patch) | |
tree | 620249daf56c0258faa40cbdcf9cfba06de2a846 /third_party/libwebrtc/net/dcsctp | |
parent | Initial commit. (diff) | |
download | firefox-43a97878ce14b72f0981164f87f2e35e14151312.tar.xz firefox-43a97878ce14b72f0981164f87f2e35e14151312.zip |
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/libwebrtc/net/dcsctp')
253 files changed, 35398 insertions, 0 deletions
diff --git a/third_party/libwebrtc/net/dcsctp/BUILD.gn b/third_party/libwebrtc/net/dcsctp/BUILD.gn new file mode 100644 index 0000000000..8b38a65ca1 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/BUILD.gn @@ -0,0 +1,26 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../webrtc.gni") + +if (rtc_include_tests) { + rtc_test("dcsctp_unittests") { + testonly = true + deps = [ + "../../test:test_main", + "common:dcsctp_common_unittests", + "fuzzers:dcsctp_fuzzers_unittests", + "packet:dcsctp_packet_unittests", + "public:dcsctp_public_unittests", + "rx:dcsctp_rx_unittests", + "socket:dcsctp_socket_unittests", + "timer:dcsctp_timer_unittests", + "tx:dcsctp_tx_unittests", + ] + } +} diff --git a/third_party/libwebrtc/net/dcsctp/OWNERS b/third_party/libwebrtc/net/dcsctp/OWNERS new file mode 100644 index 0000000000..06a0f86179 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/OWNERS @@ -0,0 +1,2 @@ +boivie@webrtc.org +orphis@webrtc.org diff --git a/third_party/libwebrtc/net/dcsctp/common/BUILD.gn b/third_party/libwebrtc/net/dcsctp/common/BUILD.gn new file mode 100644 index 0000000000..251ebaaf91 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/common/BUILD.gn @@ -0,0 +1,64 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("internal_types") { + deps = [ + "../../../rtc_base:strong_alias", + "../public:types", + ] + sources = [ "internal_types.h" ] +} + +rtc_source_set("math") { + deps = [] + sources = [ "math.h" ] +} + +rtc_source_set("sequence_numbers") { + deps = [ ":internal_types" ] + sources = [ "sequence_numbers.h" ] +} + +rtc_source_set("str_join") { + deps = [ "../../../rtc_base:stringutils" ] + sources = [ "str_join.h" ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_common_unittests") { + testonly = true + + defines = [] + deps = [ + ":math", + ":sequence_numbers", + ":str_join", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../test:test_support", + ] + sources = [ + "math_test.cc", + "sequence_numbers_test.cc", + "str_join_test.cc", + ] + } +} + +rtc_library("handover_testing") { + deps = [ "../public:socket" ] + testonly = true + sources = [ + "handover_testing.cc", + "handover_testing.h", + ] +} diff --git a/third_party/libwebrtc/net/dcsctp/common/handover_testing.cc b/third_party/libwebrtc/net/dcsctp/common/handover_testing.cc new file mode 100644 index 0000000000..1081766ea5 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/common/handover_testing.cc @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/common/handover_testing.h" + +namespace dcsctp { +namespace { +// Default transformer function does nothing - dcSCTP does not implement +// state serialization that could be tested by setting +// `g_handover_state_transformer_for_test`. +void NoTransformation(DcSctpSocketHandoverState*) {} +} // namespace + +void (*g_handover_state_transformer_for_test)(DcSctpSocketHandoverState*) = + NoTransformation; +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/common/handover_testing.h b/third_party/libwebrtc/net/dcsctp/common/handover_testing.h new file mode 100644 index 0000000000..396016afec --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/common/handover_testing.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_COMMON_HANDOVER_TESTING_H_ +#define NET_DCSCTP_COMMON_HANDOVER_TESTING_H_ + +#include "net/dcsctp/public/dcsctp_handover_state.h" + +namespace dcsctp { +// This global function is to facilitate testing of the socket handover state +// (`DcSctpSocketHandoverState`) serialization. dcSCTP library users have to +// implement state serialization if it's needed. To test the serialization one +// can set a custom `g_handover_state_transformer_for_test` at startup, link to +// the dcSCTP tests and run the resulting binary. Custom function can serialize +// and deserialize the passed state. All dcSCTP handover tests call +// `g_handover_state_transformer_for_test`. If some part of the state is +// serialized incorrectly or is forgotten, high chance that it will fail the +// tests. +extern void (*g_handover_state_transformer_for_test)( + DcSctpSocketHandoverState*); +} // namespace dcsctp + +#endif // NET_DCSCTP_COMMON_HANDOVER_TESTING_H_ diff --git a/third_party/libwebrtc/net/dcsctp/common/internal_types.h b/third_party/libwebrtc/net/dcsctp/common/internal_types.h new file mode 100644 index 0000000000..2354b92cc4 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/common/internal_types.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_COMMON_INTERNAL_TYPES_H_ +#define NET_DCSCTP_COMMON_INTERNAL_TYPES_H_ + +#include <functional> +#include <utility> + +#include "net/dcsctp/public/types.h" +#include "rtc_base/strong_alias.h" + +namespace dcsctp { + +// Stream Sequence Number (SSN) +using SSN = webrtc::StrongAlias<class SSNTag, uint16_t>; + +// Message Identifier (MID) +using MID = webrtc::StrongAlias<class MIDTag, uint32_t>; + +// Fragment Sequence Number (FSN) +using FSN = webrtc::StrongAlias<class FSNTag, uint32_t>; + +// Transmission Sequence Number (TSN) +using TSN = webrtc::StrongAlias<class TSNTag, uint32_t>; + +// Reconfiguration Request Sequence Number +using ReconfigRequestSN = + webrtc::StrongAlias<class ReconfigRequestSNTag, uint32_t>; + +// Verification Tag, used for packet validation. +using VerificationTag = webrtc::StrongAlias<class VerificationTagTag, uint32_t>; + +// Tie Tag, used as a nonce when connecting. +using TieTag = webrtc::StrongAlias<class TieTagTag, uint64_t>; + +} // namespace dcsctp +#endif // NET_DCSCTP_COMMON_INTERNAL_TYPES_H_ diff --git a/third_party/libwebrtc/net/dcsctp/common/math.h b/third_party/libwebrtc/net/dcsctp/common/math.h new file mode 100644 index 0000000000..12f690ed57 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/common/math.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_COMMON_MATH_H_ +#define NET_DCSCTP_COMMON_MATH_H_ + +namespace dcsctp { + +// Rounds up `val` to the nearest value that is divisible by four. Frequently +// used to e.g. pad chunks or parameters to an even 32-bit offset. +template <typename IntType> +IntType RoundUpTo4(IntType val) { + return (val + 3) & ~3; +} + +// Similarly, rounds down `val` to the nearest value that is divisible by four. +template <typename IntType> +IntType RoundDownTo4(IntType val) { + return val & ~3; +} + +// Returns true if `val` is divisible by four. +template <typename IntType> +bool IsDivisibleBy4(IntType val) { + return (val & 3) == 0; +} + +} // namespace dcsctp + +#endif // NET_DCSCTP_COMMON_MATH_H_ diff --git a/third_party/libwebrtc/net/dcsctp/common/math_test.cc b/third_party/libwebrtc/net/dcsctp/common/math_test.cc new file mode 100644 index 0000000000..f95dfbdb55 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/common/math_test.cc @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/common/math.h" + +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(MathUtilTest, CanRoundUpTo4) { + // Signed numbers + EXPECT_EQ(RoundUpTo4(static_cast<int>(-5)), -4); + EXPECT_EQ(RoundUpTo4(static_cast<int>(-4)), -4); + EXPECT_EQ(RoundUpTo4(static_cast<int>(-3)), 0); + EXPECT_EQ(RoundUpTo4(static_cast<int>(-2)), 0); + EXPECT_EQ(RoundUpTo4(static_cast<int>(-1)), 0); + EXPECT_EQ(RoundUpTo4(static_cast<int>(0)), 0); + EXPECT_EQ(RoundUpTo4(static_cast<int>(1)), 4); + EXPECT_EQ(RoundUpTo4(static_cast<int>(2)), 4); + EXPECT_EQ(RoundUpTo4(static_cast<int>(3)), 4); + EXPECT_EQ(RoundUpTo4(static_cast<int>(4)), 4); + EXPECT_EQ(RoundUpTo4(static_cast<int>(5)), 8); + EXPECT_EQ(RoundUpTo4(static_cast<int>(6)), 8); + EXPECT_EQ(RoundUpTo4(static_cast<int>(7)), 8); + EXPECT_EQ(RoundUpTo4(static_cast<int>(8)), 8); + EXPECT_EQ(RoundUpTo4(static_cast<int64_t>(10000000000)), 10000000000); + EXPECT_EQ(RoundUpTo4(static_cast<int64_t>(10000000001)), 10000000004); + + // Unsigned numbers + EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(0)), 0u); + EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(1)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(2)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(3)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(4)), 4u); + EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(5)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(6)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(7)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast<unsigned int>(8)), 8u); + EXPECT_EQ(RoundUpTo4(static_cast<uint64_t>(10000000000)), 10000000000u); + EXPECT_EQ(RoundUpTo4(static_cast<uint64_t>(10000000001)), 10000000004u); +} + +TEST(MathUtilTest, CanRoundDownTo4) { + // Signed numbers + EXPECT_EQ(RoundDownTo4(static_cast<int>(-5)), -8); + EXPECT_EQ(RoundDownTo4(static_cast<int>(-4)), -4); + EXPECT_EQ(RoundDownTo4(static_cast<int>(-3)), -4); + EXPECT_EQ(RoundDownTo4(static_cast<int>(-2)), -4); + EXPECT_EQ(RoundDownTo4(static_cast<int>(-1)), -4); + EXPECT_EQ(RoundDownTo4(static_cast<int>(0)), 0); + EXPECT_EQ(RoundDownTo4(static_cast<int>(1)), 0); + EXPECT_EQ(RoundDownTo4(static_cast<int>(2)), 0); + EXPECT_EQ(RoundDownTo4(static_cast<int>(3)), 0); + EXPECT_EQ(RoundDownTo4(static_cast<int>(4)), 4); + EXPECT_EQ(RoundDownTo4(static_cast<int>(5)), 4); + EXPECT_EQ(RoundDownTo4(static_cast<int>(6)), 4); + EXPECT_EQ(RoundDownTo4(static_cast<int>(7)), 4); + EXPECT_EQ(RoundDownTo4(static_cast<int>(8)), 8); + EXPECT_EQ(RoundDownTo4(static_cast<int64_t>(10000000000)), 10000000000); + EXPECT_EQ(RoundDownTo4(static_cast<int64_t>(10000000001)), 10000000000); + + // Unsigned numbers + EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(0)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(1)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(2)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(3)), 0u); + EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(4)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(5)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(6)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(7)), 4u); + EXPECT_EQ(RoundDownTo4(static_cast<unsigned int>(8)), 8u); + EXPECT_EQ(RoundDownTo4(static_cast<uint64_t>(10000000000)), 10000000000u); + EXPECT_EQ(RoundDownTo4(static_cast<uint64_t>(10000000001)), 10000000000u); +} + +TEST(MathUtilTest, IsDivisibleBy4) { + // Signed numbers + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(-4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(-3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(-2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(-1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(0)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(5)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(6)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(7)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<int>(8)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast<int64_t>(10000000000)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast<int64_t>(10000000001)), false); + + // Unsigned numbers + EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(0)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(1)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(2)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(3)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(4)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(5)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(6)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(7)), false); + EXPECT_EQ(IsDivisibleBy4(static_cast<unsigned int>(8)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast<uint64_t>(10000000000)), true); + EXPECT_EQ(IsDivisibleBy4(static_cast<uint64_t>(10000000001)), false); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/common/sequence_numbers.h b/third_party/libwebrtc/net/dcsctp/common/sequence_numbers.h new file mode 100644 index 0000000000..919fc5014a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/common/sequence_numbers.h @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_COMMON_SEQUENCE_NUMBERS_H_ +#define NET_DCSCTP_COMMON_SEQUENCE_NUMBERS_H_ + +#include <cstdint> +#include <limits> +#include <utility> + +#include "net/dcsctp/common/internal_types.h" + +namespace dcsctp { + +// UnwrappedSequenceNumber handles wrapping sequence numbers and unwraps them to +// an int64_t value space, to allow wrapped sequence numbers to be easily +// compared for ordering. +// +// Sequence numbers are expected to be monotonically increasing, but they do not +// need to be unwrapped in order, as long as the difference to the previous one +// is not larger than half the range of the wrapped sequence number. +// +// The WrappedType must be a webrtc::StrongAlias type. +template <typename WrappedType> +class UnwrappedSequenceNumber { + public: + static_assert( + !std::numeric_limits<typename WrappedType::UnderlyingType>::is_signed, + "The wrapped type must be unsigned"); + static_assert( + std::numeric_limits<typename WrappedType::UnderlyingType>::max() < + std::numeric_limits<int64_t>::max(), + "The wrapped type must be less than the int64_t value space"); + + // The unwrapper is a sort of factory and converts wrapped sequence numbers to + // unwrapped ones. + class Unwrapper { + public: + Unwrapper() : largest_(kValueLimit) {} + Unwrapper(const Unwrapper&) = default; + Unwrapper& operator=(const Unwrapper&) = default; + + // Given a wrapped `value`, and with knowledge of its current last seen + // largest number, will return a value that can be compared using normal + // operators, such as less-than, greater-than etc. + // + // This will also update the Unwrapper's state, to track the last seen + // largest value. + UnwrappedSequenceNumber<WrappedType> Unwrap(WrappedType value) { + WrappedType wrapped_largest = + static_cast<WrappedType>(largest_ % kValueLimit); + int64_t result = largest_ + Delta(value, wrapped_largest); + if (largest_ < result) { + largest_ = result; + } + return UnwrappedSequenceNumber<WrappedType>(result); + } + + // Similar to `Unwrap`, but will not update the Unwrappers's internal state. + UnwrappedSequenceNumber<WrappedType> PeekUnwrap(WrappedType value) const { + WrappedType uint32_largest = + static_cast<WrappedType>(largest_ % kValueLimit); + int64_t result = largest_ + Delta(value, uint32_largest); + return UnwrappedSequenceNumber<WrappedType>(result); + } + + // Resets the Unwrapper to its pristine state. Used when a sequence number + // is to be reset to zero. + void Reset() { largest_ = kValueLimit; } + + private: + static int64_t Delta(WrappedType value, WrappedType prev_value) { + static constexpr typename WrappedType::UnderlyingType kBreakpoint = + kValueLimit / 2; + typename WrappedType::UnderlyingType diff = *value - *prev_value; + diff %= kValueLimit; + if (diff < kBreakpoint) { + return static_cast<int64_t>(diff); + } + return static_cast<int64_t>(diff) - kValueLimit; + } + + int64_t largest_; + }; + + // Returns the wrapped value this type represents. + WrappedType Wrap() const { + return static_cast<WrappedType>(value_ % kValueLimit); + } + + template <typename H> + friend H AbslHashValue(H state, + const UnwrappedSequenceNumber<WrappedType>& hash) { + return H::combine(std::move(state), hash.value_); + } + + bool operator==(const UnwrappedSequenceNumber<WrappedType>& other) const { + return value_ == other.value_; + } + bool operator!=(const UnwrappedSequenceNumber<WrappedType>& other) const { + return value_ != other.value_; + } + bool operator<(const UnwrappedSequenceNumber<WrappedType>& other) const { + return value_ < other.value_; + } + bool operator>(const UnwrappedSequenceNumber<WrappedType>& other) const { + return value_ > other.value_; + } + bool operator>=(const UnwrappedSequenceNumber<WrappedType>& other) const { + return value_ >= other.value_; + } + bool operator<=(const UnwrappedSequenceNumber<WrappedType>& other) const { + return value_ <= other.value_; + } + + // Increments the value. + void Increment() { ++value_; } + + // Returns the next value relative to this sequence number. + UnwrappedSequenceNumber<WrappedType> next_value() const { + return UnwrappedSequenceNumber<WrappedType>(value_ + 1); + } + + // Returns a new sequence number based on `value`, and adding `delta` (which + // may be negative). + static UnwrappedSequenceNumber<WrappedType> AddTo( + UnwrappedSequenceNumber<WrappedType> value, + int delta) { + return UnwrappedSequenceNumber<WrappedType>(value.value_ + delta); + } + + // Returns the absolute difference between `lhs` and `rhs`. + static typename WrappedType::UnderlyingType Difference( + UnwrappedSequenceNumber<WrappedType> lhs, + UnwrappedSequenceNumber<WrappedType> rhs) { + return (lhs.value_ > rhs.value_) ? (lhs.value_ - rhs.value_) + : (rhs.value_ - lhs.value_); + } + + private: + explicit UnwrappedSequenceNumber(int64_t value) : value_(value) {} + static constexpr int64_t kValueLimit = + static_cast<int64_t>(1) + << std::numeric_limits<typename WrappedType::UnderlyingType>::digits; + + int64_t value_; +}; + +// Unwrapped Transmission Sequence Numbers (TSN) +using UnwrappedTSN = UnwrappedSequenceNumber<TSN>; + +// Unwrapped Stream Sequence Numbers (SSN) +using UnwrappedSSN = UnwrappedSequenceNumber<SSN>; + +// Unwrapped Message Identifier (MID) +using UnwrappedMID = UnwrappedSequenceNumber<MID>; + +} // namespace dcsctp + +#endif // NET_DCSCTP_COMMON_SEQUENCE_NUMBERS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/common/sequence_numbers_test.cc b/third_party/libwebrtc/net/dcsctp/common/sequence_numbers_test.cc new file mode 100644 index 0000000000..c4842f089e --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/common/sequence_numbers_test.cc @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/common/sequence_numbers.h" + +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +using Wrapped = webrtc::StrongAlias<class WrappedTag, uint16_t>; +using TestSequence = UnwrappedSequenceNumber<Wrapped>; + +TEST(SequenceNumbersTest, SimpleUnwrapping) { + TestSequence::Unwrapper unwrapper; + + TestSequence s0 = unwrapper.Unwrap(Wrapped(0)); + TestSequence s1 = unwrapper.Unwrap(Wrapped(1)); + TestSequence s2 = unwrapper.Unwrap(Wrapped(2)); + TestSequence s3 = unwrapper.Unwrap(Wrapped(3)); + + EXPECT_LT(s0, s1); + EXPECT_LT(s0, s2); + EXPECT_LT(s0, s3); + EXPECT_LT(s1, s2); + EXPECT_LT(s1, s3); + EXPECT_LT(s2, s3); + + EXPECT_EQ(TestSequence::Difference(s1, s0), 1); + EXPECT_EQ(TestSequence::Difference(s2, s0), 2); + EXPECT_EQ(TestSequence::Difference(s3, s0), 3); + + EXPECT_GT(s1, s0); + EXPECT_GT(s2, s0); + EXPECT_GT(s3, s0); + EXPECT_GT(s2, s1); + EXPECT_GT(s3, s1); + EXPECT_GT(s3, s2); + + s0.Increment(); + EXPECT_EQ(s0, s1); + s1.Increment(); + EXPECT_EQ(s1, s2); + s2.Increment(); + EXPECT_EQ(s2, s3); + + EXPECT_EQ(TestSequence::AddTo(s0, 2), s3); +} + +TEST(SequenceNumbersTest, MidValueUnwrapping) { + TestSequence::Unwrapper unwrapper; + + TestSequence s0 = unwrapper.Unwrap(Wrapped(0x7FFE)); + TestSequence s1 = unwrapper.Unwrap(Wrapped(0x7FFF)); + TestSequence s2 = unwrapper.Unwrap(Wrapped(0x8000)); + TestSequence s3 = unwrapper.Unwrap(Wrapped(0x8001)); + + EXPECT_LT(s0, s1); + EXPECT_LT(s0, s2); + EXPECT_LT(s0, s3); + EXPECT_LT(s1, s2); + EXPECT_LT(s1, s3); + EXPECT_LT(s2, s3); + + EXPECT_EQ(TestSequence::Difference(s1, s0), 1); + EXPECT_EQ(TestSequence::Difference(s2, s0), 2); + EXPECT_EQ(TestSequence::Difference(s3, s0), 3); + + EXPECT_GT(s1, s0); + EXPECT_GT(s2, s0); + EXPECT_GT(s3, s0); + EXPECT_GT(s2, s1); + EXPECT_GT(s3, s1); + EXPECT_GT(s3, s2); + + s0.Increment(); + EXPECT_EQ(s0, s1); + s1.Increment(); + EXPECT_EQ(s1, s2); + s2.Increment(); + EXPECT_EQ(s2, s3); + + EXPECT_EQ(TestSequence::AddTo(s0, 2), s3); +} + +TEST(SequenceNumbersTest, WrappedUnwrapping) { + TestSequence::Unwrapper unwrapper; + + TestSequence s0 = unwrapper.Unwrap(Wrapped(0xFFFE)); + TestSequence s1 = unwrapper.Unwrap(Wrapped(0xFFFF)); + TestSequence s2 = unwrapper.Unwrap(Wrapped(0x0000)); + TestSequence s3 = unwrapper.Unwrap(Wrapped(0x0001)); + + EXPECT_LT(s0, s1); + EXPECT_LT(s0, s2); + EXPECT_LT(s0, s3); + EXPECT_LT(s1, s2); + EXPECT_LT(s1, s3); + EXPECT_LT(s2, s3); + + EXPECT_EQ(TestSequence::Difference(s1, s0), 1); + EXPECT_EQ(TestSequence::Difference(s2, s0), 2); + EXPECT_EQ(TestSequence::Difference(s3, s0), 3); + + EXPECT_GT(s1, s0); + EXPECT_GT(s2, s0); + EXPECT_GT(s3, s0); + EXPECT_GT(s2, s1); + EXPECT_GT(s3, s1); + EXPECT_GT(s3, s2); + + s0.Increment(); + EXPECT_EQ(s0, s1); + s1.Increment(); + EXPECT_EQ(s1, s2); + s2.Increment(); + EXPECT_EQ(s2, s3); + + EXPECT_EQ(TestSequence::AddTo(s0, 2), s3); +} + +TEST(SequenceNumbersTest, WrapAroundAFewTimes) { + TestSequence::Unwrapper unwrapper; + + TestSequence s0 = unwrapper.Unwrap(Wrapped(0)); + TestSequence prev = s0; + + for (uint32_t i = 1; i < 65536 * 3; i++) { + uint16_t wrapped = static_cast<uint16_t>(i); + TestSequence si = unwrapper.Unwrap(Wrapped(wrapped)); + + EXPECT_LT(s0, si); + EXPECT_LT(prev, si); + prev = si; + } +} + +TEST(SequenceNumbersTest, IncrementIsSameAsWrapped) { + TestSequence::Unwrapper unwrapper; + + TestSequence s0 = unwrapper.Unwrap(Wrapped(0)); + + for (uint32_t i = 1; i < 65536 * 2; i++) { + uint16_t wrapped = static_cast<uint16_t>(i); + TestSequence si = unwrapper.Unwrap(Wrapped(wrapped)); + + s0.Increment(); + EXPECT_EQ(s0, si); + } +} + +TEST(SequenceNumbersTest, UnwrappingLargerNumberIsAlwaysLarger) { + TestSequence::Unwrapper unwrapper; + + for (uint32_t i = 1; i < 65536 * 2; i++) { + uint16_t wrapped = static_cast<uint16_t>(i); + TestSequence si = unwrapper.Unwrap(Wrapped(wrapped)); + + EXPECT_GT(unwrapper.Unwrap(Wrapped(wrapped + 1)), si); + EXPECT_GT(unwrapper.Unwrap(Wrapped(wrapped + 5)), si); + EXPECT_GT(unwrapper.Unwrap(Wrapped(wrapped + 10)), si); + EXPECT_GT(unwrapper.Unwrap(Wrapped(wrapped + 100)), si); + } +} + +TEST(SequenceNumbersTest, UnwrappingSmallerNumberIsAlwaysSmaller) { + TestSequence::Unwrapper unwrapper; + + for (uint32_t i = 1; i < 65536 * 2; i++) { + uint16_t wrapped = static_cast<uint16_t>(i); + TestSequence si = unwrapper.Unwrap(Wrapped(wrapped)); + + EXPECT_LT(unwrapper.Unwrap(Wrapped(wrapped - 1)), si); + EXPECT_LT(unwrapper.Unwrap(Wrapped(wrapped - 5)), si); + EXPECT_LT(unwrapper.Unwrap(Wrapped(wrapped - 10)), si); + EXPECT_LT(unwrapper.Unwrap(Wrapped(wrapped - 100)), si); + } +} + +TEST(SequenceNumbersTest, DifferenceIsAbsolute) { + TestSequence::Unwrapper unwrapper; + + TestSequence this_value = unwrapper.Unwrap(Wrapped(10)); + TestSequence other_value = TestSequence::AddTo(this_value, 100); + + EXPECT_EQ(TestSequence::Difference(this_value, other_value), 100); + EXPECT_EQ(TestSequence::Difference(other_value, this_value), 100); + + TestSequence minus_value = TestSequence::AddTo(this_value, -100); + + EXPECT_EQ(TestSequence::Difference(this_value, minus_value), 100); + EXPECT_EQ(TestSequence::Difference(minus_value, this_value), 100); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/common/str_join.h b/third_party/libwebrtc/net/dcsctp/common/str_join.h new file mode 100644 index 0000000000..04517827b7 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/common/str_join.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_COMMON_STR_JOIN_H_ +#define NET_DCSCTP_COMMON_STR_JOIN_H_ + +#include <string> + +#include "absl/strings/string_view.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +template <typename Range> +std::string StrJoin(const Range& seq, absl::string_view delimiter) { + rtc::StringBuilder sb; + int idx = 0; + + for (const typename Range::value_type& elem : seq) { + if (idx > 0) { + sb << delimiter; + } + sb << elem; + + ++idx; + } + return sb.Release(); +} + +template <typename Range, typename Functor> +std::string StrJoin(const Range& seq, + absl::string_view delimiter, + const Functor& fn) { + rtc::StringBuilder sb; + int idx = 0; + + for (const typename Range::value_type& elem : seq) { + if (idx > 0) { + sb << delimiter; + } + fn(sb, elem); + + ++idx; + } + return sb.Release(); +} + +} // namespace dcsctp + +#endif // NET_DCSCTP_COMMON_STR_JOIN_H_ diff --git a/third_party/libwebrtc/net/dcsctp/common/str_join_test.cc b/third_party/libwebrtc/net/dcsctp/common/str_join_test.cc new file mode 100644 index 0000000000..dbfd92c1cf --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/common/str_join_test.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/common/str_join.h" + +#include <string> +#include <utility> +#include <vector> + +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(StrJoinTest, CanJoinStringsFromVector) { + std::vector<std::string> strings = {"Hello", "World"}; + std::string s = StrJoin(strings, " "); + EXPECT_EQ(s, "Hello World"); +} + +TEST(StrJoinTest, CanJoinNumbersFromArray) { + std::array<int, 3> numbers = {1, 2, 3}; + std::string s = StrJoin(numbers, ","); + EXPECT_EQ(s, "1,2,3"); +} + +TEST(StrJoinTest, CanFormatElementsWhileJoining) { + std::vector<std::pair<std::string, std::string>> pairs = { + {"hello", "world"}, {"foo", "bar"}, {"fum", "gazonk"}}; + std::string s = StrJoin(pairs, ",", + [&](rtc::StringBuilder& sb, + const std::pair<std::string, std::string>& p) { + sb << p.first << "=" << p.second; + }); + EXPECT_EQ(s, "hello=world,foo=bar,fum=gazonk"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/fuzzers/BUILD.gn b/third_party/libwebrtc/net/dcsctp/fuzzers/BUILD.gn new file mode 100644 index 0000000000..302c828684 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/fuzzers/BUILD.gn @@ -0,0 +1,50 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_library("dcsctp_fuzzers") { + testonly = true + deps = [ + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../common:math", + "../packet:chunk", + "../packet:error_cause", + "../packet:parameter", + "../public:socket", + "../public:types", + "../socket:dcsctp_socket", + ] + sources = [ + "dcsctp_fuzzers.cc", + "dcsctp_fuzzers.h", + ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_fuzzers_unittests") { + testonly = true + + deps = [ + ":dcsctp_fuzzers", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:logging", + "../../../test:test_support", + "../packet:sctp_packet", + "../public:socket", + "../socket:dcsctp_socket", + "../testing:testing_macros", + ] + sources = [ "dcsctp_fuzzers_test.cc" ] + } +} diff --git a/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.cc b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.cc new file mode 100644 index 0000000000..e8fcacffa0 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.cc @@ -0,0 +1,461 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/fuzzers/dcsctp_fuzzers.h" + +#include <string> +#include <utility> +#include <vector> + +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/socket/dcsctp_socket.h" +#include "net/dcsctp/socket/state_cookie.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace dcsctp_fuzzers { +namespace { +static constexpr int kRandomValue = FuzzerCallbacks::kRandomValue; +static constexpr size_t kMinInputLength = 5; +static constexpr size_t kMaxInputLength = 1024; + +// A starting state for the socket, when fuzzing. +enum class StartingState : int { + kConnectNotCalled, + // When socket initiating Connect + kConnectCalled, + kReceivedInitAck, + kReceivedCookieAck, + // When socket initiating Shutdown + kShutdownCalled, + kReceivedShutdownAck, + // When peer socket initiated Connect + kReceivedInit, + kReceivedCookieEcho, + // When peer initiated Shutdown + kReceivedShutdown, + kReceivedShutdownComplete, + kNumberOfStates, +}; + +// State about the current fuzzing iteration +class FuzzState { + public: + explicit FuzzState(rtc::ArrayView<const uint8_t> data) : data_(data) {} + + uint8_t GetByte() { + uint8_t value = 0; + if (offset_ < data_.size()) { + value = data_[offset_]; + ++offset_; + } + return value; + } + + TSN GetNextTSN() { return TSN(tsn_++); } + MID GetNextMID() { return MID(mid_++); } + + bool empty() const { return offset_ >= data_.size(); } + + private: + uint32_t tsn_ = kRandomValue; + uint32_t mid_ = 0; + rtc::ArrayView<const uint8_t> data_; + size_t offset_ = 0; +}; + +void SetSocketState(DcSctpSocketInterface& socket, + FuzzerCallbacks& socket_cb, + StartingState state) { + // We'll use another temporary peer socket for the establishment. + FuzzerCallbacks peer_cb; + DcSctpSocket peer("peer", peer_cb, nullptr, {}); + + switch (state) { + case StartingState::kConnectNotCalled: + return; + case StartingState::kConnectCalled: + socket.Connect(); + return; + case StartingState::kReceivedInitAck: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + return; + case StartingState::kReceivedCookieAck: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK + return; + case StartingState::kShutdownCalled: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK + socket.Shutdown(); + return; + case StartingState::kReceivedShutdownAck: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK + socket.Shutdown(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_ACK + return; + case StartingState::kReceivedInit: + peer.Connect(); + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT + return; + case StartingState::kReceivedCookieEcho: + peer.Connect(); + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT_ACK + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ECHO + return; + case StartingState::kReceivedShutdown: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK + peer.Shutdown(); + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN + return; + case StartingState::kReceivedShutdownComplete: + socket.Connect(); + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK + peer.Shutdown(); + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN + peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN_ACK + socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_COMPLETE + return; + case StartingState::kNumberOfStates: + RTC_CHECK(false); + return; + } +} + +void MakeDataChunk(FuzzState& state, SctpPacket::Builder& b) { + DataChunk::Options options; + options.is_unordered = IsUnordered(state.GetByte() != 0); + options.is_beginning = Data::IsBeginning(state.GetByte() != 0); + options.is_end = Data::IsEnd(state.GetByte() != 0); + b.Add(DataChunk(state.GetNextTSN(), StreamID(state.GetByte()), + SSN(state.GetByte()), PPID(53), std::vector<uint8_t>(10), + options)); +} + +void MakeInitChunk(FuzzState& state, SctpPacket::Builder& b) { + Parameters::Builder builder; + builder.Add(ForwardTsnSupportedParameter()); + + b.Add(InitChunk(VerificationTag(kRandomValue), 10000, 1000, 1000, + TSN(kRandomValue), builder.Build())); +} + +void MakeInitAckChunk(FuzzState& state, SctpPacket::Builder& b) { + Parameters::Builder builder; + builder.Add(ForwardTsnSupportedParameter()); + + uint8_t state_cookie[] = {1, 2, 3, 4, 5}; + Parameters::Builder params_builder = + Parameters::Builder().Add(StateCookieParameter(state_cookie)); + + b.Add(InitAckChunk(VerificationTag(kRandomValue), 10000, 1000, 1000, + TSN(kRandomValue), builder.Build())); +} + +void MakeSackChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector<SackChunk::GapAckBlock> gap_ack_blocks; + uint16_t last_end = 0; + while (gap_ack_blocks.size() < 20) { + uint8_t delta_start = state.GetByte(); + if (delta_start < 0x80) { + break; + } + uint8_t delta_end = state.GetByte(); + + uint16_t start = last_end + delta_start; + uint16_t end = start + delta_end; + last_end = end; + gap_ack_blocks.emplace_back(start, end); + } + + TSN cum_ack_tsn(kRandomValue + state.GetByte()); + b.Add(SackChunk(cum_ack_tsn, 10000, std::move(gap_ack_blocks), {})); +} + +void MakeHeartbeatRequestChunk(FuzzState& state, SctpPacket::Builder& b) { + uint8_t info[] = {1, 2, 3, 4, 5}; + b.Add(HeartbeatRequestChunk( + Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build())); +} + +void MakeHeartbeatAckChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector<uint8_t> info(8); + b.Add(HeartbeatRequestChunk( + Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build())); +} + +void MakeAbortChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(AbortChunk( + /*filled_in_verification_tag=*/true, + Parameters::Builder().Add(UserInitiatedAbortCause("Fuzzing")).Build())); +} + +void MakeErrorChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(ErrorChunk( + Parameters::Builder().Add(ProtocolViolationCause("Fuzzing")).Build())); +} + +void MakeCookieEchoChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector<uint8_t> cookie(StateCookie::kCookieSize); + b.Add(CookieEchoChunk(cookie)); +} + +void MakeCookieAckChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(CookieAckChunk()); +} + +void MakeShutdownChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(ShutdownChunk(state.GetNextTSN())); +} + +void MakeShutdownAckChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(ShutdownAckChunk()); +} + +void MakeShutdownCompleteChunk(FuzzState& state, SctpPacket::Builder& b) { + b.Add(ShutdownCompleteChunk(false)); +} + +void MakeReConfigChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector<StreamID> streams = {StreamID(state.GetByte())}; + Parameters::Builder params_builder = + Parameters::Builder().Add(OutgoingSSNResetRequestParameter( + ReconfigRequestSN(kRandomValue), ReconfigRequestSN(kRandomValue), + state.GetNextTSN(), streams)); + b.Add(ReConfigChunk(params_builder.Build())); +} + +void MakeForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector<ForwardTsnChunk::SkippedStream> skipped_streams; + for (;;) { + uint8_t stream = state.GetByte(); + if (skipped_streams.size() > 20 || stream < 0x80) { + break; + } + skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte())); + } + b.Add(ForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams))); +} + +void MakeIDataChunk(FuzzState& state, SctpPacket::Builder& b) { + DataChunk::Options options; + options.is_unordered = IsUnordered(state.GetByte() != 0); + options.is_beginning = Data::IsBeginning(state.GetByte() != 0); + options.is_end = Data::IsEnd(state.GetByte() != 0); + b.Add(IDataChunk(state.GetNextTSN(), StreamID(state.GetByte()), + state.GetNextMID(), PPID(53), FSN(0), + std::vector<uint8_t>(10), options)); +} + +void MakeIForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) { + std::vector<ForwardTsnChunk::SkippedStream> skipped_streams; + for (;;) { + uint8_t stream = state.GetByte(); + if (skipped_streams.size() > 20 || stream < 0x80) { + break; + } + skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte())); + } + b.Add(IForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams))); +} + +class RandomFuzzedChunk : public Chunk { + public: + explicit RandomFuzzedChunk(FuzzState& state) : state_(state) {} + + void SerializeTo(std::vector<uint8_t>& out) const override { + size_t bytes = state_.GetByte(); + for (size_t i = 0; i < bytes; ++i) { + out.push_back(state_.GetByte()); + } + } + + std::string ToString() const override { return std::string("RANDOM_FUZZED"); } + + private: + FuzzState& state_; +}; + +void MakeChunkWithRandomContent(FuzzState& state, SctpPacket::Builder& b) { + b.Add(RandomFuzzedChunk(state)); +} + +std::vector<uint8_t> GeneratePacket(FuzzState& state) { + DcSctpOptions options; + // Setting a fixed limit to not be dependent on the defaults, which may + // change. + options.mtu = 2048; + SctpPacket::Builder builder(VerificationTag(kRandomValue), options); + + // The largest expected serialized chunk, as created by fuzzers. + static constexpr size_t kMaxChunkSize = 256; + + for (int i = 0; i < 5 && builder.bytes_remaining() > kMaxChunkSize; ++i) { + switch (state.GetByte()) { + case 1: + MakeDataChunk(state, builder); + break; + case 2: + MakeInitChunk(state, builder); + break; + case 3: + MakeInitAckChunk(state, builder); + break; + case 4: + MakeSackChunk(state, builder); + break; + case 5: + MakeHeartbeatRequestChunk(state, builder); + break; + case 6: + MakeHeartbeatAckChunk(state, builder); + break; + case 7: + MakeAbortChunk(state, builder); + break; + case 8: + MakeErrorChunk(state, builder); + break; + case 9: + MakeCookieEchoChunk(state, builder); + break; + case 10: + MakeCookieAckChunk(state, builder); + break; + case 11: + MakeShutdownChunk(state, builder); + break; + case 12: + MakeShutdownAckChunk(state, builder); + break; + case 13: + MakeShutdownCompleteChunk(state, builder); + break; + case 14: + MakeReConfigChunk(state, builder); + break; + case 15: + MakeForwardTsnChunk(state, builder); + break; + case 16: + MakeIDataChunk(state, builder); + break; + case 17: + MakeIForwardTsnChunk(state, builder); + break; + case 18: + MakeChunkWithRandomContent(state, builder); + break; + default: + break; + } + } + std::vector<uint8_t> packet = builder.Build(); + return packet; +} +} // namespace + +void FuzzSocket(DcSctpSocketInterface& socket, + FuzzerCallbacks& cb, + rtc::ArrayView<const uint8_t> data) { + if (data.size() < kMinInputLength || data.size() > kMaxInputLength) { + return; + } + if (data[0] >= static_cast<int>(StartingState::kNumberOfStates)) { + return; + } + + // Set the socket in a specified valid starting state + SetSocketState(socket, cb, static_cast<StartingState>(data[0])); + + FuzzState state(data.subview(1)); + + while (!state.empty()) { + switch (state.GetByte()) { + case 1: + // Generate a valid SCTP packet (based on fuzz data) and "receive it". + socket.ReceivePacket(GeneratePacket(state)); + break; + case 2: + socket.Connect(); + break; + case 3: + socket.Shutdown(); + break; + case 4: + socket.Close(); + break; + case 5: { + StreamID streams[] = {StreamID(state.GetByte())}; + socket.ResetStreams(streams); + } break; + case 6: { + uint8_t flags = state.GetByte(); + SendOptions options; + options.unordered = IsUnordered(flags & 0x01); + options.max_retransmissions = + (flags & 0x02) != 0 ? absl::make_optional(0) : absl::nullopt; + options.lifecycle_id = LifecycleId(42); + size_t payload_exponent = (flags >> 2) % 16; + size_t payload_size = static_cast<size_t>(1) << payload_exponent; + socket.Send(DcSctpMessage(StreamID(state.GetByte()), PPID(53), + std::vector<uint8_t>(payload_size)), + options); + break; + } + case 7: { + // Expire an active timeout/timer. + uint8_t timeout_idx = state.GetByte(); + absl::optional<TimeoutID> timeout_id = cb.ExpireTimeout(timeout_idx); + if (timeout_id.has_value()) { + socket.HandleTimeout(*timeout_id); + } + break; + } + default: + break; + } + } +} +} // namespace dcsctp_fuzzers +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.h b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.h new file mode 100644 index 0000000000..90cfa35099 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ +#define NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ + +#include <deque> +#include <memory> +#include <set> +#include <vector> + +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/dcsctp_socket.h" + +namespace dcsctp { +namespace dcsctp_fuzzers { + +// A fake timeout used during fuzzing. +class FuzzerTimeout : public Timeout { + public: + explicit FuzzerTimeout(std::set<TimeoutID>& active_timeouts) + : active_timeouts_(active_timeouts) {} + + void Start(DurationMs duration_ms, TimeoutID timeout_id) override { + // Start is only allowed to be called on stopped or expired timeouts. + if (timeout_id_.has_value()) { + // It has been started before, but maybe it expired. Ensure that it's not + // running at least. + RTC_DCHECK(active_timeouts_.find(*timeout_id_) == active_timeouts_.end()); + } + timeout_id_ = timeout_id; + RTC_DCHECK(active_timeouts_.insert(timeout_id).second); + } + + void Stop() override { + // Stop is only allowed to be called on active timeouts. Not stopped or + // expired. + RTC_DCHECK(timeout_id_.has_value()); + RTC_DCHECK(active_timeouts_.erase(*timeout_id_) == 1); + timeout_id_ = absl::nullopt; + } + + // A set of all active timeouts, managed by `FuzzerCallbacks`. + std::set<TimeoutID>& active_timeouts_; + // If present, the timout is active and will expire reported as `timeout_id`. + absl::optional<TimeoutID> timeout_id_; +}; + +class FuzzerCallbacks : public DcSctpSocketCallbacks { + public: + static constexpr int kRandomValue = 42; + void SendPacket(rtc::ArrayView<const uint8_t> data) override { + sent_packets_.emplace_back(std::vector<uint8_t>(data.begin(), data.end())); + } + std::unique_ptr<Timeout> CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision) override { + // The fuzzer timeouts don't implement |precision|. + return std::make_unique<FuzzerTimeout>(active_timeouts_); + } + TimeMs TimeMillis() override { return TimeMs(42); } + uint32_t GetRandomInt(uint32_t low, uint32_t high) override { + return kRandomValue; + } + void OnMessageReceived(DcSctpMessage message) override {} + void OnError(ErrorKind error, absl::string_view message) override {} + void OnAborted(ErrorKind error, absl::string_view message) override {} + void OnConnected() override {} + void OnClosed() override {} + void OnConnectionRestarted() override {} + void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams, + absl::string_view reason) override {} + void OnStreamsResetPerformed( + rtc::ArrayView<const StreamID> outgoing_streams) override {} + void OnIncomingStreamsReset( + rtc::ArrayView<const StreamID> incoming_streams) override {} + + std::vector<uint8_t> ConsumeSentPacket() { + if (sent_packets_.empty()) { + return {}; + } + std::vector<uint8_t> ret = sent_packets_.front(); + sent_packets_.pop_front(); + return ret; + } + + // Given an index among the active timeouts, will expire that one. + absl::optional<TimeoutID> ExpireTimeout(size_t index) { + if (index < active_timeouts_.size()) { + auto it = active_timeouts_.begin(); + std::advance(it, index); + TimeoutID timeout_id = *it; + active_timeouts_.erase(it); + return timeout_id; + } + return absl::nullopt; + } + + private: + // Needs to be ordered, to allow fuzzers to expire timers. + std::set<TimeoutID> active_timeouts_; + std::deque<std::vector<uint8_t>> sent_packets_; +}; + +// Given some fuzzing `data` will send packets to the socket as well as calling +// API methods. +void FuzzSocket(DcSctpSocketInterface& socket, + FuzzerCallbacks& cb, + rtc::ArrayView<const uint8_t> data); + +} // namespace dcsctp_fuzzers +} // namespace dcsctp +#endif // NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers_test.cc b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers_test.cc new file mode 100644 index 0000000000..c7d2cd7c99 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers_test.cc @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/fuzzers/dcsctp_fuzzers.h" + +#include "api/array_view.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/socket/dcsctp_socket.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "rtc_base/logging.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace dcsctp_fuzzers { +namespace { + +// This is a testbed where fuzzed data that cause issues can be evaluated and +// crashes reproduced. Use `xxd -i ./crash-abc` to generate `data` below. +TEST(DcsctpFuzzersTest, PassesTestbed) { + uint8_t data[] = {0x07, 0x09, 0x00, 0x01, 0x11, 0xff, 0xff}; + + FuzzerCallbacks cb; + DcSctpOptions options; + options.disable_checksum_verification = true; + DcSctpSocket socket("A", cb, nullptr, options); + + FuzzSocket(socket, cb, data); +} + +} // namespace +} // namespace dcsctp_fuzzers +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/BUILD.gn b/third_party/libwebrtc/net/dcsctp/packet/BUILD.gn new file mode 100644 index 0000000000..08bdb0f5a5 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/BUILD.gn @@ -0,0 +1,331 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +group("packet") { + deps = [ ":bounded_io" ] +} + +rtc_source_set("bounded_io") { + deps = [ + "../../../api:array_view", + "../../../rtc_base:checks", + ] + sources = [ + "bounded_byte_reader.h", + "bounded_byte_writer.h", + ] +} + +rtc_library("tlv_trait") { + deps = [ + ":bounded_io", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings:strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + sources = [ + "tlv_trait.cc", + "tlv_trait.h", + ] +} + +rtc_source_set("data") { + deps = [ + "../../../rtc_base:checks", + "../common:internal_types", + "../public:types", + ] + sources = [ "data.h" ] +} + +rtc_library("crc32c") { + deps = [ + "../../../api:array_view", + "../../../rtc_base:checks", + "//third_party/crc32c", + ] + sources = [ + "crc32c.cc", + "crc32c.h", + ] +} + +rtc_library("parameter") { + deps = [ + ":bounded_io", + ":data", + ":tlv_trait", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:internal_types", + "../common:math", + "../common:str_join", + "../public:types", + ] + sources = [ + "parameter/add_incoming_streams_request_parameter.cc", + "parameter/add_incoming_streams_request_parameter.h", + "parameter/add_outgoing_streams_request_parameter.cc", + "parameter/add_outgoing_streams_request_parameter.h", + "parameter/forward_tsn_supported_parameter.cc", + "parameter/forward_tsn_supported_parameter.h", + "parameter/heartbeat_info_parameter.cc", + "parameter/heartbeat_info_parameter.h", + "parameter/incoming_ssn_reset_request_parameter.cc", + "parameter/incoming_ssn_reset_request_parameter.h", + "parameter/outgoing_ssn_reset_request_parameter.cc", + "parameter/outgoing_ssn_reset_request_parameter.h", + "parameter/parameter.cc", + "parameter/parameter.h", + "parameter/reconfiguration_response_parameter.cc", + "parameter/reconfiguration_response_parameter.h", + "parameter/ssn_tsn_reset_request_parameter.cc", + "parameter/ssn_tsn_reset_request_parameter.h", + "parameter/state_cookie_parameter.cc", + "parameter/state_cookie_parameter.h", + "parameter/supported_extensions_parameter.cc", + "parameter/supported_extensions_parameter.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("error_cause") { + deps = [ + ":data", + ":parameter", + ":tlv_trait", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:internal_types", + "../common:math", + "../common:str_join", + "../packet:bounded_io", + "../public:types", + ] + sources = [ + "error_cause/cookie_received_while_shutting_down_cause.cc", + "error_cause/cookie_received_while_shutting_down_cause.h", + "error_cause/error_cause.cc", + "error_cause/error_cause.h", + "error_cause/invalid_mandatory_parameter_cause.cc", + "error_cause/invalid_mandatory_parameter_cause.h", + "error_cause/invalid_stream_identifier_cause.cc", + "error_cause/invalid_stream_identifier_cause.h", + "error_cause/missing_mandatory_parameter_cause.cc", + "error_cause/missing_mandatory_parameter_cause.h", + "error_cause/no_user_data_cause.cc", + "error_cause/no_user_data_cause.h", + "error_cause/out_of_resource_error_cause.cc", + "error_cause/out_of_resource_error_cause.h", + "error_cause/protocol_violation_cause.cc", + "error_cause/protocol_violation_cause.h", + "error_cause/restart_of_an_association_with_new_address_cause.cc", + "error_cause/restart_of_an_association_with_new_address_cause.h", + "error_cause/stale_cookie_error_cause.cc", + "error_cause/stale_cookie_error_cause.h", + "error_cause/unrecognized_chunk_type_cause.cc", + "error_cause/unrecognized_chunk_type_cause.h", + "error_cause/unrecognized_parameter_cause.cc", + "error_cause/unrecognized_parameter_cause.h", + "error_cause/unresolvable_address_cause.cc", + "error_cause/unresolvable_address_cause.h", + "error_cause/user_initiated_abort_cause.cc", + "error_cause/user_initiated_abort_cause.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("chunk") { + deps = [ + ":data", + ":error_cause", + ":parameter", + ":tlv_trait", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:math", + "../common:str_join", + "../packet:bounded_io", + ] + sources = [ + "chunk/abort_chunk.cc", + "chunk/abort_chunk.h", + "chunk/chunk.cc", + "chunk/chunk.h", + "chunk/cookie_ack_chunk.cc", + "chunk/cookie_ack_chunk.h", + "chunk/cookie_echo_chunk.cc", + "chunk/cookie_echo_chunk.h", + "chunk/data_chunk.cc", + "chunk/data_chunk.h", + "chunk/data_common.h", + "chunk/error_chunk.cc", + "chunk/error_chunk.h", + "chunk/forward_tsn_chunk.cc", + "chunk/forward_tsn_chunk.h", + "chunk/forward_tsn_common.h", + "chunk/heartbeat_ack_chunk.cc", + "chunk/heartbeat_ack_chunk.h", + "chunk/heartbeat_request_chunk.cc", + "chunk/heartbeat_request_chunk.h", + "chunk/idata_chunk.cc", + "chunk/idata_chunk.h", + "chunk/iforward_tsn_chunk.cc", + "chunk/iforward_tsn_chunk.h", + "chunk/init_ack_chunk.cc", + "chunk/init_ack_chunk.h", + "chunk/init_chunk.cc", + "chunk/init_chunk.h", + "chunk/reconfig_chunk.cc", + "chunk/reconfig_chunk.h", + "chunk/sack_chunk.cc", + "chunk/sack_chunk.h", + "chunk/shutdown_ack_chunk.cc", + "chunk/shutdown_ack_chunk.h", + "chunk/shutdown_chunk.cc", + "chunk/shutdown_chunk.h", + "chunk/shutdown_complete_chunk.cc", + "chunk/shutdown_complete_chunk.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("chunk_validators") { + deps = [ + ":chunk", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + ] + sources = [ + "chunk_validators.cc", + "chunk_validators.h", + ] +} + +rtc_library("sctp_packet") { + deps = [ + ":bounded_io", + ":chunk", + ":crc32c", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:internal_types", + "../common:math", + "../public:types", + ] + sources = [ + "sctp_packet.cc", + "sctp_packet.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/memory:memory", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_packet_unittests") { + testonly = true + + deps = [ + ":bounded_io", + ":chunk", + ":chunk_validators", + ":crc32c", + ":error_cause", + ":parameter", + ":sctp_packet", + ":tlv_trait", + "../../../api:array_view", + "../../../rtc_base:buffer", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../test:test_support", + "../common:internal_types", + "../common:math", + "../public:types", + "../testing:testing_macros", + ] + sources = [ + "bounded_byte_reader_test.cc", + "bounded_byte_writer_test.cc", + "chunk/abort_chunk_test.cc", + "chunk/cookie_ack_chunk_test.cc", + "chunk/cookie_echo_chunk_test.cc", + "chunk/data_chunk_test.cc", + "chunk/error_chunk_test.cc", + "chunk/forward_tsn_chunk_test.cc", + "chunk/heartbeat_ack_chunk_test.cc", + "chunk/heartbeat_request_chunk_test.cc", + "chunk/idata_chunk_test.cc", + "chunk/iforward_tsn_chunk_test.cc", + "chunk/init_ack_chunk_test.cc", + "chunk/init_chunk_test.cc", + "chunk/reconfig_chunk_test.cc", + "chunk/sack_chunk_test.cc", + "chunk/shutdown_ack_chunk_test.cc", + "chunk/shutdown_chunk_test.cc", + "chunk/shutdown_complete_chunk_test.cc", + "chunk_validators_test.cc", + "crc32c_test.cc", + "error_cause/cookie_received_while_shutting_down_cause_test.cc", + "error_cause/invalid_mandatory_parameter_cause_test.cc", + "error_cause/invalid_stream_identifier_cause_test.cc", + "error_cause/missing_mandatory_parameter_cause_test.cc", + "error_cause/no_user_data_cause_test.cc", + "error_cause/out_of_resource_error_cause_test.cc", + "error_cause/protocol_violation_cause_test.cc", + "error_cause/restart_of_an_association_with_new_address_cause_test.cc", + "error_cause/stale_cookie_error_cause_test.cc", + "error_cause/unrecognized_chunk_type_cause_test.cc", + "error_cause/unrecognized_parameter_cause_test.cc", + "error_cause/unresolvable_address_cause_test.cc", + "error_cause/user_initiated_abort_cause_test.cc", + "parameter/add_incoming_streams_request_parameter_test.cc", + "parameter/add_outgoing_streams_request_parameter_test.cc", + "parameter/forward_tsn_supported_parameter_test.cc", + "parameter/incoming_ssn_reset_request_parameter_test.cc", + "parameter/outgoing_ssn_reset_request_parameter_test.cc", + "parameter/parameter_test.cc", + "parameter/reconfiguration_response_parameter_test.cc", + "parameter/ssn_tsn_reset_request_parameter_test.cc", + "parameter/state_cookie_parameter_test.cc", + "parameter/supported_extensions_parameter_test.cc", + "sctp_packet_test.cc", + "tlv_trait_test.cc", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + } +} diff --git a/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_reader.h b/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_reader.h new file mode 100644 index 0000000000..603ed6ac33 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_reader.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef NET_DCSCTP_PACKET_BOUNDED_BYTE_READER_H_ +#define NET_DCSCTP_PACKET_BOUNDED_BYTE_READER_H_ + +#include <cstdint> + +#include "api/array_view.h" + +namespace dcsctp { + +// TODO(boivie): These generic functions - and possibly this entire class - +// could be a candidate to have added to rtc_base/. They should use compiler +// intrinsics as well. +namespace internal { +// Loads a 8-bit unsigned word at `data`. +inline uint8_t LoadBigEndian8(const uint8_t* data) { + return data[0]; +} + +// Loads a 16-bit unsigned word at `data`. +inline uint16_t LoadBigEndian16(const uint8_t* data) { + return (data[0] << 8) | data[1]; +} + +// Loads a 32-bit unsigned word at `data`. +inline uint32_t LoadBigEndian32(const uint8_t* data) { + return (data[0] << 24) | (data[1] << 16) | (data[2] << 8) | data[3]; +} +} // namespace internal + +// BoundedByteReader wraps an ArrayView and divides it into two parts; A fixed +// size - which is the template parameter - and a variable size, which is what +// remains in `data` after the `FixedSize`. +// +// The BoundedByteReader provides methods to load/read big endian numbers from +// the FixedSize portion of the buffer, and these are read with static bounds +// checking, to avoid out-of-bounds accesses without a run-time penalty. +// +// The variable sized portion can either be used to create sub-readers, which +// themselves would provide compile-time bounds-checking, or the entire variable +// sized portion can be retrieved as an ArrayView. +template <int FixedSize> +class BoundedByteReader { + public: + explicit BoundedByteReader(rtc::ArrayView<const uint8_t> data) : data_(data) { + RTC_CHECK(data.size() >= FixedSize); + } + + template <size_t offset> + uint8_t Load8() const { + static_assert(offset + sizeof(uint8_t) <= FixedSize, "Out-of-bounds"); + return internal::LoadBigEndian8(&data_[offset]); + } + + template <size_t offset> + uint16_t Load16() const { + static_assert(offset + sizeof(uint16_t) <= FixedSize, "Out-of-bounds"); + static_assert((offset % sizeof(uint16_t)) == 0, "Unaligned access"); + return internal::LoadBigEndian16(&data_[offset]); + } + + template <size_t offset> + uint32_t Load32() const { + static_assert(offset + sizeof(uint32_t) <= FixedSize, "Out-of-bounds"); + static_assert((offset % sizeof(uint32_t)) == 0, "Unaligned access"); + return internal::LoadBigEndian32(&data_[offset]); + } + + template <size_t SubSize> + BoundedByteReader<SubSize> sub_reader(size_t variable_offset) const { + RTC_CHECK(FixedSize + variable_offset + SubSize <= data_.size()); + + rtc::ArrayView<const uint8_t> sub_span = + data_.subview(FixedSize + variable_offset, SubSize); + return BoundedByteReader<SubSize>(sub_span); + } + + size_t variable_data_size() const { return data_.size() - FixedSize; } + + rtc::ArrayView<const uint8_t> variable_data() const { + return data_.subview(FixedSize, data_.size() - FixedSize); + } + + private: + const rtc::ArrayView<const uint8_t> data_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_BOUNDED_BYTE_READER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_reader_test.cc b/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_reader_test.cc new file mode 100644 index 0000000000..2fb4a86785 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_reader_test.cc @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "net/dcsctp/packet/bounded_byte_reader.h" + +#include "api/array_view.h" +#include "rtc_base/buffer.h" +#include "rtc_base/checks.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(BoundedByteReaderTest, CanLoadData) { + uint8_t data[14] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4}; + + BoundedByteReader<8> reader(data); + EXPECT_EQ(reader.variable_data_size(), 6U); + EXPECT_EQ(reader.Load32<0>(), 0x01020304U); + EXPECT_EQ(reader.Load32<4>(), 0x05060708U); + EXPECT_EQ(reader.Load16<4>(), 0x0506U); + EXPECT_EQ(reader.Load8<4>(), 0x05U); + EXPECT_EQ(reader.Load8<5>(), 0x06U); + + BoundedByteReader<6> sub = reader.sub_reader<6>(0); + EXPECT_EQ(sub.Load16<0>(), 0x0900U); + EXPECT_EQ(sub.Load32<0>(), 0x09000102U); + EXPECT_EQ(sub.Load16<4>(), 0x0304U); + + EXPECT_THAT(reader.variable_data(), ElementsAre(9, 0, 1, 2, 3, 4)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_writer.h b/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_writer.h new file mode 100644 index 0000000000..467f26800b --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_writer.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#ifndef NET_DCSCTP_PACKET_BOUNDED_BYTE_WRITER_H_ +#define NET_DCSCTP_PACKET_BOUNDED_BYTE_WRITER_H_ + +#include <algorithm> + +#include "api/array_view.h" + +namespace dcsctp { + +// TODO(boivie): These generic functions - and possibly this entire class - +// could be a candidate to have added to rtc_base/. They should use compiler +// intrinsics as well. +namespace internal { +// Stores a 8-bit unsigned word at `data`. +inline void StoreBigEndian8(uint8_t* data, uint8_t val) { + data[0] = val; +} + +// Stores a 16-bit unsigned word at `data`. +inline void StoreBigEndian16(uint8_t* data, uint16_t val) { + data[0] = val >> 8; + data[1] = val; +} + +// Stores a 32-bit unsigned word at `data`. +inline void StoreBigEndian32(uint8_t* data, uint32_t val) { + data[0] = val >> 24; + data[1] = val >> 16; + data[2] = val >> 8; + data[3] = val; +} +} // namespace internal + +// BoundedByteWriter wraps an ArrayView and divides it into two parts; A fixed +// size - which is the template parameter - and a variable size, which is what +// remains in `data` after the `FixedSize`. +// +// The BoundedByteWriter provides methods to write big endian numbers to the +// FixedSize portion of the buffer, and these are written with static bounds +// checking, to avoid out-of-bounds accesses without a run-time penalty. +// +// The variable sized portion can either be used to create sub-writers, which +// themselves would provide compile-time bounds-checking, or data can be copied +// to it. +template <int FixedSize> +class BoundedByteWriter { + public: + explicit BoundedByteWriter(rtc::ArrayView<uint8_t> data) : data_(data) { + RTC_CHECK(data.size() >= FixedSize); + } + + template <size_t offset> + void Store8(uint8_t value) { + static_assert(offset + sizeof(uint8_t) <= FixedSize, "Out-of-bounds"); + internal::StoreBigEndian8(&data_[offset], value); + } + + template <size_t offset> + void Store16(uint16_t value) { + static_assert(offset + sizeof(uint16_t) <= FixedSize, "Out-of-bounds"); + static_assert((offset % sizeof(uint16_t)) == 0, "Unaligned access"); + internal::StoreBigEndian16(&data_[offset], value); + } + + template <size_t offset> + void Store32(uint32_t value) { + static_assert(offset + sizeof(uint32_t) <= FixedSize, "Out-of-bounds"); + static_assert((offset % sizeof(uint32_t)) == 0, "Unaligned access"); + internal::StoreBigEndian32(&data_[offset], value); + } + + template <size_t SubSize> + BoundedByteWriter<SubSize> sub_writer(size_t variable_offset) { + RTC_CHECK(FixedSize + variable_offset + SubSize <= data_.size()); + + return BoundedByteWriter<SubSize>( + data_.subview(FixedSize + variable_offset, SubSize)); + } + + void CopyToVariableData(rtc::ArrayView<const uint8_t> source) { + memcpy(data_.data() + FixedSize, source.data(), + std::min(source.size(), data_.size() - FixedSize)); + } + + private: + rtc::ArrayView<uint8_t> data_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_BOUNDED_BYTE_WRITER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_writer_test.cc b/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_writer_test.cc new file mode 100644 index 0000000000..3cea0a2f7c --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/bounded_byte_writer_test.cc @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "net/dcsctp/packet/bounded_byte_writer.h" + +#include <vector> + +#include "api/array_view.h" +#include "rtc_base/buffer.h" +#include "rtc_base/checks.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(BoundedByteWriterTest, CanWriteData) { + std::vector<uint8_t> data(14); + + BoundedByteWriter<8> writer(data); + writer.Store32<0>(0x01020304); + writer.Store16<4>(0x0506); + writer.Store8<6>(0x07); + writer.Store8<7>(0x08); + + uint8_t variable_data[] = {0, 0, 0, 0, 3, 0}; + writer.CopyToVariableData(variable_data); + + BoundedByteWriter<6> sub = writer.sub_writer<6>(0); + sub.Store32<0>(0x09000000); + sub.Store16<2>(0x0102); + + BoundedByteWriter<2> sub2 = writer.sub_writer<2>(4); + sub2.Store8<1>(0x04); + + EXPECT_THAT(data, ElementsAre(1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/abort_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/abort_chunk.cc new file mode 100644 index 0000000000..8348eb96a9 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/abort_chunk.cc @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/abort_chunk.h" + +#include <stdint.h> + +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.7 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 6 |Reserved |T| Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / zero or more Error Causes / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int AbortChunk::kType; + +absl::optional<AbortChunk> AbortChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + absl::optional<Parameters> error_causes = + Parameters::Parse(reader->variable_data()); + if (!error_causes.has_value()) { + return absl::nullopt; + } + uint8_t flags = reader->Load8<1>(); + bool filled_in_verification_tag = (flags & (1 << kFlagsBitT)) == 0; + return AbortChunk(filled_in_verification_tag, *std::move(error_causes)); +} + +void AbortChunk::SerializeTo(std::vector<uint8_t>& out) const { + rtc::ArrayView<const uint8_t> error_causes = error_causes_.data(); + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, error_causes.size()); + writer.Store8<1>(filled_in_verification_tag_ ? 0 : (1 << kFlagsBitT)); + writer.CopyToVariableData(error_causes); +} + +std::string AbortChunk::ToString() const { + return "ABORT"; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/abort_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/abort_chunk.h new file mode 100644 index 0000000000..1408a75e80 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/abort_chunk.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_ABORT_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_ABORT_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.7 +struct AbortChunkConfig : ChunkConfig { + static constexpr int kType = 6; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class AbortChunk : public Chunk, public TLVTrait<AbortChunkConfig> { + public: + static constexpr int kType = AbortChunkConfig::kType; + + AbortChunk(bool filled_in_verification_tag, Parameters error_causes) + : filled_in_verification_tag_(filled_in_verification_tag), + error_causes_(std::move(error_causes)) {} + + AbortChunk(AbortChunk&& other) = default; + AbortChunk& operator=(AbortChunk&& other) = default; + + static absl::optional<AbortChunk> Parse(rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + bool filled_in_verification_tag() const { + return filled_in_verification_tag_; + } + + const Parameters& error_causes() const { return error_causes_; } + + private: + static constexpr int kFlagsBitT = 0; + bool filled_in_verification_tag_; + Parameters error_causes_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_ABORT_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/abort_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/abort_chunk_test.cc new file mode 100644 index 0000000000..c1f3a4d5b9 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/abort_chunk_test.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/abort_chunk.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(AbortChunkTest, FromCapture) { + /* + ABORT chunk + Chunk type: ABORT (6) + Chunk flags: 0x00 + Chunk length: 8 + User initiated ABORT cause + */ + + uint8_t data[] = {0x06, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04}; + + ASSERT_HAS_VALUE_AND_ASSIGN(AbortChunk chunk, AbortChunk::Parse(data)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + chunk.error_causes().get<UserInitiatedAbortCause>()); + + EXPECT_EQ(cause.upper_layer_abort_reason(), ""); +} + +TEST(AbortChunkTest, SerializeAndDeserialize) { + AbortChunk chunk(/*filled_in_verification_tag=*/true, + Parameters::Builder() + .Add(UserInitiatedAbortCause("Close called")) + .Build()); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(AbortChunk deserialized, + AbortChunk::Parse(serialized)); + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + deserialized.error_causes().get<UserInitiatedAbortCause>()); + + EXPECT_EQ(cause.upper_layer_abort_reason(), "Close called"); +} + +// Validates that AbortChunk doesn't make any alignment assumptions. +TEST(AbortChunkTest, SerializeAndDeserializeOneChar) { + AbortChunk chunk( + /*filled_in_verification_tag=*/true, + Parameters::Builder().Add(UserInitiatedAbortCause("!")).Build()); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(AbortChunk deserialized, + AbortChunk::Parse(serialized)); + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + deserialized.error_causes().get<UserInitiatedAbortCause>()); + + EXPECT_EQ(cause.upper_layer_abort_reason(), "!"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/chunk.cc new file mode 100644 index 0000000000..832ab82288 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/chunk.cc @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/chunk.h" + +#include <cstdint> +#include <memory> +#include <utility> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/chunk/abort_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +template <class Chunk> +bool ParseAndPrint(uint8_t chunk_type, + rtc::ArrayView<const uint8_t> data, + rtc::StringBuilder& sb) { + if (chunk_type == Chunk::kType) { + absl::optional<Chunk> c = Chunk::Parse(data); + if (c.has_value()) { + sb << c->ToString(); + } else { + sb << "Failed to parse chunk of type " << chunk_type; + } + return true; + } + return false; +} + +std::string DebugConvertChunkToString(rtc::ArrayView<const uint8_t> data) { + rtc::StringBuilder sb; + + if (data.empty()) { + sb << "Failed to parse chunk due to empty data"; + } else { + uint8_t chunk_type = data[0]; + if (!ParseAndPrint<DataChunk>(chunk_type, data, sb) && + !ParseAndPrint<InitChunk>(chunk_type, data, sb) && + !ParseAndPrint<InitAckChunk>(chunk_type, data, sb) && + !ParseAndPrint<SackChunk>(chunk_type, data, sb) && + !ParseAndPrint<HeartbeatRequestChunk>(chunk_type, data, sb) && + !ParseAndPrint<HeartbeatAckChunk>(chunk_type, data, sb) && + !ParseAndPrint<AbortChunk>(chunk_type, data, sb) && + !ParseAndPrint<ErrorChunk>(chunk_type, data, sb) && + !ParseAndPrint<CookieEchoChunk>(chunk_type, data, sb) && + !ParseAndPrint<CookieAckChunk>(chunk_type, data, sb) && + !ParseAndPrint<ShutdownChunk>(chunk_type, data, sb) && + !ParseAndPrint<ShutdownAckChunk>(chunk_type, data, sb) && + !ParseAndPrint<ShutdownCompleteChunk>(chunk_type, data, sb) && + !ParseAndPrint<ReConfigChunk>(chunk_type, data, sb) && + !ParseAndPrint<ForwardTsnChunk>(chunk_type, data, sb) && + !ParseAndPrint<IDataChunk>(chunk_type, data, sb) && + !ParseAndPrint<IForwardTsnChunk>(chunk_type, data, sb)) { + sb << "Unhandled chunk type: " << static_cast<int>(chunk_type); + } + } + return sb.Release(); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/chunk.h new file mode 100644 index 0000000000..687aa1daa1 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/chunk.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_CHUNK_H_ + +#include <stddef.h> +#include <sys/types.h> + +#include <cstdint> +#include <iterator> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// Base class for all SCTP chunks +class Chunk { + public: + Chunk() {} + virtual ~Chunk() = default; + + // Chunks can contain data payloads that shouldn't be copied unnecessarily. + Chunk(Chunk&& other) = default; + Chunk& operator=(Chunk&& other) = default; + Chunk(const Chunk&) = delete; + Chunk& operator=(const Chunk&) = delete; + + // Serializes the chunk to `out`, growing it as necessary. + virtual void SerializeTo(std::vector<uint8_t>& out) const = 0; + + // Returns a human readable description of this chunk and its parameters. + virtual std::string ToString() const = 0; +}; + +// Introspects the chunk in `data` and returns a human readable textual +// representation of it, to be used in debugging. +std::string DebugConvertChunkToString(rtc::ArrayView<const uint8_t> data); + +struct ChunkConfig { + static constexpr int kTypeSizeInBytes = 1; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_ack_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_ack_chunk.cc new file mode 100644 index 0000000000..4839969ccf --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_ack_chunk.cc @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" + +#include <stdint.h> + +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.12 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 11 |Chunk Flags | Length = 4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int CookieAckChunk::kType; + +absl::optional<CookieAckChunk> CookieAckChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return CookieAckChunk(); +} + +void CookieAckChunk::SerializeTo(std::vector<uint8_t>& out) const { + AllocateTLV(out); +} + +std::string CookieAckChunk::ToString() const { + return "COOKIE-ACK"; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_ack_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_ack_chunk.h new file mode 100644 index 0000000000..f7d4a33f7d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_ack_chunk.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_COOKIE_ACK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_COOKIE_ACK_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.12 +struct CookieAckChunkConfig : ChunkConfig { + static constexpr int kType = 11; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class CookieAckChunk : public Chunk, public TLVTrait<CookieAckChunkConfig> { + public: + static constexpr int kType = CookieAckChunkConfig::kType; + + CookieAckChunk() {} + + static absl::optional<CookieAckChunk> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_COOKIE_ACK_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_ack_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_ack_chunk_test.cc new file mode 100644 index 0000000000..3f560c6fef --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_ack_chunk_test.cc @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(CookieAckChunkTest, FromCapture) { + /* + COOKIE_ACK chunk + Chunk type: COOKIE_ACK (11) + Chunk flags: 0x00 + Chunk length: 4 + */ + + uint8_t data[] = {0x0b, 0x00, 0x00, 0x04}; + + EXPECT_TRUE(CookieAckChunk::Parse(data).has_value()); +} + +TEST(CookieAckChunkTest, SerializeAndDeserialize) { + CookieAckChunk chunk; + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(CookieAckChunk deserialized, + CookieAckChunk::Parse(serialized)); + EXPECT_EQ(deserialized.ToString(), "COOKIE-ACK"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_echo_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_echo_chunk.cc new file mode 100644 index 0000000000..a01d0b13c4 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_echo_chunk.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.11 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 10 |Chunk Flags | Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Cookie / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int CookieEchoChunk::kType; + +absl::optional<CookieEchoChunk> CookieEchoChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return CookieEchoChunk(reader->variable_data()); +} + +void CookieEchoChunk::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, cookie_.size()); + writer.CopyToVariableData(cookie_); +} + +std::string CookieEchoChunk::ToString() const { + return "COOKIE-ECHO"; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_echo_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_echo_chunk.h new file mode 100644 index 0000000000..8cb80527f8 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_echo_chunk.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_COOKIE_ECHO_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_COOKIE_ECHO_CHUNK_H_ +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.11 +struct CookieEchoChunkConfig : ChunkConfig { + static constexpr int kType = 10; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class CookieEchoChunk : public Chunk, public TLVTrait<CookieEchoChunkConfig> { + public: + static constexpr int kType = CookieEchoChunkConfig::kType; + + explicit CookieEchoChunk(rtc::ArrayView<const uint8_t> cookie) + : cookie_(cookie.begin(), cookie.end()) {} + + static absl::optional<CookieEchoChunk> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + rtc::ArrayView<const uint8_t> cookie() const { return cookie_; } + + private: + std::vector<uint8_t> cookie_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_COOKIE_ECHO_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_echo_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_echo_chunk_test.cc new file mode 100644 index 0000000000..d06e0a6439 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/cookie_echo_chunk_test.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(CookieEchoChunkTest, FromCapture) { + /* + COOKIE_ECHO chunk (Cookie length: 256 bytes) + Chunk type: COOKIE_ECHO (10) + Chunk flags: 0x00 + Chunk length: 260 + Cookie: 12345678 + */ + + uint8_t data[] = {0x0a, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78}; + + ASSERT_HAS_VALUE_AND_ASSIGN(CookieEchoChunk chunk, + CookieEchoChunk::Parse(data)); + + EXPECT_THAT(chunk.cookie(), ElementsAre(0x12, 0x34, 0x56, 0x78)); +} + +TEST(CookieEchoChunkTest, SerializeAndDeserialize) { + uint8_t cookie[] = {1, 2, 3, 4}; + CookieEchoChunk chunk(cookie); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(CookieEchoChunk deserialized, + CookieEchoChunk::Parse(serialized)); + + EXPECT_THAT(deserialized.cookie(), ElementsAre(1, 2, 3, 4)); + EXPECT_EQ(deserialized.ToString(), "COOKIE-ECHO"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/data_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/data_chunk.cc new file mode 100644 index 0000000000..769be2db91 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/data_chunk.cc @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/data_chunk.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.1 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 0 | Reserved|U|B|E| Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Identifier S | Stream Sequence Number n | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Payload Protocol Identifier | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / User Data (seq n of Stream S) / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int DataChunk::kType; + +absl::optional<DataChunk> DataChunk::Parse(rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + uint8_t flags = reader->Load8<1>(); + TSN tsn(reader->Load32<4>()); + StreamID stream_identifier(reader->Load16<8>()); + SSN ssn(reader->Load16<10>()); + PPID ppid(reader->Load32<12>()); + + Options options; + options.is_end = Data::IsEnd((flags & (1 << kFlagsBitEnd)) != 0); + options.is_beginning = + Data::IsBeginning((flags & (1 << kFlagsBitBeginning)) != 0); + options.is_unordered = IsUnordered((flags & (1 << kFlagsBitUnordered)) != 0); + options.immediate_ack = + ImmediateAckFlag((flags & (1 << kFlagsBitImmediateAck)) != 0); + + return DataChunk(tsn, stream_identifier, ssn, ppid, + std::vector<uint8_t>(reader->variable_data().begin(), + reader->variable_data().end()), + options); +} + +void DataChunk::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, payload().size()); + + writer.Store8<1>( + (*options().is_end ? (1 << kFlagsBitEnd) : 0) | + (*options().is_beginning ? (1 << kFlagsBitBeginning) : 0) | + (*options().is_unordered ? (1 << kFlagsBitUnordered) : 0) | + (*options().immediate_ack ? (1 << kFlagsBitImmediateAck) : 0)); + writer.Store32<4>(*tsn()); + writer.Store16<8>(*stream_id()); + writer.Store16<10>(*ssn()); + writer.Store32<12>(*ppid()); + + writer.CopyToVariableData(payload()); +} + +std::string DataChunk::ToString() const { + rtc::StringBuilder sb; + sb << "DATA, type=" << (options().is_unordered ? "unordered" : "ordered") + << "::" + << (*options().is_beginning && *options().is_end + ? "complete" + : *options().is_beginning ? "first" + : *options().is_end ? "last" : "middle") + << ", tsn=" << *tsn() << ", sid=" << *stream_id() << ", ssn=" << *ssn() + << ", ppid=" << *ppid() << ", length=" << payload().size(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/data_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/data_chunk.h new file mode 100644 index 0000000000..12bb05f2c4 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/data_chunk.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_DATA_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_DATA_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.1 +struct DataChunkConfig : ChunkConfig { + static constexpr int kType = 0; + static constexpr size_t kHeaderSize = 16; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class DataChunk : public AnyDataChunk, public TLVTrait<DataChunkConfig> { + public: + static constexpr int kType = DataChunkConfig::kType; + + // Exposed to allow the retransmission queue to make room for the correct + // header size. + static constexpr size_t kHeaderSize = DataChunkConfig::kHeaderSize; + + DataChunk(TSN tsn, + StreamID stream_id, + SSN ssn, + PPID ppid, + std::vector<uint8_t> payload, + const Options& options) + : AnyDataChunk(tsn, + stream_id, + ssn, + MID(0), + FSN(0), + ppid, + std::move(payload), + options) {} + + DataChunk(TSN tsn, Data&& data, bool immediate_ack) + : AnyDataChunk(tsn, std::move(data), immediate_ack) {} + + static absl::optional<DataChunk> Parse(rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_DATA_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/data_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/data_chunk_test.cc new file mode 100644 index 0000000000..def99ceb23 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/data_chunk_test.cc @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/data_chunk.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(DataChunkTest, FromCapture) { + /* + DATA chunk(ordered, complete segment, TSN: 1426601532, SID: 2, SSN: 1, + PPID: 53, payload length: 4 bytes) + Chunk type: DATA (0) + Chunk flags: 0x03 + Chunk length: 20 + Transmission sequence number: 1426601532 + Stream identifier: 0x0002 + Stream sequence number: 1 + Payload protocol identifier: WebRTC Binary (53) + */ + + uint8_t data[] = {0x00, 0x03, 0x00, 0x14, 0x55, 0x08, 0x36, 0x3c, 0x00, 0x02, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x35, 0x00, 0x01, 0x02, 0x03}; + + ASSERT_HAS_VALUE_AND_ASSIGN(DataChunk chunk, DataChunk::Parse(data)); + EXPECT_EQ(*chunk.tsn(), 1426601532u); + EXPECT_EQ(*chunk.stream_id(), 2u); + EXPECT_EQ(*chunk.ssn(), 1u); + EXPECT_EQ(*chunk.ppid(), 53u); + EXPECT_TRUE(*chunk.options().is_beginning); + EXPECT_TRUE(*chunk.options().is_end); + EXPECT_FALSE(*chunk.options().is_unordered); + EXPECT_FALSE(*chunk.options().immediate_ack); + EXPECT_THAT(chunk.payload(), ElementsAre(0x0, 0x1, 0x2, 0x3)); +} + +TEST(DataChunkTest, SerializeAndDeserialize) { + DataChunk chunk(TSN(123), StreamID(456), SSN(789), PPID(9090), + /*payload=*/{1, 2, 3, 4, 5}, + /*options=*/{}); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(DataChunk deserialized, + DataChunk::Parse(serialized)); + EXPECT_EQ(*chunk.tsn(), 123u); + EXPECT_EQ(*chunk.stream_id(), 456u); + EXPECT_EQ(*chunk.ssn(), 789u); + EXPECT_EQ(*chunk.ppid(), 9090u); + EXPECT_THAT(chunk.payload(), ElementsAre(1, 2, 3, 4, 5)); + + EXPECT_EQ(deserialized.ToString(), + "DATA, type=ordered::middle, tsn=123, sid=456, ssn=789, ppid=9090, " + "length=5"); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/data_common.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/data_common.h new file mode 100644 index 0000000000..b67efeee1e --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/data_common.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_DATA_COMMON_H_ +#define NET_DCSCTP_PACKET_CHUNK_DATA_COMMON_H_ +#include <stdint.h> + +#include <utility> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/data.h" + +namespace dcsctp { + +// Base class for DataChunk and IDataChunk +class AnyDataChunk : public Chunk { + public: + // Represents the "immediate ack" flag on DATA/I-DATA, from RFC7053. + using ImmediateAckFlag = webrtc::StrongAlias<class ImmediateAckFlagTag, bool>; + + // Data chunk options. + // See https://tools.ietf.org/html/rfc4960#section-3.3.1 + struct Options { + Data::IsEnd is_end = Data::IsEnd(false); + Data::IsBeginning is_beginning = Data::IsBeginning(false); + IsUnordered is_unordered = IsUnordered(false); + ImmediateAckFlag immediate_ack = ImmediateAckFlag(false); + }; + + TSN tsn() const { return tsn_; } + + Options options() const { + Options options; + options.is_end = data_.is_end; + options.is_beginning = data_.is_beginning; + options.is_unordered = data_.is_unordered; + options.immediate_ack = immediate_ack_; + return options; + } + + StreamID stream_id() const { return data_.stream_id; } + SSN ssn() const { return data_.ssn; } + MID message_id() const { return data_.message_id; } + FSN fsn() const { return data_.fsn; } + PPID ppid() const { return data_.ppid; } + rtc::ArrayView<const uint8_t> payload() const { return data_.payload; } + + // Extracts the Data from the chunk, as a destructive action. + Data extract() && { return std::move(data_); } + + AnyDataChunk(TSN tsn, + StreamID stream_id, + SSN ssn, + MID message_id, + FSN fsn, + PPID ppid, + std::vector<uint8_t> payload, + const Options& options) + : tsn_(tsn), + data_(stream_id, + ssn, + message_id, + fsn, + ppid, + std::move(payload), + options.is_beginning, + options.is_end, + options.is_unordered), + immediate_ack_(options.immediate_ack) {} + + AnyDataChunk(TSN tsn, Data data, bool immediate_ack) + : tsn_(tsn), data_(std::move(data)), immediate_ack_(immediate_ack) {} + + protected: + // Bits in `flags` header field. + static constexpr int kFlagsBitEnd = 0; + static constexpr int kFlagsBitBeginning = 1; + static constexpr int kFlagsBitUnordered = 2; + static constexpr int kFlagsBitImmediateAck = 3; + + private: + TSN tsn_; + Data data_; + ImmediateAckFlag immediate_ack_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_DATA_COMMON_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/error_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/error_chunk.cc new file mode 100644 index 0000000000..baac0c5588 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/error_chunk.cc @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/error_chunk.h" + +#include <stdint.h> + +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 9 | Chunk Flags | Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / one or more Error Causes / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ErrorChunk::kType; + +absl::optional<ErrorChunk> ErrorChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + absl::optional<Parameters> error_causes = + Parameters::Parse(reader->variable_data()); + if (!error_causes.has_value()) { + return absl::nullopt; + } + return ErrorChunk(*std::move(error_causes)); +} + +void ErrorChunk::SerializeTo(std::vector<uint8_t>& out) const { + rtc::ArrayView<const uint8_t> error_causes = error_causes_.data(); + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, error_causes.size()); + writer.CopyToVariableData(error_causes); +} + +std::string ErrorChunk::ToString() const { + return "ERROR"; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/error_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/error_chunk.h new file mode 100644 index 0000000000..96122cff6a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/error_chunk.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_ERROR_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_ERROR_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10 +struct ErrorChunkConfig : ChunkConfig { + static constexpr int kType = 9; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 4; +}; + +class ErrorChunk : public Chunk, public TLVTrait<ErrorChunkConfig> { + public: + static constexpr int kType = ErrorChunkConfig::kType; + + explicit ErrorChunk(Parameters error_causes) + : error_causes_(std::move(error_causes)) {} + + ErrorChunk(ErrorChunk&& other) = default; + ErrorChunk& operator=(ErrorChunk&& other) = default; + + static absl::optional<ErrorChunk> Parse(rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + const Parameters& error_causes() const { return error_causes_; } + + private: + Parameters error_causes_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_ERROR_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/error_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/error_chunk_test.cc new file mode 100644 index 0000000000..f2b8be1edc --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/error_chunk_test.cc @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/error_chunk.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(ErrorChunkTest, FromCapture) { + /* + ERROR chunk + Chunk type: ERROR (9) + Chunk flags: 0x00 + Chunk length: 12 + Unrecognized chunk type cause (Type: 73 (unknown)) + */ + + uint8_t data[] = {0x09, 0x00, 0x00, 0x0c, 0x00, 0x06, + 0x00, 0x08, 0x49, 0x00, 0x00, 0x04}; + + ASSERT_HAS_VALUE_AND_ASSIGN(ErrorChunk chunk, ErrorChunk::Parse(data)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + UnrecognizedChunkTypeCause cause, + chunk.error_causes().get<UnrecognizedChunkTypeCause>()); + + EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04)); +} + +TEST(ErrorChunkTest, SerializeAndDeserialize) { + ErrorChunk chunk(Parameters::Builder() + .Add(UnrecognizedChunkTypeCause({1, 2, 3, 4})) + .Build()); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(ErrorChunk deserialized, + ErrorChunk::Parse(serialized)); + ASSERT_HAS_VALUE_AND_ASSIGN( + UnrecognizedChunkTypeCause cause, + deserialized.error_causes().get<UnrecognizedChunkTypeCause>()); + + EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(1, 2, 3, 4)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_chunk.cc new file mode 100644 index 0000000000..e432114c50 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_chunk.cc @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" + +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc3758#section-3.2 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 192 | Flags = 0x00 | Length = Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | New Cumulative TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream-1 | Stream Sequence-1 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ / +// / \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream-N | Stream Sequence-N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ForwardTsnChunk::kType; + +absl::optional<ForwardTsnChunk> ForwardTsnChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + TSN new_cumulative_tsn(reader->Load32<4>()); + + size_t streams_skipped = + reader->variable_data_size() / kSkippedStreamBufferSize; + + std::vector<SkippedStream> skipped_streams; + skipped_streams.reserve(streams_skipped); + for (size_t i = 0; i < streams_skipped; ++i) { + BoundedByteReader<kSkippedStreamBufferSize> sub_reader = + reader->sub_reader<kSkippedStreamBufferSize>(i * + kSkippedStreamBufferSize); + + StreamID stream_id(sub_reader.Load16<0>()); + SSN ssn(sub_reader.Load16<2>()); + skipped_streams.emplace_back(stream_id, ssn); + } + return ForwardTsnChunk(new_cumulative_tsn, std::move(skipped_streams)); +} + +void ForwardTsnChunk::SerializeTo(std::vector<uint8_t>& out) const { + rtc::ArrayView<const SkippedStream> skipped = skipped_streams(); + size_t variable_size = skipped.size() * kSkippedStreamBufferSize; + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*new_cumulative_tsn()); + for (size_t i = 0; i < skipped.size(); ++i) { + BoundedByteWriter<kSkippedStreamBufferSize> sub_writer = + writer.sub_writer<kSkippedStreamBufferSize>(i * + kSkippedStreamBufferSize); + sub_writer.Store16<0>(*skipped[i].stream_id); + sub_writer.Store16<2>(*skipped[i].ssn); + } +} + +std::string ForwardTsnChunk::ToString() const { + rtc::StringBuilder sb; + sb << "FORWARD-TSN, new_cumulative_tsn=" << *new_cumulative_tsn(); + for (const auto& skipped : skipped_streams()) { + sb << ", skip " << skipped.stream_id.value() << ":" << *skipped.ssn; + } + return sb.str(); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_chunk.h new file mode 100644 index 0000000000..b9ef666f41 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_chunk.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc3758#section-3.2 +struct ForwardTsnChunkConfig : ChunkConfig { + static constexpr int kType = 192; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 4; +}; + +class ForwardTsnChunk : public AnyForwardTsnChunk, + public TLVTrait<ForwardTsnChunkConfig> { + public: + static constexpr int kType = ForwardTsnChunkConfig::kType; + + ForwardTsnChunk(TSN new_cumulative_tsn, + std::vector<SkippedStream> skipped_streams) + : AnyForwardTsnChunk(new_cumulative_tsn, std::move(skipped_streams)) {} + + static absl::optional<ForwardTsnChunk> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + private: + static constexpr size_t kSkippedStreamBufferSize = 4; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_chunk_test.cc new file mode 100644 index 0000000000..51f97f2396 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_chunk_test.cc @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(ForwardTsnChunkTest, FromCapture) { + /* + FORWARD_TSN chunk(Cumulative TSN: 1905748778) + Chunk type: FORWARD_TSN (192) + Chunk flags: 0x00 + Chunk length: 8 + New cumulative TSN: 1905748778 + */ + + uint8_t data[] = {0xc0, 0x00, 0x00, 0x08, 0x71, 0x97, 0x6b, 0x2a}; + + ASSERT_HAS_VALUE_AND_ASSIGN(ForwardTsnChunk chunk, + ForwardTsnChunk::Parse(data)); + EXPECT_EQ(*chunk.new_cumulative_tsn(), 1905748778u); +} + +TEST(ForwardTsnChunkTest, SerializeAndDeserialize) { + ForwardTsnChunk chunk( + TSN(123), {ForwardTsnChunk::SkippedStream(StreamID(1), SSN(23)), + ForwardTsnChunk::SkippedStream(StreamID(42), SSN(99))}); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(ForwardTsnChunk deserialized, + ForwardTsnChunk::Parse(serialized)); + EXPECT_EQ(*deserialized.new_cumulative_tsn(), 123u); + EXPECT_THAT( + deserialized.skipped_streams(), + ElementsAre(ForwardTsnChunk::SkippedStream(StreamID(1), SSN(23)), + ForwardTsnChunk::SkippedStream(StreamID(42), SSN(99)))); + + EXPECT_EQ(deserialized.ToString(), + "FORWARD-TSN, new_cumulative_tsn=123, skip 1:23, skip 42:99"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_common.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_common.h new file mode 100644 index 0000000000..37bd2aafff --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/forward_tsn_common.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_COMMON_H_ +#define NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_COMMON_H_ +#include <stdint.h> + +#include <utility> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" + +namespace dcsctp { + +// Base class for both ForwardTsnChunk and IForwardTsnChunk +class AnyForwardTsnChunk : public Chunk { + public: + struct SkippedStream { + SkippedStream(StreamID stream_id, SSN ssn) + : stream_id(stream_id), ssn(ssn), unordered(false), message_id(0) {} + SkippedStream(IsUnordered unordered, StreamID stream_id, MID message_id) + : stream_id(stream_id), + ssn(0), + unordered(unordered), + message_id(message_id) {} + + StreamID stream_id; + + // Set for FORWARD_TSN + SSN ssn; + + // Set for I-FORWARD_TSN + IsUnordered unordered; + MID message_id; + + bool operator==(const SkippedStream& other) const { + return stream_id == other.stream_id && ssn == other.ssn && + unordered == other.unordered && message_id == other.message_id; + } + }; + + AnyForwardTsnChunk(TSN new_cumulative_tsn, + std::vector<SkippedStream> skipped_streams) + : new_cumulative_tsn_(new_cumulative_tsn), + skipped_streams_(std::move(skipped_streams)) {} + + TSN new_cumulative_tsn() const { return new_cumulative_tsn_; } + + rtc::ArrayView<const SkippedStream> skipped_streams() const { + return skipped_streams_; + } + + private: + TSN new_cumulative_tsn_; + std::vector<SkippedStream> skipped_streams_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_FORWARD_TSN_COMMON_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_ack_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_ack_chunk.cc new file mode 100644 index 0000000000..3cbcd09c75 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_ack_chunk.cc @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" + +#include <stdint.h> + +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.6 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 5 | Chunk Flags | Heartbeat Ack Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Heartbeat Information TLV (Variable-Length) / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int HeartbeatAckChunk::kType; + +absl::optional<HeartbeatAckChunk> HeartbeatAckChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + absl::optional<Parameters> parameters = + Parameters::Parse(reader->variable_data()); + if (!parameters.has_value()) { + return absl::nullopt; + } + return HeartbeatAckChunk(*std::move(parameters)); +} + +void HeartbeatAckChunk::SerializeTo(std::vector<uint8_t>& out) const { + rtc::ArrayView<const uint8_t> parameters = parameters_.data(); + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, parameters.size()); + writer.CopyToVariableData(parameters); +} + +std::string HeartbeatAckChunk::ToString() const { + return "HEARTBEAT-ACK"; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_ack_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_ack_chunk.h new file mode 100644 index 0000000000..a6479f78b0 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_ack_chunk.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_ACK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_ACK_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.6 +struct HeartbeatAckChunkConfig : ChunkConfig { + static constexpr int kType = 5; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class HeartbeatAckChunk : public Chunk, + public TLVTrait<HeartbeatAckChunkConfig> { + public: + static constexpr int kType = HeartbeatAckChunkConfig::kType; + + explicit HeartbeatAckChunk(Parameters parameters) + : parameters_(std::move(parameters)) {} + + HeartbeatAckChunk(HeartbeatAckChunk&& other) = default; + HeartbeatAckChunk& operator=(HeartbeatAckChunk&& other) = default; + + static absl::optional<HeartbeatAckChunk> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + const Parameters& parameters() const { return parameters_; } + + absl::optional<HeartbeatInfoParameter> info() const { + return parameters_.get<HeartbeatInfoParameter>(); + } + + private: + Parameters parameters_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_ACK_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_ack_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_ack_chunk_test.cc new file mode 100644 index 0000000000..e4d0dd1489 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_ack_chunk_test.cc @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" + +#include <stdint.h> + +#include <utility> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(HeartbeatAckChunkTest, FromCapture) { + /* + HEARTBEAT_ACK chunk (Information: 40 bytes) + Chunk type: HEARTBEAT_ACK (5) + Chunk flags: 0x00 + Chunk length: 44 + Heartbeat info parameter (Information: 36 bytes) + Parameter type: Heartbeat info (0x0001) + Parameter length: 40 + Heartbeat information: ad2436603726070000000000000000007b1000000100… + */ + + uint8_t data[] = {0x05, 0x00, 0x00, 0x2c, 0x00, 0x01, 0x00, 0x28, 0xad, + 0x24, 0x36, 0x60, 0x37, 0x26, 0x07, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, 0x10, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatAckChunk chunk, + HeartbeatAckChunk::Parse(data)); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info, chunk.info()); + + EXPECT_THAT( + info.info(), + ElementsAre(0xad, 0x24, 0x36, 0x60, 0x37, 0x26, 0x07, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, 0x10, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)); +} + +TEST(HeartbeatAckChunkTest, SerializeAndDeserialize) { + uint8_t info_data[] = {1, 2, 3, 4}; + Parameters parameters = + Parameters::Builder().Add(HeartbeatInfoParameter(info_data)).Build(); + HeartbeatAckChunk chunk(std::move(parameters)); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatAckChunk deserialized, + HeartbeatAckChunk::Parse(serialized)); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info, deserialized.info()); + + EXPECT_THAT(info.info(), ElementsAre(1, 2, 3, 4)); + + EXPECT_EQ(deserialized.ToString(), "HEARTBEAT-ACK"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_request_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_request_chunk.cc new file mode 100644 index 0000000000..d759d6b16d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_request_chunk.cc @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" + +#include <stdint.h> + +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.5 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 4 | Chunk Flags | Heartbeat Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Heartbeat Information TLV (Variable-Length) / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int HeartbeatRequestChunk::kType; + +absl::optional<HeartbeatRequestChunk> HeartbeatRequestChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + absl::optional<Parameters> parameters = + Parameters::Parse(reader->variable_data()); + if (!parameters.has_value()) { + return absl::nullopt; + } + return HeartbeatRequestChunk(*std::move(parameters)); +} + +void HeartbeatRequestChunk::SerializeTo(std::vector<uint8_t>& out) const { + rtc::ArrayView<const uint8_t> parameters = parameters_.data(); + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, parameters.size()); + writer.CopyToVariableData(parameters); +} + +std::string HeartbeatRequestChunk::ToString() const { + return "HEARTBEAT"; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_request_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_request_chunk.h new file mode 100644 index 0000000000..fe2ce19504 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_request_chunk.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_REQUEST_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_REQUEST_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { +// https://tools.ietf.org/html/rfc4960#section-3.3.5 +struct HeartbeatRequestChunkConfig : ChunkConfig { + static constexpr int kType = 4; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class HeartbeatRequestChunk : public Chunk, + public TLVTrait<HeartbeatRequestChunkConfig> { + public: + static constexpr int kType = HeartbeatRequestChunkConfig::kType; + + explicit HeartbeatRequestChunk(Parameters parameters) + : parameters_(std::move(parameters)) {} + + HeartbeatRequestChunk(HeartbeatRequestChunk&& other) = default; + HeartbeatRequestChunk& operator=(HeartbeatRequestChunk&& other) = default; + + static absl::optional<HeartbeatRequestChunk> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + const Parameters& parameters() const { return parameters_; } + Parameters extract_parameters() && { return std::move(parameters_); } + absl::optional<HeartbeatInfoParameter> info() const { + return parameters_.get<HeartbeatInfoParameter>(); + } + + private: + Parameters parameters_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_HEARTBEAT_REQUEST_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_request_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_request_chunk_test.cc new file mode 100644 index 0000000000..94911fe28b --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/heartbeat_request_chunk_test.cc @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" + +#include <stdint.h> + +#include <utility> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(HeartbeatRequestChunkTest, FromCapture) { + /* + HEARTBEAT chunk (Information: 40 bytes) + Chunk type: HEARTBEAT (4) + Chunk flags: 0x00 + Chunk length: 44 + Heartbeat info parameter (Information: 36 bytes) + Parameter type: Heartbeat info (0x0001) + Parameter length: 40 + Heartbeat information: ad2436603726070000000000000000007b10000001… + */ + + uint8_t data[] = {0x04, 0x00, 0x00, 0x2c, 0x00, 0x01, 0x00, 0x28, 0xad, + 0x24, 0x36, 0x60, 0x37, 0x26, 0x07, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, 0x10, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatRequestChunk chunk, + HeartbeatRequestChunk::Parse(data)); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info, chunk.info()); + + EXPECT_THAT( + info.info(), + ElementsAre(0xad, 0x24, 0x36, 0x60, 0x37, 0x26, 0x07, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, 0x10, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00)); +} + +TEST(HeartbeatRequestChunkTest, SerializeAndDeserialize) { + uint8_t info_data[] = {1, 2, 3, 4}; + Parameters parameters = + Parameters::Builder().Add(HeartbeatInfoParameter(info_data)).Build(); + HeartbeatRequestChunk chunk(std::move(parameters)); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatRequestChunk deserialized, + HeartbeatRequestChunk::Parse(serialized)); + + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info, deserialized.info()); + + EXPECT_THAT(info.info(), ElementsAre(1, 2, 3, 4)); + + EXPECT_EQ(deserialized.ToString(), "HEARTBEAT"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/idata_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/idata_chunk.cc new file mode 100644 index 0000000000..378c527909 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/idata_chunk.cc @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/idata_chunk.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc8260#section-2.1 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 64 | Res |I|U|B|E| Length = Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Identifier | Reserved | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Message Identifier | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Payload Protocol Identifier / Fragment Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / User Data / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int IDataChunk::kType; + +absl::optional<IDataChunk> IDataChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + uint8_t flags = reader->Load8<1>(); + TSN tsn(reader->Load32<4>()); + StreamID stream_identifier(reader->Load16<8>()); + MID message_id(reader->Load32<12>()); + uint32_t ppid_or_fsn = reader->Load32<16>(); + + Options options; + options.is_end = Data::IsEnd((flags & (1 << kFlagsBitEnd)) != 0); + options.is_beginning = + Data::IsBeginning((flags & (1 << kFlagsBitBeginning)) != 0); + options.is_unordered = IsUnordered((flags & (1 << kFlagsBitUnordered)) != 0); + options.immediate_ack = + ImmediateAckFlag((flags & (1 << kFlagsBitImmediateAck)) != 0); + + return IDataChunk(tsn, stream_identifier, message_id, + PPID(options.is_beginning ? ppid_or_fsn : 0), + FSN(options.is_beginning ? 0 : ppid_or_fsn), + std::vector<uint8_t>(reader->variable_data().begin(), + reader->variable_data().end()), + options); +} + +void IDataChunk::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, payload().size()); + + writer.Store8<1>( + (*options().is_end ? (1 << kFlagsBitEnd) : 0) | + (*options().is_beginning ? (1 << kFlagsBitBeginning) : 0) | + (*options().is_unordered ? (1 << kFlagsBitUnordered) : 0) | + (*options().immediate_ack ? (1 << kFlagsBitImmediateAck) : 0)); + writer.Store32<4>(*tsn()); + writer.Store16<8>(*stream_id()); + writer.Store32<12>(*message_id()); + writer.Store32<16>(options().is_beginning ? *ppid() : *fsn()); + writer.CopyToVariableData(payload()); +} + +std::string IDataChunk::ToString() const { + rtc::StringBuilder sb; + sb << "I-DATA, type=" << (options().is_unordered ? "unordered" : "ordered") + << "::" + << (*options().is_beginning && *options().is_end + ? "complete" + : *options().is_beginning ? "first" + : *options().is_end ? "last" : "middle") + << ", tsn=" << *tsn() << ", stream_id=" << *stream_id() + << ", message_id=" << *message_id(); + + if (*options().is_beginning) { + sb << ", ppid=" << *ppid(); + } else { + sb << ", fsn=" << *fsn(); + } + sb << ", length=" << payload().size(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/idata_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/idata_chunk.h new file mode 100644 index 0000000000..8cdf2a1fc4 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/idata_chunk.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_IDATA_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_IDATA_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc8260#section-2.1 +struct IDataChunkConfig : ChunkConfig { + static constexpr int kType = 64; + static constexpr size_t kHeaderSize = 20; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class IDataChunk : public AnyDataChunk, public TLVTrait<IDataChunkConfig> { + public: + static constexpr int kType = IDataChunkConfig::kType; + + // Exposed to allow the retransmission queue to make room for the correct + // header size. + static constexpr size_t kHeaderSize = IDataChunkConfig::kHeaderSize; + IDataChunk(TSN tsn, + StreamID stream_id, + MID message_id, + PPID ppid, + FSN fsn, + std::vector<uint8_t> payload, + const Options& options) + : AnyDataChunk(tsn, + stream_id, + SSN(0), + message_id, + fsn, + ppid, + std::move(payload), + options) {} + + explicit IDataChunk(TSN tsn, Data&& data, bool immediate_ack) + : AnyDataChunk(tsn, std::move(data), immediate_ack) {} + + static absl::optional<IDataChunk> Parse(rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_IDATA_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/idata_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/idata_chunk_test.cc new file mode 100644 index 0000000000..fea492d71e --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/idata_chunk_test.cc @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/idata_chunk.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(IDataChunkTest, AtBeginningFromCapture) { + /* + I_DATA chunk(ordered, first segment, TSN: 2487901653, SID: 1, MID: 0, + payload length: 1180 bytes) + Chunk type: I_DATA (64) + Chunk flags: 0x02 + Chunk length: 1200 + Transmission sequence number: 2487901653 + Stream identifier: 0x0001 + Reserved: 0 + Message identifier: 0 + Payload protocol identifier: WebRTC Binary (53) + Reassembled Message in frame: 39 + */ + + uint8_t data[] = {0x40, 0x02, 0x00, 0x15, 0x94, 0x4a, 0x5d, 0xd5, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x35, 0x01, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(IDataChunk chunk, IDataChunk::Parse(data)); + EXPECT_EQ(*chunk.tsn(), 2487901653); + EXPECT_EQ(*chunk.stream_id(), 1); + EXPECT_EQ(*chunk.message_id(), 0u); + EXPECT_EQ(*chunk.ppid(), 53u); + EXPECT_EQ(*chunk.fsn(), 0u); // Not provided (so set to zero) +} + +TEST(IDataChunkTest, AtBeginningSerializeAndDeserialize) { + IDataChunk::Options options; + options.is_beginning = Data::IsBeginning(true); + IDataChunk chunk(TSN(123), StreamID(456), MID(789), PPID(53), FSN(0), {1}, + options); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(IDataChunk deserialized, + IDataChunk::Parse(serialized)); + EXPECT_EQ(*deserialized.tsn(), 123u); + EXPECT_EQ(*deserialized.stream_id(), 456u); + EXPECT_EQ(*deserialized.message_id(), 789u); + EXPECT_EQ(*deserialized.ppid(), 53u); + EXPECT_EQ(*deserialized.fsn(), 0u); + + EXPECT_EQ(deserialized.ToString(), + "I-DATA, type=ordered::first, tsn=123, stream_id=456, " + "message_id=789, ppid=53, length=1"); +} + +TEST(IDataChunkTest, InMiddleFromCapture) { + /* + I_DATA chunk(ordered, last segment, TSN: 2487901706, SID: 3, MID: 1, + FSN: 8, payload length: 560 bytes) + Chunk type: I_DATA (64) + Chunk flags: 0x01 + Chunk length: 580 + Transmission sequence number: 2487901706 + Stream identifier: 0x0003 + Reserved: 0 + Message identifier: 1 + Fragment sequence number: 8 + Reassembled SCTP Fragments (10000 bytes, 9 fragments): + */ + + uint8_t data[] = {0x40, 0x01, 0x00, 0x15, 0x94, 0x4a, 0x5e, 0x0a, + 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x08, 0x01, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(IDataChunk chunk, IDataChunk::Parse(data)); + EXPECT_EQ(*chunk.tsn(), 2487901706); + EXPECT_EQ(*chunk.stream_id(), 3u); + EXPECT_EQ(*chunk.message_id(), 1u); + EXPECT_EQ(*chunk.ppid(), 0u); // Not provided (so set to zero) + EXPECT_EQ(*chunk.fsn(), 8u); +} + +TEST(IDataChunkTest, InMiddleSerializeAndDeserialize) { + IDataChunk chunk(TSN(123), StreamID(456), MID(789), PPID(0), FSN(101112), + {1, 2, 3}, /*options=*/{}); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(IDataChunk deserialized, + IDataChunk::Parse(serialized)); + EXPECT_EQ(*deserialized.tsn(), 123u); + EXPECT_EQ(*deserialized.stream_id(), 456u); + EXPECT_EQ(*deserialized.message_id(), 789u); + EXPECT_EQ(*deserialized.ppid(), 0u); + EXPECT_EQ(*deserialized.fsn(), 101112u); + EXPECT_THAT(deserialized.payload(), ElementsAre(1, 2, 3)); + + EXPECT_EQ(deserialized.ToString(), + "I-DATA, type=ordered::middle, tsn=123, stream_id=456, " + "message_id=789, fsn=101112, length=3"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/iforward_tsn_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/iforward_tsn_chunk.cc new file mode 100644 index 0000000000..a647a8bf8a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/iforward_tsn_chunk.cc @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" + +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc8260#section-2.3.1 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 194 | Flags = 0x00 | Length = Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | New Cumulative TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Identifier | Reserved |U| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Message Identifier | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Identifier | Reserved |U| +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Message Identifier | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int IForwardTsnChunk::kType; + +absl::optional<IForwardTsnChunk> IForwardTsnChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + TSN new_cumulative_tsn(reader->Load32<4>()); + + size_t streams_skipped = + reader->variable_data_size() / kSkippedStreamBufferSize; + std::vector<SkippedStream> skipped_streams; + skipped_streams.reserve(streams_skipped); + size_t offset = 0; + for (size_t i = 0; i < streams_skipped; ++i) { + BoundedByteReader<kSkippedStreamBufferSize> sub_reader = + reader->sub_reader<kSkippedStreamBufferSize>(offset); + + StreamID stream_id(sub_reader.Load16<0>()); + IsUnordered unordered(sub_reader.Load8<3>() & 0x01); + MID message_id(sub_reader.Load32<4>()); + skipped_streams.emplace_back(unordered, stream_id, message_id); + offset += kSkippedStreamBufferSize; + } + RTC_DCHECK(offset == reader->variable_data_size()); + return IForwardTsnChunk(new_cumulative_tsn, std::move(skipped_streams)); +} + +void IForwardTsnChunk::SerializeTo(std::vector<uint8_t>& out) const { + rtc::ArrayView<const SkippedStream> skipped = skipped_streams(); + size_t variable_size = skipped.size() * kSkippedStreamBufferSize; + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*new_cumulative_tsn()); + size_t offset = 0; + for (size_t i = 0; i < skipped.size(); ++i) { + BoundedByteWriter<kSkippedStreamBufferSize> sub_writer = + writer.sub_writer<kSkippedStreamBufferSize>(offset); + + sub_writer.Store16<0>(*skipped[i].stream_id); + sub_writer.Store8<3>(skipped[i].unordered ? 1 : 0); + sub_writer.Store32<4>(*skipped[i].message_id); + offset += kSkippedStreamBufferSize; + } + RTC_DCHECK(offset == variable_size); +} + +std::string IForwardTsnChunk::ToString() const { + rtc::StringBuilder sb; + sb << "I-FORWARD-TSN, new_cumulative_tsn=" << *new_cumulative_tsn(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/iforward_tsn_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/iforward_tsn_chunk.h new file mode 100644 index 0000000000..54d23f7a83 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/iforward_tsn_chunk.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_IFORWARD_TSN_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_IFORWARD_TSN_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc8260#section-2.3.1 +struct IForwardTsnChunkConfig : ChunkConfig { + static constexpr int kType = 194; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 8; +}; + +class IForwardTsnChunk : public AnyForwardTsnChunk, + public TLVTrait<IForwardTsnChunkConfig> { + public: + static constexpr int kType = IForwardTsnChunkConfig::kType; + + IForwardTsnChunk(TSN new_cumulative_tsn, + std::vector<SkippedStream> skipped_streams) + : AnyForwardTsnChunk(new_cumulative_tsn, std::move(skipped_streams)) {} + + static absl::optional<IForwardTsnChunk> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + private: + static constexpr size_t kSkippedStreamBufferSize = 8; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_IFORWARD_TSN_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/iforward_tsn_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/iforward_tsn_chunk_test.cc new file mode 100644 index 0000000000..6a89433be1 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/iforward_tsn_chunk_test.cc @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(IForwardTsnChunkTest, FromCapture) { + /* + I_FORWARD_TSN chunk(Cumulative TSN: 3094631148) + Chunk type: I_FORWARD_TSN (194) + Chunk flags: 0x00 + Chunk length: 16 + New cumulative TSN: 3094631148 + Stream identifier: 1 + Flags: 0x0000 + Message identifier: 2 + */ + + uint8_t data[] = {0xc2, 0x00, 0x00, 0x10, 0xb8, 0x74, 0x52, 0xec, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02}; + + ASSERT_HAS_VALUE_AND_ASSIGN(IForwardTsnChunk chunk, + IForwardTsnChunk::Parse(data)); + EXPECT_EQ(*chunk.new_cumulative_tsn(), 3094631148u); + EXPECT_THAT(chunk.skipped_streams(), + ElementsAre(IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(1), MID(2)))); +} + +TEST(IForwardTsnChunkTest, SerializeAndDeserialize) { + IForwardTsnChunk chunk( + TSN(123), {IForwardTsnChunk::SkippedStream(IsUnordered(false), + StreamID(1), MID(23)), + IForwardTsnChunk::SkippedStream(IsUnordered(true), + StreamID(42), MID(99))}); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(IForwardTsnChunk deserialized, + IForwardTsnChunk::Parse(serialized)); + EXPECT_EQ(*deserialized.new_cumulative_tsn(), 123u); + EXPECT_THAT(deserialized.skipped_streams(), + ElementsAre(IForwardTsnChunk::SkippedStream(IsUnordered(false), + StreamID(1), MID(23)), + IForwardTsnChunk::SkippedStream( + IsUnordered(true), StreamID(42), MID(99)))); + + EXPECT_EQ(deserialized.ToString(), "I-FORWARD-TSN, new_cumulative_tsn=123"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/init_ack_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_ack_chunk.cc new file mode 100644 index 0000000000..c7ef9da1f1 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_ack_chunk.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" + +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_format.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.3 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 2 | Chunk Flags | Chunk Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Initiate Tag | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Advertised Receiver Window Credit | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of Outbound Streams | Number of Inbound Streams | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Initial TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Optional/Variable-Length Parameters / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int InitAckChunk::kType; + +absl::optional<InitAckChunk> InitAckChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + VerificationTag initiate_tag(reader->Load32<4>()); + uint32_t a_rwnd = reader->Load32<8>(); + uint16_t nbr_outbound_streams = reader->Load16<12>(); + uint16_t nbr_inbound_streams = reader->Load16<14>(); + TSN initial_tsn(reader->Load32<16>()); + absl::optional<Parameters> parameters = + Parameters::Parse(reader->variable_data()); + if (!parameters.has_value()) { + return absl::nullopt; + } + return InitAckChunk(initiate_tag, a_rwnd, nbr_outbound_streams, + nbr_inbound_streams, initial_tsn, *std::move(parameters)); +} + +void InitAckChunk::SerializeTo(std::vector<uint8_t>& out) const { + rtc::ArrayView<const uint8_t> parameters = parameters_.data(); + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, parameters.size()); + + writer.Store32<4>(*initiate_tag_); + writer.Store32<8>(a_rwnd_); + writer.Store16<12>(nbr_outbound_streams_); + writer.Store16<14>(nbr_inbound_streams_); + writer.Store32<16>(*initial_tsn_); + writer.CopyToVariableData(parameters); +} + +std::string InitAckChunk::ToString() const { + return rtc::StringFormat("INIT_ACK, initiate_tag=0x%0x, initial_tsn=%u", + *initiate_tag(), *initial_tsn()); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/init_ack_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_ack_chunk.h new file mode 100644 index 0000000000..6fcf64b2eb --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_ack_chunk.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_INIT_ACK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_INIT_ACK_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.3 +struct InitAckChunkConfig : ChunkConfig { + static constexpr int kType = 2; + static constexpr size_t kHeaderSize = 20; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class InitAckChunk : public Chunk, public TLVTrait<InitAckChunkConfig> { + public: + static constexpr int kType = InitAckChunkConfig::kType; + + InitAckChunk(VerificationTag initiate_tag, + uint32_t a_rwnd, + uint16_t nbr_outbound_streams, + uint16_t nbr_inbound_streams, + TSN initial_tsn, + Parameters parameters) + : initiate_tag_(initiate_tag), + a_rwnd_(a_rwnd), + nbr_outbound_streams_(nbr_outbound_streams), + nbr_inbound_streams_(nbr_inbound_streams), + initial_tsn_(initial_tsn), + parameters_(std::move(parameters)) {} + + InitAckChunk(InitAckChunk&& other) = default; + InitAckChunk& operator=(InitAckChunk&& other) = default; + + static absl::optional<InitAckChunk> Parse(rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + VerificationTag initiate_tag() const { return initiate_tag_; } + uint32_t a_rwnd() const { return a_rwnd_; } + uint16_t nbr_outbound_streams() const { return nbr_outbound_streams_; } + uint16_t nbr_inbound_streams() const { return nbr_inbound_streams_; } + TSN initial_tsn() const { return initial_tsn_; } + const Parameters& parameters() const { return parameters_; } + + private: + VerificationTag initiate_tag_; + uint32_t a_rwnd_; + uint16_t nbr_outbound_streams_; + uint16_t nbr_inbound_streams_; + TSN initial_tsn_; + Parameters parameters_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_INIT_ACK_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/init_ack_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_ack_chunk_test.cc new file mode 100644 index 0000000000..184ade747d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_ack_chunk_test.cc @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" + +#include <stdint.h> + +#include <utility> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(InitAckChunkTest, FromCapture) { + /* + INIT_ACK chunk (Outbound streams: 1000, inbound streams: 2048) + Chunk type: INIT_ACK (2) + Chunk flags: 0x00 + Chunk length: 292 + Initiate tag: 0x579c2f98 + Advertised receiver window credit (a_rwnd): 131072 + Number of outbound streams: 1000 + Number of inbound streams: 2048 + Initial TSN: 1670811335 + Forward TSN supported parameter + Parameter type: Forward TSN supported (0xc000) + Parameter length: 4 + Supported Extensions parameter (Supported types: FORWARD_TSN, RE_CONFIG) + Parameter type: Supported Extensions (0x8008) + Parameter length: 6 + Supported chunk type: FORWARD_TSN (192) + Supported chunk type: RE_CONFIG (130) + Parameter padding: 0000 + State cookie parameter (Cookie length: 256 bytes) + Parameter type: State cookie (0x0007) + Parameter length: 260 + State cookie: 4b414d452d42534420312e310000000096b8386000000000… + */ + + uint8_t data[] = { + 0x02, 0x00, 0x01, 0x24, 0x57, 0x9c, 0x2f, 0x98, 0x00, 0x02, 0x00, 0x00, + 0x03, 0xe8, 0x08, 0x00, 0x63, 0x96, 0x8e, 0xc7, 0xc0, 0x00, 0x00, 0x04, + 0x80, 0x08, 0x00, 0x06, 0xc0, 0x82, 0x00, 0x00, 0x00, 0x07, 0x01, 0x04, + 0x4b, 0x41, 0x4d, 0x45, 0x2d, 0x42, 0x53, 0x44, 0x20, 0x31, 0x2e, 0x31, + 0x00, 0x00, 0x00, 0x00, 0x96, 0xb8, 0x38, 0x60, 0x00, 0x00, 0x00, 0x00, + 0x52, 0x5a, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x60, 0xea, 0x00, 0x00, + 0xb5, 0xaa, 0x19, 0xea, 0x31, 0xef, 0xa4, 0x2b, 0x90, 0x16, 0x7a, 0xde, + 0x57, 0x9c, 0x2f, 0x98, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x01, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x5a, 0xde, 0x7a, 0x16, 0x90, + 0x00, 0x02, 0x00, 0x00, 0x03, 0xe8, 0x03, 0xe8, 0x25, 0x0d, 0x37, 0xe8, + 0x80, 0x00, 0x00, 0x04, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, + 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, + 0xab, 0x31, 0x44, 0x62, 0x12, 0x1a, 0x15, 0x13, 0xfd, 0x5a, 0x5f, 0x69, + 0xef, 0xaa, 0x06, 0xe9, 0xab, 0xd7, 0x48, 0xcc, 0x3b, 0xd1, 0x4b, 0x60, + 0xed, 0x7f, 0xa6, 0x44, 0xce, 0x4d, 0xd2, 0xad, 0x80, 0x04, 0x00, 0x06, + 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, + 0x02, 0x00, 0x01, 0x24, 0x57, 0x9c, 0x2f, 0x98, 0x00, 0x02, 0x00, 0x00, + 0x03, 0xe8, 0x08, 0x00, 0x63, 0x96, 0x8e, 0xc7, 0xc0, 0x00, 0x00, 0x04, + 0x80, 0x08, 0x00, 0x06, 0xc0, 0x82, 0x00, 0x00, 0x51, 0x95, 0x01, 0x88, + 0x0d, 0x80, 0x7b, 0x19, 0xe7, 0xf9, 0xc6, 0x18, 0x5c, 0x4a, 0xbf, 0x39, + 0x32, 0xe5, 0x63, 0x8e}; + + ASSERT_HAS_VALUE_AND_ASSIGN(InitAckChunk chunk, InitAckChunk::Parse(data)); + + EXPECT_EQ(chunk.initiate_tag(), VerificationTag(0x579c2f98u)); + EXPECT_EQ(chunk.a_rwnd(), 131072u); + EXPECT_EQ(chunk.nbr_outbound_streams(), 1000u); + EXPECT_EQ(chunk.nbr_inbound_streams(), 2048u); + EXPECT_EQ(chunk.initial_tsn(), TSN(1670811335u)); + EXPECT_TRUE( + chunk.parameters().get<ForwardTsnSupportedParameter>().has_value()); + EXPECT_TRUE( + chunk.parameters().get<SupportedExtensionsParameter>().has_value()); + EXPECT_TRUE(chunk.parameters().get<StateCookieParameter>().has_value()); +} + +TEST(InitAckChunkTest, SerializeAndDeserialize) { + uint8_t state_cookie[] = {1, 2, 3, 4, 5}; + Parameters parameters = + Parameters::Builder().Add(StateCookieParameter(state_cookie)).Build(); + InitAckChunk chunk(VerificationTag(123), /*a_rwnd=*/456, + /*nbr_outbound_streams=*/65535, + /*nbr_inbound_streams=*/65534, /*initial_tsn=*/TSN(789), + /*parameters=*/std::move(parameters)); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(InitAckChunk deserialized, + InitAckChunk::Parse(serialized)); + + EXPECT_EQ(chunk.initiate_tag(), VerificationTag(123u)); + EXPECT_EQ(chunk.a_rwnd(), 456u); + EXPECT_EQ(chunk.nbr_outbound_streams(), 65535u); + EXPECT_EQ(chunk.nbr_inbound_streams(), 65534u); + EXPECT_EQ(chunk.initial_tsn(), TSN(789u)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + StateCookieParameter cookie, + deserialized.parameters().get<StateCookieParameter>()); + EXPECT_THAT(cookie.data(), ElementsAre(1, 2, 3, 4, 5)); + EXPECT_EQ(deserialized.ToString(), + "INIT_ACK, initiate_tag=0x7b, initial_tsn=789"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/init_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_chunk.cc new file mode 100644 index 0000000000..8030107072 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_chunk.cc @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/init_chunk.h" + +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_format.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.2 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 1 | Chunk Flags | Chunk Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Initiate Tag | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Advertised Receiver Window Credit (a_rwnd) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of Outbound Streams | Number of Inbound Streams | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Initial TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Optional/Variable-Length Parameters / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int InitChunk::kType; + +absl::optional<InitChunk> InitChunk::Parse(rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + VerificationTag initiate_tag(reader->Load32<4>()); + uint32_t a_rwnd = reader->Load32<8>(); + uint16_t nbr_outbound_streams = reader->Load16<12>(); + uint16_t nbr_inbound_streams = reader->Load16<14>(); + TSN initial_tsn(reader->Load32<16>()); + + absl::optional<Parameters> parameters = + Parameters::Parse(reader->variable_data()); + if (!parameters.has_value()) { + return absl::nullopt; + } + return InitChunk(initiate_tag, a_rwnd, nbr_outbound_streams, + nbr_inbound_streams, initial_tsn, *std::move(parameters)); +} + +void InitChunk::SerializeTo(std::vector<uint8_t>& out) const { + rtc::ArrayView<const uint8_t> parameters = parameters_.data(); + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, parameters.size()); + + writer.Store32<4>(*initiate_tag_); + writer.Store32<8>(a_rwnd_); + writer.Store16<12>(nbr_outbound_streams_); + writer.Store16<14>(nbr_inbound_streams_); + writer.Store32<16>(*initial_tsn_); + + writer.CopyToVariableData(parameters); +} + +std::string InitChunk::ToString() const { + return rtc::StringFormat("INIT, initiate_tag=0x%0x, initial_tsn=%u", + *initiate_tag(), *initial_tsn()); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/init_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_chunk.h new file mode 100644 index 0000000000..38f9994caa --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_chunk.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_INIT_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_INIT_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.2 +struct InitChunkConfig : ChunkConfig { + static constexpr int kType = 1; + static constexpr size_t kHeaderSize = 20; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class InitChunk : public Chunk, public TLVTrait<InitChunkConfig> { + public: + static constexpr int kType = InitChunkConfig::kType; + + InitChunk(VerificationTag initiate_tag, + uint32_t a_rwnd, + uint16_t nbr_outbound_streams, + uint16_t nbr_inbound_streams, + TSN initial_tsn, + Parameters parameters) + : initiate_tag_(initiate_tag), + a_rwnd_(a_rwnd), + nbr_outbound_streams_(nbr_outbound_streams), + nbr_inbound_streams_(nbr_inbound_streams), + initial_tsn_(initial_tsn), + parameters_(std::move(parameters)) {} + + InitChunk(InitChunk&& other) = default; + InitChunk& operator=(InitChunk&& other) = default; + + static absl::optional<InitChunk> Parse(rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + VerificationTag initiate_tag() const { return initiate_tag_; } + uint32_t a_rwnd() const { return a_rwnd_; } + uint16_t nbr_outbound_streams() const { return nbr_outbound_streams_; } + uint16_t nbr_inbound_streams() const { return nbr_inbound_streams_; } + TSN initial_tsn() const { return initial_tsn_; } + const Parameters& parameters() const { return parameters_; } + + private: + VerificationTag initiate_tag_; + uint32_t a_rwnd_; + uint16_t nbr_outbound_streams_; + uint16_t nbr_inbound_streams_; + TSN initial_tsn_; + Parameters parameters_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_INIT_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/init_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_chunk_test.cc new file mode 100644 index 0000000000..bd36d6fdf8 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/init_chunk_test.cc @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/init_chunk.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(InitChunkTest, FromCapture) { + /* + INIT chunk (Outbound streams: 1000, inbound streams: 1000) + Chunk type: INIT (1) + Chunk flags: 0x00 + Chunk length: 90 + Initiate tag: 0xde7a1690 + Advertised receiver window credit (a_rwnd): 131072 + Number of outbound streams: 1000 + Number of inbound streams: 1000 + Initial TSN: 621623272 + ECN parameter + Parameter type: ECN (0x8000) + Parameter length: 4 + Forward TSN supported parameter + Parameter type: Forward TSN supported (0xc000) + Parameter length: 4 + Supported Extensions parameter (Supported types: FORWARD_TSN, AUTH, + ASCONF, ASCONF_ACK, RE_CONFIG) Parameter type: Supported Extensions (0x8008) + Parameter length: 9 + Supported chunk type: FORWARD_TSN (192) + Supported chunk type: AUTH (15) + Supported chunk type: ASCONF (193) + Supported chunk type: ASCONF_ACK (128) + Supported chunk type: RE_CONFIG (130) + Parameter padding: 000000 + Random parameter + Parameter type: Random (0x8002) + Parameter length: 36 + Random number: ab314462121a1513fd5a5f69efaa06e9abd748cc3bd14b60… + Requested HMAC Algorithm parameter (Supported HMACs: SHA-1) + Parameter type: Requested HMAC Algorithm (0x8004) + Parameter length: 6 + HMAC identifier: SHA-1 (1) + Parameter padding: 0000 + Authenticated Chunk list parameter (Chunk types to be authenticated: + ASCONF_ACK, ASCONF) Parameter type: Authenticated Chunk list (0x8003) + Parameter length: 6 + Chunk type: ASCONF_ACK (128) + Chunk type: ASCONF (193) + */ + + uint8_t data[] = { + 0x01, 0x00, 0x00, 0x5a, 0xde, 0x7a, 0x16, 0x90, 0x00, 0x02, 0x00, 0x00, + 0x03, 0xe8, 0x03, 0xe8, 0x25, 0x0d, 0x37, 0xe8, 0x80, 0x00, 0x00, 0x04, + 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, + 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xab, 0x31, 0x44, 0x62, + 0x12, 0x1a, 0x15, 0x13, 0xfd, 0x5a, 0x5f, 0x69, 0xef, 0xaa, 0x06, 0xe9, + 0xab, 0xd7, 0x48, 0xcc, 0x3b, 0xd1, 0x4b, 0x60, 0xed, 0x7f, 0xa6, 0x44, + 0xce, 0x4d, 0xd2, 0xad, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, + 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(InitChunk chunk, InitChunk::Parse(data)); + + EXPECT_EQ(chunk.initiate_tag(), VerificationTag(0xde7a1690)); + EXPECT_EQ(chunk.a_rwnd(), 131072u); + EXPECT_EQ(chunk.nbr_outbound_streams(), 1000u); + EXPECT_EQ(chunk.nbr_inbound_streams(), 1000u); + EXPECT_EQ(chunk.initial_tsn(), TSN(621623272u)); + EXPECT_TRUE( + chunk.parameters().get<ForwardTsnSupportedParameter>().has_value()); + EXPECT_TRUE( + chunk.parameters().get<SupportedExtensionsParameter>().has_value()); +} + +TEST(InitChunkTest, SerializeAndDeserialize) { + InitChunk chunk(VerificationTag(123), /*a_rwnd=*/456, + /*nbr_outbound_streams=*/65535, + /*nbr_inbound_streams=*/65534, /*initial_tsn=*/TSN(789), + /*parameters=*/Parameters::Builder().Build()); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(InitChunk deserialized, + InitChunk::Parse(serialized)); + + EXPECT_EQ(deserialized.initiate_tag(), VerificationTag(123u)); + EXPECT_EQ(deserialized.a_rwnd(), 456u); + EXPECT_EQ(deserialized.nbr_outbound_streams(), 65535u); + EXPECT_EQ(deserialized.nbr_inbound_streams(), 65534u); + EXPECT_EQ(deserialized.initial_tsn(), TSN(789u)); + EXPECT_EQ(deserialized.ToString(), + "INIT, initiate_tag=0x7b, initial_tsn=789"); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/reconfig_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/reconfig_chunk.cc new file mode 100644 index 0000000000..f39f3b619f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/reconfig_chunk.cc @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" + +#include <stdint.h> + +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-3.1 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 130 | Chunk Flags | Chunk Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Re-configuration Parameter / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Re-configuration Parameter (optional) / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ReConfigChunk::kType; + +absl::optional<ReConfigChunk> ReConfigChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + absl::optional<Parameters> parameters = + Parameters::Parse(reader->variable_data()); + if (!parameters.has_value()) { + return absl::nullopt; + } + + return ReConfigChunk(*std::move(parameters)); +} + +void ReConfigChunk::SerializeTo(std::vector<uint8_t>& out) const { + rtc::ArrayView<const uint8_t> parameters = parameters_.data(); + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, parameters.size()); + writer.CopyToVariableData(parameters); +} + +std::string ReConfigChunk::ToString() const { + return "RE-CONFIG"; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/reconfig_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/reconfig_chunk.h new file mode 100644 index 0000000000..9d2539a515 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/reconfig_chunk.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_RECONFIG_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_RECONFIG_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-3.1 +struct ReConfigChunkConfig : ChunkConfig { + static constexpr int kType = 130; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class ReConfigChunk : public Chunk, public TLVTrait<ReConfigChunkConfig> { + public: + static constexpr int kType = ReConfigChunkConfig::kType; + + explicit ReConfigChunk(Parameters parameters) + : parameters_(std::move(parameters)) {} + + static absl::optional<ReConfigChunk> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + const Parameters& parameters() const { return parameters_; } + Parameters extract_parameters() { return std::move(parameters_); } + + private: + Parameters parameters_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_RECONFIG_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/reconfig_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/reconfig_chunk_test.cc new file mode 100644 index 0000000000..dbf40ff8c0 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/reconfig_chunk_test.cc @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::SizeIs; + +TEST(ReConfigChunkTest, FromCapture) { + /* + RE_CONFIG chunk + Chunk type: RE_CONFIG (130) + Chunk flags: 0x00 + Chunk length: 22 + Outgoing SSN reset request parameter + Parameter type: Outgoing SSN reset request (0x000d) + Parameter length: 18 + Re-configuration request sequence number: 2270550051 + Re-configuration response sequence number: 1905748638 + Senders last assigned TSN: 2270550066 + Stream Identifier: 6 + Chunk padding: 0000 + */ + + uint8_t data[] = {0x82, 0x00, 0x00, 0x16, 0x00, 0x0d, 0x00, 0x12, + 0x87, 0x55, 0xd8, 0x23, 0x71, 0x97, 0x6a, 0x9e, + 0x87, 0x55, 0xd8, 0x32, 0x00, 0x06, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(ReConfigChunk chunk, ReConfigChunk::Parse(data)); + + const Parameters& parameters = chunk.parameters(); + EXPECT_THAT(parameters.descriptors(), SizeIs(1)); + ParameterDescriptor desc = parameters.descriptors()[0]; + ASSERT_EQ(desc.type, OutgoingSSNResetRequestParameter::kType); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + OutgoingSSNResetRequestParameter::Parse(desc.data)); + EXPECT_EQ(*req.request_sequence_number(), 2270550051u); + EXPECT_EQ(*req.response_sequence_number(), 1905748638u); + EXPECT_EQ(*req.sender_last_assigned_tsn(), 2270550066u); + EXPECT_THAT(req.stream_ids(), ElementsAre(StreamID(6))); +} + +TEST(ReConfigChunkTest, SerializeAndDeserialize) { + Parameters::Builder params_builder = + Parameters::Builder().Add(OutgoingSSNResetRequestParameter( + ReconfigRequestSN(123), ReconfigRequestSN(456), TSN(789), + {StreamID(42), StreamID(43)})); + + ReConfigChunk chunk(params_builder.Build()); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(ReConfigChunk deserialized, + ReConfigChunk::Parse(serialized)); + + const Parameters& parameters = deserialized.parameters(); + EXPECT_THAT(parameters.descriptors(), SizeIs(1)); + ParameterDescriptor desc = parameters.descriptors()[0]; + ASSERT_EQ(desc.type, OutgoingSSNResetRequestParameter::kType); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + OutgoingSSNResetRequestParameter::Parse(desc.data)); + EXPECT_EQ(*req.request_sequence_number(), 123u); + EXPECT_EQ(*req.response_sequence_number(), 456u); + EXPECT_EQ(*req.sender_last_assigned_tsn(), 789u); + EXPECT_THAT(req.stream_ids(), ElementsAre(StreamID(42), StreamID(43))); + + EXPECT_EQ(deserialized.ToString(), "RE-CONFIG"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/sack_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/sack_chunk.cc new file mode 100644 index 0000000000..d80e430082 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/sack_chunk.cc @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/sack_chunk.h" + +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.4 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 3 |Chunk Flags | Chunk Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cumulative TSN Ack | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Advertised Receiver Window Credit (a_rwnd) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of Gap Ack Blocks = N | Number of Duplicate TSNs = X | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Gap Ack Block #1 Start | Gap Ack Block #1 End | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / / +// \ ... \ +// / / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Gap Ack Block #N Start | Gap Ack Block #N End | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Duplicate TSN 1 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / / +// \ ... \ +// / / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Duplicate TSN X | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int SackChunk::kType; + +absl::optional<SackChunk> SackChunk::Parse(rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + TSN tsn_ack(reader->Load32<4>()); + uint32_t a_rwnd = reader->Load32<8>(); + uint16_t nbr_of_gap_blocks = reader->Load16<12>(); + uint16_t nbr_of_dup_tsns = reader->Load16<14>(); + + if (reader->variable_data_size() != nbr_of_gap_blocks * kGapAckBlockSize + + nbr_of_dup_tsns * kDupTsnBlockSize) { + RTC_DLOG(LS_WARNING) << "Invalid number of gap blocks or duplicate TSNs"; + return absl::nullopt; + } + + std::vector<GapAckBlock> gap_ack_blocks; + gap_ack_blocks.reserve(nbr_of_gap_blocks); + size_t offset = 0; + for (int i = 0; i < nbr_of_gap_blocks; ++i) { + BoundedByteReader<kGapAckBlockSize> sub_reader = + reader->sub_reader<kGapAckBlockSize>(offset); + + uint16_t start = sub_reader.Load16<0>(); + uint16_t end = sub_reader.Load16<2>(); + gap_ack_blocks.emplace_back(start, end); + offset += kGapAckBlockSize; + } + + std::set<TSN> duplicate_tsns; + for (int i = 0; i < nbr_of_dup_tsns; ++i) { + BoundedByteReader<kDupTsnBlockSize> sub_reader = + reader->sub_reader<kDupTsnBlockSize>(offset); + + duplicate_tsns.insert(TSN(sub_reader.Load32<0>())); + offset += kDupTsnBlockSize; + } + RTC_DCHECK(offset == reader->variable_data_size()); + + return SackChunk(tsn_ack, a_rwnd, gap_ack_blocks, duplicate_tsns); +} + +void SackChunk::SerializeTo(std::vector<uint8_t>& out) const { + int nbr_of_gap_blocks = gap_ack_blocks_.size(); + int nbr_of_dup_tsns = duplicate_tsns_.size(); + size_t variable_size = + nbr_of_gap_blocks * kGapAckBlockSize + nbr_of_dup_tsns * kDupTsnBlockSize; + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*cumulative_tsn_ack_); + writer.Store32<8>(a_rwnd_); + writer.Store16<12>(nbr_of_gap_blocks); + writer.Store16<14>(nbr_of_dup_tsns); + + size_t offset = 0; + for (int i = 0; i < nbr_of_gap_blocks; ++i) { + BoundedByteWriter<kGapAckBlockSize> sub_writer = + writer.sub_writer<kGapAckBlockSize>(offset); + + sub_writer.Store16<0>(gap_ack_blocks_[i].start); + sub_writer.Store16<2>(gap_ack_blocks_[i].end); + offset += kGapAckBlockSize; + } + + for (TSN tsn : duplicate_tsns_) { + BoundedByteWriter<kDupTsnBlockSize> sub_writer = + writer.sub_writer<kDupTsnBlockSize>(offset); + + sub_writer.Store32<0>(*tsn); + offset += kDupTsnBlockSize; + } + + RTC_DCHECK(offset == variable_size); +} + +std::string SackChunk::ToString() const { + rtc::StringBuilder sb; + sb << "SACK, cum_ack_tsn=" << *cumulative_tsn_ack() + << ", a_rwnd=" << a_rwnd(); + for (const GapAckBlock& gap : gap_ack_blocks_) { + uint32_t first = *cumulative_tsn_ack_ + gap.start; + uint32_t last = *cumulative_tsn_ack_ + gap.end; + sb << ", gap=" << first << "--" << last; + } + if (!duplicate_tsns_.empty()) { + sb << ", dup_tsns=" + << StrJoin(duplicate_tsns(), ",", + [](rtc::StringBuilder& sb, TSN tsn) { sb << *tsn; }); + } + + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/sack_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/sack_chunk.h new file mode 100644 index 0000000000..e6758fa332 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/sack_chunk.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_SACK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_SACK_CHUNK_H_ +#include <stddef.h> + +#include <cstdint> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.4 +struct SackChunkConfig : ChunkConfig { + static constexpr int kType = 3; + static constexpr size_t kHeaderSize = 16; + static constexpr size_t kVariableLengthAlignment = 4; +}; + +class SackChunk : public Chunk, public TLVTrait<SackChunkConfig> { + public: + static constexpr int kType = SackChunkConfig::kType; + + struct GapAckBlock { + GapAckBlock(uint16_t start, uint16_t end) : start(start), end(end) {} + + uint16_t start; + uint16_t end; + + bool operator==(const GapAckBlock& other) const { + return start == other.start && end == other.end; + } + }; + + SackChunk(TSN cumulative_tsn_ack, + uint32_t a_rwnd, + std::vector<GapAckBlock> gap_ack_blocks, + std::set<TSN> duplicate_tsns) + : cumulative_tsn_ack_(cumulative_tsn_ack), + a_rwnd_(a_rwnd), + gap_ack_blocks_(std::move(gap_ack_blocks)), + duplicate_tsns_(std::move(duplicate_tsns)) {} + static absl::optional<SackChunk> Parse(rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + TSN cumulative_tsn_ack() const { return cumulative_tsn_ack_; } + uint32_t a_rwnd() const { return a_rwnd_; } + rtc::ArrayView<const GapAckBlock> gap_ack_blocks() const { + return gap_ack_blocks_; + } + const std::set<TSN>& duplicate_tsns() const { return duplicate_tsns_; } + + private: + static constexpr size_t kGapAckBlockSize = 4; + static constexpr size_t kDupTsnBlockSize = 4; + + const TSN cumulative_tsn_ack_; + const uint32_t a_rwnd_; + std::vector<GapAckBlock> gap_ack_blocks_; + std::set<TSN> duplicate_tsns_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_SACK_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/sack_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/sack_chunk_test.cc new file mode 100644 index 0000000000..9122945308 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/sack_chunk_test.cc @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/sack_chunk.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(SackChunkTest, FromCapture) { + /* + SACK chunk (Cumulative TSN: 916312075, a_rwnd: 126323, + gaps: 2, duplicate TSNs: 1) + Chunk type: SACK (3) + Chunk flags: 0x00 + Chunk length: 28 + Cumulative TSN ACK: 916312075 + Advertised receiver window credit (a_rwnd): 126323 + Number of gap acknowledgement blocks: 2 + Number of duplicated TSNs: 1 + Gap Acknowledgement for TSN 916312077 to 916312081 + Gap Acknowledgement for TSN 916312083 to 916312083 + [Number of TSNs in gap acknowledgement blocks: 6] + Duplicate TSN: 916312081 + + */ + + uint8_t data[] = {0x03, 0x00, 0x00, 0x1c, 0x36, 0x9d, 0xd0, 0x0b, 0x00, 0x01, + 0xed, 0x73, 0x00, 0x02, 0x00, 0x01, 0x00, 0x02, 0x00, 0x06, + 0x00, 0x08, 0x00, 0x08, 0x36, 0x9d, 0xd0, 0x11}; + + ASSERT_HAS_VALUE_AND_ASSIGN(SackChunk chunk, SackChunk::Parse(data)); + + TSN cum_ack_tsn(916312075); + EXPECT_EQ(chunk.cumulative_tsn_ack(), cum_ack_tsn); + EXPECT_EQ(chunk.a_rwnd(), 126323u); + EXPECT_THAT( + chunk.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock( + static_cast<uint16_t>(916312077 - *cum_ack_tsn), + static_cast<uint16_t>(916312081 - *cum_ack_tsn)), + SackChunk::GapAckBlock( + static_cast<uint16_t>(916312083 - *cum_ack_tsn), + static_cast<uint16_t>(916312083 - *cum_ack_tsn)))); + EXPECT_THAT(chunk.duplicate_tsns(), ElementsAre(TSN(916312081))); +} + +TEST(SackChunkTest, SerializeAndDeserialize) { + SackChunk chunk(TSN(123), /*a_rwnd=*/456, {SackChunk::GapAckBlock(2, 3)}, + {TSN(1), TSN(2), TSN(3)}); + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(SackChunk deserialized, + SackChunk::Parse(serialized)); + + EXPECT_EQ(*deserialized.cumulative_tsn_ack(), 123u); + EXPECT_EQ(deserialized.a_rwnd(), 456u); + EXPECT_THAT(deserialized.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 3))); + EXPECT_THAT(deserialized.duplicate_tsns(), + ElementsAre(TSN(1), TSN(2), TSN(3))); + + EXPECT_EQ(deserialized.ToString(), + "SACK, cum_ack_tsn=123, a_rwnd=456, gap=125--126, dup_tsns=1,2,3"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_ack_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_ack_chunk.cc new file mode 100644 index 0000000000..d42aceead4 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_ack_chunk.cc @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" + +#include <stdint.h> + +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.9 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 8 |Chunk Flags | Length = 4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ShutdownAckChunk::kType; + +absl::optional<ShutdownAckChunk> ShutdownAckChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return ShutdownAckChunk(); +} + +void ShutdownAckChunk::SerializeTo(std::vector<uint8_t>& out) const { + AllocateTLV(out); +} + +std::string ShutdownAckChunk::ToString() const { + return "SHUTDOWN-ACK"; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_ack_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_ack_chunk.h new file mode 100644 index 0000000000..29c1a98be6 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_ack_chunk.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_ACK_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_ACK_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.9 +struct ShutdownAckChunkConfig : ChunkConfig { + static constexpr int kType = 8; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class ShutdownAckChunk : public Chunk, public TLVTrait<ShutdownAckChunkConfig> { + public: + static constexpr int kType = ShutdownAckChunkConfig::kType; + + ShutdownAckChunk() {} + + static absl::optional<ShutdownAckChunk> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_ACK_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_ack_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_ack_chunk_test.cc new file mode 100644 index 0000000000..ef04ea9892 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_ack_chunk_test.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" + +#include <stdint.h> + +#include <vector> + +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(ShutdownAckChunkTest, FromCapture) { + /* + SHUTDOWN_ACK chunk + Chunk type: SHUTDOWN_ACK (8) + Chunk flags: 0x00 + Chunk length: 4 + */ + + uint8_t data[] = {0x08, 0x00, 0x00, 0x04}; + + EXPECT_TRUE(ShutdownAckChunk::Parse(data).has_value()); +} + +TEST(ShutdownAckChunkTest, SerializeAndDeserialize) { + ShutdownAckChunk chunk; + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + EXPECT_TRUE(ShutdownAckChunk::Parse(serialized).has_value()); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_chunk.cc new file mode 100644 index 0000000000..59f806f7f7 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_chunk.cc @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.8 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 7 | Chunk Flags | Length = 8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cumulative TSN Ack | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ShutdownChunk::kType; + +absl::optional<ShutdownChunk> ShutdownChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + TSN cumulative_tsn_ack(reader->Load32<4>()); + return ShutdownChunk(cumulative_tsn_ack); +} + +void ShutdownChunk::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out); + writer.Store32<4>(*cumulative_tsn_ack_); +} + +std::string ShutdownChunk::ToString() const { + return "SHUTDOWN"; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_chunk.h new file mode 100644 index 0000000000..8148cca286 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_chunk.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.8 +struct ShutdownChunkConfig : ChunkConfig { + static constexpr int kType = 7; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class ShutdownChunk : public Chunk, public TLVTrait<ShutdownChunkConfig> { + public: + static constexpr int kType = ShutdownChunkConfig::kType; + + explicit ShutdownChunk(TSN cumulative_tsn_ack) + : cumulative_tsn_ack_(cumulative_tsn_ack) {} + + static absl::optional<ShutdownChunk> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + TSN cumulative_tsn_ack() const { return cumulative_tsn_ack_; } + + private: + TSN cumulative_tsn_ack_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_chunk_test.cc new file mode 100644 index 0000000000..16d147ca83 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_chunk_test.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { +TEST(ShutdownChunkTest, FromCapture) { + /* + SHUTDOWN chunk (Cumulative TSN ack: 101831101) + Chunk type: SHUTDOWN (7) + Chunk flags: 0x00 + Chunk length: 8 + Cumulative TSN Ack: 101831101 + */ + + uint8_t data[] = {0x07, 0x00, 0x00, 0x08, 0x06, 0x11, 0xd1, 0xbd}; + + ASSERT_HAS_VALUE_AND_ASSIGN(ShutdownChunk chunk, ShutdownChunk::Parse(data)); + EXPECT_EQ(chunk.cumulative_tsn_ack(), TSN(101831101u)); +} + +TEST(ShutdownChunkTest, SerializeAndDeserialize) { + ShutdownChunk chunk(TSN(12345678)); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(ShutdownChunk deserialized, + ShutdownChunk::Parse(serialized)); + + EXPECT_EQ(deserialized.cumulative_tsn_ack(), TSN(12345678u)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_complete_chunk.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_complete_chunk.cc new file mode 100644 index 0000000000..3f54857437 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_complete_chunk.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.13 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 14 |Reserved |T| Length = 4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ShutdownCompleteChunk::kType; + +absl::optional<ShutdownCompleteChunk> ShutdownCompleteChunk::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + uint8_t flags = reader->Load8<1>(); + bool tag_reflected = (flags & (1 << kFlagsBitT)) != 0; + return ShutdownCompleteChunk(tag_reflected); +} + +void ShutdownCompleteChunk::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out); + writer.Store8<1>(tag_reflected_ ? (1 << kFlagsBitT) : 0); +} + +std::string ShutdownCompleteChunk::ToString() const { + return "SHUTDOWN-COMPLETE"; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_complete_chunk.h b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_complete_chunk.h new file mode 100644 index 0000000000..46d28e88dc --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_complete_chunk.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_COMPLETE_CHUNK_H_ +#define NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_COMPLETE_CHUNK_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.13 +struct ShutdownCompleteChunkConfig : ChunkConfig { + static constexpr int kType = 14; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class ShutdownCompleteChunk : public Chunk, + public TLVTrait<ShutdownCompleteChunkConfig> { + public: + static constexpr int kType = ShutdownCompleteChunkConfig::kType; + + explicit ShutdownCompleteChunk(bool tag_reflected) + : tag_reflected_(tag_reflected) {} + + static absl::optional<ShutdownCompleteChunk> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + bool tag_reflected() const { return tag_reflected_; } + + private: + static constexpr int kFlagsBitT = 0; + bool tag_reflected_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_SHUTDOWN_COMPLETE_CHUNK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_complete_chunk_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_complete_chunk_test.cc new file mode 100644 index 0000000000..253900d5cd --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk/shutdown_complete_chunk_test.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" + +#include <stdint.h> + +#include <vector> + +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(ShutdownCompleteChunkTest, FromCapture) { + /* + SHUTDOWN_COMPLETE chunk + Chunk type: SHUTDOWN_COMPLETE (14) + Chunk flags: 0x00 + Chunk length: 4 + */ + + uint8_t data[] = {0x0e, 0x00, 0x00, 0x04}; + + EXPECT_TRUE(ShutdownCompleteChunk::Parse(data).has_value()); +} + +TEST(ShutdownCompleteChunkTest, SerializeAndDeserialize) { + ShutdownCompleteChunk chunk(/*tag_reflected=*/false); + + std::vector<uint8_t> serialized; + chunk.SerializeTo(serialized); + + EXPECT_TRUE(ShutdownCompleteChunk::Parse(serialized).has_value()); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk_validators.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk_validators.cc new file mode 100644 index 0000000000..48d351827e --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk_validators.cc @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk_validators.h" + +#include <algorithm> +#include <utility> +#include <vector> + +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +SackChunk ChunkValidators::Clean(SackChunk&& sack) { + if (Validate(sack)) { + return std::move(sack); + } + + RTC_DLOG(LS_WARNING) << "Received SACK is malformed; cleaning it"; + + std::vector<SackChunk::GapAckBlock> gap_ack_blocks; + gap_ack_blocks.reserve(sack.gap_ack_blocks().size()); + + // First: Only keep blocks that are sane + for (const SackChunk::GapAckBlock& gap_ack_block : sack.gap_ack_blocks()) { + if (gap_ack_block.end > gap_ack_block.start) { + gap_ack_blocks.emplace_back(gap_ack_block); + } + } + + // Not more than at most one remaining? Exit early. + if (gap_ack_blocks.size() <= 1) { + return SackChunk(sack.cumulative_tsn_ack(), sack.a_rwnd(), + std::move(gap_ack_blocks), sack.duplicate_tsns()); + } + + // Sort the intervals by their start value, to aid in the merging below. + absl::c_sort(gap_ack_blocks, [&](const SackChunk::GapAckBlock& a, + const SackChunk::GapAckBlock& b) { + return a.start < b.start; + }); + + // Merge overlapping ranges. + std::vector<SackChunk::GapAckBlock> merged; + merged.reserve(gap_ack_blocks.size()); + merged.push_back(gap_ack_blocks[0]); + + for (size_t i = 1; i < gap_ack_blocks.size(); ++i) { + if (merged.back().end + 1 >= gap_ack_blocks[i].start) { + merged.back().end = std::max(merged.back().end, gap_ack_blocks[i].end); + } else { + merged.push_back(gap_ack_blocks[i]); + } + } + + return SackChunk(sack.cumulative_tsn_ack(), sack.a_rwnd(), std::move(merged), + sack.duplicate_tsns()); +} + +bool ChunkValidators::Validate(const SackChunk& sack) { + if (sack.gap_ack_blocks().empty()) { + return true; + } + + // Ensure that gap-ack-blocks are sorted, has an "end" that is not before + // "start" and are non-overlapping and non-adjacent. + uint16_t prev_end = 0; + for (const SackChunk::GapAckBlock& gap_ack_block : sack.gap_ack_blocks()) { + if (gap_ack_block.end < gap_ack_block.start) { + return false; + } + if (gap_ack_block.start <= (prev_end + 1)) { + return false; + } + prev_end = gap_ack_block.end; + } + return true; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk_validators.h b/third_party/libwebrtc/net/dcsctp/packet/chunk_validators.h new file mode 100644 index 0000000000..b11848a162 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk_validators.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CHUNK_VALIDATORS_H_ +#define NET_DCSCTP_PACKET_CHUNK_VALIDATORS_H_ + +#include "net/dcsctp/packet/chunk/sack_chunk.h" + +namespace dcsctp { +// Validates and cleans SCTP chunks. +class ChunkValidators { + public: + // Given a SackChunk, will return `true` if it's valid, and `false` if not. + static bool Validate(const SackChunk& sack); + + // Given a SackChunk, it will return a cleaned and validated variant of it. + // RFC4960 doesn't say anything about validity of SACKs or if the Gap ACK + // blocks must be sorted, and non-overlapping. While they always are in + // well-behaving implementations, this can't be relied on. + // + // This method internally calls `Validate`, which means that you can always + // pass a SackChunk to this method (valid or not), and use the results. + static SackChunk Clean(SackChunk&& sack); +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CHUNK_VALIDATORS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/chunk_validators_test.cc b/third_party/libwebrtc/net/dcsctp/packet/chunk_validators_test.cc new file mode 100644 index 0000000000..d59fd4ec48 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/chunk_validators_test.cc @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/chunk_validators.h" + +#include <utility> + +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::IsEmpty; + +TEST(ChunkValidatorsTest, NoGapAckBlocksAreValid) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + /*gap_ack_blocks=*/{}, {}); + + EXPECT_TRUE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + EXPECT_THAT(clean.gap_ack_blocks(), IsEmpty()); +} + +TEST(ChunkValidatorsTest, OneValidAckBlock) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, {SackChunk::GapAckBlock(2, 3)}, {}); + + EXPECT_TRUE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 3))); +} + +TEST(ChunkValidatorsTest, TwoValidAckBlocks) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(2, 3), SackChunk::GapAckBlock(5, 6)}, + {}); + + EXPECT_TRUE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + EXPECT_THAT( + clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 3), SackChunk::GapAckBlock(5, 6))); +} + +TEST(ChunkValidatorsTest, OneInvalidAckBlock) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, {SackChunk::GapAckBlock(1, 2)}, {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + // It's not strictly valid, but due to the renegable nature of gap ack blocks, + // the cum_ack_tsn can't simply be moved. + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(1, 2))); +} + +TEST(ChunkValidatorsTest, RemovesInvalidGapAckBlockFromSack) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(2, 3), SackChunk::GapAckBlock(6, 4)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 3))); +} + +TEST(ChunkValidatorsTest, SortsGapAckBlocksInOrder) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(6, 7), SackChunk::GapAckBlock(3, 4)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT( + clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 4), SackChunk::GapAckBlock(6, 7))); +} + +TEST(ChunkValidatorsTest, MergesAdjacentBlocks) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(3, 4), SackChunk::GapAckBlock(5, 6)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 6))); +} + +TEST(ChunkValidatorsTest, MergesOverlappingByOne) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(3, 4), SackChunk::GapAckBlock(4, 5)}, + {}); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 5))); +} + +TEST(ChunkValidatorsTest, MergesOverlappingByMore) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(3, 10), SackChunk::GapAckBlock(4, 5)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 10))); +} + +TEST(ChunkValidatorsTest, MergesBlocksStartingWithSameStartOffset) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(3, 7), SackChunk::GapAckBlock(3, 5), + SackChunk::GapAckBlock(3, 9)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 9))); +} + +TEST(ChunkValidatorsTest, MergesBlocksPartiallyOverlapping) { + SackChunk sack(TSN(123), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(3, 7), SackChunk::GapAckBlock(5, 9)}, + {}); + + EXPECT_FALSE(ChunkValidators::Validate(sack)); + + SackChunk clean = ChunkValidators::Clean(std::move(sack)); + + EXPECT_THAT(clean.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(3, 9))); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/crc32c.cc b/third_party/libwebrtc/net/dcsctp/packet/crc32c.cc new file mode 100644 index 0000000000..e3f0dc1d19 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/crc32c.cc @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/crc32c.h" + +#include <cstdint> + +#include "third_party/crc32c/src/include/crc32c/crc32c.h" + +namespace dcsctp { + +uint32_t GenerateCrc32C(rtc::ArrayView<const uint8_t> data) { + uint32_t crc32c = crc32c_value(data.data(), data.size()); + + // Byte swapping for little endian byte order: + uint8_t byte0 = crc32c; + uint8_t byte1 = crc32c >> 8; + uint8_t byte2 = crc32c >> 16; + uint8_t byte3 = crc32c >> 24; + crc32c = ((byte0 << 24) | (byte1 << 16) | (byte2 << 8) | byte3); + return crc32c; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/crc32c.h b/third_party/libwebrtc/net/dcsctp/packet/crc32c.h new file mode 100644 index 0000000000..a969e1b26b --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/crc32c.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_CRC32C_H_ +#define NET_DCSCTP_PACKET_CRC32C_H_ + +#include <cstdint> + +#include "api/array_view.h" + +namespace dcsctp { + +// Generates the CRC32C checksum of `data`. +uint32_t GenerateCrc32C(rtc::ArrayView<const uint8_t> data); + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_CRC32C_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/crc32c_test.cc b/third_party/libwebrtc/net/dcsctp/packet/crc32c_test.cc new file mode 100644 index 0000000000..0821c4ef75 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/crc32c_test.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/crc32c.h" + +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +constexpr std::array<const uint8_t, 0> kEmpty = {}; +constexpr std::array<const uint8_t, 1> kZero = {0}; +constexpr std::array<const uint8_t, 4> kManyZeros = {0, 0, 0, 0}; +constexpr std::array<const uint8_t, 4> kShort = {1, 2, 3, 4}; +constexpr std::array<const uint8_t, 8> kLong = {1, 2, 3, 4, 5, 6, 7, 8}; +// https://tools.ietf.org/html/rfc3720#appendix-B.4 +constexpr std::array<const uint8_t, 32> k32Zeros = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; +constexpr std::array<const uint8_t, 32> k32Ones = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; +constexpr std::array<const uint8_t, 32> k32Incrementing = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; +constexpr std::array<const uint8_t, 32> k32Decrementing = { + 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; +constexpr std::array<const uint8_t, 48> kISCSICommandPDU = { + 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +}; + +TEST(Crc32Test, TestVectors) { + EXPECT_EQ(GenerateCrc32C(kEmpty), 0U); + EXPECT_EQ(GenerateCrc32C(kZero), 0x51537d52U); + EXPECT_EQ(GenerateCrc32C(kManyZeros), 0xc74b6748U); + EXPECT_EQ(GenerateCrc32C(kShort), 0xf48c3029U); + EXPECT_EQ(GenerateCrc32C(kLong), 0x811f8946U); + // https://tools.ietf.org/html/rfc3720#appendix-B.4 + EXPECT_EQ(GenerateCrc32C(k32Zeros), 0xaa36918aU); + EXPECT_EQ(GenerateCrc32C(k32Ones), 0x43aba862U); + EXPECT_EQ(GenerateCrc32C(k32Incrementing), 0x4e79dd46U); + EXPECT_EQ(GenerateCrc32C(k32Decrementing), 0x5cdb3f11U); + EXPECT_EQ(GenerateCrc32C(kISCSICommandPDU), 0x563a96d9U); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/data.h b/third_party/libwebrtc/net/dcsctp/packet/data.h new file mode 100644 index 0000000000..c1754ed59a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/data.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_DATA_H_ +#define NET_DCSCTP_PACKET_DATA_H_ + +#include <cstdint> +#include <utility> +#include <vector> + +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// Represents data that is either received and extracted from a DATA/I-DATA +// chunk, or data that is supposed to be sent, and wrapped in a DATA/I-DATA +// chunk (depending on peer capabilities). +// +// The data wrapped in this structure is actually the same as the DATA/I-DATA +// chunk (actually the union of them), but to avoid having all components be +// aware of the implementation details of the different chunks, this abstraction +// is used instead. A notable difference is also that it doesn't carry a +// Transmission Sequence Number (TSN), as that is not known when a chunk is +// created (assigned late, just when sending), and that the TSNs in DATA/I-DATA +// are wrapped numbers, and within the library, unwrapped sequence numbers are +// preferably used. +struct Data { + // Indicates if a chunk is the first in a fragmented message and maps to the + // "beginning" flag in DATA/I-DATA chunk. + using IsBeginning = webrtc::StrongAlias<class IsBeginningTag, bool>; + + // Indicates if a chunk is the last in a fragmented message and maps to the + // "end" flag in DATA/I-DATA chunk. + using IsEnd = webrtc::StrongAlias<class IsEndTag, bool>; + + Data(StreamID stream_id, + SSN ssn, + MID message_id, + FSN fsn, + PPID ppid, + std::vector<uint8_t> payload, + IsBeginning is_beginning, + IsEnd is_end, + IsUnordered is_unordered) + : stream_id(stream_id), + ssn(ssn), + message_id(message_id), + fsn(fsn), + ppid(ppid), + payload(std::move(payload)), + is_beginning(is_beginning), + is_end(is_end), + is_unordered(is_unordered) {} + + // Move-only, to avoid accidental copies. + Data(Data&& other) = default; + Data& operator=(Data&& other) = default; + + // Creates a copy of this `Data` object. + Data Clone() const { + return Data(stream_id, ssn, message_id, fsn, ppid, payload, is_beginning, + is_end, is_unordered); + } + + // The size of this data, which translates to the size of its payload. + size_t size() const { return payload.size(); } + + // Stream Identifier. + StreamID stream_id; + + // Stream Sequence Number (SSN), per stream, for ordered chunks. Defined by + // RFC4960 and used only in DATA chunks (not I-DATA). + SSN ssn; + + // Message Identifier (MID) per stream and ordered/unordered. Defined by + // RFC8260, and used together with options.is_unordered and stream_id to + // uniquely identify a message. Used only in I-DATA chunks (not DATA). + MID message_id; + // Fragment Sequence Number (FSN) per stream and ordered/unordered, as above. + FSN fsn; + + // Payload Protocol Identifier (PPID). + PPID ppid; + + // The actual data payload. + std::vector<uint8_t> payload; + + // If this data represents the first, last or a middle chunk. + IsBeginning is_beginning; + IsEnd is_end; + // If this data is sent/received unordered. + IsUnordered is_unordered; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_DATA_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.cc new file mode 100644 index 0000000000..ef67c2a49f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h" + +#include <stdint.h> + +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.10 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=10 | Cause Length=4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int CookieReceivedWhileShuttingDownCause::kType; + +absl::optional<CookieReceivedWhileShuttingDownCause> +CookieReceivedWhileShuttingDownCause::Parse( + rtc::ArrayView<const uint8_t> data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return CookieReceivedWhileShuttingDownCause(); +} + +void CookieReceivedWhileShuttingDownCause::SerializeTo( + std::vector<uint8_t>& out) const { + AllocateTLV(out); +} + +std::string CookieReceivedWhileShuttingDownCause::ToString() const { + return "Cookie Received While Shutting Down"; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h new file mode 100644 index 0000000000..362f181fba --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_COOKIE_RECEIVED_WHILE_SHUTTING_DOWN_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_COOKIE_RECEIVED_WHILE_SHUTTING_DOWN_CAUSE_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.10 +struct CookieReceivedWhileShuttingDownCauseConfig : public ParameterConfig { + static constexpr int kType = 10; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class CookieReceivedWhileShuttingDownCause + : public Parameter, + public TLVTrait<CookieReceivedWhileShuttingDownCauseConfig> { + public: + static constexpr int kType = + CookieReceivedWhileShuttingDownCauseConfig::kType; + + CookieReceivedWhileShuttingDownCause() {} + + static absl::optional<CookieReceivedWhileShuttingDownCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_COOKIE_RECEIVED_WHILE_SHUTTING_DOWN_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause_test.cc new file mode 100644 index 0000000000..afb8364c32 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause_test.cc @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(CookieReceivedWhileShuttingDownCauseTest, SerializeAndDeserialize) { + CookieReceivedWhileShuttingDownCause parameter; + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + CookieReceivedWhileShuttingDownCause deserialized, + CookieReceivedWhileShuttingDownCause::Parse(serialized)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/error_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/error_cause.cc new file mode 100644 index 0000000000..dcd07472ed --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/error_cause.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/error_cause.h" + +#include <stddef.h> + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h" +#include "net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h" +#include "net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h" +#include "net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h" +#include "net/dcsctp/packet/error_cause/no_user_data_cause.h" +#include "net/dcsctp/packet/error_cause/out_of_resource_error_cause.h" +#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" +#include "net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h" +#include "net/dcsctp/packet/error_cause/stale_cookie_error_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h" +#include "net/dcsctp/packet/error_cause/unresolvable_address_cause.h" +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +template <class ErrorCause> +bool ParseAndPrint(ParameterDescriptor descriptor, rtc::StringBuilder& sb) { + if (descriptor.type == ErrorCause::kType) { + absl::optional<ErrorCause> p = ErrorCause::Parse(descriptor.data); + if (p.has_value()) { + sb << p->ToString(); + } else { + sb << "Failed to parse error cause of type " << ErrorCause::kType; + } + return true; + } + return false; +} + +std::string ErrorCausesToString(const Parameters& parameters) { + rtc::StringBuilder sb; + + std::vector<ParameterDescriptor> descriptors = parameters.descriptors(); + for (size_t i = 0; i < descriptors.size(); ++i) { + if (i > 0) { + sb << "\n"; + } + + const ParameterDescriptor& d = descriptors[i]; + if (!ParseAndPrint<InvalidStreamIdentifierCause>(d, sb) && + !ParseAndPrint<MissingMandatoryParameterCause>(d, sb) && + !ParseAndPrint<StaleCookieErrorCause>(d, sb) && + !ParseAndPrint<OutOfResourceErrorCause>(d, sb) && + !ParseAndPrint<UnresolvableAddressCause>(d, sb) && + !ParseAndPrint<UnrecognizedChunkTypeCause>(d, sb) && + !ParseAndPrint<InvalidMandatoryParameterCause>(d, sb) && + !ParseAndPrint<UnrecognizedParametersCause>(d, sb) && + !ParseAndPrint<NoUserDataCause>(d, sb) && + !ParseAndPrint<CookieReceivedWhileShuttingDownCause>(d, sb) && + !ParseAndPrint<RestartOfAnAssociationWithNewAddressesCause>(d, sb) && + !ParseAndPrint<UserInitiatedAbortCause>(d, sb) && + !ParseAndPrint<ProtocolViolationCause>(d, sb)) { + sb << "Unhandled parameter of type: " << d.type; + } + } + + return sb.Release(); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/error_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/error_cause.h new file mode 100644 index 0000000000..fa2bf81478 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/error_cause.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_ERROR_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_ERROR_CAUSE_H_ + +#include <stddef.h> + +#include <cstdint> +#include <iosfwd> +#include <memory> +#include <string> +#include <type_traits> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// Converts the Error Causes in `parameters` to a human readable string, +// to be used in error reporting and logging. +std::string ErrorCausesToString(const Parameters& parameters); + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_ERROR_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.cc new file mode 100644 index 0000000000..0187544226 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h" + +#include <stdint.h> + +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.7 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=7 | Cause Length=4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int InvalidMandatoryParameterCause::kType; + +absl::optional<InvalidMandatoryParameterCause> +InvalidMandatoryParameterCause::Parse(rtc::ArrayView<const uint8_t> data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return InvalidMandatoryParameterCause(); +} + +void InvalidMandatoryParameterCause::SerializeTo( + std::vector<uint8_t>& out) const { + AllocateTLV(out); +} + +std::string InvalidMandatoryParameterCause::ToString() const { + return "Invalid Mandatory Parameter"; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h new file mode 100644 index 0000000000..e192b5a42f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_MANDATORY_PARAMETER_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_MANDATORY_PARAMETER_CAUSE_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.7 +struct InvalidMandatoryParameterCauseConfig : public ParameterConfig { + static constexpr int kType = 7; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class InvalidMandatoryParameterCause + : public Parameter, + public TLVTrait<InvalidMandatoryParameterCauseConfig> { + public: + static constexpr int kType = InvalidMandatoryParameterCauseConfig::kType; + + InvalidMandatoryParameterCause() {} + + static absl::optional<InvalidMandatoryParameterCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_MANDATORY_PARAMETER_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause_test.cc new file mode 100644 index 0000000000..3d532d09b1 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause_test.cc @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/invalid_mandatory_parameter_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(InvalidMandatoryParameterCauseTest, SerializeAndDeserialize) { + InvalidMandatoryParameterCause parameter; + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + InvalidMandatoryParameterCause deserialized, + InvalidMandatoryParameterCause::Parse(serialized)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.cc new file mode 100644 index 0000000000..b2ddd6f4ef --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.cc @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.1 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=1 | Cause Length=8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Identifier | (Reserved) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int InvalidStreamIdentifierCause::kType; + +absl::optional<InvalidStreamIdentifierCause> +InvalidStreamIdentifierCause::Parse(rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + StreamID stream_id(reader->Load16<4>()); + return InvalidStreamIdentifierCause(stream_id); +} + +void InvalidStreamIdentifierCause::SerializeTo( + std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out); + + writer.Store16<4>(*stream_id_); +} + +std::string InvalidStreamIdentifierCause::ToString() const { + rtc::StringBuilder sb; + sb << "Invalid Stream Identifier, stream_id=" << *stream_id_; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h new file mode 100644 index 0000000000..b7dfe177b8 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_STREAM_IDENTIFIER_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_STREAM_IDENTIFIER_CAUSE_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.1 +struct InvalidStreamIdentifierCauseConfig : public ParameterConfig { + static constexpr int kType = 1; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class InvalidStreamIdentifierCause + : public Parameter, + public TLVTrait<InvalidStreamIdentifierCauseConfig> { + public: + static constexpr int kType = InvalidStreamIdentifierCauseConfig::kType; + + explicit InvalidStreamIdentifierCause(StreamID stream_id) + : stream_id_(stream_id) {} + + static absl::optional<InvalidStreamIdentifierCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + StreamID stream_id() const { return stream_id_; } + + private: + StreamID stream_id_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_INVALID_STREAM_IDENTIFIER_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause_test.cc new file mode 100644 index 0000000000..a282ce5ee8 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/invalid_stream_identifier_cause_test.cc @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/invalid_stream_identifier_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(InvalidStreamIdentifierCauseTest, SerializeAndDeserialize) { + InvalidStreamIdentifierCause parameter(StreamID(1)); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(InvalidStreamIdentifierCause deserialized, + InvalidStreamIdentifierCause::Parse(serialized)); + + EXPECT_EQ(*deserialized.stream_id(), 1); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.cc new file mode 100644 index 0000000000..b89f86e43e --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.cc @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h" + +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.2 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=2 | Cause Length=8+N*2 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of missing params=N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Missing Param Type #1 | Missing Param Type #2 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Missing Param Type #N-1 | Missing Param Type #N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int MissingMandatoryParameterCause::kType; + +absl::optional<MissingMandatoryParameterCause> +MissingMandatoryParameterCause::Parse(rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + uint32_t count = reader->Load32<4>(); + if (reader->variable_data_size() / kMissingParameterSize != count) { + RTC_DLOG(LS_WARNING) << "Invalid number of missing parameters"; + return absl::nullopt; + } + + std::vector<uint16_t> missing_parameter_types; + missing_parameter_types.reserve(count); + for (uint32_t i = 0; i < count; ++i) { + BoundedByteReader<kMissingParameterSize> sub_reader = + reader->sub_reader<kMissingParameterSize>(i * kMissingParameterSize); + + missing_parameter_types.push_back(sub_reader.Load16<0>()); + } + return MissingMandatoryParameterCause(missing_parameter_types); +} + +void MissingMandatoryParameterCause::SerializeTo( + std::vector<uint8_t>& out) const { + size_t variable_size = + missing_parameter_types_.size() * kMissingParameterSize; + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(missing_parameter_types_.size()); + + for (size_t i = 0; i < missing_parameter_types_.size(); ++i) { + BoundedByteWriter<kMissingParameterSize> sub_writer = + writer.sub_writer<kMissingParameterSize>(i * kMissingParameterSize); + + sub_writer.Store16<0>(missing_parameter_types_[i]); + } +} + +std::string MissingMandatoryParameterCause::ToString() const { + rtc::StringBuilder sb; + sb << "Missing Mandatory Parameter, missing_parameter_types=" + << StrJoin(missing_parameter_types_, ","); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h new file mode 100644 index 0000000000..4435424295 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_MISSING_MANDATORY_PARAMETER_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_MISSING_MANDATORY_PARAMETER_CAUSE_H_ +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.2 +struct MissingMandatoryParameterCauseConfig : public ParameterConfig { + static constexpr int kType = 2; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 2; +}; + +class MissingMandatoryParameterCause + : public Parameter, + public TLVTrait<MissingMandatoryParameterCauseConfig> { + public: + static constexpr int kType = MissingMandatoryParameterCauseConfig::kType; + + explicit MissingMandatoryParameterCause( + rtc::ArrayView<const uint16_t> missing_parameter_types) + : missing_parameter_types_(missing_parameter_types.begin(), + missing_parameter_types.end()) {} + + static absl::optional<MissingMandatoryParameterCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + rtc::ArrayView<const uint16_t> missing_parameter_types() const { + return missing_parameter_types_; + } + + private: + static constexpr size_t kMissingParameterSize = 2; + std::vector<uint16_t> missing_parameter_types_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_MISSING_MANDATORY_PARAMETER_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause_test.cc new file mode 100644 index 0000000000..1c526ff0e2 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause_test.cc @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/missing_mandatory_parameter_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::IsEmpty; + +TEST(MissingMandatoryParameterCauseTest, SerializeAndDeserialize) { + uint16_t parameter_types[] = {1, 2, 3}; + MissingMandatoryParameterCause parameter(parameter_types); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + MissingMandatoryParameterCause deserialized, + MissingMandatoryParameterCause::Parse(serialized)); + + EXPECT_THAT(deserialized.missing_parameter_types(), ElementsAre(1, 2, 3)); +} + +TEST(MissingMandatoryParameterCauseTest, HandlesDeserializeZeroParameters) { + uint8_t serialized[] = {0, 2, 0, 8, 0, 0, 0, 0}; + + ASSERT_HAS_VALUE_AND_ASSIGN( + MissingMandatoryParameterCause deserialized, + MissingMandatoryParameterCause::Parse(serialized)); + + EXPECT_THAT(deserialized.missing_parameter_types(), IsEmpty()); +} + +TEST(MissingMandatoryParameterCauseTest, HandlesOverflowParameterCount) { + // 0x80000004 * 2 = 2**32 + 8 -> if overflow, would validate correctly. + uint8_t serialized[] = {0, 2, 0, 8, 0x80, 0x00, 0x00, 0x04}; + + EXPECT_FALSE(MissingMandatoryParameterCause::Parse(serialized).has_value()); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/no_user_data_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/no_user_data_cause.cc new file mode 100644 index 0000000000..2853915b0c --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/no_user_data_cause.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/no_user_data_cause.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.9 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=9 | Cause Length=8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / TSN value / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int NoUserDataCause::kType; + +absl::optional<NoUserDataCause> NoUserDataCause::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + TSN tsn(reader->Load32<4>()); + return NoUserDataCause(tsn); +} + +void NoUserDataCause::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out); + writer.Store32<4>(*tsn_); +} + +std::string NoUserDataCause::ToString() const { + rtc::StringBuilder sb; + sb << "No User Data, tsn=" << *tsn_; + return sb.Release(); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/no_user_data_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/no_user_data_cause.h new file mode 100644 index 0000000000..1087dcc97c --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/no_user_data_cause.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_NO_USER_DATA_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_NO_USER_DATA_CAUSE_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.9 +struct NoUserDataCauseConfig : public ParameterConfig { + static constexpr int kType = 9; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class NoUserDataCause : public Parameter, + public TLVTrait<NoUserDataCauseConfig> { + public: + static constexpr int kType = NoUserDataCauseConfig::kType; + + explicit NoUserDataCause(TSN tsn) : tsn_(tsn) {} + + static absl::optional<NoUserDataCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + TSN tsn() const { return tsn_; } + + private: + TSN tsn_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_NO_USER_DATA_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/no_user_data_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/no_user_data_cause_test.cc new file mode 100644 index 0000000000..0a535bf4fa --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/no_user_data_cause_test.cc @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/no_user_data_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(NoUserDataCauseTest, SerializeAndDeserialize) { + NoUserDataCause parameter(TSN(123)); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(NoUserDataCause deserialized, + NoUserDataCause::Parse(serialized)); + + EXPECT_EQ(*deserialized.tsn(), 123u); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/out_of_resource_error_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/out_of_resource_error_cause.cc new file mode 100644 index 0000000000..e5c7c0e787 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/out_of_resource_error_cause.cc @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/out_of_resource_error_cause.h" + +#include <stdint.h> + +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.4 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=4 | Cause Length=4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int OutOfResourceErrorCause::kType; + +absl::optional<OutOfResourceErrorCause> OutOfResourceErrorCause::Parse( + rtc::ArrayView<const uint8_t> data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return OutOfResourceErrorCause(); +} + +void OutOfResourceErrorCause::SerializeTo(std::vector<uint8_t>& out) const { + AllocateTLV(out); +} + +std::string OutOfResourceErrorCause::ToString() const { + return "Out Of Resource"; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/out_of_resource_error_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/out_of_resource_error_cause.h new file mode 100644 index 0000000000..fc798ca4ac --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/out_of_resource_error_cause.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_OUT_OF_RESOURCE_ERROR_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_OUT_OF_RESOURCE_ERROR_CAUSE_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.4 +struct OutOfResourceParameterConfig : public ParameterConfig { + static constexpr int kType = 4; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class OutOfResourceErrorCause : public Parameter, + public TLVTrait<OutOfResourceParameterConfig> { + public: + static constexpr int kType = OutOfResourceParameterConfig::kType; + + OutOfResourceErrorCause() {} + + static absl::optional<OutOfResourceErrorCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_OUT_OF_RESOURCE_ERROR_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/out_of_resource_error_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/out_of_resource_error_cause_test.cc new file mode 100644 index 0000000000..501fc201cd --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/out_of_resource_error_cause_test.cc @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/out_of_resource_error_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(OutOfResourceErrorCauseTest, SerializeAndDeserialize) { + OutOfResourceErrorCause parameter; + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(OutOfResourceErrorCause deserialized, + OutOfResourceErrorCause::Parse(serialized)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/protocol_violation_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/protocol_violation_cause.cc new file mode 100644 index 0000000000..1b8d423afb --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/protocol_violation_cause.cc @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.13 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=13 | Cause Length=Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Additional Information / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ProtocolViolationCause::kType; + +absl::optional<ProtocolViolationCause> ProtocolViolationCause::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return ProtocolViolationCause( + std::string(reinterpret_cast<const char*>(reader->variable_data().data()), + reader->variable_data().size())); +} + +void ProtocolViolationCause::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = + AllocateTLV(out, additional_information_.size()); + writer.CopyToVariableData(rtc::MakeArrayView( + reinterpret_cast<const uint8_t*>(additional_information_.data()), + additional_information_.size())); +} + +std::string ProtocolViolationCause::ToString() const { + rtc::StringBuilder sb; + sb << "Protocol Violation, additional_information=" + << additional_information_; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/protocol_violation_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/protocol_violation_cause.h new file mode 100644 index 0000000000..3081e1f28c --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/protocol_violation_cause.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_PROTOCOL_VIOLATION_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_PROTOCOL_VIOLATION_CAUSE_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.13 +struct ProtocolViolationCauseConfig : public ParameterConfig { + static constexpr int kType = 13; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class ProtocolViolationCause : public Parameter, + public TLVTrait<ProtocolViolationCauseConfig> { + public: + static constexpr int kType = ProtocolViolationCauseConfig::kType; + + explicit ProtocolViolationCause(absl::string_view additional_information) + : additional_information_(additional_information) {} + + static absl::optional<ProtocolViolationCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + absl::string_view additional_information() const { + return additional_information_; + } + + private: + std::string additional_information_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_PROTOCOL_VIOLATION_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/protocol_violation_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/protocol_violation_cause_test.cc new file mode 100644 index 0000000000..902d867091 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/protocol_violation_cause_test.cc @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; + +TEST(ProtocolViolationCauseTest, EmptyReason) { + Parameters causes = + Parameters::Builder().Add(ProtocolViolationCause("")).Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(Parameters deserialized, + Parameters::Parse(causes.data())); + ASSERT_THAT(deserialized.descriptors(), SizeIs(1)); + EXPECT_EQ(deserialized.descriptors()[0].type, ProtocolViolationCause::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN( + ProtocolViolationCause cause, + ProtocolViolationCause::Parse(deserialized.descriptors()[0].data)); + + EXPECT_EQ(cause.additional_information(), ""); +} + +TEST(ProtocolViolationCauseTest, SetReason) { + Parameters causes = Parameters::Builder() + .Add(ProtocolViolationCause("Reason goes here")) + .Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(Parameters deserialized, + Parameters::Parse(causes.data())); + ASSERT_THAT(deserialized.descriptors(), SizeIs(1)); + EXPECT_EQ(deserialized.descriptors()[0].type, ProtocolViolationCause::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN( + ProtocolViolationCause cause, + ProtocolViolationCause::Parse(deserialized.descriptors()[0].data)); + + EXPECT_EQ(cause.additional_information(), "Reason goes here"); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.cc new file mode 100644 index 0000000000..abe5de6211 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.11 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=11 | Cause Length=Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / New Address TLVs / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int RestartOfAnAssociationWithNewAddressesCause::kType; + +absl::optional<RestartOfAnAssociationWithNewAddressesCause> +RestartOfAnAssociationWithNewAddressesCause::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return RestartOfAnAssociationWithNewAddressesCause(reader->variable_data()); +} + +void RestartOfAnAssociationWithNewAddressesCause::SerializeTo( + std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = + AllocateTLV(out, new_address_tlvs_.size()); + writer.CopyToVariableData(new_address_tlvs_); +} + +std::string RestartOfAnAssociationWithNewAddressesCause::ToString() const { + return "Restart of an Association with New Addresses"; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h new file mode 100644 index 0000000000..a1cccdc8a1 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_RESTART_OF_AN_ASSOCIATION_WITH_NEW_ADDRESS_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_RESTART_OF_AN_ASSOCIATION_WITH_NEW_ADDRESS_CAUSE_H_ +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.11 +struct RestartOfAnAssociationWithNewAddressesCauseConfig + : public ParameterConfig { + static constexpr int kType = 11; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class RestartOfAnAssociationWithNewAddressesCause + : public Parameter, + public TLVTrait<RestartOfAnAssociationWithNewAddressesCauseConfig> { + public: + static constexpr int kType = + RestartOfAnAssociationWithNewAddressesCauseConfig::kType; + + explicit RestartOfAnAssociationWithNewAddressesCause( + rtc::ArrayView<const uint8_t> new_address_tlvs) + : new_address_tlvs_(new_address_tlvs.begin(), new_address_tlvs.end()) {} + + static absl::optional<RestartOfAnAssociationWithNewAddressesCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + rtc::ArrayView<const uint8_t> new_address_tlvs() const { + return new_address_tlvs_; + } + + private: + std::vector<uint8_t> new_address_tlvs_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_RESTART_OF_AN_ASSOCIATION_WITH_NEW_ADDRESS_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause_test.cc new file mode 100644 index 0000000000..b8ab8b6803 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause_test.cc @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/restart_of_an_association_with_new_address_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(RestartOfAnAssociationWithNewAddressesCauseTest, SerializeAndDeserialize) { + uint8_t data[] = {1, 2, 3}; + RestartOfAnAssociationWithNewAddressesCause parameter(data); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + RestartOfAnAssociationWithNewAddressesCause deserialized, + RestartOfAnAssociationWithNewAddressesCause::Parse(serialized)); + + EXPECT_THAT(deserialized.new_address_tlvs(), ElementsAre(1, 2, 3)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/stale_cookie_error_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/stale_cookie_error_cause.cc new file mode 100644 index 0000000000..d77d8488f1 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/stale_cookie_error_cause.cc @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/stale_cookie_error_cause.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.3 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=3 | Cause Length=8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Measure of Staleness (usec.) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int StaleCookieErrorCause::kType; + +absl::optional<StaleCookieErrorCause> StaleCookieErrorCause::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + uint32_t staleness_us = reader->Load32<4>(); + return StaleCookieErrorCause(staleness_us); +} + +void StaleCookieErrorCause::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out); + writer.Store32<4>(staleness_us_); +} + +std::string StaleCookieErrorCause::ToString() const { + rtc::StringBuilder sb; + sb << "Stale Cookie Error, staleness_us=" << staleness_us_; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/stale_cookie_error_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/stale_cookie_error_cause.h new file mode 100644 index 0000000000..d8b7b5b5bd --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/stale_cookie_error_cause.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_STALE_COOKIE_ERROR_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_STALE_COOKIE_ERROR_CAUSE_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.3 +struct StaleCookieParameterConfig : public ParameterConfig { + static constexpr int kType = 3; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class StaleCookieErrorCause : public Parameter, + public TLVTrait<StaleCookieParameterConfig> { + public: + static constexpr int kType = StaleCookieParameterConfig::kType; + + explicit StaleCookieErrorCause(uint32_t staleness_us) + : staleness_us_(staleness_us) {} + + static absl::optional<StaleCookieErrorCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + uint16_t staleness_us() const { return staleness_us_; } + + private: + uint32_t staleness_us_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_STALE_COOKIE_ERROR_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/stale_cookie_error_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/stale_cookie_error_cause_test.cc new file mode 100644 index 0000000000..c0d1ac1c58 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/stale_cookie_error_cause_test.cc @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/stale_cookie_error_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(StaleCookieErrorCauseTest, SerializeAndDeserialize) { + StaleCookieErrorCause parameter(123); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(StaleCookieErrorCause deserialized, + StaleCookieErrorCause::Parse(serialized)); + + EXPECT_EQ(deserialized.staleness_us(), 123); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.cc new file mode 100644 index 0000000000..04b960d992 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.cc @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.6 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=6 | Cause Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Unrecognized Chunk / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int UnrecognizedChunkTypeCause::kType; + +absl::optional<UnrecognizedChunkTypeCause> UnrecognizedChunkTypeCause::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + std::vector<uint8_t> unrecognized_chunk(reader->variable_data().begin(), + reader->variable_data().end()); + return UnrecognizedChunkTypeCause(std::move(unrecognized_chunk)); +} + +void UnrecognizedChunkTypeCause::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = + AllocateTLV(out, unrecognized_chunk_.size()); + writer.CopyToVariableData(unrecognized_chunk_); +} + +std::string UnrecognizedChunkTypeCause::ToString() const { + rtc::StringBuilder sb; + sb << "Unrecognized Chunk Type, chunk_type="; + if (!unrecognized_chunk_.empty()) { + sb << static_cast<int>(unrecognized_chunk_[0]); + } else { + sb << "<missing>"; + } + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h new file mode 100644 index 0000000000..26d3d3b8f9 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_CHUNK_TYPE_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_CHUNK_TYPE_CAUSE_H_ +#include <stddef.h> +#include <stdint.h> + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.6 +struct UnrecognizedChunkTypeCauseConfig : public ParameterConfig { + static constexpr int kType = 6; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class UnrecognizedChunkTypeCause + : public Parameter, + public TLVTrait<UnrecognizedChunkTypeCauseConfig> { + public: + static constexpr int kType = UnrecognizedChunkTypeCauseConfig::kType; + + explicit UnrecognizedChunkTypeCause(std::vector<uint8_t> unrecognized_chunk) + : unrecognized_chunk_(std::move(unrecognized_chunk)) {} + + static absl::optional<UnrecognizedChunkTypeCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + rtc::ArrayView<const uint8_t> unrecognized_chunk() const { + return unrecognized_chunk_; + } + + private: + std::vector<uint8_t> unrecognized_chunk_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_CHUNK_TYPE_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause_test.cc new file mode 100644 index 0000000000..baff852f40 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause_test.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(UnrecognizedChunkTypeCauseTest, SerializeAndDeserialize) { + UnrecognizedChunkTypeCause parameter({1, 2, 3}); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(UnrecognizedChunkTypeCause deserialized, + UnrecognizedChunkTypeCause::Parse(serialized)); + + EXPECT_THAT(deserialized.unrecognized_chunk(), ElementsAre(1, 2, 3)); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.cc new file mode 100644 index 0000000000..80001a9eae --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.8 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=8 | Cause Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Unrecognized Parameters / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int UnrecognizedParametersCause::kType; + +absl::optional<UnrecognizedParametersCause> UnrecognizedParametersCause::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return UnrecognizedParametersCause(reader->variable_data()); +} + +void UnrecognizedParametersCause::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = + AllocateTLV(out, unrecognized_parameters_.size()); + writer.CopyToVariableData(unrecognized_parameters_); +} + +std::string UnrecognizedParametersCause::ToString() const { + return "Unrecognized Parameters"; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h new file mode 100644 index 0000000000..ebec5ed4c3 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_PARAMETER_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_PARAMETER_CAUSE_H_ +#include <stddef.h> +#include <stdint.h> + +#include <cstdint> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.8 +struct UnrecognizedParametersCauseConfig : public ParameterConfig { + static constexpr int kType = 8; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class UnrecognizedParametersCause + : public Parameter, + public TLVTrait<UnrecognizedParametersCauseConfig> { + public: + static constexpr int kType = UnrecognizedParametersCauseConfig::kType; + + explicit UnrecognizedParametersCause( + rtc::ArrayView<const uint8_t> unrecognized_parameters) + : unrecognized_parameters_(unrecognized_parameters.begin(), + unrecognized_parameters.end()) {} + + static absl::optional<UnrecognizedParametersCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + rtc::ArrayView<const uint8_t> unrecognized_parameters() const { + return unrecognized_parameters_; + } + + private: + std::vector<uint8_t> unrecognized_parameters_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_UNRECOGNIZED_PARAMETER_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_parameter_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_parameter_cause_test.cc new file mode 100644 index 0000000000..0449599ca6 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unrecognized_parameter_cause_test.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unrecognized_parameter_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(UnrecognizedParametersCauseTest, SerializeAndDeserialize) { + uint8_t unrecognized_parameters[] = {1, 2, 3}; + UnrecognizedParametersCause parameter(unrecognized_parameters); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(UnrecognizedParametersCause deserialized, + UnrecognizedParametersCause::Parse(serialized)); + + EXPECT_THAT(deserialized.unrecognized_parameters(), ElementsAre(1, 2, 3)); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/unresolvable_address_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unresolvable_address_cause.cc new file mode 100644 index 0000000000..8108d31aa7 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unresolvable_address_cause.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unresolvable_address_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.5 + +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=5 | Cause Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Unresolvable Address / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int UnresolvableAddressCause::kType; + +absl::optional<UnresolvableAddressCause> UnresolvableAddressCause::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return UnresolvableAddressCause(reader->variable_data()); +} + +void UnresolvableAddressCause::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = + AllocateTLV(out, unresolvable_address_.size()); + writer.CopyToVariableData(unresolvable_address_); +} + +std::string UnresolvableAddressCause::ToString() const { + return "Unresolvable Address"; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/unresolvable_address_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unresolvable_address_cause.h new file mode 100644 index 0000000000..c63b3779ef --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unresolvable_address_cause.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_UNRESOLVABLE_ADDRESS_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_UNRESOLVABLE_ADDRESS_CAUSE_H_ +#include <stddef.h> +#include <stdint.h> + +#include <cstdint> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.5 +struct UnresolvableAddressCauseConfig : public ParameterConfig { + static constexpr int kType = 5; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class UnresolvableAddressCause + : public Parameter, + public TLVTrait<UnresolvableAddressCauseConfig> { + public: + static constexpr int kType = UnresolvableAddressCauseConfig::kType; + + explicit UnresolvableAddressCause( + rtc::ArrayView<const uint8_t> unresolvable_address) + : unresolvable_address_(unresolvable_address.begin(), + unresolvable_address.end()) {} + + static absl::optional<UnresolvableAddressCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + rtc::ArrayView<const uint8_t> unresolvable_address() const { + return unresolvable_address_; + } + + private: + std::vector<uint8_t> unresolvable_address_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_UNRESOLVABLE_ADDRESS_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/unresolvable_address_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unresolvable_address_cause_test.cc new file mode 100644 index 0000000000..688730e6b3 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/unresolvable_address_cause_test.cc @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/unresolvable_address_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(UnresolvableAddressCauseTest, SerializeAndDeserialize) { + uint8_t unresolvable_address[] = {1, 2, 3}; + UnresolvableAddressCause parameter(unresolvable_address); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(UnresolvableAddressCause deserialized, + UnresolvableAddressCause::Parse(serialized)); + + EXPECT_THAT(deserialized.unresolvable_address(), ElementsAre(1, 2, 3)); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/user_initiated_abort_cause.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/user_initiated_abort_cause.cc new file mode 100644 index 0000000000..da99aacbfa --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/user_initiated_abort_cause.cc @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.12 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cause Code=12 | Cause Length=Variable | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / Upper Layer Abort Reason / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int UserInitiatedAbortCause::kType; + +absl::optional<UserInitiatedAbortCause> UserInitiatedAbortCause::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + if (reader->variable_data().empty()) { + return UserInitiatedAbortCause(""); + } + return UserInitiatedAbortCause( + std::string(reinterpret_cast<const char*>(reader->variable_data().data()), + reader->variable_data().size())); +} + +void UserInitiatedAbortCause::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = + AllocateTLV(out, upper_layer_abort_reason_.size()); + writer.CopyToVariableData(rtc::MakeArrayView( + reinterpret_cast<const uint8_t*>(upper_layer_abort_reason_.data()), + upper_layer_abort_reason_.size())); +} + +std::string UserInitiatedAbortCause::ToString() const { + rtc::StringBuilder sb; + sb << "User-Initiated Abort, reason=" << upper_layer_abort_reason_; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/user_initiated_abort_cause.h b/third_party/libwebrtc/net/dcsctp/packet/error_cause/user_initiated_abort_cause.h new file mode 100644 index 0000000000..9eb16657b4 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/user_initiated_abort_cause.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_ERROR_CAUSE_USER_INITIATED_ABORT_CAUSE_H_ +#define NET_DCSCTP_PACKET_ERROR_CAUSE_USER_INITIATED_ABORT_CAUSE_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.10.12 +struct UserInitiatedAbortCauseConfig : public ParameterConfig { + static constexpr int kType = 12; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class UserInitiatedAbortCause : public Parameter, + public TLVTrait<UserInitiatedAbortCauseConfig> { + public: + static constexpr int kType = UserInitiatedAbortCauseConfig::kType; + + explicit UserInitiatedAbortCause(absl::string_view upper_layer_abort_reason) + : upper_layer_abort_reason_(upper_layer_abort_reason) {} + + static absl::optional<UserInitiatedAbortCause> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + absl::string_view upper_layer_abort_reason() const { + return upper_layer_abort_reason_; + } + + private: + std::string upper_layer_abort_reason_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_ERROR_CAUSE_USER_INITIATED_ABORT_CAUSE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/error_cause/user_initiated_abort_cause_test.cc b/third_party/libwebrtc/net/dcsctp/packet/error_cause/user_initiated_abort_cause_test.cc new file mode 100644 index 0000000000..250959e3df --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/error_cause/user_initiated_abort_cause_test.cc @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; + +TEST(UserInitiatedAbortCauseTest, EmptyReason) { + Parameters causes = + Parameters::Builder().Add(UserInitiatedAbortCause("")).Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(Parameters deserialized, + Parameters::Parse(causes.data())); + ASSERT_THAT(deserialized.descriptors(), SizeIs(1)); + EXPECT_EQ(deserialized.descriptors()[0].type, UserInitiatedAbortCause::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + UserInitiatedAbortCause::Parse(deserialized.descriptors()[0].data)); + + EXPECT_EQ(cause.upper_layer_abort_reason(), ""); +} + +TEST(UserInitiatedAbortCauseTest, SetReason) { + Parameters causes = Parameters::Builder() + .Add(UserInitiatedAbortCause("User called Close")) + .Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(Parameters deserialized, + Parameters::Parse(causes.data())); + ASSERT_THAT(deserialized.descriptors(), SizeIs(1)); + EXPECT_EQ(deserialized.descriptors()[0].type, UserInitiatedAbortCause::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + UserInitiatedAbortCause::Parse(deserialized.descriptors()[0].data)); + + EXPECT_EQ(cause.upper_layer_abort_reason(), "User called Close"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.cc new file mode 100644 index 0000000000..c33e3e11f6 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.cc @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.6 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 18 | Parameter Length = 12 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Request Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of new streams | Reserved | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int AddIncomingStreamsRequestParameter::kType; + +absl::optional<AddIncomingStreamsRequestParameter> +AddIncomingStreamsRequestParameter::Parse(rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + ReconfigRequestSN request_sequence_number(reader->Load32<4>()); + uint16_t nbr_of_new_streams = reader->Load16<8>(); + + return AddIncomingStreamsRequestParameter(request_sequence_number, + nbr_of_new_streams); +} + +void AddIncomingStreamsRequestParameter::SerializeTo( + std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out); + writer.Store32<4>(*request_sequence_number_); + writer.Store16<8>(nbr_of_new_streams_); +} + +std::string AddIncomingStreamsRequestParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Add Incoming Streams Request, req_seq_nbr=" + << *request_sequence_number(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h new file mode 100644 index 0000000000..3859eb3f7e --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_ADD_INCOMING_STREAMS_REQUEST_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_ADD_INCOMING_STREAMS_REQUEST_PARAMETER_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.6 +struct AddIncomingStreamsRequestParameterConfig : ParameterConfig { + static constexpr int kType = 18; + static constexpr size_t kHeaderSize = 12; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class AddIncomingStreamsRequestParameter + : public Parameter, + public TLVTrait<AddIncomingStreamsRequestParameterConfig> { + public: + static constexpr int kType = AddIncomingStreamsRequestParameterConfig::kType; + + explicit AddIncomingStreamsRequestParameter( + ReconfigRequestSN request_sequence_number, + uint16_t nbr_of_new_streams) + : request_sequence_number_(request_sequence_number), + nbr_of_new_streams_(nbr_of_new_streams) {} + + static absl::optional<AddIncomingStreamsRequestParameter> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + ReconfigRequestSN request_sequence_number() const { + return request_sequence_number_; + } + uint16_t nbr_of_new_streams() const { return nbr_of_new_streams_; } + + private: + ReconfigRequestSN request_sequence_number_; + uint16_t nbr_of_new_streams_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_ADD_INCOMING_STREAMS_REQUEST_PARAMETER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter_test.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter_test.cc new file mode 100644 index 0000000000..a29257a8f8 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_incoming_streams_request_parameter_test.cc @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(AddIncomingStreamsRequestParameterTest, SerializeAndDeserialize) { + AddIncomingStreamsRequestParameter parameter(ReconfigRequestSN(1), 2); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + AddIncomingStreamsRequestParameter deserialized, + AddIncomingStreamsRequestParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.request_sequence_number(), 1u); + EXPECT_EQ(deserialized.nbr_of_new_streams(), 2u); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.cc new file mode 100644 index 0000000000..4787ee9718 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.cc @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.5 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 17 | Parameter Length = 12 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Request Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Number of new streams | Reserved | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int AddOutgoingStreamsRequestParameter::kType; + +absl::optional<AddOutgoingStreamsRequestParameter> +AddOutgoingStreamsRequestParameter::Parse(rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + ReconfigRequestSN request_sequence_number(reader->Load32<4>()); + uint16_t nbr_of_new_streams = reader->Load16<8>(); + + return AddOutgoingStreamsRequestParameter(request_sequence_number, + nbr_of_new_streams); +} + +void AddOutgoingStreamsRequestParameter::SerializeTo( + std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out); + writer.Store32<4>(*request_sequence_number_); + writer.Store16<8>(nbr_of_new_streams_); +} + +std::string AddOutgoingStreamsRequestParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Add Outgoing Streams Request, req_seq_nbr=" + << *request_sequence_number(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h new file mode 100644 index 0000000000..01e8f91cfa --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_ADD_OUTGOING_STREAMS_REQUEST_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_ADD_OUTGOING_STREAMS_REQUEST_PARAMETER_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.5 +struct AddOutgoingStreamsRequestParameterConfig : ParameterConfig { + static constexpr int kType = 17; + static constexpr size_t kHeaderSize = 12; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class AddOutgoingStreamsRequestParameter + : public Parameter, + public TLVTrait<AddOutgoingStreamsRequestParameterConfig> { + public: + static constexpr int kType = AddOutgoingStreamsRequestParameterConfig::kType; + + explicit AddOutgoingStreamsRequestParameter( + ReconfigRequestSN request_sequence_number, + uint16_t nbr_of_new_streams) + : request_sequence_number_(request_sequence_number), + nbr_of_new_streams_(nbr_of_new_streams) {} + + static absl::optional<AddOutgoingStreamsRequestParameter> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + ReconfigRequestSN request_sequence_number() const { + return request_sequence_number_; + } + uint16_t nbr_of_new_streams() const { return nbr_of_new_streams_; } + + private: + ReconfigRequestSN request_sequence_number_; + uint16_t nbr_of_new_streams_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_ADD_OUTGOING_STREAMS_REQUEST_PARAMETER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter_test.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter_test.cc new file mode 100644 index 0000000000..d0303b1ba8 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter_test.cc @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(AddOutgoingStreamsRequestParameterTest, SerializeAndDeserialize) { + AddOutgoingStreamsRequestParameter parameter(ReconfigRequestSN(1), 2); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + AddOutgoingStreamsRequestParameter deserialized, + AddOutgoingStreamsRequestParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.request_sequence_number(), 1u); + EXPECT_EQ(deserialized.nbr_of_new_streams(), 2u); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.cc new file mode 100644 index 0000000000..7dd8e1923f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" + +#include <stdint.h> + +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc3758#section-3.1 + +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 49152 | Parameter Length = 4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ForwardTsnSupportedParameter::kType; + +absl::optional<ForwardTsnSupportedParameter> +ForwardTsnSupportedParameter::Parse(rtc::ArrayView<const uint8_t> data) { + if (!ParseTLV(data).has_value()) { + return absl::nullopt; + } + return ForwardTsnSupportedParameter(); +} + +void ForwardTsnSupportedParameter::SerializeTo( + std::vector<uint8_t>& out) const { + AllocateTLV(out); +} + +std::string ForwardTsnSupportedParameter::ToString() const { + return "Forward TSN Supported"; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h b/third_party/libwebrtc/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h new file mode 100644 index 0000000000..d4cff4ac21 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_FORWARD_TSN_SUPPORTED_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_FORWARD_TSN_SUPPORTED_PARAMETER_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc3758#section-3.1 +struct ForwardTsnSupportedParameterConfig : ParameterConfig { + static constexpr int kType = 49152; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class ForwardTsnSupportedParameter + : public Parameter, + public TLVTrait<ForwardTsnSupportedParameterConfig> { + public: + static constexpr int kType = ForwardTsnSupportedParameterConfig::kType; + + ForwardTsnSupportedParameter() {} + + static absl::optional<ForwardTsnSupportedParameter> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_FORWARD_TSN_SUPPORTED_PARAMETER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/forward_tsn_supported_parameter_test.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/forward_tsn_supported_parameter_test.cc new file mode 100644 index 0000000000..fb4f983fae --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/forward_tsn_supported_parameter_test.cc @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" + +namespace dcsctp { +namespace { + +TEST(ForwardTsnSupportedParameterTest, SerializeAndDeserialize) { + ForwardTsnSupportedParameter parameter; + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(ForwardTsnSupportedParameter deserialized, + ForwardTsnSupportedParameter::Parse(serialized)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/heartbeat_info_parameter.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/heartbeat_info_parameter.cc new file mode 100644 index 0000000000..918976d305 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/heartbeat_info_parameter.cc @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.5 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type = 4 | Chunk Flags | Heartbeat Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// \ \ +// / Heartbeat Information TLV (Variable-Length) / +// \ \ +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int HeartbeatInfoParameter::kType; + +absl::optional<HeartbeatInfoParameter> HeartbeatInfoParameter::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return HeartbeatInfoParameter(reader->variable_data()); +} + +void HeartbeatInfoParameter::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, info_.size()); + writer.CopyToVariableData(info_); +} + +std::string HeartbeatInfoParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Heartbeat Info parameter (info_length=" << info_.size() << ")"; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/heartbeat_info_parameter.h b/third_party/libwebrtc/net/dcsctp/packet/parameter/heartbeat_info_parameter.h new file mode 100644 index 0000000000..ec503a94b2 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/heartbeat_info_parameter.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_HEARTBEAT_INFO_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_HEARTBEAT_INFO_PARAMETER_H_ +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.5 +struct HeartbeatInfoParameterConfig : ParameterConfig { + static constexpr int kType = 1; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class HeartbeatInfoParameter : public Parameter, + public TLVTrait<HeartbeatInfoParameterConfig> { + public: + static constexpr int kType = HeartbeatInfoParameterConfig::kType; + + explicit HeartbeatInfoParameter(rtc::ArrayView<const uint8_t> info) + : info_(info.begin(), info.end()) {} + + static absl::optional<HeartbeatInfoParameter> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + rtc::ArrayView<const uint8_t> info() const { return info_; } + + private: + std::vector<uint8_t> info_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_HEARTBEAT_INFO_PARAMETER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.cc new file mode 100644 index 0000000000..6191adfe9d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.cc @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" + +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.2 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 14 | Parameter Length = 8 + 2 * N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Request Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Number 1 (optional) | Stream Number 2 (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / ...... / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Number N-1 (optional) | Stream Number N (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int IncomingSSNResetRequestParameter::kType; + +absl::optional<IncomingSSNResetRequestParameter> +IncomingSSNResetRequestParameter::Parse(rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + ReconfigRequestSN request_sequence_number(reader->Load32<4>()); + + size_t stream_count = reader->variable_data_size() / kStreamIdSize; + std::vector<StreamID> stream_ids; + stream_ids.reserve(stream_count); + for (size_t i = 0; i < stream_count; ++i) { + BoundedByteReader<kStreamIdSize> sub_reader = + reader->sub_reader<kStreamIdSize>(i * kStreamIdSize); + + stream_ids.push_back(StreamID(sub_reader.Load16<0>())); + } + + return IncomingSSNResetRequestParameter(request_sequence_number, + std::move(stream_ids)); +} + +void IncomingSSNResetRequestParameter::SerializeTo( + std::vector<uint8_t>& out) const { + size_t variable_size = stream_ids_.size() * kStreamIdSize; + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*request_sequence_number_); + + for (size_t i = 0; i < stream_ids_.size(); ++i) { + BoundedByteWriter<kStreamIdSize> sub_writer = + writer.sub_writer<kStreamIdSize>(i * kStreamIdSize); + sub_writer.Store16<0>(*stream_ids_[i]); + } +} + +std::string IncomingSSNResetRequestParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Incoming SSN Reset Request, req_seq_nbr=" + << *request_sequence_number(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h b/third_party/libwebrtc/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h new file mode 100644 index 0000000000..18963efafc --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_INCOMING_SSN_RESET_REQUEST_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_INCOMING_SSN_RESET_REQUEST_PARAMETER_H_ +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.2 +struct IncomingSSNResetRequestParameterConfig : ParameterConfig { + static constexpr int kType = 14; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 2; +}; + +class IncomingSSNResetRequestParameter + : public Parameter, + public TLVTrait<IncomingSSNResetRequestParameterConfig> { + public: + static constexpr int kType = IncomingSSNResetRequestParameterConfig::kType; + + explicit IncomingSSNResetRequestParameter( + ReconfigRequestSN request_sequence_number, + std::vector<StreamID> stream_ids) + : request_sequence_number_(request_sequence_number), + stream_ids_(std::move(stream_ids)) {} + + static absl::optional<IncomingSSNResetRequestParameter> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + ReconfigRequestSN request_sequence_number() const { + return request_sequence_number_; + } + rtc::ArrayView<const StreamID> stream_ids() const { return stream_ids_; } + + private: + static constexpr size_t kStreamIdSize = sizeof(uint16_t); + + ReconfigRequestSN request_sequence_number_; + std::vector<StreamID> stream_ids_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_INCOMING_SSN_RESET_REQUEST_PARAMETER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter_test.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter_test.cc new file mode 100644 index 0000000000..17793f6638 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter_test.cc @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(IncomingSSNResetRequestParameterTest, SerializeAndDeserialize) { + IncomingSSNResetRequestParameter parameter( + ReconfigRequestSN(1), {StreamID(2), StreamID(3), StreamID(4)}); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + IncomingSSNResetRequestParameter deserialized, + IncomingSSNResetRequestParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.request_sequence_number(), 1u); + EXPECT_THAT(deserialized.stream_ids(), + ElementsAre(StreamID(2), StreamID(3), StreamID(4))); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.cc new file mode 100644 index 0000000000..c25a2426be --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.cc @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" + +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/types.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.1 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 13 | Parameter Length = 16 + 2 * N | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Request Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Response Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Sender's Last Assigned TSN | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Number 1 (optional) | Stream Number 2 (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// / ...... / +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Stream Number N-1 (optional) | Stream Number N (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int OutgoingSSNResetRequestParameter::kType; + +absl::optional<OutgoingSSNResetRequestParameter> +OutgoingSSNResetRequestParameter::Parse(rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + ReconfigRequestSN request_sequence_number(reader->Load32<4>()); + ReconfigRequestSN response_sequence_number(reader->Load32<8>()); + TSN sender_last_assigned_tsn(reader->Load32<12>()); + + size_t stream_count = reader->variable_data_size() / kStreamIdSize; + std::vector<StreamID> stream_ids; + stream_ids.reserve(stream_count); + for (size_t i = 0; i < stream_count; ++i) { + BoundedByteReader<kStreamIdSize> sub_reader = + reader->sub_reader<kStreamIdSize>(i * kStreamIdSize); + + stream_ids.push_back(StreamID(sub_reader.Load16<0>())); + } + + return OutgoingSSNResetRequestParameter( + request_sequence_number, response_sequence_number, + sender_last_assigned_tsn, std::move(stream_ids)); +} + +void OutgoingSSNResetRequestParameter::SerializeTo( + std::vector<uint8_t>& out) const { + size_t variable_size = stream_ids_.size() * kStreamIdSize; + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*request_sequence_number_); + writer.Store32<8>(*response_sequence_number_); + writer.Store32<12>(*sender_last_assigned_tsn_); + + for (size_t i = 0; i < stream_ids_.size(); ++i) { + BoundedByteWriter<kStreamIdSize> sub_writer = + writer.sub_writer<kStreamIdSize>(i * kStreamIdSize); + sub_writer.Store16<0>(*stream_ids_[i]); + } +} + +std::string OutgoingSSNResetRequestParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Outgoing SSN Reset Request, req_seq_nbr=" << *request_sequence_number() + << ", resp_seq_nbr=" << *response_sequence_number() + << ", sender_last_asg_tsn=" << *sender_last_assigned_tsn(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h b/third_party/libwebrtc/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h new file mode 100644 index 0000000000..6eb44e079f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_OUTGOING_SSN_RESET_REQUEST_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_OUTGOING_SSN_RESET_REQUEST_PARAMETER_H_ +#include <stddef.h> +#include <stdint.h> + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.1 +struct OutgoingSSNResetRequestParameterConfig : ParameterConfig { + static constexpr int kType = 13; + static constexpr size_t kHeaderSize = 16; + static constexpr size_t kVariableLengthAlignment = 2; +}; + +class OutgoingSSNResetRequestParameter + : public Parameter, + public TLVTrait<OutgoingSSNResetRequestParameterConfig> { + public: + static constexpr int kType = OutgoingSSNResetRequestParameterConfig::kType; + + explicit OutgoingSSNResetRequestParameter( + ReconfigRequestSN request_sequence_number, + ReconfigRequestSN response_sequence_number, + TSN sender_last_assigned_tsn, + std::vector<StreamID> stream_ids) + : request_sequence_number_(request_sequence_number), + response_sequence_number_(response_sequence_number), + sender_last_assigned_tsn_(sender_last_assigned_tsn), + stream_ids_(std::move(stream_ids)) {} + + static absl::optional<OutgoingSSNResetRequestParameter> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + ReconfigRequestSN request_sequence_number() const { + return request_sequence_number_; + } + ReconfigRequestSN response_sequence_number() const { + return response_sequence_number_; + } + TSN sender_last_assigned_tsn() const { return sender_last_assigned_tsn_; } + rtc::ArrayView<const StreamID> stream_ids() const { return stream_ids_; } + + private: + static constexpr size_t kStreamIdSize = sizeof(uint16_t); + + ReconfigRequestSN request_sequence_number_; + ReconfigRequestSN response_sequence_number_; + TSN sender_last_assigned_tsn_; + std::vector<StreamID> stream_ids_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_OUTGOING_SSN_RESET_REQUEST_PARAMETER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter_test.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter_test.cc new file mode 100644 index 0000000000..dae73c2fba --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter_test.cc @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(OutgoingSSNResetRequestParameterTest, SerializeAndDeserialize) { + OutgoingSSNResetRequestParameter parameter( + ReconfigRequestSN(1), ReconfigRequestSN(2), TSN(3), + {StreamID(4), StreamID(5), StreamID(6)}); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter deserialized, + OutgoingSSNResetRequestParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.request_sequence_number(), 1u); + EXPECT_EQ(*deserialized.response_sequence_number(), 2u); + EXPECT_EQ(*deserialized.sender_last_assigned_tsn(), 3u); + EXPECT_THAT(deserialized.stream_ids(), + ElementsAre(StreamID(4), StreamID(5), StreamID(6))); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/parameter.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/parameter.cc new file mode 100644 index 0000000000..b3b2bffef7 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/parameter.cc @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/parameter.h" + +#include <stddef.h> + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h" +#include "net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h" +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +constexpr size_t kParameterHeaderSize = 4; + +Parameters::Builder& Parameters::Builder::Add(const Parameter& p) { + // https://tools.ietf.org/html/rfc4960#section-3.2.1 + // "If the length of the parameter is not a multiple of 4 bytes, the sender + // pads the parameter at the end (i.e., after the Parameter Value field) with + // all zero bytes." + if (data_.size() % 4 != 0) { + data_.resize(RoundUpTo4(data_.size())); + } + + p.SerializeTo(data_); + return *this; +} + +std::vector<ParameterDescriptor> Parameters::descriptors() const { + rtc::ArrayView<const uint8_t> span(data_); + std::vector<ParameterDescriptor> result; + while (!span.empty()) { + BoundedByteReader<kParameterHeaderSize> header(span); + uint16_t type = header.Load16<0>(); + uint16_t length = header.Load16<2>(); + result.emplace_back(type, span.subview(0, length)); + size_t length_with_padding = RoundUpTo4(length); + if (length_with_padding > span.size()) { + break; + } + span = span.subview(length_with_padding); + } + return result; +} + +absl::optional<Parameters> Parameters::Parse( + rtc::ArrayView<const uint8_t> data) { + // Validate the parameter descriptors + rtc::ArrayView<const uint8_t> span(data); + while (!span.empty()) { + if (span.size() < kParameterHeaderSize) { + RTC_DLOG(LS_WARNING) << "Insufficient parameter length"; + return absl::nullopt; + } + BoundedByteReader<kParameterHeaderSize> header(span); + uint16_t length = header.Load16<2>(); + if (length < kParameterHeaderSize || length > span.size()) { + RTC_DLOG(LS_WARNING) << "Invalid parameter length field"; + return absl::nullopt; + } + size_t length_with_padding = RoundUpTo4(length); + if (length_with_padding > span.size()) { + break; + } + span = span.subview(length_with_padding); + } + return Parameters(std::vector<uint8_t>(data.begin(), data.end())); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/parameter.h b/third_party/libwebrtc/net/dcsctp/packet/parameter/parameter.h new file mode 100644 index 0000000000..e8fa67c8f7 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/parameter.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_PARAMETER_H_ + +#include <stddef.h> + +#include <algorithm> +#include <cstdint> +#include <iterator> +#include <memory> +#include <string> +#include <type_traits> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +class Parameter { + public: + Parameter() {} + virtual ~Parameter() = default; + + Parameter(const Parameter& other) = default; + Parameter& operator=(const Parameter& other) = default; + + virtual void SerializeTo(std::vector<uint8_t>& out) const = 0; + virtual std::string ToString() const = 0; +}; + +struct ParameterDescriptor { + ParameterDescriptor(uint16_t type, rtc::ArrayView<const uint8_t> data) + : type(type), data(data) {} + uint16_t type; + rtc::ArrayView<const uint8_t> data; +}; + +class Parameters { + public: + class Builder { + public: + Builder() {} + Builder& Add(const Parameter& p); + Parameters Build() { return Parameters(std::move(data_)); } + + private: + std::vector<uint8_t> data_; + }; + + static absl::optional<Parameters> Parse(rtc::ArrayView<const uint8_t> data); + + Parameters() {} + Parameters(Parameters&& other) = default; + Parameters& operator=(Parameters&& other) = default; + + rtc::ArrayView<const uint8_t> data() const { return data_; } + std::vector<ParameterDescriptor> descriptors() const; + + template <typename P> + absl::optional<P> get() const { + static_assert(std::is_base_of<Parameter, P>::value, + "Template parameter not derived from Parameter"); + for (const auto& p : descriptors()) { + if (p.type == P::kType) { + return P::Parse(p.data); + } + } + return absl::nullopt; + } + + private: + explicit Parameters(std::vector<uint8_t> data) : data_(std::move(data)) {} + std::vector<uint8_t> data_; +}; + +struct ParameterConfig { + static constexpr int kTypeSizeInBytes = 2; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_PARAMETER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/parameter_test.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/parameter_test.cc new file mode 100644 index 0000000000..467e324592 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/parameter_test.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/parameter.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::SizeIs; + +TEST(ParameterTest, SerializeDeserializeParameter) { + Parameters parameters = + Parameters::Builder() + .Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(123), + ReconfigRequestSN(456), + TSN(789), {StreamID(42)})) + .Build(); + + rtc::ArrayView<const uint8_t> serialized = parameters.data(); + + ASSERT_HAS_VALUE_AND_ASSIGN(Parameters parsed, Parameters::Parse(serialized)); + auto descriptors = parsed.descriptors(); + ASSERT_THAT(descriptors, SizeIs(1)); + EXPECT_THAT(descriptors[0].type, OutgoingSSNResetRequestParameter::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter parsed_param, + OutgoingSSNResetRequestParameter::Parse(descriptors[0].data)); + EXPECT_EQ(*parsed_param.request_sequence_number(), 123u); + EXPECT_EQ(*parsed_param.response_sequence_number(), 456u); + EXPECT_EQ(*parsed_param.sender_last_assigned_tsn(), 789u); + EXPECT_THAT(parsed_param.stream_ids(), ElementsAre(StreamID(42))); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/reconfiguration_response_parameter.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/reconfiguration_response_parameter.cc new file mode 100644 index 0000000000..fafb204acc --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/reconfiguration_response_parameter.cc @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" + +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.4 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 16 | Parameter Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Response Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Result | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Sender's Next TSN (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Receiver's Next TSN (optional) | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int ReconfigurationResponseParameter::kType; + +absl::string_view ToString(ReconfigurationResponseParameter::Result result) { + switch (result) { + case ReconfigurationResponseParameter::Result::kSuccessNothingToDo: + return "Success: nothing to do"; + case ReconfigurationResponseParameter::Result::kSuccessPerformed: + return "Success: performed"; + case ReconfigurationResponseParameter::Result::kDenied: + return "Denied"; + case ReconfigurationResponseParameter::Result::kErrorWrongSSN: + return "Error: wrong ssn"; + case ReconfigurationResponseParameter::Result:: + kErrorRequestAlreadyInProgress: + return "Error: request already in progress"; + case ReconfigurationResponseParameter::Result::kErrorBadSequenceNumber: + return "Error: bad sequence number"; + case ReconfigurationResponseParameter::Result::kInProgress: + return "In progress"; + } +} + +absl::optional<ReconfigurationResponseParameter> +ReconfigurationResponseParameter::Parse(rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + ReconfigRequestSN response_sequence_number(reader->Load32<4>()); + Result result; + uint32_t result_nbr = reader->Load32<8>(); + switch (result_nbr) { + case 0: + result = ReconfigurationResponseParameter::Result::kSuccessNothingToDo; + break; + case 1: + result = ReconfigurationResponseParameter::Result::kSuccessPerformed; + break; + case 2: + result = ReconfigurationResponseParameter::Result::kDenied; + break; + case 3: + result = ReconfigurationResponseParameter::Result::kErrorWrongSSN; + break; + case 4: + result = ReconfigurationResponseParameter::Result:: + kErrorRequestAlreadyInProgress; + break; + case 5: + result = + ReconfigurationResponseParameter::Result::kErrorBadSequenceNumber; + break; + case 6: + result = ReconfigurationResponseParameter::Result::kInProgress; + break; + default: + RTC_DLOG(LS_WARNING) << "Invalid reconfig response result: " + << result_nbr; + return absl::nullopt; + } + + if (reader->variable_data().empty()) { + return ReconfigurationResponseParameter(response_sequence_number, result); + } else if (reader->variable_data_size() != kNextTsnHeaderSize) { + RTC_DLOG(LS_WARNING) << "Invalid parameter size"; + return absl::nullopt; + } + + BoundedByteReader<kNextTsnHeaderSize> sub_reader = + reader->sub_reader<kNextTsnHeaderSize>(0); + + TSN sender_next_tsn(sub_reader.Load32<0>()); + TSN receiver_next_tsn(sub_reader.Load32<4>()); + + return ReconfigurationResponseParameter(response_sequence_number, result, + sender_next_tsn, receiver_next_tsn); +} + +void ReconfigurationResponseParameter::SerializeTo( + std::vector<uint8_t>& out) const { + size_t variable_size = + (sender_next_tsn().has_value() ? kNextTsnHeaderSize : 0); + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, variable_size); + + writer.Store32<4>(*response_sequence_number_); + uint32_t result_nbr = + static_cast<std::underlying_type<Result>::type>(result_); + writer.Store32<8>(result_nbr); + + if (sender_next_tsn().has_value()) { + BoundedByteWriter<kNextTsnHeaderSize> sub_writer = + writer.sub_writer<kNextTsnHeaderSize>(0); + + sub_writer.Store32<0>(sender_next_tsn_.has_value() ? **sender_next_tsn_ + : 0); + sub_writer.Store32<4>(receiver_next_tsn_.has_value() ? **receiver_next_tsn_ + : 0); + } +} + +std::string ReconfigurationResponseParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Re-configuration Response, resp_seq_nbr=" + << *response_sequence_number(); + return sb.Release(); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/reconfiguration_response_parameter.h b/third_party/libwebrtc/net/dcsctp/packet/parameter/reconfiguration_response_parameter.h new file mode 100644 index 0000000000..c5a68acb33 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/reconfiguration_response_parameter.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_RECONFIGURATION_RESPONSE_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_RECONFIGURATION_RESPONSE_PARAMETER_H_ +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.4 +struct ReconfigurationResponseParameterConfig : ParameterConfig { + static constexpr int kType = 16; + static constexpr size_t kHeaderSize = 12; + static constexpr size_t kVariableLengthAlignment = 4; +}; + +class ReconfigurationResponseParameter + : public Parameter, + public TLVTrait<ReconfigurationResponseParameterConfig> { + public: + static constexpr int kType = ReconfigurationResponseParameterConfig::kType; + + enum class Result { + kSuccessNothingToDo = 0, + kSuccessPerformed = 1, + kDenied = 2, + kErrorWrongSSN = 3, + kErrorRequestAlreadyInProgress = 4, + kErrorBadSequenceNumber = 5, + kInProgress = 6, + }; + + ReconfigurationResponseParameter(ReconfigRequestSN response_sequence_number, + Result result) + : response_sequence_number_(response_sequence_number), + result_(result), + sender_next_tsn_(absl::nullopt), + receiver_next_tsn_(absl::nullopt) {} + + explicit ReconfigurationResponseParameter( + ReconfigRequestSN response_sequence_number, + Result result, + TSN sender_next_tsn, + TSN receiver_next_tsn) + : response_sequence_number_(response_sequence_number), + result_(result), + sender_next_tsn_(sender_next_tsn), + receiver_next_tsn_(receiver_next_tsn) {} + + static absl::optional<ReconfigurationResponseParameter> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + ReconfigRequestSN response_sequence_number() const { + return response_sequence_number_; + } + Result result() const { return result_; } + absl::optional<TSN> sender_next_tsn() const { return sender_next_tsn_; } + absl::optional<TSN> receiver_next_tsn() const { return receiver_next_tsn_; } + + private: + static constexpr size_t kNextTsnHeaderSize = 8; + ReconfigRequestSN response_sequence_number_; + Result result_; + absl::optional<TSN> sender_next_tsn_; + absl::optional<TSN> receiver_next_tsn_; +}; + +absl::string_view ToString(ReconfigurationResponseParameter::Result result); + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_RECONFIGURATION_RESPONSE_PARAMETER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/reconfiguration_response_parameter_test.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/reconfiguration_response_parameter_test.cc new file mode 100644 index 0000000000..8125d93cd0 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/reconfiguration_response_parameter_test.cc @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(ReconfigurationResponseParameterTest, SerializeAndDeserializeFirstForm) { + ReconfigurationResponseParameter parameter( + ReconfigRequestSN(1), + ReconfigurationResponseParameter::Result::kSuccessPerformed); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + ReconfigurationResponseParameter deserialized, + ReconfigurationResponseParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.response_sequence_number(), 1u); + EXPECT_EQ(deserialized.result(), + ReconfigurationResponseParameter::Result::kSuccessPerformed); + EXPECT_EQ(deserialized.sender_next_tsn(), absl::nullopt); + EXPECT_EQ(deserialized.receiver_next_tsn(), absl::nullopt); +} + +TEST(ReconfigurationResponseParameterTest, + SerializeAndDeserializeFirstFormSecondForm) { + ReconfigurationResponseParameter parameter( + ReconfigRequestSN(1), + ReconfigurationResponseParameter::Result::kSuccessPerformed, TSN(2), + TSN(3)); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN( + ReconfigurationResponseParameter deserialized, + ReconfigurationResponseParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.response_sequence_number(), 1u); + EXPECT_EQ(deserialized.result(), + ReconfigurationResponseParameter::Result::kSuccessPerformed); + EXPECT_TRUE(deserialized.sender_next_tsn().has_value()); + EXPECT_EQ(**deserialized.sender_next_tsn(), 2u); + EXPECT_TRUE(deserialized.receiver_next_tsn().has_value()); + EXPECT_EQ(**deserialized.receiver_next_tsn(), 3u); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.cc new file mode 100644 index 0000000000..d656e0db8f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.cc @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.3 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 15 | Parameter Length = 8 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Re-configuration Request Sequence Number | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int SSNTSNResetRequestParameter::kType; + +absl::optional<SSNTSNResetRequestParameter> SSNTSNResetRequestParameter::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + ReconfigRequestSN request_sequence_number(reader->Load32<4>()); + + return SSNTSNResetRequestParameter(request_sequence_number); +} + +void SSNTSNResetRequestParameter::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out); + writer.Store32<4>(*request_sequence_number_); +} + +std::string SSNTSNResetRequestParameter::ToString() const { + rtc::StringBuilder sb; + sb << "SSN/TSN Reset Request, req_seq_nbr=" << *request_sequence_number(); + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h b/third_party/libwebrtc/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h new file mode 100644 index 0000000000..e31d7ebe8f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_SSN_TSN_RESET_REQUEST_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_SSN_TSN_RESET_REQUEST_PARAMETER_H_ +#include <stddef.h> +#include <stdint.h> + +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc6525#section-4.3 +struct SSNTSNResetRequestParameterConfig : ParameterConfig { + static constexpr int kType = 15; + static constexpr size_t kHeaderSize = 8; + static constexpr size_t kVariableLengthAlignment = 0; +}; + +class SSNTSNResetRequestParameter + : public Parameter, + public TLVTrait<SSNTSNResetRequestParameterConfig> { + public: + static constexpr int kType = SSNTSNResetRequestParameterConfig::kType; + + explicit SSNTSNResetRequestParameter( + ReconfigRequestSN request_sequence_number) + : request_sequence_number_(request_sequence_number) {} + + static absl::optional<SSNTSNResetRequestParameter> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + ReconfigRequestSN request_sequence_number() const { + return request_sequence_number_; + } + + private: + ReconfigRequestSN request_sequence_number_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_SSN_TSN_RESET_REQUEST_PARAMETER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter_test.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter_test.cc new file mode 100644 index 0000000000..eeb973cbcb --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter_test.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(SSNTSNResetRequestParameterTest, SerializeAndDeserialize) { + SSNTSNResetRequestParameter parameter(ReconfigRequestSN(1)); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(SSNTSNResetRequestParameter deserialized, + SSNTSNResetRequestParameter::Parse(serialized)); + + EXPECT_EQ(*deserialized.request_sequence_number(), 1u); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/state_cookie_parameter.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/state_cookie_parameter.cc new file mode 100644 index 0000000000..9777aa6667 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/state_cookie_parameter.cc @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" + +#include <stdint.h> + +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.3.1 + +constexpr int StateCookieParameter::kType; + +absl::optional<StateCookieParameter> StateCookieParameter::Parse( + rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + return StateCookieParameter(reader->variable_data()); +} + +void StateCookieParameter::SerializeTo(std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, data_.size()); + writer.CopyToVariableData(data_); +} + +std::string StateCookieParameter::ToString() const { + rtc::StringBuilder sb; + sb << "State Cookie parameter (cookie_length=" << data_.size() << ")"; + return sb.Release(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/state_cookie_parameter.h b/third_party/libwebrtc/net/dcsctp/packet/parameter/state_cookie_parameter.h new file mode 100644 index 0000000000..f4355495e2 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/state_cookie_parameter.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_STATE_COOKIE_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_STATE_COOKIE_PARAMETER_H_ +#include <stddef.h> +#include <stdint.h> + +#include <cstdint> +#include <string> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc4960#section-3.3.3.1 +struct StateCookieParameterConfig : ParameterConfig { + static constexpr int kType = 7; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class StateCookieParameter : public Parameter, + public TLVTrait<StateCookieParameterConfig> { + public: + static constexpr int kType = StateCookieParameterConfig::kType; + + explicit StateCookieParameter(rtc::ArrayView<const uint8_t> data) + : data_(data.begin(), data.end()) {} + + static absl::optional<StateCookieParameter> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + rtc::ArrayView<const uint8_t> data() const { return data_; } + + private: + std::vector<uint8_t> data_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_STATE_COOKIE_PARAMETER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/state_cookie_parameter_test.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/state_cookie_parameter_test.cc new file mode 100644 index 0000000000..bcca38b586 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/state_cookie_parameter_test.cc @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" + +#include <stdint.h> + +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(StateCookieParameterTest, SerializeAndDeserialize) { + uint8_t cookie[] = {1, 2, 3}; + StateCookieParameter parameter(cookie); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(StateCookieParameter deserialized, + StateCookieParameter::Parse(serialized)); + + EXPECT_THAT(deserialized.data(), ElementsAre(1, 2, 3)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/supported_extensions_parameter.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/supported_extensions_parameter.cc new file mode 100644 index 0000000000..6a8fb214de --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/supported_extensions_parameter.cc @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc5061#section-4.2.7 + +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Parameter Type = 0x8008 | Parameter Length | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | CHUNK TYPE 1 | CHUNK TYPE 2 | CHUNK TYPE 3 | CHUNK TYPE 4 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | .... | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | CHUNK TYPE N | PAD | PAD | PAD | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +constexpr int SupportedExtensionsParameter::kType; + +absl::optional<SupportedExtensionsParameter> +SupportedExtensionsParameter::Parse(rtc::ArrayView<const uint8_t> data) { + absl::optional<BoundedByteReader<kHeaderSize>> reader = ParseTLV(data); + if (!reader.has_value()) { + return absl::nullopt; + } + + std::vector<uint8_t> chunk_types(reader->variable_data().begin(), + reader->variable_data().end()); + return SupportedExtensionsParameter(std::move(chunk_types)); +} + +void SupportedExtensionsParameter::SerializeTo( + std::vector<uint8_t>& out) const { + BoundedByteWriter<kHeaderSize> writer = AllocateTLV(out, chunk_types_.size()); + writer.CopyToVariableData(chunk_types_); +} + +std::string SupportedExtensionsParameter::ToString() const { + rtc::StringBuilder sb; + sb << "Supported Extensions (" << StrJoin(chunk_types_, ", ") << ")"; + return sb.Release(); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/supported_extensions_parameter.h b/third_party/libwebrtc/net/dcsctp/packet/parameter/supported_extensions_parameter.h new file mode 100644 index 0000000000..5689fd8035 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/supported_extensions_parameter.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_PARAMETER_SUPPORTED_EXTENSIONS_PARAMETER_H_ +#define NET_DCSCTP_PACKET_PARAMETER_SUPPORTED_EXTENSIONS_PARAMETER_H_ +#include <stddef.h> + +#include <algorithm> +#include <cstdint> +#include <iterator> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" + +namespace dcsctp { + +// https://tools.ietf.org/html/rfc5061#section-4.2.7 +struct SupportedExtensionsParameterConfig : ParameterConfig { + static constexpr int kType = 0x8008; + static constexpr size_t kHeaderSize = 4; + static constexpr size_t kVariableLengthAlignment = 1; +}; + +class SupportedExtensionsParameter + : public Parameter, + public TLVTrait<SupportedExtensionsParameterConfig> { + public: + static constexpr int kType = SupportedExtensionsParameterConfig::kType; + + explicit SupportedExtensionsParameter(std::vector<uint8_t> chunk_types) + : chunk_types_(std::move(chunk_types)) {} + + static absl::optional<SupportedExtensionsParameter> Parse( + rtc::ArrayView<const uint8_t> data); + + void SerializeTo(std::vector<uint8_t>& out) const override; + std::string ToString() const override; + + bool supports(uint8_t chunk_type) const { + return std::find(chunk_types_.begin(), chunk_types_.end(), chunk_type) != + chunk_types_.end(); + } + + rtc::ArrayView<const uint8_t> chunk_types() const { return chunk_types_; } + + private: + std::vector<uint8_t> chunk_types_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_PARAMETER_SUPPORTED_EXTENSIONS_PARAMETER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/parameter/supported_extensions_parameter_test.cc b/third_party/libwebrtc/net/dcsctp/packet/parameter/supported_extensions_parameter_test.cc new file mode 100644 index 0000000000..c870af2e70 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/parameter/supported_extensions_parameter_test.cc @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; + +TEST(SupportedExtensionsParameterTest, SerializeAndDeserialize) { + SupportedExtensionsParameter parameter({1, 2, 3}); + + std::vector<uint8_t> serialized; + parameter.SerializeTo(serialized); + + ASSERT_HAS_VALUE_AND_ASSIGN(SupportedExtensionsParameter deserialized, + SupportedExtensionsParameter::Parse(serialized)); + + EXPECT_THAT(deserialized.chunk_types(), ElementsAre(1, 2, 3)); + EXPECT_TRUE(deserialized.supports(1)); + EXPECT_TRUE(deserialized.supports(2)); + EXPECT_TRUE(deserialized.supports(3)); + EXPECT_FALSE(deserialized.supports(4)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/sctp_packet.cc b/third_party/libwebrtc/net/dcsctp/packet/sctp_packet.cc new file mode 100644 index 0000000000..cc66235122 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/sctp_packet.cc @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/sctp_packet.h" + +#include <stddef.h> + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/crc32c.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_format.h" + +namespace dcsctp { +namespace { +constexpr size_t kMaxUdpPacketSize = 65535; +constexpr size_t kChunkTlvHeaderSize = 4; +constexpr size_t kExpectedDescriptorCount = 4; +} // namespace + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Source Port Number | Destination Port Number | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Verification Tag | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Checksum | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ + +SctpPacket::Builder::Builder(VerificationTag verification_tag, + const DcSctpOptions& options) + : verification_tag_(verification_tag), + source_port_(options.local_port), + dest_port_(options.remote_port), + max_packet_size_(RoundDownTo4(options.mtu)) {} + +SctpPacket::Builder& SctpPacket::Builder::Add(const Chunk& chunk) { + if (out_.empty()) { + out_.reserve(max_packet_size_); + out_.resize(SctpPacket::kHeaderSize); + BoundedByteWriter<kHeaderSize> buffer(out_); + buffer.Store16<0>(source_port_); + buffer.Store16<2>(dest_port_); + buffer.Store32<4>(*verification_tag_); + // Checksum is at offset 8 - written when calling Build(); + } + RTC_DCHECK(IsDivisibleBy4(out_.size())); + + chunk.SerializeTo(out_); + if (out_.size() % 4 != 0) { + out_.resize(RoundUpTo4(out_.size())); + } + + RTC_DCHECK(out_.size() <= max_packet_size_) + << "Exceeded max size, data=" << out_.size() + << ", max_size=" << max_packet_size_; + return *this; +} + +size_t SctpPacket::Builder::bytes_remaining() const { + if (out_.empty()) { + // The packet header (CommonHeader) hasn't been written yet: + return max_packet_size_ - kHeaderSize; + } else if (out_.size() > max_packet_size_) { + RTC_DCHECK_NOTREACHED() << "Exceeded max size, data=" << out_.size() + << ", max_size=" << max_packet_size_; + return 0; + } + return max_packet_size_ - out_.size(); +} + +std::vector<uint8_t> SctpPacket::Builder::Build() { + std::vector<uint8_t> out; + out_.swap(out); + + if (!out.empty()) { + uint32_t crc = GenerateCrc32C(out); + BoundedByteWriter<kHeaderSize>(out).Store32<8>(crc); + } + + RTC_DCHECK(out.size() <= max_packet_size_) + << "Exceeded max size, data=" << out.size() + << ", max_size=" << max_packet_size_; + + return out; +} + +absl::optional<SctpPacket> SctpPacket::Parse( + rtc::ArrayView<const uint8_t> data, + bool disable_checksum_verification) { + if (data.size() < kHeaderSize + kChunkTlvHeaderSize || + data.size() > kMaxUdpPacketSize) { + RTC_DLOG(LS_WARNING) << "Invalid packet size"; + return absl::nullopt; + } + + BoundedByteReader<kHeaderSize> reader(data); + + CommonHeader common_header; + common_header.source_port = reader.Load16<0>(); + common_header.destination_port = reader.Load16<2>(); + common_header.verification_tag = VerificationTag(reader.Load32<4>()); + common_header.checksum = reader.Load32<8>(); + + // Create a copy of the packet, which will be held by this object. + std::vector<uint8_t> data_copy = + std::vector<uint8_t>(data.begin(), data.end()); + + // Verify the checksum. The checksum field must be zero when that's done. + BoundedByteWriter<kHeaderSize>(data_copy).Store32<8>(0); + uint32_t calculated_checksum = GenerateCrc32C(data_copy); + if (!disable_checksum_verification && + calculated_checksum != common_header.checksum) { + RTC_DLOG(LS_WARNING) << rtc::StringFormat( + "Invalid packet checksum, packet_checksum=0x%08x, " + "calculated_checksum=0x%08x", + common_header.checksum, calculated_checksum); + return absl::nullopt; + } + // Restore the checksum in the header. + BoundedByteWriter<kHeaderSize>(data_copy).Store32<8>(common_header.checksum); + + // Validate and parse the chunk headers in the message. + /* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Chunk Type | Chunk Flags | Chunk Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + + std::vector<ChunkDescriptor> descriptors; + descriptors.reserve(kExpectedDescriptorCount); + rtc::ArrayView<const uint8_t> descriptor_data = + rtc::ArrayView<const uint8_t>(data_copy).subview(kHeaderSize); + while (!descriptor_data.empty()) { + if (descriptor_data.size() < kChunkTlvHeaderSize) { + RTC_DLOG(LS_WARNING) << "Too small chunk"; + return absl::nullopt; + } + BoundedByteReader<kChunkTlvHeaderSize> chunk_header(descriptor_data); + uint8_t type = chunk_header.Load8<0>(); + uint8_t flags = chunk_header.Load8<1>(); + uint16_t length = chunk_header.Load16<2>(); + uint16_t padded_length = RoundUpTo4(length); + if (padded_length > descriptor_data.size()) { + RTC_DLOG(LS_WARNING) << "Too large chunk. length=" << length + << ", remaining=" << descriptor_data.size(); + return absl::nullopt; + } else if (padded_length < kChunkTlvHeaderSize) { + RTC_DLOG(LS_WARNING) << "Too small chunk. length=" << length; + return absl::nullopt; + } + descriptors.emplace_back(type, flags, + descriptor_data.subview(0, padded_length)); + descriptor_data = descriptor_data.subview(padded_length); + } + + // Note that iterators (and pointer) are guaranteed to be stable when moving a + // std::vector, and `descriptors` have pointers to within `data_copy`. + return SctpPacket(common_header, std::move(data_copy), + std::move(descriptors)); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/sctp_packet.h b/third_party/libwebrtc/net/dcsctp/packet/sctp_packet.h new file mode 100644 index 0000000000..4c6234e0c9 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/sctp_packet.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_SCTP_PACKET_H_ +#define NET_DCSCTP_PACKET_SCTP_PACKET_H_ + +#include <stddef.h> + +#include <cstdint> +#include <functional> +#include <memory> +#include <utility> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/public/dcsctp_options.h" + +namespace dcsctp { + +// The "Common Header", which every SCTP packet starts with, and is described in +// https://tools.ietf.org/html/rfc4960#section-3.1. +struct CommonHeader { + uint16_t source_port; + uint16_t destination_port; + VerificationTag verification_tag; + uint32_t checksum; +}; + +// Represents an immutable (received or to-be-sent) SCTP packet. +class SctpPacket { + public: + static constexpr size_t kHeaderSize = 12; + + struct ChunkDescriptor { + ChunkDescriptor(uint8_t type, + uint8_t flags, + rtc::ArrayView<const uint8_t> data) + : type(type), flags(flags), data(data) {} + uint8_t type; + uint8_t flags; + rtc::ArrayView<const uint8_t> data; + }; + + SctpPacket(SctpPacket&& other) = default; + SctpPacket& operator=(SctpPacket&& other) = default; + SctpPacket(const SctpPacket&) = delete; + SctpPacket& operator=(const SctpPacket&) = delete; + + // Used for building SctpPacket, as those are immutable. + class Builder { + public: + Builder(VerificationTag verification_tag, const DcSctpOptions& options); + + Builder(Builder&& other) = default; + Builder& operator=(Builder&& other) = default; + + // Adds a chunk to the to-be-built SCTP packet. + Builder& Add(const Chunk& chunk); + + // The number of bytes remaining in the packet for chunk storage until the + // packet reaches its maximum size. + size_t bytes_remaining() const; + + // Indicates if any packets have been added to the builder. + bool empty() const { return out_.empty(); } + + // Returns the payload of the build SCTP packet. The Builder will be cleared + // after having called this function, and can be used to build a new packet. + std::vector<uint8_t> Build(); + + private: + VerificationTag verification_tag_; + uint16_t source_port_; + uint16_t dest_port_; + // The maximum packet size is always even divisible by four, as chunks are + // always padded to a size even divisible by four. + size_t max_packet_size_; + std::vector<uint8_t> out_; + }; + + // Parses `data` as an SCTP packet and returns it if it validates. + static absl::optional<SctpPacket> Parse( + rtc::ArrayView<const uint8_t> data, + bool disable_checksum_verification = false); + + // Returns the SCTP common header. + const CommonHeader& common_header() const { return common_header_; } + + // Returns the chunks (types and offsets) within the packet. + rtc::ArrayView<const ChunkDescriptor> descriptors() const { + return descriptors_; + } + + private: + SctpPacket(const CommonHeader& common_header, + std::vector<uint8_t> data, + std::vector<ChunkDescriptor> descriptors) + : common_header_(common_header), + data_(std::move(data)), + descriptors_(std::move(descriptors)) {} + + CommonHeader common_header_; + + // As the `descriptors_` refer to offset within data, and since SctpPacket is + // movable, `data` needs to be pointer stable, which it is according to + // http://www.open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2321 + std::vector<uint8_t> data_; + // The chunks and their offsets within `data_ `. + std::vector<ChunkDescriptor> descriptors_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_SCTP_PACKET_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/sctp_packet_test.cc b/third_party/libwebrtc/net/dcsctp/packet/sctp_packet_test.cc new file mode 100644 index 0000000000..7438315eec --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/sctp_packet_test.cc @@ -0,0 +1,342 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/sctp_packet.h" + +#include <cstdint> +#include <utility> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/chunk/abort_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; + +constexpr VerificationTag kVerificationTag = VerificationTag(0x12345678); + +TEST(SctpPacketTest, DeserializeSimplePacketFromCapture) { + /* + Stream Control Transmission Protocol, Src Port: 5000 (5000), Dst Port: 5000 + (5000) Source port: 5000 Destination port: 5000 Verification tag: 0x00000000 + [Association index: 1] + Checksum: 0xaa019d33 [unverified] + [Checksum Status: Unverified] + INIT chunk (Outbound streams: 1000, inbound streams: 1000) + Chunk type: INIT (1) + Chunk flags: 0x00 + Chunk length: 90 + Initiate tag: 0x0eddca08 + Advertised receiver window credit (a_rwnd): 131072 + Number of outbound streams: 1000 + Number of inbound streams: 1000 + Initial TSN: 1426601527 + ECN parameter + Parameter type: ECN (0x8000) + Parameter length: 4 + Forward TSN supported parameter + Parameter type: Forward TSN supported (0xc000) + Parameter length: 4 + Supported Extensions parameter (Supported types: FORWARD_TSN, AUTH, + ASCONF, ASCONF_ACK, RE_CONFIG) Parameter type: Supported Extensions + (0x8008) Parameter length: 9 Supported chunk type: FORWARD_TSN (192) Supported + chunk type: AUTH (15) Supported chunk type: ASCONF (193) Supported chunk type: + ASCONF_ACK (128) Supported chunk type: RE_CONFIG (130) Parameter padding: + 000000 Random parameter Parameter type: Random (0x8002) Parameter length: 36 + Random number: c5a86155090e6f420050634cc8d6b908dfd53e17c99cb143… + Requested HMAC Algorithm parameter (Supported HMACs: SHA-1) + Parameter type: Requested HMAC Algorithm (0x8004) + Parameter length: 6 + HMAC identifier: SHA-1 (1) + Parameter padding: 0000 + Authenticated Chunk list parameter (Chunk types to be authenticated: + ASCONF_ACK, ASCONF) Parameter type: Authenticated Chunk list + (0x8003) Parameter length: 6 Chunk type: ASCONF_ACK (128) Chunk type: ASCONF + (193) Chunk padding: 0000 + */ + + uint8_t data[] = { + 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0xaa, 0x01, 0x9d, 0x33, + 0x01, 0x00, 0x00, 0x5a, 0x0e, 0xdd, 0xca, 0x08, 0x00, 0x02, 0x00, 0x00, + 0x03, 0xe8, 0x03, 0xe8, 0x55, 0x08, 0x36, 0x37, 0x80, 0x00, 0x00, 0x04, + 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, + 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xc5, 0xa8, 0x61, 0x55, + 0x09, 0x0e, 0x6f, 0x42, 0x00, 0x50, 0x63, 0x4c, 0xc8, 0xd6, 0xb9, 0x08, + 0xdf, 0xd5, 0x3e, 0x17, 0xc9, 0x9c, 0xb1, 0x43, 0x28, 0x4e, 0xaf, 0x64, + 0x68, 0x2a, 0xc2, 0x97, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, + 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(data)); + EXPECT_EQ(packet.common_header().source_port, 5000); + EXPECT_EQ(packet.common_header().destination_port, 5000); + EXPECT_EQ(packet.common_header().verification_tag, VerificationTag(0)); + EXPECT_EQ(packet.common_header().checksum, 0xaa019d33); + + EXPECT_THAT(packet.descriptors(), SizeIs(1)); + EXPECT_EQ(packet.descriptors()[0].type, InitChunk::kType); + ASSERT_HAS_VALUE_AND_ASSIGN(InitChunk init, + InitChunk::Parse(packet.descriptors()[0].data)); + EXPECT_EQ(init.initial_tsn(), TSN(1426601527)); +} + +TEST(SctpPacketTest, DeserializePacketWithTwoChunks) { + /* + Stream Control Transmission Protocol, Src Port: 1234 (1234), + Dst Port: 4321 (4321) + Source port: 1234 + Destination port: 4321 + Verification tag: 0x697e3a4e + [Association index: 3] + Checksum: 0xc06e8b36 [unverified] + [Checksum Status: Unverified] + COOKIE_ACK chunk + Chunk type: COOKIE_ACK (11) + Chunk flags: 0x00 + Chunk length: 4 + SACK chunk (Cumulative TSN: 2930332242, a_rwnd: 131072, + gaps: 0, duplicate TSNs: 0) + Chunk type: SACK (3) + Chunk flags: 0x00 + Chunk length: 16 + Cumulative TSN ACK: 2930332242 + Advertised receiver window credit (a_rwnd): 131072 + Number of gap acknowledgement blocks: 0 + Number of duplicated TSNs: 0 + */ + + uint8_t data[] = {0x04, 0xd2, 0x10, 0xe1, 0x69, 0x7e, 0x3a, 0x4e, + 0xc0, 0x6e, 0x8b, 0x36, 0x0b, 0x00, 0x00, 0x04, + 0x03, 0x00, 0x00, 0x10, 0xae, 0xa9, 0x52, 0x52, + 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(data)); + EXPECT_EQ(packet.common_header().source_port, 1234); + EXPECT_EQ(packet.common_header().destination_port, 4321); + EXPECT_EQ(packet.common_header().verification_tag, + VerificationTag(0x697e3a4eu)); + EXPECT_EQ(packet.common_header().checksum, 0xc06e8b36u); + + EXPECT_THAT(packet.descriptors(), SizeIs(2)); + EXPECT_EQ(packet.descriptors()[0].type, CookieAckChunk::kType); + EXPECT_EQ(packet.descriptors()[1].type, SackChunk::kType); + ASSERT_HAS_VALUE_AND_ASSIGN( + CookieAckChunk cookie_ack, + CookieAckChunk::Parse(packet.descriptors()[0].data)); + ASSERT_HAS_VALUE_AND_ASSIGN(SackChunk sack, + SackChunk::Parse(packet.descriptors()[1].data)); +} + +TEST(SctpPacketTest, DeserializePacketWithWrongChecksum) { + /* + Stream Control Transmission Protocol, Src Port: 5000 (5000), + Dst Port: 5000 (5000) + Source port: 5000 + Destination port: 5000 + Verification tag: 0x0eddca08 + [Association index: 1] + Checksum: 0x2a81f531 [unverified] + [Checksum Status: Unverified] + SACK chunk (Cumulative TSN: 1426601536, a_rwnd: 131072, + gaps: 0, duplicate TSNs: 0) + Chunk type: SACK (3) + Chunk flags: 0x00 + Chunk length: 16 + Cumulative TSN ACK: 1426601536 + Advertised receiver window credit (a_rwnd): 131072 + Number of gap acknowledgement blocks: 0 + Number of duplicated TSNs: 0 + */ + + uint8_t data[] = {0x13, 0x88, 0x13, 0x88, 0x0e, 0xdd, 0xca, 0x08, 0x2a, 0x81, + 0xf5, 0x31, 0x03, 0x00, 0x00, 0x10, 0x55, 0x08, 0x36, 0x40, + 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + EXPECT_FALSE(SctpPacket::Parse(data).has_value()); +} + +TEST(SctpPacketTest, DeserializePacketDontValidateChecksum) { + /* + Stream Control Transmission Protocol, Src Port: 5000 (5000), + Dst Port: 5000 (5000) + Source port: 5000 + Destination port: 5000 + Verification tag: 0x0eddca08 + [Association index: 1] + Checksum: 0x2a81f531 [unverified] + [Checksum Status: Unverified] + SACK chunk (Cumulative TSN: 1426601536, a_rwnd: 131072, + gaps: 0, duplicate TSNs: 0) + Chunk type: SACK (3) + Chunk flags: 0x00 + Chunk length: 16 + Cumulative TSN ACK: 1426601536 + Advertised receiver window credit (a_rwnd): 131072 + Number of gap acknowledgement blocks: 0 + Number of duplicated TSNs: 0 + */ + + uint8_t data[] = {0x13, 0x88, 0x13, 0x88, 0x0e, 0xdd, 0xca, 0x08, 0x2a, 0x81, + 0xf5, 0x31, 0x03, 0x00, 0x00, 0x10, 0x55, 0x08, 0x36, 0x40, + 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + + ASSERT_HAS_VALUE_AND_ASSIGN( + SctpPacket packet, + SctpPacket::Parse(data, /*disable_checksum_verification=*/true)); + EXPECT_EQ(packet.common_header().source_port, 5000); + EXPECT_EQ(packet.common_header().destination_port, 5000); + EXPECT_EQ(packet.common_header().verification_tag, + VerificationTag(0x0eddca08u)); + EXPECT_EQ(packet.common_header().checksum, 0x2a81f531u); +} + +TEST(SctpPacketTest, SerializeAndDeserializeSingleChunk) { + SctpPacket::Builder b(kVerificationTag, {}); + InitChunk init(/*initiate_tag=*/VerificationTag(123), /*a_rwnd=*/456, + /*nbr_outbound_streams=*/65535, + /*nbr_inbound_streams=*/65534, /*initial_tsn=*/TSN(789), + /*parameters=*/Parameters()); + + b.Add(init); + std::vector<uint8_t> serialized = b.Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(serialized)); + + EXPECT_EQ(packet.common_header().verification_tag, kVerificationTag); + + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + EXPECT_EQ(packet.descriptors()[0].type, InitChunk::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN(InitChunk deserialized, + InitChunk::Parse(packet.descriptors()[0].data)); + EXPECT_EQ(deserialized.initiate_tag(), VerificationTag(123)); + EXPECT_EQ(deserialized.a_rwnd(), 456u); + EXPECT_EQ(deserialized.nbr_outbound_streams(), 65535u); + EXPECT_EQ(deserialized.nbr_inbound_streams(), 65534u); + EXPECT_EQ(deserialized.initial_tsn(), TSN(789)); +} + +TEST(SctpPacketTest, SerializeAndDeserializeThreeChunks) { + SctpPacket::Builder b(kVerificationTag, {}); + b.Add(SackChunk(/*cumulative_tsn_ack=*/TSN(999), /*a_rwnd=*/456, + {SackChunk::GapAckBlock(2, 3)}, + /*duplicate_tsns=*/{TSN(1), TSN(2), TSN(3)})); + b.Add(DataChunk(TSN(123), StreamID(456), SSN(789), PPID(9090), + /*payload=*/{1, 2, 3, 4, 5}, + /*options=*/{})); + b.Add(DataChunk(TSN(124), StreamID(654), SSN(987), PPID(909), + /*payload=*/{5, 4, 3, 3, 1}, + /*options=*/{})); + + std::vector<uint8_t> serialized = b.Build(); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(serialized)); + + EXPECT_EQ(packet.common_header().verification_tag, kVerificationTag); + + ASSERT_THAT(packet.descriptors(), SizeIs(3)); + EXPECT_EQ(packet.descriptors()[0].type, SackChunk::kType); + EXPECT_EQ(packet.descriptors()[1].type, DataChunk::kType); + EXPECT_EQ(packet.descriptors()[2].type, DataChunk::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN(SackChunk sack, + SackChunk::Parse(packet.descriptors()[0].data)); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(999)); + EXPECT_EQ(sack.a_rwnd(), 456u); + + ASSERT_HAS_VALUE_AND_ASSIGN(DataChunk data1, + DataChunk::Parse(packet.descriptors()[1].data)); + EXPECT_EQ(data1.tsn(), TSN(123)); + + ASSERT_HAS_VALUE_AND_ASSIGN(DataChunk data2, + DataChunk::Parse(packet.descriptors()[2].data)); + EXPECT_EQ(data2.tsn(), TSN(124)); +} + +TEST(SctpPacketTest, ParseAbortWithEmptyCause) { + SctpPacket::Builder b(kVerificationTag, {}); + b.Add(AbortChunk( + /*filled_in_verification_tag=*/true, + Parameters::Builder().Add(UserInitiatedAbortCause("")).Build())); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(b.Build())); + + EXPECT_EQ(packet.common_header().verification_tag, kVerificationTag); + + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + EXPECT_EQ(packet.descriptors()[0].type, AbortChunk::kType); + + ASSERT_HAS_VALUE_AND_ASSIGN(AbortChunk abort, + AbortChunk::Parse(packet.descriptors()[0].data)); + ASSERT_HAS_VALUE_AND_ASSIGN( + UserInitiatedAbortCause cause, + abort.error_causes().get<UserInitiatedAbortCause>()); + EXPECT_EQ(cause.upper_layer_abort_reason(), ""); +} + +TEST(SctpPacketTest, DetectPacketWithZeroSizeChunk) { + uint8_t data[] = {0xff, 0xff, 0xff, 0xff, 0xff, 0x0a, 0x0a, 0x0a, 0x5c, + 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x00, 0x00, 0x00}; + + EXPECT_FALSE(SctpPacket::Parse(data, true).has_value()); +} + +TEST(SctpPacketTest, ReturnsCorrectSpaceAvailableToStayWithinMTU) { + DcSctpOptions options; + options.mtu = 1191; + + SctpPacket::Builder builder(VerificationTag(123), options); + + // Chunks will be padded to an even 4 bytes, so the maximum packet size should + // be rounded down. + const size_t kMaxPacketSize = RoundDownTo4(options.mtu); + EXPECT_EQ(kMaxPacketSize, 1188u); + + const size_t kSctpHeaderSize = 12; + EXPECT_EQ(builder.bytes_remaining(), kMaxPacketSize - kSctpHeaderSize); + EXPECT_EQ(builder.bytes_remaining(), 1176u); + + // Add a smaller packet first. + DataChunk::Options data_options; + + std::vector<uint8_t> payload1(183); + builder.Add( + DataChunk(TSN(1), StreamID(1), SSN(0), PPID(53), payload1, data_options)); + + size_t chunk1_size = RoundUpTo4(DataChunk::kHeaderSize + payload1.size()); + EXPECT_EQ(builder.bytes_remaining(), + kMaxPacketSize - kSctpHeaderSize - chunk1_size); + EXPECT_EQ(builder.bytes_remaining(), 976u); // Hand-calculated. + + std::vector<uint8_t> payload2(957); + builder.Add( + DataChunk(TSN(1), StreamID(1), SSN(0), PPID(53), payload2, data_options)); + + size_t chunk2_size = RoundUpTo4(DataChunk::kHeaderSize + payload2.size()); + EXPECT_EQ(builder.bytes_remaining(), + kMaxPacketSize - kSctpHeaderSize - chunk1_size - chunk2_size); + EXPECT_EQ(builder.bytes_remaining(), 0u); // Hand-calculated. +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/tlv_trait.cc b/third_party/libwebrtc/net/dcsctp/packet/tlv_trait.cc new file mode 100644 index 0000000000..493b6a4613 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/tlv_trait.cc @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/tlv_trait.h" + +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace tlv_trait_impl { +void ReportInvalidSize(size_t actual_size, size_t expected_size) { + RTC_DLOG(LS_WARNING) << "Invalid size (" << actual_size + << ", expected minimum " << expected_size << " bytes)"; +} + +void ReportInvalidType(int actual_type, int expected_type) { + RTC_DLOG(LS_WARNING) << "Invalid type (" << actual_type << ", expected " + << expected_type << ")"; +} + +void ReportInvalidFixedLengthField(size_t value, size_t expected) { + RTC_DLOG(LS_WARNING) << "Invalid length field (" << value << ", expected " + << expected << " bytes)"; +} + +void ReportInvalidVariableLengthField(size_t value, size_t available) { + RTC_DLOG(LS_WARNING) << "Invalid length field (" << value << ", available " + << available << " bytes)"; +} + +void ReportInvalidPadding(size_t padding_bytes) { + RTC_DLOG(LS_WARNING) << "Invalid padding (" << padding_bytes << " bytes)"; +} + +void ReportInvalidLengthMultiple(size_t length, size_t alignment) { + RTC_DLOG(LS_WARNING) << "Invalid length field (" << length + << ", expected an even multiple of " << alignment + << " bytes)"; +} +} // namespace tlv_trait_impl +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/packet/tlv_trait.h b/third_party/libwebrtc/net/dcsctp/packet/tlv_trait.h new file mode 100644 index 0000000000..a3c728efd7 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/tlv_trait.h @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PACKET_TLV_TRAIT_H_ +#define NET_DCSCTP_PACKET_TLV_TRAIT_H_ + +#include <stdint.h> +#include <string.h> + +#include <algorithm> +#include <cstddef> +#include <cstdint> +#include <string> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" + +namespace dcsctp { +namespace tlv_trait_impl { +// Logging functions, only to be used by TLVTrait, which is a templated class. +void ReportInvalidSize(size_t actual_size, size_t expected_size); +void ReportInvalidType(int actual_type, int expected_type); +void ReportInvalidFixedLengthField(size_t value, size_t expected); +void ReportInvalidVariableLengthField(size_t value, size_t available); +void ReportInvalidPadding(size_t padding_bytes); +void ReportInvalidLengthMultiple(size_t length, size_t alignment); +} // namespace tlv_trait_impl + +// Various entities in SCTP are padded data blocks, with a type and length +// field at fixed offsets, all stored in a 4-byte header. +// +// See e.g. https://tools.ietf.org/html/rfc4960#section-3.2 and +// https://tools.ietf.org/html/rfc4960#section-3.2.1 +// +// These are helper classes for writing and parsing that data, which in SCTP is +// called Type-Length-Value, or TLV. +// +// This templated class is configurable - a struct passed in as template +// parameter with the following expected members: +// * kType - The type field's value +// * kTypeSizeInBytes - The type field's width in bytes. +// Either 1 or 2. +// * kHeaderSize - The fixed size header +// * kVariableLengthAlignment - The size alignment on the variable data. Set +// to zero (0) if no variable data is used. +// +// This class is to be used as a trait +// (https://en.wikipedia.org/wiki/Trait_(computer_programming)) that adds a few +// public and protected members and which a class inherits from when it +// represents a type-length-value object. +template <typename Config> +class TLVTrait { + private: + static constexpr size_t kTlvHeaderSize = 4; + + protected: + static constexpr size_t kHeaderSize = Config::kHeaderSize; + + static_assert(Config::kTypeSizeInBytes == 1 || Config::kTypeSizeInBytes == 2, + "kTypeSizeInBytes must be 1 or 2"); + static_assert(Config::kHeaderSize >= kTlvHeaderSize, + "HeaderSize must be >= 4 bytes"); + static_assert((Config::kHeaderSize % 4 == 0), + "kHeaderSize must be an even multiple of 4 bytes"); + static_assert((Config::kVariableLengthAlignment == 0 || + Config::kVariableLengthAlignment == 1 || + Config::kVariableLengthAlignment == 2 || + Config::kVariableLengthAlignment == 4 || + Config::kVariableLengthAlignment == 8), + "kVariableLengthAlignment must be an allowed value"); + + // Validates the data with regards to size, alignment and type. + // If valid, returns a bounded buffer. + static absl::optional<BoundedByteReader<Config::kHeaderSize>> ParseTLV( + rtc::ArrayView<const uint8_t> data) { + if (data.size() < Config::kHeaderSize) { + tlv_trait_impl::ReportInvalidSize(data.size(), Config::kHeaderSize); + return absl::nullopt; + } + BoundedByteReader<kTlvHeaderSize> tlv_header(data); + + const int type = (Config::kTypeSizeInBytes == 1) + ? tlv_header.template Load8<0>() + : tlv_header.template Load16<0>(); + + if (type != Config::kType) { + tlv_trait_impl::ReportInvalidType(type, Config::kType); + return absl::nullopt; + } + const uint16_t length = tlv_header.template Load16<2>(); + if (Config::kVariableLengthAlignment == 0) { + // Don't expect any variable length data at all. + if (length != Config::kHeaderSize || data.size() != Config::kHeaderSize) { + tlv_trait_impl::ReportInvalidFixedLengthField(length, + Config::kHeaderSize); + return absl::nullopt; + } + } else { + // Expect variable length data - verify its size alignment. + if (length > data.size() || length < Config::kHeaderSize) { + tlv_trait_impl::ReportInvalidVariableLengthField(length, data.size()); + return absl::nullopt; + } + const size_t padding = data.size() - length; + if (padding > 3) { + // https://tools.ietf.org/html/rfc4960#section-3.2 + // "This padding MUST NOT be more than 3 bytes in total" + tlv_trait_impl::ReportInvalidPadding(padding); + return absl::nullopt; + } + if (!ValidateLengthAlignment(length, Config::kVariableLengthAlignment)) { + tlv_trait_impl::ReportInvalidLengthMultiple( + length, Config::kVariableLengthAlignment); + return absl::nullopt; + } + } + return BoundedByteReader<Config::kHeaderSize>(data.subview(0, length)); + } + + // Allocates space for data with a static header size, as defined by + // `Config::kHeaderSize` and a variable footer, as defined by `variable_size` + // (which may be 0) and writes the type and length in the header. + static BoundedByteWriter<Config::kHeaderSize> AllocateTLV( + std::vector<uint8_t>& out, + size_t variable_size = 0) { + const size_t offset = out.size(); + const size_t size = Config::kHeaderSize + variable_size; + out.resize(offset + size); + + BoundedByteWriter<kTlvHeaderSize> tlv_header( + rtc::ArrayView<uint8_t>(out.data() + offset, kTlvHeaderSize)); + if (Config::kTypeSizeInBytes == 1) { + tlv_header.template Store8<0>(static_cast<uint8_t>(Config::kType)); + } else { + tlv_header.template Store16<0>(Config::kType); + } + tlv_header.template Store16<2>(size); + + return BoundedByteWriter<Config::kHeaderSize>( + rtc::ArrayView<uint8_t>(out.data() + offset, size)); + } + + private: + static bool ValidateLengthAlignment(uint16_t length, size_t alignment) { + // This is to avoid MSVC believing there could be a "mod by zero", when it + // certainly can't. + if (alignment == 0) { + return true; + } + return (length % alignment) == 0; + } +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PACKET_TLV_TRAIT_H_ diff --git a/third_party/libwebrtc/net/dcsctp/packet/tlv_trait_test.cc b/third_party/libwebrtc/net/dcsctp/packet/tlv_trait_test.cc new file mode 100644 index 0000000000..a0dd1a1136 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/packet/tlv_trait_test.cc @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/packet/tlv_trait.h" + +#include <vector> + +#include "api/array_view.h" +#include "rtc_base/buffer.h" +#include "rtc_base/checks.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::SizeIs; + +struct OneByteTypeConfig { + static constexpr int kTypeSizeInBytes = 1; + static constexpr int kType = 0x49; + static constexpr size_t kHeaderSize = 12; + static constexpr int kVariableLengthAlignment = 4; +}; + +class OneByteChunk : public TLVTrait<OneByteTypeConfig> { + public: + static constexpr size_t kVariableSize = 4; + + void SerializeTo(std::vector<uint8_t>& out) { + BoundedByteWriter<OneByteTypeConfig::kHeaderSize> writer = + AllocateTLV(out, kVariableSize); + writer.Store32<4>(0x01020304); + writer.Store16<8>(0x0506); + writer.Store16<10>(0x0708); + + uint8_t variable_data[kVariableSize] = {0xDE, 0xAD, 0xBE, 0xEF}; + writer.CopyToVariableData(rtc::ArrayView<const uint8_t>(variable_data)); + } + + static absl::optional<BoundedByteReader<OneByteTypeConfig::kHeaderSize>> + Parse(rtc::ArrayView<const uint8_t> data) { + return ParseTLV(data); + } +}; + +TEST(TlvDataTest, CanWriteOneByteTypeTlvs) { + std::vector<uint8_t> out; + OneByteChunk().SerializeTo(out); + + EXPECT_THAT(out, SizeIs(OneByteTypeConfig::kHeaderSize + + OneByteChunk::kVariableSize)); + EXPECT_THAT(out, ElementsAre(0x49, 0x00, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF)); +} + +TEST(TlvDataTest, CanReadOneByteTypeTlvs) { + uint8_t data[] = {0x49, 0x00, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF}; + + absl::optional<BoundedByteReader<OneByteTypeConfig::kHeaderSize>> reader = + OneByteChunk::Parse(data); + ASSERT_TRUE(reader.has_value()); + EXPECT_EQ(reader->Load32<4>(), 0x01020304U); + EXPECT_EQ(reader->Load16<8>(), 0x0506U); + EXPECT_EQ(reader->Load16<10>(), 0x0708U); + EXPECT_THAT(reader->variable_data(), ElementsAre(0xDE, 0xAD, 0xBE, 0xEF)); +} + +struct TwoByteTypeConfig { + static constexpr int kTypeSizeInBytes = 2; + static constexpr int kType = 31337; + static constexpr size_t kHeaderSize = 8; + static constexpr int kVariableLengthAlignment = 2; +}; + +class TwoByteChunk : public TLVTrait<TwoByteTypeConfig> { + public: + static constexpr size_t kVariableSize = 8; + + void SerializeTo(std::vector<uint8_t>& out) { + BoundedByteWriter<TwoByteTypeConfig::kHeaderSize> writer = + AllocateTLV(out, kVariableSize); + writer.Store32<4>(0x01020304U); + + uint8_t variable_data[] = {0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF}; + writer.CopyToVariableData(rtc::ArrayView<const uint8_t>(variable_data)); + } + + static absl::optional<BoundedByteReader<TwoByteTypeConfig::kHeaderSize>> + Parse(rtc::ArrayView<const uint8_t> data) { + return ParseTLV(data); + } +}; + +TEST(TlvDataTest, CanWriteTwoByteTypeTlvs) { + std::vector<uint8_t> out; + + TwoByteChunk().SerializeTo(out); + + EXPECT_THAT(out, SizeIs(TwoByteTypeConfig::kHeaderSize + + TwoByteChunk::kVariableSize)); + EXPECT_THAT(out, ElementsAre(0x7A, 0x69, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF)); +} + +TEST(TlvDataTest, CanReadTwoByteTypeTlvs) { + uint8_t data[] = {0x7A, 0x69, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF}; + + absl::optional<BoundedByteReader<TwoByteTypeConfig::kHeaderSize>> reader = + TwoByteChunk::Parse(data); + EXPECT_TRUE(reader.has_value()); + EXPECT_EQ(reader->Load32<4>(), 0x01020304U); + EXPECT_THAT(reader->variable_data(), + ElementsAre(0x05, 0x06, 0x07, 0x08, 0xDE, 0xAD, 0xBE, 0xEF)); +} + +TEST(TlvDataTest, CanHandleInvalidLengthSmallerThanFixedSize) { + // Has 'length=6', which is below the kHeaderSize of 8. + uint8_t data[] = {0x7A, 0x69, 0x00, 0x06, 0x01, 0x02, 0x03, 0x04}; + + EXPECT_FALSE(TwoByteChunk::Parse(data).has_value()); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/public/BUILD.gn b/third_party/libwebrtc/net/dcsctp/public/BUILD.gn new file mode 100644 index 0000000000..6cb289bf5b --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/BUILD.gn @@ -0,0 +1,103 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("types") { + deps = [ + "../../../api:array_view", + "../../../rtc_base:strong_alias", + ] + sources = [ + "dcsctp_message.h", + "dcsctp_options.h", + "types.h", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] +} + +rtc_source_set("socket") { + deps = [ + ":types", + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base:checks", + "../../../rtc_base:strong_alias", + ] + sources = [ + "dcsctp_handover_state.cc", + "dcsctp_handover_state.h", + "dcsctp_socket.h", + "packet_observer.h", + "timeout.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_source_set("factory") { + deps = [ + ":socket", + ":types", + "../socket:dcsctp_socket", + ] + sources = [ + "dcsctp_socket_factory.cc", + "dcsctp_socket_factory.h", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_source_set("mocks") { + testonly = true + sources = [ + "mock_dcsctp_socket.h", + "mock_dcsctp_socket_factory.h", + ] + deps = [ + ":factory", + ":socket", + "../../../test:test_support", + ] +} + +rtc_source_set("utils") { + deps = [ + ":socket", + ":types", + "../../../api:array_view", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../socket:dcsctp_socket", + ] + sources = [ + "text_pcap_packet_observer.cc", + "text_pcap_packet_observer.h", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_public_unittests") { + testonly = true + + deps = [ + ":mocks", + ":types", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../test:test_support", + ] + sources = [ + "mock_dcsctp_socket_test.cc", + "types_test.cc", + ] + } +} diff --git a/third_party/libwebrtc/net/dcsctp/public/dcsctp_handover_state.cc b/third_party/libwebrtc/net/dcsctp/public/dcsctp_handover_state.cc new file mode 100644 index 0000000000..6a1bd06eba --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/dcsctp_handover_state.cc @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/public/dcsctp_handover_state.h" + +#include <string> + +#include "absl/strings/string_view.h" + +namespace dcsctp { +namespace { +constexpr absl::string_view HandoverUnreadinessReasonToString( + HandoverUnreadinessReason reason) { + switch (reason) { + case HandoverUnreadinessReason::kWrongConnectionState: + return "WRONG_CONNECTION_STATE"; + case HandoverUnreadinessReason::kSendQueueNotEmpty: + return "SEND_QUEUE_NOT_EMPTY"; + case HandoverUnreadinessReason::kDataTrackerTsnBlocksPending: + return "DATA_TRACKER_TSN_BLOCKS_PENDING"; + case HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap: + return "REASSEMBLY_QUEUE_DELIVERED_TSN_GAP"; + case HandoverUnreadinessReason::kStreamResetDeferred: + return "STREAM_RESET_DEFERRED"; + case HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks: + return "ORDERED_STREAM_HAS_UNASSEMBLED_CHUNKS"; + case HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks: + return "UNORDERED_STREAM_HAS_UNASSEMBLED_CHUNKS"; + case HandoverUnreadinessReason::kRetransmissionQueueOutstandingData: + return "RETRANSMISSION_QUEUE_OUTSTANDING_DATA"; + case HandoverUnreadinessReason::kRetransmissionQueueFastRecovery: + return "RETRANSMISSION_QUEUE_FAST_RECOVERY"; + case HandoverUnreadinessReason::kRetransmissionQueueNotEmpty: + return "RETRANSMISSION_QUEUE_NOT_EMPTY"; + case HandoverUnreadinessReason::kPendingStreamReset: + return "PENDING_STREAM_RESET"; + case HandoverUnreadinessReason::kPendingStreamResetRequest: + return "PENDING_STREAM_RESET_REQUEST"; + } +} +} // namespace + +std::string HandoverReadinessStatus::ToString() const { + std::string result; + for (uint32_t bit = 1; + bit <= static_cast<uint32_t>(HandoverUnreadinessReason::kMax); + bit *= 2) { + auto flag = static_cast<HandoverUnreadinessReason>(bit); + if (Contains(flag)) { + if (!result.empty()) { + result.append(","); + } + absl::string_view s = HandoverUnreadinessReasonToString(flag); + result.append(s.data(), s.size()); + } + } + if (result.empty()) { + result = "READY"; + } + return result; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/public/dcsctp_handover_state.h b/third_party/libwebrtc/net/dcsctp/public/dcsctp_handover_state.h new file mode 100644 index 0000000000..36fc37ba89 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/dcsctp_handover_state.h @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_DCSCTP_HANDOVER_STATE_H_ +#define NET_DCSCTP_PUBLIC_DCSCTP_HANDOVER_STATE_H_ + +#include <cstdint> +#include <string> +#include <vector> + +#include "rtc_base/strong_alias.h" + +namespace dcsctp { + +// Stores state snapshot of a dcSCTP socket. The snapshot can be used to +// recreate the socket - possibly in another process. This state should be +// treaded as opaque - the calling client should not inspect or alter it except +// for serialization. Serialization is not provided by dcSCTP. If needed it has +// to be implemented in the calling client. +struct DcSctpSocketHandoverState { + enum class SocketState { + kClosed, + kConnected, + }; + SocketState socket_state = SocketState::kClosed; + + uint32_t my_verification_tag = 0; + uint32_t my_initial_tsn = 0; + uint32_t peer_verification_tag = 0; + uint32_t peer_initial_tsn = 0; + uint64_t tie_tag = 0; + + struct Capabilities { + bool partial_reliability = false; + bool message_interleaving = false; + bool reconfig = false; + }; + Capabilities capabilities; + + struct OutgoingStream { + uint32_t id = 0; + uint32_t next_ssn = 0; + uint32_t next_unordered_mid = 0; + uint32_t next_ordered_mid = 0; + uint16_t priority = 0; + }; + struct Transmission { + uint32_t next_tsn = 0; + uint32_t next_reset_req_sn = 0; + uint32_t cwnd = 0; + uint32_t rwnd = 0; + uint32_t ssthresh = 0; + uint32_t partial_bytes_acked = 0; + std::vector<OutgoingStream> streams; + }; + Transmission tx; + + struct OrderedStream { + uint32_t id = 0; + uint32_t next_ssn = 0; + }; + struct UnorderedStream { + uint32_t id = 0; + }; + struct Receive { + bool seen_packet = false; + uint32_t last_cumulative_acked_tsn = 0; + uint32_t last_assembled_tsn = 0; + uint32_t last_completed_deferred_reset_req_sn = 0; + uint32_t last_completed_reset_req_sn = 0; + std::vector<OrderedStream> ordered_streams; + std::vector<UnorderedStream> unordered_streams; + }; + Receive rx; +}; + +// A list of possible reasons for a socket to be not ready for handover. +enum class HandoverUnreadinessReason : uint32_t { + kWrongConnectionState = 1, + kSendQueueNotEmpty = 2, + kPendingStreamResetRequest = 4, + kDataTrackerTsnBlocksPending = 8, + kPendingStreamReset = 16, + kReassemblyQueueDeliveredTSNsGap = 32, + kStreamResetDeferred = 64, + kOrderedStreamHasUnassembledChunks = 128, + kUnorderedStreamHasUnassembledChunks = 256, + kRetransmissionQueueOutstandingData = 512, + kRetransmissionQueueFastRecovery = 1024, + kRetransmissionQueueNotEmpty = 2048, + kMax = kRetransmissionQueueNotEmpty, +}; + +// Return value of `DcSctpSocketInterface::GetHandoverReadiness`. Set of +// `HandoverUnreadinessReason` bits. When no bit is set, the socket is in the +// state in which a snapshot of the state can be made by +// `GetHandoverStateAndClose()`. +class HandoverReadinessStatus + : public webrtc::StrongAlias<class HandoverReadinessStatusTag, uint32_t> { + public: + // Constructs an empty `HandoverReadinessStatus` which represents ready state. + constexpr HandoverReadinessStatus() + : webrtc::StrongAlias<class HandoverReadinessStatusTag, uint32_t>(0) {} + // Constructs status object that contains a single reason for not being + // handover ready. + constexpr explicit HandoverReadinessStatus(HandoverUnreadinessReason reason) + : webrtc::StrongAlias<class HandoverReadinessStatusTag, uint32_t>( + static_cast<uint32_t>(reason)) {} + + // Convenience methods + constexpr bool IsReady() const { return value() == 0; } + constexpr bool Contains(HandoverUnreadinessReason reason) const { + return value() & static_cast<uint32_t>(reason); + } + HandoverReadinessStatus& Add(HandoverUnreadinessReason reason) { + return Add(HandoverReadinessStatus(reason)); + } + HandoverReadinessStatus& Add(HandoverReadinessStatus status) { + value() |= status.value(); + return *this; + } + std::string ToString() const; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_DCSCTP_HANDOVER_STATE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/public/dcsctp_message.h b/third_party/libwebrtc/net/dcsctp/public/dcsctp_message.h new file mode 100644 index 0000000000..38e6763916 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/dcsctp_message.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_DCSCTP_MESSAGE_H_ +#define NET_DCSCTP_PUBLIC_DCSCTP_MESSAGE_H_ + +#include <cstdint> +#include <utility> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// An SCTP message is a group of bytes sent and received as a whole on a +// specified stream identifier (`stream_id`), and with a payload protocol +// identifier (`ppid`). +class DcSctpMessage { + public: + DcSctpMessage(StreamID stream_id, PPID ppid, std::vector<uint8_t> payload) + : stream_id_(stream_id), ppid_(ppid), payload_(std::move(payload)) {} + + DcSctpMessage(DcSctpMessage&& other) = default; + DcSctpMessage& operator=(DcSctpMessage&& other) = default; + DcSctpMessage(const DcSctpMessage&) = delete; + DcSctpMessage& operator=(const DcSctpMessage&) = delete; + + // The stream identifier to which the message is sent. + StreamID stream_id() const { return stream_id_; } + + // The payload protocol identifier (ppid) associated with the message. + PPID ppid() const { return ppid_; } + + // The payload of the message. + rtc::ArrayView<const uint8_t> payload() const { return payload_; } + + // When destructing the message, extracts the payload. + std::vector<uint8_t> ReleasePayload() && { return std::move(payload_); } + + private: + StreamID stream_id_; + PPID ppid_; + std::vector<uint8_t> payload_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_DCSCTP_MESSAGE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/public/dcsctp_options.h b/third_party/libwebrtc/net/dcsctp/public/dcsctp_options.h new file mode 100644 index 0000000000..4511bed4a4 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/dcsctp_options.h @@ -0,0 +1,201 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_DCSCTP_OPTIONS_H_ +#define NET_DCSCTP_PUBLIC_DCSCTP_OPTIONS_H_ + +#include <stddef.h> +#include <stdint.h> + +#include "absl/types/optional.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { +struct DcSctpOptions { + // The largest safe SCTP packet. Starting from the minimum guaranteed MTU + // value of 1280 for IPv6 (which may not support fragmentation), take off 85 + // bytes for DTLS/TURN/TCP/IP and ciphertext overhead. + // + // Additionally, it's possible that TURN adds an additional 4 bytes of + // overhead after a channel has been established, so an additional 4 bytes is + // subtracted + // + // 1280 IPV6 MTU + // -40 IPV6 header + // -8 UDP + // -24 GCM Cipher + // -13 DTLS record header + // -4 TURN ChannelData + // = 1191 bytes. + static constexpr size_t kMaxSafeMTUSize = 1191; + + // The local port for which the socket is supposed to be bound to. Incoming + // packets will be verified that they are sent to this port number and all + // outgoing packets will have this port number as source port. + int local_port = 5000; + + // The remote port to send packets to. All outgoing packets will have this + // port number as destination port. + int remote_port = 5000; + + // The announced maximum number of incoming streams. Note that this value is + // constant and can't be currently increased in run-time as "Add Incoming + // Streams Request" in RFC6525 isn't supported. + // + // The socket implementation doesn't have any per-stream fixed costs, which is + // why the default value is set to be the maximum value. + uint16_t announced_maximum_incoming_streams = 65535; + + // The announced maximum number of outgoing streams. Note that this value is + // constant and can't be currently increased in run-time as "Add Outgoing + // Streams Request" in RFC6525 isn't supported. + // + // The socket implementation doesn't have any per-stream fixed costs, which is + // why the default value is set to be the maximum value. + uint16_t announced_maximum_outgoing_streams = 65535; + + // Maximum SCTP packet size. The library will limit the size of generated + // packets to be less than or equal to this number. This does not include any + // overhead of DTLS, TURN, UDP or IP headers. + size_t mtu = kMaxSafeMTUSize; + + // The largest allowed message payload to be sent. Messages will be rejected + // if their payload is larger than this value. Note that this doesn't affect + // incoming messages, which may larger than this value (but smaller than + // `max_receiver_window_buffer_size`). + size_t max_message_size = 256 * 1024; + + // The default stream priority, if not overridden by + // `SctpSocket::SetStreamPriority`. The default value is selected to be + // compatible with https://www.w3.org/TR/webrtc-priority/, section 4.2-4.3. + StreamPriority default_stream_priority = StreamPriority(256); + + // Maximum received window buffer size. This should be a bit larger than the + // largest sized message you want to be able to receive. This essentially + // limits the memory usage on the receive side. Note that memory is allocated + // dynamically, and this represents the maximum amount of buffered data. The + // actual memory usage of the library will be smaller in normal operation, and + // will be larger than this due to other allocations and overhead if the + // buffer is fully utilized. + size_t max_receiver_window_buffer_size = 5 * 1024 * 1024; + + // Maximum send buffer size. It will not be possible to queue more data than + // this before sending it. + size_t max_send_buffer_size = 2'000'000; + + // A threshold that, when the amount of data in the send buffer goes below + // this value, will trigger `DcSctpCallbacks::OnTotalBufferedAmountLow`. + size_t total_buffered_amount_low_threshold = 1'800'000; + + // Max allowed RTT value. When the RTT is measured and it's found to be larger + // than this value, it will be discarded and not used for e.g. any RTO + // calculation. The default value is an extreme maximum but can be adapted + // to better match the environment. + DurationMs rtt_max = DurationMs(60'000); + + // Initial RTO value. + DurationMs rto_initial = DurationMs(500); + + // Maximum RTO value. + DurationMs rto_max = DurationMs(60'000); + + // Minimum RTO value. This must be larger than an expected peer delayed ack + // timeout. + DurationMs rto_min = DurationMs(400); + + // T1-init timeout. + DurationMs t1_init_timeout = DurationMs(1000); + + // T1-cookie timeout. + DurationMs t1_cookie_timeout = DurationMs(1000); + + // T2-shutdown timeout. + DurationMs t2_shutdown_timeout = DurationMs(1000); + + // For t1-init, t1-cookie, t2-shutdown, t3-rtx, this value - if set - will be + // the upper bound on how large the exponentially backed off timeout can + // become. The lower the duration, the faster the connection can recover on + // transient network issues. Setting this value may require changing + // `max_retransmissions` and `max_init_retransmits` to ensure that the + // connection is not closed too quickly. + absl::optional<DurationMs> max_timer_backoff_duration = absl::nullopt; + + // Hearbeat interval (on idle connections only). Set to zero to disable. + DurationMs heartbeat_interval = DurationMs(30000); + + // The maximum time when a SACK will be sent from the arrival of an + // unacknowledged packet. Whatever is smallest of RTO/2 and this will be used. + DurationMs delayed_ack_max_timeout = DurationMs(200); + + // The minimum limit for the measured RTT variance + // + // Setting this below the expected delayed ack timeout (+ margin) of the peer + // might result in unnecessary retransmissions, as the maximum time it takes + // to ACK a DATA chunk is typically RTT + ATO (delayed ack timeout), and when + // the SCTP channel is quite idle, and heartbeats dominate the source of RTT + // measurement, the RTO would converge with the smoothed RTT (SRTT). The + // default ATO is 200ms in usrsctp, and a 20ms (10%) margin would include the + // processing time of received packets and the clock granularity when setting + // the delayed ack timer on the peer. + // + // This is described for TCP in + // https://datatracker.ietf.org/doc/html/rfc6298#section-4. + DurationMs min_rtt_variance = DurationMs(220); + + // The initial congestion window size, in number of MTUs. + // See https://tools.ietf.org/html/rfc4960#section-7.2.1 which defaults at ~3 + // and https://research.google/pubs/pub36640/ which argues for at least ten + // segments. + size_t cwnd_mtus_initial = 10; + + // The minimum congestion window size, in number of MTUs, upon detection of + // packet loss by SACK. Note that if the retransmission timer expires, the + // congestion window will be as small as one MTU. See + // https://tools.ietf.org/html/rfc4960#section-7.2.3. + size_t cwnd_mtus_min = 4; + + // When the congestion window is at or above this number of MTUs, the + // congestion control algorithm will avoid filling the congestion window + // fully, if that results in fragmenting large messages into quite small + // packets. When the congestion window is smaller than this option, it will + // aim to fill the congestion window as much as it can, even if it results in + // creating small fragmented packets. + size_t avoid_fragmentation_cwnd_mtus = 6; + + // The number of packets that may be sent at once. This is limited to avoid + // bursts that too quickly fill the send buffer. Typically in a a socket in + // its "slow start" phase (when it sends as much as it can), it will send + // up to three packets for every SACK received, so the default limit is set + // just above that, and then mostly applicable for (but not limited to) fast + // retransmission scenarios. + int max_burst = 4; + + // Maximum Data Retransmit Attempts (per DATA chunk). Set to absl::nullopt for + // no limit. + absl::optional<int> max_retransmissions = 10; + + // Max.Init.Retransmits (https://tools.ietf.org/html/rfc4960#section-15). Set + // to absl::nullopt for no limit. + absl::optional<int> max_init_retransmits = 8; + + // RFC3758 Partial Reliability Extension + bool enable_partial_reliability = true; + + // RFC8260 Stream Schedulers and User Message Interleaving + bool enable_message_interleaving = false; + + // If RTO should be added to heartbeat_interval + bool heartbeat_interval_include_rtt = true; + + // Disables SCTP packet crc32 verification. Useful when running with fuzzers. + bool disable_checksum_verification = false; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_DCSCTP_OPTIONS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket.h b/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket.h new file mode 100644 index 0000000000..8506397581 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket.h @@ -0,0 +1,610 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_H_ +#define NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_H_ + +#include <cstdint> +#include <memory> +#include <utility> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/dcsctp_handover_state.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/public/timeout.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// The socket/association state +enum class SocketState { + // The socket is closed. + kClosed, + // The socket has initiated a connection, which is not yet established. Note + // that for incoming connections and for reconnections when the socket is + // already connected, the socket will not transition to this state. + kConnecting, + // The socket is connected, and the connection is established. + kConnected, + // The socket is shutting down, and the connection is not yet closed. + kShuttingDown, +}; + +// Send options for sending messages +struct SendOptions { + // If the message should be sent with unordered message delivery. + IsUnordered unordered = IsUnordered(false); + + // If set, will discard messages that haven't been correctly sent and + // received before the lifetime has expired. This is only available if the + // peer supports Partial Reliability Extension (RFC3758). + absl::optional<DurationMs> lifetime = absl::nullopt; + + // If set, limits the number of retransmissions. This is only available + // if the peer supports Partial Reliability Extension (RFC3758). + absl::optional<size_t> max_retransmissions = absl::nullopt; + + // If set, will generate lifecycle events for this message. See e.g. + // `DcSctpSocketCallbacks::OnLifecycleMessageFullySent`. This value is decided + // by the client and the library will provide it to all lifecycle callbacks. + LifecycleId lifecycle_id = LifecycleId::NotSet(); +}; + +enum class ErrorKind { + // Indicates that no error has occurred. This will never be the case when + // `OnError` or `OnAborted` is called. + kNoError, + // There have been too many retries or timeouts, and the library has given up. + kTooManyRetries, + // A command was received that is only possible to execute when the socket is + // connected, which it is not. + kNotConnected, + // Parsing of the command or its parameters failed. + kParseFailed, + // Commands are received in the wrong sequence, which indicates a + // synchronisation mismatch between the peers. + kWrongSequence, + // The peer has reported an issue using ERROR or ABORT command. + kPeerReported, + // The peer has performed a protocol violation. + kProtocolViolation, + // The receive or send buffers have been exhausted. + kResourceExhaustion, + // The client has performed an invalid operation. + kUnsupportedOperation, +}; + +inline constexpr absl::string_view ToString(ErrorKind error) { + switch (error) { + case ErrorKind::kNoError: + return "NO_ERROR"; + case ErrorKind::kTooManyRetries: + return "TOO_MANY_RETRIES"; + case ErrorKind::kNotConnected: + return "NOT_CONNECTED"; + case ErrorKind::kParseFailed: + return "PARSE_FAILED"; + case ErrorKind::kWrongSequence: + return "WRONG_SEQUENCE"; + case ErrorKind::kPeerReported: + return "PEER_REPORTED"; + case ErrorKind::kProtocolViolation: + return "PROTOCOL_VIOLATION"; + case ErrorKind::kResourceExhaustion: + return "RESOURCE_EXHAUSTION"; + case ErrorKind::kUnsupportedOperation: + return "UNSUPPORTED_OPERATION"; + } +} + +enum class SendStatus { + // The message was enqueued successfully. As sending the message is done + // asynchronously, this is no guarantee that the message has been actually + // sent. + kSuccess, + // The message was rejected as the payload was empty (which is not allowed in + // SCTP). + kErrorMessageEmpty, + // The message was rejected as the payload was larger than what has been set + // as `DcSctpOptions.max_message_size`. + kErrorMessageTooLarge, + // The message could not be enqueued as the socket is out of resources. This + // mainly indicates that the send queue is full. + kErrorResourceExhaustion, + // The message could not be sent as the socket is shutting down. + kErrorShuttingDown, +}; + +inline constexpr absl::string_view ToString(SendStatus error) { + switch (error) { + case SendStatus::kSuccess: + return "SUCCESS"; + case SendStatus::kErrorMessageEmpty: + return "ERROR_MESSAGE_EMPTY"; + case SendStatus::kErrorMessageTooLarge: + return "ERROR_MESSAGE_TOO_LARGE"; + case SendStatus::kErrorResourceExhaustion: + return "ERROR_RESOURCE_EXHAUSTION"; + case SendStatus::kErrorShuttingDown: + return "ERROR_SHUTTING_DOWN"; + } +} + +// Return value of ResetStreams. +enum class ResetStreamsStatus { + // If the connection is not yet established, this will be returned. + kNotConnected, + // Indicates that ResetStreams operation has been successfully initiated. + kPerformed, + // Indicates that ResetStreams has failed as it's not supported by the peer. + kNotSupported, +}; + +inline constexpr absl::string_view ToString(ResetStreamsStatus error) { + switch (error) { + case ResetStreamsStatus::kNotConnected: + return "NOT_CONNECTED"; + case ResetStreamsStatus::kPerformed: + return "PERFORMED"; + case ResetStreamsStatus::kNotSupported: + return "NOT_SUPPORTED"; + } +} + +// Return value of DcSctpSocketCallbacks::SendPacketWithStatus. +enum class SendPacketStatus { + // Indicates that the packet was successfully sent. As sending is unreliable, + // there are no guarantees that the packet was actually delivered. + kSuccess, + // The packet was not sent due to a temporary failure, such as the local send + // buffer becoming exhausted. This return value indicates that the socket will + // recover and sending that packet can be retried at a later time. + kTemporaryFailure, + // The packet was not sent due to other reasons. + kError, +}; + +// Represent known SCTP implementations. +enum class SctpImplementation { + // There is not enough information toto determine any SCTP implementation. + kUnknown, + // This implementation. + kDcsctp, + // https://github.com/sctplab/usrsctp. + kUsrSctp, + // Any other implementation. + kOther, +}; + +inline constexpr absl::string_view ToString(SctpImplementation implementation) { + switch (implementation) { + case SctpImplementation::kUnknown: + return "unknown"; + case SctpImplementation::kDcsctp: + return "dcsctp"; + case SctpImplementation::kUsrSctp: + return "usrsctp"; + case SctpImplementation::kOther: + return "other"; + } +} + +// Tracked metrics, which is the return value of GetMetrics. Optional members +// will be unset when they are not yet known. +struct Metrics { + // Transmission stats and metrics. + + // Number of packets sent. + size_t tx_packets_count = 0; + + // Number of messages requested to be sent. + size_t tx_messages_count = 0; + + // The current congestion window (cwnd) in bytes, corresponding to spinfo_cwnd + // defined in RFC6458. + size_t cwnd_bytes = 0; + + // Smoothed round trip time, corresponding to spinfo_srtt defined in RFC6458. + int srtt_ms = 0; + + // Number of data items in the retransmission queue that haven’t been + // acked/nacked yet and are in-flight. Corresponding to sstat_unackdata + // defined in RFC6458. This may be an approximation when there are messages in + // the send queue that haven't been fragmented/packetized yet. + size_t unack_data_count = 0; + + // Receive stats and metrics. + + // Number of packets received. + size_t rx_packets_count = 0; + + // Number of messages received. + size_t rx_messages_count = 0; + + // The peer’s last announced receiver window size, corresponding to + // sstat_rwnd defined in RFC6458. + uint32_t peer_rwnd_bytes = 0; + + // Returns the detected SCTP implementation of the peer. As this is not + // explicitly signalled during the connection establishment, heuristics is + // used to analyze e.g. the state cookie in the INIT-ACK chunk. + SctpImplementation peer_implementation = SctpImplementation::kUnknown; + + // Indicates if RFC8260 User Message Interleaving has been negotiated by both + // peers. + bool uses_message_interleaving = false; +}; + +// Callbacks that the DcSctpSocket will call synchronously to the owning +// client. It is allowed to call back into the library from callbacks that start +// with "On". It has been explicitly documented when it's not allowed to call +// back into this library from within a callback. +// +// Theses callbacks are only synchronously triggered as a result of the client +// calling a public method in `DcSctpSocketInterface`. +class DcSctpSocketCallbacks { + public: + virtual ~DcSctpSocketCallbacks() = default; + + // Called when the library wants the packet serialized as `data` to be sent. + // + // TODO(bugs.webrtc.org/12943): This method is deprecated, see + // `SendPacketWithStatus`. + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual void SendPacket(rtc::ArrayView<const uint8_t> data) {} + + // Called when the library wants the packet serialized as `data` to be sent. + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual SendPacketStatus SendPacketWithStatus( + rtc::ArrayView<const uint8_t> data) { + SendPacket(data); + return SendPacketStatus::kSuccess; + } + + // Called when the library wants to create a Timeout. The callback must return + // an object that implements that interface. + // + // Low precision tasks are scheduled more efficiently by using leeway to + // reduce Idle Wake Ups and is the preferred precision whenever possible. High + // precision timeouts do not have this leeway, but is still limited by OS + // timer precision. At the time of writing, kLow's additional leeway may be up + // to 17 ms, but please see webrtc::TaskQueueBase::DelayPrecision for + // up-to-date information. + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual std::unique_ptr<Timeout> CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision) { + // TODO(hbos): When dependencies have migrated to this new signature, make + // this pure virtual and delete the other version. + return CreateTimeout(); + } + // TODO(hbos): When dependencies have migrated to the other signature, delete + // this version. + virtual std::unique_ptr<Timeout> CreateTimeout() { + return CreateTimeout(webrtc::TaskQueueBase::DelayPrecision::kLow); + } + + // Returns the current time in milliseconds (from any epoch). + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual TimeMs TimeMillis() = 0; + + // Called when the library needs a random number uniformly distributed between + // `low` (inclusive) and `high` (exclusive). The random numbers used by the + // library are not used for cryptographic purposes. There are no requirements + // that the random number generator must be secure. + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual uint32_t GetRandomInt(uint32_t low, uint32_t high) = 0; + + // Triggered when the outgoing message buffer is empty, meaning that there are + // no more queued messages, but there can still be packets in-flight or to be + // retransmitted. (in contrast to SCTP_SENDER_DRY_EVENT). + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + ABSL_DEPRECATED("Use OnTotalBufferedAmountLow instead") + virtual void NotifyOutgoingMessageBufferEmpty() {} + + // Called when the library has received an SCTP message in full and delivers + // it to the upper layer. + // + // It is allowed to call into this library from within this callback. + virtual void OnMessageReceived(DcSctpMessage message) = 0; + + // Triggered when an non-fatal error is reported by either this library or + // from the other peer (by sending an ERROR command). These should be logged, + // but no other action need to be taken as the association is still viable. + // + // It is allowed to call into this library from within this callback. + virtual void OnError(ErrorKind error, absl::string_view message) = 0; + + // Triggered when the socket has aborted - either as decided by this socket + // due to e.g. too many retransmission attempts, or by the peer when + // receiving an ABORT command. No other callbacks will be done after this + // callback, unless reconnecting. + // + // It is allowed to call into this library from within this callback. + virtual void OnAborted(ErrorKind error, absl::string_view message) = 0; + + // Called when calling `Connect` succeeds, but also for incoming successful + // connection attempts. + // + // It is allowed to call into this library from within this callback. + virtual void OnConnected() = 0; + + // Called when the socket is closed in a controlled way. No other + // callbacks will be done after this callback, unless reconnecting. + // + // It is allowed to call into this library from within this callback. + virtual void OnClosed() = 0; + + // On connection restarted (by peer). This is just a notification, and the + // association is expected to work fine after this call, but there could have + // been packet loss as a result of restarting the association. + // + // It is allowed to call into this library from within this callback. + virtual void OnConnectionRestarted() = 0; + + // Indicates that a stream reset request has failed. + // + // It is allowed to call into this library from within this callback. + virtual void OnStreamsResetFailed( + rtc::ArrayView<const StreamID> outgoing_streams, + absl::string_view reason) = 0; + + // Indicates that a stream reset request has been performed. + // + // It is allowed to call into this library from within this callback. + virtual void OnStreamsResetPerformed( + rtc::ArrayView<const StreamID> outgoing_streams) = 0; + + // When a peer has reset some of its outgoing streams, this will be called. An + // empty list indicates that all streams have been reset. + // + // It is allowed to call into this library from within this callback. + virtual void OnIncomingStreamsReset( + rtc::ArrayView<const StreamID> incoming_streams) = 0; + + // Will be called when the amount of data buffered to be sent falls to or + // below the threshold set when calling `SetBufferedAmountLowThreshold`. + // + // It is allowed to call into this library from within this callback. + virtual void OnBufferedAmountLow(StreamID stream_id) {} + + // Will be called when the total amount of data buffered (in the entire send + // buffer, for all streams) falls to or below the threshold specified in + // `DcSctpOptions::total_buffered_amount_low_threshold`. + virtual void OnTotalBufferedAmountLow() {} + + // == Lifecycle Events == + // + // If a `lifecycle_id` is provided as `SendOptions`, lifecycle callbacks will + // be triggered as the message is processed by the library. + // + // The possible transitions are shown in the graph below: + // + // DcSctpSocket::Send ────────────────────────┐ + // │ │ + // │ │ + // v v + // OnLifecycleMessageFullySent ───────> OnLifecycleMessageExpired + // │ │ + // │ │ + // v v + // OnLifeCycleMessageDelivered ────────────> OnLifecycleEnd + + // OnLifecycleMessageFullySent will be called when a message has been fully + // sent, meaning that the last fragment has been produced from the send queue + // and sent on the network. Note that this will trigger at most once per + // message even if the message was retransmitted due to packet loss. + // + // This is a lifecycle event. + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual void OnLifecycleMessageFullySent(LifecycleId lifecycle_id) {} + + // OnLifecycleMessageExpired will be called when a message has expired. If it + // was expired with data remaining in the send queue that had not been sent + // ever, `maybe_delivered` will be set to false. If `maybe_delivered` is true, + // the message has at least once been sent and may have been correctly + // received by the peer, but it has expired before the receiver managed to + // acknowledge it. This means that if `maybe_delivered` is true, it's unknown + // if the message was lost or was delivered, and if `maybe_delivered` is + // false, it's guaranteed to not be delivered. + // + // It's guaranteed that `OnLifecycleMessageDelivered` is not called if this + // callback has triggered. + // + // This is a lifecycle event. + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual void OnLifecycleMessageExpired(LifecycleId lifecycle_id, + bool maybe_delivered) {} + + // OnLifecycleMessageDelivered will be called when a non-expired message has + // been acknowledged by the peer as delivered. + // + // Note that this will trigger only when the peer moves its cumulative TSN ack + // beyond this message, and will not fire for messages acked using + // gap-ack-blocks as those are renegable. This means that this may fire a bit + // later than the message was actually first "acked" by the peer, as - + // according to the protocol - those acks may be unacked later by the client. + // + // It's guaranteed that `OnLifecycleMessageExpired` is not called if this + // callback has triggered. + // + // This is a lifecycle event. + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual void OnLifecycleMessageDelivered(LifecycleId lifecycle_id) {} + + // OnLifecycleEnd will be called when a lifecycle event has reached its end. + // It will be called when processing of a message is complete, no matter how + // it completed. It will be called after all other lifecycle events, if any. + // + // Note that it's possible that this callback triggers without any other + // lifecycle callbacks having been called before in case of errors, such as + // attempting to send an empty message or failing to enqueue a message if the + // send queue is full. + // + // NOTE: When the socket is deallocated, there will be no `OnLifecycleEnd` + // callbacks sent for messages that were enqueued. But as long as the socket + // is alive, `OnLifecycleEnd` callbacks are guaranteed to be sent as messages + // are either expired or successfully acknowledged. + // + // This is a lifecycle event. + // + // Note that it's NOT ALLOWED to call into this library from within this + // callback. + virtual void OnLifecycleEnd(LifecycleId lifecycle_id) {} +}; + +// The DcSctpSocket implementation implements the following interface. +// This class is thread-compatible. +class DcSctpSocketInterface { + public: + virtual ~DcSctpSocketInterface() = default; + + // To be called when an incoming SCTP packet is to be processed. + virtual void ReceivePacket(rtc::ArrayView<const uint8_t> data) = 0; + + // To be called when a timeout has expired. The `timeout_id` is provided + // when the timeout was initiated. + virtual void HandleTimeout(TimeoutID timeout_id) = 0; + + // Connects the socket. This is an asynchronous operation, and + // `DcSctpSocketCallbacks::OnConnected` will be called on success. + virtual void Connect() = 0; + + // Puts this socket to the state in which the original socket was when its + // `DcSctpSocketHandoverState` was captured by `GetHandoverStateAndClose`. + // `RestoreFromState` is allowed only on the closed socket. + // `DcSctpSocketCallbacks::OnConnected` will be called if a connected socket + // state is restored. + // `DcSctpSocketCallbacks::OnError` will be called on error. + virtual void RestoreFromState(const DcSctpSocketHandoverState& state) = 0; + + // Gracefully shutdowns the socket and sends all outstanding data. This is an + // asynchronous operation and `DcSctpSocketCallbacks::OnClosed` will be called + // on success. + virtual void Shutdown() = 0; + + // Closes the connection non-gracefully. Will send ABORT if the connection is + // not already closed. No callbacks will be made after Close() has returned. + virtual void Close() = 0; + + // The socket state. + virtual SocketState state() const = 0; + + // The options it was created with. + virtual const DcSctpOptions& options() const = 0; + + // Update the options max_message_size. + virtual void SetMaxMessageSize(size_t max_message_size) = 0; + + // Sets the priority of an outgoing stream. The initial value, when not set, + // is `DcSctpOptions::default_stream_priority`. + virtual void SetStreamPriority(StreamID stream_id, + StreamPriority priority) = 0; + + // Returns the currently set priority for an outgoing stream. The initial + // value, when not set, is `DcSctpOptions::default_stream_priority`. + virtual StreamPriority GetStreamPriority(StreamID stream_id) const = 0; + + // Sends the message `message` using the provided send options. + // Sending a message is an asynchronous operation, and the `OnError` callback + // may be invoked to indicate any errors in sending the message. + // + // The association does not have to be established before calling this method. + // If it's called before there is an established association, the message will + // be queued. + virtual SendStatus Send(DcSctpMessage message, + const SendOptions& send_options) = 0; + + // Resetting streams is an asynchronous operation and the results will + // be notified using `DcSctpSocketCallbacks::OnStreamsResetDone()` on success + // and `DcSctpSocketCallbacks::OnStreamsResetFailed()` on failure. Note that + // only outgoing streams can be reset. + // + // When it's known that the peer has reset its own outgoing streams, + // `DcSctpSocketCallbacks::OnIncomingStreamReset` is called. + // + // Note that resetting a stream will also remove all queued messages on those + // streams, but will ensure that the currently sent message (if any) is fully + // sent before closing the stream. + // + // Resetting streams can only be done on an established association that + // supports stream resetting. Calling this method on e.g. a closed association + // or streams that don't support resetting will not perform any operation. + virtual ResetStreamsStatus ResetStreams( + rtc::ArrayView<const StreamID> outgoing_streams) = 0; + + // Returns the number of bytes of data currently queued to be sent on a given + // stream. + virtual size_t buffered_amount(StreamID stream_id) const = 0; + + // Returns the number of buffered outgoing bytes that is considered "low" for + // a given stream. See `SetBufferedAmountLowThreshold`. + virtual size_t buffered_amount_low_threshold(StreamID stream_id) const = 0; + + // Used to specify the number of bytes of buffered outgoing data that is + // considered "low" for a given stream, which will trigger an + // OnBufferedAmountLow event. The default value is zero (0). + virtual void SetBufferedAmountLowThreshold(StreamID stream_id, + size_t bytes) = 0; + + // Retrieves the latest metrics. If the socket is not fully connected, + // `absl::nullopt` will be returned. + virtual absl::optional<Metrics> GetMetrics() const = 0; + + // Returns empty bitmask if the socket is in the state in which a snapshot of + // the state can be made by `GetHandoverStateAndClose()`. Return value is + // invalidated by a call to any non-const method. + virtual HandoverReadinessStatus GetHandoverReadiness() const = 0; + + // Collects a snapshot of the socket state that can be used to reconstruct + // this socket in another process. On success this socket object is closed + // synchronously and no callbacks will be made after the method has returned. + // The method fails if the socket is not in a state ready for handover. + // nullopt indicates the failure. `DcSctpSocketCallbacks::OnClosed` will be + // called on success. + virtual absl::optional<DcSctpSocketHandoverState> + GetHandoverStateAndClose() = 0; + + // Returns the detected SCTP implementation of the peer. As this is not + // explicitly signalled during the connection establishment, heuristics is + // used to analyze e.g. the state cookie in the INIT-ACK chunk. + // + // If this method is called too early (before + // `DcSctpSocketCallbacks::OnConnected` has triggered), this will likely + // return `SctpImplementation::kUnknown`. + ABSL_DEPRECATED("See Metrics::peer_implementation instead") + virtual SctpImplementation peer_implementation() const { + return SctpImplementation::kUnknown; + } +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_H_ diff --git a/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket_factory.cc b/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket_factory.cc new file mode 100644 index 0000000000..ebcb5553e3 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket_factory.cc @@ -0,0 +1,34 @@ +/* + * Copyright 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "net/dcsctp/public/dcsctp_socket_factory.h" + +#include <memory> +#include <utility> + +#include "absl/strings/string_view.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/socket/dcsctp_socket.h" + +namespace dcsctp { + +DcSctpSocketFactory::~DcSctpSocketFactory() = default; + +std::unique_ptr<DcSctpSocketInterface> DcSctpSocketFactory::Create( + absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr<PacketObserver> packet_observer, + const DcSctpOptions& options) { + return std::make_unique<DcSctpSocket>(log_prefix, callbacks, + std::move(packet_observer), options); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket_factory.h b/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket_factory.h new file mode 100644 index 0000000000..ca429d3275 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/dcsctp_socket_factory.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_FACTORY_H_ +#define NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_FACTORY_H_ + +#include <memory> + +#include "absl/strings/string_view.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/packet_observer.h" + +namespace dcsctp { +class DcSctpSocketFactory { + public: + virtual ~DcSctpSocketFactory(); + virtual std::unique_ptr<DcSctpSocketInterface> Create( + absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr<PacketObserver> packet_observer, + const DcSctpOptions& options); +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_DCSCTP_SOCKET_FACTORY_H_ diff --git a/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket.h b/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket.h new file mode 100644 index 0000000000..0fd572bd94 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_H_ +#define NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_H_ + +#include "net/dcsctp/public/dcsctp_socket.h" +#include "test/gmock.h" + +namespace dcsctp { + +class MockDcSctpSocket : public DcSctpSocketInterface { + public: + MOCK_METHOD(void, + ReceivePacket, + (rtc::ArrayView<const uint8_t> data), + (override)); + + MOCK_METHOD(void, HandleTimeout, (TimeoutID timeout_id), (override)); + + MOCK_METHOD(void, Connect, (), (override)); + + MOCK_METHOD(void, + RestoreFromState, + (const DcSctpSocketHandoverState&), + (override)); + + MOCK_METHOD(void, Shutdown, (), (override)); + + MOCK_METHOD(void, Close, (), (override)); + + MOCK_METHOD(SocketState, state, (), (const, override)); + + MOCK_METHOD(const DcSctpOptions&, options, (), (const, override)); + + MOCK_METHOD(void, SetMaxMessageSize, (size_t max_message_size), (override)); + + MOCK_METHOD(void, + SetStreamPriority, + (StreamID stream_id, StreamPriority priority), + (override)); + + MOCK_METHOD(StreamPriority, + GetStreamPriority, + (StreamID stream_id), + (const, override)); + + MOCK_METHOD(SendStatus, + Send, + (DcSctpMessage message, const SendOptions& send_options), + (override)); + + MOCK_METHOD(ResetStreamsStatus, + ResetStreams, + (rtc::ArrayView<const StreamID> outgoing_streams), + (override)); + + MOCK_METHOD(size_t, buffered_amount, (StreamID stream_id), (const, override)); + + MOCK_METHOD(size_t, + buffered_amount_low_threshold, + (StreamID stream_id), + (const, override)); + + MOCK_METHOD(void, + SetBufferedAmountLowThreshold, + (StreamID stream_id, size_t bytes), + (override)); + + MOCK_METHOD(absl::optional<Metrics>, GetMetrics, (), (const, override)); + + MOCK_METHOD(HandoverReadinessStatus, + GetHandoverReadiness, + (), + (const, override)); + MOCK_METHOD(absl::optional<DcSctpSocketHandoverState>, + GetHandoverStateAndClose, + (), + (override)); +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_H_ diff --git a/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket_factory.h b/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket_factory.h new file mode 100644 index 0000000000..61f05577f2 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket_factory.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_FACTORY_H_ +#define NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_FACTORY_H_ + +#include <memory> + +#include "net/dcsctp/public/dcsctp_socket_factory.h" +#include "test/gmock.h" + +namespace dcsctp { + +class MockDcSctpSocketFactory : public DcSctpSocketFactory { + public: + MOCK_METHOD(std::unique_ptr<DcSctpSocketInterface>, + Create, + (absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr<PacketObserver> packet_observer, + const DcSctpOptions& options), + (override)); +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_MOCK_DCSCTP_SOCKET_FACTORY_H_ diff --git a/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket_test.cc b/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket_test.cc new file mode 100644 index 0000000000..57013e4ce2 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/mock_dcsctp_socket_test.cc @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/public/mock_dcsctp_socket.h" + +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { + +// This test exists to ensure that all methods are mocked correctly, and to +// generate compiler errors if they are not. +TEST(MockDcSctpSocketTest, CanInstantiateAndConnect) { + testing::StrictMock<MockDcSctpSocket> socket; + + EXPECT_CALL(socket, Connect); + + socket.Connect(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/public/packet_observer.h b/third_party/libwebrtc/net/dcsctp/public/packet_observer.h new file mode 100644 index 0000000000..fe7567824f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/packet_observer.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_PACKET_OBSERVER_H_ +#define NET_DCSCTP_PUBLIC_PACKET_OBSERVER_H_ + +#include <stdint.h> + +#include "api/array_view.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// A PacketObserver can be attached to a socket and will be called for +// all sent and received packets. +class PacketObserver { + public: + virtual ~PacketObserver() = default; + // Called when a packet is sent, with the current time (in milliseconds) as + // `now`, and the packet payload as `payload`. + virtual void OnSentPacket(TimeMs now, + rtc::ArrayView<const uint8_t> payload) = 0; + + // Called when a packet is received, with the current time (in milliseconds) + // as `now`, and the packet payload as `payload`. + virtual void OnReceivedPacket(TimeMs now, + rtc::ArrayView<const uint8_t> payload) = 0; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_PACKET_OBSERVER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/public/text_pcap_packet_observer.cc b/third_party/libwebrtc/net/dcsctp/public/text_pcap_packet_observer.cc new file mode 100644 index 0000000000..2b13060190 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/text_pcap_packet_observer.cc @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/public/text_pcap_packet_observer.h" + +#include "api/array_view.h" +#include "net/dcsctp/public/types.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +void TextPcapPacketObserver::OnSentPacket( + dcsctp::TimeMs now, + rtc::ArrayView<const uint8_t> payload) { + PrintPacket("O ", name_, now, payload); +} + +void TextPcapPacketObserver::OnReceivedPacket( + dcsctp::TimeMs now, + rtc::ArrayView<const uint8_t> payload) { + PrintPacket("I ", name_, now, payload); +} + +void TextPcapPacketObserver::PrintPacket( + absl::string_view prefix, + absl::string_view socket_name, + dcsctp::TimeMs now, + rtc::ArrayView<const uint8_t> payload) { + rtc::StringBuilder s; + s << "\n" << prefix; + int64_t remaining = *now % (24 * 60 * 60 * 1000); + int hours = remaining / (60 * 60 * 1000); + remaining = remaining % (60 * 60 * 1000); + int minutes = remaining / (60 * 1000); + remaining = remaining % (60 * 1000); + int seconds = remaining / 1000; + int ms = remaining % 1000; + s.AppendFormat("%02d:%02d:%02d.%03d", hours, minutes, seconds, ms); + s << " 0000"; + for (uint8_t byte : payload) { + s.AppendFormat(" %02x", byte); + } + s << " # SCTP_PACKET " << socket_name; + RTC_LOG(LS_VERBOSE) << s.str(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/public/text_pcap_packet_observer.h b/third_party/libwebrtc/net/dcsctp/public/text_pcap_packet_observer.h new file mode 100644 index 0000000000..0685771ccf --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/text_pcap_packet_observer.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_TEXT_PCAP_PACKET_OBSERVER_H_ +#define NET_DCSCTP_PUBLIC_TEXT_PCAP_PACKET_OBSERVER_H_ + +#include <string> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// Print outs all sent and received packets to the logs, at LS_VERBOSE severity. +class TextPcapPacketObserver : public dcsctp::PacketObserver { + public: + explicit TextPcapPacketObserver(absl::string_view name) : name_(name) {} + + // Implementation of `dcsctp::PacketObserver`. + void OnSentPacket(dcsctp::TimeMs now, + rtc::ArrayView<const uint8_t> payload) override; + + void OnReceivedPacket(dcsctp::TimeMs now, + rtc::ArrayView<const uint8_t> payload) override; + + // Prints a packet to the log. Exposed to allow it to be used in compatibility + // tests suites that don't use PacketObserver. + static void PrintPacket(absl::string_view prefix, + absl::string_view socket_name, + dcsctp::TimeMs now, + rtc::ArrayView<const uint8_t> payload); + + private: + const std::string name_; +}; + +} // namespace dcsctp +#endif // NET_DCSCTP_PUBLIC_TEXT_PCAP_PACKET_OBSERVER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/public/timeout.h b/third_party/libwebrtc/net/dcsctp/public/timeout.h new file mode 100644 index 0000000000..64ba351093 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/timeout.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_TIMEOUT_H_ +#define NET_DCSCTP_PUBLIC_TIMEOUT_H_ + +#include <cstdint> + +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// A very simple timeout that can be started and stopped. When started, +// it will be given a unique `timeout_id` which should be provided to +// `DcSctpSocket::HandleTimeout` when it expires. +class Timeout { + public: + virtual ~Timeout() = default; + + // Called to start time timeout, with the duration in milliseconds as + // `duration` and with the timeout identifier as `timeout_id`, which - if + // the timeout expires - shall be provided to `DcSctpSocket::HandleTimeout`. + // + // `Start` and `Stop` will always be called in pairs. In other words will + // ´Start` never be called twice, without a call to `Stop` in between. + virtual void Start(DurationMs duration, TimeoutID timeout_id) = 0; + + // Called to stop the running timeout. + // + // `Start` and `Stop` will always be called in pairs. In other words will + // ´Start` never be called twice, without a call to `Stop` in between. + // + // `Stop` will always be called prior to releasing this object. + virtual void Stop() = 0; + + // Called to restart an already running timeout, with the `duration` and + // `timeout_id` parameters as described in `Start`. This can be overridden by + // the implementation to restart it more efficiently. + virtual void Restart(DurationMs duration, TimeoutID timeout_id) { + Stop(); + Start(duration, timeout_id); + } +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_TIMEOUT_H_ diff --git a/third_party/libwebrtc/net/dcsctp/public/types.h b/third_party/libwebrtc/net/dcsctp/public/types.h new file mode 100644 index 0000000000..d0725620d8 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/types.h @@ -0,0 +1,143 @@ +/* + * Copyright 2019 The Chromium Authors. All rights reserved. + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_PUBLIC_TYPES_H_ +#define NET_DCSCTP_PUBLIC_TYPES_H_ + +#include <cstdint> +#include <limits> + +#include "rtc_base/strong_alias.h" + +namespace dcsctp { + +// Stream Identifier +using StreamID = webrtc::StrongAlias<class StreamIDTag, uint16_t>; + +// Payload Protocol Identifier (PPID) +using PPID = webrtc::StrongAlias<class PPIDTag, uint32_t>; + +// Timeout Identifier +using TimeoutID = webrtc::StrongAlias<class TimeoutTag, uint64_t>; + +// Indicates if a message is allowed to be received out-of-order compared to +// other messages on the same stream. +using IsUnordered = webrtc::StrongAlias<class IsUnorderedTag, bool>; + +// Stream priority, where higher values indicate higher priority. The meaning of +// this value and how it's used depends on the stream scheduler. +using StreamPriority = webrtc::StrongAlias<class StreamPriorityTag, uint16_t>; + +// Duration, as milliseconds. Overflows after 24 days. +class DurationMs : public webrtc::StrongAlias<class DurationMsTag, int32_t> { + public: + constexpr explicit DurationMs(const UnderlyingType& v) + : webrtc::StrongAlias<class DurationMsTag, int32_t>(v) {} + + // Convenience methods for working with time. + constexpr DurationMs& operator+=(DurationMs d) { + value_ += d.value_; + return *this; + } + constexpr DurationMs& operator-=(DurationMs d) { + value_ -= d.value_; + return *this; + } + template <typename T> + constexpr DurationMs& operator*=(T factor) { + value_ *= factor; + return *this; + } +}; + +constexpr inline DurationMs operator+(DurationMs lhs, DurationMs rhs) { + return lhs += rhs; +} +constexpr inline DurationMs operator-(DurationMs lhs, DurationMs rhs) { + return lhs -= rhs; +} +template <typename T> +constexpr inline DurationMs operator*(DurationMs lhs, T rhs) { + return lhs *= rhs; +} +template <typename T> +constexpr inline DurationMs operator*(T lhs, DurationMs rhs) { + return rhs *= lhs; +} +constexpr inline int32_t operator/(DurationMs lhs, DurationMs rhs) { + return lhs.value() / rhs.value(); +} + +// Represents time, in milliseconds since a client-defined epoch. +class TimeMs : public webrtc::StrongAlias<class TimeMsTag, int64_t> { + public: + constexpr explicit TimeMs(const UnderlyingType& v) + : webrtc::StrongAlias<class TimeMsTag, int64_t>(v) {} + + // Convenience methods for working with time. + constexpr TimeMs& operator+=(DurationMs d) { + value_ += *d; + return *this; + } + constexpr TimeMs& operator-=(DurationMs d) { + value_ -= *d; + return *this; + } + + static constexpr TimeMs InfiniteFuture() { + return TimeMs(std::numeric_limits<int64_t>::max()); + } +}; + +constexpr inline TimeMs operator+(TimeMs lhs, DurationMs rhs) { + return lhs += rhs; +} +constexpr inline TimeMs operator+(DurationMs lhs, TimeMs rhs) { + return rhs += lhs; +} +constexpr inline TimeMs operator-(TimeMs lhs, DurationMs rhs) { + return lhs -= rhs; +} +constexpr inline DurationMs operator-(TimeMs lhs, TimeMs rhs) { + return DurationMs(*lhs - *rhs); +} + +// The maximum number of times the socket should attempt to retransmit a +// message which fails the first time in unreliable mode. +class MaxRetransmits + : public webrtc::StrongAlias<class MaxRetransmitsTag, uint16_t> { + public: + constexpr explicit MaxRetransmits(const UnderlyingType& v) + : webrtc::StrongAlias<class MaxRetransmitsTag, uint16_t>(v) {} + + // There should be no limit - the message should be sent reliably. + static constexpr MaxRetransmits NoLimit() { + return MaxRetransmits(std::numeric_limits<uint16_t>::max()); + } +}; + +// An identifier that can be set on sent messages, and picked by the sending +// client. If different from `::NotSet()`, lifecycle events will be generated, +// and eventually `DcSctpSocketCallbacks::OnLifecycleEnd` will be called to +// indicate that the lifecycle isn't tracked any longer. The value zero (0) is +// not a valid lifecycle identifier, and will be interpreted as not having it +// set. +class LifecycleId : public webrtc::StrongAlias<class LifecycleIdTag, uint64_t> { + public: + constexpr explicit LifecycleId(const UnderlyingType& v) + : webrtc::StrongAlias<class LifecycleIdTag, uint64_t>(v) {} + + constexpr bool IsSet() const { return value_ != 0; } + + static constexpr LifecycleId NotSet() { return LifecycleId(0); } +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_PUBLIC_TYPES_H_ diff --git a/third_party/libwebrtc/net/dcsctp/public/types_test.cc b/third_party/libwebrtc/net/dcsctp/public/types_test.cc new file mode 100644 index 0000000000..d3d1240751 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/public/types_test.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/public/types.h" + +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(TypesTest, DurationOperators) { + DurationMs d1(10); + DurationMs d2(25); + EXPECT_EQ(d1 + d2, DurationMs(35)); + EXPECT_EQ(d2 - d1, DurationMs(15)); + + d1 += d2; + EXPECT_EQ(d1, DurationMs(35)); + + d1 -= DurationMs(5); + EXPECT_EQ(d1, DurationMs(30)); + + d1 *= 1.5; + EXPECT_EQ(d1, DurationMs(45)); + + EXPECT_EQ(DurationMs(10) * 2, DurationMs(20)); +} + +TEST(TypesTest, TimeOperators) { + EXPECT_EQ(TimeMs(250) + DurationMs(100), TimeMs(350)); + EXPECT_EQ(DurationMs(250) + TimeMs(100), TimeMs(350)); + EXPECT_EQ(TimeMs(250) - DurationMs(100), TimeMs(150)); + EXPECT_EQ(TimeMs(250) - TimeMs(100), DurationMs(150)); + + TimeMs t1(150); + t1 -= DurationMs(50); + EXPECT_EQ(t1, TimeMs(100)); + t1 += DurationMs(200); + EXPECT_EQ(t1, TimeMs(300)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/BUILD.gn b/third_party/libwebrtc/net/dcsctp/rx/BUILD.gn new file mode 100644 index 0000000000..8ef60dcd5f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/BUILD.gn @@ -0,0 +1,149 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_library("data_tracker") { + deps = [ + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:socket", + "../timer", + ] + sources = [ + "data_tracker.cc", + "data_tracker.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_source_set("reassembly_streams") { + deps = [ + "../../../api:array_view", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:socket", + "../public:types", + ] + sources = [ "reassembly_streams.h" ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_library("interleaved_reassembly_streams") { + deps = [ + ":reassembly_streams", + "../../../api:array_view", + "../../../rtc_base", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:types", + ] + sources = [ + "interleaved_reassembly_streams.cc", + "interleaved_reassembly_streams.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} +rtc_library("traditional_reassembly_streams") { + deps = [ + ":reassembly_streams", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:types", + ] + sources = [ + "traditional_reassembly_streams.cc", + "traditional_reassembly_streams.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("reassembly_queue") { + deps = [ + ":interleaved_reassembly_streams", + ":reassembly_streams", + ":traditional_reassembly_streams", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../common:internal_types", + "../common:sequence_numbers", + "../common:str_join", + "../packet:chunk", + "../packet:data", + "../packet:parameter", + "../public:socket", + "../public:types", + ] + sources = [ + "reassembly_queue.cc", + "reassembly_queue.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_rx_unittests") { + testonly = true + + deps = [ + ":data_tracker", + ":interleaved_reassembly_streams", + ":reassembly_queue", + ":reassembly_streams", + ":traditional_reassembly_streams", + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../test:test_support", + "../common:handover_testing", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../public:types", + "../testing:data_generator", + "../timer", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + sources = [ + "data_tracker_test.cc", + "interleaved_reassembly_streams_test.cc", + "reassembly_queue_test.cc", + "traditional_reassembly_streams_test.cc", + ] + } +} diff --git a/third_party/libwebrtc/net/dcsctp/rx/data_tracker.cc b/third_party/libwebrtc/net/dcsctp/rx/data_tracker.cc new file mode 100644 index 0000000000..1f2e43f7f5 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/data_tracker.cc @@ -0,0 +1,386 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/data_tracker.h" + +#include <algorithm> +#include <cstdint> +#include <iterator> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/timer/timer.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +constexpr size_t DataTracker::kMaxDuplicateTsnReported; +constexpr size_t DataTracker::kMaxGapAckBlocksReported; + +bool DataTracker::AdditionalTsnBlocks::Add(UnwrappedTSN tsn) { + // Find any block to expand. It will look for any block that includes (also + // when expanded) the provided `tsn`. It will return the block that is greater + // than, or equal to `tsn`. + auto it = absl::c_lower_bound( + blocks_, tsn, [&](const TsnRange& elem, const UnwrappedTSN& t) { + return elem.last.next_value() < t; + }); + + if (it == blocks_.end()) { + // No matching block found. There is no greater than, or equal block - which + // means that this TSN is greater than any block. It can then be inserted at + // the end. + blocks_.emplace_back(tsn, tsn); + return true; + } + + if (tsn >= it->first && tsn <= it->last) { + // It's already in this block. + return false; + } + + if (it->last.next_value() == tsn) { + // This block can be expanded to the right, or merged with the next. + auto next_it = it + 1; + if (next_it != blocks_.end() && tsn.next_value() == next_it->first) { + // Expanding it would make it adjacent to next block - merge those. + it->last = next_it->last; + blocks_.erase(next_it); + return true; + } + + // Expand to the right + it->last = tsn; + return true; + } + + if (it->first == tsn.next_value()) { + // This block can be expanded to the left. Merging to the left would've been + // covered by the above "merge to the right". Both blocks (expand a + // right-most block to the left and expand a left-most block to the right) + // would match, but the left-most would be returned by std::lower_bound. + RTC_DCHECK(it == blocks_.begin() || (it - 1)->last.next_value() != tsn); + + // Expand to the left. + it->first = tsn; + return true; + } + + // Need to create a new block in the middle. + blocks_.emplace(it, tsn, tsn); + return true; +} + +void DataTracker::AdditionalTsnBlocks::EraseTo(UnwrappedTSN tsn) { + // Find the block that is greater than or equals `tsn`. + auto it = absl::c_lower_bound( + blocks_, tsn, [&](const TsnRange& elem, const UnwrappedTSN& t) { + return elem.last < t; + }); + + // The block that is found is greater or equal (or possibly ::end, when no + // block is greater or equal). All blocks before this block can be safely + // removed. the TSN might be within this block, so possibly truncate it. + bool tsn_is_within_block = it != blocks_.end() && tsn >= it->first; + blocks_.erase(blocks_.begin(), it); + + if (tsn_is_within_block) { + blocks_.front().first = tsn.next_value(); + } +} + +void DataTracker::AdditionalTsnBlocks::PopFront() { + RTC_DCHECK(!blocks_.empty()); + blocks_.erase(blocks_.begin()); +} + +bool DataTracker::IsTSNValid(TSN tsn) const { + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.PeekUnwrap(tsn); + + // Note that this method doesn't return `false` for old DATA chunks, as those + // are actually valid, and receiving those may affect the generated SACK + // response (by setting "duplicate TSNs"). + + uint32_t difference = + UnwrappedTSN::Difference(unwrapped_tsn, last_cumulative_acked_tsn_); + if (difference > kMaxAcceptedOutstandingFragments) { + return false; + } + return true; +} + +bool DataTracker::Observe(TSN tsn, + AnyDataChunk::ImmediateAckFlag immediate_ack) { + bool is_duplicate = false; + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.Unwrap(tsn); + + // IsTSNValid must be called prior to calling this method. + RTC_DCHECK( + UnwrappedTSN::Difference(unwrapped_tsn, last_cumulative_acked_tsn_) <= + kMaxAcceptedOutstandingFragments); + + // Old chunk already seen before? + if (unwrapped_tsn <= last_cumulative_acked_tsn_) { + if (duplicate_tsns_.size() < kMaxDuplicateTsnReported) { + duplicate_tsns_.insert(unwrapped_tsn.Wrap()); + } + // https://datatracker.ietf.org/doc/html/rfc4960#section-6.2 + // "When a packet arrives with duplicate DATA chunk(s) and with no new DATA + // chunk(s), the endpoint MUST immediately send a SACK with no delay. If a + // packet arrives with duplicate DATA chunk(s) bundled with new DATA chunks, + // the endpoint MAY immediately send a SACK." + UpdateAckState(AckState::kImmediate, "duplicate data"); + is_duplicate = true; + } else { + if (unwrapped_tsn == last_cumulative_acked_tsn_.next_value()) { + last_cumulative_acked_tsn_ = unwrapped_tsn; + // The cumulative acked tsn may be moved even further, if a gap was + // filled. + if (!additional_tsn_blocks_.empty() && + additional_tsn_blocks_.front().first == + last_cumulative_acked_tsn_.next_value()) { + last_cumulative_acked_tsn_ = additional_tsn_blocks_.front().last; + additional_tsn_blocks_.PopFront(); + } + } else { + bool inserted = additional_tsn_blocks_.Add(unwrapped_tsn); + if (!inserted) { + // Already seen before. + if (duplicate_tsns_.size() < kMaxDuplicateTsnReported) { + duplicate_tsns_.insert(unwrapped_tsn.Wrap()); + } + // https://datatracker.ietf.org/doc/html/rfc4960#section-6.2 + // "When a packet arrives with duplicate DATA chunk(s) and with no new + // DATA chunk(s), the endpoint MUST immediately send a SACK with no + // delay. If a packet arrives with duplicate DATA chunk(s) bundled with + // new DATA chunks, the endpoint MAY immediately send a SACK." + // No need to do this. SACKs are sent immediately on packet loss below. + is_duplicate = true; + } + } + } + + // https://tools.ietf.org/html/rfc4960#section-6.7 + // "Upon the reception of a new DATA chunk, an endpoint shall examine the + // continuity of the TSNs received. If the endpoint detects a gap in + // the received DATA chunk sequence, it SHOULD send a SACK with Gap Ack + // Blocks immediately. The data receiver continues sending a SACK after + // receipt of each SCTP packet that doesn't fill the gap." + if (!additional_tsn_blocks_.empty()) { + UpdateAckState(AckState::kImmediate, "packet loss"); + } + + // https://tools.ietf.org/html/rfc7053#section-5.2 + // "Upon receipt of an SCTP packet containing a DATA chunk with the I + // bit set, the receiver SHOULD NOT delay the sending of the corresponding + // SACK chunk, i.e., the receiver SHOULD immediately respond with the + // corresponding SACK chunk." + if (*immediate_ack) { + UpdateAckState(AckState::kImmediate, "immediate-ack bit set"); + } + + if (!seen_packet_) { + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "After the reception of the first DATA chunk in an association the + // endpoint MUST immediately respond with a SACK to acknowledge the DATA + // chunk." + seen_packet_ = true; + UpdateAckState(AckState::kImmediate, "first DATA chunk"); + } + + // https://tools.ietf.org/html/rfc4960#section-6.2 + // "Specifically, an acknowledgement SHOULD be generated for at least + // every second packet (not every second DATA chunk) received, and SHOULD be + // generated within 200 ms of the arrival of any unacknowledged DATA chunk." + if (ack_state_ == AckState::kIdle) { + UpdateAckState(AckState::kBecomingDelayed, "received DATA when idle"); + } else if (ack_state_ == AckState::kDelayed) { + UpdateAckState(AckState::kImmediate, "received DATA when already delayed"); + } + return !is_duplicate; +} + +void DataTracker::HandleForwardTsn(TSN new_cumulative_ack) { + // ForwardTSN is sent to make the receiver (this socket) "forget" about partly + // received (or not received at all) data, up until `new_cumulative_ack`. + + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.Unwrap(new_cumulative_ack); + UnwrappedTSN prev_last_cum_ack_tsn = last_cumulative_acked_tsn_; + + // Old chunk already seen before? + if (unwrapped_tsn <= last_cumulative_acked_tsn_) { + // https://tools.ietf.org/html/rfc3758#section-3.6 + // "Note, if the "New Cumulative TSN" value carried in the arrived + // FORWARD TSN chunk is found to be behind or at the current cumulative TSN + // point, the data receiver MUST treat this FORWARD TSN as out-of-date and + // MUST NOT update its Cumulative TSN. The receiver SHOULD send a SACK to + // its peer (the sender of the FORWARD TSN) since such a duplicate may + // indicate the previous SACK was lost in the network." + UpdateAckState(AckState::kImmediate, + "FORWARD_TSN new_cumulative_tsn was behind"); + return; + } + + // https://tools.ietf.org/html/rfc3758#section-3.6 + // "When a FORWARD TSN chunk arrives, the data receiver MUST first update + // its cumulative TSN point to the value carried in the FORWARD TSN chunk, and + // then MUST further advance its cumulative TSN point locally if possible, as + // shown by the following example..." + + // The `new_cumulative_ack` will become the current + // `last_cumulative_acked_tsn_`, and if there have been prior "gaps" that are + // now overlapping with the new value, remove them. + last_cumulative_acked_tsn_ = unwrapped_tsn; + additional_tsn_blocks_.EraseTo(unwrapped_tsn); + + // See if the `last_cumulative_acked_tsn_` can be moved even further: + if (!additional_tsn_blocks_.empty() && + additional_tsn_blocks_.front().first == + last_cumulative_acked_tsn_.next_value()) { + last_cumulative_acked_tsn_ = additional_tsn_blocks_.front().last; + additional_tsn_blocks_.PopFront(); + } + + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "FORWARD_TSN, cum_ack_tsn=" + << *prev_last_cum_ack_tsn.Wrap() << "->" + << *new_cumulative_ack << "->" + << *last_cumulative_acked_tsn_.Wrap(); + + // https://tools.ietf.org/html/rfc3758#section-3.6 + // "Any time a FORWARD TSN chunk arrives, for the purposes of sending a + // SACK, the receiver MUST follow the same rules as if a DATA chunk had been + // received (i.e., follow the delayed sack rules specified in ..." + if (ack_state_ == AckState::kIdle) { + UpdateAckState(AckState::kBecomingDelayed, + "received FORWARD_TSN when idle"); + } else if (ack_state_ == AckState::kDelayed) { + UpdateAckState(AckState::kImmediate, + "received FORWARD_TSN when already delayed"); + } +} + +SackChunk DataTracker::CreateSelectiveAck(size_t a_rwnd) { + // Note that in SCTP, the receiver side is allowed to discard received data + // and signal that to the sender, but only chunks that have previously been + // reported in the gap-ack-blocks. However, this implementation will never do + // that. So this SACK produced is more like a NR-SACK as explained in + // https://ieeexplore.ieee.org/document/4697037 and which there is an RFC + // draft at https://tools.ietf.org/html/draft-tuexen-tsvwg-sctp-multipath-17. + std::set<TSN> duplicate_tsns; + duplicate_tsns_.swap(duplicate_tsns); + + return SackChunk(last_cumulative_acked_tsn_.Wrap(), a_rwnd, + CreateGapAckBlocks(), std::move(duplicate_tsns)); +} + +std::vector<SackChunk::GapAckBlock> DataTracker::CreateGapAckBlocks() const { + const auto& blocks = additional_tsn_blocks_.blocks(); + std::vector<SackChunk::GapAckBlock> gap_ack_blocks; + gap_ack_blocks.reserve(std::min(blocks.size(), kMaxGapAckBlocksReported)); + for (size_t i = 0; i < blocks.size() && i < kMaxGapAckBlocksReported; ++i) { + auto start_diff = + UnwrappedTSN::Difference(blocks[i].first, last_cumulative_acked_tsn_); + auto end_diff = + UnwrappedTSN::Difference(blocks[i].last, last_cumulative_acked_tsn_); + gap_ack_blocks.emplace_back(static_cast<uint16_t>(start_diff), + static_cast<uint16_t>(end_diff)); + } + + return gap_ack_blocks; +} + +bool DataTracker::ShouldSendAck(bool also_if_delayed) { + if (ack_state_ == AckState::kImmediate || + (also_if_delayed && (ack_state_ == AckState::kBecomingDelayed || + ack_state_ == AckState::kDelayed))) { + UpdateAckState(AckState::kIdle, "sending SACK"); + return true; + } + + return false; +} + +bool DataTracker::will_increase_cum_ack_tsn(TSN tsn) const { + UnwrappedTSN unwrapped = tsn_unwrapper_.PeekUnwrap(tsn); + return unwrapped == last_cumulative_acked_tsn_.next_value(); +} + +void DataTracker::ForceImmediateSack() { + ack_state_ = AckState::kImmediate; +} + +void DataTracker::HandleDelayedAckTimerExpiry() { + UpdateAckState(AckState::kImmediate, "delayed ack timer expired"); +} + +void DataTracker::ObservePacketEnd() { + if (ack_state_ == AckState::kBecomingDelayed) { + UpdateAckState(AckState::kDelayed, "packet end"); + } +} + +void DataTracker::UpdateAckState(AckState new_state, absl::string_view reason) { + if (new_state != ack_state_) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "State changed from " + << ToString(ack_state_) << " to " + << ToString(new_state) << " due to " << reason; + if (ack_state_ == AckState::kDelayed) { + delayed_ack_timer_.Stop(); + } else if (new_state == AckState::kDelayed) { + delayed_ack_timer_.Start(); + } + ack_state_ = new_state; + } +} + +absl::string_view DataTracker::ToString(AckState ack_state) { + switch (ack_state) { + case AckState::kIdle: + return "IDLE"; + case AckState::kBecomingDelayed: + return "BECOMING_DELAYED"; + case AckState::kDelayed: + return "DELAYED"; + case AckState::kImmediate: + return "IMMEDIATE"; + } +} + +HandoverReadinessStatus DataTracker::GetHandoverReadiness() const { + HandoverReadinessStatus status; + if (!additional_tsn_blocks_.empty()) { + status.Add(HandoverUnreadinessReason::kDataTrackerTsnBlocksPending); + } + return status; +} + +void DataTracker::AddHandoverState(DcSctpSocketHandoverState& state) { + state.rx.last_cumulative_acked_tsn = last_cumulative_acked_tsn().value(); + state.rx.seen_packet = seen_packet_; +} + +void DataTracker::RestoreFromState(const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(additional_tsn_blocks_.empty()); + RTC_DCHECK(duplicate_tsns_.empty()); + RTC_DCHECK(!seen_packet_); + + seen_packet_ = state.rx.seen_packet; + last_cumulative_acked_tsn_ = + tsn_unwrapper_.Unwrap(TSN(state.rx.last_cumulative_acked_tsn)); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/data_tracker.h b/third_party/libwebrtc/net/dcsctp/rx/data_tracker.h new file mode 100644 index 0000000000..ea077a9b57 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/data_tracker.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_DATA_TRACKER_H_ +#define NET_DCSCTP_RX_DATA_TRACKER_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <cstdint> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_handover_state.h" +#include "net/dcsctp/timer/timer.h" + +namespace dcsctp { + +// Keeps track of received DATA chunks and handles all logic for _when_ to +// create SACKs and also _how_ to generate them. +// +// It only uses TSNs to track delivery and doesn't need to be aware of streams. +// +// SACKs are optimally sent every second packet on connections with no packet +// loss. When packet loss is detected, it's sent for every packet. When SACKs +// are not sent directly, a timer is used to send a SACK delayed (by RTO/2, or +// 200ms, whatever is smallest). +class DataTracker { + public: + // The maximum number of duplicate TSNs that will be reported in a SACK. + static constexpr size_t kMaxDuplicateTsnReported = 20; + // The maximum number of gap-ack-blocks that will be reported in a SACK. + static constexpr size_t kMaxGapAckBlocksReported = 20; + + // The maximum number of accepted in-flight DATA chunks. This indicates the + // maximum difference from this buffer's last cumulative ack TSN, and any + // received data. Data received beyond this limit will be dropped, which will + // force the transmitter to send data that actually increases the last + // cumulative acked TSN. + static constexpr uint32_t kMaxAcceptedOutstandingFragments = 100000; + + DataTracker(absl::string_view log_prefix, + Timer* delayed_ack_timer, + TSN peer_initial_tsn) + : log_prefix_(std::string(log_prefix) + "dtrack: "), + seen_packet_(false), + delayed_ack_timer_(*delayed_ack_timer), + last_cumulative_acked_tsn_( + tsn_unwrapper_.Unwrap(TSN(*peer_initial_tsn - 1))) {} + + // Indicates if the provided TSN is valid. If this return false, the data + // should be dropped and not added to any other buffers, which essentially + // means that there is intentional packet loss. + bool IsTSNValid(TSN tsn) const; + + // Call for every incoming data chunk. Returns `true` if `tsn` was seen for + // the first time, and `false` if it has been seen before (a duplicate `tsn`). + bool Observe(TSN tsn, + AnyDataChunk::ImmediateAckFlag immediate_ack = + AnyDataChunk::ImmediateAckFlag(false)); + // Called at the end of processing an SCTP packet. + void ObservePacketEnd(); + + // Called for incoming FORWARD-TSN/I-FORWARD-TSN chunks + void HandleForwardTsn(TSN new_cumulative_ack); + + // Indicates if a SACK should be sent. There may be other reasons to send a + // SACK, but if this function indicates so, it should be sent as soon as + // possible. Calling this function will make it clear a flag so that if it's + // called again, it will probably return false. + // + // If the delayed ack timer is running, this method will return false _unless_ + // `also_if_delayed` is set to true. Then it will return true as well. + bool ShouldSendAck(bool also_if_delayed = false); + + // Returns the last cumulative ack TSN - the last seen data chunk's TSN + // value before any packet loss was detected. + TSN last_cumulative_acked_tsn() const { + return TSN(last_cumulative_acked_tsn_.Wrap()); + } + + // Returns true if the received `tsn` would increase the cumulative ack TSN. + bool will_increase_cum_ack_tsn(TSN tsn) const; + + // Forces `ShouldSendSack` to return true. + void ForceImmediateSack(); + + // Note that this will clear `duplicates_`, so every SackChunk that is + // consumed must be sent. + SackChunk CreateSelectiveAck(size_t a_rwnd); + + void HandleDelayedAckTimerExpiry(); + + HandoverReadinessStatus GetHandoverReadiness() const; + + void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); + + private: + enum class AckState { + // No need to send an ACK. + kIdle, + + // Has received data chunks (but not yet end of packet). + kBecomingDelayed, + + // Has received data chunks and the end of a packet. Delayed ack timer is + // running and a SACK will be sent on expiry, or if DATA is sent, or after + // next packet with data. + kDelayed, + + // Send a SACK immediately after handling this packet. + kImmediate, + }; + + // Represents ranges of TSNs that have been received that are not directly + // following the last cumulative acked TSN. This information is returned to + // the sender in the "gap ack blocks" in the SACK chunk. The blocks are always + // non-overlapping and non-adjacent. + class AdditionalTsnBlocks { + public: + // Represents an inclusive range of received TSNs, i.e. [first, last]. + struct TsnRange { + TsnRange(UnwrappedTSN first, UnwrappedTSN last) + : first(first), last(last) {} + UnwrappedTSN first; + UnwrappedTSN last; + }; + + // Adds a TSN to the set. This will try to expand any existing block and + // might merge blocks to ensure that all blocks are non-adjacent. If a + // current block can't be expanded, a new block is created. + // + // The return value indicates if `tsn` was added. If false is returned, the + // `tsn` was already represented in one of the blocks. + bool Add(UnwrappedTSN tsn); + + // Erases all TSNs up to, and including `tsn`. This will remove all blocks + // that are completely below `tsn` and may truncate a block where `tsn` is + // within that block. In that case, the frontmost block's start TSN will be + // the next following tsn after `tsn`. + void EraseTo(UnwrappedTSN tsn); + + // Removes the first block. Must not be called on an empty set. + void PopFront(); + + const std::vector<TsnRange>& blocks() const { return blocks_; } + + bool empty() const { return blocks_.empty(); } + + const TsnRange& front() const { return blocks_.front(); } + + private: + // A sorted vector of non-overlapping and non-adjacent blocks. + std::vector<TsnRange> blocks_; + }; + + std::vector<SackChunk::GapAckBlock> CreateGapAckBlocks() const; + void UpdateAckState(AckState new_state, absl::string_view reason); + static absl::string_view ToString(AckState ack_state); + + const std::string log_prefix_; + // If a packet has ever been seen. + bool seen_packet_; + Timer& delayed_ack_timer_; + AckState ack_state_ = AckState::kIdle; + UnwrappedTSN::Unwrapper tsn_unwrapper_; + + // All TSNs up until (and including) this value have been seen. + UnwrappedTSN last_cumulative_acked_tsn_; + // Received TSNs that are not directly following `last_cumulative_acked_tsn_`. + AdditionalTsnBlocks additional_tsn_blocks_; + std::set<TSN> duplicate_tsns_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_DATA_TRACKER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/rx/data_tracker_test.cc b/third_party/libwebrtc/net/dcsctp/rx/data_tracker_test.cc new file mode 100644 index 0000000000..f74dd6eb0b --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/data_tracker_test.cc @@ -0,0 +1,739 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/data_tracker.h" + +#include <cstdint> +#include <initializer_list> +#include <memory> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/timer/fake_timeout.h" +#include "net/dcsctp/timer/timer.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +constexpr size_t kArwnd = 10000; +constexpr TSN kInitialTSN(11); + +class DataTrackerTest : public testing::Test { + protected: + DataTrackerTest() + : timeout_manager_([this]() { return now_; }), + timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return timeout_manager_.CreateTimeout(precision); + }), + timer_(timer_manager_.CreateTimer( + "test/delayed_ack", + []() { return absl::nullopt; }, + TimerOptions(DurationMs(0)))), + tracker_( + std::make_unique<DataTracker>("log: ", timer_.get(), kInitialTSN)) { + } + + void Observer(std::initializer_list<uint32_t> tsns, + bool expect_as_duplicate = false) { + for (const uint32_t tsn : tsns) { + if (expect_as_duplicate) { + EXPECT_FALSE( + tracker_->Observe(TSN(tsn), AnyDataChunk::ImmediateAckFlag(false))); + } else { + EXPECT_TRUE( + tracker_->Observe(TSN(tsn), AnyDataChunk::ImmediateAckFlag(false))); + } + } + } + + void HandoverTracker() { + EXPECT_TRUE(tracker_->GetHandoverReadiness().IsReady()); + DcSctpSocketHandoverState state; + tracker_->AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + tracker_ = + std::make_unique<DataTracker>("log: ", timer_.get(), kInitialTSN); + tracker_->RestoreFromState(state); + } + + TimeMs now_ = TimeMs(0); + FakeTimeoutManager timeout_manager_; + TimerManager timer_manager_; + std::unique_ptr<Timer> timer_; + std::unique_ptr<DataTracker> tracker_; +}; + +TEST_F(DataTrackerTest, Empty) { + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ObserverSingleInOrderPacket) { + Observer({11}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ObserverManyInOrderMovesCumulativeTsnAck) { + Observer({11, 12, 13}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(13)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ObserveOutOfOrderMovesCumulativeTsnAck) { + Observer({12, 13, 14, 11}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, SingleGap) { + Observer({12}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ExampleFromRFC4960Section334) { + Observer({11, 12, 14, 15, 17}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(12)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 5))); + EXPECT_THAT(sack.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, AckAlreadyReceivedChunk) { + Observer({11}); + SackChunk sack1 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack1.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack1.gap_ack_blocks(), IsEmpty()); + + // Receive old chunk + Observer({8}, /*expect_as_duplicate=*/true); + SackChunk sack2 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack2.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack2.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, DoubleSendRetransmittedChunk) { + Observer({11, 13, 14, 15}); + SackChunk sack1 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack1.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack1.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 4))); + + // Fill in the hole. + Observer({12, 16, 17, 18}); + SackChunk sack2 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack2.cumulative_tsn_ack(), TSN(18)); + EXPECT_THAT(sack2.gap_ack_blocks(), IsEmpty()); + + // Receive chunk 12 again. + Observer({12}, /*expect_as_duplicate=*/true); + Observer({19, 20, 21}); + SackChunk sack3 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack3.cumulative_tsn_ack(), TSN(21)); + EXPECT_THAT(sack3.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ForwardTsnSimple) { + // Messages (11, 12, 13), (14, 15) - first message expires. + Observer({11, 12, 15}); + + tracker_->HandleForwardTsn(TSN(13)); + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(13)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); +} + +TEST_F(DataTrackerTest, ForwardTsnSkipsFromGapBlock) { + // Messages (11, 12, 13), (14, 15) - first message expires. + Observer({11, 12, 14}); + + tracker_->HandleForwardTsn(TSN(13)); + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, ExampleFromRFC3758) { + tracker_->HandleForwardTsn(TSN(102)); + + Observer({102}, /*expect_as_duplicate=*/true); + Observer({104, 105, 107}); + + tracker_->HandleForwardTsn(TSN(103)); + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(105)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); +} + +TEST_F(DataTrackerTest, EmptyAllAcks) { + Observer({11, 13, 14, 15}); + + tracker_->HandleForwardTsn(TSN(100)); + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(100)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, SetsArwndCorrectly) { + SackChunk sack1 = tracker_->CreateSelectiveAck(/*a_rwnd=*/100); + EXPECT_EQ(sack1.a_rwnd(), 100u); + + SackChunk sack2 = tracker_->CreateSelectiveAck(/*a_rwnd=*/101); + EXPECT_EQ(sack2.a_rwnd(), 101u); +} + +TEST_F(DataTrackerTest, WillIncreaseCumAckTsn) { + EXPECT_EQ(tracker_->last_cumulative_acked_tsn(), TSN(10)); + EXPECT_FALSE(tracker_->will_increase_cum_ack_tsn(TSN(10))); + EXPECT_TRUE(tracker_->will_increase_cum_ack_tsn(TSN(11))); + EXPECT_FALSE(tracker_->will_increase_cum_ack_tsn(TSN(12))); + + Observer({11, 12, 13, 14, 15}); + EXPECT_EQ(tracker_->last_cumulative_acked_tsn(), TSN(15)); + EXPECT_FALSE(tracker_->will_increase_cum_ack_tsn(TSN(15))); + EXPECT_TRUE(tracker_->will_increase_cum_ack_tsn(TSN(16))); + EXPECT_FALSE(tracker_->will_increase_cum_ack_tsn(TSN(17))); +} + +TEST_F(DataTrackerTest, ForceShouldSendSackImmediately) { + EXPECT_FALSE(tracker_->ShouldSendAck()); + + tracker_->ForceImmediateSack(); + + EXPECT_TRUE(tracker_->ShouldSendAck()); +} + +TEST_F(DataTrackerTest, WillAcceptValidTSNs) { + // The initial TSN is always one more than the last, which is our base. + TSN last_tsn = TSN(*kInitialTSN - 1); + int limit = static_cast<int>(DataTracker::kMaxAcceptedOutstandingFragments); + + for (int i = -limit; i <= limit; ++i) { + EXPECT_TRUE(tracker_->IsTSNValid(TSN(*last_tsn + i))); + } +} + +TEST_F(DataTrackerTest, WillNotAcceptInvalidTSNs) { + // The initial TSN is always one more than the last, which is our base. + TSN last_tsn = TSN(*kInitialTSN - 1); + + size_t limit = DataTracker::kMaxAcceptedOutstandingFragments; + EXPECT_FALSE(tracker_->IsTSNValid(TSN(*last_tsn + limit + 1))); + EXPECT_FALSE(tracker_->IsTSNValid(TSN(*last_tsn - (limit + 1)))); + EXPECT_FALSE(tracker_->IsTSNValid(TSN(*last_tsn + 0x8000000))); + EXPECT_FALSE(tracker_->IsTSNValid(TSN(*last_tsn - 0x8000000))); +} + +TEST_F(DataTrackerTest, ReportSingleDuplicateTsns) { + Observer({11, 12}); + Observer({11}, /*expect_as_duplicate=*/true); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(12)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), UnorderedElementsAre(TSN(11))); +} + +TEST_F(DataTrackerTest, ReportMultipleDuplicateTsns) { + Observer({11, 12, 13, 14}); + Observer({12, 13, 12, 13}, /*expect_as_duplicate=*/true); + Observer({15, 16}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(16)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), UnorderedElementsAre(TSN(12), TSN(13))); +} + +TEST_F(DataTrackerTest, ReportDuplicateTsnsInGapAckBlocks) { + Observer({11, /*12,*/ 13, 14}); + Observer({13, 14}, /*expect_as_duplicate=*/true); + Observer({15, 16}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 5))); + EXPECT_THAT(sack.duplicate_tsns(), UnorderedElementsAre(TSN(13), TSN(14))); +} + +TEST_F(DataTrackerTest, ClearsDuplicateTsnsAfterCreatingSack) { + Observer({11, 12, 13, 14}); + Observer({12, 13, 12, 13}, /*expect_as_duplicate=*/true); + Observer({15, 16}); + SackChunk sack1 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack1.cumulative_tsn_ack(), TSN(16)); + EXPECT_THAT(sack1.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack1.duplicate_tsns(), UnorderedElementsAre(TSN(12), TSN(13))); + + Observer({17}); + SackChunk sack2 = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack2.cumulative_tsn_ack(), TSN(17)); + EXPECT_THAT(sack2.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack2.duplicate_tsns(), IsEmpty()); +} + +TEST_F(DataTrackerTest, LimitsNumberOfDuplicatesReported) { + for (size_t i = 0; i < DataTracker::kMaxDuplicateTsnReported + 10; ++i) { + TSN tsn(11 + i); + tracker_->Observe(tsn, AnyDataChunk::ImmediateAckFlag(false)); + tracker_->Observe(tsn, AnyDataChunk::ImmediateAckFlag(false)); + } + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); + EXPECT_THAT(sack.duplicate_tsns(), + SizeIs(DataTracker::kMaxDuplicateTsnReported)); +} + +TEST_F(DataTrackerTest, LimitsNumberOfGapAckBlocksReported) { + for (size_t i = 0; i < DataTracker::kMaxGapAckBlocksReported + 10; ++i) { + TSN tsn(11 + i * 2); + tracker_->Observe(tsn, AnyDataChunk::ImmediateAckFlag(false)); + } + + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack.gap_ack_blocks(), + SizeIs(DataTracker::kMaxGapAckBlocksReported)); +} + +TEST_F(DataTrackerTest, SendsSackForFirstPacketObserved) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, SendsSackEverySecondPacketWhenThereIsNoPacketLoss) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({12}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + Observer({13}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({14}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + Observer({15}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, SendsSackEveryPacketOnPacketLoss) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({13}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({14}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({15}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({16}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + // Fill the hole. + Observer({12}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + // Goes back to every second packet + Observer({17}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({18}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, SendsSackOnDuplicateDataChunks) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({11}, /*expect_as_duplicate=*/true); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({12}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + // Goes back to every second packet + Observer({13}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + // Duplicate again + Observer({12}, /*expect_as_duplicate=*/true); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, GapAckBlockAddSingleBlock) { + Observer({12}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); +} + +TEST_F(DataTrackerTest, GapAckBlockAddsAnother) { + Observer({12}); + Observer({14}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2), + SackChunk::GapAckBlock(4, 4))); +} + +TEST_F(DataTrackerTest, GapAckBlockAddsDuplicate) { + Observer({12}); + Observer({12}, /*expect_as_duplicate=*/true); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 2))); + EXPECT_THAT(sack.duplicate_tsns(), ElementsAre(TSN(12))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToRight) { + Observer({12}); + Observer({13}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 3))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToRightWithOther) { + Observer({12}); + Observer({20}); + Observer({30}); + Observer({21}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 2), // + SackChunk::GapAckBlock(10, 11), // + SackChunk::GapAckBlock(20, 20))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToLeft) { + Observer({13}); + Observer({12}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), ElementsAre(SackChunk::GapAckBlock(2, 3))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToLeftWithOther) { + Observer({12}); + Observer({21}); + Observer({30}); + Observer({20}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 2), // + SackChunk::GapAckBlock(10, 11), // + SackChunk::GapAckBlock(20, 20))); +} + +TEST_F(DataTrackerTest, GapAckBlockExpandsToLRightAndMerges) { + Observer({12}); + Observer({20}); + Observer({22}); + Observer({30}); + Observer({21}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(sack.gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 2), // + SackChunk::GapAckBlock(10, 12), // + SackChunk::GapAckBlock(20, 20))); +} + +TEST_F(DataTrackerTest, GapAckBlockMergesManyBlocksIntoOne) { + Observer({22}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12))); + Observer({30}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(20, 20))); + Observer({24}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(20, 20))); + Observer({28}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(18, 18), // + SackChunk::GapAckBlock(20, 20))); + Observer({26}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(16, 16), // + SackChunk::GapAckBlock(18, 18), // + SackChunk::GapAckBlock(20, 20))); + Observer({29}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 12), // + SackChunk::GapAckBlock(14, 14), // + SackChunk::GapAckBlock(16, 16), // + SackChunk::GapAckBlock(18, 20))); + Observer({23}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 14), // + SackChunk::GapAckBlock(16, 16), // + SackChunk::GapAckBlock(18, 20))); + Observer({27}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 14), // + SackChunk::GapAckBlock(16, 20))); + + Observer({25}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(12, 20))); + Observer({20}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 10), // + SackChunk::GapAckBlock(12, 20))); + Observer({32}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 10), // + SackChunk::GapAckBlock(12, 20), // + SackChunk::GapAckBlock(22, 22))); + Observer({21}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 20), // + SackChunk::GapAckBlock(22, 22))); + Observer({31}); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(10, 22))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveBeforeCumAckTsn) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(8)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(10)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 4), // + SackChunk::GapAckBlock(10, 12), + SackChunk::GapAckBlock(20, 21))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveBeforeFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(11)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveAtBeginningOfFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(12)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveAtMiddleOfFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + tracker_->HandleForwardTsn(TSN(13)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveAtEndOfFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + tracker_->HandleForwardTsn(TSN(14)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(14)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(6, 8), // + SackChunk::GapAckBlock(16, 17))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAfterFirstBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(18)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(18)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(2, 4), // + SackChunk::GapAckBlock(12, 13))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightBeforeSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(19)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAtStartOfSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(20)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAtMiddleOfSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(21)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveRightAtEndOfSecondBlock) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(22)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(22)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), + ElementsAre(SackChunk::GapAckBlock(8, 9))); +} + +TEST_F(DataTrackerTest, GapAckBlockRemoveeFarAfterAllBlocks) { + Observer({12, 13, 14, 20, 21, 22, 30, 31}); + + tracker_->HandleForwardTsn(TSN(40)); + EXPECT_EQ(tracker_->CreateSelectiveAck(kArwnd).cumulative_tsn_ack(), TSN(40)); + EXPECT_THAT(tracker_->CreateSelectiveAck(kArwnd).gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, HandoverEmpty) { + HandoverTracker(); + Observer({11}); + SackChunk sack = tracker_->CreateSelectiveAck(kArwnd); + EXPECT_EQ(sack.cumulative_tsn_ack(), TSN(11)); + EXPECT_THAT(sack.gap_ack_blocks(), IsEmpty()); +} + +TEST_F(DataTrackerTest, + HandoverWhileSendingSackEverySecondPacketWhenThereIsNoPacketLoss) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + + HandoverTracker(); + + Observer({12}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + Observer({13}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); + Observer({14}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); + Observer({15}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_FALSE(timer_->is_running()); +} + +TEST_F(DataTrackerTest, HandoverWhileSendingSackEveryPacketOnPacketLoss) { + Observer({11}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + Observer({13}); + EXPECT_EQ(tracker_->GetHandoverReadiness(), + HandoverReadinessStatus().Add( + HandoverUnreadinessReason::kDataTrackerTsnBlocksPending)); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + Observer({14}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + EXPECT_EQ(tracker_->GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kDataTrackerTsnBlocksPending)); + Observer({15}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + Observer({16}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + // Fill the hole. + Observer({12}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + // Goes back to every second packet + Observer({17}); + tracker_->ObservePacketEnd(); + EXPECT_TRUE(tracker_->ShouldSendAck()); + + HandoverTracker(); + + Observer({18}); + tracker_->ObservePacketEnd(); + EXPECT_FALSE(tracker_->ShouldSendAck()); + EXPECT_TRUE(timer_->is_running()); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.cc b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.cc new file mode 100644 index 0000000000..8b316de676 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.cc @@ -0,0 +1,272 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/interleaved_reassembly_streams.h" + +#include <stddef.h> + +#include <cstdint> +#include <functional> +#include <iterator> +#include <map> +#include <numeric> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/types.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +InterleavedReassemblyStreams::InterleavedReassemblyStreams( + absl::string_view log_prefix, + OnAssembledMessage on_assembled_message) + : log_prefix_(log_prefix), on_assembled_message_(on_assembled_message) {} + +size_t InterleavedReassemblyStreams::Stream::TryToAssembleMessage( + UnwrappedMID mid) { + std::map<UnwrappedMID, ChunkMap>::const_iterator it = + chunks_by_mid_.find(mid); + if (it == chunks_by_mid_.end()) { + RTC_DLOG(LS_VERBOSE) << parent_.log_prefix_ << "TryToAssembleMessage " + << *mid.Wrap() << " - no chunks"; + return 0; + } + const ChunkMap& chunks = it->second; + if (!chunks.begin()->second.second.is_beginning || + !chunks.rbegin()->second.second.is_end) { + RTC_DLOG(LS_VERBOSE) << parent_.log_prefix_ << "TryToAssembleMessage " + << *mid.Wrap() << "- missing beginning or end"; + return 0; + } + int64_t fsn_diff = *chunks.rbegin()->first - *chunks.begin()->first; + if (fsn_diff != (static_cast<int64_t>(chunks.size()) - 1)) { + RTC_DLOG(LS_VERBOSE) << parent_.log_prefix_ << "TryToAssembleMessage " + << *mid.Wrap() << "- not all chunks exist (have " + << chunks.size() << ", expect " << (fsn_diff + 1) + << ")"; + return 0; + } + + size_t removed_bytes = AssembleMessage(chunks); + RTC_DLOG(LS_VERBOSE) << parent_.log_prefix_ << "TryToAssembleMessage " + << *mid.Wrap() << " - succeeded and removed " + << removed_bytes; + + chunks_by_mid_.erase(mid); + return removed_bytes; +} + +size_t InterleavedReassemblyStreams::Stream::AssembleMessage( + const ChunkMap& tsn_chunks) { + size_t count = tsn_chunks.size(); + if (count == 1) { + // Fast path - zero-copy + const Data& data = tsn_chunks.begin()->second.second; + size_t payload_size = data.size(); + UnwrappedTSN tsns[1] = {tsn_chunks.begin()->second.first}; + DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload)); + parent_.on_assembled_message_(tsns, std::move(message)); + return payload_size; + } + + // Slow path - will need to concatenate the payload. + std::vector<UnwrappedTSN> tsns; + tsns.reserve(count); + + std::vector<uint8_t> payload; + size_t payload_size = absl::c_accumulate( + tsn_chunks, 0, + [](size_t v, const auto& p) { return v + p.second.second.size(); }); + payload.reserve(payload_size); + + for (auto& item : tsn_chunks) { + const UnwrappedTSN tsn = item.second.first; + const Data& data = item.second.second; + tsns.push_back(tsn); + payload.insert(payload.end(), data.payload.begin(), data.payload.end()); + } + + const Data& data = tsn_chunks.begin()->second.second; + + DcSctpMessage message(data.stream_id, data.ppid, std::move(payload)); + parent_.on_assembled_message_(tsns, std::move(message)); + return payload_size; +} + +size_t InterleavedReassemblyStreams::Stream::EraseTo(MID message_id) { + UnwrappedMID unwrapped_mid = mid_unwrapper_.Unwrap(message_id); + + size_t removed_bytes = 0; + auto it = chunks_by_mid_.begin(); + while (it != chunks_by_mid_.end() && it->first <= unwrapped_mid) { + removed_bytes += absl::c_accumulate( + it->second, 0, + [](size_t r2, const auto& q) { return r2 + q.second.second.size(); }); + it = chunks_by_mid_.erase(it); + } + + if (!stream_id_.unordered) { + // For ordered streams, erasing a message might suddenly unblock that queue + // and allow it to deliver any following received messages. + if (unwrapped_mid >= next_mid_) { + next_mid_ = unwrapped_mid.next_value(); + } + + removed_bytes += TryToAssembleMessages(); + } + + return removed_bytes; +} + +int InterleavedReassemblyStreams::Stream::Add(UnwrappedTSN tsn, Data data) { + RTC_DCHECK_EQ(*data.is_unordered, *stream_id_.unordered); + RTC_DCHECK_EQ(*data.stream_id, *stream_id_.stream_id); + int queued_bytes = data.size(); + UnwrappedMID mid = mid_unwrapper_.Unwrap(data.message_id); + FSN fsn = data.fsn; + auto [unused, inserted] = + chunks_by_mid_[mid].emplace(fsn, std::make_pair(tsn, std::move(data))); + if (!inserted) { + return 0; + } + + if (stream_id_.unordered) { + queued_bytes -= TryToAssembleMessage(mid); + } else { + if (mid == next_mid_) { + queued_bytes -= TryToAssembleMessages(); + } + } + + return queued_bytes; +} + +size_t InterleavedReassemblyStreams::Stream::TryToAssembleMessages() { + size_t removed_bytes = 0; + + for (;;) { + size_t removed_bytes_this_iter = TryToAssembleMessage(next_mid_); + if (removed_bytes_this_iter == 0) { + break; + } + + removed_bytes += removed_bytes_this_iter; + next_mid_.Increment(); + } + return removed_bytes; +} + +void InterleavedReassemblyStreams::Stream::AddHandoverState( + DcSctpSocketHandoverState& state) const { + if (stream_id_.unordered) { + DcSctpSocketHandoverState::UnorderedStream state_stream; + state_stream.id = stream_id_.stream_id.value(); + state.rx.unordered_streams.push_back(std::move(state_stream)); + } else { + DcSctpSocketHandoverState::OrderedStream state_stream; + state_stream.id = stream_id_.stream_id.value(); + state_stream.next_ssn = next_mid_.Wrap().value(); + state.rx.ordered_streams.push_back(std::move(state_stream)); + } +} + +InterleavedReassemblyStreams::Stream& +InterleavedReassemblyStreams::GetOrCreateStream(const FullStreamId& stream_id) { + auto it = streams_.find(stream_id); + if (it == streams_.end()) { + it = + streams_ + .emplace(std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple(stream_id, this)) + .first; + } + return it->second; +} + +int InterleavedReassemblyStreams::Add(UnwrappedTSN tsn, Data data) { + return GetOrCreateStream(FullStreamId(data.is_unordered, data.stream_id)) + .Add(tsn, std::move(data)); +} + +size_t InterleavedReassemblyStreams::HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) { + size_t removed_bytes = 0; + for (const auto& skipped : skipped_streams) { + removed_bytes += + GetOrCreateStream(FullStreamId(skipped.unordered, skipped.stream_id)) + .EraseTo(skipped.message_id); + } + return removed_bytes; +} + +void InterleavedReassemblyStreams::ResetStreams( + rtc::ArrayView<const StreamID> stream_ids) { + if (stream_ids.empty()) { + for (auto& entry : streams_) { + entry.second.Reset(); + } + } else { + for (StreamID stream_id : stream_ids) { + GetOrCreateStream(FullStreamId(IsUnordered(true), stream_id)).Reset(); + GetOrCreateStream(FullStreamId(IsUnordered(false), stream_id)).Reset(); + } + } +} + +HandoverReadinessStatus InterleavedReassemblyStreams::GetHandoverReadiness() + const { + HandoverReadinessStatus status; + for (const auto& [stream_id, stream] : streams_) { + if (stream.has_unassembled_chunks()) { + status.Add( + stream_id.unordered + ? HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks + : HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks); + break; + } + } + return status; +} + +void InterleavedReassemblyStreams::AddHandoverState( + DcSctpSocketHandoverState& state) { + for (const auto& [unused, stream] : streams_) { + stream.AddHandoverState(state); + } +} + +void InterleavedReassemblyStreams::RestoreFromState( + const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(streams_.empty()); + + for (const DcSctpSocketHandoverState::OrderedStream& state : + state.rx.ordered_streams) { + FullStreamId stream_id(IsUnordered(false), StreamID(state.id)); + streams_.emplace( + std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple(stream_id, this, MID(state.next_ssn))); + } + for (const DcSctpSocketHandoverState::UnorderedStream& state : + state.rx.unordered_streams) { + FullStreamId stream_id(IsUnordered(true), StreamID(state.id)); + streams_.emplace(std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple(stream_id, this)); + } +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.h b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.h new file mode 100644 index 0000000000..a7b67707e9 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams.h @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_INTERLEAVED_REASSEMBLY_STREAMS_H_ +#define NET_DCSCTP_RX_INTERLEAVED_REASSEMBLY_STREAMS_H_ + +#include <cstdint> +#include <map> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/rx/reassembly_streams.h" + +namespace dcsctp { + +// Handles reassembly of incoming data when interleaved message sending is +// enabled on the association, i.e. when RFC8260 is in use. +class InterleavedReassemblyStreams : public ReassemblyStreams { + public: + InterleavedReassemblyStreams(absl::string_view log_prefix, + OnAssembledMessage on_assembled_message); + + int Add(UnwrappedTSN tsn, Data data) override; + + size_t HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) + override; + + void ResetStreams(rtc::ArrayView<const StreamID> stream_ids) override; + + HandoverReadinessStatus GetHandoverReadiness() const override; + void AddHandoverState(DcSctpSocketHandoverState& state) override; + void RestoreFromState(const DcSctpSocketHandoverState& state) override; + + private: + struct FullStreamId { + const IsUnordered unordered; + const StreamID stream_id; + + FullStreamId(IsUnordered unordered, StreamID stream_id) + : unordered(unordered), stream_id(stream_id) {} + + friend bool operator<(FullStreamId a, FullStreamId b) { + return a.unordered < b.unordered || + (!(a.unordered < b.unordered) && (a.stream_id < b.stream_id)); + } + }; + + class Stream { + public: + Stream(FullStreamId stream_id, + InterleavedReassemblyStreams* parent, + MID next_mid = MID(0)) + : stream_id_(stream_id), + parent_(*parent), + next_mid_(mid_unwrapper_.Unwrap(next_mid)) {} + int Add(UnwrappedTSN tsn, Data data); + size_t EraseTo(MID message_id); + void Reset() { + mid_unwrapper_.Reset(); + next_mid_ = mid_unwrapper_.Unwrap(MID(0)); + } + bool has_unassembled_chunks() const { return !chunks_by_mid_.empty(); } + void AddHandoverState(DcSctpSocketHandoverState& state) const; + + private: + using ChunkMap = std::map<FSN, std::pair<UnwrappedTSN, Data>>; + + // Try to assemble one message identified by `mid`. + // Returns the number of bytes assembled if a message was assembled. + size_t TryToAssembleMessage(UnwrappedMID mid); + size_t AssembleMessage(const ChunkMap& tsn_chunks); + // Try to assemble one or several messages in order from the stream. + // Returns the number of bytes assembled if one or more messages were + // assembled. + size_t TryToAssembleMessages(); + + const FullStreamId stream_id_; + InterleavedReassemblyStreams& parent_; + std::map<UnwrappedMID, ChunkMap> chunks_by_mid_; + UnwrappedMID::Unwrapper mid_unwrapper_; + UnwrappedMID next_mid_; + }; + + Stream& GetOrCreateStream(const FullStreamId& stream_id); + + const std::string log_prefix_; + + // Callback for when a message has been assembled. + const OnAssembledMessage on_assembled_message_; + + // All unordered and ordered streams, managing not-yet-assembled data. + std::map<FullStreamId, Stream> streams_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_INTERLEAVED_REASSEMBLY_STREAMS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams_test.cc b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams_test.cc new file mode 100644 index 0000000000..df4024ed60 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/interleaved_reassembly_streams_test.cc @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/interleaved_reassembly_streams.h" + +#include <cstdint> +#include <memory> +#include <utility> + +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/rx/reassembly_streams.h" +#include "net/dcsctp/testing/data_generator.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::MockFunction; +using ::testing::NiceMock; + +class InterleavedReassemblyStreamsTest : public testing::Test { + protected: + UnwrappedTSN tsn(uint32_t value) { return tsn_.Unwrap(TSN(value)); } + + InterleavedReassemblyStreamsTest() {} + DataGenerator gen_; + UnwrappedTSN::Unwrapper tsn_; +}; + +TEST_F(InterleavedReassemblyStreamsTest, + AddUnorderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Unordered({5, 6})), 2); + // Adding the end fragment should make it empty again. + EXPECT_EQ(streams.Add(tsn(4), gen_.Unordered({7}, "E")), -6); +} + +TEST_F(InterleavedReassemblyStreamsTest, + AddSimpleOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), -6); +} + +TEST_F(InterleavedReassemblyStreamsTest, + AddMoreComplexOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + Data late = gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + EXPECT_EQ(streams.Add(tsn(2), std::move(late)), -8); +} + +TEST_F(InterleavedReassemblyStreamsTest, + DeleteUnorderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Unordered({5, 6})), 2); + + IForwardTsnChunk::SkippedStream skipped[] = { + IForwardTsnChunk::SkippedStream(IsUnordered(true), StreamID(1), MID(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(3), skipped), 6u); +} + +TEST_F(InterleavedReassemblyStreamsTest, + DeleteSimpleOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + + IForwardTsnChunk::SkippedStream skipped[] = { + IForwardTsnChunk::SkippedStream(IsUnordered(false), StreamID(1), MID(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(3), skipped), 6u); +} + +TEST_F(InterleavedReassemblyStreamsTest, + DeleteManyOrderedMessagesReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + + // Expire all three messages + IForwardTsnChunk::SkippedStream skipped[] = { + IForwardTsnChunk::SkippedStream(IsUnordered(false), StreamID(1), MID(2))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(8), skipped), 8u); +} + +TEST_F(InterleavedReassemblyStreamsTest, + DeleteOrderedMessageDelivesTwoReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + InterleavedReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + + // The first ordered message expire, and the following two are delivered. + IForwardTsnChunk::SkippedStream skipped[] = { + IForwardTsnChunk::SkippedStream(IsUnordered(false), StreamID(1), MID(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(4), skipped), 8u); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.cc b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.cc new file mode 100644 index 0000000000..f72c5cb8c1 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.cc @@ -0,0 +1,312 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/reassembly_queue.h" + +#include <stddef.h> + +#include <algorithm> +#include <cstdint> +#include <memory> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/rx/interleaved_reassembly_streams.h" +#include "net/dcsctp/rx/reassembly_streams.h" +#include "net/dcsctp/rx/traditional_reassembly_streams.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace { +std::unique_ptr<ReassemblyStreams> CreateStreams( + absl::string_view log_prefix, + ReassemblyStreams::OnAssembledMessage on_assembled_message, + bool use_message_interleaving) { + if (use_message_interleaving) { + return std::make_unique<InterleavedReassemblyStreams>( + log_prefix, std::move(on_assembled_message)); + } + return std::make_unique<TraditionalReassemblyStreams>( + log_prefix, std::move(on_assembled_message)); +} +} // namespace + +ReassemblyQueue::ReassemblyQueue(absl::string_view log_prefix, + TSN peer_initial_tsn, + size_t max_size_bytes, + bool use_message_interleaving) + : log_prefix_(std::string(log_prefix) + "reasm: "), + max_size_bytes_(max_size_bytes), + watermark_bytes_(max_size_bytes * kHighWatermarkLimit), + last_assembled_tsn_watermark_( + tsn_unwrapper_.Unwrap(TSN(*peer_initial_tsn - 1))), + last_completed_reset_req_seq_nbr_(ReconfigRequestSN(0)), + streams_(CreateStreams( + log_prefix_, + [this](rtc::ArrayView<const UnwrappedTSN> tsns, + DcSctpMessage message) { + AddReassembledMessage(tsns, std::move(message)); + }, + use_message_interleaving)) {} + +void ReassemblyQueue::Add(TSN tsn, Data data) { + RTC_DCHECK(IsConsistent()); + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "added tsn=" << *tsn + << ", stream=" << *data.stream_id << ":" + << *data.message_id << ":" << *data.fsn << ", type=" + << (data.is_beginning && data.is_end ? "complete" + : data.is_beginning ? "first" + : data.is_end ? "last" + : "middle"); + + UnwrappedTSN unwrapped_tsn = tsn_unwrapper_.Unwrap(tsn); + + if (unwrapped_tsn <= last_assembled_tsn_watermark_ || + delivered_tsns_.find(unwrapped_tsn) != delivered_tsns_.end()) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Chunk has already been delivered - skipping"; + return; + } + + // If a stream reset has been received with a "sender's last assigned tsn" in + // the future, the socket is in "deferred reset processing" mode and must + // buffer chunks until it's exited. + if (deferred_reset_streams_.has_value() && + unwrapped_tsn > + tsn_unwrapper_.Unwrap( + deferred_reset_streams_->req.sender_last_assigned_tsn())) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Deferring chunk with tsn=" << *tsn + << " until cum_ack_tsn=" + << *deferred_reset_streams_->req.sender_last_assigned_tsn(); + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "In this mode, any data arriving with a TSN larger than the + // Sender's Last Assigned TSN for the affected stream(s) MUST be queued + // locally and held until the cumulative acknowledgment point reaches the + // Sender's Last Assigned TSN." + queued_bytes_ += data.size(); + deferred_reset_streams_->deferred_chunks.emplace_back( + std::make_pair(tsn, std::move(data))); + } else { + queued_bytes_ += streams_->Add(unwrapped_tsn, std::move(data)); + } + + // https://tools.ietf.org/html/rfc4960#section-6.9 + // "Note: If the data receiver runs out of buffer space while still + // waiting for more fragments to complete the reassembly of the message, it + // should dispatch part of its inbound message through a partial delivery + // API (see Section 10), freeing some of its receive buffer space so that + // the rest of the message may be received." + + // TODO(boivie): Support EOR flag and partial delivery? + RTC_DCHECK(IsConsistent()); +} + +ReconfigurationResponseParameter::Result ReassemblyQueue::ResetStreams( + const OutgoingSSNResetRequestParameter& req, + TSN cum_tsn_ack) { + RTC_DCHECK(IsConsistent()); + if (deferred_reset_streams_.has_value()) { + // In deferred mode already. + return ReconfigurationResponseParameter::Result::kInProgress; + } else if (req.request_sequence_number() <= + last_completed_reset_req_seq_nbr_) { + // Already performed at some time previously. + return ReconfigurationResponseParameter::Result::kSuccessPerformed; + } + + UnwrappedTSN sla_tsn = tsn_unwrapper_.Unwrap(req.sender_last_assigned_tsn()); + UnwrappedTSN unwrapped_cum_tsn_ack = tsn_unwrapper_.Unwrap(cum_tsn_ack); + + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "If the Sender's Last Assigned TSN is greater than the + // cumulative acknowledgment point, then the endpoint MUST enter "deferred + // reset processing"." + if (sla_tsn > unwrapped_cum_tsn_ack) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ + << "Entering deferred reset processing mode until cum_tsn_ack=" + << *req.sender_last_assigned_tsn(); + deferred_reset_streams_ = absl::make_optional<DeferredResetStreams>(req); + return ReconfigurationResponseParameter::Result::kInProgress; + } + + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "... streams MUST be reset to 0 as the next expected SSN." + streams_->ResetStreams(req.stream_ids()); + last_completed_reset_req_seq_nbr_ = req.request_sequence_number(); + RTC_DCHECK(IsConsistent()); + return ReconfigurationResponseParameter::Result::kSuccessPerformed; +} + +bool ReassemblyQueue::MaybeResetStreamsDeferred(TSN cum_ack_tsn) { + RTC_DCHECK(IsConsistent()); + if (deferred_reset_streams_.has_value()) { + UnwrappedTSN unwrapped_cum_ack_tsn = tsn_unwrapper_.Unwrap(cum_ack_tsn); + UnwrappedTSN unwrapped_sla_tsn = tsn_unwrapper_.Unwrap( + deferred_reset_streams_->req.sender_last_assigned_tsn()); + if (unwrapped_cum_ack_tsn >= unwrapped_sla_tsn) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Leaving deferred reset processing with tsn=" + << *cum_ack_tsn << ", feeding back " + << deferred_reset_streams_->deferred_chunks.size() + << " chunks"; + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "... streams MUST be reset to 0 as the next expected SSN." + streams_->ResetStreams(deferred_reset_streams_->req.stream_ids()); + std::vector<std::pair<TSN, Data>> deferred_chunks = + std::move(deferred_reset_streams_->deferred_chunks); + // The response will not be sent now, but as a reply to the retried + // request, which will come as "in progress" has been sent prior. + last_completed_reset_req_seq_nbr_ = + deferred_reset_streams_->req.request_sequence_number(); + deferred_reset_streams_ = absl::nullopt; + + // https://tools.ietf.org/html/rfc6525#section-5.2.2 + // "Any queued TSNs (queued at step E2) MUST now be released and processed + // normally." + for (auto& [tsn, data] : deferred_chunks) { + queued_bytes_ -= data.size(); + Add(tsn, std::move(data)); + } + + RTC_DCHECK(IsConsistent()); + return true; + } else { + RTC_DLOG(LS_VERBOSE) << "Staying in deferred reset processing. tsn=" + << *cum_ack_tsn; + } + } + + return false; +} + +std::vector<DcSctpMessage> ReassemblyQueue::FlushMessages() { + std::vector<DcSctpMessage> ret; + reassembled_messages_.swap(ret); + return ret; +} + +void ReassemblyQueue::AddReassembledMessage( + rtc::ArrayView<const UnwrappedTSN> tsns, + DcSctpMessage message) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Assembled message from TSN=[" + << StrJoin(tsns, ",", + [](rtc::StringBuilder& sb, UnwrappedTSN tsn) { + sb << *tsn.Wrap(); + }) + << "], message; stream_id=" << *message.stream_id() + << ", ppid=" << *message.ppid() + << ", payload=" << message.payload().size() << " bytes"; + + for (const UnwrappedTSN tsn : tsns) { + if (tsn <= last_assembled_tsn_watermark_) { + // This can be provoked by a misbehaving peer by sending FORWARD-TSN with + // invalid SSNs, allowing ordered messages to stay in the queue that + // should've been discarded. + RTC_DLOG(LS_VERBOSE) + << log_prefix_ + << "Message is built from fragments already seen - skipping"; + return; + } else if (tsn == last_assembled_tsn_watermark_.next_value()) { + // Update watermark, or insert into delivered_tsns_ + last_assembled_tsn_watermark_.Increment(); + } else { + delivered_tsns_.insert(tsn); + } + } + + // With new TSNs in delivered_tsns, gaps might be filled. + MaybeMoveLastAssembledWatermarkFurther(); + + reassembled_messages_.emplace_back(std::move(message)); +} + +void ReassemblyQueue::MaybeMoveLastAssembledWatermarkFurther() { + // `delivered_tsns_` contain TSNS when there is a gap between ranges of + // assembled TSNs. `last_assembled_tsn_watermark_` should not be adjacent to + // that list, because if so, it can be moved. + while (!delivered_tsns_.empty() && + *delivered_tsns_.begin() == + last_assembled_tsn_watermark_.next_value()) { + last_assembled_tsn_watermark_.Increment(); + delivered_tsns_.erase(delivered_tsns_.begin()); + } +} + +void ReassemblyQueue::Handle(const AnyForwardTsnChunk& forward_tsn) { + RTC_DCHECK(IsConsistent()); + UnwrappedTSN tsn = tsn_unwrapper_.Unwrap(forward_tsn.new_cumulative_tsn()); + + last_assembled_tsn_watermark_ = std::max(last_assembled_tsn_watermark_, tsn); + delivered_tsns_.erase(delivered_tsns_.begin(), + delivered_tsns_.upper_bound(tsn)); + + MaybeMoveLastAssembledWatermarkFurther(); + + queued_bytes_ -= + streams_->HandleForwardTsn(tsn, forward_tsn.skipped_streams()); + RTC_DCHECK(IsConsistent()); +} + +bool ReassemblyQueue::IsConsistent() const { + // `delivered_tsns_` and `last_assembled_tsn_watermark_` mustn't overlap or be + // adjacent. + if (!delivered_tsns_.empty() && + last_assembled_tsn_watermark_.next_value() >= *delivered_tsns_.begin()) { + return false; + } + + // Allow queued_bytes_ to be larger than max_size_bytes, as it's not actively + // enforced in this class. This comparison will still trigger if queued_bytes_ + // became "negative". + return (queued_bytes_ >= 0 && queued_bytes_ <= 2 * max_size_bytes_); +} + +HandoverReadinessStatus ReassemblyQueue::GetHandoverReadiness() const { + HandoverReadinessStatus status = streams_->GetHandoverReadiness(); + if (!delivered_tsns_.empty()) { + status.Add(HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap); + } + if (deferred_reset_streams_.has_value()) { + status.Add(HandoverUnreadinessReason::kStreamResetDeferred); + } + return status; +} + +void ReassemblyQueue::AddHandoverState(DcSctpSocketHandoverState& state) { + state.rx.last_assembled_tsn = last_assembled_tsn_watermark_.Wrap().value(); + state.rx.last_completed_deferred_reset_req_sn = + last_completed_reset_req_seq_nbr_.value(); + streams_->AddHandoverState(state); +} + +void ReassemblyQueue::RestoreFromState(const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(last_completed_reset_req_seq_nbr_ == ReconfigRequestSN(0)); + + last_assembled_tsn_watermark_ = + tsn_unwrapper_.Unwrap(TSN(state.rx.last_assembled_tsn)); + last_completed_reset_req_seq_nbr_ = + ReconfigRequestSN(state.rx.last_completed_deferred_reset_req_sn); + streams_->RestoreFromState(state); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.h b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.h new file mode 100644 index 0000000000..91f30a3f69 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue.h @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_REASSEMBLY_QUEUE_H_ +#define NET_DCSCTP_RX_REASSEMBLY_QUEUE_H_ + +#include <stddef.h> + +#include <cstdint> +#include <memory> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/public/dcsctp_handover_state.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/rx/reassembly_streams.h" + +namespace dcsctp { + +// Contains the received DATA chunks that haven't yet been reassembled, and +// reassembles chunks when possible. +// +// The actual assembly is handled by an implementation of the +// `ReassemblyStreams` interface. +// +// Except for reassembling fragmented messages, this class will also handle two +// less common operations; To handle the receiver-side of partial reliability +// (limited number of retransmissions or limited message lifetime) as well as +// stream resetting, which is used when a sender wishes to close a data channel. +// +// Partial reliability is handled when a FORWARD-TSN or I-FORWARD-TSN chunk is +// received, and it will simply delete any chunks matching the parameters in +// that chunk. This is mainly implemented in ReassemblyStreams. +// +// Resetting streams is handled when a RECONFIG chunks is received, with an +// "Outgoing SSN Reset Request" parameter. That parameter will contain a list of +// streams to reset, and a `sender_last_assigned_tsn`. If this TSN is not yet +// seen, the stream cannot be directly reset, and this class will respond that +// the reset is "deferred". But if this TSN provided is known, the stream can be +// immediately be reset. +// +// The ReassemblyQueue has a maximum size, as it would otherwise be an DoS +// attack vector where a peer could consume all memory of the other peer by +// sending a lot of ordered chunks, but carefully withholding an early one. It +// also has a watermark limit, which the caller can query is the number of bytes +// is above that limit. This is used by the caller to be selective in what to +// add to the reassembly queue, so that it's not exhausted. The caller is +// expected to call `is_full` prior to adding data to the queue and to act +// accordingly if the queue is full. +class ReassemblyQueue { + public: + // When the queue is filled over this fraction (of its maximum size), the + // socket should restrict incoming data to avoid filling up the queue. + static constexpr float kHighWatermarkLimit = 0.9; + + ReassemblyQueue(absl::string_view log_prefix, + TSN peer_initial_tsn, + size_t max_size_bytes, + bool use_message_interleaving = false); + + // Adds a data chunk to the queue, with a `tsn` and other parameters in + // `data`. + void Add(TSN tsn, Data data); + + // Indicates if the reassembly queue has any reassembled messages that can be + // retrieved by calling `FlushMessages`. + bool HasMessages() const { return !reassembled_messages_.empty(); } + + // Returns any reassembled messages. + std::vector<DcSctpMessage> FlushMessages(); + + // Handle a ForwardTSN chunk, when the sender has indicated that the received + // (this class) should forget about some chunks. This is used to implement + // partial reliability. + void Handle(const AnyForwardTsnChunk& forward_tsn); + + // Given the reset stream request and the current cum_tsn_ack, might either + // reset the streams directly (returns kSuccessPerformed), or at a later time, + // by entering the "deferred reset processing" mode (returns kInProgress). + ReconfigurationResponseParameter::Result ResetStreams( + const OutgoingSSNResetRequestParameter& req, + TSN cum_tsn_ack); + + // Given the current (updated) cum_tsn_ack, might leave "defererred reset + // processing" mode and reset streams. Returns true if so. + bool MaybeResetStreamsDeferred(TSN cum_ack_tsn); + + // The number of payload bytes that have been queued. Note that the actual + // memory usage is higher due to additional overhead of tracking received + // data. + size_t queued_bytes() const { return queued_bytes_; } + + // The remaining bytes until the queue has reached the watermark limit. + size_t remaining_bytes() const { return watermark_bytes_ - queued_bytes_; } + + // Indicates if the queue is full. Data should not be added to the queue when + // it's full. + bool is_full() const { return queued_bytes_ >= max_size_bytes_; } + + // Indicates if the queue is above the watermark limit, which is a certain + // percentage of its size. + bool is_above_watermark() const { return queued_bytes_ >= watermark_bytes_; } + + // Returns the watermark limit, in bytes. + size_t watermark_bytes() const { return watermark_bytes_; } + + HandoverReadinessStatus GetHandoverReadiness() const; + + void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); + + private: + bool IsConsistent() const; + void AddReassembledMessage(rtc::ArrayView<const UnwrappedTSN> tsns, + DcSctpMessage message); + void MaybeMoveLastAssembledWatermarkFurther(); + + struct DeferredResetStreams { + explicit DeferredResetStreams(OutgoingSSNResetRequestParameter req) + : req(std::move(req)) {} + OutgoingSSNResetRequestParameter req; + std::vector<std::pair<TSN, Data>> deferred_chunks; + }; + + const std::string log_prefix_; + const size_t max_size_bytes_; + const size_t watermark_bytes_; + UnwrappedTSN::Unwrapper tsn_unwrapper_; + + // Whenever a message has been assembled, either increase + // `last_assembled_tsn_watermark_` or - if there are gaps - add the message's + // TSNs into delivered_tsns_ so that messages are not re-delivered on + // duplicate chunks. + UnwrappedTSN last_assembled_tsn_watermark_; + std::set<UnwrappedTSN> delivered_tsns_; + // Messages that have been reassembled, and will be returned by + // `FlushMessages`. + std::vector<DcSctpMessage> reassembled_messages_; + + // If present, "deferred reset processing" mode is active. + absl::optional<DeferredResetStreams> deferred_reset_streams_; + + // Contains the last request sequence number of the + // OutgoingSSNResetRequestParameter that was performed. + ReconfigRequestSN last_completed_reset_req_seq_nbr_; + + // The number of "payload bytes" that are in this queue, in total. + size_t queued_bytes_ = 0; + + // The actual implementation of ReassemblyStreams. + std::unique_ptr<ReassemblyStreams> streams_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_REASSEMBLY_QUEUE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue_test.cc b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue_test.cc new file mode 100644 index 0000000000..549bc6fce1 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/reassembly_queue_test.cc @@ -0,0 +1,509 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/reassembly_queue.h" + +#include <stddef.h> + +#include <algorithm> +#include <array> +#include <cstdint> +#include <iterator> +#include <vector> + +#include "api/array_view.h" +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/testing/data_generator.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +// The default maximum size of the Reassembly Queue. +static constexpr size_t kBufferSize = 10000; + +static constexpr StreamID kStreamID(1); +static constexpr SSN kSSN(0); +static constexpr MID kMID(0); +static constexpr FSN kFSN(0); +static constexpr PPID kPPID(53); + +static constexpr std::array<uint8_t, 4> kShortPayload = {1, 2, 3, 4}; +static constexpr std::array<uint8_t, 4> kMessage2Payload = {5, 6, 7, 8}; +static constexpr std::array<uint8_t, 6> kSixBytePayload = {1, 2, 3, 4, 5, 6}; +static constexpr std::array<uint8_t, 8> kMediumPayload1 = {1, 2, 3, 4, + 5, 6, 7, 8}; +static constexpr std::array<uint8_t, 8> kMediumPayload2 = {9, 10, 11, 12, + 13, 14, 15, 16}; +static constexpr std::array<uint8_t, 16> kLongPayload = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + +MATCHER_P3(SctpMessageIs, stream_id, ppid, expected_payload, "") { + if (arg.stream_id() != stream_id) { + *result_listener << "the stream_id is " << *arg.stream_id(); + return false; + } + + if (arg.ppid() != ppid) { + *result_listener << "the ppid is " << *arg.ppid(); + return false; + } + + if (std::vector<uint8_t>(arg.payload().begin(), arg.payload().end()) != + std::vector<uint8_t>(expected_payload.begin(), expected_payload.end())) { + *result_listener << "the payload is wrong"; + return false; + } + return true; +} + +class ReassemblyQueueTest : public testing::Test { + protected: + ReassemblyQueueTest() {} + DataGenerator gen_; +}; + +TEST_F(ReassemblyQueueTest, EmptyQueue) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + EXPECT_FALSE(reasm.HasMessages()); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, SingleUnorderedChunkMessage) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, LargeUnorderedChunkAllPermutations) { + std::vector<uint32_t> tsns = {10, 11, 12, 13}; + rtc::ArrayView<const uint8_t> payload(kLongPayload); + do { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + + for (size_t i = 0; i < tsns.size(); i++) { + auto span = payload.subview((tsns[i] - 10) * 4, 4); + Data::IsBeginning is_beginning(tsns[i] == 10); + Data::IsEnd is_end(tsns[i] == 13); + + reasm.Add(TSN(tsns[i]), + Data(kStreamID, kSSN, kMID, kFSN, kPPID, + std::vector<uint8_t>(span.begin(), span.end()), + is_beginning, is_end, IsUnordered(false))); + if (i < 3) { + EXPECT_FALSE(reasm.HasMessages()); + } else { + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kLongPayload))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + } + } + } while (std::next_permutation(std::begin(tsns), std::end(tsns))); +} + +TEST_F(ReassemblyQueueTest, SingleOrderedChunkMessage) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); +} + +TEST_F(ReassemblyQueueTest, ManySmallOrderedMessages) { + std::vector<uint32_t> tsns = {10, 11, 12, 13}; + rtc::ArrayView<const uint8_t> payload(kLongPayload); + do { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + for (size_t i = 0; i < tsns.size(); i++) { + auto span = payload.subview((tsns[i] - 10) * 4, 4); + Data::IsBeginning is_beginning(true); + Data::IsEnd is_end(true); + + SSN ssn(static_cast<uint16_t>(tsns[i] - 10)); + reasm.Add(TSN(tsns[i]), + Data(kStreamID, ssn, kMID, kFSN, kPPID, + std::vector<uint8_t>(span.begin(), span.end()), + is_beginning, is_end, IsUnordered(false))); + } + EXPECT_THAT( + reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, payload.subview(0, 4)), + SctpMessageIs(kStreamID, kPPID, payload.subview(4, 4)), + SctpMessageIs(kStreamID, kPPID, payload.subview(8, 4)), + SctpMessageIs(kStreamID, kPPID, payload.subview(12, 4)))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + } while (std::next_permutation(std::begin(tsns), std::end(tsns))); +} + +TEST_F(ReassemblyQueueTest, RetransmissionInLargeOrdered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1}, "B")); + reasm.Add(TSN(12), gen_.Ordered({3})); + reasm.Add(TSN(13), gen_.Ordered({4})); + reasm.Add(TSN(14), gen_.Ordered({5})); + reasm.Add(TSN(15), gen_.Ordered({6})); + reasm.Add(TSN(16), gen_.Ordered({7})); + reasm.Add(TSN(17), gen_.Ordered({8})); + EXPECT_EQ(reasm.queued_bytes(), 7u); + + // lost and retransmitted + reasm.Add(TSN(11), gen_.Ordered({2})); + reasm.Add(TSN(18), gen_.Ordered({9})); + reasm.Add(TSN(19), gen_.Ordered({10})); + EXPECT_EQ(reasm.queued_bytes(), 10u); + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Add(TSN(20), gen_.Ordered({11, 12, 13, 14, 15, 16}, "E")); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kLongPayload))); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, ForwardTSNRemoveUnordered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1}, "B")); + reasm.Add(TSN(12), gen_.Unordered({3})); + reasm.Add(TSN(13), gen_.Unordered({4}, "E")); + + reasm.Add(TSN(14), gen_.Unordered({5}, "B")); + reasm.Add(TSN(15), gen_.Unordered({6})); + reasm.Add(TSN(17), gen_.Unordered({8}, "E")); + EXPECT_EQ(reasm.queued_bytes(), 6u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Handle(ForwardTsnChunk(TSN(13), {})); + EXPECT_EQ(reasm.queued_bytes(), 3u); + + // The lost chunk comes, but too late. + reasm.Add(TSN(11), gen_.Unordered({2})); + EXPECT_FALSE(reasm.HasMessages()); + EXPECT_EQ(reasm.queued_bytes(), 3u); + + // The second lost chunk comes, message is assembled. + reasm.Add(TSN(16), gen_.Unordered({7})); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_EQ(reasm.queued_bytes(), 0u); +} + +TEST_F(ReassemblyQueueTest, ForwardTSNRemoveOrdered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1}, "B")); + reasm.Add(TSN(12), gen_.Ordered({3})); + reasm.Add(TSN(13), gen_.Ordered({4}, "E")); + + reasm.Add(TSN(14), gen_.Ordered({5}, "B")); + reasm.Add(TSN(15), gen_.Ordered({6})); + reasm.Add(TSN(16), gen_.Ordered({7})); + reasm.Add(TSN(17), gen_.Ordered({8}, "E")); + EXPECT_EQ(reasm.queued_bytes(), 7u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Handle(ForwardTsnChunk( + TSN(13), {ForwardTsnChunk::SkippedStream(kStreamID, kSSN)})); + EXPECT_EQ(reasm.queued_bytes(), 0u); + + // The lost chunk comes, but too late. + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kMessage2Payload))); +} + +TEST_F(ReassemblyQueueTest, ForwardTSNRemoveALotOrdered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Ordered({1}, "B")); + reasm.Add(TSN(12), gen_.Ordered({3})); + reasm.Add(TSN(13), gen_.Ordered({4}, "E")); + + reasm.Add(TSN(15), gen_.Ordered({5}, "B")); + reasm.Add(TSN(16), gen_.Ordered({6})); + reasm.Add(TSN(17), gen_.Ordered({7})); + reasm.Add(TSN(18), gen_.Ordered({8}, "E")); + EXPECT_EQ(reasm.queued_bytes(), 7u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Handle(ForwardTsnChunk( + TSN(13), {ForwardTsnChunk::SkippedStream(kStreamID, kSSN)})); + EXPECT_EQ(reasm.queued_bytes(), 0u); + + // The lost chunk comes, but too late. + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kMessage2Payload))); +} + +TEST_F(ReassemblyQueueTest, ShouldntDeliverMessagesBeforeInitialTsn) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(5), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, ShouldntRedeliverUnorderedMessages) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, ShouldntRedeliverUnorderedMessagesReallyUnordered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "B")); + EXPECT_EQ(reasm.queued_bytes(), 4u); + + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Add(TSN(12), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 4u); + EXPECT_TRUE(reasm.HasMessages()); + + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); + reasm.Add(TSN(12), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 4u); + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, ShouldntDeliverBeforeForwardedTsn) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Handle(ForwardTsnChunk(TSN(12), {})); + + reasm.Add(TSN(12), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, NotReadyForHandoverWhenDeliveredTsnsHaveGap) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + reasm.Add(TSN(10), gen_.Unordered({1, 2, 3, 4}, "B")); + EXPECT_FALSE(reasm.HasMessages()); + + reasm.Add(TSN(12), gen_.Unordered({1, 2, 3, 4}, "BE")); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_EQ( + reasm.GetHandoverReadiness(), + HandoverReadinessStatus() + .Add(HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap) + .Add( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks)); + + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); + EXPECT_EQ( + reasm.GetHandoverReadiness(), + HandoverReadinessStatus() + .Add(HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap) + .Add( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks)); + + reasm.Handle(ForwardTsnChunk(TSN(13), {})); + EXPECT_EQ(reasm.GetHandoverReadiness(), HandoverReadinessStatus()); +} + +TEST_F(ReassemblyQueueTest, NotReadyForHandoverWhenResetStreamIsDeferred) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + DataGeneratorOptions opts; + opts.message_id = MID(0); + reasm.Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + opts.message_id = MID(1); + reasm.Add(TSN(11), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + EXPECT_THAT(reasm.FlushMessages(), SizeIs(2)); + + reasm.ResetStreams( + OutgoingSSNResetRequestParameter( + ReconfigRequestSN(10), ReconfigRequestSN(3), TSN(13), {StreamID(1)}), + TSN(11)); + EXPECT_EQ(reasm.GetHandoverReadiness(), + HandoverReadinessStatus().Add( + HandoverUnreadinessReason::kStreamResetDeferred)); + + opts.message_id = MID(3); + opts.ppid = PPID(3); + reasm.Add(TSN(13), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm.MaybeResetStreamsDeferred(TSN(11)); + + opts.message_id = MID(2); + opts.ppid = PPID(2); + reasm.Add(TSN(13), gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm.MaybeResetStreamsDeferred(TSN(15)); + EXPECT_EQ(reasm.GetHandoverReadiness(), + HandoverReadinessStatus().Add( + HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap)); + + EXPECT_THAT(reasm.FlushMessages(), SizeIs(2)); + EXPECT_EQ(reasm.GetHandoverReadiness(), + HandoverReadinessStatus().Add( + HandoverUnreadinessReason::kReassemblyQueueDeliveredTSNsGap)); + + reasm.Handle(ForwardTsnChunk(TSN(15), {})); + EXPECT_EQ(reasm.GetHandoverReadiness(), HandoverReadinessStatus()); +} + +TEST_F(ReassemblyQueueTest, HandoverInInitialState) { + ReassemblyQueue reasm1("log: ", TSN(10), kBufferSize); + + EXPECT_EQ(reasm1.GetHandoverReadiness(), HandoverReadinessStatus()); + DcSctpSocketHandoverState state; + reasm1.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + ReassemblyQueue reasm2("log: ", TSN(100), kBufferSize, + /*use_message_interleaving=*/false); + reasm2.RestoreFromState(state); + + reasm2.Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm2.FlushMessages(), SizeIs(1)); +} + +TEST_F(ReassemblyQueueTest, HandoverAfterHavingAssembedOneMessage) { + ReassemblyQueue reasm1("log: ", TSN(10), kBufferSize); + reasm1.Add(TSN(10), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm1.FlushMessages(), SizeIs(1)); + + EXPECT_EQ(reasm1.GetHandoverReadiness(), HandoverReadinessStatus()); + DcSctpSocketHandoverState state; + reasm1.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + ReassemblyQueue reasm2("log: ", TSN(100), kBufferSize, + /*use_message_interleaving=*/false); + reasm2.RestoreFromState(state); + + reasm2.Add(TSN(11), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm2.FlushMessages(), SizeIs(1)); +} + +TEST_F(ReassemblyQueueTest, HandleInconsistentForwardTSN) { + // Found when fuzzing. + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize); + // Add TSN=43, SSN=7. Can't be reassembled as previous SSNs aren't known. + reasm.Add(TSN(43), Data(kStreamID, SSN(7), MID(0), FSN(0), kPPID, + std::vector<uint8_t>(10), Data::IsBeginning(true), + Data::IsEnd(true), IsUnordered(false))); + + // Invalid, as TSN=44 have to have SSN>=7, but peer says 6. + reasm.Handle(ForwardTsnChunk( + TSN(44), {ForwardTsnChunk::SkippedStream(kStreamID, SSN(6))})); + + // Don't assemble SSN=7, as that TSN is skipped. + EXPECT_FALSE(reasm.HasMessages()); +} + +TEST_F(ReassemblyQueueTest, SingleUnorderedChunkMessageInRfc8260) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize, + /*use_message_interleaving=*/true); + reasm.Add(TSN(10), Data(StreamID(1), SSN(0), MID(0), FSN(0), kPPID, + {1, 2, 3, 4}, Data::IsBeginning(true), + Data::IsEnd(true), IsUnordered(true))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kShortPayload))); +} + +TEST_F(ReassemblyQueueTest, TwoInterleavedChunks) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize, + /*use_message_interleaving=*/true); + reasm.Add(TSN(10), Data(StreamID(1), SSN(0), MID(0), FSN(0), kPPID, + {1, 2, 3, 4}, Data::IsBeginning(true), + Data::IsEnd(false), IsUnordered(true))); + reasm.Add(TSN(11), Data(StreamID(2), SSN(0), MID(0), FSN(0), kPPID, + {9, 10, 11, 12}, Data::IsBeginning(true), + Data::IsEnd(false), IsUnordered(true))); + EXPECT_EQ(reasm.queued_bytes(), 8u); + reasm.Add(TSN(12), Data(StreamID(1), SSN(0), MID(0), FSN(1), kPPID, + {5, 6, 7, 8}, Data::IsBeginning(false), + Data::IsEnd(true), IsUnordered(true))); + EXPECT_EQ(reasm.queued_bytes(), 4u); + reasm.Add(TSN(13), Data(StreamID(2), SSN(0), MID(0), FSN(1), kPPID, + {13, 14, 15, 16}, Data::IsBeginning(false), + Data::IsEnd(true), IsUnordered(true))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(StreamID(1), kPPID, kMediumPayload1), + SctpMessageIs(StreamID(2), kPPID, kMediumPayload2))); +} + +TEST_F(ReassemblyQueueTest, UnorderedInterleavedMessagesAllPermutations) { + std::vector<int> indexes = {0, 1, 2, 3, 4, 5}; + TSN tsns[] = {TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), TSN(15)}; + StreamID stream_ids[] = {StreamID(1), StreamID(2), StreamID(1), + StreamID(1), StreamID(2), StreamID(2)}; + FSN fsns[] = {FSN(0), FSN(0), FSN(1), FSN(2), FSN(1), FSN(2)}; + rtc::ArrayView<const uint8_t> payload(kSixBytePayload); + do { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize, + /*use_message_interleaving=*/true); + for (int i : indexes) { + auto span = payload.subview(*fsns[i] * 2, 2); + Data::IsBeginning is_beginning(fsns[i] == FSN(0)); + Data::IsEnd is_end(fsns[i] == FSN(2)); + reasm.Add(tsns[i], Data(stream_ids[i], SSN(0), MID(0), fsns[i], kPPID, + std::vector<uint8_t>(span.begin(), span.end()), + is_beginning, is_end, IsUnordered(true))); + } + EXPECT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), kPPID, kSixBytePayload), + SctpMessageIs(StreamID(2), kPPID, kSixBytePayload))); + EXPECT_EQ(reasm.queued_bytes(), 0u); + } while (std::next_permutation(std::begin(indexes), std::end(indexes))); +} + +TEST_F(ReassemblyQueueTest, IForwardTSNRemoveALotOrdered) { + ReassemblyQueue reasm("log: ", TSN(10), kBufferSize, + /*use_message_interleaving=*/true); + reasm.Add(TSN(10), gen_.Ordered({1}, "B")); + gen_.Ordered({2}, ""); + reasm.Add(TSN(12), gen_.Ordered({3}, "")); + reasm.Add(TSN(13), gen_.Ordered({4}, "E")); + reasm.Add(TSN(15), gen_.Ordered({5}, "B")); + reasm.Add(TSN(16), gen_.Ordered({6}, "")); + reasm.Add(TSN(17), gen_.Ordered({7}, "")); + reasm.Add(TSN(18), gen_.Ordered({8}, "E")); + + ASSERT_FALSE(reasm.HasMessages()); + EXPECT_EQ(reasm.queued_bytes(), 7u); + + reasm.Handle( + IForwardTsnChunk(TSN(13), {IForwardTsnChunk::SkippedStream( + IsUnordered(false), kStreamID, MID(0))})); + EXPECT_EQ(reasm.queued_bytes(), 0u); + + // The lost chunk comes, but too late. + ASSERT_TRUE(reasm.HasMessages()); + EXPECT_THAT(reasm.FlushMessages(), + ElementsAre(SctpMessageIs(kStreamID, kPPID, kMessage2Payload))); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.cc b/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.cc new file mode 100644 index 0000000000..9fd52fb15d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.cc @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/reassembly_streams.h" + +#include <cstddef> +#include <map> +#include <utility> + +namespace dcsctp { + +ReassembledMessage AssembleMessage(std::map<UnwrappedTSN, Data>::iterator start, + std::map<UnwrappedTSN, Data>::iterator end) { + size_t count = std::distance(start, end); + + if (count == 1) { + // Fast path - zero-copy + Data& data = start->second; + + return ReassembledMessage{ + .tsns = {start->first}, + .message = DcSctpMessage(data.stream_id, data.ppid, + std::move(start->second.payload)), + }; + } + + // Slow path - will need to concatenate the payload. + std::vector<UnwrappedTSN> tsns; + std::vector<uint8_t> payload; + + size_t payload_size = std::accumulate( + start, end, 0, + [](size_t v, const auto& p) { return v + p.second.size(); }); + + tsns.reserve(count); + payload.reserve(payload_size); + for (auto it = start; it != end; ++it) { + Data& data = it->second; + tsns.push_back(it->first); + payload.insert(payload.end(), data.payload.begin(), data.payload.end()); + } + + return ReassembledMessage{ + .tsns = std::move(tsns), + .message = DcSctpMessage(start->second.stream_id, start->second.ppid, + std::move(payload)), + }; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.h b/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.h new file mode 100644 index 0000000000..0ecfac0c0a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/reassembly_streams.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_REASSEMBLY_STREAMS_H_ +#define NET_DCSCTP_RX_REASSEMBLY_STREAMS_H_ + +#include <stddef.h> +#include <stdint.h> + +#include <functional> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_handover_state.h" +#include "net/dcsctp/public/dcsctp_message.h" + +namespace dcsctp { + +// Implementations of this interface will be called when data is received, when +// data should be skipped/forgotten or when sequence number should be reset. +// +// As a result of these operations - mainly when data is received - the +// implementations of this interface should notify when a message has been +// assembled, by calling the provided callback of type `OnAssembledMessage`. How +// it assembles messages will depend on e.g. if a message was sent on an ordered +// or unordered stream. +// +// Implementations will - for each operation - indicate how much additional +// memory that has been used as a result of performing the operation. This is +// used to limit the maximum amount of memory used, to prevent out-of-memory +// situations. +class ReassemblyStreams { + public: + // This callback will be provided as an argument to the constructor of the + // concrete class implementing this interface and should be called when a + // message has been assembled as well as indicating from which TSNs this + // message was assembled from. + using OnAssembledMessage = + std::function<void(rtc::ArrayView<const UnwrappedTSN> tsns, + DcSctpMessage message)>; + + virtual ~ReassemblyStreams() = default; + + // Adds a data chunk to a stream as identified in `data`. + // If it was the last remaining chunk in a message, reassemble one (or + // several, in case of ordered chunks) messages. + // + // Returns the additional number of bytes added to the queue as a result of + // performing this operation. If this addition resulted in messages being + // assembled and delivered, this may be negative. + virtual int Add(UnwrappedTSN tsn, Data data) = 0; + + // Called for incoming FORWARD-TSN/I-FORWARD-TSN chunks - when the sender + // wishes the received to skip/forget about data up until the provided TSN. + // This is used to implement partial reliability, such as limiting the number + // of retransmissions or the an expiration duration. As a result of skipping + // data, this may result in the implementation being able to assemble messages + // in ordered streams. + // + // Returns the number of bytes removed from the queue as a result of + // this operation. + virtual size_t HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> + skipped_streams) = 0; + + // Called for incoming (possibly deferred) RE_CONFIG chunks asking for + // either a few streams, or all streams (when the list is empty) to be + // reset - to have their next SSN or Message ID to be zero. + virtual void ResetStreams(rtc::ArrayView<const StreamID> stream_ids) = 0; + + virtual HandoverReadinessStatus GetHandoverReadiness() const = 0; + virtual void AddHandoverState(DcSctpSocketHandoverState& state) = 0; + virtual void RestoreFromState(const DcSctpSocketHandoverState& state) = 0; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_REASSEMBLY_STREAMS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc new file mode 100644 index 0000000000..dce6c90131 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.cc @@ -0,0 +1,348 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/traditional_reassembly_streams.h" + +#include <stddef.h> + +#include <cstdint> +#include <functional> +#include <iterator> +#include <map> +#include <numeric> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace { + +// Given a map (`chunks`) and an iterator to within that map (`iter`), this +// function will return an iterator to the first chunk in that message, which +// has the `is_beginning` flag set. If there are any gaps, or if the beginning +// can't be found, `absl::nullopt` is returned. +absl::optional<std::map<UnwrappedTSN, Data>::iterator> FindBeginning( + const std::map<UnwrappedTSN, Data>& chunks, + std::map<UnwrappedTSN, Data>::iterator iter) { + UnwrappedTSN prev_tsn = iter->first; + for (;;) { + if (iter->second.is_beginning) { + return iter; + } + if (iter == chunks.begin()) { + return absl::nullopt; + } + --iter; + if (iter->first.next_value() != prev_tsn) { + return absl::nullopt; + } + prev_tsn = iter->first; + } +} + +// Given a map (`chunks`) and an iterator to within that map (`iter`), this +// function will return an iterator to the chunk after the last chunk in that +// message, which has the `is_end` flag set. If there are any gaps, or if the +// end can't be found, `absl::nullopt` is returned. +absl::optional<std::map<UnwrappedTSN, Data>::iterator> FindEnd( + std::map<UnwrappedTSN, Data>& chunks, + std::map<UnwrappedTSN, Data>::iterator iter) { + UnwrappedTSN prev_tsn = iter->first; + for (;;) { + if (iter->second.is_end) { + return ++iter; + } + ++iter; + if (iter == chunks.end()) { + return absl::nullopt; + } + if (iter->first != prev_tsn.next_value()) { + return absl::nullopt; + } + prev_tsn = iter->first; + } +} +} // namespace + +TraditionalReassemblyStreams::TraditionalReassemblyStreams( + absl::string_view log_prefix, + OnAssembledMessage on_assembled_message) + : log_prefix_(log_prefix), + on_assembled_message_(std::move(on_assembled_message)) {} + +int TraditionalReassemblyStreams::UnorderedStream::Add(UnwrappedTSN tsn, + Data data) { + int queued_bytes = data.size(); + auto [it, inserted] = chunks_.emplace(tsn, std::move(data)); + if (!inserted) { + return 0; + } + + queued_bytes -= TryToAssembleMessage(it); + + return queued_bytes; +} + +size_t TraditionalReassemblyStreams::UnorderedStream::TryToAssembleMessage( + ChunkMap::iterator iter) { + // TODO(boivie): This method is O(N) with the number of fragments in a + // message, which can be inefficient for very large values of N. This could be + // optimized by e.g. only trying to assemble a message once _any_ beginning + // and _any_ end has been found. + absl::optional<ChunkMap::iterator> start = FindBeginning(chunks_, iter); + if (!start.has_value()) { + return 0; + } + absl::optional<ChunkMap::iterator> end = FindEnd(chunks_, iter); + if (!end.has_value()) { + return 0; + } + + size_t bytes_assembled = AssembleMessage(*start, *end); + chunks_.erase(*start, *end); + return bytes_assembled; +} + +size_t TraditionalReassemblyStreams::StreamBase::AssembleMessage( + const ChunkMap::iterator start, + const ChunkMap::iterator end) { + size_t count = std::distance(start, end); + + if (count == 1) { + // Fast path - zero-copy + const Data& data = start->second; + size_t payload_size = start->second.size(); + UnwrappedTSN tsns[1] = {start->first}; + DcSctpMessage message(data.stream_id, data.ppid, std::move(data.payload)); + parent_.on_assembled_message_(tsns, std::move(message)); + return payload_size; + } + + // Slow path - will need to concatenate the payload. + std::vector<UnwrappedTSN> tsns; + std::vector<uint8_t> payload; + + size_t payload_size = std::accumulate( + start, end, 0, + [](size_t v, const auto& p) { return v + p.second.size(); }); + + tsns.reserve(count); + payload.reserve(payload_size); + for (auto it = start; it != end; ++it) { + const Data& data = it->second; + tsns.push_back(it->first); + payload.insert(payload.end(), data.payload.begin(), data.payload.end()); + } + + DcSctpMessage message(start->second.stream_id, start->second.ppid, + std::move(payload)); + parent_.on_assembled_message_(tsns, std::move(message)); + + return payload_size; +} + +size_t TraditionalReassemblyStreams::UnorderedStream::EraseTo( + UnwrappedTSN tsn) { + auto end_iter = chunks_.upper_bound(tsn); + size_t removed_bytes = std::accumulate( + chunks_.begin(), end_iter, 0, + [](size_t r, const auto& p) { return r + p.second.size(); }); + + chunks_.erase(chunks_.begin(), end_iter); + return removed_bytes; +} + +size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessage() { + if (chunks_by_ssn_.empty() || chunks_by_ssn_.begin()->first != next_ssn_) { + return 0; + } + + ChunkMap& chunks = chunks_by_ssn_.begin()->second; + + if (!chunks.begin()->second.is_beginning || !chunks.rbegin()->second.is_end) { + return 0; + } + + uint32_t tsn_diff = + UnwrappedTSN::Difference(chunks.rbegin()->first, chunks.begin()->first); + if (tsn_diff != chunks.size() - 1) { + return 0; + } + + size_t assembled_bytes = AssembleMessage(chunks.begin(), chunks.end()); + chunks_by_ssn_.erase(chunks_by_ssn_.begin()); + next_ssn_.Increment(); + return assembled_bytes; +} + +size_t TraditionalReassemblyStreams::OrderedStream::TryToAssembleMessages() { + size_t assembled_bytes = 0; + + for (;;) { + size_t assembled_bytes_this_iter = TryToAssembleMessage(); + if (assembled_bytes_this_iter == 0) { + break; + } + assembled_bytes += assembled_bytes_this_iter; + } + return assembled_bytes; +} + +int TraditionalReassemblyStreams::OrderedStream::Add(UnwrappedTSN tsn, + Data data) { + int queued_bytes = data.size(); + + UnwrappedSSN ssn = ssn_unwrapper_.Unwrap(data.ssn); + auto [unused, inserted] = chunks_by_ssn_[ssn].emplace(tsn, std::move(data)); + if (!inserted) { + return 0; + } + + if (ssn == next_ssn_) { + queued_bytes -= TryToAssembleMessages(); + } + + return queued_bytes; +} + +size_t TraditionalReassemblyStreams::OrderedStream::EraseTo(SSN ssn) { + UnwrappedSSN unwrapped_ssn = ssn_unwrapper_.Unwrap(ssn); + + auto end_iter = chunks_by_ssn_.upper_bound(unwrapped_ssn); + size_t removed_bytes = std::accumulate( + chunks_by_ssn_.begin(), end_iter, 0, [](size_t r1, const auto& p) { + return r1 + + absl::c_accumulate(p.second, 0, [](size_t r2, const auto& q) { + return r2 + q.second.size(); + }); + }); + chunks_by_ssn_.erase(chunks_by_ssn_.begin(), end_iter); + + if (unwrapped_ssn >= next_ssn_) { + unwrapped_ssn.Increment(); + next_ssn_ = unwrapped_ssn; + } + + removed_bytes += TryToAssembleMessages(); + return removed_bytes; +} + +int TraditionalReassemblyStreams::Add(UnwrappedTSN tsn, Data data) { + if (data.is_unordered) { + auto it = unordered_streams_.try_emplace(data.stream_id, this).first; + return it->second.Add(tsn, std::move(data)); + } + + auto it = ordered_streams_.try_emplace(data.stream_id, this).first; + return it->second.Add(tsn, std::move(data)); +} + +size_t TraditionalReassemblyStreams::HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) { + size_t bytes_removed = 0; + // The `skipped_streams` only cover ordered messages - need to + // iterate all unordered streams manually to remove those chunks. + for (auto& [unused, stream] : unordered_streams_) { + bytes_removed += stream.EraseTo(new_cumulative_ack_tsn); + } + + for (const auto& skipped_stream : skipped_streams) { + auto it = + ordered_streams_.try_emplace(skipped_stream.stream_id, this).first; + bytes_removed += it->second.EraseTo(skipped_stream.ssn); + } + + return bytes_removed; +} + +void TraditionalReassemblyStreams::ResetStreams( + rtc::ArrayView<const StreamID> stream_ids) { + if (stream_ids.empty()) { + for (auto& [stream_id, stream] : ordered_streams_) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Resetting implicit stream_id=" << *stream_id; + stream.Reset(); + } + } else { + for (StreamID stream_id : stream_ids) { + auto it = ordered_streams_.find(stream_id); + if (it != ordered_streams_.end()) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Resetting explicit stream_id=" << *stream_id; + it->second.Reset(); + } + } + } +} + +HandoverReadinessStatus TraditionalReassemblyStreams::GetHandoverReadiness() + const { + HandoverReadinessStatus status; + for (const auto& [unused, stream] : ordered_streams_) { + if (stream.has_unassembled_chunks()) { + status.Add(HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks); + break; + } + } + for (const auto& [unused, stream] : unordered_streams_) { + if (stream.has_unassembled_chunks()) { + status.Add( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks); + break; + } + } + return status; +} + +void TraditionalReassemblyStreams::AddHandoverState( + DcSctpSocketHandoverState& state) { + for (const auto& [stream_id, stream] : ordered_streams_) { + DcSctpSocketHandoverState::OrderedStream state_stream; + state_stream.id = stream_id.value(); + state_stream.next_ssn = stream.next_ssn().value(); + state.rx.ordered_streams.push_back(std::move(state_stream)); + } + for (const auto& [stream_id, unused] : unordered_streams_) { + DcSctpSocketHandoverState::UnorderedStream state_stream; + state_stream.id = stream_id.value(); + state.rx.unordered_streams.push_back(std::move(state_stream)); + } +} + +void TraditionalReassemblyStreams::RestoreFromState( + const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(ordered_streams_.empty()); + RTC_DCHECK(unordered_streams_.empty()); + + for (const DcSctpSocketHandoverState::OrderedStream& state_stream : + state.rx.ordered_streams) { + ordered_streams_.emplace( + std::piecewise_construct, + std::forward_as_tuple(StreamID(state_stream.id)), + std::forward_as_tuple(this, SSN(state_stream.next_ssn))); + } + for (const DcSctpSocketHandoverState::UnorderedStream& state_stream : + state.rx.unordered_streams) { + unordered_streams_.emplace(std::piecewise_construct, + std::forward_as_tuple(StreamID(state_stream.id)), + std::forward_as_tuple(this)); + } +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.h b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.h new file mode 100644 index 0000000000..4825afd1ba --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams.h @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_RX_TRADITIONAL_REASSEMBLY_STREAMS_H_ +#define NET_DCSCTP_RX_TRADITIONAL_REASSEMBLY_STREAMS_H_ +#include <stddef.h> +#include <stdint.h> + +#include <map> +#include <string> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/rx/reassembly_streams.h" + +namespace dcsctp { + +// Handles reassembly of incoming data when interleaved message sending +// is not enabled on the association, i.e. when RFC8260 is not in use and +// RFC4960 is to be followed. +class TraditionalReassemblyStreams : public ReassemblyStreams { + public: + TraditionalReassemblyStreams(absl::string_view log_prefix, + OnAssembledMessage on_assembled_message); + + int Add(UnwrappedTSN tsn, Data data) override; + + size_t HandleForwardTsn( + UnwrappedTSN new_cumulative_ack_tsn, + rtc::ArrayView<const AnyForwardTsnChunk::SkippedStream> skipped_streams) + override; + + void ResetStreams(rtc::ArrayView<const StreamID> stream_ids) override; + + HandoverReadinessStatus GetHandoverReadiness() const override; + void AddHandoverState(DcSctpSocketHandoverState& state) override; + void RestoreFromState(const DcSctpSocketHandoverState& state) override; + + private: + using ChunkMap = std::map<UnwrappedTSN, Data>; + + // Base class for `UnorderedStream` and `OrderedStream`. + class StreamBase { + protected: + explicit StreamBase(TraditionalReassemblyStreams* parent) + : parent_(*parent) {} + + size_t AssembleMessage(ChunkMap::iterator start, ChunkMap::iterator end); + TraditionalReassemblyStreams& parent_; + }; + + // Manages all received data for a specific unordered stream, and assembles + // messages when possible. + class UnorderedStream : StreamBase { + public: + explicit UnorderedStream(TraditionalReassemblyStreams* parent) + : StreamBase(parent) {} + int Add(UnwrappedTSN tsn, Data data); + // Returns the number of bytes removed from the queue. + size_t EraseTo(UnwrappedTSN tsn); + bool has_unassembled_chunks() const { return !chunks_.empty(); } + + private: + // Given an iterator to any chunk within the map, try to assemble a message + // into `reassembled_messages` containing it and - if successful - erase + // those chunks from the stream chunks map. + // + // Returns the number of bytes that were assembled. + size_t TryToAssembleMessage(ChunkMap::iterator iter); + + ChunkMap chunks_; + }; + + // Manages all received data for a specific ordered stream, and assembles + // messages when possible. + class OrderedStream : StreamBase { + public: + explicit OrderedStream(TraditionalReassemblyStreams* parent, + SSN next_ssn = SSN(0)) + : StreamBase(parent), next_ssn_(ssn_unwrapper_.Unwrap(next_ssn)) {} + int Add(UnwrappedTSN tsn, Data data); + size_t EraseTo(SSN ssn); + void Reset() { + ssn_unwrapper_.Reset(); + next_ssn_ = ssn_unwrapper_.Unwrap(SSN(0)); + } + SSN next_ssn() const { return next_ssn_.Wrap(); } + bool has_unassembled_chunks() const { return !chunks_by_ssn_.empty(); } + + private: + // Try to assemble one or several messages in order from the stream. + // Returns the number of bytes assembled if a message was assembled. + size_t TryToAssembleMessage(); + size_t TryToAssembleMessages(); + // This must be an ordered container to be able to iterate in SSN order. + std::map<UnwrappedSSN, ChunkMap> chunks_by_ssn_; + UnwrappedSSN::Unwrapper ssn_unwrapper_; + UnwrappedSSN next_ssn_; + }; + + const std::string log_prefix_; + + // Callback for when a message has been assembled. + const OnAssembledMessage on_assembled_message_; + + // All unordered and ordered streams, managing not-yet-assembled data. + std::map<StreamID, UnorderedStream> unordered_streams_; + std::map<StreamID, OrderedStream> ordered_streams_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_RX_TRADITIONAL_REASSEMBLY_STREAMS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams_test.cc b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams_test.cc new file mode 100644 index 0000000000..341870442d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/rx/traditional_reassembly_streams_test.cc @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/rx/traditional_reassembly_streams.h" + +#include <cstdint> +#include <memory> +#include <utility> + +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/rx/reassembly_streams.h" +#include "net/dcsctp/testing/data_generator.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::MockFunction; +using ::testing::NiceMock; +using ::testing::Property; + +class TraditionalReassemblyStreamsTest : public testing::Test { + protected: + UnwrappedTSN tsn(uint32_t value) { return tsn_.Unwrap(TSN(value)); } + + TraditionalReassemblyStreamsTest() {} + DataGenerator gen_; + UnwrappedTSN::Unwrapper tsn_; +}; + +TEST_F(TraditionalReassemblyStreamsTest, + AddUnorderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Unordered({5, 6})), 2); + // Adding the end fragment should make it empty again. + EXPECT_EQ(streams.Add(tsn(4), gen_.Unordered({7}, "E")), -6); +} + +TEST_F(TraditionalReassemblyStreamsTest, + AddSimpleOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), -6); +} + +TEST_F(TraditionalReassemblyStreamsTest, + AddMoreComplexOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + Data late = gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + EXPECT_EQ(streams.Add(tsn(2), std::move(late)), -8); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteUnorderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Unordered({5, 6})), 2); + + EXPECT_EQ(streams.HandleForwardTsn(tsn(3), {}), 6u); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteSimpleOrderedMessageReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(3), skipped), 6u); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteManyOrderedMessagesReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + + // Expire all three messages + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(2))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(8), skipped), 8u); +} + +TEST_F(TraditionalReassemblyStreamsTest, + DeleteOrderedMessageDelivesTwoReturnsCorrectSize) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + gen_.Ordered({2, 3, 4}); + EXPECT_EQ(streams.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams.Add(tsn(4), gen_.Ordered({7}, "E")), 1); + + EXPECT_EQ(streams.Add(tsn(5), gen_.Ordered({1}, "BE")), 1); + EXPECT_EQ(streams.Add(tsn(6), gen_.Ordered({5, 6}, "B")), 2); + EXPECT_EQ(streams.Add(tsn(7), gen_.Ordered({7}, "E")), 1); + + // The first ordered message expire, and the following two are delivered. + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(4), skipped), 8u); +} + +TEST_F(TraditionalReassemblyStreamsTest, NoStreamsCanBeHandedOver) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams1("", on_assembled.AsStdFunction()); + EXPECT_TRUE(streams1.GetHandoverReadiness().IsReady()); + + DcSctpSocketHandoverState state; + streams1.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction()); + streams2.RestoreFromState(state); + + EXPECT_EQ(streams2.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams2.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams2.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ(streams2.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); +} + +TEST_F(TraditionalReassemblyStreamsTest, + OrderedStreamsCanBeHandedOverWhenNoUnassembledChunksExist) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams1("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams1.Add(tsn(1), gen_.Ordered({1}, "B")), 1); + EXPECT_EQ(streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks)); + EXPECT_EQ(streams1.Add(tsn(2), gen_.Ordered({2, 3, 4})), 3); + EXPECT_EQ(streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks)); + EXPECT_EQ(streams1.Add(tsn(3), gen_.Ordered({5, 6})), 2); + EXPECT_EQ(streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kOrderedStreamHasUnassembledChunks)); + + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(0))}; + EXPECT_EQ(streams1.HandleForwardTsn(tsn(3), skipped), 6u); + EXPECT_TRUE(streams1.GetHandoverReadiness().IsReady()); + + DcSctpSocketHandoverState state; + streams1.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction()); + streams2.RestoreFromState(state); + EXPECT_EQ(streams2.Add(tsn(4), gen_.Ordered({7})), 1); +} + +TEST_F(TraditionalReassemblyStreamsTest, + UnorderedStreamsCanBeHandedOverWhenNoUnassembledChunksExist) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + + TraditionalReassemblyStreams streams1("", on_assembled.AsStdFunction()); + + EXPECT_EQ(streams1.Add(tsn(1), gen_.Unordered({1}, "B")), 1); + EXPECT_EQ( + streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks)); + EXPECT_EQ(streams1.Add(tsn(2), gen_.Unordered({2, 3, 4})), 3); + EXPECT_EQ( + streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks)); + EXPECT_EQ(streams1.Add(tsn(3), gen_.Unordered({5, 6})), 2); + EXPECT_EQ( + streams1.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kUnorderedStreamHasUnassembledChunks)); + + EXPECT_EQ(streams1.HandleForwardTsn(tsn(3), {}), 6u); + EXPECT_TRUE(streams1.GetHandoverReadiness().IsReady()); + + DcSctpSocketHandoverState state; + streams1.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + TraditionalReassemblyStreams streams2("", on_assembled.AsStdFunction()); + streams2.RestoreFromState(state); + EXPECT_EQ(streams2.Add(tsn(4), gen_.Unordered({7})), 1); +} + +TEST_F(TraditionalReassemblyStreamsTest, CanDeleteFirstOrderedMessage) { + NiceMock<MockFunction<ReassemblyStreams::OnAssembledMessage>> on_assembled; + EXPECT_CALL(on_assembled, + Call(ElementsAre(tsn(2)), + Property(&DcSctpMessage::payload, ElementsAre(2, 3, 4)))); + + TraditionalReassemblyStreams streams("", on_assembled.AsStdFunction()); + + // Not received, SID=1. TSN=1, SSN=0 + gen_.Ordered({1}, "BE"); + // And deleted (SID=1, TSN=1, SSN=0) + ForwardTsnChunk::SkippedStream skipped[] = { + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(0))}; + EXPECT_EQ(streams.HandleForwardTsn(tsn(1), skipped), 0u); + + // Receive SID=1, TSN=2, SSN=1 + EXPECT_EQ(streams.Add(tsn(2), gen_.Ordered({2, 3, 4}, "BE")), 0); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn b/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn new file mode 100644 index 0000000000..92ce413d0d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/BUILD.gn @@ -0,0 +1,278 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("context") { + sources = [ "context.h" ] + deps = [ + "../common:internal_types", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_library("heartbeat_handler") { + deps = [ + ":context", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../packet:bounded_io", + "../packet:chunk", + "../packet:parameter", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + "../timer", + ] + sources = [ + "heartbeat_handler.cc", + "heartbeat_handler.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/functional:bind_front", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("stream_reset_handler") { + deps = [ + ":context", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base/containers:flat_set", + "../common:internal_types", + "../common:str_join", + "../packet:chunk", + "../packet:parameter", + "../packet:sctp_packet", + "../packet:tlv_trait", + "../public:socket", + "../public:types", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../timer", + "../tx:retransmission_queue", + ] + sources = [ + "stream_reset_handler.cc", + "stream_reset_handler.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/functional:bind_front", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("packet_sender") { + deps = [ + "../packet:sctp_packet", + "../public:socket", + "../public:types", + "../timer", + ] + sources = [ + "packet_sender.cc", + "packet_sender.h", + ] + absl_deps = [] +} + +rtc_library("transmission_control_block") { + deps = [ + ":context", + ":heartbeat_handler", + ":packet_sender", + ":stream_reset_handler", + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../timer", + "../tx:retransmission_error_counter", + "../tx:retransmission_queue", + "../tx:retransmission_timeout", + "../tx:send_queue", + ] + sources = [ + "capabilities.h", + "transmission_control_block.cc", + "transmission_control_block.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/functional:bind_front", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("dcsctp_socket") { + deps = [ + ":context", + ":heartbeat_handler", + ":packet_sender", + ":stream_reset_handler", + ":transmission_control_block", + "../../../api:array_view", + "../../../api:make_ref_counted", + "../../../api:refcountedbase", + "../../../api:scoped_refptr", + "../../../api:sequence_checker", + "../../../api/task_queue:task_queue", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:internal_types", + "../packet:bounded_io", + "../packet:chunk", + "../packet:chunk_validators", + "../packet:data", + "../packet:error_cause", + "../packet:parameter", + "../packet:sctp_packet", + "../packet:tlv_trait", + "../public:socket", + "../public:types", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../timer", + "../tx:retransmission_error_counter", + "../tx:retransmission_queue", + "../tx:retransmission_timeout", + "../tx:rr_send_queue", + "../tx:send_queue", + ] + sources = [ + "callback_deferrer.cc", + "callback_deferrer.h", + "dcsctp_socket.cc", + "dcsctp_socket.h", + "state_cookie.cc", + "state_cookie.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/functional:bind_front", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +if (rtc_include_tests) { + rtc_source_set("mock_callbacks") { + testonly = true + sources = [ "mock_dcsctp_socket_callbacks.h" ] + deps = [ + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base:logging", + "../../../rtc_base:random", + "../../../test:test_support", + "../public:socket", + "../public:types", + "../timer", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + } + + rtc_source_set("mock_context") { + testonly = true + sources = [ "mock_context.h" ] + deps = [ + ":context", + ":mock_callbacks", + "../../../test:test_support", + "../common:internal_types", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + } + + rtc_library("dcsctp_socket_unittests") { + testonly = true + + deps = [ + ":dcsctp_socket", + ":heartbeat_handler", + ":mock_callbacks", + ":mock_context", + ":packet_sender", + ":stream_reset_handler", + "../../../api:array_view", + "../../../api:create_network_emulation_manager", + "../../../api:network_emulation_manager_api", + "../../../api/task_queue", + "../../../api/task_queue:pending_task_safety_flag", + "../../../api/units:time_delta", + "../../../call:simulated_network", + "../../../rtc_base:checks", + "../../../rtc_base:copy_on_write_buffer", + "../../../rtc_base:gunit_helpers", + "../../../rtc_base:logging", + "../../../rtc_base:rtc_base_tests_utils", + "../../../rtc_base:socket_address", + "../../../rtc_base:stringutils", + "../../../rtc_base:timeutils", + "../../../test:test_support", + "../common:handover_testing", + "../common:internal_types", + "../packet:chunk", + "../packet:error_cause", + "../packet:parameter", + "../packet:sctp_packet", + "../packet:tlv_trait", + "../public:socket", + "../public:types", + "../public:utils", + "../rx:data_tracker", + "../rx:reassembly_queue", + "../testing:data_generator", + "../testing:testing_macros", + "../timer", + "../timer:task_queue_timeout", + "../tx:mock_send_queue", + "../tx:retransmission_queue", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/flags:flag", + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] + sources = [ + "dcsctp_socket_network_test.cc", + "dcsctp_socket_test.cc", + "heartbeat_handler_test.cc", + "packet_sender_test.cc", + "state_cookie_test.cc", + "stream_reset_handler_test.cc", + ] + } +} diff --git a/third_party/libwebrtc/net/dcsctp/socket/DEPS b/third_party/libwebrtc/net/dcsctp/socket/DEPS new file mode 100644 index 0000000000..d4966290e3 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/DEPS @@ -0,0 +1,5 @@ +specific_include_rules = { + "dcsctp_socket_network_test.cc": [ + "+call", + ] +} diff --git a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc new file mode 100644 index 0000000000..123526e782 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.cc @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/callback_deferrer.h" + +#include "api/make_ref_counted.h" + +namespace dcsctp { +namespace { +// A wrapper around the move-only DcSctpMessage, to let it be captured in a +// lambda. +class MessageDeliverer { + public: + explicit MessageDeliverer(DcSctpMessage&& message) + : state_(rtc::make_ref_counted<State>(std::move(message))) {} + + void Deliver(DcSctpSocketCallbacks& c) { + // Really ensure that it's only called once. + RTC_DCHECK(!state_->has_delivered); + state_->has_delivered = true; + c.OnMessageReceived(std::move(state_->message)); + } + + private: + struct State : public rtc::RefCountInterface { + explicit State(DcSctpMessage&& m) + : has_delivered(false), message(std::move(m)) {} + bool has_delivered; + DcSctpMessage message; + }; + rtc::scoped_refptr<State> state_; +}; +} // namespace + +void CallbackDeferrer::Prepare() { + RTC_DCHECK(!prepared_); + prepared_ = true; +} + +void CallbackDeferrer::TriggerDeferred() { + // Need to swap here. The client may call into the library from within a + // callback, and that might result in adding new callbacks to this instance, + // and the vector can't be modified while iterated on. + RTC_DCHECK(prepared_); + std::vector<std::function<void(DcSctpSocketCallbacks & cb)>> deferred; + deferred.swap(deferred_); + prepared_ = false; + + for (auto& cb : deferred) { + cb(underlying_); + } +} + +SendPacketStatus CallbackDeferrer::SendPacketWithStatus( + rtc::ArrayView<const uint8_t> data) { + // Will not be deferred - call directly. + return underlying_.SendPacketWithStatus(data); +} + +std::unique_ptr<Timeout> CallbackDeferrer::CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision) { + // Will not be deferred - call directly. + return underlying_.CreateTimeout(precision); +} + +TimeMs CallbackDeferrer::TimeMillis() { + // Will not be deferred - call directly. + return underlying_.TimeMillis(); +} + +uint32_t CallbackDeferrer::GetRandomInt(uint32_t low, uint32_t high) { + // Will not be deferred - call directly. + return underlying_.GetRandomInt(low, high); +} + +void CallbackDeferrer::OnMessageReceived(DcSctpMessage message) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [deliverer = MessageDeliverer(std::move(message))]( + DcSctpSocketCallbacks& cb) mutable { deliverer.Deliver(cb); }); +} + +void CallbackDeferrer::OnError(ErrorKind error, absl::string_view message) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { + cb.OnError(error, message); + }); +} + +void CallbackDeferrer::OnAborted(ErrorKind error, absl::string_view message) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [error, message = std::string(message)](DcSctpSocketCallbacks& cb) { + cb.OnAborted(error, message); + }); +} + +void CallbackDeferrer::OnConnected() { + RTC_DCHECK(prepared_); + deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnConnected(); }); +} + +void CallbackDeferrer::OnClosed() { + RTC_DCHECK(prepared_); + deferred_.emplace_back([](DcSctpSocketCallbacks& cb) { cb.OnClosed(); }); +} + +void CallbackDeferrer::OnConnectionRestarted() { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [](DcSctpSocketCallbacks& cb) { cb.OnConnectionRestarted(); }); +} + +void CallbackDeferrer::OnStreamsResetFailed( + rtc::ArrayView<const StreamID> outgoing_streams, + absl::string_view reason) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [streams = std::vector<StreamID>(outgoing_streams.begin(), + outgoing_streams.end()), + reason = std::string(reason)](DcSctpSocketCallbacks& cb) { + cb.OnStreamsResetFailed(streams, reason); + }); +} + +void CallbackDeferrer::OnStreamsResetPerformed( + rtc::ArrayView<const StreamID> outgoing_streams) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [streams = std::vector<StreamID>(outgoing_streams.begin(), + outgoing_streams.end())]( + DcSctpSocketCallbacks& cb) { cb.OnStreamsResetPerformed(streams); }); +} + +void CallbackDeferrer::OnIncomingStreamsReset( + rtc::ArrayView<const StreamID> incoming_streams) { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [streams = std::vector<StreamID>(incoming_streams.begin(), + incoming_streams.end())]( + DcSctpSocketCallbacks& cb) { cb.OnIncomingStreamsReset(streams); }); +} + +void CallbackDeferrer::OnBufferedAmountLow(StreamID stream_id) { + RTC_DCHECK(prepared_); + deferred_.emplace_back([stream_id](DcSctpSocketCallbacks& cb) { + cb.OnBufferedAmountLow(stream_id); + }); +} + +void CallbackDeferrer::OnTotalBufferedAmountLow() { + RTC_DCHECK(prepared_); + deferred_.emplace_back( + [](DcSctpSocketCallbacks& cb) { cb.OnTotalBufferedAmountLow(); }); +} + +void CallbackDeferrer::OnLifecycleMessageExpired(LifecycleId lifecycle_id, + bool maybe_delivered) { + // Will not be deferred - call directly. + underlying_.OnLifecycleMessageExpired(lifecycle_id, maybe_delivered); +} +void CallbackDeferrer::OnLifecycleMessageFullySent(LifecycleId lifecycle_id) { + // Will not be deferred - call directly. + underlying_.OnLifecycleMessageFullySent(lifecycle_id); +} +void CallbackDeferrer::OnLifecycleMessageDelivered(LifecycleId lifecycle_id) { + // Will not be deferred - call directly. + underlying_.OnLifecycleMessageDelivered(lifecycle_id); +} +void CallbackDeferrer::OnLifecycleEnd(LifecycleId lifecycle_id) { + // Will not be deferred - call directly. + underlying_.OnLifecycleEnd(lifecycle_id); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h new file mode 100644 index 0000000000..1c35dda6cf --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/callback_deferrer.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_ +#define NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_ + +#include <cstdint> +#include <functional> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "api/ref_counted_base.h" +#include "api/scoped_refptr.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" + +namespace dcsctp { +// Defers callbacks until they can be safely triggered. +// +// There are a lot of callbacks from the dcSCTP library to the client, +// such as when messages are received or streams are closed. When the client +// receives these callbacks, the client is expected to be able to call into the +// library - from within the callback. For example, sending a reply message when +// a certain SCTP message has been received, or to reconnect when the connection +// was closed for any reason. This means that the dcSCTP library must always be +// in a consistent and stable state when these callbacks are delivered, and to +// ensure that's the case, callbacks are not immediately delivered from where +// they originate, but instead queued (deferred) by this class. At the end of +// any public API method that may result in callbacks, they are triggered and +// then delivered. +// +// There are a number of exceptions, which is clearly annotated in the API. +class CallbackDeferrer : public DcSctpSocketCallbacks { + public: + class ScopedDeferrer { + public: + explicit ScopedDeferrer(CallbackDeferrer& callback_deferrer) + : callback_deferrer_(callback_deferrer) { + callback_deferrer_.Prepare(); + } + + ~ScopedDeferrer() { callback_deferrer_.TriggerDeferred(); } + + private: + CallbackDeferrer& callback_deferrer_; + }; + + explicit CallbackDeferrer(DcSctpSocketCallbacks& underlying) + : underlying_(underlying) {} + + // Implementation of DcSctpSocketCallbacks + SendPacketStatus SendPacketWithStatus( + rtc::ArrayView<const uint8_t> data) override; + std::unique_ptr<Timeout> CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision) override; + TimeMs TimeMillis() override; + uint32_t GetRandomInt(uint32_t low, uint32_t high) override; + void OnMessageReceived(DcSctpMessage message) override; + void OnError(ErrorKind error, absl::string_view message) override; + void OnAborted(ErrorKind error, absl::string_view message) override; + void OnConnected() override; + void OnClosed() override; + void OnConnectionRestarted() override; + void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams, + absl::string_view reason) override; + void OnStreamsResetPerformed( + rtc::ArrayView<const StreamID> outgoing_streams) override; + void OnIncomingStreamsReset( + rtc::ArrayView<const StreamID> incoming_streams) override; + void OnBufferedAmountLow(StreamID stream_id) override; + void OnTotalBufferedAmountLow() override; + + void OnLifecycleMessageExpired(LifecycleId lifecycle_id, + bool maybe_delivered) override; + void OnLifecycleMessageFullySent(LifecycleId lifecycle_id) override; + void OnLifecycleMessageDelivered(LifecycleId lifecycle_id) override; + void OnLifecycleEnd(LifecycleId lifecycle_id) override; + + private: + void Prepare(); + void TriggerDeferred(); + + DcSctpSocketCallbacks& underlying_; + bool prepared_ = false; + std::vector<std::function<void(DcSctpSocketCallbacks& cb)>> deferred_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_CALLBACK_DEFERRER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/capabilities.h b/third_party/libwebrtc/net/dcsctp/socket/capabilities.h new file mode 100644 index 0000000000..c6d3692b2d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/capabilities.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_CAPABILITIES_H_ +#define NET_DCSCTP_SOCKET_CAPABILITIES_H_ + +namespace dcsctp { +// Indicates what the association supports, meaning that both parties +// support it and that feature can be used. +struct Capabilities { + // RFC3758 Partial Reliability Extension + bool partial_reliability = false; + // RFC8260 Stream Schedulers and User Message Interleaving + bool message_interleaving = false; + // RFC6525 Stream Reconfiguration + bool reconfig = false; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_CAPABILITIES_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/context.h b/third_party/libwebrtc/net/dcsctp/socket/context.h new file mode 100644 index 0000000000..eca5b9e4fb --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/context.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_CONTEXT_H_ +#define NET_DCSCTP_SOCKET_CONTEXT_H_ + +#include <cstdint> + +#include "absl/strings/string_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// A set of helper methods used by handlers to e.g. send packets. +// +// Implemented by the TransmissionControlBlock. +class Context { + public: + virtual ~Context() = default; + + // Indicates if a connection has been established. + virtual bool is_connection_established() const = 0; + + // Returns this side's initial TSN value. + virtual TSN my_initial_tsn() const = 0; + + // Returns the peer's initial TSN value. + virtual TSN peer_initial_tsn() const = 0; + + // Returns the socket callbacks. + virtual DcSctpSocketCallbacks& callbacks() const = 0; + + // Observes a measured RTT value, in milliseconds. + virtual void ObserveRTT(DurationMs rtt_ms) = 0; + + // Returns the current Retransmission Timeout (rto) value, in milliseconds. + virtual DurationMs current_rto() const = 0; + + // Increments the transmission error counter, given a human readable reason. + virtual bool IncrementTxErrorCounter(absl::string_view reason) = 0; + + // Clears the transmission error counter. + virtual void ClearTxErrorCounter() = 0; + + // Returns true if there have been too many retransmission errors. + virtual bool HasTooManyTxErrors() const = 0; + + // Returns a PacketBuilder, filled in with the correct verification tag. + virtual SctpPacket::Builder PacketBuilder() const = 0; + + // Builds the packet from `builder` and sends it. + virtual void Send(SctpPacket::Builder& builder) = 0; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_CONTEXT_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc new file mode 100644 index 0000000000..53838193ec --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.cc @@ -0,0 +1,1752 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/dcsctp_socket.h" + +#include <algorithm> +#include <cstdint> +#include <limits> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/functional/bind_front.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/packet/chunk/abort_chunk.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" +#include "net/dcsctp/packet/chunk_validators.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/error_cause/cookie_received_while_shutting_down_cause.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/no_user_data_cause.h" +#include "net/dcsctp/packet/error_cause/out_of_resource_error_cause.h" +#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" +#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h" +#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/state_cookie_parameter.h" +#include "net/dcsctp/packet/parameter/supported_extensions_parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/callback_deferrer.h" +#include "net/dcsctp/socket/capabilities.h" +#include "net/dcsctp/socket/heartbeat_handler.h" +#include "net/dcsctp/socket/state_cookie.h" +#include "net/dcsctp/socket/stream_reset_handler.h" +#include "net/dcsctp/socket/transmission_control_block.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" +#include "rtc_base/strings/string_format.h" + +namespace dcsctp { +namespace { + +// https://tools.ietf.org/html/rfc4960#section-5.1 +constexpr uint32_t kMinVerificationTag = 1; +constexpr uint32_t kMaxVerificationTag = std::numeric_limits<uint32_t>::max(); + +// https://tools.ietf.org/html/rfc4960#section-3.3.2 +constexpr uint32_t kMinInitialTsn = 0; +constexpr uint32_t kMaxInitialTsn = std::numeric_limits<uint32_t>::max(); + +Capabilities GetCapabilities(const DcSctpOptions& options, + const Parameters& parameters) { + Capabilities capabilities; + absl::optional<SupportedExtensionsParameter> supported_extensions = + parameters.get<SupportedExtensionsParameter>(); + + if (options.enable_partial_reliability) { + capabilities.partial_reliability = + parameters.get<ForwardTsnSupportedParameter>().has_value(); + if (supported_extensions.has_value()) { + capabilities.partial_reliability |= + supported_extensions->supports(ForwardTsnChunk::kType); + } + } + + if (options.enable_message_interleaving && supported_extensions.has_value()) { + capabilities.message_interleaving = + supported_extensions->supports(IDataChunk::kType) && + supported_extensions->supports(IForwardTsnChunk::kType); + } + if (supported_extensions.has_value() && + supported_extensions->supports(ReConfigChunk::kType)) { + capabilities.reconfig = true; + } + return capabilities; +} + +void AddCapabilityParameters(const DcSctpOptions& options, + Parameters::Builder& builder) { + std::vector<uint8_t> chunk_types = {ReConfigChunk::kType}; + + if (options.enable_partial_reliability) { + builder.Add(ForwardTsnSupportedParameter()); + chunk_types.push_back(ForwardTsnChunk::kType); + } + if (options.enable_message_interleaving) { + chunk_types.push_back(IDataChunk::kType); + chunk_types.push_back(IForwardTsnChunk::kType); + } + builder.Add(SupportedExtensionsParameter(std::move(chunk_types))); +} + +TieTag MakeTieTag(DcSctpSocketCallbacks& cb) { + uint32_t tie_tag_upper = + cb.GetRandomInt(0, std::numeric_limits<uint32_t>::max()); + uint32_t tie_tag_lower = + cb.GetRandomInt(1, std::numeric_limits<uint32_t>::max()); + return TieTag(static_cast<uint64_t>(tie_tag_upper) << 32 | + static_cast<uint64_t>(tie_tag_lower)); +} + +SctpImplementation DeterminePeerImplementation( + rtc::ArrayView<const uint8_t> cookie) { + if (cookie.size() > 8) { + absl::string_view magic(reinterpret_cast<const char*>(cookie.data()), 8); + if (magic == "dcSCTP00") { + return SctpImplementation::kDcsctp; + } + if (magic == "KAME-BSD") { + return SctpImplementation::kUsrSctp; + } + } + return SctpImplementation::kOther; +} +} // namespace + +DcSctpSocket::DcSctpSocket(absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr<PacketObserver> packet_observer, + const DcSctpOptions& options) + : log_prefix_(std::string(log_prefix) + ": "), + packet_observer_(std::move(packet_observer)), + options_(options), + callbacks_(callbacks), + timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return callbacks_.CreateTimeout(precision); + }), + t1_init_(timer_manager_.CreateTimer( + "t1-init", + absl::bind_front(&DcSctpSocket::OnInitTimerExpiry, this), + TimerOptions(options.t1_init_timeout, + TimerBackoffAlgorithm::kExponential, + options.max_init_retransmits))), + t1_cookie_(timer_manager_.CreateTimer( + "t1-cookie", + absl::bind_front(&DcSctpSocket::OnCookieTimerExpiry, this), + TimerOptions(options.t1_cookie_timeout, + TimerBackoffAlgorithm::kExponential, + options.max_init_retransmits))), + t2_shutdown_(timer_manager_.CreateTimer( + "t2-shutdown", + absl::bind_front(&DcSctpSocket::OnShutdownTimerExpiry, this), + TimerOptions(options.t2_shutdown_timeout, + TimerBackoffAlgorithm::kExponential, + options.max_retransmissions))), + packet_sender_(callbacks_, + absl::bind_front(&DcSctpSocket::OnSentPacket, this)), + send_queue_(log_prefix_, + &callbacks_, + options_.max_send_buffer_size, + options_.mtu, + options_.default_stream_priority, + options_.total_buffered_amount_low_threshold) {} + +std::string DcSctpSocket::log_prefix() const { + return log_prefix_ + "[" + std::string(ToString(state_)) + "] "; +} + +bool DcSctpSocket::IsConsistent() const { + if (tcb_ != nullptr && tcb_->reassembly_queue().HasMessages()) { + return false; + } + switch (state_) { + case State::kClosed: + return (tcb_ == nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kCookieWait: + return (tcb_ == nullptr && t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kCookieEchoed: + return (tcb_ != nullptr && !t1_init_->is_running() && + t1_cookie_->is_running() && !t2_shutdown_->is_running() && + tcb_->has_cookie_echo_chunk()); + case State::kEstablished: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kShutdownPending: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kShutdownSent: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && t2_shutdown_->is_running()); + case State::kShutdownReceived: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && !t2_shutdown_->is_running()); + case State::kShutdownAckSent: + return (tcb_ != nullptr && !t1_init_->is_running() && + !t1_cookie_->is_running() && t2_shutdown_->is_running()); + } +} + +constexpr absl::string_view DcSctpSocket::ToString(DcSctpSocket::State state) { + switch (state) { + case DcSctpSocket::State::kClosed: + return "CLOSED"; + case DcSctpSocket::State::kCookieWait: + return "COOKIE_WAIT"; + case DcSctpSocket::State::kCookieEchoed: + return "COOKIE_ECHOED"; + case DcSctpSocket::State::kEstablished: + return "ESTABLISHED"; + case DcSctpSocket::State::kShutdownPending: + return "SHUTDOWN_PENDING"; + case DcSctpSocket::State::kShutdownSent: + return "SHUTDOWN_SENT"; + case DcSctpSocket::State::kShutdownReceived: + return "SHUTDOWN_RECEIVED"; + case DcSctpSocket::State::kShutdownAckSent: + return "SHUTDOWN_ACK_SENT"; + } +} + +void DcSctpSocket::SetState(State state, absl::string_view reason) { + if (state_ != state) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Socket state changed from " + << ToString(state_) << " to " << ToString(state) + << " due to " << reason; + state_ = state; + } +} + +void DcSctpSocket::SendInit() { + Parameters::Builder params_builder; + AddCapabilityParameters(options_, params_builder); + InitChunk init(/*initiate_tag=*/connect_params_.verification_tag, + /*a_rwnd=*/options_.max_receiver_window_buffer_size, + options_.announced_maximum_outgoing_streams, + options_.announced_maximum_incoming_streams, + connect_params_.initial_tsn, params_builder.Build()); + SctpPacket::Builder b(VerificationTag(0), options_); + b.Add(init); + packet_sender_.Send(b); +} + +void DcSctpSocket::MakeConnectionParameters() { + VerificationTag new_verification_tag( + callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag)); + TSN initial_tsn(callbacks_.GetRandomInt(kMinInitialTsn, kMaxInitialTsn)); + connect_params_.initial_tsn = initial_tsn; + connect_params_.verification_tag = new_verification_tag; +} + +void DcSctpSocket::Connect() { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (state_ == State::kClosed) { + MakeConnectionParameters(); + RTC_DLOG(LS_INFO) + << log_prefix() + << rtc::StringFormat( + "Connecting. my_verification_tag=%08x, my_initial_tsn=%u", + *connect_params_.verification_tag, *connect_params_.initial_tsn); + SendInit(); + t1_init_->Start(); + SetState(State::kCookieWait, "Connect called"); + } else { + RTC_DLOG(LS_WARNING) << log_prefix() + << "Called Connect on a socket that is not closed"; + } + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::CreateTransmissionControlBlock( + const Capabilities& capabilities, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag) { + metrics_.uses_message_interleaving = capabilities.message_interleaving; + tcb_ = std::make_unique<TransmissionControlBlock>( + timer_manager_, log_prefix_, options_, capabilities, callbacks_, + send_queue_, my_verification_tag, my_initial_tsn, peer_verification_tag, + peer_initial_tsn, a_rwnd, tie_tag, packet_sender_, + [this]() { return state_ == State::kEstablished; }); + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Created TCB: " << tcb_->ToString(); +} + +void DcSctpSocket::RestoreFromState(const DcSctpSocketHandoverState& state) { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (state_ != State::kClosed) { + callbacks_.OnError(ErrorKind::kUnsupportedOperation, + "Only closed socket can be restored from state"); + } else { + if (state.socket_state == + DcSctpSocketHandoverState::SocketState::kConnected) { + VerificationTag my_verification_tag = + VerificationTag(state.my_verification_tag); + connect_params_.verification_tag = my_verification_tag; + + Capabilities capabilities; + capabilities.partial_reliability = state.capabilities.partial_reliability; + capabilities.message_interleaving = + state.capabilities.message_interleaving; + capabilities.reconfig = state.capabilities.reconfig; + + send_queue_.RestoreFromState(state); + + CreateTransmissionControlBlock( + capabilities, my_verification_tag, TSN(state.my_initial_tsn), + VerificationTag(state.peer_verification_tag), + TSN(state.peer_initial_tsn), static_cast<size_t>(0), + TieTag(state.tie_tag)); + + tcb_->RestoreFromState(state); + + SetState(State::kEstablished, "restored from handover state"); + callbacks_.OnConnected(); + } + } + + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::Shutdown() { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (tcb_ != nullptr) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon receipt of the SHUTDOWN primitive from its upper layer, the + // endpoint enters the SHUTDOWN-PENDING state and remains there until all + // outstanding data has been acknowledged by its peer." + + // TODO(webrtc:12739): Remove this check, as it just hides the problem that + // the socket can transition from ShutdownSent to ShutdownPending, or + // ShutdownAckSent to ShutdownPending which is illegal. + if (state_ != State::kShutdownSent && state_ != State::kShutdownAckSent) { + SetState(State::kShutdownPending, "Shutdown called"); + t1_init_->Stop(); + t1_cookie_->Stop(); + MaybeSendShutdownOrAck(); + } + } else { + // Connection closed before even starting to connect, or during the initial + // connection phase. There is no outstanding data, so the socket can just + // be closed (stopping any connection timers, if any), as this is the + // client's intention, by calling Shutdown. + InternalClose(ErrorKind::kNoError, ""); + } + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::Close() { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (state_ != State::kClosed) { + if (tcb_ != nullptr) { + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(AbortChunk(/*filled_in_verification_tag=*/true, + Parameters::Builder() + .Add(UserInitiatedAbortCause("Close called")) + .Build())); + packet_sender_.Send(b); + } + InternalClose(ErrorKind::kNoError, ""); + } else { + RTC_DLOG(LS_INFO) << log_prefix() << "Called Close on a closed socket"; + } + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::CloseConnectionBecauseOfTooManyTransmissionErrors() { + packet_sender_.Send(tcb_->PacketBuilder().Add(AbortChunk( + true, Parameters::Builder() + .Add(UserInitiatedAbortCause("Too many retransmissions")) + .Build()))); + InternalClose(ErrorKind::kTooManyRetries, "Too many retransmissions"); +} + +void DcSctpSocket::InternalClose(ErrorKind error, absl::string_view message) { + if (state_ != State::kClosed) { + t1_init_->Stop(); + t1_cookie_->Stop(); + t2_shutdown_->Stop(); + tcb_ = nullptr; + + if (error == ErrorKind::kNoError) { + callbacks_.OnClosed(); + } else { + callbacks_.OnAborted(error, message); + } + SetState(State::kClosed, message); + } + // This method's purpose is to abort/close and make it consistent by ensuring + // that e.g. all timers really are stopped. + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::SetStreamPriority(StreamID stream_id, + StreamPriority priority) { + RTC_DCHECK_RUN_ON(&thread_checker_); + send_queue_.SetStreamPriority(stream_id, priority); +} +StreamPriority DcSctpSocket::GetStreamPriority(StreamID stream_id) const { + RTC_DCHECK_RUN_ON(&thread_checker_); + return send_queue_.GetStreamPriority(stream_id); +} + +SendStatus DcSctpSocket::Send(DcSctpMessage message, + const SendOptions& send_options) { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + LifecycleId lifecycle_id = send_options.lifecycle_id; + + if (message.payload().empty()) { + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Unable to send empty message"); + return SendStatus::kErrorMessageEmpty; + } + if (message.payload().size() > options_.max_message_size) { + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Unable to send too large message"); + return SendStatus::kErrorMessageTooLarge; + } + if (state_ == State::kShutdownPending || state_ == State::kShutdownSent || + state_ == State::kShutdownReceived || state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "An endpoint should reject any new data request from its upper layer + // if it is in the SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED, or + // SHUTDOWN-ACK-SENT state." + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } + callbacks_.OnError(ErrorKind::kWrongSequence, + "Unable to send message as the socket is shutting down"); + return SendStatus::kErrorShuttingDown; + } + if (send_queue_.IsFull()) { + if (lifecycle_id.IsSet()) { + callbacks_.OnLifecycleEnd(lifecycle_id); + } + callbacks_.OnError(ErrorKind::kResourceExhaustion, + "Unable to send message as the send queue is full"); + return SendStatus::kErrorResourceExhaustion; + } + + TimeMs now = callbacks_.TimeMillis(); + ++metrics_.tx_messages_count; + send_queue_.Add(now, std::move(message), send_options); + if (tcb_ != nullptr) { + tcb_->SendBufferedPackets(now); + } + + RTC_DCHECK(IsConsistent()); + return SendStatus::kSuccess; +} + +ResetStreamsStatus DcSctpSocket::ResetStreams( + rtc::ArrayView<const StreamID> outgoing_streams) { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (tcb_ == nullptr) { + callbacks_.OnError(ErrorKind::kWrongSequence, + "Can't reset streams as the socket is not connected"); + return ResetStreamsStatus::kNotConnected; + } + if (!tcb_->capabilities().reconfig) { + callbacks_.OnError(ErrorKind::kUnsupportedOperation, + "Can't reset streams as the peer doesn't support it"); + return ResetStreamsStatus::kNotSupported; + } + + tcb_->stream_reset_handler().ResetStreams(outgoing_streams); + MaybeSendResetStreamsRequest(); + + RTC_DCHECK(IsConsistent()); + return ResetStreamsStatus::kPerformed; +} + +SocketState DcSctpSocket::state() const { + RTC_DCHECK_RUN_ON(&thread_checker_); + switch (state_) { + case State::kClosed: + return SocketState::kClosed; + case State::kCookieWait: + case State::kCookieEchoed: + return SocketState::kConnecting; + case State::kEstablished: + return SocketState::kConnected; + case State::kShutdownPending: + case State::kShutdownSent: + case State::kShutdownReceived: + case State::kShutdownAckSent: + return SocketState::kShuttingDown; + } +} + +void DcSctpSocket::SetMaxMessageSize(size_t max_message_size) { + RTC_DCHECK_RUN_ON(&thread_checker_); + options_.max_message_size = max_message_size; +} + +size_t DcSctpSocket::buffered_amount(StreamID stream_id) const { + RTC_DCHECK_RUN_ON(&thread_checker_); + return send_queue_.buffered_amount(stream_id); +} + +size_t DcSctpSocket::buffered_amount_low_threshold(StreamID stream_id) const { + RTC_DCHECK_RUN_ON(&thread_checker_); + return send_queue_.buffered_amount_low_threshold(stream_id); +} + +void DcSctpSocket::SetBufferedAmountLowThreshold(StreamID stream_id, + size_t bytes) { + RTC_DCHECK_RUN_ON(&thread_checker_); + send_queue_.SetBufferedAmountLowThreshold(stream_id, bytes); +} + +absl::optional<Metrics> DcSctpSocket::GetMetrics() const { + RTC_DCHECK_RUN_ON(&thread_checker_); + + if (tcb_ == nullptr) { + return absl::nullopt; + } + + Metrics metrics = metrics_; + metrics.cwnd_bytes = tcb_->cwnd(); + metrics.srtt_ms = tcb_->current_srtt().value(); + size_t packet_payload_size = + options_.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize; + metrics.unack_data_count = + tcb_->retransmission_queue().outstanding_items() + + (send_queue_.total_buffered_amount() + packet_payload_size - 1) / + packet_payload_size; + metrics.peer_rwnd_bytes = tcb_->retransmission_queue().rwnd(); + + return metrics; +} + +void DcSctpSocket::MaybeSendShutdownOnPacketReceived(const SctpPacket& packet) { + if (state_ == State::kShutdownSent) { + bool has_data_chunk = + std::find_if(packet.descriptors().begin(), packet.descriptors().end(), + [](const SctpPacket::ChunkDescriptor& descriptor) { + return descriptor.type == DataChunk::kType; + }) != packet.descriptors().end(); + if (has_data_chunk) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "While in the SHUTDOWN-SENT state, the SHUTDOWN sender MUST immediately + // respond to each received packet containing one or more DATA chunks with + // a SHUTDOWN chunk and restart the T2-shutdown timer."" + SendShutdown(); + t2_shutdown_->set_duration(tcb_->current_rto()); + t2_shutdown_->Start(); + } + } +} + +void DcSctpSocket::MaybeSendResetStreamsRequest() { + absl::optional<ReConfigChunk> reconfig = + tcb_->stream_reset_handler().MakeStreamResetRequest(); + if (reconfig.has_value()) { + SctpPacket::Builder builder = tcb_->PacketBuilder(); + builder.Add(*reconfig); + packet_sender_.Send(builder); + } +} + +bool DcSctpSocket::ValidatePacket(const SctpPacket& packet) { + const CommonHeader& header = packet.common_header(); + VerificationTag my_verification_tag = + tcb_ != nullptr ? tcb_->my_verification_tag() : VerificationTag(0); + + if (header.verification_tag == VerificationTag(0)) { + if (packet.descriptors().size() == 1 && + packet.descriptors()[0].type == InitChunk::kType) { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "When an endpoint receives an SCTP packet with the Verification Tag + // set to 0, it should verify that the packet contains only an INIT chunk. + // Otherwise, the receiver MUST silently discard the packet."" + return true; + } + callbacks_.OnError( + ErrorKind::kParseFailed, + "Only a single INIT chunk can be present in packets sent on " + "verification_tag = 0"); + return false; + } + + if (packet.descriptors().size() == 1 && + packet.descriptors()[0].type == AbortChunk::kType) { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "The receiver of an ABORT MUST accept the packet if the Verification + // Tag field of the packet matches its own tag and the T bit is not set OR + // if it is set to its peer's tag and the T bit is set in the Chunk Flags. + // Otherwise, the receiver MUST silently discard the packet and take no + // further action." + bool t_bit = (packet.descriptors()[0].flags & 0x01) != 0; + if (t_bit && tcb_ == nullptr) { + // Can't verify the tag - assume it's okey. + return true; + } + if ((!t_bit && header.verification_tag == my_verification_tag) || + (t_bit && header.verification_tag == tcb_->peer_verification_tag())) { + return true; + } + callbacks_.OnError(ErrorKind::kParseFailed, + "ABORT chunk verification tag was wrong"); + return false; + } + + if (packet.descriptors()[0].type == InitAckChunk::kType) { + if (header.verification_tag == connect_params_.verification_tag) { + return true; + } + callbacks_.OnError( + ErrorKind::kParseFailed, + rtc::StringFormat( + "Packet has invalid verification tag: %08x, expected %08x", + *header.verification_tag, *connect_params_.verification_tag)); + return false; + } + + if (packet.descriptors()[0].type == CookieEchoChunk::kType) { + // Handled in chunk handler (due to RFC 4960, section 5.2.4). + return true; + } + + if (packet.descriptors().size() == 1 && + packet.descriptors()[0].type == ShutdownCompleteChunk::kType) { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "The receiver of a SHUTDOWN COMPLETE shall accept the packet if the + // Verification Tag field of the packet matches its own tag and the T bit is + // not set OR if it is set to its peer's tag and the T bit is set in the + // Chunk Flags. Otherwise, the receiver MUST silently discard the packet + // and take no further action." + bool t_bit = (packet.descriptors()[0].flags & 0x01) != 0; + if (t_bit && tcb_ == nullptr) { + // Can't verify the tag - assume it's okey. + return true; + } + if ((!t_bit && header.verification_tag == my_verification_tag) || + (t_bit && header.verification_tag == tcb_->peer_verification_tag())) { + return true; + } + callbacks_.OnError(ErrorKind::kParseFailed, + "SHUTDOWN_COMPLETE chunk verification tag was wrong"); + return false; + } + + // https://tools.ietf.org/html/rfc4960#section-8.5 + // "When receiving an SCTP packet, the endpoint MUST ensure that the value + // in the Verification Tag field of the received SCTP packet matches its own + // tag. If the received Verification Tag value does not match the receiver's + // own tag value, the receiver shall silently discard the packet and shall not + // process it any further..." + if (header.verification_tag == my_verification_tag) { + return true; + } + + callbacks_.OnError( + ErrorKind::kParseFailed, + rtc::StringFormat( + "Packet has invalid verification tag: %08x, expected %08x", + *header.verification_tag, *my_verification_tag)); + return false; +} + +void DcSctpSocket::HandleTimeout(TimeoutID timeout_id) { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + timer_manager_.HandleTimeout(timeout_id); + + if (tcb_ != nullptr && tcb_->HasTooManyTxErrors()) { + // Tearing down the TCB has to be done outside the handlers. + CloseConnectionBecauseOfTooManyTransmissionErrors(); + } + + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::ReceivePacket(rtc::ArrayView<const uint8_t> data) { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + ++metrics_.rx_packets_count; + + if (packet_observer_ != nullptr) { + packet_observer_->OnReceivedPacket(callbacks_.TimeMillis(), data); + } + + absl::optional<SctpPacket> packet = + SctpPacket::Parse(data, options_.disable_checksum_verification); + if (!packet.has_value()) { + // https://tools.ietf.org/html/rfc4960#section-6.8 + // "The default procedure for handling invalid SCTP packets is to + // silently discard them." + callbacks_.OnError(ErrorKind::kParseFailed, + "Failed to parse received SCTP packet"); + RTC_DCHECK(IsConsistent()); + return; + } + + if (RTC_DLOG_IS_ON) { + for (const auto& descriptor : packet->descriptors()) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received " + << DebugConvertChunkToString(descriptor.data); + } + } + + if (!ValidatePacket(*packet)) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Packet failed verification tag check - dropping"; + RTC_DCHECK(IsConsistent()); + return; + } + + MaybeSendShutdownOnPacketReceived(*packet); + + for (const auto& descriptor : packet->descriptors()) { + if (!Dispatch(packet->common_header(), descriptor)) { + break; + } + } + + if (tcb_ != nullptr) { + tcb_->data_tracker().ObservePacketEnd(); + tcb_->MaybeSendSack(); + } + + RTC_DCHECK(IsConsistent()); +} + +void DcSctpSocket::DebugPrintOutgoing(rtc::ArrayView<const uint8_t> payload) { + auto packet = SctpPacket::Parse(payload); + RTC_DCHECK(packet.has_value()); + + for (const auto& desc : packet->descriptors()) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Sent " + << DebugConvertChunkToString(desc.data); + } +} + +bool DcSctpSocket::Dispatch(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + switch (descriptor.type) { + case DataChunk::kType: + HandleData(header, descriptor); + break; + case InitChunk::kType: + HandleInit(header, descriptor); + break; + case InitAckChunk::kType: + HandleInitAck(header, descriptor); + break; + case SackChunk::kType: + HandleSack(header, descriptor); + break; + case HeartbeatRequestChunk::kType: + HandleHeartbeatRequest(header, descriptor); + break; + case HeartbeatAckChunk::kType: + HandleHeartbeatAck(header, descriptor); + break; + case AbortChunk::kType: + HandleAbort(header, descriptor); + break; + case ErrorChunk::kType: + HandleError(header, descriptor); + break; + case CookieEchoChunk::kType: + HandleCookieEcho(header, descriptor); + break; + case CookieAckChunk::kType: + HandleCookieAck(header, descriptor); + break; + case ShutdownChunk::kType: + HandleShutdown(header, descriptor); + break; + case ShutdownAckChunk::kType: + HandleShutdownAck(header, descriptor); + break; + case ShutdownCompleteChunk::kType: + HandleShutdownComplete(header, descriptor); + break; + case ReConfigChunk::kType: + HandleReconfig(header, descriptor); + break; + case ForwardTsnChunk::kType: + HandleForwardTsn(header, descriptor); + break; + case IDataChunk::kType: + HandleIData(header, descriptor); + break; + case IForwardTsnChunk::kType: + HandleIForwardTsn(header, descriptor); + break; + default: + return HandleUnrecognizedChunk(descriptor); + } + return true; +} + +bool DcSctpSocket::HandleUnrecognizedChunk( + const SctpPacket::ChunkDescriptor& descriptor) { + bool report_as_error = (descriptor.type & 0x40) != 0; + bool continue_processing = (descriptor.type & 0x80) != 0; + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received unknown chunk: " + << static_cast<int>(descriptor.type); + if (report_as_error) { + rtc::StringBuilder sb; + sb << "Received unknown chunk of type: " + << static_cast<int>(descriptor.type) << " with report-error bit set"; + callbacks_.OnError(ErrorKind::kParseFailed, sb.str()); + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Unknown chunk, with type indicating it should be reported."; + + // https://tools.ietf.org/html/rfc4960#section-3.2 + // "... report in an ERROR chunk using the 'Unrecognized Chunk Type' + // cause." + if (tcb_ != nullptr) { + // Need TCB - this chunk must be sent with a correct verification tag. + packet_sender_.Send(tcb_->PacketBuilder().Add( + ErrorChunk(Parameters::Builder() + .Add(UnrecognizedChunkTypeCause(std::vector<uint8_t>( + descriptor.data.begin(), descriptor.data.end()))) + .Build()))); + } + } + if (!continue_processing) { + // https://tools.ietf.org/html/rfc4960#section-3.2 + // "Stop processing this SCTP packet and discard it, do not process any + // further chunks within it." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Unknown chunk, with type indicating not to " + "process any further chunks"; + } + + return continue_processing; +} + +absl::optional<DurationMs> DcSctpSocket::OnInitTimerExpiry() { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t1_init_->name() + << " has expired: " << t1_init_->expiration_count() + << "/" << t1_init_->options().max_restarts.value_or(-1); + RTC_DCHECK(state_ == State::kCookieWait); + + if (t1_init_->is_running()) { + SendInit(); + } else { + InternalClose(ErrorKind::kTooManyRetries, "No INIT_ACK received"); + } + RTC_DCHECK(IsConsistent()); + return absl::nullopt; +} + +absl::optional<DurationMs> DcSctpSocket::OnCookieTimerExpiry() { + // https://tools.ietf.org/html/rfc4960#section-4 + // "If the T1-cookie timer expires, the endpoint MUST retransmit COOKIE + // ECHO and restart the T1-cookie timer without changing state. This MUST + // be repeated up to 'Max.Init.Retransmits' times. After that, the endpoint + // MUST abort the initialization process and report the error to the SCTP + // user." + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t1_cookie_->name() + << " has expired: " << t1_cookie_->expiration_count() + << "/" + << t1_cookie_->options().max_restarts.value_or(-1); + + RTC_DCHECK(state_ == State::kCookieEchoed); + + if (t1_cookie_->is_running()) { + tcb_->SendBufferedPackets(callbacks_.TimeMillis()); + } else { + InternalClose(ErrorKind::kTooManyRetries, "No COOKIE_ACK received"); + } + + RTC_DCHECK(IsConsistent()); + return absl::nullopt; +} + +absl::optional<DurationMs> DcSctpSocket::OnShutdownTimerExpiry() { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Timer " << t2_shutdown_->name() + << " has expired: " << t2_shutdown_->expiration_count() + << "/" + << t2_shutdown_->options().max_restarts.value_or(-1); + + if (!t2_shutdown_->is_running()) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "An endpoint should limit the number of retransmissions of the SHUTDOWN + // chunk to the protocol parameter 'Association.Max.Retrans'. If this + // threshold is exceeded, the endpoint should destroy the TCB..." + + packet_sender_.Send(tcb_->PacketBuilder().Add( + AbortChunk(true, Parameters::Builder() + .Add(UserInitiatedAbortCause( + "Too many retransmissions of SHUTDOWN")) + .Build()))); + + InternalClose(ErrorKind::kTooManyRetries, "No SHUTDOWN_ACK received"); + RTC_DCHECK(IsConsistent()); + return absl::nullopt; + } + + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If the timer expires, the endpoint must resend the SHUTDOWN with the + // updated last sequential TSN received from its peer." + SendShutdown(); + RTC_DCHECK(IsConsistent()); + return tcb_->current_rto(); +} + +void DcSctpSocket::OnSentPacket(rtc::ArrayView<const uint8_t> packet, + SendPacketStatus status) { + // The packet observer is invoked even if the packet was failed to be sent, to + // indicate an attempt was made. + if (packet_observer_ != nullptr) { + packet_observer_->OnSentPacket(callbacks_.TimeMillis(), packet); + } + + if (status == SendPacketStatus::kSuccess) { + if (RTC_DLOG_IS_ON) { + DebugPrintOutgoing(packet); + } + + // The heartbeat interval timer is restarted for every sent packet, to + // fire when the outgoing channel is inactive. + if (tcb_ != nullptr) { + tcb_->heartbeat_handler().RestartTimer(); + } + + ++metrics_.tx_packets_count; + } +} + +bool DcSctpSocket::ValidateHasTCB() { + if (tcb_ != nullptr) { + return true; + } + + callbacks_.OnError( + ErrorKind::kNotConnected, + "Received unexpected commands on socket that is not connected"); + return false; +} + +void DcSctpSocket::ReportFailedToParseChunk(int chunk_type) { + rtc::StringBuilder sb; + sb << "Failed to parse chunk of type: " << chunk_type; + callbacks_.OnError(ErrorKind::kParseFailed, sb.str()); +} + +void DcSctpSocket::HandleData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<DataChunk> chunk = DataChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleDataCommon(*chunk); + } +} + +void DcSctpSocket::HandleIData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<IDataChunk> chunk = IDataChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleDataCommon(*chunk); + } +} + +void DcSctpSocket::HandleDataCommon(AnyDataChunk& chunk) { + TSN tsn = chunk.tsn(); + AnyDataChunk::ImmediateAckFlag immediate_ack = chunk.options().immediate_ack; + Data data = std::move(chunk).extract(); + + if (data.payload.empty()) { + // Empty DATA chunks are illegal. + packet_sender_.Send(tcb_->PacketBuilder().Add( + ErrorChunk(Parameters::Builder().Add(NoUserDataCause(tsn)).Build()))); + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Received DATA chunk with no user data"); + return; + } + + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Handle DATA, queue_size=" + << tcb_->reassembly_queue().queued_bytes() + << ", water_mark=" + << tcb_->reassembly_queue().watermark_bytes() + << ", full=" << tcb_->reassembly_queue().is_full() + << ", above=" + << tcb_->reassembly_queue().is_above_watermark(); + + if (tcb_->reassembly_queue().is_full()) { + // If the reassembly queue is full, there is nothing that can be done. The + // specification only allows dropping gap-ack-blocks, and that's not + // likely to help as the socket has been trying to fill gaps since the + // watermark was reached. + packet_sender_.Send(tcb_->PacketBuilder().Add(AbortChunk( + true, Parameters::Builder().Add(OutOfResourceErrorCause()).Build()))); + InternalClose(ErrorKind::kResourceExhaustion, + "Reassembly Queue is exhausted"); + return; + } + + if (tcb_->reassembly_queue().is_above_watermark()) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Is above high watermark"; + // If the reassembly queue is above its high watermark, only accept data + // chunks that increase its cumulative ack tsn in an attempt to fill gaps + // to deliver messages. + if (!tcb_->data_tracker().will_increase_cum_ack_tsn(tsn)) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Rejected data because of exceeding watermark"; + tcb_->data_tracker().ForceImmediateSack(); + return; + } + } + + if (!tcb_->data_tracker().IsTSNValid(tsn)) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Rejected data because of failing TSN validity"; + return; + } + + if (tcb_->data_tracker().Observe(tsn, immediate_ack)) { + tcb_->reassembly_queue().MaybeResetStreamsDeferred( + tcb_->data_tracker().last_cumulative_acked_tsn()); + tcb_->reassembly_queue().Add(tsn, std::move(data)); + DeliverReassembledMessages(); + } +} + +void DcSctpSocket::HandleInit(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<InitChunk> chunk = InitChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + if (chunk->initiate_tag() == VerificationTag(0) || + chunk->nbr_outbound_streams() == 0 || chunk->nbr_inbound_streams() == 0) { + // https://tools.ietf.org/html/rfc4960#section-3.3.2 + // "If the value of the Initiate Tag in a received INIT chunk is found + // to be 0, the receiver MUST treat it as an error and close the + // association by transmitting an ABORT." + + // "A receiver of an INIT with the OS value set to 0 SHOULD abort the + // association." + + // "A receiver of an INIT with the MIS value of 0 SHOULD abort the + // association." + + packet_sender_.Send( + SctpPacket::Builder(VerificationTag(0), options_) + .Add(AbortChunk( + /*filled_in_verification_tag=*/false, + Parameters::Builder() + .Add(ProtocolViolationCause("INIT malformed")) + .Build()))); + InternalClose(ErrorKind::kProtocolViolation, "Received invalid INIT"); + return; + } + + if (state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If an endpoint is in the SHUTDOWN-ACK-SENT state and receives an + // INIT chunk (e.g., if the SHUTDOWN COMPLETE was lost) with source and + // destination transport addresses (either in the IP addresses or in the + // INIT chunk) that belong to this association, it should discard the INIT + // chunk and retransmit the SHUTDOWN ACK chunk." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init indicating lost ShutdownComplete"; + SendShutdownAck(); + return; + } + + TieTag tie_tag(0); + if (state_ == State::kClosed) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init in closed state (normal)"; + + MakeConnectionParameters(); + } else if (state_ == State::kCookieWait || state_ == State::kCookieEchoed) { + // https://tools.ietf.org/html/rfc4960#section-5.2.1 + // "This usually indicates an initialization collision, i.e., each + // endpoint is attempting, at about the same time, to establish an + // association with the other endpoint. Upon receipt of an INIT in the + // COOKIE-WAIT state, an endpoint MUST respond with an INIT ACK using the + // same parameters it sent in its original INIT chunk (including its + // Initiate Tag, unchanged). When responding, the endpoint MUST send the + // INIT ACK back to the same address that the original INIT (sent by this + // endpoint) was sent." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init indicating simultaneous connections"; + } else { + RTC_DCHECK(tcb_ != nullptr); + // https://tools.ietf.org/html/rfc4960#section-5.2.2 + // "The outbound SCTP packet containing this INIT ACK MUST carry a + // Verification Tag value equal to the Initiate Tag found in the + // unexpected INIT. And the INIT ACK MUST contain a new Initiate Tag + // (randomly generated; see Section 5.3.1). Other parameters for the + // endpoint SHOULD be copied from the existing parameters of the + // association (e.g., number of outbound streams) into the INIT ACK and + // cookie." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received Init indicating restarted connection"; + // Create a new verification tag - different from the previous one. + for (int tries = 0; tries < 10; ++tries) { + connect_params_.verification_tag = VerificationTag( + callbacks_.GetRandomInt(kMinVerificationTag, kMaxVerificationTag)); + if (connect_params_.verification_tag != tcb_->my_verification_tag()) { + break; + } + } + + // Make the initial TSN make a large jump, so that there is no overlap + // with the old and new association. + connect_params_.initial_tsn = + TSN(*tcb_->retransmission_queue().next_tsn() + 1000000); + tie_tag = tcb_->tie_tag(); + } + + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << rtc::StringFormat( + "Proceeding with connection. my_verification_tag=%08x, " + "my_initial_tsn=%u, peer_verification_tag=%08x, " + "peer_initial_tsn=%u", + *connect_params_.verification_tag, *connect_params_.initial_tsn, + *chunk->initiate_tag(), *chunk->initial_tsn()); + + Capabilities capabilities = GetCapabilities(options_, chunk->parameters()); + + SctpPacket::Builder b(chunk->initiate_tag(), options_); + Parameters::Builder params_builder = + Parameters::Builder().Add(StateCookieParameter( + StateCookie(chunk->initiate_tag(), chunk->initial_tsn(), + chunk->a_rwnd(), tie_tag, capabilities) + .Serialize())); + AddCapabilityParameters(options_, params_builder); + + InitAckChunk init_ack(/*initiate_tag=*/connect_params_.verification_tag, + options_.max_receiver_window_buffer_size, + options_.announced_maximum_outgoing_streams, + options_.announced_maximum_incoming_streams, + connect_params_.initial_tsn, params_builder.Build()); + b.Add(init_ack); + packet_sender_.Send(b); +} + +void DcSctpSocket::HandleInitAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<InitAckChunk> chunk = InitAckChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + if (state_ != State::kCookieWait) { + // https://tools.ietf.org/html/rfc4960#section-5.2.3 + // "If an INIT ACK is received by an endpoint in any state other than + // the COOKIE-WAIT state, the endpoint should discard the INIT ACK chunk." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received INIT_ACK in unexpected state"; + return; + } + + auto cookie = chunk->parameters().get<StateCookieParameter>(); + if (!cookie.has_value()) { + packet_sender_.Send( + SctpPacket::Builder(connect_params_.verification_tag, options_) + .Add(AbortChunk( + /*filled_in_verification_tag=*/false, + Parameters::Builder() + .Add(ProtocolViolationCause("INIT-ACK malformed")) + .Build()))); + InternalClose(ErrorKind::kProtocolViolation, + "InitAck chunk doesn't contain a cookie"); + return; + } + Capabilities capabilities = GetCapabilities(options_, chunk->parameters()); + t1_init_->Stop(); + + metrics_.peer_implementation = DeterminePeerImplementation(cookie->data()); + + // If the connection is re-established (peer restarted, but re-used old + // connection), make sure that all message identifiers are reset and any + // partly sent message is re-sent in full. The same is true when the socket + // is closed and later re-opened, which never happens in WebRTC, but is a + // valid operation on the SCTP level. Note that in case of handover, the + // send queue is already re-configured, and shouldn't be reset. + send_queue_.Reset(); + + CreateTransmissionControlBlock(capabilities, connect_params_.verification_tag, + connect_params_.initial_tsn, + chunk->initiate_tag(), chunk->initial_tsn(), + chunk->a_rwnd(), MakeTieTag(callbacks_)); + + SetState(State::kCookieEchoed, "INIT_ACK received"); + + // The connection isn't fully established just yet. + tcb_->SetCookieEchoChunk(CookieEchoChunk(cookie->data())); + tcb_->SendBufferedPackets(callbacks_.TimeMillis()); + t1_cookie_->Start(); +} + +void DcSctpSocket::HandleCookieEcho( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<CookieEchoChunk> chunk = + CookieEchoChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + absl::optional<StateCookie> cookie = + StateCookie::Deserialize(chunk->cookie()); + if (!cookie.has_value()) { + callbacks_.OnError(ErrorKind::kParseFailed, "Failed to parse state cookie"); + return; + } + + if (tcb_ != nullptr) { + if (!HandleCookieEchoWithTCB(header, *cookie)) { + return; + } + } else { + if (header.verification_tag != connect_params_.verification_tag) { + callbacks_.OnError( + ErrorKind::kParseFailed, + rtc::StringFormat( + "Received CookieEcho with invalid verification tag: %08x, " + "expected %08x", + *header.verification_tag, *connect_params_.verification_tag)); + return; + } + } + + // The init timer can be running on simultaneous connections. + t1_init_->Stop(); + t1_cookie_->Stop(); + if (state_ != State::kEstablished) { + if (tcb_ != nullptr) { + tcb_->ClearCookieEchoChunk(); + } + SetState(State::kEstablished, "COOKIE_ECHO received"); + callbacks_.OnConnected(); + } + + if (tcb_ == nullptr) { + // If the connection is re-established (peer restarted, but re-used old + // connection), make sure that all message identifiers are reset and any + // partly sent message is re-sent in full. The same is true when the socket + // is closed and later re-opened, which never happens in WebRTC, but is a + // valid operation on the SCTP level. Note that in case of handover, the + // send queue is already re-configured, and shouldn't be reset. + send_queue_.Reset(); + + CreateTransmissionControlBlock( + cookie->capabilities(), connect_params_.verification_tag, + connect_params_.initial_tsn, cookie->initiate_tag(), + cookie->initial_tsn(), cookie->a_rwnd(), MakeTieTag(callbacks_)); + } + + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(CookieAckChunk()); + + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "A COOKIE ACK chunk may be bundled with any pending DATA chunks (and/or + // SACK chunks), but the COOKIE ACK chunk MUST be the first chunk in the + // packet." + tcb_->SendBufferedPackets(b, callbacks_.TimeMillis()); +} + +bool DcSctpSocket::HandleCookieEchoWithTCB(const CommonHeader& header, + const StateCookie& cookie) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Handling CookieEchoChunk with TCB. local_tag=" + << *tcb_->my_verification_tag() + << ", peer_tag=" << *header.verification_tag + << ", tcb_tag=" << *tcb_->peer_verification_tag() + << ", cookie_tag=" << *cookie.initiate_tag() + << ", local_tie_tag=" << *tcb_->tie_tag() + << ", peer_tie_tag=" << *cookie.tie_tag(); + // https://tools.ietf.org/html/rfc4960#section-5.2.4 + // "Handle a COOKIE ECHO when a TCB Exists" + if (header.verification_tag != tcb_->my_verification_tag() && + tcb_->peer_verification_tag() != cookie.initiate_tag() && + cookie.tie_tag() == tcb_->tie_tag()) { + // "A) In this case, the peer may have restarted." + if (state_ == State::kShutdownAckSent) { + // "If the endpoint is in the SHUTDOWN-ACK-SENT state and recognizes + // that the peer has restarted ... it MUST NOT set up a new association + // but instead resend the SHUTDOWN ACK and send an ERROR chunk with a + // "Cookie Received While Shutting Down" error cause to its peer." + SctpPacket::Builder b(cookie.initiate_tag(), options_); + b.Add(ShutdownAckChunk()); + b.Add(ErrorChunk(Parameters::Builder() + .Add(CookieReceivedWhileShuttingDownCause()) + .Build())); + packet_sender_.Send(b); + callbacks_.OnError(ErrorKind::kWrongSequence, + "Received COOKIE-ECHO while shutting down"); + return false; + } + + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received COOKIE-ECHO indicating a restarted peer"; + + tcb_ = nullptr; + callbacks_.OnConnectionRestarted(); + } else if (header.verification_tag == tcb_->my_verification_tag() && + tcb_->peer_verification_tag() != cookie.initiate_tag()) { + // TODO(boivie): Handle the peer_tag == 0? + // "B) In this case, both sides may be attempting to start an + // association at about the same time, but the peer endpoint started its + // INIT after responding to the local endpoint's INIT." + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Received COOKIE-ECHO indicating simultaneous connections"; + tcb_ = nullptr; + } else if (header.verification_tag != tcb_->my_verification_tag() && + tcb_->peer_verification_tag() == cookie.initiate_tag() && + cookie.tie_tag() == TieTag(0)) { + // "C) In this case, the local endpoint's cookie has arrived late. + // Before it arrived, the local endpoint sent an INIT and received an + // INIT ACK and finally sent a COOKIE ECHO with the peer's same tag but + // a new tag of its own. The cookie should be silently discarded. The + // endpoint SHOULD NOT change states and should leave any timers + // running." + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Received COOKIE-ECHO indicating a late COOKIE-ECHO. Discarding"; + return false; + } else if (header.verification_tag == tcb_->my_verification_tag() && + tcb_->peer_verification_tag() == cookie.initiate_tag()) { + // "D) When both local and remote tags match, the endpoint should enter + // the ESTABLISHED state, if it is in the COOKIE-ECHOED state. It + // should stop any cookie timer that may be running and send a COOKIE + // ACK." + RTC_DLOG(LS_VERBOSE) + << log_prefix() + << "Received duplicate COOKIE-ECHO, probably because of peer not " + "receiving COOKIE-ACK and retransmitting COOKIE-ECHO. Continuing."; + } + return true; +} + +void DcSctpSocket::HandleCookieAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<CookieAckChunk> chunk = CookieAckChunk::Parse(descriptor.data); + if (!ValidateParseSuccess(chunk)) { + return; + } + + if (state_ != State::kCookieEchoed) { + // https://tools.ietf.org/html/rfc4960#section-5.2.5 + // "At any state other than COOKIE-ECHOED, an endpoint should silently + // discard a received COOKIE ACK chunk." + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received COOKIE_ACK not in COOKIE_ECHOED state"; + return; + } + + // RFC 4960, Errata ID: 4400 + t1_cookie_->Stop(); + tcb_->ClearCookieEchoChunk(); + SetState(State::kEstablished, "COOKIE_ACK received"); + tcb_->SendBufferedPackets(callbacks_.TimeMillis()); + callbacks_.OnConnected(); +} + +void DcSctpSocket::DeliverReassembledMessages() { + if (tcb_->reassembly_queue().HasMessages()) { + for (auto& message : tcb_->reassembly_queue().FlushMessages()) { + ++metrics_.rx_messages_count; + callbacks_.OnMessageReceived(std::move(message)); + } + } +} + +void DcSctpSocket::HandleSack(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<SackChunk> chunk = SackChunk::Parse(descriptor.data); + + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + TimeMs now = callbacks_.TimeMillis(); + SackChunk sack = ChunkValidators::Clean(*std::move(chunk)); + + if (tcb_->retransmission_queue().HandleSack(now, sack)) { + MaybeSendShutdownOrAck(); + // Receiving an ACK may make the socket go into fast recovery mode. + // https://datatracker.ietf.org/doc/html/rfc4960#section-7.2.4 + // "Determine how many of the earliest (i.e., lowest TSN) DATA chunks + // marked for retransmission will fit into a single packet, subject to + // constraint of the path MTU of the destination transport address to + // which the packet is being sent. Call this value K. Retransmit those K + // DATA chunks in a single packet. When a Fast Retransmit is being + // performed, the sender SHOULD ignore the value of cwnd and SHOULD NOT + // delay retransmission for this single packet." + tcb_->MaybeSendFastRetransmit(); + + // Receiving an ACK will decrease outstanding bytes (maybe now below + // cwnd?) or indicate packet loss that may result in sending FORWARD-TSN. + tcb_->SendBufferedPackets(now); + } else { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Dropping out-of-order SACK with TSN " + << *sack.cumulative_tsn_ack(); + } + } +} + +void DcSctpSocket::HandleHeartbeatRequest( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<HeartbeatRequestChunk> chunk = + HeartbeatRequestChunk::Parse(descriptor.data); + + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + tcb_->heartbeat_handler().HandleHeartbeatRequest(*std::move(chunk)); + } +} + +void DcSctpSocket::HandleHeartbeatAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<HeartbeatAckChunk> chunk = + HeartbeatAckChunk::Parse(descriptor.data); + + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + tcb_->heartbeat_handler().HandleHeartbeatAck(*std::move(chunk)); + } +} + +void DcSctpSocket::HandleAbort(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<AbortChunk> chunk = AbortChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk)) { + std::string error_string = ErrorCausesToString(chunk->error_causes()); + if (tcb_ == nullptr) { + // https://tools.ietf.org/html/rfc4960#section-3.3.7 + // "If an endpoint receives an ABORT with a format error or no TCB is + // found, it MUST silently discard it." + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received ABORT (" << error_string + << ") on a connection with no TCB. Ignoring"; + return; + } + + RTC_DLOG(LS_WARNING) << log_prefix() << "Received ABORT (" << error_string + << ") - closing connection."; + InternalClose(ErrorKind::kPeerReported, error_string); + } +} + +void DcSctpSocket::HandleError(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<ErrorChunk> chunk = ErrorChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk)) { + std::string error_string = ErrorCausesToString(chunk->error_causes()); + if (tcb_ == nullptr) { + RTC_DLOG(LS_VERBOSE) << log_prefix() << "Received ERROR (" << error_string + << ") on a connection with no TCB. Ignoring"; + return; + } + + RTC_DLOG(LS_WARNING) << log_prefix() << "Received ERROR: " << error_string; + callbacks_.OnError(ErrorKind::kPeerReported, + "Peer reported error: " + error_string); + } +} + +void DcSctpSocket::HandleReconfig( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<ReConfigChunk> chunk = ReConfigChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + tcb_->stream_reset_handler().HandleReConfig(*std::move(chunk)); + // Handling this response may result in outgoing stream resets finishing + // (either successfully or with failure). If there still are pending streams + // that were waiting for this request to finish, continue resetting them. + MaybeSendResetStreamsRequest(); + } +} + +void DcSctpSocket::HandleShutdown( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + if (!ValidateParseSuccess(ShutdownChunk::Parse(descriptor.data))) { + return; + } + + if (state_ == State::kClosed) { + return; + } else if (state_ == State::kCookieWait || state_ == State::kCookieEchoed) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If a SHUTDOWN is received in the COOKIE-WAIT or COOKIE ECHOED state, + // the SHUTDOWN chunk SHOULD be silently discarded." + } else if (state_ == State::kShutdownSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If an endpoint is in the SHUTDOWN-SENT state and receives a + // SHUTDOWN chunk from its peer, the endpoint shall respond immediately + // with a SHUTDOWN ACK to its peer, and move into the SHUTDOWN-ACK-SENT + // state restarting its T2-shutdown timer." + SendShutdownAck(); + SetState(State::kShutdownAckSent, "SHUTDOWN received"); + } else if (state_ == State::kShutdownAckSent) { + // TODO(webrtc:12739): This condition should be removed and handled by the + // next (state_ != State::kShutdownReceived). + return; + } else if (state_ != State::kShutdownReceived) { + RTC_DLOG(LS_VERBOSE) << log_prefix() + << "Received SHUTDOWN - shutting down the socket"; + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon reception of the SHUTDOWN, the peer endpoint shall enter the + // SHUTDOWN-RECEIVED state, stop accepting new data from its SCTP user, + // and verify, by checking the Cumulative TSN Ack field of the chunk, that + // all its outstanding DATA chunks have been received by the SHUTDOWN + // sender." + SetState(State::kShutdownReceived, "SHUTDOWN received"); + MaybeSendShutdownOrAck(); + } +} + +void DcSctpSocket::HandleShutdownAck( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + if (!ValidateParseSuccess(ShutdownAckChunk::Parse(descriptor.data))) { + return; + } + + if (state_ == State::kShutdownSent || state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon the receipt of the SHUTDOWN ACK, the SHUTDOWN sender shall stop + // the T2-shutdown timer, send a SHUTDOWN COMPLETE chunk to its peer, and + // remove all record of the association." + + // "If an endpoint is in the SHUTDOWN-ACK-SENT state and receives a + // SHUTDOWN ACK, it shall stop the T2-shutdown timer, send a SHUTDOWN + // COMPLETE chunk to its peer, and remove all record of the association." + + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(ShutdownCompleteChunk(/*tag_reflected=*/false)); + packet_sender_.Send(b); + InternalClose(ErrorKind::kNoError, ""); + } else { + // https://tools.ietf.org/html/rfc4960#section-8.5.1 + // "If the receiver is in COOKIE-ECHOED or COOKIE-WAIT state + // the procedures in Section 8.4 SHOULD be followed; in other words, it + // should be treated as an Out Of The Blue packet." + + // https://tools.ietf.org/html/rfc4960#section-8.4 + // "If the packet contains a SHUTDOWN ACK chunk, the receiver + // should respond to the sender of the OOTB packet with a SHUTDOWN + // COMPLETE. When sending the SHUTDOWN COMPLETE, the receiver of the OOTB + // packet must fill in the Verification Tag field of the outbound packet + // with the Verification Tag received in the SHUTDOWN ACK and set the T + // bit in the Chunk Flags to indicate that the Verification Tag is + // reflected." + + SctpPacket::Builder b(header.verification_tag, options_); + b.Add(ShutdownCompleteChunk(/*tag_reflected=*/true)); + packet_sender_.Send(b); + } +} + +void DcSctpSocket::HandleShutdownComplete( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + if (!ValidateParseSuccess(ShutdownCompleteChunk::Parse(descriptor.data))) { + return; + } + + if (state_ == State::kShutdownAckSent) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Upon reception of the SHUTDOWN COMPLETE chunk, the endpoint will + // verify that it is in the SHUTDOWN-ACK-SENT state; if it is not, the + // chunk should be discarded. If the endpoint is in the SHUTDOWN-ACK-SENT + // state, the endpoint should stop the T2-shutdown timer and remove all + // knowledge of the association (and thus the association enters the + // CLOSED state)." + InternalClose(ErrorKind::kNoError, ""); + } +} + +void DcSctpSocket::HandleForwardTsn( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<ForwardTsnChunk> chunk = + ForwardTsnChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleForwardTsnCommon(*chunk); + } +} + +void DcSctpSocket::HandleIForwardTsn( + const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor) { + absl::optional<IForwardTsnChunk> chunk = + IForwardTsnChunk::Parse(descriptor.data); + if (ValidateParseSuccess(chunk) && ValidateHasTCB()) { + HandleForwardTsnCommon(*chunk); + } +} + +void DcSctpSocket::HandleForwardTsnCommon(const AnyForwardTsnChunk& chunk) { + if (!tcb_->capabilities().partial_reliability) { + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(AbortChunk(/*filled_in_verification_tag=*/true, + Parameters::Builder() + .Add(ProtocolViolationCause( + "I-FORWARD-TSN received, but not indicated " + "during connection establishment")) + .Build())); + packet_sender_.Send(b); + + callbacks_.OnError(ErrorKind::kProtocolViolation, + "Received a FORWARD_TSN without announced peer support"); + return; + } + tcb_->data_tracker().HandleForwardTsn(chunk.new_cumulative_tsn()); + tcb_->reassembly_queue().Handle(chunk); + // A forward TSN - for ordered streams - may allow messages to be + // delivered. + DeliverReassembledMessages(); + + // Processing a FORWARD_TSN might result in sending a SACK. + tcb_->MaybeSendSack(); +} + +void DcSctpSocket::MaybeSendShutdownOrAck() { + if (tcb_->retransmission_queue().outstanding_bytes() != 0) { + return; + } + + if (state_ == State::kShutdownPending) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "Once all its outstanding data has been acknowledged, the endpoint + // shall send a SHUTDOWN chunk to its peer including in the Cumulative TSN + // Ack field the last sequential TSN it has received from the peer. It + // shall then start the T2-shutdown timer and enter the SHUTDOWN-SENT + // state."" + + SendShutdown(); + t2_shutdown_->set_duration(tcb_->current_rto()); + t2_shutdown_->Start(); + SetState(State::kShutdownSent, "No more outstanding data"); + } else if (state_ == State::kShutdownReceived) { + // https://tools.ietf.org/html/rfc4960#section-9.2 + // "If the receiver of the SHUTDOWN has no more outstanding DATA + // chunks, the SHUTDOWN receiver MUST send a SHUTDOWN ACK and start a + // T2-shutdown timer of its own, entering the SHUTDOWN-ACK-SENT state. If + // the timer expires, the endpoint must resend the SHUTDOWN ACK." + + SendShutdownAck(); + SetState(State::kShutdownAckSent, "No more outstanding data"); + } +} + +void DcSctpSocket::SendShutdown() { + SctpPacket::Builder b = tcb_->PacketBuilder(); + b.Add(ShutdownChunk(tcb_->data_tracker().last_cumulative_acked_tsn())); + packet_sender_.Send(b); +} + +void DcSctpSocket::SendShutdownAck() { + packet_sender_.Send(tcb_->PacketBuilder().Add(ShutdownAckChunk())); + t2_shutdown_->set_duration(tcb_->current_rto()); + t2_shutdown_->Start(); +} + +HandoverReadinessStatus DcSctpSocket::GetHandoverReadiness() const { + RTC_DCHECK_RUN_ON(&thread_checker_); + HandoverReadinessStatus status; + if (state_ != State::kClosed && state_ != State::kEstablished) { + status.Add(HandoverUnreadinessReason::kWrongConnectionState); + } + status.Add(send_queue_.GetHandoverReadiness()); + if (tcb_) { + status.Add(tcb_->GetHandoverReadiness()); + } + return status; +} + +absl::optional<DcSctpSocketHandoverState> +DcSctpSocket::GetHandoverStateAndClose() { + RTC_DCHECK_RUN_ON(&thread_checker_); + CallbackDeferrer::ScopedDeferrer deferrer(callbacks_); + + if (!GetHandoverReadiness().IsReady()) { + return absl::nullopt; + } + + DcSctpSocketHandoverState state; + + if (state_ == State::kClosed) { + state.socket_state = DcSctpSocketHandoverState::SocketState::kClosed; + } else if (state_ == State::kEstablished) { + state.socket_state = DcSctpSocketHandoverState::SocketState::kConnected; + tcb_->AddHandoverState(state); + send_queue_.AddHandoverState(state); + InternalClose(ErrorKind::kNoError, "handover"); + } + + return std::move(state); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h new file mode 100644 index 0000000000..157c515d65 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket.h @@ -0,0 +1,298 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_ +#define NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "api/array_view.h" +#include "api/sequence_checker.h" +#include "net/dcsctp/packet/chunk/abort_chunk.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/init_ack_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_ack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_complete_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/packet_observer.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/callback_deferrer.h" +#include "net/dcsctp/socket/packet_sender.h" +#include "net/dcsctp/socket/state_cookie.h" +#include "net/dcsctp/socket/transmission_control_block.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_error_counter.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "net/dcsctp/tx/rr_send_queue.h" + +namespace dcsctp { + +// DcSctpSocket represents a single SCTP socket, to be used over DTLS. +// +// Every dcSCTP is completely isolated from any other socket. +// +// This class manages all packet and chunk dispatching and mainly handles the +// connection sequences (connect, close, shutdown, etc) as well as managing +// the Transmission Control Block (tcb). +// +// This class is thread-compatible. +class DcSctpSocket : public DcSctpSocketInterface { + public: + // Instantiates a DcSctpSocket, which interacts with the world through the + // `callbacks` interface and is configured using `options`. + // + // For debugging, `log_prefix` will prefix all debug logs, and a + // `packet_observer` can be attached to e.g. dump sent and received packets. + DcSctpSocket(absl::string_view log_prefix, + DcSctpSocketCallbacks& callbacks, + std::unique_ptr<PacketObserver> packet_observer, + const DcSctpOptions& options); + + DcSctpSocket(const DcSctpSocket&) = delete; + DcSctpSocket& operator=(const DcSctpSocket&) = delete; + + // Implementation of `DcSctpSocketInterface`. + void ReceivePacket(rtc::ArrayView<const uint8_t> data) override; + void HandleTimeout(TimeoutID timeout_id) override; + void Connect() override; + void RestoreFromState(const DcSctpSocketHandoverState& state) override; + void Shutdown() override; + void Close() override; + SendStatus Send(DcSctpMessage message, + const SendOptions& send_options) override; + ResetStreamsStatus ResetStreams( + rtc::ArrayView<const StreamID> outgoing_streams) override; + SocketState state() const override; + const DcSctpOptions& options() const override { return options_; } + void SetMaxMessageSize(size_t max_message_size) override; + void SetStreamPriority(StreamID stream_id, StreamPriority priority) override; + StreamPriority GetStreamPriority(StreamID stream_id) const override; + size_t buffered_amount(StreamID stream_id) const override; + size_t buffered_amount_low_threshold(StreamID stream_id) const override; + void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; + absl::optional<Metrics> GetMetrics() const override; + HandoverReadinessStatus GetHandoverReadiness() const override; + absl::optional<DcSctpSocketHandoverState> GetHandoverStateAndClose() override; + SctpImplementation peer_implementation() const override { + return metrics_.peer_implementation; + } + // Returns this socket's verification tag, or zero if not yet connected. + VerificationTag verification_tag() const { + return tcb_ != nullptr ? tcb_->my_verification_tag() : VerificationTag(0); + } + + private: + // Parameter proposals valid during the connect phase. + struct ConnectParameters { + TSN initial_tsn = TSN(0); + VerificationTag verification_tag = VerificationTag(0); + }; + + // Detailed state (separate from SocketState, which is the public state). + enum class State { + kClosed, + kCookieWait, + // TCB valid in these: + kCookieEchoed, + kEstablished, + kShutdownPending, + kShutdownSent, + kShutdownReceived, + kShutdownAckSent, + }; + + // Returns the log prefix used for debug logging. + std::string log_prefix() const; + + bool IsConsistent() const; + static constexpr absl::string_view ToString(DcSctpSocket::State state); + + void CreateTransmissionControlBlock(const Capabilities& capabilities, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag); + + // Changes the socket state, given a `reason` (for debugging/logging). + void SetState(State state, absl::string_view reason); + // Fills in `connect_params` with random verification tag and initial TSN. + void MakeConnectionParameters(); + // Closes the association. Note that the TCB will not be valid past this call. + void InternalClose(ErrorKind error, absl::string_view message); + // Closes the association, because of too many retransmission errors. + void CloseConnectionBecauseOfTooManyTransmissionErrors(); + // Timer expiration handlers + absl::optional<DurationMs> OnInitTimerExpiry(); + absl::optional<DurationMs> OnCookieTimerExpiry(); + absl::optional<DurationMs> OnShutdownTimerExpiry(); + void OnSentPacket(rtc::ArrayView<const uint8_t> packet, + SendPacketStatus status); + // Sends SHUTDOWN or SHUTDOWN-ACK if the socket is shutting down and if all + // outstanding data has been acknowledged. + void MaybeSendShutdownOrAck(); + // If the socket is shutting down, responds SHUTDOWN to any incoming DATA. + void MaybeSendShutdownOnPacketReceived(const SctpPacket& packet); + // If there are streams pending to be reset, send a request to reset them. + void MaybeSendResetStreamsRequest(); + // Sends a INIT chunk. + void SendInit(); + // Sends a SHUTDOWN chunk. + void SendShutdown(); + // Sends a SHUTDOWN-ACK chunk. + void SendShutdownAck(); + // Validates the SCTP packet, as a whole - not the validity of individual + // chunks within it, as that's done in the different chunk handlers. + bool ValidatePacket(const SctpPacket& packet); + // Parses `payload`, which is a serialized packet that is just going to be + // sent and prints all chunks. + void DebugPrintOutgoing(rtc::ArrayView<const uint8_t> payload); + // Called whenever there may be reassembled messages, and delivers those. + void DeliverReassembledMessages(); + // Returns true if there is a TCB, and false otherwise (and reports an error). + bool ValidateHasTCB(); + + // Returns true if the parsing of a chunk of type `T` succeeded. If it didn't, + // it reports an error and returns false. + template <class T> + bool ValidateParseSuccess(const absl::optional<T>& c) { + if (c.has_value()) { + return true; + } + + ReportFailedToParseChunk(T::kType); + return false; + } + + // Reports failing to have parsed a chunk with the provided `chunk_type`. + void ReportFailedToParseChunk(int chunk_type); + // Called when unknown chunks are received. May report an error. + bool HandleUnrecognizedChunk(const SctpPacket::ChunkDescriptor& descriptor); + + // Will dispatch more specific chunk handlers. + bool Dispatch(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming DATA chunks. + void HandleData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming I-DATA chunks. + void HandleIData(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Common handler for DATA and I-DATA chunks. + void HandleDataCommon(AnyDataChunk& chunk); + // Handles incoming INIT chunks. + void HandleInit(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming INIT-ACK chunks. + void HandleInitAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming SACK chunks. + void HandleSack(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming HEARTBEAT chunks. + void HandleHeartbeatRequest(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming HEARTBEAT-ACK chunks. + void HandleHeartbeatAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming ABORT chunks. + void HandleAbort(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming ERROR chunks. + void HandleError(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming COOKIE-ECHO chunks. + void HandleCookieEcho(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles receiving COOKIE-ECHO when there already is a TCB. The return value + // indicates if the processing should continue. + bool HandleCookieEchoWithTCB(const CommonHeader& header, + const StateCookie& cookie); + // Handles incoming COOKIE-ACK chunks. + void HandleCookieAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming SHUTDOWN chunks. + void HandleShutdown(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming SHUTDOWN-ACK chunks. + void HandleShutdownAck(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming FORWARD-TSN chunks. + void HandleForwardTsn(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming I-FORWARD-TSN chunks. + void HandleIForwardTsn(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Handles incoming RE-CONFIG chunks. + void HandleReconfig(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + // Common handled for FORWARD-TSN/I-FORWARD-TSN. + void HandleForwardTsnCommon(const AnyForwardTsnChunk& chunk); + // Handles incoming SHUTDOWN-COMPLETE chunks + void HandleShutdownComplete(const CommonHeader& header, + const SctpPacket::ChunkDescriptor& descriptor); + + const std::string log_prefix_; + const std::unique_ptr<PacketObserver> packet_observer_; + RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker thread_checker_; + Metrics metrics_; + DcSctpOptions options_; + + // Enqueues callbacks and dispatches them just before returning to the caller. + CallbackDeferrer callbacks_; + + TimerManager timer_manager_; + const std::unique_ptr<Timer> t1_init_; + const std::unique_ptr<Timer> t1_cookie_; + const std::unique_ptr<Timer> t2_shutdown_; + + // Packets that failed to be sent, but should be retried. + PacketSender packet_sender_; + + // The actual SendQueue implementation. As data can be sent on a socket before + // the connection is established, this component is not in the TCB. + RRSendQueue send_queue_; + + // Contains verification tag and initial TSN between having sent the INIT + // until the connection is established (there is no TCB at this point). + ConnectParameters connect_params_; + // The socket state. + State state_ = State::kClosed; + // If the connection is established, contains a transmission control block. + std::unique_ptr<TransmissionControlBlock> tcb_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_DCSCTP_SOCKET_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_network_test.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_network_test.cc new file mode 100644 index 0000000000..f097bfa095 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_network_test.cc @@ -0,0 +1,518 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include <cstdint> +#include <deque> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/pending_task_safety_flag.h" +#include "api/task_queue/task_queue_base.h" +#include "api/test/create_network_emulation_manager.h" +#include "api/test/network_emulation_manager.h" +#include "api/units/time_delta.h" +#include "call/simulated_network.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/socket/dcsctp_socket.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "net/dcsctp/timer/task_queue_timeout.h" +#include "rtc_base/copy_on_write_buffer.h" +#include "rtc_base/gunit.h" +#include "rtc_base/logging.h" +#include "rtc_base/socket_address.h" +#include "rtc_base/strings/string_format.h" +#include "rtc_base/time_utils.h" +#include "test/gmock.h" + +#if !defined(WEBRTC_ANDROID) && defined(NDEBUG) && \ + !defined(THREAD_SANITIZER) && !defined(MEMORY_SANITIZER) +#define DCSCTP_NDEBUG_TEST(t) t +#else +// In debug mode, and when MSAN or TSAN sanitizers are enabled, these tests are +// too expensive to run due to extensive consistency checks that iterate on all +// outstanding chunks. Same with low-end Android devices, which have +// difficulties with these tests. +#define DCSCTP_NDEBUG_TEST(t) DISABLED_##t +#endif + +namespace dcsctp { +namespace { +using ::testing::AllOf; +using ::testing::Ge; +using ::testing::Le; +using ::testing::SizeIs; + +constexpr StreamID kStreamId(1); +constexpr PPID kPpid(53); +constexpr size_t kSmallPayloadSize = 10; +constexpr size_t kLargePayloadSize = 10000; +constexpr size_t kHugePayloadSize = 262144; +constexpr size_t kBufferedAmountLowThreshold = kLargePayloadSize * 2; +constexpr webrtc::TimeDelta kPrintBandwidthDuration = + webrtc::TimeDelta::Seconds(1); +constexpr webrtc::TimeDelta kBenchmarkRuntime(webrtc::TimeDelta::Seconds(10)); +constexpr webrtc::TimeDelta kAWhile(webrtc::TimeDelta::Seconds(1)); + +inline int GetUniqueSeed() { + static int seed = 0; + return ++seed; +} + +DcSctpOptions MakeOptionsForTest() { + DcSctpOptions options; + + // Throughput numbers are affected by the MTU. Ensure it's constant. + options.mtu = 1200; + + // By disabling the heartbeat interval, there will no timers at all running + // when the socket is idle, which makes it easy to just continue the test + // until there are no more scheduled tasks. Note that it _will_ run for longer + // than necessary as timers aren't cancelled when they are stopped (as that's + // not supported), but it's still simulated time and passes quickly. + options.heartbeat_interval = DurationMs(0); + return options; +} + +// When doing throughput tests, knowing what each actor should do. +enum class ActorMode { + kAtRest, + kThroughputSender, + kThroughputReceiver, + kLimitedRetransmissionSender, +}; + +// An abstraction around EmulatedEndpoint, representing a bound socket that +// will send its packet to a given destination. +class BoundSocket : public webrtc::EmulatedNetworkReceiverInterface { + public: + void Bind(webrtc::EmulatedEndpoint* endpoint) { + endpoint_ = endpoint; + uint16_t port = endpoint->BindReceiver(0, this).value(); + source_address_ = + rtc::SocketAddress(endpoint_->GetPeerLocalAddress(), port); + } + + void SetDestination(const BoundSocket& socket) { + dest_address_ = socket.source_address_; + } + + void SetReceiver(std::function<void(rtc::CopyOnWriteBuffer)> receiver) { + receiver_ = std::move(receiver); + } + + void SendPacket(rtc::ArrayView<const uint8_t> data) { + endpoint_->SendPacket(source_address_, dest_address_, + rtc::CopyOnWriteBuffer(data.data(), data.size())); + } + + private: + // Implementation of `webrtc::EmulatedNetworkReceiverInterface`. + void OnPacketReceived(webrtc::EmulatedIpPacket packet) override { + receiver_(std::move(packet.data)); + } + + std::function<void(rtc::CopyOnWriteBuffer)> receiver_; + webrtc::EmulatedEndpoint* endpoint_ = nullptr; + rtc::SocketAddress source_address_; + rtc::SocketAddress dest_address_; +}; + +// Sends at a constant rate but with random packet sizes. +class SctpActor : public DcSctpSocketCallbacks { + public: + SctpActor(absl::string_view name, + BoundSocket& emulated_socket, + const DcSctpOptions& sctp_options) + : log_prefix_(std::string(name) + ": "), + thread_(rtc::Thread::Current()), + emulated_socket_(emulated_socket), + timeout_factory_( + *thread_, + [this]() { return TimeMillis(); }, + [this](dcsctp::TimeoutID timeout_id) { + sctp_socket_.HandleTimeout(timeout_id); + }), + random_(GetUniqueSeed()), + sctp_socket_(name, *this, nullptr, sctp_options), + last_bandwidth_printout_(TimeMs(TimeMillis())) { + emulated_socket.SetReceiver([this](rtc::CopyOnWriteBuffer buf) { + // The receiver will be executed on the NetworkEmulation task queue, but + // the dcSCTP socket is owned by `thread_` and is not thread-safe. + thread_->PostTask([this, buf] { this->sctp_socket_.ReceivePacket(buf); }); + }); + } + + void PrintBandwidth() { + TimeMs now = TimeMillis(); + DurationMs duration = now - last_bandwidth_printout_; + + double bitrate_mbps = + static_cast<double>(received_bytes_ * 8) / *duration / 1000; + RTC_LOG(LS_INFO) << log_prefix() + << rtc::StringFormat("Received %0.2f Mbps", bitrate_mbps); + + received_bitrate_mbps_.push_back(bitrate_mbps); + received_bytes_ = 0; + last_bandwidth_printout_ = now; + // Print again in a second. + if (mode_ == ActorMode::kThroughputReceiver) { + thread_->PostDelayedTask( + SafeTask(safety_.flag(), [this] { PrintBandwidth(); }), + kPrintBandwidthDuration); + } + } + + void SendPacket(rtc::ArrayView<const uint8_t> data) override { + emulated_socket_.SendPacket(data); + } + + std::unique_ptr<Timeout> CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision) override { + return timeout_factory_.CreateTimeout(precision); + } + + TimeMs TimeMillis() override { return TimeMs(rtc::TimeMillis()); } + + uint32_t GetRandomInt(uint32_t low, uint32_t high) override { + return random_.Rand(low, high); + } + + void OnMessageReceived(DcSctpMessage message) override { + received_bytes_ += message.payload().size(); + last_received_message_ = std::move(message); + } + + void OnError(ErrorKind error, absl::string_view message) override { + RTC_LOG(LS_WARNING) << log_prefix() << "Socket error: " << ToString(error) + << "; " << message; + } + + void OnAborted(ErrorKind error, absl::string_view message) override { + RTC_LOG(LS_ERROR) << log_prefix() << "Socket abort: " << ToString(error) + << "; " << message; + } + + void OnConnected() override {} + + void OnClosed() override {} + + void OnConnectionRestarted() override {} + + void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams, + absl::string_view reason) override {} + + void OnStreamsResetPerformed( + rtc::ArrayView<const StreamID> outgoing_streams) override {} + + void OnIncomingStreamsReset( + rtc::ArrayView<const StreamID> incoming_streams) override {} + + void NotifyOutgoingMessageBufferEmpty() override {} + + void OnBufferedAmountLow(StreamID stream_id) override { + if (mode_ == ActorMode::kThroughputSender) { + std::vector<uint8_t> payload(kHugePayloadSize); + sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)), + SendOptions()); + + } else if (mode_ == ActorMode::kLimitedRetransmissionSender) { + while (sctp_socket_.buffered_amount(kStreamId) < + kBufferedAmountLowThreshold * 2) { + SendOptions send_options; + send_options.max_retransmissions = 0; + sctp_socket_.Send( + DcSctpMessage(kStreamId, kPpid, + std::vector<uint8_t>(kLargePayloadSize)), + send_options); + + send_options.max_retransmissions = absl::nullopt; + sctp_socket_.Send( + DcSctpMessage(kStreamId, kPpid, + std::vector<uint8_t>(kSmallPayloadSize)), + send_options); + } + } + } + + absl::optional<DcSctpMessage> ConsumeReceivedMessage() { + if (!last_received_message_.has_value()) { + return absl::nullopt; + } + DcSctpMessage ret = *std::move(last_received_message_); + last_received_message_ = absl::nullopt; + return ret; + } + + DcSctpSocket& sctp_socket() { return sctp_socket_; } + + void SetActorMode(ActorMode mode) { + mode_ = mode; + if (mode_ == ActorMode::kThroughputSender) { + sctp_socket_.SetBufferedAmountLowThreshold(kStreamId, + kBufferedAmountLowThreshold); + std::vector<uint8_t> payload(kHugePayloadSize); + sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)), + SendOptions()); + + } else if (mode_ == ActorMode::kLimitedRetransmissionSender) { + sctp_socket_.SetBufferedAmountLowThreshold(kStreamId, + kBufferedAmountLowThreshold); + std::vector<uint8_t> payload(kHugePayloadSize); + sctp_socket_.Send(DcSctpMessage(kStreamId, kPpid, std::move(payload)), + SendOptions()); + + } else if (mode == ActorMode::kThroughputReceiver) { + thread_->PostDelayedTask( + SafeTask(safety_.flag(), [this] { PrintBandwidth(); }), + kPrintBandwidthDuration); + } + } + + // Returns the average bitrate, stripping the first `remove_first_n` that + // represent the time it took to ramp up the congestion control algorithm. + double avg_received_bitrate_mbps(size_t remove_first_n = 3) const { + std::vector<double> bitrates = received_bitrate_mbps_; + bitrates.erase(bitrates.begin(), bitrates.begin() + remove_first_n); + + double sum = 0; + for (double bitrate : bitrates) { + sum += bitrate; + } + + return sum / bitrates.size(); + } + + private: + std::string log_prefix() const { + rtc::StringBuilder sb; + sb << log_prefix_; + sb << rtc::TimeMillis(); + sb << ": "; + return sb.Release(); + } + + ActorMode mode_ = ActorMode::kAtRest; + const std::string log_prefix_; + rtc::Thread* thread_; + BoundSocket& emulated_socket_; + TaskQueueTimeoutFactory timeout_factory_; + webrtc::Random random_; + DcSctpSocket sctp_socket_; + size_t received_bytes_ = 0; + absl::optional<DcSctpMessage> last_received_message_; + TimeMs last_bandwidth_printout_; + // Per-second received bitrates, in Mbps + std::vector<double> received_bitrate_mbps_; + webrtc::ScopedTaskSafety safety_; +}; + +class DcSctpSocketNetworkTest : public testing::Test { + protected: + DcSctpSocketNetworkTest() + : options_(MakeOptionsForTest()), + emulation_(webrtc::CreateNetworkEmulationManager( + webrtc::TimeMode::kSimulated)) {} + + void MakeNetwork(const webrtc::BuiltInNetworkBehaviorConfig& config) { + webrtc::EmulatedEndpoint* endpoint_a = + emulation_->CreateEndpoint(webrtc::EmulatedEndpointConfig()); + webrtc::EmulatedEndpoint* endpoint_z = + emulation_->CreateEndpoint(webrtc::EmulatedEndpointConfig()); + + webrtc::EmulatedNetworkNode* node1 = emulation_->CreateEmulatedNode(config); + webrtc::EmulatedNetworkNode* node2 = emulation_->CreateEmulatedNode(config); + + emulation_->CreateRoute(endpoint_a, {node1}, endpoint_z); + emulation_->CreateRoute(endpoint_z, {node2}, endpoint_a); + + emulated_socket_a_.Bind(endpoint_a); + emulated_socket_z_.Bind(endpoint_z); + + emulated_socket_a_.SetDestination(emulated_socket_z_); + emulated_socket_z_.SetDestination(emulated_socket_a_); + } + + void Sleep(webrtc::TimeDelta duration) { + // Sleep in one-millisecond increments, to let timers expire when expected. + for (int i = 0; i < duration.ms(); ++i) { + emulation_->time_controller()->AdvanceTime(webrtc::TimeDelta::Millis(1)); + } + } + + DcSctpOptions options_; + std::unique_ptr<webrtc::NetworkEmulationManager> emulation_; + BoundSocket emulated_socket_a_; + BoundSocket emulated_socket_z_; +}; + +TEST_F(DcSctpSocketNetworkTest, CanConnectAndShutdown) { + webrtc::BuiltInNetworkBehaviorConfig pipe_config; + MakeNetwork(pipe_config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + EXPECT_THAT(sender.sctp_socket().state(), SocketState::kClosed); + + sender.sctp_socket().Connect(); + Sleep(kAWhile); + EXPECT_THAT(sender.sctp_socket().state(), SocketState::kConnected); + + sender.sctp_socket().Shutdown(); + Sleep(kAWhile); + EXPECT_THAT(sender.sctp_socket().state(), SocketState::kClosed); +} + +TEST_F(DcSctpSocketNetworkTest, CanSendLargeMessage) { + webrtc::BuiltInNetworkBehaviorConfig pipe_config; + pipe_config.queue_delay_ms = 30; + MakeNetwork(pipe_config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + sender.sctp_socket().Connect(); + + constexpr size_t kPayloadSize = 100 * 1024; + + std::vector<uint8_t> payload(kPayloadSize); + sender.sctp_socket().Send(DcSctpMessage(kStreamId, kPpid, payload), + SendOptions()); + + Sleep(kAWhile); + + ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage message, + receiver.ConsumeReceivedMessage()); + + EXPECT_THAT(message.payload(), SizeIs(kPayloadSize)); + + sender.sctp_socket().Shutdown(); + Sleep(kAWhile); +} + +TEST_F(DcSctpSocketNetworkTest, CanSendMessagesReliablyWithLowBandwidth) { + webrtc::BuiltInNetworkBehaviorConfig pipe_config; + pipe_config.queue_delay_ms = 30; + pipe_config.link_capacity_kbps = 1000; + MakeNetwork(pipe_config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + sender.sctp_socket().Connect(); + + sender.SetActorMode(ActorMode::kThroughputSender); + receiver.SetActorMode(ActorMode::kThroughputReceiver); + + Sleep(kBenchmarkRuntime); + sender.SetActorMode(ActorMode::kAtRest); + receiver.SetActorMode(ActorMode::kAtRest); + + Sleep(kAWhile); + + sender.sctp_socket().Shutdown(); + + Sleep(kAWhile); + + // Verify that the bitrates are in the range of 0.5-1.0 Mbps. + double bitrate = receiver.avg_received_bitrate_mbps(); + EXPECT_THAT(bitrate, AllOf(Ge(0.5), Le(1.0))); +} + +TEST_F(DcSctpSocketNetworkTest, + DCSCTP_NDEBUG_TEST(CanSendMessagesReliablyWithMediumBandwidth)) { + webrtc::BuiltInNetworkBehaviorConfig pipe_config; + pipe_config.queue_delay_ms = 30; + pipe_config.link_capacity_kbps = 18000; + MakeNetwork(pipe_config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + sender.sctp_socket().Connect(); + + sender.SetActorMode(ActorMode::kThroughputSender); + receiver.SetActorMode(ActorMode::kThroughputReceiver); + + Sleep(kBenchmarkRuntime); + sender.SetActorMode(ActorMode::kAtRest); + receiver.SetActorMode(ActorMode::kAtRest); + + Sleep(kAWhile); + + sender.sctp_socket().Shutdown(); + + Sleep(kAWhile); + + // Verify that the bitrates are in the range of 16-18 Mbps. + double bitrate = receiver.avg_received_bitrate_mbps(); + EXPECT_THAT(bitrate, AllOf(Ge(16), Le(18))); +} + +TEST_F(DcSctpSocketNetworkTest, CanSendMessagesReliablyWithMuchPacketLoss) { + webrtc::BuiltInNetworkBehaviorConfig config; + config.queue_delay_ms = 30; + config.loss_percent = 1; + MakeNetwork(config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + sender.sctp_socket().Connect(); + + sender.SetActorMode(ActorMode::kThroughputSender); + receiver.SetActorMode(ActorMode::kThroughputReceiver); + + Sleep(kBenchmarkRuntime); + sender.SetActorMode(ActorMode::kAtRest); + receiver.SetActorMode(ActorMode::kAtRest); + + Sleep(kAWhile); + + sender.sctp_socket().Shutdown(); + + Sleep(kAWhile); + + // TCP calculator gives: 1200 MTU, 60ms RTT and 1% packet loss -> 1.6Mbps. + // This test is doing slightly better (doesn't have any additional header + // overhead etc). Verify that the bitrates are in the range of 1.5-2.5 Mbps. + double bitrate = receiver.avg_received_bitrate_mbps(); + EXPECT_THAT(bitrate, AllOf(Ge(1.5), Le(2.5))); +} + +TEST_F(DcSctpSocketNetworkTest, DCSCTP_NDEBUG_TEST(HasHighBandwidth)) { + webrtc::BuiltInNetworkBehaviorConfig pipe_config; + pipe_config.queue_delay_ms = 30; + MakeNetwork(pipe_config); + + SctpActor sender("A", emulated_socket_a_, options_); + SctpActor receiver("Z", emulated_socket_z_, options_); + sender.sctp_socket().Connect(); + + sender.SetActorMode(ActorMode::kThroughputSender); + receiver.SetActorMode(ActorMode::kThroughputReceiver); + + Sleep(kBenchmarkRuntime); + + sender.SetActorMode(ActorMode::kAtRest); + receiver.SetActorMode(ActorMode::kAtRest); + Sleep(kAWhile); + + sender.sctp_socket().Shutdown(); + Sleep(kAWhile); + + // Verify that the bitrate is in the range of 540-640 Mbps + double bitrate = receiver.avg_received_bitrate_mbps(); + EXPECT_THAT(bitrate, AllOf(Ge(520), Le(640))); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc new file mode 100644 index 0000000000..d1e2d904e0 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/dcsctp_socket_test.cc @@ -0,0 +1,2672 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/dcsctp_socket.h" + +#include <cstdint> +#include <deque> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/packet/chunk/chunk.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/data_common.h" +#include "net/dcsctp/packet/chunk/error_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/init_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/chunk/shutdown_chunk.h" +#include "net/dcsctp/packet/error_cause/error_cause.h" +#include "net/dcsctp/packet/error_cause/unrecognized_chunk_type_cause.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/text_pcap_packet_observer.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +ABSL_FLAG(bool, dcsctp_capture_packets, false, "Print packet capture."); + +namespace dcsctp { +namespace { +using ::testing::_; +using ::testing::AllOf; +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +constexpr SendOptions kSendOptions; +constexpr size_t kLargeMessageSize = DcSctpOptions::kMaxSafeMTUSize * 20; +constexpr size_t kSmallMessageSize = 10; +constexpr int kMaxBurstPackets = 4; + +MATCHER_P(HasDataChunkWithStreamId, stream_id, "") { + absl::optional<SctpPacket> packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != DataChunk::kType) { + *result_listener << "the first chunk in the packet is not a data chunk"; + return false; + } + + absl::optional<DataChunk> dc = + DataChunk::Parse(packet->descriptors()[0].data); + if (!dc.has_value()) { + *result_listener << "The first chunk didn't parse as a data chunk"; + return false; + } + + if (dc->stream_id() != stream_id) { + *result_listener << "the stream_id is " << *dc->stream_id(); + return false; + } + + return true; +} + +MATCHER_P(HasDataChunkWithPPID, ppid, "") { + absl::optional<SctpPacket> packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != DataChunk::kType) { + *result_listener << "the first chunk in the packet is not a data chunk"; + return false; + } + + absl::optional<DataChunk> dc = + DataChunk::Parse(packet->descriptors()[0].data); + if (!dc.has_value()) { + *result_listener << "The first chunk didn't parse as a data chunk"; + return false; + } + + if (dc->ppid() != ppid) { + *result_listener << "the ppid is " << *dc->ppid(); + return false; + } + + return true; +} + +MATCHER_P(HasDataChunkWithSsn, ssn, "") { + absl::optional<SctpPacket> packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != DataChunk::kType) { + *result_listener << "the first chunk in the packet is not a data chunk"; + return false; + } + + absl::optional<DataChunk> dc = + DataChunk::Parse(packet->descriptors()[0].data); + if (!dc.has_value()) { + *result_listener << "The first chunk didn't parse as a data chunk"; + return false; + } + + if (dc->ssn() != ssn) { + *result_listener << "the ssn is " << *dc->ssn(); + return false; + } + + return true; +} + +MATCHER_P(HasDataChunkWithMid, mid, "") { + absl::optional<SctpPacket> packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != IDataChunk::kType) { + *result_listener << "the first chunk in the packet is not an i-data chunk"; + return false; + } + + absl::optional<IDataChunk> dc = + IDataChunk::Parse(packet->descriptors()[0].data); + if (!dc.has_value()) { + *result_listener << "The first chunk didn't parse as an i-data chunk"; + return false; + } + + if (dc->message_id() != mid) { + *result_listener << "the mid is " << *dc->message_id(); + return false; + } + + return true; +} + +MATCHER_P(HasSackWithCumAckTsn, tsn, "") { + absl::optional<SctpPacket> packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != SackChunk::kType) { + *result_listener << "the first chunk in the packet is not a data chunk"; + return false; + } + + absl::optional<SackChunk> sc = + SackChunk::Parse(packet->descriptors()[0].data); + if (!sc.has_value()) { + *result_listener << "The first chunk didn't parse as a data chunk"; + return false; + } + + if (sc->cumulative_tsn_ack() != tsn) { + *result_listener << "the cum_ack_tsn is " << *sc->cumulative_tsn_ack(); + return false; + } + + return true; +} + +MATCHER(HasSackWithNoGapAckBlocks, "") { + absl::optional<SctpPacket> packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != SackChunk::kType) { + *result_listener << "the first chunk in the packet is not a data chunk"; + return false; + } + + absl::optional<SackChunk> sc = + SackChunk::Parse(packet->descriptors()[0].data); + if (!sc.has_value()) { + *result_listener << "The first chunk didn't parse as a data chunk"; + return false; + } + + if (!sc->gap_ack_blocks().empty()) { + *result_listener << "there are gap ack blocks"; + return false; + } + + return true; +} + +MATCHER_P(HasReconfigWithStreams, streams_matcher, "") { + absl::optional<SctpPacket> packet = SctpPacket::Parse(arg); + if (!packet.has_value()) { + *result_listener << "data didn't parse as an SctpPacket"; + return false; + } + + if (packet->descriptors()[0].type != ReConfigChunk::kType) { + *result_listener << "the first chunk in the packet is not a data chunk"; + return false; + } + + absl::optional<ReConfigChunk> reconfig = + ReConfigChunk::Parse(packet->descriptors()[0].data); + if (!reconfig.has_value()) { + *result_listener << "The first chunk didn't parse as a data chunk"; + return false; + } + + const Parameters& parameters = reconfig->parameters(); + if (parameters.descriptors().size() != 1 || + parameters.descriptors()[0].type != + OutgoingSSNResetRequestParameter::kType) { + *result_listener << "Expected the reconfig chunk to have an outgoing SSN " + "reset request parameter"; + return false; + } + + absl::optional<OutgoingSSNResetRequestParameter> p = + OutgoingSSNResetRequestParameter::Parse(parameters.descriptors()[0].data); + testing::Matcher<rtc::ArrayView<const StreamID>> matcher = streams_matcher; + if (!matcher.MatchAndExplain(p->stream_ids(), result_listener)) { + return false; + } + + return true; +} + +TSN AddTo(TSN tsn, int delta) { + return TSN(*tsn + delta); +} + +DcSctpOptions FixupOptions(DcSctpOptions options = {}) { + DcSctpOptions fixup = options; + // To make the interval more predictable in tests. + fixup.heartbeat_interval_include_rtt = false; + fixup.max_burst = kMaxBurstPackets; + return fixup; +} + +std::unique_ptr<PacketObserver> GetPacketObserver(absl::string_view name) { + if (absl::GetFlag(FLAGS_dcsctp_capture_packets)) { + return std::make_unique<TextPcapPacketObserver>(name); + } + return nullptr; +} + +struct SocketUnderTest { + explicit SocketUnderTest(absl::string_view name, + const DcSctpOptions& opts = {}) + : options(FixupOptions(opts)), + cb(name), + socket(name, cb, GetPacketObserver(name), options) {} + + const DcSctpOptions options; + testing::NiceMock<MockDcSctpSocketCallbacks> cb; + DcSctpSocket socket; +}; + +void ExchangeMessages(SocketUnderTest& a, SocketUnderTest& z) { + bool delivered_packet = false; + do { + delivered_packet = false; + std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket(); + if (!packet_from_a.empty()) { + delivered_packet = true; + z.socket.ReceivePacket(std::move(packet_from_a)); + } + std::vector<uint8_t> packet_from_z = z.cb.ConsumeSentPacket(); + if (!packet_from_z.empty()) { + delivered_packet = true; + a.socket.ReceivePacket(std::move(packet_from_z)); + } + } while (delivered_packet); +} + +void RunTimers(SocketUnderTest& s) { + for (;;) { + absl::optional<TimeoutID> timeout_id = s.cb.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + s.socket.HandleTimeout(*timeout_id); + } +} + +void AdvanceTime(SocketUnderTest& a, SocketUnderTest& z, DurationMs duration) { + a.cb.AdvanceTime(duration); + z.cb.AdvanceTime(duration); + + RunTimers(a); + RunTimers(z); +} + +// Calls Connect() on `sock_a_` and make the connection established. +void ConnectSockets(SocketUnderTest& a, SocketUnderTest& z) { + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + + a.socket.Connect(); + // Z reads INIT, INIT_ACK, COOKIE_ECHO, COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +std::unique_ptr<SocketUnderTest> HandoverSocket( + std::unique_ptr<SocketUnderTest> sut) { + EXPECT_EQ(sut->socket.GetHandoverReadiness(), HandoverReadinessStatus()); + + bool is_closed = sut->socket.state() == SocketState::kClosed; + if (!is_closed) { + EXPECT_CALL(sut->cb, OnClosed).Times(1); + } + absl::optional<DcSctpSocketHandoverState> handover_state = + sut->socket.GetHandoverStateAndClose(); + EXPECT_TRUE(handover_state.has_value()); + g_handover_state_transformer_for_test(&*handover_state); + + auto handover_socket = std::make_unique<SocketUnderTest>("H", sut->options); + if (!is_closed) { + EXPECT_CALL(handover_socket->cb, OnConnected).Times(1); + } + handover_socket->socket.RestoreFromState(*handover_state); + return handover_socket; +} + +std::vector<uint32_t> GetReceivedMessagePpids(SocketUnderTest& z) { + std::vector<uint32_t> ppids; + for (;;) { + absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage(); + if (!msg.has_value()) { + break; + } + ppids.push_back(*msg->ppid()); + } + return ppids; +} + +// Test parameter that controls whether to perform handovers during the test. A +// test can have multiple points where it conditionally hands over socket Z. +// Either socket Z will be handed over at all those points or handed over never. +enum class HandoverMode { + kNoHandover, + kPerformHandovers, +}; + +class DcSctpSocketParametrizedTest + : public ::testing::Test, + public ::testing::WithParamInterface<HandoverMode> { + protected: + // Trigger handover for `sut` depending on the current test param. + std::unique_ptr<SocketUnderTest> MaybeHandoverSocket( + std::unique_ptr<SocketUnderTest> sut) { + if (GetParam() == HandoverMode::kPerformHandovers) { + return HandoverSocket(std::move(sut)); + } + return sut; + } + + // Trigger handover for socket Z depending on the current test param. + // Then checks message passing to verify the handed over socket is functional. + void MaybeHandoverSocketAndSendMessage(SocketUnderTest& a, + std::unique_ptr<SocketUnderTest> z) { + if (GetParam() == HandoverMode::kPerformHandovers) { + z = HandoverSocket(std::move(z)); + } + + ExchangeMessages(a, *z); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + ExchangeMessages(a, *z); + + absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + } +}; + +INSTANTIATE_TEST_SUITE_P(Handovers, + DcSctpSocketParametrizedTest, + testing::Values(HandoverMode::kNoHandover, + HandoverMode::kPerformHandovers), + [](const auto& test_info) { + return test_info.param == + HandoverMode::kPerformHandovers + ? "WithHandovers" + : "NoHandover"; + }); + +TEST(DcSctpSocketTest, EstablishConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + + a.socket.Connect(); + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, EstablishConnectionWithSetupCollision) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + a.socket.Connect(); + z.socket.Connect(); + + ExchangeMessages(a, z); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, ShuttingDownWhileEstablishingConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(0); + EXPECT_CALL(z.cb, OnConnected).Times(1); + a.socket.Connect(); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Drop COOKIE_ACK, just to more easily verify shutdown protocol. + z.cb.ConsumeSentPacket(); + + // As Socket A has received INIT_ACK, it has a TCB and is connected, while + // Socket Z needs to receive COOKIE_ECHO to get there. Socket A still has + // timers running at this point. + EXPECT_EQ(a.socket.state(), SocketState::kConnecting); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); + + // Socket A is now shut down, which should make it stop those timers. + a.socket.Shutdown(); + + EXPECT_CALL(a.cb, OnClosed).Times(1); + EXPECT_CALL(z.cb, OnClosed).Times(1); + + // Z reads SHUTDOWN, produces SHUTDOWN_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // Z reads SHUTDOWN_COMPLETE. + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + EXPECT_TRUE(a.cb.ConsumeSentPacket().empty()); + EXPECT_TRUE(z.cb.ConsumeSentPacket().empty()); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); + EXPECT_EQ(z.socket.state(), SocketState::kClosed); +} + +TEST(DcSctpSocketTest, EstablishSimultaneousConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + a.socket.Connect(); + + // INIT isn't received by Z, as it wasn't ready yet. + a.cb.ConsumeSentPacket(); + + z.socket.Connect(); + + // A reads INIT, produces INIT_ACK + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // Z reads INIT_ACK, sends COOKIE_ECHO + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + // A reads COOKIE_ECHO - establishes connection. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + + // Proceed with the remaining packets. + ExchangeMessages(a, z); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, EstablishConnectionLostCookieAck) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(0); + EXPECT_CALL(z.cb, OnConnectionRestarted).Times(0); + + a.socket.Connect(); + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // COOKIE_ACK is lost. + z.cb.ConsumeSentPacket(); + + EXPECT_EQ(a.socket.state(), SocketState::kConnecting); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); + + // This will make A re-send the COOKIE_ECHO + AdvanceTime(a, z, DurationMs(a.options.t1_cookie_timeout)); + + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, ResendInitAndEstablishConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + // INIT is never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType); + + AdvanceTime(a, z, a.options.t1_init_timeout); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, ResendingInitTooManyTimesAborts) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + + // INIT is never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_EQ(init_packet.descriptors()[0].type, InitChunk::kType); + + for (int i = 0; i < *a.options.max_init_retransmits; ++i) { + AdvanceTime(a, z, a.options.t1_init_timeout * (1 << i)); + + // INIT is resent + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_EQ(resent_init_packet.descriptors()[0].type, InitChunk::kType); + } + + // Another timeout, after the max init retransmits. + EXPECT_CALL(a.cb, OnAborted).Times(1); + AdvanceTime( + a, z, a.options.t1_init_timeout * (1 << *a.options.max_init_retransmits)); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); +} + +TEST(DcSctpSocketTest, ResendCookieEchoAndEstablishConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // COOKIE_ECHO is never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType); + + AdvanceTime(a, z, a.options.t1_init_timeout); + + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); +} + +TEST(DcSctpSocketTest, ResendingCookieEchoTooManyTimesAborts) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // COOKIE_ECHO is never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_EQ(init_packet.descriptors()[0].type, CookieEchoChunk::kType); + + for (int i = 0; i < *a.options.max_init_retransmits; ++i) { + AdvanceTime(a, z, a.options.t1_cookie_timeout * (1 << i)); + + // COOKIE_ECHO is resent + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket resent_init_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_EQ(resent_init_packet.descriptors()[0].type, CookieEchoChunk::kType); + } + + // Another timeout, after the max init retransmits. + EXPECT_CALL(a.cb, OnAborted).Times(1); + AdvanceTime( + a, z, + a.options.t1_cookie_timeout * (1 << *a.options.max_init_retransmits)); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); +} + +TEST(DcSctpSocketTest, DoesntSendMorePacketsUntilCookieAckHasBeenReceived) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + a.socket.Connect(); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // COOKIE_ECHO is never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet1, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_THAT(cookie_echo_packet1.descriptors(), SizeIs(2)); + EXPECT_EQ(cookie_echo_packet1.descriptors()[0].type, CookieEchoChunk::kType); + EXPECT_EQ(cookie_echo_packet1.descriptors()[1].type, DataChunk::kType); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + // There are DATA chunks in the sent packet (that was lost), which means that + // the T3-RTX timer is running, but as the socket is in kCookieEcho state, it + // will be T1-COOKIE that drives retransmissions, so when the T3-RTX expires, + // nothing should be retransmitted. + ASSERT_TRUE(a.options.rto_initial < a.options.t1_cookie_timeout); + AdvanceTime(a, z, a.options.rto_initial); + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + // When T1-COOKIE expires, both the COOKIE-ECHO and DATA should be present. + AdvanceTime(a, z, a.options.t1_cookie_timeout - a.options.rto_initial); + + // And this COOKIE-ECHO and DATA is also lost - never received by Z. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket cookie_echo_packet2, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_THAT(cookie_echo_packet2.descriptors(), SizeIs(2)); + EXPECT_EQ(cookie_echo_packet2.descriptors()[0].type, CookieEchoChunk::kType); + EXPECT_EQ(cookie_echo_packet2.descriptors()[1].type, DataChunk::kType); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + // COOKIE_ECHO has exponential backoff. + AdvanceTime(a, z, a.options.t1_cookie_timeout * 2); + + // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); + + ExchangeMessages(a, z); + EXPECT_THAT(z.cb.ConsumeReceivedMessage()->payload(), + SizeIs(kLargeMessageSize)); +} + +TEST_P(DcSctpSocketParametrizedTest, ShutdownConnection) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + RTC_LOG(LS_INFO) << "Shutting down"; + + EXPECT_CALL(z->cb, OnClosed).Times(1); + a.socket.Shutdown(); + // Z reads SHUTDOWN, produces SHUTDOWN_ACK + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // A reads SHUTDOWN_ACK, produces SHUTDOWN_COMPLETE + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + // Z reads SHUTDOWN_COMPLETE. + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); + EXPECT_EQ(z->socket.state(), SocketState::kClosed); + + z = MaybeHandoverSocket(std::move(z)); + EXPECT_EQ(z->socket.state(), SocketState::kClosed); +} + +TEST(DcSctpSocketTest, ShutdownTimerExpiresTooManyTimeClosesConnection) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + a.socket.Shutdown(); + // Drop first SHUTDOWN packet. + a.cb.ConsumeSentPacket(); + + EXPECT_EQ(a.socket.state(), SocketState::kShuttingDown); + + for (int i = 0; i < *a.options.max_retransmissions; ++i) { + AdvanceTime(a, z, DurationMs(a.options.rto_initial * (1 << i))); + + // Dropping every shutdown chunk. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_EQ(packet.descriptors()[0].type, ShutdownChunk::kType); + EXPECT_TRUE(a.cb.ConsumeSentPacket().empty()); + } + // The last expiry, makes it abort the connection. + EXPECT_CALL(a.cb, OnAborted).Times(1); + AdvanceTime(a, z, + a.options.rto_initial * (1 << *a.options.max_retransmissions)); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_EQ(packet.descriptors()[0].type, AbortChunk::kType); + EXPECT_TRUE(a.cb.ConsumeSentPacket().empty()); +} + +TEST(DcSctpSocketTest, EstablishConnectionWhileSendingData) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + + // Z reads INIT, produces INIT_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // // A reads INIT_ACK, produces COOKIE_ECHO + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + // // Z reads COOKIE_ECHO, produces COOKIE_ACK + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // // A reads COOKIE_ACK. + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + EXPECT_EQ(a.socket.state(), SocketState::kConnected); + EXPECT_EQ(z.socket.state(), SocketState::kConnected); + + absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); +} + +TEST(DcSctpSocketTest, SendMessageAfterEstablished) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); +} + +TEST_P(DcSctpSocketParametrizedTest, TimeoutResendsPacket) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + a.cb.ConsumeSentPacket(); + + RTC_LOG(LS_INFO) << "Advancing time"; + AdvanceTime(a, *z, a.options.rto_initial); + + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, SendALotOfBytesMissedSecondPacket) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + std::vector<uint8_t> payload(kLargeMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + // First DATA + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Second DATA (lost) + a.cb.ConsumeSentPacket(); + + // Retransmit and handle the rest + ExchangeMessages(a, *z); + + absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload)); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, SendingHeartbeatAnswersWithAck) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + // Inject a HEARTBEAT chunk + SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions()); + uint8_t info[] = {1, 2, 3, 4}; + Parameters::Builder params_builder; + params_builder.Add(HeartbeatInfoParameter(info)); + b.Add(HeartbeatRequestChunk(params_builder.Build())); + a.socket.ReceivePacket(b.Build()); + + // HEARTBEAT_ACK is sent as a reply. Capture it. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket ack_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + ASSERT_THAT(ack_packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatAckChunk ack, + HeartbeatAckChunk::Parse(ack_packet.descriptors()[0].data)); + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, ack.info()); + EXPECT_THAT(info_param.info(), ElementsAre(1, 2, 3, 4)); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, ExpectHeartbeatToBeSent) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + AdvanceTime(a, *z, a.options.heartbeat_interval); + + std::vector<uint8_t> hb_packet_raw = a.cb.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet, + SctpPacket::Parse(hb_packet_raw)); + ASSERT_THAT(hb_packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk hb, + HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data)); + ASSERT_HAS_VALUE_AND_ASSIGN(HeartbeatInfoParameter info_param, hb.info()); + + // The info is a single 64-bit number. + EXPECT_THAT(hb.info()->info(), SizeIs(8)); + + // Feed it to Sock-z and expect a HEARTBEAT_ACK that will be propagated back. + z->socket.ReceivePacket(hb_packet_raw); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + CloseConnectionAfterTooManyLostHeartbeats) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(z->cb, OnClosed).Times(1); + EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty()); + // Force-close socket Z so that it doesn't interfere from now on. + z->socket.Close(); + + DurationMs time_to_next_hearbeat = a.options.heartbeat_interval; + + for (int i = 0; i < *a.options.max_retransmissions; ++i) { + RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending..."; + AdvanceTime(a, *z, time_to_next_hearbeat); + + // Dropping every heartbeat. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_EQ(hb_packet.descriptors()[0].type, HeartbeatRequestChunk::kType); + + RTC_LOG(LS_INFO) << "Letting the heartbeat expire."; + AdvanceTime(a, *z, DurationMs(1000)); + + time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000); + } + + RTC_LOG(LS_INFO) << "Letting HEARTBEAT interval timer expire - sending..."; + AdvanceTime(a, *z, time_to_next_hearbeat); + + // Last heartbeat + EXPECT_THAT(a.cb.ConsumeSentPacket(), Not(IsEmpty())); + + EXPECT_CALL(a.cb, OnAborted).Times(1); + // Should suffice as exceeding RTO + AdvanceTime(a, *z, DurationMs(1000)); + + z = MaybeHandoverSocket(std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, RecoversAfterASuccessfulAck) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), testing::IsEmpty()); + EXPECT_CALL(z->cb, OnClosed).Times(1); + // Force-close socket Z so that it doesn't interfere from now on. + z->socket.Close(); + + DurationMs time_to_next_hearbeat = a.options.heartbeat_interval; + + for (int i = 0; i < *a.options.max_retransmissions; ++i) { + AdvanceTime(a, *z, time_to_next_hearbeat); + + // Dropping every heartbeat. + a.cb.ConsumeSentPacket(); + + RTC_LOG(LS_INFO) << "Letting the heartbeat expire."; + AdvanceTime(a, *z, DurationMs(1000)); + + time_to_next_hearbeat = a.options.heartbeat_interval - DurationMs(1000); + } + + RTC_LOG(LS_INFO) << "Getting the last heartbeat - and acking it"; + AdvanceTime(a, *z, time_to_next_hearbeat); + + std::vector<uint8_t> hb_packet_raw = a.cb.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket hb_packet, + SctpPacket::Parse(hb_packet_raw)); + ASSERT_THAT(hb_packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk hb, + HeartbeatRequestChunk::Parse(hb_packet.descriptors()[0].data)); + + SctpPacket::Builder b(a.socket.verification_tag(), a.options); + b.Add(HeartbeatAckChunk(std::move(hb).extract_parameters())); + a.socket.ReceivePacket(b.Build()); + + // Should suffice as exceeding RTO - which will not fire. + EXPECT_CALL(a.cb, OnAborted).Times(0); + AdvanceTime(a, *z, DurationMs(1000)); + + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + // Verify that we get new heartbeats again. + RTC_LOG(LS_INFO) << "Expecting a new heartbeat"; + AdvanceTime(a, *z, time_to_next_hearbeat); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket another_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + EXPECT_EQ(another_packet.descriptors()[0].type, HeartbeatRequestChunk::kType); +} + +TEST_P(DcSctpSocketParametrizedTest, ResetStream) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), {}); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + + // Handle SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Reset the outgoing stream. This will directly send a RE-CONFIG. + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + + // Receiving the packet will trigger a callback, indicating that A has + // reset its stream. It will also send a RE-CONFIG with a response. + EXPECT_CALL(z->cb, OnIncomingStreamsReset).Times(1); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + // Receiving a response will trigger a callback. Streams are now reset. + EXPECT_CALL(a.cb, OnStreamsResetPerformed).Times(1); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, ResetStreamWillMakeChunksStartAtZeroSsn) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + std::vector<uint8_t> payload(a.options.mtu - 100); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto packet1 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0))); + z->socket.ReceivePacket(packet1); + + auto packet2 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1))); + z->socket.ReceivePacket(packet2); + + // Handle SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->stream_id(), StreamID(1)); + + absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->stream_id(), StreamID(1)); + + // Reset the outgoing stream. This will directly send a RE-CONFIG. + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + // RE-CONFIG, req + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // RE-CONFIG, resp + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto packet3 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0))); + z->socket.ReceivePacket(packet3); + + auto packet4 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1))); + z->socket.ReceivePacket(packet4); + + // Handle SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + ResetStreamWillOnlyResetTheRequestedStreams) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + std::vector<uint8_t> payload(a.options.mtu - 100); + + // Send two ordered messages on SID 1 + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + auto packet1 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1))); + EXPECT_THAT(packet1, HasDataChunkWithSsn(SSN(0))); + z->socket.ReceivePacket(packet1); + + auto packet2 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet1, HasDataChunkWithStreamId(StreamID(1))); + EXPECT_THAT(packet2, HasDataChunkWithSsn(SSN(1))); + z->socket.ReceivePacket(packet2); + + // Handle SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Do the same, for SID 3 + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + auto packet3 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet3, HasDataChunkWithStreamId(StreamID(3))); + EXPECT_THAT(packet3, HasDataChunkWithSsn(SSN(0))); + z->socket.ReceivePacket(packet3); + auto packet4 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet4, HasDataChunkWithStreamId(StreamID(3))); + EXPECT_THAT(packet4, HasDataChunkWithSsn(SSN(1))); + z->socket.ReceivePacket(packet4); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Receive all messages. + absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->stream_id(), StreamID(1)); + + absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->stream_id(), StreamID(1)); + + absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg3.has_value()); + EXPECT_EQ(msg3->stream_id(), StreamID(3)); + + absl::optional<DcSctpMessage> msg4 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg4.has_value()); + EXPECT_EQ(msg4->stream_id(), StreamID(3)); + + // Reset SID 1. This will directly send a RE-CONFIG. + a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)})); + // RE-CONFIG, req + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // RE-CONFIG, resp + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Send a message on SID 1 and 3 - SID 1 should not be reset, but 3 should. + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), {}); + + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), payload), {}); + + auto packet5 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet5, HasDataChunkWithStreamId(StreamID(1))); + EXPECT_THAT(packet5, HasDataChunkWithSsn(SSN(2))); // Unchanged. + z->socket.ReceivePacket(packet5); + + auto packet6 = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet6, HasDataChunkWithStreamId(StreamID(3))); + EXPECT_THAT(packet6, HasDataChunkWithSsn(SSN(0))); // Reset. + z->socket.ReceivePacket(packet6); + + // Handle SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, OnePeerReconnects) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnConnectionRestarted).Times(1); + // Let's be evil here - reconnect while a fragmented packet was about to be + // sent. The receiving side should get it in full. + std::vector<uint8_t> payload(kLargeMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + // First DATA + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + // Create a new association, z2 - and don't use z anymore. + SocketUnderTest z2("Z2"); + z2.socket.Connect(); + + // Retransmit and handle the rest. As there will be some chunks in-flight that + // have the wrong verification tag, those will yield errors. + ExchangeMessages(a, z2); + + absl::optional<DcSctpMessage> msg = z2.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAreArray(payload)); +} + +TEST_P(DcSctpSocketParametrizedTest, SendMessageWithLimitedRtx) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + SendOptions send_options; + send_options.max_retransmissions = 0; + std::vector<uint8_t> payload(a.options.mtu - 100); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options); + + // First DATA + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Second DATA (lost) + a.cb.ConsumeSentPacket(); + // Third DATA + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + // Handle SACK for first DATA + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Handle delayed SACK for third DATA + AdvanceTime(a, *z, a.options.delayed_ack_max_timeout); + + // Handle SACK for second DATA + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + // Now the missing data chunk will be marked as nacked, but it might still be + // in-flight and the reported gap could be due to out-of-order delivery. So + // the RetransmissionQueue will not mark it as "to be retransmitted" until + // after the t3-rtx timer has expired. + AdvanceTime(a, *z, a.options.rto_initial); + + // The chunk will be marked as retransmitted, and then as abandoned, which + // will trigger a FORWARD-TSN to be sent. + + // FORWARD-TSN (third) + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + // Which will trigger a SACK + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); + + absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->ppid(), PPID(51)); + + absl::optional<DcSctpMessage> msg2 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->ppid(), PPID(53)); + + absl::optional<DcSctpMessage> msg3 = z->cb.ConsumeReceivedMessage(); + EXPECT_FALSE(msg3.has_value()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, SendManyFragmentedMessagesWithLimitedRtx) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + SendOptions send_options; + send_options.unordered = IsUnordered(true); + send_options.max_retransmissions = 0; + std::vector<uint8_t> payload(a.options.mtu * 2 - 100 /* margin */); + // Sending first message + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + // Sending second message + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); + // Sending third message + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), send_options); + // Sending fourth message + a.socket.Send(DcSctpMessage(StreamID(1), PPID(54), payload), send_options); + + // First DATA, first fragment + std::vector<uint8_t> packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51))); + z->socket.ReceivePacket(std::move(packet)); + + // First DATA, second fragment (lost) + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(51))); + + // Second DATA, first fragment + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52))); + z->socket.ReceivePacket(std::move(packet)); + + // Second DATA, second fragment (lost) + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(52))); + EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); + + // Third DATA, first fragment + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53))); + EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); + z->socket.ReceivePacket(std::move(packet)); + + // Third DATA, second fragment (lost) + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(53))); + EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); + + // Fourth DATA, first fragment + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54))); + EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); + z->socket.ReceivePacket(std::move(packet)); + + // Fourth DATA, second fragment + packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasDataChunkWithPPID(PPID(54))); + EXPECT_THAT(packet, HasDataChunkWithSsn(SSN(0))); + z->socket.ReceivePacket(std::move(packet)); + + ExchangeMessages(a, *z); + + // Let the RTX timer expire, and exchange FORWARD-TSN/SACKs + AdvanceTime(a, *z, a.options.rto_initial); + + ExchangeMessages(a, *z); + + absl::optional<DcSctpMessage> msg1 = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->ppid(), PPID(54)); + + ASSERT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +struct FakeChunkConfig : ChunkConfig { + static constexpr int kType = 0x49; + static constexpr size_t kHeaderSize = 4; + static constexpr int kVariableLengthAlignment = 0; +}; + +class FakeChunk : public Chunk, public TLVTrait<FakeChunkConfig> { + public: + FakeChunk() {} + + FakeChunk(FakeChunk&& other) = default; + FakeChunk& operator=(FakeChunk&& other) = default; + + void SerializeTo(std::vector<uint8_t>& out) const override { + AllocateTLV(out); + } + std::string ToString() const override { return "FAKE"; } +}; + +TEST_P(DcSctpSocketParametrizedTest, ReceivingUnknownChunkRespondsWithError) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + // Inject a FAKE chunk + SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions()); + b.Add(FakeChunk()); + a.socket.ReceivePacket(b.Build()); + + // ERROR is sent as a reply. Capture it. + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket reply_packet, + SctpPacket::Parse(a.cb.ConsumeSentPacket())); + ASSERT_THAT(reply_packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + ErrorChunk error, ErrorChunk::Parse(reply_packet.descriptors()[0].data)); + ASSERT_HAS_VALUE_AND_ASSIGN( + UnrecognizedChunkTypeCause cause, + error.error_causes().get<UnrecognizedChunkTypeCause>()); + EXPECT_THAT(cause.unrecognized_chunk(), ElementsAre(0x49, 0x00, 0x00, 0x04)); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, ReceivingErrorChunkReportsAsCallback) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + // Inject a ERROR chunk + SctpPacket::Builder b(a.socket.verification_tag(), DcSctpOptions()); + b.Add( + ErrorChunk(Parameters::Builder() + .Add(UnrecognizedChunkTypeCause({0x49, 0x00, 0x00, 0x04})) + .Build())); + + EXPECT_CALL(a.cb, OnError(ErrorKind::kPeerReported, + HasSubstr("Unrecognized Chunk Type"))); + a.socket.ReceivePacket(b.Build()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST(DcSctpSocketTest, PassingHighWatermarkWillOnlyAcceptCumAckTsn) { + SocketUnderTest a("A"); + + constexpr size_t kReceiveWindowBufferSize = 2000; + SocketUnderTest z( + "Z", {.mtu = 3000, + .max_receiver_window_buffer_size = kReceiveWindowBufferSize}); + + EXPECT_CALL(z.cb, OnClosed).Times(0); + EXPECT_CALL(z.cb, OnAborted).Times(0); + + a.socket.Connect(); + std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(init_data)); + ASSERT_HAS_VALUE_AND_ASSIGN( + InitChunk init_chunk, + InitChunk::Parse(init_packet.descriptors()[0].data)); + z.socket.ReceivePacket(init_data); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // Fill up Z2 to the high watermark limit. + constexpr size_t kWatermarkLimit = + kReceiveWindowBufferSize * ReassemblyQueue::kHighWatermarkLimit; + constexpr size_t kRemainingSize = kReceiveWindowBufferSize - kWatermarkLimit; + + TSN tsn = init_chunk.initial_tsn(); + AnyDataChunk::Options opts; + opts.is_beginning = Data::IsBeginning(true); + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(kWatermarkLimit + 1), opts)) + .Build()); + + // First DATA will always trigger a SACK. It's not interesting. + EXPECT_THAT(z.cb.ConsumeSentPacket(), + AllOf(HasSackWithCumAckTsn(tsn), HasSackWithNoGapAckBlocks())); + + // This DATA should be accepted - it's advancing cum ack tsn. + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 1), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(1), + /*options=*/{})) + .Build()); + + // The receiver might have moved into delayed ack mode. + AdvanceTime(a, z, z.options.rto_initial); + + EXPECT_THAT( + z.cb.ConsumeSentPacket(), + AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks())); + + // This DATA will not be accepted - it's not advancing cum ack tsn. + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(1), + /*options=*/{})) + .Build()); + + // Sack will be sent in IMMEDIATE mode when this is happening. + EXPECT_THAT( + z.cb.ConsumeSentPacket(), + AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks())); + + // This DATA will not be accepted either. + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 4), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(1), + /*options=*/{})) + .Build()); + + // Sack will be sent in IMMEDIATE mode when this is happening. + EXPECT_THAT( + z.cb.ConsumeSentPacket(), + AllOf(HasSackWithCumAckTsn(AddTo(tsn, 1)), HasSackWithNoGapAckBlocks())); + + // This DATA should be accepted, and it fills the reassembly queue. + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 2), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(kRemainingSize), + /*options=*/{})) + .Build()); + + // The receiver might have moved into delayed ack mode. + AdvanceTime(a, z, z.options.rto_initial); + + EXPECT_THAT( + z.cb.ConsumeSentPacket(), + AllOf(HasSackWithCumAckTsn(AddTo(tsn, 2)), HasSackWithNoGapAckBlocks())); + + EXPECT_CALL(z.cb, OnAborted(ErrorKind::kResourceExhaustion, _)); + EXPECT_CALL(z.cb, OnClosed).Times(0); + + // This DATA will make the connection close. It's too full now. + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(AddTo(tsn, 3), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(kSmallMessageSize), + /*options=*/{})) + .Build()); +} + +TEST(DcSctpSocketTest, SetMaxMessageSize) { + SocketUnderTest a("A"); + + a.socket.SetMaxMessageSize(42u); + EXPECT_EQ(a.socket.options().max_message_size, 42u); +} + +TEST_P(DcSctpSocketParametrizedTest, SendsMessagesWithLowLifetime) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + // Mock that the time always goes forward. + TimeMs now(0); + EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + + // Queue a few small messages with low lifetime, both ordered and unordered, + // and validate that all are delivered. + static constexpr int kIterations = 100; + for (int i = 0; i < kIterations; ++i) { + SendOptions send_options; + send_options.unordered = IsUnordered((i % 2) == 0); + send_options.lifetime = DurationMs(i % 3); // 0, 1, 2 ms + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options); + } + + ExchangeMessages(a, *z); + + for (int i = 0; i < kIterations; ++i) { + EXPECT_TRUE(z->cb.ConsumeReceivedMessage().has_value()); + } + + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); + + // Validate that the sockets really make the time move forward. + EXPECT_GE(*now, kIterations * 2); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + DiscardsMessagesWithLowLifetimeIfMustBuffer) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + SendOptions lifetime_0; + lifetime_0.unordered = IsUnordered(true); + lifetime_0.lifetime = DurationMs(0); + + SendOptions lifetime_1; + lifetime_1.unordered = IsUnordered(true); + lifetime_1.lifetime = DurationMs(1); + + // Mock that the time always goes forward. + TimeMs now(0); + EXPECT_CALL(a.cb, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + EXPECT_CALL(z->cb, TimeMillis).WillRepeatedly([&]() { + now += DurationMs(3); + return now; + }); + + // Fill up the send buffer with a large message. + std::vector<uint8_t> payload(kLargeMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), kSendOptions); + + // And queue a few small messages with lifetime=0 or 1 ms - can't be sent. + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), lifetime_0); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {4, 5, 6}), lifetime_1); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {7, 8, 9}), lifetime_0); + + // Handle all that was sent until congestion window got full. + for (;;) { + std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket(); + if (packet_from_a.empty()) { + break; + } + z->socket.ReceivePacket(std::move(packet_from_a)); + } + + // Shouldn't be enough to send that large message. + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); + + // Exchange the rest of the messages, with the time ever increasing. + ExchangeMessages(a, *z); + + // The large message should be delivered. It was sent reliably. + ASSERT_HAS_VALUE_AND_ASSIGN(DcSctpMessage m1, z->cb.ConsumeReceivedMessage()); + EXPECT_EQ(m1.stream_id(), StreamID(1)); + EXPECT_THAT(m1.payload(), SizeIs(kLargeMessageSize)); + + // But none of the smaller messages. + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, HasReasonableBufferedAmountValues) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + // Sending a small message will directly send it as a single packet, so + // nothing is left in the queue. + EXPECT_EQ(a.socket.buffered_amount(StreamID(1)), 0u); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + // Sending a message will directly start sending a few packets, so the + // buffered amount is not the full message size. + EXPECT_GT(a.socket.buffered_amount(StreamID(1)), 0u); + EXPECT_LT(a.socket.buffered_amount(StreamID(1)), kLargeMessageSize); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST(DcSctpSocketTest, HasDefaultOnBufferedAmountLowValueZero) { + SocketUnderTest a("A"); + EXPECT_EQ(a.socket.buffered_amount_low_threshold(StreamID(1)), 0u); +} + +TEST_P(DcSctpSocketParametrizedTest, + TriggersOnBufferedAmountLowWithDefaultValueZero) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + EXPECT_CALL(a.cb, OnBufferedAmountLow).WillRepeatedly(testing::Return()); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + DoesntTriggerOnBufferedAmountLowIfBelowThreshold) { + static constexpr size_t kMessageSize = 1000; + static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 10; + + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + a.socket.SetBufferedAmountLowThreshold(StreamID(1), + kBufferedAmountLowThreshold); + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(0); + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, TriggersOnBufferedAmountMultipleTimes) { + static constexpr size_t kMessageSize = 1000; + static constexpr size_t kBufferedAmountLowThreshold = kMessageSize / 2; + + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + a.socket.SetBufferedAmountLowThreshold(StreamID(1), + kBufferedAmountLowThreshold); + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(3); + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(2))).Times(2); + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + a.socket.Send( + DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + a.socket.Send( + DcSctpMessage(StreamID(2), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + ExchangeMessages(a, *z); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + TriggersOnBufferedAmountLowOnlyWhenCrossingThreshold) { + static constexpr size_t kMessageSize = 1000; + static constexpr size_t kBufferedAmountLowThreshold = kMessageSize * 1.5; + + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + a.socket.SetBufferedAmountLowThreshold(StreamID(1), + kBufferedAmountLowThreshold); + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnBufferedAmountLow).Times(0); + + // Add a few messages to fill up the congestion window. When that is full, + // messages will start to be fully buffered. + while (a.socket.buffered_amount(StreamID(1)) <= kBufferedAmountLowThreshold) { + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kMessageSize)), + kSendOptions); + } + size_t initial_buffered = a.socket.buffered_amount(StreamID(1)); + ASSERT_GT(initial_buffered, kBufferedAmountLowThreshold); + + // Start ACKing packets, which will empty the send queue, and trigger the + // callback. + EXPECT_CALL(a.cb, OnBufferedAmountLow(StreamID(1))).Times(1); + ExchangeMessages(a, *z); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + DoesntTriggerOnTotalBufferAmountLowWhenBelow) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + ExchangeMessages(a, *z); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, + TriggersOnTotalBufferAmountLowWhenCrossingThreshold) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(0); + + // Fill up the send queue completely. + for (;;) { + if (a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions) == SendStatus::kErrorResourceExhaustion) { + break; + } + } + + EXPECT_CALL(a.cb, OnTotalBufferedAmountLow).Times(1); + ExchangeMessages(a, *z); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST(DcSctpSocketTest, InitialMetricsAreUnset) { + SocketUnderTest a("A"); + + EXPECT_FALSE(a.socket.GetMetrics().has_value()); +} + +TEST(DcSctpSocketTest, MessageInterleavingMetricsAreSet) { + std::vector<std::pair<bool, bool>> combinations = { + {false, false}, {false, true}, {true, false}, {true, true}}; + for (const auto& [a_enable, z_enable] : combinations) { + DcSctpOptions a_options = {.enable_message_interleaving = a_enable}; + DcSctpOptions z_options = {.enable_message_interleaving = z_enable}; + + SocketUnderTest a("A", a_options); + SocketUnderTest z("Z", z_options); + ConnectSockets(a, z); + + EXPECT_EQ(a.socket.GetMetrics()->uses_message_interleaving, + a_enable && z_enable); + } +} + +TEST(DcSctpSocketTest, RxAndTxPacketMetricsIncrease) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + const size_t initial_a_rwnd = a.options.max_receiver_window_buffer_size * + ReassemblyQueue::kHighWatermarkLimit; + + EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 2u); + EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 2u); + EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 0u); + EXPECT_EQ(a.socket.GetMetrics()->cwnd_bytes, + a.options.cwnd_mtus_initial * a.options.mtu); + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u); + + EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 2u); + EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 0u); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u); + + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK + EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd); + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u); + + EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value()); + + EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 3u); + EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 3u); + EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 1u); + + EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 3u); + EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 1u); + + // Send one more (large - fragmented), and receive the delayed SACK. + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(a.options.mtu * 2 + 1)), + kSendOptions); + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 3u); + + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA + + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 1u); + EXPECT_GT(a.socket.GetMetrics()->peer_rwnd_bytes, 0u); + EXPECT_LT(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd); + + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); // DATA + + EXPECT_TRUE(z.cb.ConsumeReceivedMessage().has_value()); + + EXPECT_EQ(a.socket.GetMetrics()->tx_packets_count, 6u); + EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 4u); + EXPECT_EQ(a.socket.GetMetrics()->tx_messages_count, 2u); + + EXPECT_EQ(z.socket.GetMetrics()->rx_packets_count, 6u); + EXPECT_EQ(z.socket.GetMetrics()->rx_messages_count, 2u); + + // Delayed sack + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); + + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); // SACK + EXPECT_EQ(a.socket.GetMetrics()->unack_data_count, 0u); + EXPECT_EQ(a.socket.GetMetrics()->rx_packets_count, 5u); + EXPECT_EQ(a.socket.GetMetrics()->peer_rwnd_bytes, initial_a_rwnd); +} + +TEST_P(DcSctpSocketParametrizedTest, UnackDataAlsoIncludesSendQueue) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + size_t payload_bytes = + a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize; + + size_t expected_sent_packets = a.options.cwnd_mtus_initial; + + size_t expected_queued_bytes = + kLargeMessageSize - expected_sent_packets * payload_bytes; + + size_t expected_queued_packets = expected_queued_bytes / payload_bytes; + + // Due to alignment, padding etc, it's hard to calculate the exact number, but + // it should be in this range. + EXPECT_GE(a.socket.GetMetrics()->unack_data_count, + expected_sent_packets + expected_queued_packets); + + EXPECT_LE(a.socket.GetMetrics()->unack_data_count, + expected_sent_packets + expected_queued_packets + 2); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, DoesntSendMoreThanMaxBurstPackets) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + for (int i = 0; i < kMaxBurstPackets; ++i) { + std::vector<uint8_t> packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, Not(IsEmpty())); + z->socket.ReceivePacket(std::move(packet)); // DATA + } + + EXPECT_THAT(a.cb.ConsumeSentPacket(), IsEmpty()); + + ExchangeMessages(a, *z); + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST_P(DcSctpSocketParametrizedTest, SendsOnlyLargePackets) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + // A really large message, to ensure that the congestion window is often full. + constexpr size_t kMessageSize = 100000; + a.socket.Send( + DcSctpMessage(StreamID(1), PPID(53), std::vector<uint8_t>(kMessageSize)), + kSendOptions); + + bool delivered_packet = false; + std::vector<size_t> data_packet_sizes; + do { + delivered_packet = false; + std::vector<uint8_t> packet_from_a = a.cb.ConsumeSentPacket(); + if (!packet_from_a.empty()) { + data_packet_sizes.push_back(packet_from_a.size()); + delivered_packet = true; + z->socket.ReceivePacket(std::move(packet_from_a)); + } + std::vector<uint8_t> packet_from_z = z->cb.ConsumeSentPacket(); + if (!packet_from_z.empty()) { + delivered_packet = true; + a.socket.ReceivePacket(std::move(packet_from_z)); + } + } while (delivered_packet); + + size_t packet_payload_bytes = + a.options.mtu - SctpPacket::kHeaderSize - DataChunk::kHeaderSize; + // +1 accounts for padding, and rounding up. + size_t expected_packets = + (kMessageSize + packet_payload_bytes - 1) / packet_payload_bytes + 1; + EXPECT_THAT(data_packet_sizes, SizeIs(expected_packets)); + + // Remove the last size - it will be the remainder. But all other sizes should + // be large. + data_packet_sizes.pop_back(); + + for (size_t size : data_packet_sizes) { + // The 4 is for padding/alignment. + EXPECT_GE(size, a.options.mtu - 4); + } + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST(DcSctpSocketTest, SendMessagesAfterHandover) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + + // Send message before handover to move socket to a not initial state + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + z->cb.ConsumeReceivedMessage(); + + z = HandoverSocket(std::move(z)); + + absl::optional<DcSctpMessage> msg; + + RTC_LOG(LS_INFO) << "Sending A #1"; + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {3, 4}), kSendOptions); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAre(3, 4)); + + RTC_LOG(LS_INFO) << "Sending A #2"; + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {5, 6}), kSendOptions); + z->socket.ReceivePacket(a.cb.ConsumeSentPacket()); + + msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(2)); + EXPECT_THAT(msg->payload(), testing::ElementsAre(5, 6)); + + RTC_LOG(LS_INFO) << "Sending Z #1"; + + z->socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2, 3}), kSendOptions); + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // ack + a.socket.ReceivePacket(z->cb.ConsumeSentPacket()); // data + + msg = a.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->stream_id(), StreamID(1)); + EXPECT_THAT(msg->payload(), testing::ElementsAre(1, 2, 3)); +} + +TEST(DcSctpSocketTest, CanDetectDcsctpImplementation) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp); + + // As A initiated the connection establishment, Z will not receive enough + // information to know about A's implementation + EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kUnknown); +} + +TEST(DcSctpSocketTest, BothCanDetectDcsctpImplementation) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + a.socket.Connect(); + z.socket.Connect(); + + ExchangeMessages(a, z); + + EXPECT_EQ(a.socket.peer_implementation(), SctpImplementation::kDcsctp); + EXPECT_EQ(z.socket.peer_implementation(), SctpImplementation::kDcsctp); +} + +TEST_P(DcSctpSocketParametrizedTest, CanLoseFirstOrderedMessage) { + SocketUnderTest a("A"); + auto z = std::make_unique<SocketUnderTest>("Z"); + + ConnectSockets(a, *z); + z = MaybeHandoverSocket(std::move(z)); + + SendOptions send_options; + send_options.unordered = IsUnordered(false); + send_options.max_retransmissions = 0; + std::vector<uint8_t> payload(a.options.mtu - 100); + + // Send a first message (SID=1, SSN=0) + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), send_options); + + // First DATA is lost, and retransmission timer will delete it. + a.cb.ConsumeSentPacket(); + AdvanceTime(a, *z, a.options.rto_initial); + ExchangeMessages(a, *z); + + // Send a second message (SID=0, SSN=1). + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), send_options); + ExchangeMessages(a, *z); + + // The Z socket should receive the second message, but not the first. + absl::optional<DcSctpMessage> msg = z->cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg.has_value()); + EXPECT_EQ(msg->ppid(), PPID(52)); + + EXPECT_FALSE(z->cb.ConsumeReceivedMessage().has_value()); + + MaybeHandoverSocketAndSendMessage(a, std::move(z)); +} + +TEST(DcSctpSocketTest, ReceiveBothUnorderedAndOrderedWithSameTSN) { + /* This issue was found by fuzzing. */ + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + a.socket.Connect(); + std::vector<uint8_t> init_data = a.cb.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket init_packet, + SctpPacket::Parse(init_data)); + ASSERT_HAS_VALUE_AND_ASSIGN( + InitChunk init_chunk, + InitChunk::Parse(init_packet.descriptors()[0].data)); + z.socket.ReceivePacket(init_data); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + a.socket.ReceivePacket(z.cb.ConsumeSentPacket()); + + // Receive a short unordered message with tsn=INITIAL_TSN+1 + TSN tsn = init_chunk.initial_tsn(); + AnyDataChunk::Options opts; + opts.is_beginning = Data::IsBeginning(true); + opts.is_end = Data::IsEnd(true); + opts.is_unordered = IsUnordered(true); + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(10), opts)) + .Build()); + + // Now receive a longer _ordered_ message with [INITIAL_TSN, INITIAL_TSN+1]. + // This isn't allowed as it reuses TSN=53 with different properties, but it + // shouldn't cause any issues. + opts.is_unordered = IsUnordered(false); + opts.is_end = Data::IsEnd(false); + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(tsn, StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(10), opts)) + .Build()); + + opts.is_beginning = Data::IsBeginning(false); + opts.is_end = Data::IsEnd(true); + z.socket.ReceivePacket( + SctpPacket::Builder(z.socket.verification_tag(), z.options) + .Add(DataChunk(TSN(*tsn + 1), StreamID(1), SSN(0), PPID(53), + std::vector<uint8_t>(10), opts)) + .Build()); +} + +TEST(DcSctpSocketTest, CloseTwoStreamsAtTheSameTime) { + // Reported as https://crbug.com/1312009. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(2)))).Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(2)))).Times(1); + + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + + ExchangeMessages(a, z); + + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)})); + + ExchangeMessages(a, z); +} + +TEST(DcSctpSocketTest, CloseThreeStreamsAtTheSameTime) { + // Similar to CloseTwoStreamsAtTheSameTime, but ensuring that the two + // remaining streams are reset at the same time in the second request. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(z.cb, OnIncomingStreamsReset( + UnorderedElementsAre(StreamID(2), StreamID(3)))) + .Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed( + UnorderedElementsAre(StreamID(2), StreamID(3)))) + .Times(1); + + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), kSendOptions); + + ExchangeMessages(a, z); + + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)})); + a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)})); + + ExchangeMessages(a, z); +} + +TEST(DcSctpSocketTest, CloseStreamsWithPendingRequest) { + // Checks that stream reset requests are properly paused when they can't be + // immediately reset - i.e. when there is already an ongoing stream reset + // request (and there can only be a single one in-flight). + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + EXPECT_CALL(z.cb, OnIncomingStreamsReset(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(z.cb, OnIncomingStreamsReset( + UnorderedElementsAre(StreamID(2), StreamID(3)))) + .Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed(ElementsAre(StreamID(1)))).Times(1); + EXPECT_CALL(a.cb, OnStreamsResetPerformed( + UnorderedElementsAre(StreamID(2), StreamID(3)))) + .Times(1); + + ConnectSockets(a, z); + + SendOptions send_options = {.unordered = IsUnordered(false)}; + + // Send a few ordered messages + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options); + + ExchangeMessages(a, z); + + // Receive these messages + absl::optional<DcSctpMessage> msg1 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg1.has_value()); + EXPECT_EQ(msg1->stream_id(), StreamID(1)); + absl::optional<DcSctpMessage> msg2 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg2.has_value()); + EXPECT_EQ(msg2->stream_id(), StreamID(2)); + absl::optional<DcSctpMessage> msg3 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg3.has_value()); + EXPECT_EQ(msg3->stream_id(), StreamID(3)); + + // Reset the streams - not all at once. + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + + std::vector<uint8_t> packet = a.cb.ConsumeSentPacket(); + EXPECT_THAT(packet, HasReconfigWithStreams(ElementsAre(StreamID(1)))); + z.socket.ReceivePacket(std::move(packet)); + + // Sending more reset requests while this one is ongoing. + + a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)})); + a.socket.ResetStreams(std::vector<StreamID>({StreamID(3)})); + + ExchangeMessages(a, z); + + // Send a few more ordered messages + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), {1, 2}), send_options); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), send_options); + a.socket.Send(DcSctpMessage(StreamID(3), PPID(53), {1, 2}), send_options); + + ExchangeMessages(a, z); + + // Receive these messages + absl::optional<DcSctpMessage> msg4 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg4.has_value()); + EXPECT_EQ(msg4->stream_id(), StreamID(1)); + absl::optional<DcSctpMessage> msg5 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg5.has_value()); + EXPECT_EQ(msg5->stream_id(), StreamID(2)); + absl::optional<DcSctpMessage> msg6 = z.cb.ConsumeReceivedMessage(); + ASSERT_TRUE(msg6.has_value()); + EXPECT_EQ(msg6->stream_id(), StreamID(3)); +} + +TEST(DcSctpSocketTest, StreamsHaveInitialPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + SocketUnderTest a("A", options); + + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), + options.default_stream_priority); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), + options.default_stream_priority); +} + +TEST(DcSctpSocketTest, CanChangeStreamPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + SocketUnderTest a("A", options); + + a.socket.SetStreamPriority(StreamID(1), StreamPriority(43)); + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(1)), StreamPriority(43)); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + + a.socket.SetStreamPriority(StreamID(2), StreamPriority(43)); + EXPECT_EQ(a.socket.GetStreamPriority(StreamID(2)), StreamPriority(43)); +} + +TEST_P(DcSctpSocketParametrizedTest, WillHandoverPriority) { + DcSctpOptions options = {.default_stream_priority = StreamPriority(42)}; + auto a = std::make_unique<SocketUnderTest>("A", options); + SocketUnderTest z("Z"); + + ConnectSockets(*a, z); + + a->socket.SetStreamPriority(StreamID(1), StreamPriority(43)); + a->socket.Send(DcSctpMessage(StreamID(2), PPID(53), {1, 2}), kSendOptions); + a->socket.SetStreamPriority(StreamID(2), StreamPriority(43)); + + ExchangeMessages(*a, z); + + a = MaybeHandoverSocket(std::move(a)); + + EXPECT_EQ(a->socket.GetStreamPriority(StreamID(1)), StreamPriority(43)); + EXPECT_EQ(a->socket.GetStreamPriority(StreamID(2)), StreamPriority(43)); +} + +TEST(DcSctpSocketTest, ReconnectSocketWithPendingStreamReset) { + // This is an issue found by fuzzing, and doesn't really make sense in WebRTC + // data channels as a SCTP connection is never ever closed and then + // reconnected. SCTP connections are closed when the peer connection is + // deleted, and then it doesn't do more with SCTP. + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + ConnectSockets(a, z); + + a.socket.ResetStreams(std::vector<StreamID>({StreamID(1)})); + + EXPECT_CALL(z.cb, OnAborted).Times(1); + a.socket.Close(); + + EXPECT_EQ(a.socket.state(), SocketState::kClosed); + + EXPECT_CALL(a.cb, OnConnected).Times(1); + EXPECT_CALL(z.cb, OnConnected).Times(1); + a.socket.Connect(); + ExchangeMessages(a, z); + a.socket.ResetStreams(std::vector<StreamID>({StreamID(2)})); +} + +TEST(DcSctpSocketTest, SmallSentMessagesWithPrioWillArriveInSpecificOrder) { + DcSctpOptions options = {.enable_message_interleaving = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("A", options); + + a.socket.SetStreamPriority(StreamID(1), StreamPriority(700)); + a.socket.SetStreamPriority(StreamID(2), StreamPriority(200)); + a.socket.SetStreamPriority(StreamID(3), StreamPriority(100)); + + // Enqueue messages before connecting the socket, to ensure they aren't send + // as soon as Send() is called. + a.socket.Send(DcSctpMessage(StreamID(3), PPID(301), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(101), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(201), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(102), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(103), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + + ConnectSockets(a, z); + ExchangeMessages(a, z); + + std::vector<uint32_t> received_ppids; + for (;;) { + absl::optional<DcSctpMessage> msg = z.cb.ConsumeReceivedMessage(); + if (!msg.has_value()) { + break; + } + received_ppids.push_back(*msg->ppid()); + } + + EXPECT_THAT(received_ppids, ElementsAre(101, 102, 103, 201, 301)); +} + +TEST(DcSctpSocketTest, LargeSentMessagesWithPrioWillArriveInSpecificOrder) { + DcSctpOptions options = {.enable_message_interleaving = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("A", options); + + a.socket.SetStreamPriority(StreamID(1), StreamPriority(700)); + a.socket.SetStreamPriority(StreamID(2), StreamPriority(200)); + a.socket.SetStreamPriority(StreamID(3), StreamPriority(100)); + + // Enqueue messages before connecting the socket, to ensure they aren't send + // as soon as Send() is called. + a.socket.Send(DcSctpMessage(StreamID(3), PPID(301), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(101), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(201), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(102), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + ConnectSockets(a, z); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201, 301)); +} + +TEST(DcSctpSocketTest, MessageWithHigherPrioWillInterruptLowerPrioMessage) { + DcSctpOptions options = {.enable_message_interleaving = true}; + SocketUnderTest a("A", options); + SocketUnderTest z("Z", options); + + ConnectSockets(a, z); + + a.socket.SetStreamPriority(StreamID(2), StreamPriority(128)); + a.socket.Send(DcSctpMessage(StreamID(2), PPID(201), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + // Due to a non-zero initial congestion window, the message will already start + // to send, but will not succeed to be sent completely before filling the + // congestion window or stopping due to reaching how many packets that can be + // sent at once (max burst). The important thing is that the entire message + // doesn't get sent in full. + + // Now enqueue two messages; one small and one large higher priority message. + a.socket.SetStreamPriority(StreamID(1), StreamPriority(512)); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(101), + std::vector<uint8_t>(kSmallMessageSize)), + kSendOptions); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(102), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 201)); +} + +TEST(DcSctpSocketTest, LifecycleEventsAreGeneratedForAckedMessages) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + ConnectSockets(a, z); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(101), + std::vector<uint8_t>(kLargeMessageSize)), + {.lifecycle_id = LifecycleId(41)}); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(102), + std::vector<uint8_t>(kLargeMessageSize)), + kSendOptions); + + a.socket.Send(DcSctpMessage(StreamID(2), PPID(103), + std::vector<uint8_t>(kLargeMessageSize)), + {.lifecycle_id = LifecycleId(42)}); + + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(41))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(41))); + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(42))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(42))); + ExchangeMessages(a, z); + // In case of delayed ack. + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(101, 102, 103)); +} + +TEST(DcSctpSocketTest, LifecycleEventsForFailMaxRetransmissions) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + ConnectSockets(a, z); + + std::vector<uint8_t> payload(a.options.mtu - 100); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(1), + }); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(52), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(2), + }); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(53), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(3), + }); + + // First DATA + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Second DATA (lost) + a.cb.ConsumeSentPacket(); + + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(1))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1))); + EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(2), + /*maybe_delivered=*/true)); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(2))); + EXPECT_CALL(a.cb, OnLifecycleMessageDelivered(LifecycleId(3))); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(3))); + ExchangeMessages(a, z); + + // Handle delayed SACK. + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); + ExchangeMessages(a, z); + + // The chunk is now NACKed. Let the RTO expire, to discard the message. + AdvanceTime(a, z, a.options.rto_initial); + ExchangeMessages(a, z); + + // Handle delayed SACK. + AdvanceTime(a, z, a.options.delayed_ack_max_timeout); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), ElementsAre(51, 53)); +} + +TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithRetransmitLimit) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + ConnectSockets(a, z); + + // Will not be able to send it in full within the congestion window, but will + // need to wait for SACKs to be received for more fragments to be sent. + std::vector<uint8_t> payload(kLargeMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .max_retransmissions = 0, + .lifecycle_id = LifecycleId(1), + }); + + // First DATA + z.socket.ReceivePacket(a.cb.ConsumeSentPacket()); + // Second DATA (lost) + a.cb.ConsumeSentPacket(); + + EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1), + /*maybe_delivered=*/false)); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1))); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty()); +} + +TEST(DcSctpSocketTest, LifecycleEventsForExpiredMessageWithLifetimeLimit) { + SocketUnderTest a("A"); + SocketUnderTest z("Z"); + + // Send it before the socket is connected, to prevent it from being sent too + // quickly. The idea is that it should be expired before even attempting to + // send it in full. + std::vector<uint8_t> payload(kSmallMessageSize); + a.socket.Send(DcSctpMessage(StreamID(1), PPID(51), payload), + { + .lifetime = DurationMs(100), + .lifecycle_id = LifecycleId(1), + }); + + AdvanceTime(a, z, DurationMs(200)); + + EXPECT_CALL(a.cb, OnLifecycleMessageExpired(LifecycleId(1), + /*maybe_delivered=*/false)); + EXPECT_CALL(a.cb, OnLifecycleEnd(LifecycleId(1))); + ConnectSockets(a, z); + ExchangeMessages(a, z); + + EXPECT_THAT(GetReceivedMessagePpids(z), IsEmpty()); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.cc b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.cc new file mode 100644 index 0000000000..9588b85b59 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.cc @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/heartbeat_handler.h" + +#include <stddef.h> + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/functional/bind_front.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +// This is stored (in serialized form) as HeartbeatInfoParameter sent in +// HeartbeatRequestChunk and received back in HeartbeatAckChunk. It should be +// well understood that this data may be modified by the peer, so it can't +// be trusted. +// +// It currently only stores a timestamp, in millisecond precision, to allow for +// RTT measurements. If that would be manipulated by the peer, it would just +// result in incorrect RTT measurements, which isn't an issue. +class HeartbeatInfo { + public: + static constexpr size_t kBufferSize = sizeof(uint64_t); + static_assert(kBufferSize == 8, "Unexpected buffer size"); + + explicit HeartbeatInfo(TimeMs created_at) : created_at_(created_at) {} + + std::vector<uint8_t> Serialize() { + uint32_t high_bits = static_cast<uint32_t>(*created_at_ >> 32); + uint32_t low_bits = static_cast<uint32_t>(*created_at_); + + std::vector<uint8_t> data(kBufferSize); + BoundedByteWriter<kBufferSize> writer(data); + writer.Store32<0>(high_bits); + writer.Store32<4>(low_bits); + return data; + } + + static absl::optional<HeartbeatInfo> Deserialize( + rtc::ArrayView<const uint8_t> data) { + if (data.size() != kBufferSize) { + RTC_LOG(LS_WARNING) << "Invalid heartbeat info: " << data.size() + << " bytes"; + return absl::nullopt; + } + + BoundedByteReader<kBufferSize> reader(data); + uint32_t high_bits = reader.Load32<0>(); + uint32_t low_bits = reader.Load32<4>(); + + uint64_t created_at = static_cast<uint64_t>(high_bits) << 32 | low_bits; + return HeartbeatInfo(TimeMs(created_at)); + } + + TimeMs created_at() const { return created_at_; } + + private: + const TimeMs created_at_; +}; + +HeartbeatHandler::HeartbeatHandler(absl::string_view log_prefix, + const DcSctpOptions& options, + Context* context, + TimerManager* timer_manager) + : log_prefix_(std::string(log_prefix) + "heartbeat: "), + ctx_(context), + timer_manager_(timer_manager), + interval_duration_(options.heartbeat_interval), + interval_duration_should_include_rtt_( + options.heartbeat_interval_include_rtt), + interval_timer_(timer_manager_->CreateTimer( + "heartbeat-interval", + absl::bind_front(&HeartbeatHandler::OnIntervalTimerExpiry, this), + TimerOptions(interval_duration_, TimerBackoffAlgorithm::kFixed))), + timeout_timer_(timer_manager_->CreateTimer( + "heartbeat-timeout", + absl::bind_front(&HeartbeatHandler::OnTimeoutTimerExpiry, this), + TimerOptions(options.rto_initial, + TimerBackoffAlgorithm::kExponential, + /*max_restarts=*/0))) { + // The interval timer must always be running as long as the association is up. + RestartTimer(); +} + +void HeartbeatHandler::RestartTimer() { + if (interval_duration_ == DurationMs(0)) { + // Heartbeating has been disabled. + return; + } + + if (interval_duration_should_include_rtt_) { + // The RTT should be used, but it's not easy accessible. The RTO will + // suffice. + interval_timer_->set_duration(interval_duration_ + ctx_->current_rto()); + } else { + interval_timer_->set_duration(interval_duration_); + } + + interval_timer_->Start(); +} + +void HeartbeatHandler::HandleHeartbeatRequest(HeartbeatRequestChunk chunk) { + // https://tools.ietf.org/html/rfc4960#section-8.3 + // "The receiver of the HEARTBEAT should immediately respond with a + // HEARTBEAT ACK that contains the Heartbeat Information TLV, together with + // any other received TLVs, copied unchanged from the received HEARTBEAT + // chunk." + ctx_->Send(ctx_->PacketBuilder().Add( + HeartbeatAckChunk(std::move(chunk).extract_parameters()))); +} + +void HeartbeatHandler::HandleHeartbeatAck(HeartbeatAckChunk chunk) { + timeout_timer_->Stop(); + absl::optional<HeartbeatInfoParameter> info_param = chunk.info(); + if (!info_param.has_value()) { + ctx_->callbacks().OnError( + ErrorKind::kParseFailed, + "Failed to parse HEARTBEAT-ACK; No Heartbeat Info parameter"); + return; + } + absl::optional<HeartbeatInfo> info = + HeartbeatInfo::Deserialize(info_param->info()); + if (!info.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse HEARTBEAT-ACK; Failed to " + "deserialized Heartbeat info parameter"); + return; + } + + TimeMs now = ctx_->callbacks().TimeMillis(); + if (info->created_at() > TimeMs(0) && info->created_at() <= now) { + ctx_->ObserveRTT(now - info->created_at()); + } + + // https://tools.ietf.org/html/rfc4960#section-8.1 + // "The counter shall be reset each time ... a HEARTBEAT ACK is received from + // the peer endpoint." + ctx_->ClearTxErrorCounter(); +} + +absl::optional<DurationMs> HeartbeatHandler::OnIntervalTimerExpiry() { + if (ctx_->is_connection_established()) { + HeartbeatInfo info(ctx_->callbacks().TimeMillis()); + timeout_timer_->set_duration(ctx_->current_rto()); + timeout_timer_->Start(); + RTC_DLOG(LS_INFO) << log_prefix_ << "Sending HEARTBEAT with timeout " + << *timeout_timer_->duration(); + + Parameters parameters = Parameters::Builder() + .Add(HeartbeatInfoParameter(info.Serialize())) + .Build(); + + ctx_->Send(ctx_->PacketBuilder().Add( + HeartbeatRequestChunk(std::move(parameters)))); + } else { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ + << "Will not send HEARTBEAT when connection not established"; + } + return absl::nullopt; +} + +absl::optional<DurationMs> HeartbeatHandler::OnTimeoutTimerExpiry() { + // Note that the timeout timer is not restarted. It will be started again when + // the interval timer expires. + RTC_DCHECK(!timeout_timer_->is_running()); + ctx_->IncrementTxErrorCounter("HEARTBEAT timeout"); + return absl::nullopt; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.h b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.h new file mode 100644 index 0000000000..14c3109534 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_ +#define NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_ + +#include <stdint.h> + +#include <memory> +#include <string> + +#include "absl/strings/string_view.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" + +namespace dcsctp { + +// HeartbeatHandler handles all logic around sending heartbeats and receiving +// the responses, as well as receiving incoming heartbeat requests. +// +// Heartbeats are sent on idle connections to ensure that the connection is +// still healthy and to measure the RTT. If a number of heartbeats time out, +// the connection will eventually be closed. +class HeartbeatHandler { + public: + HeartbeatHandler(absl::string_view log_prefix, + const DcSctpOptions& options, + Context* context, + TimerManager* timer_manager); + + // Called when the heartbeat interval timer should be restarted. This is + // generally done every time data is sent, which makes the timer expire when + // the connection is idle. + void RestartTimer(); + + // Called on received HeartbeatRequestChunk chunks. + void HandleHeartbeatRequest(HeartbeatRequestChunk chunk); + + // Called on received HeartbeatRequestChunk chunks. + void HandleHeartbeatAck(HeartbeatAckChunk chunk); + + private: + absl::optional<DurationMs> OnIntervalTimerExpiry(); + absl::optional<DurationMs> OnTimeoutTimerExpiry(); + + const std::string log_prefix_; + Context* ctx_; + TimerManager* timer_manager_; + // The time for a connection to be idle before a heartbeat is sent. + const DurationMs interval_duration_; + // Adding RTT to the duration will add some jitter, which is good in + // production, but less good in unit tests, which is why it can be disabled. + const bool interval_duration_should_include_rtt_; + const std::unique_ptr<Timer> interval_timer_; + const std::unique_ptr<Timer> timeout_timer_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_HEARTBEAT_HANDLER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler_test.cc b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler_test.cc new file mode 100644 index 0000000000..faa0e3da06 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/heartbeat_handler_test.cc @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/heartbeat_handler.h" + +#include <memory> +#include <utility> +#include <vector> + +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/packet/chunk/heartbeat_ack_chunk.h" +#include "net/dcsctp/packet/chunk/heartbeat_request_chunk.h" +#include "net/dcsctp/packet/parameter/heartbeat_info_parameter.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/socket/mock_context.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SizeIs; + +constexpr DurationMs kHeartbeatInterval = DurationMs(30'000); + +DcSctpOptions MakeOptions(DurationMs heartbeat_interval) { + DcSctpOptions options; + options.heartbeat_interval_include_rtt = false; + options.heartbeat_interval = heartbeat_interval; + return options; +} + +class HeartbeatHandlerTestBase : public testing::Test { + protected: + explicit HeartbeatHandlerTestBase(DurationMs heartbeat_interval) + : options_(MakeOptions(heartbeat_interval)), + context_(&callbacks_), + timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return callbacks_.CreateTimeout(precision); + }), + handler_("log: ", options_, &context_, &timer_manager_) {} + + void AdvanceTime(DurationMs duration) { + callbacks_.AdvanceTime(duration); + for (;;) { + absl::optional<TimeoutID> timeout_id = callbacks_.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + timer_manager_.HandleTimeout(*timeout_id); + } + } + + const DcSctpOptions options_; + NiceMock<MockDcSctpSocketCallbacks> callbacks_; + NiceMock<MockContext> context_; + TimerManager timer_manager_; + HeartbeatHandler handler_; +}; + +class HeartbeatHandlerTest : public HeartbeatHandlerTestBase { + protected: + HeartbeatHandlerTest() : HeartbeatHandlerTestBase(kHeartbeatInterval) {} +}; + +class DisabledHeartbeatHandlerTest : public HeartbeatHandlerTestBase { + protected: + DisabledHeartbeatHandlerTest() : HeartbeatHandlerTestBase(DurationMs(0)) {} +}; + +TEST_F(HeartbeatHandlerTest, HasRunningHeartbeatIntervalTimer) { + AdvanceTime(options_.heartbeat_interval); + + // Validate that a heartbeat request was sent. + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk request, + HeartbeatRequestChunk::Parse(packet.descriptors()[0].data)); + + EXPECT_TRUE(request.info().has_value()); +} + +TEST_F(HeartbeatHandlerTest, RepliesToHeartbeatRequests) { + uint8_t info_data[] = {1, 2, 3, 4, 5}; + HeartbeatRequestChunk request( + Parameters::Builder().Add(HeartbeatInfoParameter(info_data)).Build()); + + handler_.HandleHeartbeatRequest(std::move(request)); + + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatAckChunk response, + HeartbeatAckChunk::Parse(packet.descriptors()[0].data)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatInfoParameter param, + response.parameters().get<HeartbeatInfoParameter>()); + + EXPECT_THAT(param.info(), ElementsAre(1, 2, 3, 4, 5)); +} + +TEST_F(HeartbeatHandlerTest, SendsHeartbeatRequestsOnIdleChannel) { + AdvanceTime(options_.heartbeat_interval); + + // Grab the request, and make a response. + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk req, + HeartbeatRequestChunk::Parse(packet.descriptors()[0].data)); + + HeartbeatAckChunk ack(std::move(req).extract_parameters()); + + // Respond a while later. This RTT will be measured by the handler + constexpr DurationMs rtt(313); + + EXPECT_CALL(context_, ObserveRTT(rtt)).Times(1); + + callbacks_.AdvanceTime(rtt); + handler_.HandleHeartbeatAck(std::move(ack)); +} + +TEST_F(HeartbeatHandlerTest, DoesntObserveInvalidHeartbeats) { + AdvanceTime(options_.heartbeat_interval); + + // Grab the request, and make a response. + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + HeartbeatRequestChunk req, + HeartbeatRequestChunk::Parse(packet.descriptors()[0].data)); + + HeartbeatAckChunk ack(std::move(req).extract_parameters()); + + EXPECT_CALL(context_, ObserveRTT).Times(0); + + // Go backwards in time - which make the HEARTBEAT-ACK have an invalid + // timestamp in it, as it will be in the future. + callbacks_.AdvanceTime(DurationMs(-100)); + + handler_.HandleHeartbeatAck(std::move(ack)); +} + +TEST_F(HeartbeatHandlerTest, IncreasesErrorIfNotAckedInTime) { + DurationMs rto(105); + EXPECT_CALL(context_, current_rto).WillOnce(Return(rto)); + AdvanceTime(options_.heartbeat_interval); + + // Validate that a request was sent. + EXPECT_THAT(callbacks_.ConsumeSentPacket(), Not(IsEmpty())); + + EXPECT_CALL(context_, IncrementTxErrorCounter).Times(1); + AdvanceTime(rto); +} + +TEST_F(DisabledHeartbeatHandlerTest, IsReallyDisabled) { + AdvanceTime(options_.heartbeat_interval); + + // Validate that a request was NOT sent. + EXPECT_THAT(callbacks_.ConsumeSentPacket(), IsEmpty()); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/mock_context.h b/third_party/libwebrtc/net/dcsctp/socket/mock_context.h new file mode 100644 index 0000000000..88e71d1b35 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/mock_context.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_ +#define NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_ + +#include <cstdint> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "test/gmock.h" + +namespace dcsctp { + +class MockContext : public Context { + public: + static constexpr TSN MyInitialTsn() { return TSN(990); } + static constexpr TSN PeerInitialTsn() { return TSN(10); } + static constexpr VerificationTag PeerVerificationTag() { + return VerificationTag(0x01234567); + } + + explicit MockContext(MockDcSctpSocketCallbacks* callbacks) + : callbacks_(*callbacks) { + ON_CALL(*this, is_connection_established) + .WillByDefault(testing::Return(true)); + ON_CALL(*this, my_initial_tsn) + .WillByDefault(testing::Return(MyInitialTsn())); + ON_CALL(*this, peer_initial_tsn) + .WillByDefault(testing::Return(PeerInitialTsn())); + ON_CALL(*this, callbacks).WillByDefault(testing::ReturnRef(callbacks_)); + ON_CALL(*this, current_rto).WillByDefault(testing::Return(DurationMs(123))); + ON_CALL(*this, Send).WillByDefault([this](SctpPacket::Builder& builder) { + callbacks_.SendPacketWithStatus(builder.Build()); + }); + } + + MOCK_METHOD(bool, is_connection_established, (), (const, override)); + MOCK_METHOD(TSN, my_initial_tsn, (), (const, override)); + MOCK_METHOD(TSN, peer_initial_tsn, (), (const, override)); + MOCK_METHOD(DcSctpSocketCallbacks&, callbacks, (), (const, override)); + + MOCK_METHOD(void, ObserveRTT, (DurationMs rtt_ms), (override)); + MOCK_METHOD(DurationMs, current_rto, (), (const, override)); + MOCK_METHOD(bool, + IncrementTxErrorCounter, + (absl::string_view reason), + (override)); + MOCK_METHOD(void, ClearTxErrorCounter, (), (override)); + MOCK_METHOD(bool, HasTooManyTxErrors, (), (const, override)); + SctpPacket::Builder PacketBuilder() const override { + return SctpPacket::Builder(PeerVerificationTag(), options_); + } + MOCK_METHOD(void, Send, (SctpPacket::Builder & builder), (override)); + + DcSctpOptions options_; + MockDcSctpSocketCallbacks& callbacks_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_MOCK_CONTEXT_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h b/third_party/libwebrtc/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h new file mode 100644 index 0000000000..8b2a772fa3 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/mock_dcsctp_socket_callbacks.h @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_ +#define NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_ + +#include <cstdint> +#include <deque> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/timeout.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/timer/fake_timeout.h" +#include "rtc_base/logging.h" +#include "rtc_base/random.h" +#include "test/gmock.h" + +namespace dcsctp { + +namespace internal { +// It can be argued if a mocked random number generator should be deterministic +// or if it should be have as a "real" random number generator. In this +// implementation, each instantiation of `MockDcSctpSocketCallbacks` will have +// their `GetRandomInt` return different sequences, but each instantiation will +// always generate the same sequence of random numbers. This to make it easier +// to compare logs from tests, but still to let e.g. two different sockets (used +// in the same test) get different random numbers, so that they don't start e.g. +// on the same sequence number. While that isn't an issue in the protocol, it +// just makes debugging harder as the two sockets would look exactly the same. +// +// In a real implementation of `DcSctpSocketCallbacks` the random number +// generator backing `GetRandomInt` should be seeded externally and correctly. +inline int GetUniqueSeed() { + static int seed = 0; + return ++seed; +} +} // namespace internal + +class MockDcSctpSocketCallbacks : public DcSctpSocketCallbacks { + public: + explicit MockDcSctpSocketCallbacks(absl::string_view name = "") + : log_prefix_(name.empty() ? "" : std::string(name) + ": "), + random_(internal::GetUniqueSeed()), + timeout_manager_([this]() { return now_; }) { + ON_CALL(*this, SendPacketWithStatus) + .WillByDefault([this](rtc::ArrayView<const uint8_t> data) { + sent_packets_.emplace_back( + std::vector<uint8_t>(data.begin(), data.end())); + return SendPacketStatus::kSuccess; + }); + ON_CALL(*this, OnMessageReceived) + .WillByDefault([this](DcSctpMessage message) { + received_messages_.emplace_back(std::move(message)); + }); + + ON_CALL(*this, OnError) + .WillByDefault([this](ErrorKind error, absl::string_view message) { + RTC_LOG(LS_WARNING) + << log_prefix_ << "Socket error: " << ToString(error) << "; " + << message; + }); + ON_CALL(*this, OnAborted) + .WillByDefault([this](ErrorKind error, absl::string_view message) { + RTC_LOG(LS_WARNING) + << log_prefix_ << "Socket abort: " << ToString(error) << "; " + << message; + }); + ON_CALL(*this, TimeMillis).WillByDefault([this]() { return now_; }); + } + + MOCK_METHOD(SendPacketStatus, + SendPacketWithStatus, + (rtc::ArrayView<const uint8_t> data), + (override)); + + std::unique_ptr<Timeout> CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision) override { + // The fake timeout manager does not implement |precision|. + return timeout_manager_.CreateTimeout(); + } + + MOCK_METHOD(TimeMs, TimeMillis, (), (override)); + uint32_t GetRandomInt(uint32_t low, uint32_t high) override { + return random_.Rand(low, high); + } + + MOCK_METHOD(void, OnMessageReceived, (DcSctpMessage message), (override)); + MOCK_METHOD(void, + OnError, + (ErrorKind error, absl::string_view message), + (override)); + MOCK_METHOD(void, + OnAborted, + (ErrorKind error, absl::string_view message), + (override)); + MOCK_METHOD(void, OnConnected, (), (override)); + MOCK_METHOD(void, OnClosed, (), (override)); + MOCK_METHOD(void, OnConnectionRestarted, (), (override)); + MOCK_METHOD(void, + OnStreamsResetFailed, + (rtc::ArrayView<const StreamID> outgoing_streams, + absl::string_view reason), + (override)); + MOCK_METHOD(void, + OnStreamsResetPerformed, + (rtc::ArrayView<const StreamID> outgoing_streams), + (override)); + MOCK_METHOD(void, + OnIncomingStreamsReset, + (rtc::ArrayView<const StreamID> incoming_streams), + (override)); + MOCK_METHOD(void, OnBufferedAmountLow, (StreamID stream_id), (override)); + MOCK_METHOD(void, OnTotalBufferedAmountLow, (), (override)); + MOCK_METHOD(void, + OnLifecycleMessageExpired, + (LifecycleId lifecycle_id, bool maybe_delivered), + (override)); + MOCK_METHOD(void, + OnLifecycleMessageFullySent, + (LifecycleId lifecycle_id), + (override)); + MOCK_METHOD(void, + OnLifecycleMessageDelivered, + (LifecycleId lifecycle_id), + (override)); + MOCK_METHOD(void, OnLifecycleEnd, (LifecycleId lifecycle_id), (override)); + + bool HasPacket() const { return !sent_packets_.empty(); } + + std::vector<uint8_t> ConsumeSentPacket() { + if (sent_packets_.empty()) { + return {}; + } + std::vector<uint8_t> ret = std::move(sent_packets_.front()); + sent_packets_.pop_front(); + return ret; + } + absl::optional<DcSctpMessage> ConsumeReceivedMessage() { + if (received_messages_.empty()) { + return absl::nullopt; + } + DcSctpMessage ret = std::move(received_messages_.front()); + received_messages_.pop_front(); + return ret; + } + + void AdvanceTime(DurationMs duration_ms) { now_ = now_ + duration_ms; } + void SetTime(TimeMs now) { now_ = now; } + + absl::optional<TimeoutID> GetNextExpiredTimeout() { + return timeout_manager_.GetNextExpiredTimeout(); + } + + private: + const std::string log_prefix_; + TimeMs now_ = TimeMs(0); + webrtc::Random random_; + FakeTimeoutManager timeout_manager_; + std::deque<std::vector<uint8_t>> sent_packets_; + std::deque<DcSctpMessage> received_messages_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_MOCK_DCSCTP_SOCKET_CALLBACKS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/packet_sender.cc b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.cc new file mode 100644 index 0000000000..85392e205d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.cc @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/packet_sender.h" + +#include <utility> +#include <vector> + +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +PacketSender::PacketSender(DcSctpSocketCallbacks& callbacks, + std::function<void(rtc::ArrayView<const uint8_t>, + SendPacketStatus)> on_sent_packet) + : callbacks_(callbacks), on_sent_packet_(std::move(on_sent_packet)) {} + +bool PacketSender::Send(SctpPacket::Builder& builder) { + if (builder.empty()) { + return false; + } + + std::vector<uint8_t> payload = builder.Build(); + + SendPacketStatus status = callbacks_.SendPacketWithStatus(payload); + on_sent_packet_(payload, status); + switch (status) { + case SendPacketStatus::kSuccess: { + return true; + } + case SendPacketStatus::kTemporaryFailure: { + // TODO(boivie): Queue this packet to be retried to be sent later. + return false; + } + + case SendPacketStatus::kError: { + // Nothing that can be done. + return false; + } + } +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/packet_sender.h b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.h new file mode 100644 index 0000000000..7af4d3c47b --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/packet_sender.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_PACKET_SENDER_H_ +#define NET_DCSCTP_SOCKET_PACKET_SENDER_H_ + +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_socket.h" + +namespace dcsctp { + +// The PacketSender sends packets to the network using the provided callback +// interface. When an attempt to send a packet is made, the `on_sent_packet` +// callback will be triggered. +class PacketSender { + public: + PacketSender(DcSctpSocketCallbacks& callbacks, + std::function<void(rtc::ArrayView<const uint8_t>, + SendPacketStatus)> on_sent_packet); + + // Sends the packet, and returns true if it was sent successfully. + bool Send(SctpPacket::Builder& builder); + + private: + DcSctpSocketCallbacks& callbacks_; + + // Callback that will be triggered for every send attempt, indicating the + // status of the operation. + std::function<void(rtc::ArrayView<const uint8_t>, SendPacketStatus)> + on_sent_packet_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_PACKET_SENDER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/packet_sender_test.cc b/third_party/libwebrtc/net/dcsctp/socket/packet_sender_test.cc new file mode 100644 index 0000000000..079dc36a41 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/packet_sender_test.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/packet_sender.h" + +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::_; + +constexpr VerificationTag kVerificationTag(123); + +class PacketSenderTest : public testing::Test { + protected: + PacketSenderTest() : sender_(callbacks_, on_send_fn_.AsStdFunction()) {} + + SctpPacket::Builder PacketBuilder() const { + return SctpPacket::Builder(kVerificationTag, options_); + } + + DcSctpOptions options_; + testing::NiceMock<MockDcSctpSocketCallbacks> callbacks_; + testing::MockFunction<void(rtc::ArrayView<const uint8_t>, SendPacketStatus)> + on_send_fn_; + PacketSender sender_; +}; + +TEST_F(PacketSenderTest, SendPacketCallsCallback) { + EXPECT_CALL(on_send_fn_, Call(_, SendPacketStatus::kSuccess)); + EXPECT_TRUE(sender_.Send(PacketBuilder().Add(CookieAckChunk()))); + + EXPECT_CALL(callbacks_, SendPacketWithStatus) + .WillOnce(testing::Return(SendPacketStatus::kError)); + EXPECT_CALL(on_send_fn_, Call(_, SendPacketStatus::kError)); + EXPECT_FALSE(sender_.Send(PacketBuilder().Add(CookieAckChunk()))); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc new file mode 100644 index 0000000000..7d04cbb0d7 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.cc @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/state_cookie.h" + +#include <cstdint> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/bounded_byte_reader.h" +#include "net/dcsctp/packet/bounded_byte_writer.h" +#include "net/dcsctp/socket/capabilities.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +// Magic values, which the state cookie is prefixed with. +constexpr uint32_t kMagic1 = 1684230979; +constexpr uint32_t kMagic2 = 1414541360; +constexpr size_t StateCookie::kCookieSize; + +std::vector<uint8_t> StateCookie::Serialize() { + std::vector<uint8_t> cookie; + cookie.resize(kCookieSize); + BoundedByteWriter<kCookieSize> buffer(cookie); + buffer.Store32<0>(kMagic1); + buffer.Store32<4>(kMagic2); + buffer.Store32<8>(*initiate_tag_); + buffer.Store32<12>(*initial_tsn_); + buffer.Store32<16>(a_rwnd_); + buffer.Store32<20>(static_cast<uint32_t>(*tie_tag_ >> 32)); + buffer.Store32<24>(static_cast<uint32_t>(*tie_tag_)); + buffer.Store8<28>(capabilities_.partial_reliability); + buffer.Store8<29>(capabilities_.message_interleaving); + buffer.Store8<30>(capabilities_.reconfig); + return cookie; +} + +absl::optional<StateCookie> StateCookie::Deserialize( + rtc::ArrayView<const uint8_t> cookie) { + if (cookie.size() != kCookieSize) { + RTC_DLOG(LS_WARNING) << "Invalid state cookie: " << cookie.size() + << " bytes"; + return absl::nullopt; + } + + BoundedByteReader<kCookieSize> buffer(cookie); + uint32_t magic1 = buffer.Load32<0>(); + uint32_t magic2 = buffer.Load32<4>(); + if (magic1 != kMagic1 || magic2 != kMagic2) { + RTC_DLOG(LS_WARNING) << "Invalid state cookie; wrong magic"; + return absl::nullopt; + } + + VerificationTag verification_tag(buffer.Load32<8>()); + TSN initial_tsn(buffer.Load32<12>()); + uint32_t a_rwnd = buffer.Load32<16>(); + uint32_t tie_tag_upper = buffer.Load32<20>(); + uint32_t tie_tag_lower = buffer.Load32<24>(); + TieTag tie_tag(static_cast<uint64_t>(tie_tag_upper) << 32 | + static_cast<uint64_t>(tie_tag_lower)); + Capabilities capabilities; + capabilities.partial_reliability = buffer.Load8<28>() != 0; + capabilities.message_interleaving = buffer.Load8<29>() != 0; + capabilities.reconfig = buffer.Load8<30>() != 0; + + return StateCookie(verification_tag, initial_tsn, a_rwnd, tie_tag, + capabilities); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h new file mode 100644 index 0000000000..df4b801397 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_STATE_COOKIE_H_ +#define NET_DCSCTP_SOCKET_STATE_COOKIE_H_ + +#include <cstdint> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/socket/capabilities.h" + +namespace dcsctp { + +// This is serialized as a state cookie and put in INIT_ACK. The client then +// responds with this in COOKIE_ECHO. +// +// NOTE: Expect that the client will modify it to try to exploit the library. +// Do not trust anything in it; no pointers or anything like that. +class StateCookie { + public: + static constexpr size_t kCookieSize = 31; + + StateCookie(VerificationTag initiate_tag, + TSN initial_tsn, + uint32_t a_rwnd, + TieTag tie_tag, + Capabilities capabilities) + : initiate_tag_(initiate_tag), + initial_tsn_(initial_tsn), + a_rwnd_(a_rwnd), + tie_tag_(tie_tag), + capabilities_(capabilities) {} + + // Returns a serialized version of this cookie. + std::vector<uint8_t> Serialize(); + + // Deserializes the cookie, and returns absl::nullopt if that failed. + static absl::optional<StateCookie> Deserialize( + rtc::ArrayView<const uint8_t> cookie); + + VerificationTag initiate_tag() const { return initiate_tag_; } + TSN initial_tsn() const { return initial_tsn_; } + uint32_t a_rwnd() const { return a_rwnd_; } + TieTag tie_tag() const { return tie_tag_; } + const Capabilities& capabilities() const { return capabilities_; } + + private: + const VerificationTag initiate_tag_; + const TSN initial_tsn_; + const uint32_t a_rwnd_; + const TieTag tie_tag_; + const Capabilities capabilities_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_STATE_COOKIE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc b/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc new file mode 100644 index 0000000000..870620ba14 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/state_cookie_test.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/state_cookie.h" + +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; + +TEST(StateCookieTest, SerializeAndDeserialize) { + Capabilities capabilities = {/*partial_reliability=*/true, + /*message_interleaving=*/false, + /*reconfig=*/true}; + StateCookie cookie(VerificationTag(123), TSN(456), + /*a_rwnd=*/789, TieTag(101112), capabilities); + std::vector<uint8_t> serialized = cookie.Serialize(); + EXPECT_THAT(serialized, SizeIs(StateCookie::kCookieSize)); + ASSERT_HAS_VALUE_AND_ASSIGN(StateCookie deserialized, + StateCookie::Deserialize(serialized)); + EXPECT_EQ(deserialized.initiate_tag(), VerificationTag(123)); + EXPECT_EQ(deserialized.initial_tsn(), TSN(456)); + EXPECT_EQ(deserialized.a_rwnd(), 789u); + EXPECT_EQ(deserialized.tie_tag(), TieTag(101112)); + EXPECT_TRUE(deserialized.capabilities().partial_reliability); + EXPECT_FALSE(deserialized.capabilities().message_interleaving); + EXPECT_TRUE(deserialized.capabilities().reconfig); +} + +TEST(StateCookieTest, ValidateMagicValue) { + Capabilities capabilities = {/*partial_reliability=*/true, + /*message_interleaving=*/false, + /*reconfig=*/true}; + StateCookie cookie(VerificationTag(123), TSN(456), + /*a_rwnd=*/789, TieTag(101112), capabilities); + std::vector<uint8_t> serialized = cookie.Serialize(); + ASSERT_THAT(serialized, SizeIs(StateCookie::kCookieSize)); + + absl::string_view magic(reinterpret_cast<const char*>(serialized.data()), 8); + EXPECT_EQ(magic, "dcSCTP00"); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.cc b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.cc new file mode 100644 index 0000000000..9d86953f44 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.cc @@ -0,0 +1,349 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/stream_reset_handler.h" + +#include <cstdint> +#include <memory> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/parameter/add_incoming_streams_request_parameter.h" +#include "net/dcsctp/packet/parameter/add_outgoing_streams_request_parameter.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/packet/parameter/ssn_tsn_reset_request_parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/packet/tlv_trait.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +namespace { +using ResponseResult = ReconfigurationResponseParameter::Result; + +bool DescriptorsAre(const std::vector<ParameterDescriptor>& c, + uint16_t e1, + uint16_t e2) { + return (c[0].type == e1 && c[1].type == e2) || + (c[0].type == e2 && c[1].type == e1); +} + +} // namespace + +bool StreamResetHandler::Validate(const ReConfigChunk& chunk) { + const Parameters& parameters = chunk.parameters(); + + // https://tools.ietf.org/html/rfc6525#section-3.1 + // "Note that each RE-CONFIG chunk holds at least one parameter + // and at most two parameters. Only the following combinations are allowed:" + std::vector<ParameterDescriptor> descriptors = parameters.descriptors(); + if (descriptors.size() == 1) { + if ((descriptors[0].type == OutgoingSSNResetRequestParameter::kType) || + (descriptors[0].type == IncomingSSNResetRequestParameter::kType) || + (descriptors[0].type == SSNTSNResetRequestParameter::kType) || + (descriptors[0].type == AddOutgoingStreamsRequestParameter::kType) || + (descriptors[0].type == AddIncomingStreamsRequestParameter::kType) || + (descriptors[0].type == ReconfigurationResponseParameter::kType)) { + return true; + } + } else if (descriptors.size() == 2) { + if (DescriptorsAre(descriptors, OutgoingSSNResetRequestParameter::kType, + IncomingSSNResetRequestParameter::kType) || + DescriptorsAre(descriptors, AddOutgoingStreamsRequestParameter::kType, + AddIncomingStreamsRequestParameter::kType) || + DescriptorsAre(descriptors, ReconfigurationResponseParameter::kType, + OutgoingSSNResetRequestParameter::kType) || + DescriptorsAre(descriptors, ReconfigurationResponseParameter::kType, + ReconfigurationResponseParameter::kType)) { + return true; + } + } + + RTC_LOG(LS_WARNING) << "Invalid set of RE-CONFIG parameters"; + return false; +} + +absl::optional<std::vector<ReconfigurationResponseParameter>> +StreamResetHandler::Process(const ReConfigChunk& chunk) { + if (!Validate(chunk)) { + return absl::nullopt; + } + + std::vector<ReconfigurationResponseParameter> responses; + + for (const ParameterDescriptor& desc : chunk.parameters().descriptors()) { + switch (desc.type) { + case OutgoingSSNResetRequestParameter::kType: + HandleResetOutgoing(desc, responses); + break; + + case IncomingSSNResetRequestParameter::kType: + HandleResetIncoming(desc, responses); + break; + + case ReconfigurationResponseParameter::kType: + HandleResponse(desc); + break; + } + } + + return responses; +} + +void StreamResetHandler::HandleReConfig(ReConfigChunk chunk) { + absl::optional<std::vector<ReconfigurationResponseParameter>> responses = + Process(chunk); + + if (!responses.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse RE-CONFIG command"); + return; + } + + if (!responses->empty()) { + SctpPacket::Builder b = ctx_->PacketBuilder(); + Parameters::Builder params_builder; + for (const auto& response : *responses) { + params_builder.Add(response); + } + b.Add(ReConfigChunk(params_builder.Build())); + ctx_->Send(b); + } +} + +bool StreamResetHandler::ValidateReqSeqNbr( + ReconfigRequestSN req_seq_nbr, + std::vector<ReconfigurationResponseParameter>& responses) { + if (req_seq_nbr == last_processed_req_seq_nbr_) { + // This has already been performed previously. + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "req=" << *req_seq_nbr + << " already processed"; + responses.push_back(ReconfigurationResponseParameter( + req_seq_nbr, ResponseResult::kSuccessNothingToDo)); + return false; + } + + if (req_seq_nbr != ReconfigRequestSN(*last_processed_req_seq_nbr_ + 1)) { + // Too old, too new, from wrong association etc. + // This is expected to happen when handing over a RTCPeerConnection from one + // server to another. The client will notice this and may decide to close + // old data channels, which may be sent to the wrong (or both) servers + // during a handover. + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "req=" << *req_seq_nbr + << " bad seq_nbr"; + responses.push_back(ReconfigurationResponseParameter( + req_seq_nbr, ResponseResult::kErrorBadSequenceNumber)); + return false; + } + + return true; +} + +void StreamResetHandler::HandleResetOutgoing( + const ParameterDescriptor& descriptor, + std::vector<ReconfigurationResponseParameter>& responses) { + absl::optional<OutgoingSSNResetRequestParameter> req = + OutgoingSSNResetRequestParameter::Parse(descriptor.data); + if (!req.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse Outgoing Reset command"); + return; + } + + if (ValidateReqSeqNbr(req->request_sequence_number(), responses)) { + ResponseResult result; + + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "Reset outgoing streams with req_seq_nbr=" + << *req->request_sequence_number(); + + last_processed_req_seq_nbr_ = req->request_sequence_number(); + result = reassembly_queue_->ResetStreams( + *req, data_tracker_->last_cumulative_acked_tsn()); + if (result == ResponseResult::kSuccessPerformed) { + ctx_->callbacks().OnIncomingStreamsReset(req->stream_ids()); + } + responses.push_back(ReconfigurationResponseParameter( + req->request_sequence_number(), result)); + } +} + +void StreamResetHandler::HandleResetIncoming( + const ParameterDescriptor& descriptor, + std::vector<ReconfigurationResponseParameter>& responses) { + absl::optional<IncomingSSNResetRequestParameter> req = + IncomingSSNResetRequestParameter::Parse(descriptor.data); + if (!req.has_value()) { + ctx_->callbacks().OnError(ErrorKind::kParseFailed, + "Failed to parse Incoming Reset command"); + return; + } + if (ValidateReqSeqNbr(req->request_sequence_number(), responses)) { + responses.push_back(ReconfigurationResponseParameter( + req->request_sequence_number(), ResponseResult::kSuccessNothingToDo)); + last_processed_req_seq_nbr_ = req->request_sequence_number(); + } +} + +void StreamResetHandler::HandleResponse(const ParameterDescriptor& descriptor) { + absl::optional<ReconfigurationResponseParameter> resp = + ReconfigurationResponseParameter::Parse(descriptor.data); + if (!resp.has_value()) { + ctx_->callbacks().OnError( + ErrorKind::kParseFailed, + "Failed to parse Reconfiguration Response command"); + return; + } + + if (current_request_.has_value() && current_request_->has_been_sent() && + resp->response_sequence_number() == current_request_->req_seq_nbr()) { + reconfig_timer_->Stop(); + + switch (resp->result()) { + case ResponseResult::kSuccessNothingToDo: + case ResponseResult::kSuccessPerformed: + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Reset stream success, req_seq_nbr=" + << *current_request_->req_seq_nbr() << ", streams=" + << StrJoin(current_request_->streams(), ",", + [](rtc::StringBuilder& sb, StreamID stream_id) { + sb << *stream_id; + }); + ctx_->callbacks().OnStreamsResetPerformed(current_request_->streams()); + current_request_ = absl::nullopt; + retransmission_queue_->CommitResetStreams(); + break; + case ResponseResult::kInProgress: + RTC_DLOG(LS_VERBOSE) + << log_prefix_ << "Reset stream still pending, req_seq_nbr=" + << *current_request_->req_seq_nbr() << ", streams=" + << StrJoin(current_request_->streams(), ",", + [](rtc::StringBuilder& sb, StreamID stream_id) { + sb << *stream_id; + }); + // Force this request to be sent again, but with new req_seq_nbr. + current_request_->PrepareRetransmission(); + reconfig_timer_->set_duration(ctx_->current_rto()); + reconfig_timer_->Start(); + break; + case ResponseResult::kErrorRequestAlreadyInProgress: + case ResponseResult::kDenied: + case ResponseResult::kErrorWrongSSN: + case ResponseResult::kErrorBadSequenceNumber: + RTC_DLOG(LS_WARNING) + << log_prefix_ << "Reset stream error=" << ToString(resp->result()) + << ", req_seq_nbr=" << *current_request_->req_seq_nbr() + << ", streams=" + << StrJoin(current_request_->streams(), ",", + [](rtc::StringBuilder& sb, StreamID stream_id) { + sb << *stream_id; + }); + ctx_->callbacks().OnStreamsResetFailed(current_request_->streams(), + ToString(resp->result())); + current_request_ = absl::nullopt; + retransmission_queue_->RollbackResetStreams(); + break; + } + } +} + +absl::optional<ReConfigChunk> StreamResetHandler::MakeStreamResetRequest() { + // Only send stream resets if there are streams to reset, and no current + // ongoing request (there can only be one at a time), and if the stream + // can be reset. + if (current_request_.has_value() || + !retransmission_queue_->HasStreamsReadyToBeReset()) { + return absl::nullopt; + } + + current_request_.emplace(TSN(*retransmission_queue_->next_tsn() - 1), + retransmission_queue_->GetStreamsReadyToBeReset()); + reconfig_timer_->set_duration(ctx_->current_rto()); + reconfig_timer_->Start(); + return MakeReconfigChunk(); +} + +ReConfigChunk StreamResetHandler::MakeReconfigChunk() { + // The req_seq_nbr will be empty if the request has never been sent before, + // or if it was sent, but the sender responded "in progress", and then the + // req_seq_nbr will be cleared to re-send with a new number. But if the + // request is re-sent due to timeout (reconfig-timer expiring), the same + // req_seq_nbr will be used. + RTC_DCHECK(current_request_.has_value()); + + if (!current_request_->has_been_sent()) { + current_request_->PrepareToSend(next_outgoing_req_seq_nbr_); + next_outgoing_req_seq_nbr_ = + ReconfigRequestSN(*next_outgoing_req_seq_nbr_ + 1); + } + + Parameters::Builder params_builder = + Parameters::Builder().Add(OutgoingSSNResetRequestParameter( + current_request_->req_seq_nbr(), current_request_->req_seq_nbr(), + current_request_->sender_last_assigned_tsn(), + current_request_->streams())); + + return ReConfigChunk(params_builder.Build()); +} + +void StreamResetHandler::ResetStreams( + rtc::ArrayView<const StreamID> outgoing_streams) { + for (StreamID stream_id : outgoing_streams) { + retransmission_queue_->PrepareResetStream(stream_id); + } +} + +absl::optional<DurationMs> StreamResetHandler::OnReconfigTimerExpiry() { + if (current_request_->has_been_sent()) { + // There is an outstanding request, which timed out while waiting for a + // response. + if (!ctx_->IncrementTxErrorCounter("RECONFIG timeout")) { + // Timed out. The connection will close after processing the timers. + return absl::nullopt; + } + } else { + // There is no outstanding request, but there is a prepared one. This means + // that the receiver has previously responded "in progress", which resulted + // in retrying the request (but with a new req_seq_nbr) after a while. + } + + ctx_->Send(ctx_->PacketBuilder().Add(MakeReconfigChunk())); + return ctx_->current_rto(); +} + +HandoverReadinessStatus StreamResetHandler::GetHandoverReadiness() const { + HandoverReadinessStatus status; + if (retransmission_queue_->HasStreamsReadyToBeReset()) { + status.Add(HandoverUnreadinessReason::kPendingStreamReset); + } + if (current_request_.has_value()) { + status.Add(HandoverUnreadinessReason::kPendingStreamResetRequest); + } + return status; +} + +void StreamResetHandler::AddHandoverState(DcSctpSocketHandoverState& state) { + state.rx.last_completed_reset_req_sn = last_processed_req_seq_nbr_.value(); + state.tx.next_reset_req_sn = next_outgoing_req_seq_nbr_.value(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.h b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.h new file mode 100644 index 0000000000..6e49665538 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler.h @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_ +#define NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/functional/bind_front.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "rtc_base/containers/flat_set.h" + +namespace dcsctp { + +// StreamResetHandler handles sending outgoing stream reset requests (to close +// an SCTP stream, which translates to closing a data channel). +// +// It also handles incoming "outgoing stream reset requests", when the peer +// wants to close its data channel. +// +// Resetting streams is an asynchronous operation where the client will request +// a request a stream to be reset, but then it might not be performed exactly at +// this point. First, the sender might need to discard all messages that have +// been enqueued for this stream, or it may select to wait until all have been +// sent. At least, it must wait for the currently sending fragmented message to +// be fully sent, because a stream can't be reset while having received half a +// message. In the stream reset request, the "sender's last assigned TSN" is +// provided, which is simply the TSN for which the receiver should've received +// all messages before this value, before the stream can be reset. Since +// fragments can get lost or sent out-of-order, the receiver of a request may +// not have received all the data just yet, and then it will respond to the +// sender: "In progress". In other words, try again. The sender will then need +// to start a timer and try the very same request again (but with a new sequence +// number) until the receiver successfully performs the operation. +// +// All this can take some time, and may be driven by timers, so the client will +// ultimately be notified using callbacks. +// +// In this implementation, when a stream is reset, the queued but not-yet-sent +// messages will be discarded, but that may change in the future. RFC8831 allows +// both behaviors. +class StreamResetHandler { + public: + StreamResetHandler(absl::string_view log_prefix, + Context* context, + TimerManager* timer_manager, + DataTracker* data_tracker, + ReassemblyQueue* reassembly_queue, + RetransmissionQueue* retransmission_queue, + const DcSctpSocketHandoverState* handover_state = nullptr) + : log_prefix_(std::string(log_prefix) + "reset: "), + ctx_(context), + data_tracker_(data_tracker), + reassembly_queue_(reassembly_queue), + retransmission_queue_(retransmission_queue), + reconfig_timer_(timer_manager->CreateTimer( + "re-config", + absl::bind_front(&StreamResetHandler::OnReconfigTimerExpiry, this), + TimerOptions(DurationMs(0)))), + next_outgoing_req_seq_nbr_( + handover_state + ? ReconfigRequestSN(handover_state->tx.next_reset_req_sn) + : ReconfigRequestSN(*ctx_->my_initial_tsn())), + last_processed_req_seq_nbr_( + handover_state ? ReconfigRequestSN( + handover_state->rx.last_completed_reset_req_sn) + : ReconfigRequestSN(*ctx_->peer_initial_tsn() - 1)) { + } + + // Initiates reset of the provided streams. While there can only be one + // ongoing stream reset request at any time, this method can be called at any + // time and also multiple times. It will enqueue requests that can't be + // directly fulfilled, and will asynchronously process them when any ongoing + // request has completed. + void ResetStreams(rtc::ArrayView<const StreamID> outgoing_streams); + + // Creates a Reset Streams request that must be sent if returned. Will start + // the reconfig timer. Will return absl::nullopt if there is no need to + // create a request (no streams to reset) or if there already is an ongoing + // stream reset request that hasn't completed yet. + absl::optional<ReConfigChunk> MakeStreamResetRequest(); + + // Called when handling and incoming RE-CONFIG chunk. + void HandleReConfig(ReConfigChunk chunk); + + HandoverReadinessStatus GetHandoverReadiness() const; + + void AddHandoverState(DcSctpSocketHandoverState& state); + + private: + // Represents a stream request operation. There can only be one ongoing at + // any time, and a sent request may either succeed, fail or result in the + // receiver signaling that it can't process it right now, and then it will be + // retried. + class CurrentRequest { + public: + CurrentRequest(TSN sender_last_assigned_tsn, std::vector<StreamID> streams) + : req_seq_nbr_(absl::nullopt), + sender_last_assigned_tsn_(sender_last_assigned_tsn), + streams_(std::move(streams)) {} + + // Returns the current request sequence number, if this request has been + // sent (check `has_been_sent` first). Will return 0 if the request is just + // prepared (or scheduled for retransmission) but not yet sent. + ReconfigRequestSN req_seq_nbr() const { + return req_seq_nbr_.value_or(ReconfigRequestSN(0)); + } + + // The sender's last assigned TSN, from the retransmission queue. The + // receiver uses this to know when all data up to this TSN has been + // received, to know when to safely reset the stream. + TSN sender_last_assigned_tsn() const { return sender_last_assigned_tsn_; } + + // The streams that are to be reset. + const std::vector<StreamID>& streams() const { return streams_; } + + // If this request has been sent yet. If not, then it's either because it + // has only been prepared and not yet sent, or because the received couldn't + // apply the request, and then the exact same request will be retried, but + // with a new sequence number. + bool has_been_sent() const { return req_seq_nbr_.has_value(); } + + // If the receiver can't apply the request yet (and answered "In Progress"), + // this will be called to prepare the request to be retransmitted at a later + // time. + void PrepareRetransmission() { req_seq_nbr_ = absl::nullopt; } + + // If the request hasn't been sent yet, this assigns it a request number. + void PrepareToSend(ReconfigRequestSN new_req_seq_nbr) { + req_seq_nbr_ = new_req_seq_nbr; + } + + private: + // If this is set, this request has been sent. If it's not set, the request + // has been prepared, but has not yet been sent. This is typically used when + // the peer responded "in progress" and the same request (but a different + // request number) must be sent again. + absl::optional<ReconfigRequestSN> req_seq_nbr_; + // The sender's (that's us) last assigned TSN, from the retransmission + // queue. + TSN sender_last_assigned_tsn_; + // The streams that are to be reset in this request. + const std::vector<StreamID> streams_; + }; + + // Called to validate an incoming RE-CONFIG chunk. + bool Validate(const ReConfigChunk& chunk); + + // Processes a stream stream reconfiguration chunk and may either return + // absl::nullopt (on protocol errors), or a list of responses - either 0, 1 + // or 2. + absl::optional<std::vector<ReconfigurationResponseParameter>> Process( + const ReConfigChunk& chunk); + + // Creates the actual RE-CONFIG chunk. A request (which set `current_request`) + // must have been created prior. + ReConfigChunk MakeReconfigChunk(); + + // Called to validate the `req_seq_nbr`, that it's the next in sequence. If it + // fails to validate, and returns false, it will also add a response to + // `responses`. + bool ValidateReqSeqNbr( + ReconfigRequestSN req_seq_nbr, + std::vector<ReconfigurationResponseParameter>& responses); + + // Called when this socket receives an outgoing stream reset request. It might + // either be performed straight away, or have to be deferred, and the result + // of that will be put in `responses`. + void HandleResetOutgoing( + const ParameterDescriptor& descriptor, + std::vector<ReconfigurationResponseParameter>& responses); + + // Called when this socket receives an incoming stream reset request. This + // isn't really supported, but a successful response is put in `responses`. + void HandleResetIncoming( + const ParameterDescriptor& descriptor, + std::vector<ReconfigurationResponseParameter>& responses); + + // Called when receiving a response to an outgoing stream reset request. It + // will either commit the stream resetting, if the operation was successful, + // or will schedule a retry if it was deferred. And if it failed, the + // operation will be rolled back. + void HandleResponse(const ParameterDescriptor& descriptor); + + // Expiration handler for the Reconfig timer. + absl::optional<DurationMs> OnReconfigTimerExpiry(); + + const std::string log_prefix_; + Context* ctx_; + DataTracker* data_tracker_; + ReassemblyQueue* reassembly_queue_; + RetransmissionQueue* retransmission_queue_; + const std::unique_ptr<Timer> reconfig_timer_; + + // The next sequence number for outgoing stream requests. + ReconfigRequestSN next_outgoing_req_seq_nbr_; + + // The current stream request operation. + absl::optional<CurrentRequest> current_request_; + + // For incoming requests - last processed request sequence number. + ReconfigRequestSN last_processed_req_seq_nbr_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_STREAM_RESET_HANDLER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler_test.cc b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler_test.cc new file mode 100644 index 0000000000..493b4c4bf7 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/stream_reset_handler_test.cc @@ -0,0 +1,802 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/stream_reset_handler.h" + +#include <array> +#include <cstdint> +#include <memory> +#include <type_traits> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/parameter/incoming_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h" +#include "net/dcsctp/packet/parameter/parameter.h" +#include "net/dcsctp/packet/parameter/reconfiguration_response_parameter.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/mock_context.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "net/dcsctp/testing/data_generator.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/mock_send_queue.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::IsEmpty; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; +using ResponseResult = ReconfigurationResponseParameter::Result; + +constexpr TSN kMyInitialTsn = MockContext::MyInitialTsn(); +constexpr ReconfigRequestSN kMyInitialReqSn = ReconfigRequestSN(*kMyInitialTsn); +constexpr TSN kPeerInitialTsn = MockContext::PeerInitialTsn(); +constexpr ReconfigRequestSN kPeerInitialReqSn = + ReconfigRequestSN(*kPeerInitialTsn); +constexpr uint32_t kArwnd = 131072; +constexpr DurationMs kRto = DurationMs(250); + +constexpr std::array<uint8_t, 4> kShortPayload = {1, 2, 3, 4}; + +MATCHER_P3(SctpMessageIs, stream_id, ppid, expected_payload, "") { + if (arg.stream_id() != stream_id) { + *result_listener << "the stream_id is " << *arg.stream_id(); + return false; + } + + if (arg.ppid() != ppid) { + *result_listener << "the ppid is " << *arg.ppid(); + return false; + } + + if (std::vector<uint8_t>(arg.payload().begin(), arg.payload().end()) != + std::vector<uint8_t>(expected_payload.begin(), expected_payload.end())) { + *result_listener << "the payload is wrong"; + return false; + } + return true; +} + +TSN AddTo(TSN tsn, int delta) { + return TSN(*tsn + delta); +} + +ReconfigRequestSN AddTo(ReconfigRequestSN req_sn, int delta) { + return ReconfigRequestSN(*req_sn + delta); +} + +class StreamResetHandlerTest : public testing::Test { + protected: + StreamResetHandlerTest() + : ctx_(&callbacks_), + timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return callbacks_.CreateTimeout(precision); + }), + delayed_ack_timer_(timer_manager_.CreateTimer( + "test/delayed_ack", + []() { return absl::nullopt; }, + TimerOptions(DurationMs(0)))), + t3_rtx_timer_(timer_manager_.CreateTimer( + "test/t3_rtx", + []() { return absl::nullopt; }, + TimerOptions(DurationMs(0)))), + data_tracker_(std::make_unique<DataTracker>("log: ", + delayed_ack_timer_.get(), + kPeerInitialTsn)), + reasm_(std::make_unique<ReassemblyQueue>("log: ", + kPeerInitialTsn, + kArwnd)), + retransmission_queue_(std::make_unique<RetransmissionQueue>( + "", + &callbacks_, + kMyInitialTsn, + kArwnd, + producer_, + [](DurationMs rtt_ms) {}, + []() {}, + *t3_rtx_timer_, + DcSctpOptions())), + handler_( + std::make_unique<StreamResetHandler>("log: ", + &ctx_, + &timer_manager_, + data_tracker_.get(), + reasm_.get(), + retransmission_queue_.get())) { + EXPECT_CALL(ctx_, current_rto).WillRepeatedly(Return(kRto)); + } + + void AdvanceTime(DurationMs duration) { + callbacks_.AdvanceTime(kRto); + for (;;) { + absl::optional<TimeoutID> timeout_id = callbacks_.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + timer_manager_.HandleTimeout(*timeout_id); + } + } + + // Handles the passed in RE-CONFIG `chunk` and returns the responses + // that are sent in the response RE-CONFIG. + std::vector<ReconfigurationResponseParameter> HandleAndCatchResponse( + ReConfigChunk chunk) { + handler_->HandleReConfig(std::move(chunk)); + + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + if (payload.empty()) { + EXPECT_TRUE(false); + return {}; + } + + std::vector<ReconfigurationResponseParameter> responses; + absl::optional<SctpPacket> p = SctpPacket::Parse(payload); + if (!p.has_value()) { + EXPECT_TRUE(false); + return {}; + } + if (p->descriptors().size() != 1) { + EXPECT_TRUE(false); + return {}; + } + absl::optional<ReConfigChunk> response_chunk = + ReConfigChunk::Parse(p->descriptors()[0].data); + if (!response_chunk.has_value()) { + EXPECT_TRUE(false); + return {}; + } + for (const auto& desc : response_chunk->parameters().descriptors()) { + if (desc.type == ReconfigurationResponseParameter::kType) { + absl::optional<ReconfigurationResponseParameter> response = + ReconfigurationResponseParameter::Parse(desc.data); + if (!response.has_value()) { + EXPECT_TRUE(false); + return {}; + } + responses.emplace_back(*std::move(response)); + } + } + return responses; + } + + void PerformHandover() { + EXPECT_TRUE(handler_->GetHandoverReadiness().IsReady()); + EXPECT_TRUE(data_tracker_->GetHandoverReadiness().IsReady()); + EXPECT_TRUE(reasm_->GetHandoverReadiness().IsReady()); + EXPECT_TRUE(retransmission_queue_->GetHandoverReadiness().IsReady()); + + DcSctpSocketHandoverState state; + handler_->AddHandoverState(state); + data_tracker_->AddHandoverState(state); + reasm_->AddHandoverState(state); + + retransmission_queue_->AddHandoverState(state); + + g_handover_state_transformer_for_test(&state); + + data_tracker_ = std::make_unique<DataTracker>( + "log: ", delayed_ack_timer_.get(), kPeerInitialTsn); + data_tracker_->RestoreFromState(state); + reasm_ = + std::make_unique<ReassemblyQueue>("log: ", kPeerInitialTsn, kArwnd); + reasm_->RestoreFromState(state); + retransmission_queue_ = std::make_unique<RetransmissionQueue>( + "", &callbacks_, kMyInitialTsn, kArwnd, producer_, + [](DurationMs rtt_ms) {}, []() {}, *t3_rtx_timer_, DcSctpOptions(), + /*supports_partial_reliability=*/true, + /*use_message_interleaving=*/false); + retransmission_queue_->RestoreFromState(state); + handler_ = std::make_unique<StreamResetHandler>( + "log: ", &ctx_, &timer_manager_, data_tracker_.get(), reasm_.get(), + retransmission_queue_.get(), &state); + } + + DataGenerator gen_; + NiceMock<MockDcSctpSocketCallbacks> callbacks_; + NiceMock<MockContext> ctx_; + NiceMock<MockSendQueue> producer_; + TimerManager timer_manager_; + std::unique_ptr<Timer> delayed_ack_timer_; + std::unique_ptr<Timer> t3_rtx_timer_; + std::unique_ptr<DataTracker> data_tracker_; + std::unique_ptr<ReassemblyQueue> reasm_; + std::unique_ptr<RetransmissionQueue> retransmission_queue_; + std::unique_ptr<StreamResetHandler> handler_; +}; + +TEST_F(StreamResetHandlerTest, ChunkWithNoParametersReturnsError) { + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + EXPECT_CALL(callbacks_, OnError).Times(1); + handler_->HandleReConfig(ReConfigChunk(Parameters())); +} + +TEST_F(StreamResetHandlerTest, ChunkWithInvalidParametersReturnsError) { + Parameters::Builder builder; + // Two OutgoingSSNResetRequestParameter in a RE-CONFIG is not valid. + builder.Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(1), + ReconfigRequestSN(10), + kPeerInitialTsn, {StreamID(1)})); + builder.Add(OutgoingSSNResetRequestParameter(ReconfigRequestSN(2), + ReconfigRequestSN(10), + kPeerInitialTsn, {StreamID(2)})); + + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + EXPECT_CALL(callbacks_, OnError).Times(1); + handler_->HandleReConfig(ReConfigChunk(builder.Build())); +} + +TEST_F(StreamResetHandlerTest, FailToDeliverWithoutResettingStream) { + reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE")); + reasm_->Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE")); + + data_tracker_->Observe(kPeerInitialTsn); + data_tracker_->Observe(AddTo(kPeerInitialTsn, 1)); + EXPECT_THAT(reasm_->FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload), + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); + + gen_.ResetStream(); + reasm_->Add(AddTo(kPeerInitialTsn, 2), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm_->FlushMessages(), IsEmpty()); +} + +TEST_F(StreamResetHandlerTest, ResetStreamsNotDeferred) { + reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE")); + reasm_->Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE")); + + data_tracker_->Observe(kPeerInitialTsn); + data_tracker_->Observe(AddTo(kPeerInitialTsn, 1)); + EXPECT_THAT(reasm_->FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload), + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); + + Parameters::Builder builder; + builder.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 1), + {StreamID(1)})); + + std::vector<ReconfigurationResponseParameter> responses = + HandleAndCatchResponse(ReConfigChunk(builder.Build())); + EXPECT_THAT(responses, SizeIs(1)); + EXPECT_EQ(responses[0].result(), ResponseResult::kSuccessPerformed); + + gen_.ResetStream(); + reasm_->Add(AddTo(kPeerInitialTsn, 2), gen_.Ordered({1, 2, 3, 4}, "BE")); + EXPECT_THAT(reasm_->FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); +} + +TEST_F(StreamResetHandlerTest, ResetStreamsDeferred) { + DataGeneratorOptions opts; + opts.message_id = MID(0); + reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + + opts.message_id = MID(1); + reasm_->Add(AddTo(kPeerInitialTsn, 1), + gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + + data_tracker_->Observe(kPeerInitialTsn); + data_tracker_->Observe(AddTo(kPeerInitialTsn, 1)); + EXPECT_THAT(reasm_->FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload), + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); + + Parameters::Builder builder; + builder.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 3), + {StreamID(1)})); + + std::vector<ReconfigurationResponseParameter> responses = + HandleAndCatchResponse(ReConfigChunk(builder.Build())); + EXPECT_THAT(responses, SizeIs(1)); + EXPECT_EQ(responses[0].result(), ResponseResult::kInProgress); + + opts.message_id = MID(1); + opts.ppid = PPID(5); + reasm_->Add(AddTo(kPeerInitialTsn, 5), + gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm_->MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 1)); + + opts.message_id = MID(0); + opts.ppid = PPID(4); + reasm_->Add(AddTo(kPeerInitialTsn, 4), + gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm_->MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 1)); + + opts.message_id = MID(3); + opts.ppid = PPID(3); + reasm_->Add(AddTo(kPeerInitialTsn, 3), + gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm_->MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 1)); + + opts.message_id = MID(2); + opts.ppid = PPID(2); + reasm_->Add(AddTo(kPeerInitialTsn, 2), + gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm_->MaybeResetStreamsDeferred(AddTo(kPeerInitialTsn, 5)); + + EXPECT_THAT( + reasm_->FlushMessages(), + UnorderedElementsAre(SctpMessageIs(StreamID(1), PPID(2), kShortPayload), + SctpMessageIs(StreamID(1), PPID(3), kShortPayload), + SctpMessageIs(StreamID(1), PPID(4), kShortPayload), + SctpMessageIs(StreamID(1), PPID(5), kShortPayload))); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingRequestDirectly) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get<OutgoingSSNResetRequestParameter>()); + + EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(42))); +} + +TEST_F(StreamResetHandlerTest, ResetMultipleStreamsInOneRequest) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(40))); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(41))); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))).Times(2); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(43))); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(44))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + handler_->ResetStreams( + std::vector<StreamID>({StreamID(43), StreamID(44), StreamID(41)})); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42), StreamID(40)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return( + std::vector<StreamID>({StreamID(40), StreamID(41), StreamID(42), + StreamID(43), StreamID(44)}))); + absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get<OutgoingSSNResetRequestParameter>()); + + EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), + UnorderedElementsAre(StreamID(40), StreamID(41), StreamID(42), + StreamID(43), StreamID(44))); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingRequestDeferred) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()) + .WillOnce(Return(false)) + .WillOnce(Return(false)) + .WillOnce(Return(true)); + + EXPECT_FALSE(handler_->MakeStreamResetRequest().has_value()); + EXPECT_FALSE(handler_->MakeStreamResetRequest().has_value()); + EXPECT_TRUE(handler_->MakeStreamResetRequest().has_value()); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingResettingOnPositiveResponse) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get<OutgoingSSNResetRequestParameter>()); + + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter( + req.request_sequence_number(), ResponseResult::kSuccessPerformed)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams); + EXPECT_CALL(producer_, RollbackResetStreams).Times(0); + + // Processing a response shouldn't result in sending anything. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + handler_->HandleReConfig(std::move(response_reconfig)); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingResetRollbackOnError) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get<OutgoingSSNResetRequestParameter>()); + + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter( + req.request_sequence_number(), ResponseResult::kErrorBadSequenceNumber)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams).Times(0); + EXPECT_CALL(producer_, RollbackResetStreams); + + // Only requests should result in sending responses. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + handler_->HandleReConfig(std::move(response_reconfig)); +} + +TEST_F(StreamResetHandlerTest, SendOutgoingResetRetransmitOnInProgress) { + static constexpr StreamID kStreamToReset = StreamID(42); + + EXPECT_CALL(producer_, PrepareResetStream(kStreamToReset)); + handler_->ResetStreams(std::vector<StreamID>({kStreamToReset})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({kStreamToReset}))); + + absl::optional<ReConfigChunk> reconfig1 = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig1.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req1, + reconfig1->parameters().get<OutgoingSSNResetRequestParameter>()); + + // Simulate that the peer responded "In Progress". + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter(req1.request_sequence_number(), + ResponseResult::kInProgress)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams()).Times(0); + EXPECT_CALL(producer_, RollbackResetStreams()).Times(0); + + // Processing a response shouldn't result in sending anything. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + handler_->HandleReConfig(std::move(response_reconfig)); + + // Let some time pass, so that the reconfig timer expires, and retries the + // same request. + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(1); + AdvanceTime(kRto); + + std::vector<uint8_t> payload = callbacks_.ConsumeSentPacket(); + ASSERT_FALSE(payload.empty()); + + ASSERT_HAS_VALUE_AND_ASSIGN(SctpPacket packet, SctpPacket::Parse(payload)); + ASSERT_THAT(packet.descriptors(), SizeIs(1)); + ASSERT_HAS_VALUE_AND_ASSIGN( + ReConfigChunk reconfig2, + ReConfigChunk::Parse(packet.descriptors()[0].data)); + + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req2, + reconfig2.parameters().get<OutgoingSSNResetRequestParameter>()); + + EXPECT_EQ(req2.request_sequence_number(), + AddTo(req1.request_sequence_number(), 1)); + EXPECT_THAT(req2.stream_ids(), UnorderedElementsAre(kStreamToReset)); +} + +TEST_F(StreamResetHandlerTest, ResetWhileRequestIsSentWillQueue) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + absl::optional<ReConfigChunk> reconfig1 = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig1.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req1, + reconfig1->parameters().get<OutgoingSSNResetRequestParameter>()); + EXPECT_EQ(req1.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req1.sender_last_assigned_tsn(), + AddTo(retransmission_queue_->next_tsn(), -1)); + EXPECT_THAT(req1.stream_ids(), UnorderedElementsAre(StreamID(42))); + + // Streams reset while the request is in-flight will be queued. + EXPECT_CALL(producer_, PrepareResetStream(StreamID(41))); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(43))); + StreamID stream_ids[] = {StreamID(41), StreamID(43)}; + handler_->ResetStreams(stream_ids); + EXPECT_EQ(handler_->MakeStreamResetRequest(), absl::nullopt); + + Parameters::Builder builder; + builder.Add(ReconfigurationResponseParameter( + req1.request_sequence_number(), ResponseResult::kSuccessPerformed)); + ReConfigChunk response_reconfig(builder.Build()); + + EXPECT_CALL(producer_, CommitResetStreams()).Times(1); + EXPECT_CALL(producer_, RollbackResetStreams()).Times(0); + + // Processing a response shouldn't result in sending anything. + EXPECT_CALL(callbacks_, OnError).Times(0); + EXPECT_CALL(callbacks_, SendPacketWithStatus).Times(0); + handler_->HandleReConfig(std::move(response_reconfig)); + + // Response has been processed. A new request can be sent. + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(41), StreamID(43)}))); + + absl::optional<ReConfigChunk> reconfig2 = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig2.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req2, + reconfig2->parameters().get<OutgoingSSNResetRequestParameter>()); + EXPECT_EQ(req2.request_sequence_number(), AddTo(kMyInitialReqSn, 1)); + EXPECT_EQ(req2.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req2.stream_ids(), + UnorderedElementsAre(StreamID(41), StreamID(43))); +} + +TEST_F(StreamResetHandlerTest, SendIncomingResetJustReturnsNothingPerformed) { + Parameters::Builder builder; + builder.Add( + IncomingSSNResetRequestParameter(kPeerInitialReqSn, {StreamID(1)})); + + std::vector<ReconfigurationResponseParameter> responses = + HandleAndCatchResponse(ReConfigChunk(builder.Build())); + ASSERT_THAT(responses, SizeIs(1)); + EXPECT_THAT(responses[0].response_sequence_number(), kPeerInitialReqSn); + EXPECT_THAT(responses[0].result(), ResponseResult::kSuccessNothingToDo); +} + +TEST_F(StreamResetHandlerTest, SendSameRequestTwiceReturnsNothingToDo) { + reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE")); + reasm_->Add(AddTo(kPeerInitialTsn, 1), gen_.Ordered({1, 2, 3, 4}, "BE")); + + data_tracker_->Observe(kPeerInitialTsn); + data_tracker_->Observe(AddTo(kPeerInitialTsn, 1)); + EXPECT_THAT(reasm_->FlushMessages(), + UnorderedElementsAre( + SctpMessageIs(StreamID(1), PPID(53), kShortPayload), + SctpMessageIs(StreamID(1), PPID(53), kShortPayload))); + + Parameters::Builder builder1; + builder1.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 1), + {StreamID(1)})); + + std::vector<ReconfigurationResponseParameter> responses1 = + HandleAndCatchResponse(ReConfigChunk(builder1.Build())); + EXPECT_THAT(responses1, SizeIs(1)); + EXPECT_EQ(responses1[0].result(), ResponseResult::kSuccessPerformed); + + Parameters::Builder builder2; + builder2.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), AddTo(kPeerInitialTsn, 1), + {StreamID(1)})); + + std::vector<ReconfigurationResponseParameter> responses2 = + HandleAndCatchResponse(ReConfigChunk(builder2.Build())); + EXPECT_THAT(responses2, SizeIs(1)); + EXPECT_EQ(responses2[0].result(), ResponseResult::kSuccessNothingToDo); +} + +TEST_F(StreamResetHandlerTest, + HandoverIsAllowedOnlyWhenNoStreamIsBeingOrWillBeReset) { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_EQ( + handler_->GetHandoverReadiness(), + HandoverReadinessStatus(HandoverUnreadinessReason::kPendingStreamReset)); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + ASSERT_TRUE(handler_->MakeStreamResetRequest().has_value()); + EXPECT_EQ(handler_->GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kPendingStreamResetRequest)); + + // Reset more streams while the request is in-flight. + EXPECT_CALL(producer_, PrepareResetStream(StreamID(41))); + EXPECT_CALL(producer_, PrepareResetStream(StreamID(43))); + StreamID stream_ids[] = {StreamID(41), StreamID(43)}; + handler_->ResetStreams(stream_ids); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_EQ(handler_->GetHandoverReadiness(), + HandoverReadinessStatus() + .Add(HandoverUnreadinessReason::kPendingStreamResetRequest) + .Add(HandoverUnreadinessReason::kPendingStreamReset)); + + // Processing a response to first request. + EXPECT_CALL(producer_, CommitResetStreams()).Times(1); + handler_->HandleReConfig( + ReConfigChunk(Parameters::Builder() + .Add(ReconfigurationResponseParameter( + kMyInitialReqSn, ResponseResult::kSuccessPerformed)) + .Build())); + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_EQ( + handler_->GetHandoverReadiness(), + HandoverReadinessStatus(HandoverUnreadinessReason::kPendingStreamReset)); + + // Second request can be sent. + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(41), StreamID(43)}))); + + ASSERT_TRUE(handler_->MakeStreamResetRequest().has_value()); + EXPECT_EQ(handler_->GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kPendingStreamResetRequest)); + + // Processing a response to second request. + EXPECT_CALL(producer_, CommitResetStreams()).Times(1); + handler_->HandleReConfig(ReConfigChunk( + Parameters::Builder() + .Add(ReconfigurationResponseParameter( + AddTo(kMyInitialReqSn, 1), ResponseResult::kSuccessPerformed)) + .Build())); + + // Seconds response has been processed. No pending resets. + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(false)); + + EXPECT_TRUE(handler_->GetHandoverReadiness().IsReady()); +} + +TEST_F(StreamResetHandlerTest, HandoverInInitialState) { + PerformHandover(); + + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + absl::optional<ReConfigChunk> reconfig = handler_->MakeStreamResetRequest(); + ASSERT_TRUE(reconfig.has_value()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig->parameters().get<OutgoingSSNResetRequestParameter>()); + + EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(42))); +} + +TEST_F(StreamResetHandlerTest, HandoverAfterHavingResetOneStream) { + // Reset one stream + { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(42))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(42)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()) + .WillOnce(Return(true)) + .WillOnce(Return(false)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(42)}))); + + ASSERT_HAS_VALUE_AND_ASSIGN(ReConfigChunk reconfig, + handler_->MakeStreamResetRequest()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig.parameters().get<OutgoingSSNResetRequestParameter>()); + EXPECT_EQ(req.request_sequence_number(), kMyInitialReqSn); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(42))); + + EXPECT_CALL(producer_, CommitResetStreams()).Times(1); + handler_->HandleReConfig( + ReConfigChunk(Parameters::Builder() + .Add(ReconfigurationResponseParameter( + req.request_sequence_number(), + ResponseResult::kSuccessPerformed)) + .Build())); + } + + PerformHandover(); + + // Reset another stream after handover + { + EXPECT_CALL(producer_, PrepareResetStream(StreamID(43))); + handler_->ResetStreams(std::vector<StreamID>({StreamID(43)})); + + EXPECT_CALL(producer_, HasStreamsReadyToBeReset()).WillOnce(Return(true)); + EXPECT_CALL(producer_, GetStreamsReadyToBeReset()) + .WillOnce(Return(std::vector<StreamID>({StreamID(43)}))); + + ASSERT_HAS_VALUE_AND_ASSIGN(ReConfigChunk reconfig, + handler_->MakeStreamResetRequest()); + ASSERT_HAS_VALUE_AND_ASSIGN( + OutgoingSSNResetRequestParameter req, + reconfig.parameters().get<OutgoingSSNResetRequestParameter>()); + + EXPECT_EQ(req.request_sequence_number(), + ReconfigRequestSN(kMyInitialReqSn.value() + 1)); + EXPECT_EQ(req.sender_last_assigned_tsn(), + TSN(*retransmission_queue_->next_tsn() - 1)); + EXPECT_THAT(req.stream_ids(), UnorderedElementsAre(StreamID(43))); + } +} + +TEST_F(StreamResetHandlerTest, PerformCloseAfterOneFirstFailing) { + // Inject a stream reset on the first expected TSN (which hasn't been seen). + Parameters::Builder builder; + builder.Add(OutgoingSSNResetRequestParameter( + kPeerInitialReqSn, ReconfigRequestSN(3), kPeerInitialTsn, {StreamID(1)})); + + // The socket is expected to say "in progress" as that TSN hasn't been seen. + std::vector<ReconfigurationResponseParameter> responses = + HandleAndCatchResponse(ReConfigChunk(builder.Build())); + EXPECT_THAT(responses, SizeIs(1)); + EXPECT_EQ(responses[0].result(), ResponseResult::kInProgress); + + // Let the socket receive the TSN. + DataGeneratorOptions opts; + opts.message_id = MID(0); + reasm_->Add(kPeerInitialTsn, gen_.Ordered({1, 2, 3, 4}, "BE", opts)); + reasm_->MaybeResetStreamsDeferred(kPeerInitialTsn); + data_tracker_->Observe(kPeerInitialTsn); + + // And emulate that time has passed, and the peer retries the stream reset, + // but now with an incremented request sequence number. + Parameters::Builder builder2; + builder2.Add(OutgoingSSNResetRequestParameter( + ReconfigRequestSN(*kPeerInitialReqSn + 1), ReconfigRequestSN(3), + kPeerInitialTsn, {StreamID(1)})); + + // This is supposed to be handled well. + std::vector<ReconfigurationResponseParameter> responses2 = + HandleAndCatchResponse(ReConfigChunk(builder2.Build())); + EXPECT_THAT(responses2, SizeIs(1)); + EXPECT_EQ(responses2[0].result(), ResponseResult::kSuccessPerformed); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.cc b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.cc new file mode 100644 index 0000000000..d769e26069 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.cc @@ -0,0 +1,314 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/socket/transmission_control_block.h" + +#include <algorithm> +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/reconfig_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/capabilities.h" +#include "net/dcsctp/socket/stream_reset_handler.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { + +TransmissionControlBlock::TransmissionControlBlock( + TimerManager& timer_manager, + absl::string_view log_prefix, + const DcSctpOptions& options, + const Capabilities& capabilities, + DcSctpSocketCallbacks& callbacks, + SendQueue& send_queue, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag, + PacketSender& packet_sender, + std::function<bool()> is_connection_established) + : log_prefix_(log_prefix), + options_(options), + timer_manager_(timer_manager), + capabilities_(capabilities), + callbacks_(callbacks), + t3_rtx_(timer_manager_.CreateTimer( + "t3-rtx", + absl::bind_front(&TransmissionControlBlock::OnRtxTimerExpiry, this), + TimerOptions(options.rto_initial, + TimerBackoffAlgorithm::kExponential, + /*max_restarts=*/absl::nullopt, + options.max_timer_backoff_duration))), + delayed_ack_timer_(timer_manager_.CreateTimer( + "delayed-ack", + absl::bind_front(&TransmissionControlBlock::OnDelayedAckTimerExpiry, + this), + TimerOptions(options.delayed_ack_max_timeout, + TimerBackoffAlgorithm::kExponential, + /*max_restarts=*/0, + /*max_backoff_duration=*/absl::nullopt, + webrtc::TaskQueueBase::DelayPrecision::kHigh))), + my_verification_tag_(my_verification_tag), + my_initial_tsn_(my_initial_tsn), + peer_verification_tag_(peer_verification_tag), + peer_initial_tsn_(peer_initial_tsn), + tie_tag_(tie_tag), + is_connection_established_(std::move(is_connection_established)), + packet_sender_(packet_sender), + rto_(options), + tx_error_counter_(log_prefix, options), + data_tracker_(log_prefix, delayed_ack_timer_.get(), peer_initial_tsn), + reassembly_queue_(log_prefix, + peer_initial_tsn, + options.max_receiver_window_buffer_size, + capabilities.message_interleaving), + retransmission_queue_( + log_prefix, + &callbacks_, + my_initial_tsn, + a_rwnd, + send_queue, + absl::bind_front(&TransmissionControlBlock::ObserveRTT, this), + [this]() { tx_error_counter_.Clear(); }, + *t3_rtx_, + options, + capabilities.partial_reliability, + capabilities.message_interleaving), + stream_reset_handler_(log_prefix, + this, + &timer_manager, + &data_tracker_, + &reassembly_queue_, + &retransmission_queue_), + heartbeat_handler_(log_prefix, options, this, &timer_manager_) { + send_queue.EnableMessageInterleaving(capabilities.message_interleaving); +} + +void TransmissionControlBlock::ObserveRTT(DurationMs rtt) { + DurationMs prev_rto = rto_.rto(); + rto_.ObserveRTT(rtt); + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "new rtt=" << *rtt + << ", srtt=" << *rto_.srtt() << ", rto=" << *rto_.rto() + << " (" << *prev_rto << ")"; + t3_rtx_->set_duration(rto_.rto()); + + DurationMs delayed_ack_tmo = + std::min(rto_.rto() * 0.5, options_.delayed_ack_max_timeout); + delayed_ack_timer_->set_duration(delayed_ack_tmo); +} + +absl::optional<DurationMs> TransmissionControlBlock::OnRtxTimerExpiry() { + TimeMs now = callbacks_.TimeMillis(); + RTC_DLOG(LS_INFO) << log_prefix_ << "Timer " << t3_rtx_->name() + << " has expired"; + if (cookie_echo_chunk_.has_value()) { + // In the COOKIE_ECHO state, let the T1-COOKIE timer trigger + // retransmissions, to avoid having two timers doing that. + RTC_DLOG(LS_VERBOSE) << "Not retransmitting as T1-cookie is active."; + } else { + if (IncrementTxErrorCounter("t3-rtx expired")) { + retransmission_queue_.HandleT3RtxTimerExpiry(); + SendBufferedPackets(now); + } + } + return absl::nullopt; +} + +absl::optional<DurationMs> TransmissionControlBlock::OnDelayedAckTimerExpiry() { + data_tracker_.HandleDelayedAckTimerExpiry(); + MaybeSendSack(); + return absl::nullopt; +} + +void TransmissionControlBlock::MaybeSendSack() { + if (data_tracker_.ShouldSendAck(/*also_if_delayed=*/false)) { + SctpPacket::Builder builder = PacketBuilder(); + builder.Add( + data_tracker_.CreateSelectiveAck(reassembly_queue_.remaining_bytes())); + Send(builder); + } +} + +void TransmissionControlBlock::MaybeSendForwardTsn(SctpPacket::Builder& builder, + TimeMs now) { + if (now >= limit_forward_tsn_until_ && + retransmission_queue_.ShouldSendForwardTsn(now)) { + if (capabilities_.message_interleaving) { + builder.Add(retransmission_queue_.CreateIForwardTsn()); + } else { + builder.Add(retransmission_queue_.CreateForwardTsn()); + } + packet_sender_.Send(builder); + // https://datatracker.ietf.org/doc/html/rfc3758 + // "IMPLEMENTATION NOTE: An implementation may wish to limit the number of + // duplicate FORWARD TSN chunks it sends by ... waiting a full RTT before + // sending a duplicate FORWARD TSN." + // "Any delay applied to the sending of FORWARD TSN chunk SHOULD NOT exceed + // 200ms and MUST NOT exceed 500ms". + limit_forward_tsn_until_ = now + std::min(DurationMs(200), rto_.srtt()); + } +} + +void TransmissionControlBlock::MaybeSendFastRetransmit() { + if (!retransmission_queue_.has_data_to_be_fast_retransmitted()) { + return; + } + + // https://datatracker.ietf.org/doc/html/rfc4960#section-7.2.4 + // "Determine how many of the earliest (i.e., lowest TSN) DATA chunks marked + // for retransmission will fit into a single packet, subject to constraint of + // the path MTU of the destination transport address to which the packet is + // being sent. Call this value K. Retransmit those K DATA chunks in a single + // packet. When a Fast Retransmit is being performed, the sender SHOULD + // ignore the value of cwnd and SHOULD NOT delay retransmission for this + // single packet." + + SctpPacket::Builder builder(peer_verification_tag_, options_); + auto chunks = retransmission_queue_.GetChunksForFastRetransmit( + builder.bytes_remaining()); + for (auto& [tsn, data] : chunks) { + if (capabilities_.message_interleaving) { + builder.Add(IDataChunk(tsn, std::move(data), false)); + } else { + builder.Add(DataChunk(tsn, std::move(data), false)); + } + } + packet_sender_.Send(builder); +} + +void TransmissionControlBlock::SendBufferedPackets(SctpPacket::Builder& builder, + TimeMs now) { + for (int packet_idx = 0; + packet_idx < options_.max_burst && retransmission_queue_.can_send_data(); + ++packet_idx) { + // Only add control chunks to the first packet that is sent, if sending + // multiple packets in one go (as allowed by the congestion window). + if (packet_idx == 0) { + if (cookie_echo_chunk_.has_value()) { + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "The COOKIE ECHO chunk can be bundled with any pending outbound DATA + // chunks, but it MUST be the first chunk in the packet..." + RTC_DCHECK(builder.empty()); + builder.Add(*cookie_echo_chunk_); + } + + // https://tools.ietf.org/html/rfc4960#section-6 + // "Before an endpoint transmits a DATA chunk, if any received DATA + // chunks have not been acknowledged (e.g., due to delayed ack), the + // sender should create a SACK and bundle it with the outbound DATA chunk, + // as long as the size of the final SCTP packet does not exceed the + // current MTU." + if (data_tracker_.ShouldSendAck(/*also_if_delayed=*/true)) { + builder.Add(data_tracker_.CreateSelectiveAck( + reassembly_queue_.remaining_bytes())); + } + MaybeSendForwardTsn(builder, now); + absl::optional<ReConfigChunk> reconfig = + stream_reset_handler_.MakeStreamResetRequest(); + if (reconfig.has_value()) { + builder.Add(*reconfig); + } + } + + auto chunks = + retransmission_queue_.GetChunksToSend(now, builder.bytes_remaining()); + for (auto& [tsn, data] : chunks) { + if (capabilities_.message_interleaving) { + builder.Add(IDataChunk(tsn, std::move(data), false)); + } else { + builder.Add(DataChunk(tsn, std::move(data), false)); + } + } + + if (!packet_sender_.Send(builder)) { + break; + } + + if (cookie_echo_chunk_.has_value()) { + // https://tools.ietf.org/html/rfc4960#section-5.1 + // "... until the COOKIE ACK is returned the sender MUST NOT send any + // other packets to the peer." + break; + } + } +} + +std::string TransmissionControlBlock::ToString() const { + rtc::StringBuilder sb; + + sb.AppendFormat( + "verification_tag=%08x, last_cumulative_ack=%u, capabilities=", + *peer_verification_tag_, *data_tracker_.last_cumulative_acked_tsn()); + + if (capabilities_.partial_reliability) { + sb << "PR,"; + } + if (capabilities_.message_interleaving) { + sb << "IL,"; + } + if (capabilities_.reconfig) { + sb << "Reconfig,"; + } + + return sb.Release(); +} + +HandoverReadinessStatus TransmissionControlBlock::GetHandoverReadiness() const { + HandoverReadinessStatus status; + status.Add(data_tracker_.GetHandoverReadiness()); + status.Add(stream_reset_handler_.GetHandoverReadiness()); + status.Add(reassembly_queue_.GetHandoverReadiness()); + status.Add(retransmission_queue_.GetHandoverReadiness()); + return status; +} + +void TransmissionControlBlock::AddHandoverState( + DcSctpSocketHandoverState& state) { + state.capabilities.partial_reliability = capabilities_.partial_reliability; + state.capabilities.message_interleaving = capabilities_.message_interleaving; + state.capabilities.reconfig = capabilities_.reconfig; + + state.my_verification_tag = my_verification_tag().value(); + state.peer_verification_tag = peer_verification_tag().value(); + state.my_initial_tsn = my_initial_tsn().value(); + state.peer_initial_tsn = peer_initial_tsn().value(); + state.tie_tag = tie_tag().value(); + + data_tracker_.AddHandoverState(state); + stream_reset_handler_.AddHandoverState(state); + reassembly_queue_.AddHandoverState(state); + retransmission_queue_.AddHandoverState(state); +} + +void TransmissionControlBlock::RestoreFromState( + const DcSctpSocketHandoverState& state) { + data_tracker_.RestoreFromState(state); + retransmission_queue_.RestoreFromState(state); + reassembly_queue_.RestoreFromState(state); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.h b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.h new file mode 100644 index 0000000000..8e0e9a3ec5 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/socket/transmission_control_block.h @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_ +#define NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_ + +#include <cstdint> +#include <functional> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/functional/bind_front.h" +#include "absl/strings/string_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/rx/data_tracker.h" +#include "net/dcsctp/rx/reassembly_queue.h" +#include "net/dcsctp/socket/capabilities.h" +#include "net/dcsctp/socket/context.h" +#include "net/dcsctp/socket/heartbeat_handler.h" +#include "net/dcsctp/socket/packet_sender.h" +#include "net/dcsctp/socket/stream_reset_handler.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/retransmission_error_counter.h" +#include "net/dcsctp/tx/retransmission_queue.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "net/dcsctp/tx/send_queue.h" + +namespace dcsctp { + +// The TransmissionControlBlock (TCB) represents an open connection to a peer, +// and holds all the resources for that. If the connection is e.g. shutdown, +// closed or restarted, this object will be deleted and/or replaced. +class TransmissionControlBlock : public Context { + public: + TransmissionControlBlock(TimerManager& timer_manager, + absl::string_view log_prefix, + const DcSctpOptions& options, + const Capabilities& capabilities, + DcSctpSocketCallbacks& callbacks, + SendQueue& send_queue, + VerificationTag my_verification_tag, + TSN my_initial_tsn, + VerificationTag peer_verification_tag, + TSN peer_initial_tsn, + size_t a_rwnd, + TieTag tie_tag, + PacketSender& packet_sender, + std::function<bool()> is_connection_established); + + // Implementation of `Context`. + bool is_connection_established() const override { + return is_connection_established_(); + } + TSN my_initial_tsn() const override { return my_initial_tsn_; } + TSN peer_initial_tsn() const override { return peer_initial_tsn_; } + DcSctpSocketCallbacks& callbacks() const override { return callbacks_; } + void ObserveRTT(DurationMs rtt) override; + DurationMs current_rto() const override { return rto_.rto(); } + bool IncrementTxErrorCounter(absl::string_view reason) override { + return tx_error_counter_.Increment(reason); + } + void ClearTxErrorCounter() override { tx_error_counter_.Clear(); } + SctpPacket::Builder PacketBuilder() const override { + return SctpPacket::Builder(peer_verification_tag_, options_); + } + bool HasTooManyTxErrors() const override { + return tx_error_counter_.IsExhausted(); + } + void Send(SctpPacket::Builder& builder) override { + packet_sender_.Send(builder); + } + + // Other accessors + DataTracker& data_tracker() { return data_tracker_; } + ReassemblyQueue& reassembly_queue() { return reassembly_queue_; } + RetransmissionQueue& retransmission_queue() { return retransmission_queue_; } + StreamResetHandler& stream_reset_handler() { return stream_reset_handler_; } + HeartbeatHandler& heartbeat_handler() { return heartbeat_handler_; } + size_t cwnd() const { return retransmission_queue_.cwnd(); } + DurationMs current_srtt() const { return rto_.srtt(); } + + // Returns this socket's verification tag, set in all packet headers. + VerificationTag my_verification_tag() const { return my_verification_tag_; } + // Returns the peer's verification tag, which should be in received packets. + VerificationTag peer_verification_tag() const { + return peer_verification_tag_; + } + // All negotiated supported capabilities. + const Capabilities& capabilities() const { return capabilities_; } + // A 64-bit tie-tag, used to e.g. detect reconnections. + TieTag tie_tag() const { return tie_tag_; } + + // Sends a SACK, if there is a need to. + void MaybeSendSack(); + + // Sends a FORWARD-TSN, if it is needed and allowed (rate-limited). + void MaybeSendForwardTsn(SctpPacket::Builder& builder, TimeMs now); + + // Will be set while the socket is in kCookieEcho state. In this state, there + // can only be a single packet outstanding, and it must contain the COOKIE + // ECHO chunk as the first chunk in that packet, until the COOKIE ACK has been + // received, which will make the socket call `ClearCookieEchoChunk`. + void SetCookieEchoChunk(CookieEchoChunk chunk) { + cookie_echo_chunk_ = std::move(chunk); + } + + // Called when the COOKIE ACK chunk has been received, to allow further + // packets to be sent. + void ClearCookieEchoChunk() { cookie_echo_chunk_ = absl::nullopt; } + + bool has_cookie_echo_chunk() const { return cookie_echo_chunk_.has_value(); } + + void MaybeSendFastRetransmit(); + + // Fills `builder` (which may already be filled with control chunks) with + // other control and data chunks, and sends packets as much as can be + // allowed by the congestion control algorithm. + void SendBufferedPackets(SctpPacket::Builder& builder, TimeMs now); + + // As above, but without passing in a builder. If `cookie_echo_chunk_` is + // present, then only one packet will be sent, with this chunk as the first + // chunk. + void SendBufferedPackets(TimeMs now) { + SctpPacket::Builder builder(peer_verification_tag_, options_); + SendBufferedPackets(builder, now); + } + + // Returns a textual representation of this object, for logging. + std::string ToString() const; + + HandoverReadinessStatus GetHandoverReadiness() const; + + void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& handover_state); + + private: + // Will be called when the retransmission timer (t3-rtx) expires. + absl::optional<DurationMs> OnRtxTimerExpiry(); + // Will be called when the delayed ack timer expires. + absl::optional<DurationMs> OnDelayedAckTimerExpiry(); + + const std::string log_prefix_; + const DcSctpOptions options_; + TimerManager& timer_manager_; + // Negotiated capabilities that both peers support. + const Capabilities capabilities_; + DcSctpSocketCallbacks& callbacks_; + // The data retransmission timer, called t3-rtx in SCTP. + const std::unique_ptr<Timer> t3_rtx_; + // Delayed ack timer, which triggers when acks should be sent (when delayed). + const std::unique_ptr<Timer> delayed_ack_timer_; + const VerificationTag my_verification_tag_; + const TSN my_initial_tsn_; + const VerificationTag peer_verification_tag_; + const TSN peer_initial_tsn_; + // Nonce, used to detect reconnections. + const TieTag tie_tag_; + const std::function<bool()> is_connection_established_; + PacketSender& packet_sender_; + // Rate limiting of FORWARD-TSN. Next can be sent at or after this timestamp. + TimeMs limit_forward_tsn_until_ = TimeMs(0); + + RetransmissionTimeout rto_; + RetransmissionErrorCounter tx_error_counter_; + DataTracker data_tracker_; + ReassemblyQueue reassembly_queue_; + RetransmissionQueue retransmission_queue_; + StreamResetHandler stream_reset_handler_; + HeartbeatHandler heartbeat_handler_; + + // Only valid when the socket state == State::kCookieEchoed. In this state, + // the socket must wait for COOKIE ACK to continue sending any packets (not + // including a COOKIE ECHO). So if `cookie_echo_chunk_` is present, the + // SendBufferedChunks will always only just send one packet, with this chunk + // as the first chunk in the packet. + absl::optional<CookieEchoChunk> cookie_echo_chunk_ = absl::nullopt; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_SOCKET_TRANSMISSION_CONTROL_BLOCK_H_ diff --git a/third_party/libwebrtc/net/dcsctp/testing/BUILD.gn b/third_party/libwebrtc/net/dcsctp/testing/BUILD.gn new file mode 100644 index 0000000000..7e005a1f0c --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/testing/BUILD.gn @@ -0,0 +1,33 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("testing_macros") { + testonly = true + sources = [ "testing_macros.h" ] +} + +rtc_library("data_generator") { + testonly = true + deps = [ + "../../../api:array_view", + "../../../rtc_base:checks", + "../common:internal_types", + "../packet:data", + "../public:types", + ] + sources = [ + "data_generator.cc", + "data_generator.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} diff --git a/third_party/libwebrtc/net/dcsctp/testing/data_generator.cc b/third_party/libwebrtc/net/dcsctp/testing/data_generator.cc new file mode 100644 index 0000000000..e4f9f91384 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/testing/data_generator.cc @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/testing/data_generator.h" + +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { +constexpr PPID kPpid = PPID(53); + +Data DataGenerator::Ordered(std::vector<uint8_t> payload, + absl::string_view flags, + const DataGeneratorOptions opts) { + Data::IsBeginning is_beginning(flags.find('B') != std::string::npos); + Data::IsEnd is_end(flags.find('E') != std::string::npos); + + if (is_beginning) { + fsn_ = FSN(0); + } else { + fsn_ = FSN(*fsn_ + 1); + } + MID message_id = opts.message_id.value_or(message_id_); + Data ret = Data(opts.stream_id, SSN(static_cast<uint16_t>(*message_id)), + message_id, fsn_, opts.ppid, std::move(payload), is_beginning, + is_end, IsUnordered(false)); + + if (is_end) { + message_id_ = MID(*message_id + 1); + } + return ret; +} + +Data DataGenerator::Unordered(std::vector<uint8_t> payload, + absl::string_view flags, + const DataGeneratorOptions opts) { + Data::IsBeginning is_beginning(flags.find('B') != std::string::npos); + Data::IsEnd is_end(flags.find('E') != std::string::npos); + + if (is_beginning) { + fsn_ = FSN(0); + } else { + fsn_ = FSN(*fsn_ + 1); + } + MID message_id = opts.message_id.value_or(message_id_); + Data ret = Data(opts.stream_id, SSN(0), message_id, fsn_, kPpid, + std::move(payload), is_beginning, is_end, IsUnordered(true)); + if (is_end) { + message_id_ = MID(*message_id + 1); + } + return ret; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/testing/data_generator.h b/third_party/libwebrtc/net/dcsctp/testing/data_generator.h new file mode 100644 index 0000000000..f917c740a7 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/testing/data_generator.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TESTING_DATA_GENERATOR_H_ +#define NET_DCSCTP_TESTING_DATA_GENERATOR_H_ + +#include <cstdint> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/data.h" + +namespace dcsctp { + +struct DataGeneratorOptions { + StreamID stream_id = StreamID(1); + absl::optional<MID> message_id = absl::nullopt; + PPID ppid = PPID(53); +}; + +// Generates Data with correct sequence numbers, and used only in unit tests. +class DataGenerator { + public: + explicit DataGenerator(MID start_message_id = MID(0)) + : message_id_(start_message_id) {} + + // Generates ordered "data" with the provided `payload` and flags, which can + // contain "B" for setting the "is_beginning" flag, and/or "E" for setting the + // "is_end" flag. + Data Ordered(std::vector<uint8_t> payload, + absl::string_view flags = "", + DataGeneratorOptions opts = {}); + + // Generates unordered "data" with the provided `payload` and flags, which can + // contain "B" for setting the "is_beginning" flag, and/or "E" for setting the + // "is_end" flag. + Data Unordered(std::vector<uint8_t> payload, + absl::string_view flags = "", + DataGeneratorOptions opts = {}); + + // Resets the Message ID identifier - simulating a "stream reset". + void ResetStream() { message_id_ = MID(0); } + + private: + MID message_id_; + FSN fsn_ = FSN(0); +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TESTING_DATA_GENERATOR_H_ diff --git a/third_party/libwebrtc/net/dcsctp/testing/testing_macros.h b/third_party/libwebrtc/net/dcsctp/testing/testing_macros.h new file mode 100644 index 0000000000..5cbdfffdce --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/testing/testing_macros.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TESTING_TESTING_MACROS_H_ +#define NET_DCSCTP_TESTING_TESTING_MACROS_H_ + +#include <utility> + +namespace dcsctp { + +#define DCSCTP_CONCAT_INNER_(x, y) x##y +#define DCSCTP_CONCAT_(x, y) DCSCTP_CONCAT_INNER_(x, y) + +// Similar to ASSERT_OK_AND_ASSIGN, this works with an absl::optional<> instead +// of an absl::StatusOr<>. +#define ASSERT_HAS_VALUE_AND_ASSIGN(lhs, rexpr) \ + auto DCSCTP_CONCAT_(tmp_opt_val__, __LINE__) = rexpr; \ + ASSERT_TRUE(DCSCTP_CONCAT_(tmp_opt_val__, __LINE__).has_value()); \ + lhs = *std::move(DCSCTP_CONCAT_(tmp_opt_val__, __LINE__)); + +} // namespace dcsctp + +#endif // NET_DCSCTP_TESTING_TESTING_MACROS_H_ diff --git a/third_party/libwebrtc/net/dcsctp/timer/BUILD.gn b/third_party/libwebrtc/net/dcsctp/timer/BUILD.gn new file mode 100644 index 0000000000..d3be1ec872 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/timer/BUILD.gn @@ -0,0 +1,74 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_library("timer") { + deps = [ + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base:checks", + "../../../rtc_base:strong_alias", + "../../../rtc_base/containers:flat_map", + "../../../rtc_base/containers:flat_set", + "../public:socket", + "../public:types", + ] + sources = [ + "fake_timeout.h", + "timer.cc", + "timer.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/memory", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("task_queue_timeout") { + deps = [ + "../../../api:array_view", + "../../../api/task_queue:pending_task_safety_flag", + "../../../api/task_queue:task_queue", + "../../../api/units:time_delta", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../public:socket", + "../public:types", + ] + sources = [ + "task_queue_timeout.cc", + "task_queue_timeout.h", + ] +} + +if (rtc_include_tests) { + rtc_library("dcsctp_timer_unittests") { + testonly = true + + defines = [] + deps = [ + ":task_queue_timeout", + ":timer", + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../api/task_queue/test:mock_task_queue_base", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../test:test_support", + "../../../test/time_controller:time_controller", + "../public:socket", + ] + sources = [ + "task_queue_timeout_test.cc", + "timer_test.cc", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + } +} diff --git a/third_party/libwebrtc/net/dcsctp/timer/fake_timeout.h b/third_party/libwebrtc/net/dcsctp/timer/fake_timeout.h new file mode 100644 index 0000000000..74ffe5af29 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/timer/fake_timeout.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TIMER_FAKE_TIMEOUT_H_ +#define NET_DCSCTP_TIMER_FAKE_TIMEOUT_H_ + +#include <cstdint> +#include <functional> +#include <limits> +#include <memory> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/timeout.h" +#include "rtc_base/checks.h" +#include "rtc_base/containers/flat_set.h" + +namespace dcsctp { + +// A timeout used in tests. +class FakeTimeout : public Timeout { + public: + FakeTimeout(std::function<TimeMs()> get_time, + std::function<void(FakeTimeout*)> on_delete) + : get_time_(std::move(get_time)), on_delete_(std::move(on_delete)) {} + + ~FakeTimeout() override { on_delete_(this); } + + void Start(DurationMs duration_ms, TimeoutID timeout_id) override { + RTC_DCHECK(expiry_ == TimeMs::InfiniteFuture()); + timeout_id_ = timeout_id; + expiry_ = get_time_() + duration_ms; + } + void Stop() override { + RTC_DCHECK(expiry_ != TimeMs::InfiniteFuture()); + expiry_ = TimeMs::InfiniteFuture(); + } + + bool EvaluateHasExpired(TimeMs now) { + if (now >= expiry_) { + expiry_ = TimeMs::InfiniteFuture(); + return true; + } + return false; + } + + TimeoutID timeout_id() const { return timeout_id_; } + + private: + const std::function<TimeMs()> get_time_; + const std::function<void(FakeTimeout*)> on_delete_; + + TimeoutID timeout_id_ = TimeoutID(0); + TimeMs expiry_ = TimeMs::InfiniteFuture(); +}; + +class FakeTimeoutManager { + public: + // The `get_time` function must return the current time, relative to any + // epoch. + explicit FakeTimeoutManager(std::function<TimeMs()> get_time) + : get_time_(std::move(get_time)) {} + + std::unique_ptr<FakeTimeout> CreateTimeout() { + auto timer = std::make_unique<FakeTimeout>( + get_time_, [this](FakeTimeout* timer) { timers_.erase(timer); }); + timers_.insert(timer.get()); + return timer; + } + std::unique_ptr<FakeTimeout> CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision) { + // FakeTimeout does not support implement |precision|. + return CreateTimeout(); + } + + // NOTE: This can't return a vector, as calling EvaluateHasExpired requires + // calling socket->HandleTimeout directly afterwards, as the owning Timer + // still believes it's running, and it needs to be updated to set + // Timer::is_running_ to false before you operate on the Timer or Timeout + // again. + absl::optional<TimeoutID> GetNextExpiredTimeout() { + TimeMs now = get_time_(); + std::vector<TimeoutID> expired_timers; + for (auto& timer : timers_) { + if (timer->EvaluateHasExpired(now)) { + return timer->timeout_id(); + } + } + return absl::nullopt; + } + + private: + const std::function<TimeMs()> get_time_; + webrtc::flat_set<FakeTimeout*> timers_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_TIMER_FAKE_TIMEOUT_H_ diff --git a/third_party/libwebrtc/net/dcsctp/timer/task_queue_timeout.cc b/third_party/libwebrtc/net/dcsctp/timer/task_queue_timeout.cc new file mode 100644 index 0000000000..6c43640d39 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/timer/task_queue_timeout.cc @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/timer/task_queue_timeout.h" + +#include "api/task_queue/pending_task_safety_flag.h" +#include "api/units/time_delta.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +TaskQueueTimeoutFactory::TaskQueueTimeout::TaskQueueTimeout( + TaskQueueTimeoutFactory& parent, + webrtc::TaskQueueBase::DelayPrecision precision) + : parent_(parent), + precision_(precision), + pending_task_safety_flag_(webrtc::PendingTaskSafetyFlag::Create()) {} + +TaskQueueTimeoutFactory::TaskQueueTimeout::~TaskQueueTimeout() { + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + pending_task_safety_flag_->SetNotAlive(); +} + +void TaskQueueTimeoutFactory::TaskQueueTimeout::Start(DurationMs duration_ms, + TimeoutID timeout_id) { + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + RTC_DCHECK(timeout_expiration_ == TimeMs::InfiniteFuture()); + timeout_expiration_ = parent_.get_time_() + duration_ms; + timeout_id_ = timeout_id; + + if (timeout_expiration_ >= posted_task_expiration_) { + // There is already a running task, and it's scheduled to expire sooner than + // the new expiration time. Don't do anything; The `timeout_expiration_` has + // already been updated and if the delayed task _does_ expire and the timer + // hasn't been stopped, that will be noticed in the timeout handler, and the + // task will be re-scheduled. Most timers are stopped before they expire. + return; + } + + if (posted_task_expiration_ != TimeMs::InfiniteFuture()) { + RTC_DLOG(LS_VERBOSE) << "New timeout duration is less than scheduled - " + "ghosting old delayed task."; + // There is already a scheduled delayed task, but its expiration time is + // further away than the new expiration, so it can't be used. It will be + // "killed" by replacing the safety flag. This is not expected to happen + // especially often; Mainly when a timer did exponential backoff and + // later recovered. + pending_task_safety_flag_->SetNotAlive(); + pending_task_safety_flag_ = webrtc::PendingTaskSafetyFlag::Create(); + } + + posted_task_expiration_ = timeout_expiration_; + parent_.task_queue_.PostDelayedTaskWithPrecision( + precision_, + webrtc::SafeTask( + pending_task_safety_flag_, + [timeout_id, this]() { + RTC_DLOG(LS_VERBOSE) << "Timout expired: " << timeout_id.value(); + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + RTC_DCHECK(posted_task_expiration_ != TimeMs::InfiniteFuture()); + posted_task_expiration_ = TimeMs::InfiniteFuture(); + + if (timeout_expiration_ == TimeMs::InfiniteFuture()) { + // The timeout was stopped before it expired. Very common. + } else { + // Note that the timeout might have been restarted, which updated + // `timeout_expiration_` but left the scheduled task running. So + // if it's not quite time to trigger the timeout yet, schedule a + // new delayed task with what's remaining and retry at that point + // in time. + DurationMs remaining = timeout_expiration_ - parent_.get_time_(); + timeout_expiration_ = TimeMs::InfiniteFuture(); + if (*remaining > 0) { + Start(remaining, timeout_id_); + } else { + // It has actually triggered. + RTC_DLOG(LS_VERBOSE) + << "Timout triggered: " << timeout_id.value(); + parent_.on_expired_(timeout_id_); + } + } + }), + webrtc::TimeDelta::Millis(duration_ms.value())); +} + +void TaskQueueTimeoutFactory::TaskQueueTimeout::Stop() { + // As the TaskQueue doesn't support deleting a posted task, just mark the + // timeout as not running. + RTC_DCHECK_RUN_ON(&parent_.thread_checker_); + timeout_expiration_ = TimeMs::InfiniteFuture(); +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/timer/task_queue_timeout.h b/third_party/libwebrtc/net/dcsctp/timer/task_queue_timeout.h new file mode 100644 index 0000000000..faae14464f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/timer/task_queue_timeout.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TIMER_TASK_QUEUE_TIMEOUT_H_ +#define NET_DCSCTP_TIMER_TASK_QUEUE_TIMEOUT_H_ + +#include <memory> +#include <utility> + +#include "api/task_queue/pending_task_safety_flag.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/timeout.h" + +namespace dcsctp { + +// The TaskQueueTimeoutFactory creates `Timeout` instances, which schedules +// itself to be triggered on the provided `task_queue`, which may be a thread, +// an actual TaskQueue or something else which supports posting a delayed task. +// +// Note that each `DcSctpSocket` must have its own `TaskQueueTimeoutFactory`, +// as the `TimeoutID` are not unique among sockets. +// +// This class must outlive any created Timeout that it has created. Note that +// the `DcSctpSocket` will ensure that all Timeouts are deleted when the socket +// is destructed, so this means that this class must outlive the `DcSctpSocket`. +// +// This class, and the timeouts created it, are not thread safe. +class TaskQueueTimeoutFactory { + public: + // The `get_time` function must return the current time, relative to any + // epoch. Whenever a timeout expires, the `on_expired` callback will be + // triggered, and then the client should provided `timeout_id` to + // `DcSctpSocketInterface::HandleTimeout`. + TaskQueueTimeoutFactory(webrtc::TaskQueueBase& task_queue, + std::function<TimeMs()> get_time, + std::function<void(TimeoutID timeout_id)> on_expired) + : task_queue_(task_queue), + get_time_(std::move(get_time)), + on_expired_(std::move(on_expired)) {} + + // Creates an implementation of `Timeout`. + std::unique_ptr<Timeout> CreateTimeout( + webrtc::TaskQueueBase::DelayPrecision precision = + webrtc::TaskQueueBase::DelayPrecision::kLow) { + return std::make_unique<TaskQueueTimeout>(*this, precision); + } + + private: + class TaskQueueTimeout : public Timeout { + public: + TaskQueueTimeout(TaskQueueTimeoutFactory& parent, + webrtc::TaskQueueBase::DelayPrecision precision); + ~TaskQueueTimeout(); + + void Start(DurationMs duration_ms, TimeoutID timeout_id) override; + void Stop() override; + + private: + TaskQueueTimeoutFactory& parent_; + const webrtc::TaskQueueBase::DelayPrecision precision_; + // A safety flag to ensure that posted tasks to the task queue don't + // reference these object when they go out of scope. Note that this safety + // flag will be re-created if the scheduled-but-not-yet-expired task is not + // to be run. This happens when there is a posted delayed task with an + // expiration time _further away_ than what is now the expected expiration + // time. In this scenario, a new delayed task has to be posted with a + // shorter duration and the old task has to be forgotten. + rtc::scoped_refptr<webrtc::PendingTaskSafetyFlag> pending_task_safety_flag_; + // The time when the posted delayed task is set to expire. Will be set to + // the infinite future if there is no such task running. + TimeMs posted_task_expiration_ = TimeMs::InfiniteFuture(); + // The time when the timeout expires. It will be set to the infinite future + // if the timeout is not running/not started. + TimeMs timeout_expiration_ = TimeMs::InfiniteFuture(); + // The current timeout ID that will be reported when expired. + TimeoutID timeout_id_ = TimeoutID(0); + }; + + RTC_NO_UNIQUE_ADDRESS webrtc::SequenceChecker thread_checker_; + webrtc::TaskQueueBase& task_queue_; + const std::function<TimeMs()> get_time_; + const std::function<void(TimeoutID)> on_expired_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TIMER_TASK_QUEUE_TIMEOUT_H_ diff --git a/third_party/libwebrtc/net/dcsctp/timer/task_queue_timeout_test.cc b/third_party/libwebrtc/net/dcsctp/timer/task_queue_timeout_test.cc new file mode 100644 index 0000000000..f360ba7a58 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/timer/task_queue_timeout_test.cc @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/timer/task_queue_timeout.h" + +#include <memory> + +#include "api/task_queue/task_queue_base.h" +#include "api/task_queue/test/mock_task_queue_base.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" +#include "test/time_controller/simulated_time_controller.h" + +namespace dcsctp { +namespace { +using ::testing::_; +using ::testing::MockFunction; +using ::testing::NiceMock; + +class TaskQueueTimeoutTest : public testing::Test { + protected: + TaskQueueTimeoutTest() + : time_controller_(webrtc::Timestamp::Millis(1234)), + task_queue_(time_controller_.GetMainThread()), + factory_( + *task_queue_, + [this]() { + return TimeMs(time_controller_.GetClock()->CurrentTime().ms()); + }, + on_expired_.AsStdFunction()) {} + + void AdvanceTime(DurationMs duration) { + time_controller_.AdvanceTime(webrtc::TimeDelta::Millis(*duration)); + } + + MockFunction<void(TimeoutID)> on_expired_; + webrtc::GlobalSimulatedTimeController time_controller_; + + rtc::Thread* task_queue_; + TaskQueueTimeoutFactory factory_; +}; + +TEST_F(TaskQueueTimeoutTest, StartPostsDelayedTask) { + std::unique_ptr<Timeout> timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(999)); + + EXPECT_CALL(on_expired_, Call(TimeoutID(1))); + AdvanceTime(DurationMs(1)); +} + +TEST_F(TaskQueueTimeoutTest, StopBeforeExpiringDoesntTrigger) { + std::unique_ptr<Timeout> timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(999)); + + timeout->Stop(); + + AdvanceTime(DurationMs(1)); + AdvanceTime(DurationMs(1000)); +} + +TEST_F(TaskQueueTimeoutTest, RestartPrologingTimeoutDuration) { + std::unique_ptr<Timeout> timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(500)); + + timeout->Restart(DurationMs(1000), TimeoutID(2)); + + AdvanceTime(DurationMs(999)); + + EXPECT_CALL(on_expired_, Call(TimeoutID(2))); + AdvanceTime(DurationMs(1)); +} + +TEST_F(TaskQueueTimeoutTest, RestartWithShorterDurationExpiresWhenExpected) { + std::unique_ptr<Timeout> timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(500)); + + timeout->Restart(DurationMs(200), TimeoutID(2)); + + AdvanceTime(DurationMs(199)); + + EXPECT_CALL(on_expired_, Call(TimeoutID(2))); + AdvanceTime(DurationMs(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(1000)); +} + +TEST_F(TaskQueueTimeoutTest, KilledBeforeExpired) { + std::unique_ptr<Timeout> timeout = factory_.CreateTimeout(); + timeout->Start(DurationMs(1000), TimeoutID(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(500)); + + timeout = nullptr; + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTime(DurationMs(1000)); +} + +TEST(TaskQueueTimeoutWithMockTaskQueueTest, CanSetTimeoutPrecisionToLow) { + NiceMock<webrtc::MockTaskQueueBase> mock_task_queue; + EXPECT_CALL(mock_task_queue, PostDelayedTask(_, _)); + TaskQueueTimeoutFactory factory( + mock_task_queue, []() { return TimeMs(1337); }, + [](TimeoutID timeout_id) {}); + std::unique_ptr<Timeout> timeout = + factory.CreateTimeout(webrtc::TaskQueueBase::DelayPrecision::kLow); + timeout->Start(DurationMs(1), TimeoutID(1)); +} + +TEST(TaskQueueTimeoutWithMockTaskQueueTest, CanSetTimeoutPrecisionToHigh) { + NiceMock<webrtc::MockTaskQueueBase> mock_task_queue; + EXPECT_CALL(mock_task_queue, PostDelayedHighPrecisionTask(_, _)); + TaskQueueTimeoutFactory factory( + mock_task_queue, []() { return TimeMs(1337); }, + [](TimeoutID timeout_id) {}); + std::unique_ptr<Timeout> timeout = + factory.CreateTimeout(webrtc::TaskQueueBase::DelayPrecision::kHigh); + timeout->Start(DurationMs(1), TimeoutID(1)); +} + +TEST(TaskQueueTimeoutWithMockTaskQueueTest, TimeoutPrecisionIsLowByDefault) { + NiceMock<webrtc::MockTaskQueueBase> mock_task_queue; + EXPECT_CALL(mock_task_queue, PostDelayedTask(_, _)); + TaskQueueTimeoutFactory factory( + mock_task_queue, []() { return TimeMs(1337); }, + [](TimeoutID timeout_id) {}); + std::unique_ptr<Timeout> timeout = factory.CreateTimeout(); + timeout->Start(DurationMs(1), TimeoutID(1)); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/timer/timer.cc b/third_party/libwebrtc/net/dcsctp/timer/timer.cc new file mode 100644 index 0000000000..bde07638a5 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/timer/timer.cc @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/timer/timer.h" + +#include <algorithm> +#include <cstdint> +#include <limits> +#include <memory> +#include <utility> + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "net/dcsctp/public/timeout.h" +#include "rtc_base/checks.h" + +namespace dcsctp { +namespace { +TimeoutID MakeTimeoutId(TimerID timer_id, TimerGeneration generation) { + return TimeoutID(static_cast<uint64_t>(*timer_id) << 32 | *generation); +} + +DurationMs GetBackoffDuration(const TimerOptions& options, + DurationMs base_duration, + int expiration_count) { + switch (options.backoff_algorithm) { + case TimerBackoffAlgorithm::kFixed: + return base_duration; + case TimerBackoffAlgorithm::kExponential: { + int32_t duration_ms = *base_duration; + + while (expiration_count > 0 && duration_ms < *Timer::kMaxTimerDuration) { + duration_ms *= 2; + --expiration_count; + + if (options.max_backoff_duration.has_value() && + duration_ms > **options.max_backoff_duration) { + return *options.max_backoff_duration; + } + } + + return DurationMs(std::min(duration_ms, *Timer::kMaxTimerDuration)); + } + } +} +} // namespace + +constexpr DurationMs Timer::kMaxTimerDuration; + +Timer::Timer(TimerID id, + absl::string_view name, + OnExpired on_expired, + UnregisterHandler unregister_handler, + std::unique_ptr<Timeout> timeout, + const TimerOptions& options) + : id_(id), + name_(name), + options_(options), + on_expired_(std::move(on_expired)), + unregister_handler_(std::move(unregister_handler)), + timeout_(std::move(timeout)), + duration_(options.duration) {} + +Timer::~Timer() { + Stop(); + unregister_handler_(); +} + +void Timer::Start() { + expiration_count_ = 0; + if (!is_running()) { + is_running_ = true; + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Start(duration_, MakeTimeoutId(id_, generation_)); + } else { + // Timer was running - stop and restart it, to make it expire in `duration_` + // from now. + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Restart(duration_, MakeTimeoutId(id_, generation_)); + } +} + +void Timer::Stop() { + if (is_running()) { + timeout_->Stop(); + expiration_count_ = 0; + is_running_ = false; + } +} + +void Timer::Trigger(TimerGeneration generation) { + if (is_running_ && generation == generation_) { + ++expiration_count_; + is_running_ = false; + if (!options_.max_restarts.has_value() || + expiration_count_ <= *options_.max_restarts) { + // The timer should still be running after this triggers. Start a new + // timer. Note that it might be very quickly restarted again, if the + // `on_expired_` callback returns a new duration. + is_running_ = true; + DurationMs duration = + GetBackoffDuration(options_, duration_, expiration_count_); + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Start(duration, MakeTimeoutId(id_, generation_)); + } + + absl::optional<DurationMs> new_duration = on_expired_(); + if (new_duration.has_value() && new_duration != duration_) { + duration_ = new_duration.value(); + if (is_running_) { + // Restart it with new duration. + timeout_->Stop(); + + DurationMs duration = + GetBackoffDuration(options_, duration_, expiration_count_); + generation_ = TimerGeneration(*generation_ + 1); + timeout_->Start(duration, MakeTimeoutId(id_, generation_)); + } + } + } +} + +void TimerManager::HandleTimeout(TimeoutID timeout_id) { + TimerID timer_id(*timeout_id >> 32); + TimerGeneration generation(*timeout_id); + auto it = timers_.find(timer_id); + if (it != timers_.end()) { + it->second->Trigger(generation); + } +} + +std::unique_ptr<Timer> TimerManager::CreateTimer(absl::string_view name, + Timer::OnExpired on_expired, + const TimerOptions& options) { + next_id_ = TimerID(*next_id_ + 1); + TimerID id = next_id_; + // This would overflow after 4 billion timers created, which in SCTP would be + // after 800 million reconnections on a single socket. Ensure this will never + // happen. + RTC_CHECK_NE(*id, std::numeric_limits<uint32_t>::max()); + std::unique_ptr<Timeout> timeout = create_timeout_(options.precision); + RTC_CHECK(timeout != nullptr); + auto timer = absl::WrapUnique(new Timer( + id, name, std::move(on_expired), [this, id]() { timers_.erase(id); }, + std::move(timeout), options)); + timers_[id] = timer.get(); + return timer; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/timer/timer.h b/third_party/libwebrtc/net/dcsctp/timer/timer.h new file mode 100644 index 0000000000..31b496dc81 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/timer/timer.h @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TIMER_TIMER_H_ +#define NET_DCSCTP_TIMER_TIMER_H_ + +#include <stdint.h> + +#include <algorithm> +#include <functional> +#include <map> +#include <memory> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/timeout.h" +#include "rtc_base/strong_alias.h" + +namespace dcsctp { + +using TimerID = webrtc::StrongAlias<class TimerIDTag, uint32_t>; +using TimerGeneration = webrtc::StrongAlias<class TimerGenerationTag, uint32_t>; + +enum class TimerBackoffAlgorithm { + // The base duration will be used for any restart. + kFixed, + // An exponential backoff is used for restarts, with a 2x multiplier, meaning + // that every restart will use a duration that is twice as long as the + // previous. + kExponential, +}; + +struct TimerOptions { + explicit TimerOptions(DurationMs duration) + : TimerOptions(duration, TimerBackoffAlgorithm::kExponential) {} + TimerOptions(DurationMs duration, TimerBackoffAlgorithm backoff_algorithm) + : TimerOptions(duration, backoff_algorithm, absl::nullopt) {} + TimerOptions(DurationMs duration, + TimerBackoffAlgorithm backoff_algorithm, + absl::optional<int> max_restarts) + : TimerOptions(duration, backoff_algorithm, max_restarts, absl::nullopt) { + } + TimerOptions(DurationMs duration, + TimerBackoffAlgorithm backoff_algorithm, + absl::optional<int> max_restarts, + absl::optional<DurationMs> max_backoff_duration) + : TimerOptions(duration, + backoff_algorithm, + max_restarts, + max_backoff_duration, + webrtc::TaskQueueBase::DelayPrecision::kLow) {} + TimerOptions(DurationMs duration, + TimerBackoffAlgorithm backoff_algorithm, + absl::optional<int> max_restarts, + absl::optional<DurationMs> max_backoff_duration, + webrtc::TaskQueueBase::DelayPrecision precision) + : duration(duration), + backoff_algorithm(backoff_algorithm), + max_restarts(max_restarts), + max_backoff_duration(max_backoff_duration), + precision(precision) {} + + // The initial timer duration. Can be overridden with `set_duration`. + const DurationMs duration; + // If the duration should be increased (using exponential backoff) when it is + // restarted. If not set, the same duration will be used. + const TimerBackoffAlgorithm backoff_algorithm; + // The maximum number of times that the timer will be automatically restarted, + // or absl::nullopt if there is no limit. + const absl::optional<int> max_restarts; + // The maximum timeout value for exponential backoff. + const absl::optional<DurationMs> max_backoff_duration; + // The precision of the webrtc::TaskQueueBase used for scheduling. + const webrtc::TaskQueueBase::DelayPrecision precision; +}; + +// A high-level timer (in contrast to the low-level `Timeout` class). +// +// Timers are started and can be stopped or restarted. When a timer expires, +// the provided `on_expired` callback will be triggered. A timer is +// automatically restarted, as long as the number of restarts is below the +// configurable `max_restarts` parameter. The `is_running` property can be +// queried to know if it's still running after having expired. +// +// When a timer is restarted, it will use a configurable `backoff_algorithm` to +// possibly adjust the duration of the next expiry. It is also possible to +// return a new base duration (which is the duration before it's adjusted by the +// backoff algorithm). +class Timer { + public: + // The maximum timer duration - one day. + static constexpr DurationMs kMaxTimerDuration = DurationMs(24 * 3600 * 1000); + + // When expired, the timer handler can optionally return a new duration which + // will be set as `duration` and used as base duration when the timer is + // restarted and as input to the backoff algorithm. + using OnExpired = std::function<absl::optional<DurationMs>()>; + + // TimerManager will have pointers to these instances, so they must not move. + Timer(const Timer&) = delete; + Timer& operator=(const Timer&) = delete; + + ~Timer(); + + // Starts the timer if it's stopped or restarts the timer if it's already + // running. The `expiration_count` will be reset. + void Start(); + + // Stops the timer. This can also be called when the timer is already stopped. + // The `expiration_count` will be reset. + void Stop(); + + // Sets the base duration. The actual timer duration may be larger depending + // on the backoff algorithm. + void set_duration(DurationMs duration) { + duration_ = std::min(duration, kMaxTimerDuration); + } + + // Retrieves the base duration. The actual timer duration may be larger + // depending on the backoff algorithm. + DurationMs duration() const { return duration_; } + + // Returns the number of times the timer has expired. + int expiration_count() const { return expiration_count_; } + + // Returns the timer's options. + const TimerOptions& options() const { return options_; } + + // Returns the name of the timer. + absl::string_view name() const { return name_; } + + // Indicates if this timer is currently running. + bool is_running() const { return is_running_; } + + private: + friend class TimerManager; + using UnregisterHandler = std::function<void()>; + Timer(TimerID id, + absl::string_view name, + OnExpired on_expired, + UnregisterHandler unregister, + std::unique_ptr<Timeout> timeout, + const TimerOptions& options); + + // Called by TimerManager. Will trigger the callback and increment + // `expiration_count`. The timer will automatically be restarted at the + // duration as decided by the backoff algorithm, unless the + // `TimerOptions::max_restarts` has been reached and then it will be stopped + // and `is_running()` will return false. + void Trigger(TimerGeneration generation); + + const TimerID id_; + const std::string name_; + const TimerOptions options_; + const OnExpired on_expired_; + const UnregisterHandler unregister_handler_; + const std::unique_ptr<Timeout> timeout_; + + DurationMs duration_; + + // Increased on each start, and is matched on Trigger, to avoid races. And by + // race, meaning that a timeout - which may be evaluated/expired on a + // different thread while this thread has stopped that timer already. Note + // that the entire socket is not thread-safe, so `TimerManager::HandleTimeout` + // is never executed concurrently with any timer starting/stopping. + // + // This will wrap around after 4 billion timer restarts, and if it wraps + // around, it would just trigger _this_ timer in advance (but it's hard to + // restart it 4 billion times within its duration). + TimerGeneration generation_ = TimerGeneration(0); + bool is_running_ = false; + // Incremented each time time has expired and reset when stopped or restarted. + int expiration_count_ = 0; +}; + +// Creates and manages timers. +class TimerManager { + public: + explicit TimerManager( + std::function<std::unique_ptr<Timeout>( + webrtc::TaskQueueBase::DelayPrecision)> create_timeout) + : create_timeout_(std::move(create_timeout)) {} + + // Creates a timer with name `name` that will expire (when started) after + // `options.duration` and call `on_expired`. There are more `options` that + // affects the behavior. Note that timers are created initially stopped. + std::unique_ptr<Timer> CreateTimer(absl::string_view name, + Timer::OnExpired on_expired, + const TimerOptions& options); + + void HandleTimeout(TimeoutID timeout_id); + + private: + const std::function<std::unique_ptr<Timeout>( + webrtc::TaskQueueBase::DelayPrecision)> + create_timeout_; + std::map<TimerID, Timer*> timers_; + TimerID next_id_ = TimerID(0); +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_TIMER_TIMER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/timer/timer_test.cc b/third_party/libwebrtc/net/dcsctp/timer/timer_test.cc new file mode 100644 index 0000000000..4aebe65b48 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/timer/timer_test.cc @@ -0,0 +1,459 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/timer/timer.h" + +#include <memory> + +#include "absl/types/optional.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/public/timeout.h" +#include "net/dcsctp/timer/fake_timeout.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::Return; + +class TimerTest : public testing::Test { + protected: + TimerTest() + : timeout_manager_([this]() { return now_; }), + manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return timeout_manager_.CreateTimeout(precision); + }) { + ON_CALL(on_expired_, Call).WillByDefault(Return(absl::nullopt)); + } + + void AdvanceTimeAndRunTimers(DurationMs duration) { + now_ = now_ + duration; + + for (;;) { + absl::optional<TimeoutID> timeout_id = + timeout_manager_.GetNextExpiredTimeout(); + if (!timeout_id.has_value()) { + break; + } + manager_.HandleTimeout(*timeout_id); + } + } + + TimeMs now_ = TimeMs(0); + FakeTimeoutManager timeout_manager_; + TimerManager manager_; + testing::MockFunction<absl::optional<DurationMs>()> on_expired_; +}; + +TEST_F(TimerTest, TimerIsInitiallyStopped) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed)); + + EXPECT_FALSE(t1->is_running()); +} + +TEST_F(TimerTest, TimerExpiresAtGivenTime) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + EXPECT_TRUE(t1->is_running()); + + AdvanceTimeAndRunTimers(DurationMs(4000)); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); +} + +TEST_F(TimerTest, TimerReschedulesAfterExpiredWithFixedBackoff) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + EXPECT_EQ(t1->expiration_count(), 0); + + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Fire first time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + EXPECT_EQ(t1->expiration_count(), 1); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Second time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + EXPECT_EQ(t1->expiration_count(), 2); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Third time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + EXPECT_EQ(t1->expiration_count(), 3); +} + +TEST_F(TimerTest, TimerWithNoRestarts) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed, + /*max_restart=*/0)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Fire first time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + EXPECT_FALSE(t1->is_running()); + + // Second time - shouldn't fire + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(5000)); + EXPECT_FALSE(t1->is_running()); +} + +TEST_F(TimerTest, TimerWithOneRestart) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed, + /*max_restart=*/1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Fire first time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Second time - max restart limit reached. + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_FALSE(t1->is_running()); + + // Third time - should not fire. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(5000)); + EXPECT_FALSE(t1->is_running()); +} + +TEST_F(TimerTest, TimerWithTwoRestart) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed, + /*max_restart=*/2)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Fire first time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Second time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_TRUE(t1->is_running()); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Third time + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_FALSE(t1->is_running()); +} + +TEST_F(TimerTest, TimerWithExponentialBackoff) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kExponential)); + + t1->Start(); + + // Fire first time at 5 seconds + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(5000)); + + // Second time at 5*2^1 = 10 seconds later. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(9000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + // Third time at 5*2^2 = 20 seconds later. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(19000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + // Fourth time at 5*2^3 = 40 seconds later. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(39000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); +} + +TEST_F(TimerTest, StartTimerWillStopAndStart) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kExponential)); + + t1->Start(); + + AdvanceTimeAndRunTimers(DurationMs(3000)); + + t1->Start(); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(2000)); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(3000)); +} + +TEST_F(TimerTest, ExpirationCounterWillResetIfStopped) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kExponential)); + + t1->Start(); + + // Fire first time at 5 seconds + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(5000)); + EXPECT_EQ(t1->expiration_count(), 1); + + // Second time at 5*2^1 = 10 seconds later. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(9000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_EQ(t1->expiration_count(), 2); + + t1->Start(); + EXPECT_EQ(t1->expiration_count(), 0); + + // Third time at 5*2^0 = 5 seconds later. + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_EQ(t1->expiration_count(), 1); +} + +TEST_F(TimerTest, StopTimerWillMakeItNotExpire) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kExponential)); + + t1->Start(); + EXPECT_TRUE(t1->is_running()); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4000)); + t1->Stop(); + EXPECT_FALSE(t1->is_running()); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(1000)); +} + +TEST_F(TimerTest, ReturningNewDurationWhenExpired) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(5000), TimerBackoffAlgorithm::kFixed)); + + EXPECT_CALL(on_expired_, Call).Times(0); + t1->Start(); + EXPECT_EQ(t1->duration(), DurationMs(5000)); + + AdvanceTimeAndRunTimers(DurationMs(4000)); + + // Fire first time + EXPECT_CALL(on_expired_, Call).WillOnce(Return(DurationMs(2000))); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_EQ(t1->duration(), DurationMs(2000)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + // Second time + EXPECT_CALL(on_expired_, Call).WillOnce(Return(DurationMs(10000))); + AdvanceTimeAndRunTimers(DurationMs(1000)); + EXPECT_EQ(t1->duration(), DurationMs(10000)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(9000)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); +} + +TEST_F(TimerTest, TimersHaveMaximumDuration) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kExponential)); + + t1->set_duration(DurationMs(2 * *Timer::kMaxTimerDuration)); + EXPECT_EQ(t1->duration(), Timer::kMaxTimerDuration); +} + +TEST_F(TimerTest, TimersHaveMaximumBackoffDuration) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kExponential)); + + t1->Start(); + + int max_exponent = static_cast<int>(log2(*Timer::kMaxTimerDuration / 1000)); + for (int i = 0; i < max_exponent; ++i) { + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000 * (1 << i))); + } + + // Reached the maximum duration. + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(Timer::kMaxTimerDuration); +} + +TEST_F(TimerTest, TimerCanBeStartedFromWithinExpirationHandler) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kFixed)); + + t1->Start(); + + // Start a timer, but don't return any new duration in callback. + EXPECT_CALL(on_expired_, Call).WillOnce([&]() { + EXPECT_TRUE(t1->is_running()); + t1->set_duration(DurationMs(5000)); + t1->Start(); + return absl::nullopt; + }); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4999)); + + // Start a timer, and return any new duration in callback. + EXPECT_CALL(on_expired_, Call).WillOnce([&]() { + EXPECT_TRUE(t1->is_running()); + t1->set_duration(DurationMs(5000)); + t1->Start(); + return absl::make_optional(DurationMs(8000)); + }); + AdvanceTimeAndRunTimers(DurationMs(1)); + + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(7999)); + + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1)); +} + +TEST_F(TimerTest, DurationStaysWithinMaxTimerBackOffDuration) { + std::unique_ptr<Timer> t1 = manager_.CreateTimer( + "t1", on_expired_.AsStdFunction(), + TimerOptions(DurationMs(1000), TimerBackoffAlgorithm::kExponential, + /*max_restarts=*/absl::nullopt, DurationMs(5000))); + + t1->Start(); + + // Initial timeout, 1000 ms + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1000)); + + // Exponential backoff -> 2000 ms + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(1999)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1)); + + // Exponential backoff -> 4000 ms + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(3999)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1)); + + // Limited backoff -> 5000ms + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4999)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1)); + + // ... where it plateaus + EXPECT_CALL(on_expired_, Call).Times(0); + AdvanceTimeAndRunTimers(DurationMs(4999)); + EXPECT_CALL(on_expired_, Call).Times(1); + AdvanceTimeAndRunTimers(DurationMs(1)); +} + +TEST(TimerManagerTest, TimerManagerPassesPrecisionToCreateTimeoutMethod) { + FakeTimeoutManager timeout_manager([&]() { return TimeMs(0); }); + absl::optional<webrtc::TaskQueueBase::DelayPrecision> create_timer_precison; + TimerManager manager([&](webrtc::TaskQueueBase::DelayPrecision precision) { + create_timer_precison = precision; + return timeout_manager.CreateTimeout(precision); + }); + // Default TimerOptions. + manager.CreateTimer( + "test_timer", []() { return absl::optional<DurationMs>(); }, + TimerOptions(DurationMs(123))); + EXPECT_EQ(create_timer_precison, webrtc::TaskQueueBase::DelayPrecision::kLow); + // High precision TimerOptions. + manager.CreateTimer( + "test_timer", []() { return absl::optional<DurationMs>(); }, + TimerOptions(DurationMs(123), TimerBackoffAlgorithm::kExponential, + absl::nullopt, absl::nullopt, + webrtc::TaskQueueBase::DelayPrecision::kHigh)); + EXPECT_EQ(create_timer_precison, + webrtc::TaskQueueBase::DelayPrecision::kHigh); + // Low precision TimerOptions. + manager.CreateTimer( + "test_timer", []() { return absl::optional<DurationMs>(); }, + TimerOptions(DurationMs(123), TimerBackoffAlgorithm::kExponential, + absl::nullopt, absl::nullopt, + webrtc::TaskQueueBase::DelayPrecision::kLow)); + EXPECT_EQ(create_timer_precison, webrtc::TaskQueueBase::DelayPrecision::kLow); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/BUILD.gn b/third_party/libwebrtc/net/dcsctp/tx/BUILD.gn new file mode 100644 index 0000000000..3cb7df4cc2 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/BUILD.gn @@ -0,0 +1,208 @@ +# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +import("../../../webrtc.gni") + +rtc_source_set("send_queue") { + deps = [ + "../../../api:array_view", + "../common:internal_types", + "../packet:chunk", + "../packet:data", + "../public:socket", + "../public:types", + ] + sources = [ "send_queue.h" ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] +} + +rtc_library("rr_send_queue") { + deps = [ + ":send_queue", + ":stream_scheduler", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base/containers:flat_map", + "../common:str_join", + "../packet:data", + "../public:socket", + "../public:types", + ] + sources = [ + "rr_send_queue.cc", + "rr_send_queue.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("stream_scheduler") { + deps = [ + ":send_queue", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:strong_alias", + "../../../rtc_base/containers:flat_set", + "../common:str_join", + "../packet:chunk", + "../packet:data", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + ] + sources = [ + "stream_scheduler.cc", + "stream_scheduler.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("retransmission_error_counter") { + deps = [ + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../public:types", + ] + sources = [ + "retransmission_error_counter.cc", + "retransmission_error_counter.h", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/strings" ] +} + +rtc_library("retransmission_timeout") { + deps = [ + "../../../rtc_base:checks", + "../public:types", + ] + sources = [ + "retransmission_timeout.cc", + "retransmission_timeout.h", + ] +} + +rtc_library("outstanding_data") { + deps = [ + ":retransmission_timeout", + ":send_queue", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../common:math", + "../common:sequence_numbers", + "../common:str_join", + "../packet:chunk", + "../packet:data", + "../public:socket", + "../public:types", + "../timer", + ] + sources = [ + "outstanding_data.cc", + "outstanding_data.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +rtc_library("retransmission_queue") { + deps = [ + ":outstanding_data", + ":retransmission_timeout", + ":send_queue", + "../../../api:array_view", + "../../../rtc_base:checks", + "../../../rtc_base:logging", + "../../../rtc_base:stringutils", + "../common:math", + "../common:sequence_numbers", + "../common:str_join", + "../packet:chunk", + "../packet:data", + "../public:socket", + "../public:types", + "../timer", + ] + sources = [ + "retransmission_queue.cc", + "retransmission_queue.h", + ] + absl_deps = [ + "//third_party/abseil-cpp/absl/algorithm:container", + "//third_party/abseil-cpp/absl/strings", + "//third_party/abseil-cpp/absl/types:optional", + ] +} + +if (rtc_include_tests) { + rtc_source_set("mock_send_queue") { + testonly = true + deps = [ + ":send_queue", + "../../../api:array_view", + "../../../test:test_support", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + sources = [ "mock_send_queue.h" ] + } + + rtc_library("dcsctp_tx_unittests") { + testonly = true + + deps = [ + ":mock_send_queue", + ":outstanding_data", + ":retransmission_error_counter", + ":retransmission_queue", + ":retransmission_timeout", + ":rr_send_queue", + ":send_queue", + ":stream_scheduler", + "../../../api:array_view", + "../../../api/task_queue:task_queue", + "../../../rtc_base:checks", + "../../../rtc_base:gunit_helpers", + "../../../test:test_support", + "../common:handover_testing", + "../common:math", + "../common:sequence_numbers", + "../packet:chunk", + "../packet:data", + "../packet:sctp_packet", + "../public:socket", + "../public:types", + "../socket:mock_callbacks", + "../socket:mock_callbacks", + "../testing:data_generator", + "../testing:testing_macros", + "../timer", + ] + absl_deps = [ "//third_party/abseil-cpp/absl/types:optional" ] + sources = [ + "outstanding_data_test.cc", + "retransmission_error_counter_test.cc", + "retransmission_queue_test.cc", + "retransmission_timeout_test.cc", + "rr_send_queue_test.cc", + "stream_scheduler_test.cc", + ] + } +} diff --git a/third_party/libwebrtc/net/dcsctp/tx/mock_send_queue.h b/third_party/libwebrtc/net/dcsctp/tx/mock_send_queue.h new file mode 100644 index 0000000000..0c8f5d141d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/mock_send_queue.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_MOCK_SEND_QUEUE_H_ +#define NET_DCSCTP_TX_MOCK_SEND_QUEUE_H_ + +#include <cstdint> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/tx/send_queue.h" +#include "test/gmock.h" + +namespace dcsctp { + +class MockSendQueue : public SendQueue { + public: + MockSendQueue() { + ON_CALL(*this, Produce).WillByDefault([](TimeMs now, size_t max_size) { + return absl::nullopt; + }); + } + + MOCK_METHOD(absl::optional<SendQueue::DataToSend>, + Produce, + (TimeMs now, size_t max_size), + (override)); + MOCK_METHOD(bool, + Discard, + (IsUnordered unordered, StreamID stream_id, MID message_id), + (override)); + MOCK_METHOD(void, PrepareResetStream, (StreamID stream_id), (override)); + MOCK_METHOD(bool, HasStreamsReadyToBeReset, (), (const, override)); + MOCK_METHOD(std::vector<StreamID>, GetStreamsReadyToBeReset, (), (override)); + MOCK_METHOD(void, CommitResetStreams, (), (override)); + MOCK_METHOD(void, RollbackResetStreams, (), (override)); + MOCK_METHOD(void, Reset, (), (override)); + MOCK_METHOD(size_t, buffered_amount, (StreamID stream_id), (const, override)); + MOCK_METHOD(size_t, total_buffered_amount, (), (const, override)); + MOCK_METHOD(size_t, + buffered_amount_low_threshold, + (StreamID stream_id), + (const, override)); + MOCK_METHOD(void, + SetBufferedAmountLowThreshold, + (StreamID stream_id, size_t bytes), + (override)); + MOCK_METHOD(void, EnableMessageInterleaving, (bool enabled), (override)); +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_MOCK_SEND_QUEUE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.cc b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.cc new file mode 100644 index 0000000000..4f1e863056 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.cc @@ -0,0 +1,543 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/outstanding_data.h" + +#include <algorithm> +#include <set> +#include <utility> +#include <vector> + +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/public/types.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +// The number of times a packet must be NACKed before it's retransmitted. +// See https://tools.ietf.org/html/rfc4960#section-7.2.4 +constexpr uint8_t kNumberOfNacksForRetransmission = 3; + +// Returns how large a chunk will be, serialized, carrying the data +size_t OutstandingData::GetSerializedChunkSize(const Data& data) const { + return RoundUpTo4(data_chunk_header_size_ + data.size()); +} + +void OutstandingData::Item::Ack() { + if (lifecycle_ != Lifecycle::kAbandoned) { + lifecycle_ = Lifecycle::kActive; + } + ack_state_ = AckState::kAcked; +} + +OutstandingData::Item::NackAction OutstandingData::Item::Nack( + bool retransmit_now) { + ack_state_ = AckState::kNacked; + ++nack_count_; + if (!should_be_retransmitted() && !is_abandoned() && + (retransmit_now || nack_count_ >= kNumberOfNacksForRetransmission)) { + // Nacked enough times - it's considered lost. + if (num_retransmissions_ < *max_retransmissions_) { + lifecycle_ = Lifecycle::kToBeRetransmitted; + return NackAction::kRetransmit; + } + Abandon(); + return NackAction::kAbandon; + } + return NackAction::kNothing; +} + +void OutstandingData::Item::MarkAsRetransmitted() { + lifecycle_ = Lifecycle::kActive; + ack_state_ = AckState::kUnacked; + + nack_count_ = 0; + ++num_retransmissions_; +} + +void OutstandingData::Item::Abandon() { + lifecycle_ = Lifecycle::kAbandoned; +} + +bool OutstandingData::Item::has_expired(TimeMs now) const { + return expires_at_ <= now; +} + +bool OutstandingData::IsConsistent() const { + size_t actual_outstanding_bytes = 0; + size_t actual_outstanding_items = 0; + + std::set<UnwrappedTSN> combined_to_be_retransmitted; + combined_to_be_retransmitted.insert(to_be_retransmitted_.begin(), + to_be_retransmitted_.end()); + combined_to_be_retransmitted.insert(to_be_fast_retransmitted_.begin(), + to_be_fast_retransmitted_.end()); + + std::set<UnwrappedTSN> actual_combined_to_be_retransmitted; + for (const auto& [tsn, item] : outstanding_data_) { + if (item.is_outstanding()) { + actual_outstanding_bytes += GetSerializedChunkSize(item.data()); + ++actual_outstanding_items; + } + + if (item.should_be_retransmitted()) { + actual_combined_to_be_retransmitted.insert(tsn); + } + } + + if (outstanding_data_.empty() && + next_tsn_ != last_cumulative_tsn_ack_.next_value()) { + return false; + } + + return actual_outstanding_bytes == outstanding_bytes_ && + actual_outstanding_items == outstanding_items_ && + actual_combined_to_be_retransmitted == combined_to_be_retransmitted; +} + +void OutstandingData::AckChunk(AckInfo& ack_info, + std::map<UnwrappedTSN, Item>::iterator iter) { + if (!iter->second.is_acked()) { + size_t serialized_size = GetSerializedChunkSize(iter->second.data()); + ack_info.bytes_acked += serialized_size; + if (iter->second.is_outstanding()) { + outstanding_bytes_ -= serialized_size; + --outstanding_items_; + } + if (iter->second.should_be_retransmitted()) { + RTC_DCHECK(to_be_fast_retransmitted_.find(iter->first) == + to_be_fast_retransmitted_.end()); + to_be_retransmitted_.erase(iter->first); + } + iter->second.Ack(); + ack_info.highest_tsn_acked = + std::max(ack_info.highest_tsn_acked, iter->first); + } +} + +OutstandingData::AckInfo OutstandingData::HandleSack( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + bool is_in_fast_recovery) { + OutstandingData::AckInfo ack_info(cumulative_tsn_ack); + // Erase all items up to cumulative_tsn_ack. + RemoveAcked(cumulative_tsn_ack, ack_info); + + // ACK packets reported in the gap ack blocks + AckGapBlocks(cumulative_tsn_ack, gap_ack_blocks, ack_info); + + // NACK and possibly mark for retransmit chunks that weren't acked. + NackBetweenAckBlocks(cumulative_tsn_ack, gap_ack_blocks, is_in_fast_recovery, + ack_info); + + RTC_DCHECK(IsConsistent()); + return ack_info; +} + +void OutstandingData::RemoveAcked(UnwrappedTSN cumulative_tsn_ack, + AckInfo& ack_info) { + auto first_unacked = outstanding_data_.upper_bound(cumulative_tsn_ack); + + for (auto iter = outstanding_data_.begin(); iter != first_unacked; ++iter) { + AckChunk(ack_info, iter); + if (iter->second.lifecycle_id().IsSet()) { + RTC_DCHECK(iter->second.data().is_end); + if (iter->second.is_abandoned()) { + ack_info.abandoned_lifecycle_ids.push_back(iter->second.lifecycle_id()); + } else { + ack_info.acked_lifecycle_ids.push_back(iter->second.lifecycle_id()); + } + } + } + + outstanding_data_.erase(outstanding_data_.begin(), first_unacked); + last_cumulative_tsn_ack_ = cumulative_tsn_ack; +} + +void OutstandingData::AckGapBlocks( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + AckInfo& ack_info) { + // Mark all non-gaps as ACKED (but they can't be removed) as (from RFC) + // "SCTP considers the information carried in the Gap Ack Blocks in the + // SACK chunk as advisory.". Note that when NR-SACK is supported, this can be + // handled differently. + + for (auto& block : gap_ack_blocks) { + auto start = outstanding_data_.lower_bound( + UnwrappedTSN::AddTo(cumulative_tsn_ack, block.start)); + auto end = outstanding_data_.upper_bound( + UnwrappedTSN::AddTo(cumulative_tsn_ack, block.end)); + for (auto iter = start; iter != end; ++iter) { + AckChunk(ack_info, iter); + } + } +} + +void OutstandingData::NackBetweenAckBlocks( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + bool is_in_fast_recovery, + OutstandingData::AckInfo& ack_info) { + // Mark everything between the blocks as NACKED/TO_BE_RETRANSMITTED. + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "Mark the DATA chunk(s) with three miss indications for retransmission." + // "For each incoming SACK, miss indications are incremented only for + // missing TSNs prior to the highest TSN newly acknowledged in the SACK." + // + // What this means is that only when there is a increasing stream of data + // received and there are new packets seen (since last time), packets that are + // in-flight and between gaps should be nacked. This means that SCTP relies on + // the T3-RTX-timer to re-send packets otherwise. + UnwrappedTSN max_tsn_to_nack = ack_info.highest_tsn_acked; + if (is_in_fast_recovery && cumulative_tsn_ack > last_cumulative_tsn_ack_) { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "If an endpoint is in Fast Recovery and a SACK arrives that advances + // the Cumulative TSN Ack Point, the miss indications are incremented for + // all TSNs reported missing in the SACK." + max_tsn_to_nack = UnwrappedTSN::AddTo( + cumulative_tsn_ack, + gap_ack_blocks.empty() ? 0 : gap_ack_blocks.rbegin()->end); + } + + UnwrappedTSN prev_block_last_acked = cumulative_tsn_ack; + for (auto& block : gap_ack_blocks) { + UnwrappedTSN cur_block_first_acked = + UnwrappedTSN::AddTo(cumulative_tsn_ack, block.start); + for (auto iter = outstanding_data_.upper_bound(prev_block_last_acked); + iter != outstanding_data_.lower_bound(cur_block_first_acked); ++iter) { + if (iter->first <= max_tsn_to_nack) { + ack_info.has_packet_loss |= + NackItem(iter->first, iter->second, /*retransmit_now=*/false, + /*do_fast_retransmit=*/!is_in_fast_recovery); + } + } + prev_block_last_acked = UnwrappedTSN::AddTo(cumulative_tsn_ack, block.end); + } + + // Note that packets are not NACKED which are above the highest gap-ack-block + // (or above the cumulative ack TSN if no gap-ack-blocks) as only packets + // up until the highest_tsn_acked (see above) should be considered when + // NACKing. +} + +bool OutstandingData::NackItem(UnwrappedTSN tsn, + Item& item, + bool retransmit_now, + bool do_fast_retransmit) { + if (item.is_outstanding()) { + outstanding_bytes_ -= GetSerializedChunkSize(item.data()); + --outstanding_items_; + } + + switch (item.Nack(retransmit_now)) { + case Item::NackAction::kNothing: + return false; + case Item::NackAction::kRetransmit: + if (do_fast_retransmit) { + to_be_fast_retransmitted_.insert(tsn); + } else { + to_be_retransmitted_.insert(tsn); + } + RTC_DLOG(LS_VERBOSE) << *tsn.Wrap() << " marked for retransmission"; + break; + case Item::NackAction::kAbandon: + AbandonAllFor(item); + break; + } + return true; +} + +void OutstandingData::AbandonAllFor(const Item& item) { + // Erase all remaining chunks from the producer, if any. + if (discard_from_send_queue_(item.data().is_unordered, item.data().stream_id, + item.data().message_id)) { + // There were remaining chunks to be produced for this message. Since the + // receiver may have already received all chunks (up till now) for this + // message, we can't just FORWARD-TSN to the last fragment in this + // (abandoned) message and start sending a new message, as the receiver will + // then see a new message before the end of the previous one was seen (or + // skipped over). So create a new fragment, representing the end, that the + // received will never see as it is abandoned immediately and used as cum + // TSN in the sent FORWARD-TSN. + UnwrappedTSN tsn = next_tsn_; + next_tsn_.Increment(); + Data message_end(item.data().stream_id, item.data().ssn, + item.data().message_id, item.data().fsn, item.data().ppid, + std::vector<uint8_t>(), Data::IsBeginning(false), + Data::IsEnd(true), item.data().is_unordered); + Item& added_item = + outstanding_data_ + .emplace(std::piecewise_construct, std::forward_as_tuple(tsn), + std::forward_as_tuple(std::move(message_end), TimeMs(0), + MaxRetransmits::NoLimit(), + TimeMs::InfiniteFuture(), + LifecycleId::NotSet())) + .first->second; + // The added chunk shouldn't be included in `outstanding_bytes`, so set it + // as acked. + added_item.Ack(); + RTC_DLOG(LS_VERBOSE) << "Adding unsent end placeholder for message at tsn=" + << *tsn.Wrap(); + } + + for (auto& [tsn, other] : outstanding_data_) { + if (!other.is_abandoned() && + other.data().stream_id == item.data().stream_id && + other.data().is_unordered == item.data().is_unordered && + other.data().message_id == item.data().message_id) { + RTC_DLOG(LS_VERBOSE) << "Marking chunk " << *tsn.Wrap() + << " as abandoned"; + if (other.should_be_retransmitted()) { + to_be_fast_retransmitted_.erase(tsn); + to_be_retransmitted_.erase(tsn); + } + other.Abandon(); + } + } +} + +std::vector<std::pair<TSN, Data>> OutstandingData::ExtractChunksThatCanFit( + std::set<UnwrappedTSN>& chunks, + size_t max_size) { + std::vector<std::pair<TSN, Data>> result; + + for (auto it = chunks.begin(); it != chunks.end();) { + UnwrappedTSN tsn = *it; + auto elem = outstanding_data_.find(tsn); + RTC_DCHECK(elem != outstanding_data_.end()); + Item& item = elem->second; + RTC_DCHECK(item.should_be_retransmitted()); + RTC_DCHECK(!item.is_outstanding()); + RTC_DCHECK(!item.is_abandoned()); + RTC_DCHECK(!item.is_acked()); + + size_t serialized_size = GetSerializedChunkSize(item.data()); + if (serialized_size <= max_size) { + item.MarkAsRetransmitted(); + result.emplace_back(tsn.Wrap(), item.data().Clone()); + max_size -= serialized_size; + outstanding_bytes_ += serialized_size; + ++outstanding_items_; + it = chunks.erase(it); + } else { + ++it; + } + // No point in continuing if the packet is full. + if (max_size <= data_chunk_header_size_) { + break; + } + } + return result; +} + +std::vector<std::pair<TSN, Data>> +OutstandingData::GetChunksToBeFastRetransmitted(size_t max_size) { + std::vector<std::pair<TSN, Data>> result = + ExtractChunksThatCanFit(to_be_fast_retransmitted_, max_size); + + // https://datatracker.ietf.org/doc/html/rfc4960#section-7.2.4 + // "Those TSNs marked for retransmission due to the Fast-Retransmit algorithm + // that did not fit in the sent datagram carrying K other TSNs are also marked + // as ineligible for a subsequent Fast Retransmit. However, as they are + // marked for retransmission they will be retransmitted later on as soon as + // cwnd allows." + if (!to_be_fast_retransmitted_.empty()) { + to_be_retransmitted_.insert(to_be_fast_retransmitted_.begin(), + to_be_fast_retransmitted_.end()); + to_be_fast_retransmitted_.clear(); + } + + RTC_DCHECK(IsConsistent()); + return result; +} + +std::vector<std::pair<TSN, Data>> OutstandingData::GetChunksToBeRetransmitted( + size_t max_size) { + // Chunks scheduled for fast retransmission must be sent first. + RTC_DCHECK(to_be_fast_retransmitted_.empty()); + return ExtractChunksThatCanFit(to_be_retransmitted_, max_size); +} + +void OutstandingData::ExpireOutstandingChunks(TimeMs now) { + for (const auto& [tsn, item] : outstanding_data_) { + // Chunks that are nacked can be expired. Care should be taken not to expire + // unacked (in-flight) chunks as they might have been received, but the SACK + // is either delayed or in-flight and may be received later. + if (item.is_abandoned()) { + // Already abandoned. + } else if (item.is_nacked() && item.has_expired(now)) { + RTC_DLOG(LS_VERBOSE) << "Marking nacked chunk " << *tsn.Wrap() + << " and message " << *item.data().message_id + << " as expired"; + AbandonAllFor(item); + } else { + // A non-expired chunk. No need to iterate any further. + break; + } + } + RTC_DCHECK(IsConsistent()); +} + +UnwrappedTSN OutstandingData::highest_outstanding_tsn() const { + return outstanding_data_.empty() ? last_cumulative_tsn_ack_ + : outstanding_data_.rbegin()->first; +} + +absl::optional<UnwrappedTSN> OutstandingData::Insert( + const Data& data, + TimeMs time_sent, + MaxRetransmits max_retransmissions, + TimeMs expires_at, + LifecycleId lifecycle_id) { + UnwrappedTSN tsn = next_tsn_; + next_tsn_.Increment(); + + // All chunks are always padded to be even divisible by 4. + size_t chunk_size = GetSerializedChunkSize(data); + outstanding_bytes_ += chunk_size; + ++outstanding_items_; + auto it = outstanding_data_ + .emplace(std::piecewise_construct, std::forward_as_tuple(tsn), + std::forward_as_tuple(data.Clone(), time_sent, + max_retransmissions, expires_at, + lifecycle_id)) + .first; + + if (it->second.has_expired(time_sent)) { + // No need to send it - it was expired when it was in the send + // queue. + RTC_DLOG(LS_VERBOSE) << "Marking freshly produced chunk " + << *it->first.Wrap() << " and message " + << *it->second.data().message_id << " as expired"; + AbandonAllFor(it->second); + RTC_DCHECK(IsConsistent()); + return absl::nullopt; + } + + RTC_DCHECK(IsConsistent()); + return tsn; +} + +void OutstandingData::NackAll() { + for (auto& [tsn, item] : outstanding_data_) { + if (!item.is_acked()) { + NackItem(tsn, item, /*retransmit_now=*/true, + /*do_fast_retransmit=*/false); + } + } + RTC_DCHECK(IsConsistent()); +} + +absl::optional<DurationMs> OutstandingData::MeasureRTT(TimeMs now, + UnwrappedTSN tsn) const { + auto it = outstanding_data_.find(tsn); + if (it != outstanding_data_.end() && !it->second.has_been_retransmitted()) { + // https://tools.ietf.org/html/rfc4960#section-6.3.1 + // "Karn's algorithm: RTT measurements MUST NOT be made using + // packets that were retransmitted (and thus for which it is ambiguous + // whether the reply was for the first instance of the chunk or for a + // later instance)" + return now - it->second.time_sent(); + } + return absl::nullopt; +} + +std::vector<std::pair<TSN, OutstandingData::State>> +OutstandingData::GetChunkStatesForTesting() const { + std::vector<std::pair<TSN, State>> states; + states.emplace_back(last_cumulative_tsn_ack_.Wrap(), State::kAcked); + for (const auto& [tsn, item] : outstanding_data_) { + State state; + if (item.is_abandoned()) { + state = State::kAbandoned; + } else if (item.should_be_retransmitted()) { + state = State::kToBeRetransmitted; + } else if (item.is_acked()) { + state = State::kAcked; + } else if (item.is_outstanding()) { + state = State::kInFlight; + } else { + state = State::kNacked; + } + + states.emplace_back(tsn.Wrap(), state); + } + return states; +} + +bool OutstandingData::ShouldSendForwardTsn() const { + if (!outstanding_data_.empty()) { + auto it = outstanding_data_.begin(); + return it->first == last_cumulative_tsn_ack_.next_value() && + it->second.is_abandoned(); + } + return false; +} + +ForwardTsnChunk OutstandingData::CreateForwardTsn() const { + std::map<StreamID, SSN> skipped_per_ordered_stream; + UnwrappedTSN new_cumulative_ack = last_cumulative_tsn_ack_; + + for (const auto& [tsn, item] : outstanding_data_) { + if ((tsn != new_cumulative_ack.next_value()) || !item.is_abandoned()) { + break; + } + new_cumulative_ack = tsn; + if (!item.data().is_unordered && + item.data().ssn > skipped_per_ordered_stream[item.data().stream_id]) { + skipped_per_ordered_stream[item.data().stream_id] = item.data().ssn; + } + } + + std::vector<ForwardTsnChunk::SkippedStream> skipped_streams; + skipped_streams.reserve(skipped_per_ordered_stream.size()); + for (const auto& [stream_id, ssn] : skipped_per_ordered_stream) { + skipped_streams.emplace_back(stream_id, ssn); + } + return ForwardTsnChunk(new_cumulative_ack.Wrap(), std::move(skipped_streams)); +} + +IForwardTsnChunk OutstandingData::CreateIForwardTsn() const { + std::map<std::pair<IsUnordered, StreamID>, MID> skipped_per_stream; + UnwrappedTSN new_cumulative_ack = last_cumulative_tsn_ack_; + + for (const auto& [tsn, item] : outstanding_data_) { + if ((tsn != new_cumulative_ack.next_value()) || !item.is_abandoned()) { + break; + } + new_cumulative_ack = tsn; + std::pair<IsUnordered, StreamID> stream_id = + std::make_pair(item.data().is_unordered, item.data().stream_id); + + if (item.data().message_id > skipped_per_stream[stream_id]) { + skipped_per_stream[stream_id] = item.data().message_id; + } + } + + std::vector<IForwardTsnChunk::SkippedStream> skipped_streams; + skipped_streams.reserve(skipped_per_stream.size()); + for (const auto& [stream, message_id] : skipped_per_stream) { + skipped_streams.emplace_back(stream.first, stream.second, message_id); + } + + return IForwardTsnChunk(new_cumulative_ack.Wrap(), + std::move(skipped_streams)); +} + +void OutstandingData::ResetSequenceNumbers(UnwrappedTSN next_tsn, + UnwrappedTSN last_cumulative_tsn) { + RTC_DCHECK(outstanding_data_.empty()); + RTC_DCHECK(next_tsn_ == last_cumulative_tsn_ack_.next_value()); + RTC_DCHECK(next_tsn == last_cumulative_tsn.next_value()); + next_tsn_ = next_tsn; + last_cumulative_tsn_ack_ = last_cumulative_tsn; +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.h b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.h new file mode 100644 index 0000000000..6b4b7121fb --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data.h @@ -0,0 +1,350 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_OUTSTANDING_DATA_H_ +#define NET_DCSCTP_TX_OUTSTANDING_DATA_H_ + +#include <map> +#include <set> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +// This class keeps track of outstanding data chunks (sent, not yet acked) and +// handles acking, nacking, rescheduling and abandoning. +class OutstandingData { + public: + // State for DATA chunks (message fragments) in the queue - used in tests. + enum class State { + // The chunk has been sent but not received yet (from the sender's point of + // view, as no SACK has been received yet that reference this chunk). + kInFlight, + // A SACK has been received which explicitly marked this chunk as missing - + // it's now NACKED and may be retransmitted if NACKED enough times. + kNacked, + // A chunk that will be retransmitted when possible. + kToBeRetransmitted, + // A SACK has been received which explicitly marked this chunk as received. + kAcked, + // A chunk whose message has expired or has been retransmitted too many + // times (RFC3758). It will not be retransmitted anymore. + kAbandoned, + }; + + // Contains variables scoped to a processing of an incoming SACK. + struct AckInfo { + explicit AckInfo(UnwrappedTSN cumulative_tsn_ack) + : highest_tsn_acked(cumulative_tsn_ack) {} + + // Bytes acked by increasing cumulative_tsn_ack and gap_ack_blocks. + size_t bytes_acked = 0; + + // Indicates if this SACK indicates that packet loss has occurred. Just + // because a packet is missing in the SACK doesn't necessarily mean that + // there is packet loss as that packet might be in-flight and received + // out-of-order. But when it has been reported missing consecutive times, it + // will eventually be considered "lost" and this will be set. + bool has_packet_loss = false; + + // Highest TSN Newly Acknowledged, an SCTP variable. + UnwrappedTSN highest_tsn_acked; + + // The set of lifecycle IDs that were acked using cumulative_tsn_ack. + std::vector<LifecycleId> acked_lifecycle_ids; + // The set of lifecycle IDs that were acked, but had been abandoned. + std::vector<LifecycleId> abandoned_lifecycle_ids; + }; + + OutstandingData( + size_t data_chunk_header_size, + UnwrappedTSN next_tsn, + UnwrappedTSN last_cumulative_tsn_ack, + std::function<bool(IsUnordered, StreamID, MID)> discard_from_send_queue) + : data_chunk_header_size_(data_chunk_header_size), + next_tsn_(next_tsn), + last_cumulative_tsn_ack_(last_cumulative_tsn_ack), + discard_from_send_queue_(std::move(discard_from_send_queue)) {} + + AckInfo HandleSack( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + bool is_in_fast_recovery); + + // Returns as many of the chunks that are eligible for fast retransmissions + // and that would fit in a single packet of `max_size`. The eligible chunks + // that didn't fit will be marked for (normal) retransmission and will not be + // returned if this method is called again. + std::vector<std::pair<TSN, Data>> GetChunksToBeFastRetransmitted( + size_t max_size); + + // Given `max_size` of space left in a packet, which chunks can be added to + // it? + std::vector<std::pair<TSN, Data>> GetChunksToBeRetransmitted(size_t max_size); + + size_t outstanding_bytes() const { return outstanding_bytes_; } + + // Returns the number of DATA chunks that are in-flight. + size_t outstanding_items() const { return outstanding_items_; } + + // Given the current time `now_ms`, expire and abandon outstanding (sent at + // least once) chunks that have a limited lifetime. + void ExpireOutstandingChunks(TimeMs now); + + bool empty() const { return outstanding_data_.empty(); } + + bool has_data_to_be_fast_retransmitted() const { + return !to_be_fast_retransmitted_.empty(); + } + + bool has_data_to_be_retransmitted() const { + return !to_be_retransmitted_.empty() || !to_be_fast_retransmitted_.empty(); + } + + UnwrappedTSN last_cumulative_tsn_ack() const { + return last_cumulative_tsn_ack_; + } + + UnwrappedTSN next_tsn() const { return next_tsn_; } + + UnwrappedTSN highest_outstanding_tsn() const; + + // Schedules `data` to be sent, with the provided partial reliability + // parameters. Returns the TSN if the item was actually added and scheduled to + // be sent, and absl::nullopt if it shouldn't be sent. + absl::optional<UnwrappedTSN> Insert( + const Data& data, + TimeMs time_sent, + MaxRetransmits max_retransmissions = MaxRetransmits::NoLimit(), + TimeMs expires_at = TimeMs::InfiniteFuture(), + LifecycleId lifecycle_id = LifecycleId::NotSet()); + + // Nacks all outstanding data. + void NackAll(); + + // Creates a FORWARD-TSN chunk. + ForwardTsnChunk CreateForwardTsn() const; + + // Creates an I-FORWARD-TSN chunk. + IForwardTsnChunk CreateIForwardTsn() const; + + // Given the current time and a TSN, it returns the measured RTT between when + // the chunk was sent and now. It takes into acccount Karn's algorithm, so if + // the chunk has ever been retransmitted, it will return absl::nullopt. + absl::optional<DurationMs> MeasureRTT(TimeMs now, UnwrappedTSN tsn) const; + + // Returns the internal state of all queued chunks. This is only used in + // unit-tests. + std::vector<std::pair<TSN, State>> GetChunkStatesForTesting() const; + + // Returns true if the next chunk that is not acked by the peer has been + // abandoned, which means that a FORWARD-TSN should be sent. + bool ShouldSendForwardTsn() const; + + // Sets the next TSN to be used. This is used in handover. + void ResetSequenceNumbers(UnwrappedTSN next_tsn, + UnwrappedTSN last_cumulative_tsn); + + private: + // A fragmented message's DATA chunk while in the retransmission queue, and + // its associated metadata. + class Item { + public: + enum class NackAction { + kNothing, + kRetransmit, + kAbandon, + }; + + Item(Data data, + TimeMs time_sent, + MaxRetransmits max_retransmissions, + TimeMs expires_at, + LifecycleId lifecycle_id) + : time_sent_(time_sent), + max_retransmissions_(max_retransmissions), + expires_at_(expires_at), + lifecycle_id_(lifecycle_id), + data_(std::move(data)) {} + + Item(const Item&) = delete; + Item& operator=(const Item&) = delete; + + TimeMs time_sent() const { return time_sent_; } + + const Data& data() const { return data_; } + + // Acks an item. + void Ack(); + + // Nacks an item. If it has been nacked enough times, or if `retransmit_now` + // is set, it might be marked for retransmission. If the item has reached + // its max retransmission value, it will instead be abandoned. The action + // performed is indicated as return value. + NackAction Nack(bool retransmit_now); + + // Prepares the item to be retransmitted. Sets it as outstanding and + // clears all nack counters. + void MarkAsRetransmitted(); + + // Marks this item as abandoned. + void Abandon(); + + bool is_outstanding() const { return ack_state_ == AckState::kUnacked; } + bool is_acked() const { return ack_state_ == AckState::kAcked; } + bool is_nacked() const { return ack_state_ == AckState::kNacked; } + bool is_abandoned() const { return lifecycle_ == Lifecycle::kAbandoned; } + + // Indicates if this chunk should be retransmitted. + bool should_be_retransmitted() const { + return lifecycle_ == Lifecycle::kToBeRetransmitted; + } + // Indicates if this chunk has ever been retransmitted. + bool has_been_retransmitted() const { return num_retransmissions_ > 0; } + + // Given the current time, and the current state of this DATA chunk, it will + // indicate if it has expired (SCTP Partial Reliability Extension). + bool has_expired(TimeMs now) const; + + LifecycleId lifecycle_id() const { return lifecycle_id_; } + + private: + enum class Lifecycle : uint8_t { + // The chunk is alive (sent, received, etc) + kActive, + // The chunk is scheduled to be retransmitted, and will then transition to + // become active. + kToBeRetransmitted, + // The chunk has been abandoned. This is a terminal state. + kAbandoned + }; + enum class AckState : uint8_t { + // The chunk is in-flight. + kUnacked, + // The chunk has been received and acknowledged. + kAcked, + // The chunk has been nacked and is possibly lost. + kNacked + }; + + // NOTE: This data structure has been optimized for size, by ordering fields + // to avoid unnecessary padding. + + // When the packet was sent, and placed in this queue. + const TimeMs time_sent_; + // If the message was sent with a maximum number of retransmissions, this is + // set to that number. The value zero (0) means that it will never be + // retransmitted. + const MaxRetransmits max_retransmissions_; + + // Indicates the life cycle status of this chunk. + Lifecycle lifecycle_ = Lifecycle::kActive; + // Indicates the presence of this chunk, if it's in flight (Unacked), has + // been received (Acked) or is possibly lost (Nacked). + AckState ack_state_ = AckState::kUnacked; + + // The number of times the DATA chunk has been nacked (by having received a + // SACK which doesn't include it). Will be cleared on retransmissions. + uint8_t nack_count_ = 0; + // The number of times the DATA chunk has been retransmitted. + uint16_t num_retransmissions_ = 0; + + // At this exact millisecond, the item is considered expired. If the message + // is not to be expired, this is set to the infinite future. + const TimeMs expires_at_; + + // An optional lifecycle id, which may only be set for the last fragment. + const LifecycleId lifecycle_id_; + + // The actual data to send/retransmit. + const Data data_; + }; + + // Returns how large a chunk will be, serialized, carrying the data + size_t GetSerializedChunkSize(const Data& data) const; + + // Given a `cumulative_tsn_ack` from an incoming SACK, will remove those items + // in the retransmission queue up until this value and will update `ack_info` + // by setting `bytes_acked_by_cumulative_tsn_ack`. + void RemoveAcked(UnwrappedTSN cumulative_tsn_ack, AckInfo& ack_info); + + // Will mark the chunks covered by the `gap_ack_blocks` from an incoming SACK + // as "acked" and update `ack_info` by adding new TSNs to `added_tsns`. + void AckGapBlocks(UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + AckInfo& ack_info); + + // Mark chunks reported as "missing", as "nacked" or "to be retransmitted" + // depending how many times this has happened. Only packets up until + // `ack_info.highest_tsn_acked` (highest TSN newly acknowledged) are + // nacked/retransmitted. The method will set `ack_info.has_packet_loss`. + void NackBetweenAckBlocks( + UnwrappedTSN cumulative_tsn_ack, + rtc::ArrayView<const SackChunk::GapAckBlock> gap_ack_blocks, + bool is_in_fast_recovery, + OutstandingData::AckInfo& ack_info); + + // Process the acknowledgement of the chunk referenced by `iter` and updates + // state in `ack_info` and the object's state. + void AckChunk(AckInfo& ack_info, std::map<UnwrappedTSN, Item>::iterator iter); + + // Helper method to process an incoming nack of an item and perform the + // correct operations given the action indicated when nacking an item (e.g. + // retransmitting or abandoning). The return value indicate if an action was + // performed, meaning that packet loss was detected and acted upon. If + // `do_fast_retransmit` is set and if the item has been nacked sufficiently + // many times so that it should be retransmitted, this will schedule it to be + // "fast retransmitted". This is only done just before going into fast + // recovery. + bool NackItem(UnwrappedTSN tsn, + Item& item, + bool retransmit_now, + bool do_fast_retransmit); + + // Given that a message fragment, `item` has been abandoned, abandon all other + // fragments that share the same message - both never-before-sent fragments + // that are still in the SendQueue and outstanding chunks. + void AbandonAllFor(const OutstandingData::Item& item); + + std::vector<std::pair<TSN, Data>> ExtractChunksThatCanFit( + std::set<UnwrappedTSN>& chunks, + size_t max_size); + + bool IsConsistent() const; + + // The size of the data chunk (DATA/I-DATA) header that is used. + const size_t data_chunk_header_size_; + // Next TSN to used. + UnwrappedTSN next_tsn_; + // The last cumulative TSN ack number. + UnwrappedTSN last_cumulative_tsn_ack_; + // Callback when to discard items from the send queue. + std::function<bool(IsUnordered, StreamID, MID)> discard_from_send_queue_; + + std::map<UnwrappedTSN, Item> outstanding_data_; + // The number of bytes that are in-flight (sent but not yet acked or nacked). + size_t outstanding_bytes_ = 0; + // The number of DATA chunks that are in-flight (sent but not yet acked or + // nacked). + size_t outstanding_items_ = 0; + // Data chunks that are eligible for fast retransmission. + std::set<UnwrappedTSN> to_be_fast_retransmitted_; + // Data chunks that are to be retransmitted. + std::set<UnwrappedTSN> to_be_retransmitted_; +}; +} // namespace dcsctp +#endif // NET_DCSCTP_TX_OUTSTANDING_DATA_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/outstanding_data_test.cc b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data_test.cc new file mode 100644 index 0000000000..cdca40cfef --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/outstanding_data_test.cc @@ -0,0 +1,591 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/outstanding_data.h" + +#include <vector> + +#include "absl/types/optional.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/testing/data_generator.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::MockFunction; +using State = ::dcsctp::OutstandingData::State; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::Return; +using ::testing::StrictMock; + +constexpr TimeMs kNow(42); + +class OutstandingDataTest : public testing::Test { + protected: + OutstandingDataTest() + : gen_(MID(42)), + buf_(DataChunk::kHeaderSize, + unwrapper_.Unwrap(TSN(10)), + unwrapper_.Unwrap(TSN(9)), + on_discard_.AsStdFunction()) {} + + UnwrappedTSN::Unwrapper unwrapper_; + DataGenerator gen_; + StrictMock<MockFunction<bool(IsUnordered, StreamID, MID)>> on_discard_; + OutstandingData buf_; +}; + +TEST_F(OutstandingDataTest, HasInitialState) { + EXPECT_TRUE(buf_.empty()); + EXPECT_EQ(buf_.outstanding_bytes(), 0u); + EXPECT_EQ(buf_.outstanding_items(), 0u); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(9)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(10)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(9)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked))); + EXPECT_FALSE(buf_.ShouldSendForwardTsn()); +} + +TEST_F(OutstandingDataTest, InsertChunk) { + ASSERT_HAS_VALUE_AND_ASSIGN(UnwrappedTSN tsn, + buf_.Insert(gen_.Ordered({1}, "BE"), kNow)); + + EXPECT_EQ(tsn.Wrap(), TSN(10)); + + EXPECT_EQ(buf_.outstanding_bytes(), DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(buf_.outstanding_items(), 1u); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(9)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(11)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(10)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), + Pair(TSN(10), State::kInFlight))); +} + +TEST_F(OutstandingDataTest, AcksSingleChunk) { + buf_.Insert(gen_.Ordered({1}, "BE"), kNow); + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(10)), {}, false); + + EXPECT_EQ(ack.bytes_acked, DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(ack.highest_tsn_acked.Wrap(), TSN(10)); + EXPECT_FALSE(ack.has_packet_loss); + + EXPECT_EQ(buf_.outstanding_bytes(), 0u); + EXPECT_EQ(buf_.outstanding_items(), 0u); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(10)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(11)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(10)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked))); +} + +TEST_F(OutstandingDataTest, AcksPreviousChunkDoesntUpdate) { + buf_.Insert(gen_.Ordered({1}, "BE"), kNow); + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), {}, false); + + EXPECT_EQ(buf_.outstanding_bytes(), DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(buf_.outstanding_items(), 1u); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(9)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(11)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(10)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), + Pair(TSN(10), State::kInFlight))); +} + +TEST_F(OutstandingDataTest, AcksAndNacksWithGapAckBlocks) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow); + buf_.Insert(gen_.Ordered({1}, "E"), kNow); + + std::vector<SackChunk::GapAckBlock> gab = {SackChunk::GapAckBlock(2, 2)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab, false); + EXPECT_EQ(ack.bytes_acked, DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(ack.highest_tsn_acked.Wrap(), TSN(11)); + EXPECT_FALSE(ack.has_packet_loss); + + EXPECT_EQ(buf_.outstanding_bytes(), 0u); + EXPECT_EQ(buf_.outstanding_items(), 0u); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(9)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(12)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(11)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked))); +} + +TEST_F(OutstandingDataTest, NacksThreeTimesWithSameTsnDoesntRetransmit) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow); + buf_.Insert(gen_.Ordered({1}, "E"), kNow); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked))); +} + +TEST_F(OutstandingDataTest, NacksThreeTimesResultsInRetransmission) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, "E"), kNow); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(2, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(2, 4)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_EQ(ack.bytes_acked, DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(ack.highest_tsn_acked.Wrap(), TSN(13)); + EXPECT_TRUE(ack.has_packet_loss); + + EXPECT_TRUE(buf_.has_data_to_be_retransmitted()); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked))); + + EXPECT_THAT(buf_.GetChunksToBeFastRetransmitted(1000), + ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(buf_.GetChunksToBeRetransmitted(1000), IsEmpty()); +} + +TEST_F(OutstandingDataTest, NacksThreeTimesResultsInAbandoning) { + static constexpr MaxRetransmits kMaxRetransmissions(0); + buf_.Insert(gen_.Ordered({1}, "B"), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, "E"), kNow, kMaxRetransmissions); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(2, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(2, 4)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_EQ(ack.bytes_acked, DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(ack.highest_tsn_acked.Wrap(), TSN(13)); + EXPECT_TRUE(ack.has_packet_loss); + + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(14)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAbandoned))); +} + +TEST_F(OutstandingDataTest, NacksThreeTimesResultsInAbandoningWithPlaceholder) { + static constexpr MaxRetransmits kMaxRetransmissions(0); + buf_.Insert(gen_.Ordered({1}, "B"), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(2, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(true)); + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(2, 4)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_EQ(ack.bytes_acked, DataChunk::kHeaderSize + RoundUpTo4(1)); + EXPECT_EQ(ack.highest_tsn_acked.Wrap(), TSN(13)); + EXPECT_TRUE(ack.has_packet_loss); + + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(15)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAbandoned), // + Pair(TSN(14), State::kAbandoned))); +} + +TEST_F(OutstandingDataTest, ExpiresChunkBeforeItIsInserted) { + static constexpr TimeMs kExpiresAt = kNow + DurationMs(1); + EXPECT_TRUE(buf_.Insert(gen_.Ordered({1}, "B"), kNow, + MaxRetransmits::NoLimit(), kExpiresAt) + .has_value()); + EXPECT_TRUE(buf_.Insert(gen_.Ordered({1}, ""), kNow + DurationMs(0), + MaxRetransmits::NoLimit(), kExpiresAt) + .has_value()); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + EXPECT_FALSE(buf_.Insert(gen_.Ordered({1}, "E"), kNow + DurationMs(1), + MaxRetransmits::NoLimit(), kExpiresAt) + .has_value()); + + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_EQ(buf_.last_cumulative_tsn_ack().Wrap(), TSN(9)); + EXPECT_EQ(buf_.next_tsn().Wrap(), TSN(13)); + EXPECT_EQ(buf_.highest_outstanding_tsn().Wrap(), TSN(12)); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), + Pair(TSN(12), State::kAbandoned))); +} + +TEST_F(OutstandingDataTest, CanGenerateForwardTsn) { + static constexpr MaxRetransmits kMaxRetransmissions(0); + buf_.Insert(gen_.Ordered({1}, "B"), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, ""), kNow, kMaxRetransmissions); + buf_.Insert(gen_.Ordered({1}, "E"), kNow, kMaxRetransmissions); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + buf_.NackAll(); + + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), + Pair(TSN(12), State::kAbandoned))); + + EXPECT_TRUE(buf_.ShouldSendForwardTsn()); + ForwardTsnChunk chunk = buf_.CreateForwardTsn(); + EXPECT_EQ(chunk.new_cumulative_tsn(), TSN(12)); +} + +TEST_F(OutstandingDataTest, AckWithGapBlocksFromRFC4960Section334) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, ""), kNow); + buf_.Insert(gen_.Ordered({1}, "E"), kNow); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + testing::ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight), // + Pair(TSN(14), State::kInFlight), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight))); + + std::vector<SackChunk::GapAckBlock> gab = {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 5)}; + buf_.HandleSack(unwrapper_.Unwrap(TSN(12)), gab, false); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kNacked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kNacked), // + Pair(TSN(17), State::kAcked))); +} + +TEST_F(OutstandingDataTest, MeasureRTT) { + buf_.Insert(gen_.Ordered({1}, "BE"), kNow); + buf_.Insert(gen_.Ordered({1}, "BE"), kNow + DurationMs(1)); + buf_.Insert(gen_.Ordered({1}, "BE"), kNow + DurationMs(2)); + + static constexpr DurationMs kDuration(123); + ASSERT_HAS_VALUE_AND_ASSIGN( + DurationMs duration, + buf_.MeasureRTT(kNow + kDuration, unwrapper_.Unwrap(TSN(11)))); + + EXPECT_EQ(duration, kDuration - DurationMs(1)); +} + +TEST_F(OutstandingDataTest, MustRetransmitBeforeGettingNackedAgain) { + // This test case verifies that a chunk that has been nacked, and scheduled to + // be retransmitted, doesn't get nacked again until it has been actually sent + // on the wire. + + static constexpr MaxRetransmits kOneRetransmission(1); + for (int tsn = 10; tsn <= 20; ++tsn) { + buf_.Insert(gen_.Ordered({1}, tsn == 10 ? "B" + : tsn == 20 ? "E" + : ""), + kNow, kOneRetransmission); + } + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(2, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(2, 4)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_TRUE(ack.has_packet_loss); + EXPECT_TRUE(buf_.has_data_to_be_retransmitted()); + + // Don't call GetChunksToBeRetransmitted yet - simulate that the congestion + // window doesn't allow it to be retransmitted yet. It does however get more + // SACKs indicating packet loss. + + std::vector<SackChunk::GapAckBlock> gab4 = {SackChunk::GapAckBlock(2, 5)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab4, false).has_packet_loss); + EXPECT_TRUE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab5 = {SackChunk::GapAckBlock(2, 6)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab5, false).has_packet_loss); + EXPECT_TRUE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab6 = {SackChunk::GapAckBlock(2, 7)}; + OutstandingData::AckInfo ack2 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab6, false); + + EXPECT_FALSE(ack2.has_packet_loss); + EXPECT_TRUE(buf_.has_data_to_be_retransmitted()); + + // Now it's retransmitted. + EXPECT_THAT(buf_.GetChunksToBeFastRetransmitted(1000), + ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(buf_.GetChunksToBeRetransmitted(1000), IsEmpty()); + + // And obviously lost, as it will get NACKed and abandoned. + std::vector<SackChunk::GapAckBlock> gab7 = {SackChunk::GapAckBlock(2, 8)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab7, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab8 = {SackChunk::GapAckBlock(2, 9)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab8, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + + std::vector<SackChunk::GapAckBlock> gab9 = {SackChunk::GapAckBlock(2, 10)}; + OutstandingData::AckInfo ack3 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab9, false); + + EXPECT_TRUE(ack3.has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); +} + +TEST_F(OutstandingDataTest, CanAbandonChunksMarkedForFastRetransmit) { + // This test is a bit convoluted, and can't really happen with a well behaving + // client, but this was found by fuzzers. This test will verify that a message + // that was both marked as "to be fast retransmitted" and "abandoned" at the + // same time doesn't cause any consistency issues. + + // Add chunks 10-14, but chunk 11 has zero retransmissions. When chunk 10 and + // 11 are NACKed three times, chunk 10 will be marked for retransmission, but + // chunk 11 will be abandoned, which also abandons chunk 10, as it's part of + // the same message. + buf_.Insert(gen_.Ordered({1}, "B"), kNow); // 10 + buf_.Insert(gen_.Ordered({1}, ""), kNow, MaxRetransmits(0)); // 11 + buf_.Insert(gen_.Ordered({1}, ""), kNow); // 12 + buf_.Insert(gen_.Ordered({1}, ""), kNow); // 13 + buf_.Insert(gen_.Ordered({1}, "E"), kNow); // 14 + + // ACK 9, 12 + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(3, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + // ACK 9, 12, 13 + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(3, 4)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + + // ACK 9, 12, 13, 14 + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(3, 5)}; + OutstandingData::AckInfo ack = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_TRUE(ack.has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + EXPECT_THAT(buf_.GetChunksToBeFastRetransmitted(1000), IsEmpty()); + EXPECT_THAT(buf_.GetChunksToBeRetransmitted(1000), IsEmpty()); +} + +TEST_F(OutstandingDataTest, LifecyleReturnsAckedItemsInAckInfo) { + buf_.Insert(gen_.Ordered({1}, "BE"), kNow, MaxRetransmits::NoLimit(), + TimeMs::InfiniteFuture(), LifecycleId(42)); + buf_.Insert(gen_.Ordered({1}, "BE"), kNow, MaxRetransmits::NoLimit(), + TimeMs::InfiniteFuture(), LifecycleId(43)); + buf_.Insert(gen_.Ordered({1}, "BE"), kNow, MaxRetransmits::NoLimit(), + TimeMs::InfiniteFuture(), LifecycleId(44)); + + OutstandingData::AckInfo ack1 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(11)), {}, false); + + EXPECT_THAT(ack1.acked_lifecycle_ids, + ElementsAre(LifecycleId(42), LifecycleId(43))); + + OutstandingData::AckInfo ack2 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(12)), {}, false); + + EXPECT_THAT(ack2.acked_lifecycle_ids, ElementsAre(LifecycleId(44))); +} + +TEST_F(OutstandingDataTest, LifecycleReturnsAbandonedNackedThreeTimes) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, ""), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, ""), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, "E"), kNow, MaxRetransmits(0), + TimeMs::InfiniteFuture(), LifecycleId(42)); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 2)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab2 = {SackChunk::GapAckBlock(2, 3)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab2, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + std::vector<SackChunk::GapAckBlock> gab3 = {SackChunk::GapAckBlock(2, 4)}; + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + OutstandingData::AckInfo ack1 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab3, false); + EXPECT_TRUE(ack1.has_packet_loss); + EXPECT_THAT(ack1.abandoned_lifecycle_ids, IsEmpty()); + + // This will generate a FORWARD-TSN, which is acked + EXPECT_TRUE(buf_.ShouldSendForwardTsn()); + ForwardTsnChunk chunk = buf_.CreateForwardTsn(); + EXPECT_EQ(chunk.new_cumulative_tsn(), TSN(13)); + + OutstandingData::AckInfo ack2 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(13)), {}, false); + EXPECT_FALSE(ack2.has_packet_loss); + EXPECT_THAT(ack2.abandoned_lifecycle_ids, ElementsAre(LifecycleId(42))); +} + +TEST_F(OutstandingDataTest, LifecycleReturnsAbandonedAfterT3rtxExpired) { + buf_.Insert(gen_.Ordered({1}, "B"), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, ""), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, ""), kNow, MaxRetransmits(0)); + buf_.Insert(gen_.Ordered({1}, "E"), kNow, MaxRetransmits(0), + TimeMs::InfiniteFuture(), LifecycleId(42)); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + testing::ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight))); + + std::vector<SackChunk::GapAckBlock> gab1 = {SackChunk::GapAckBlock(2, 4)}; + EXPECT_FALSE( + buf_.HandleSack(unwrapper_.Unwrap(TSN(9)), gab1, false).has_packet_loss); + EXPECT_FALSE(buf_.has_data_to_be_retransmitted()); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + testing::ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked))); + + // T3-rtx triggered. + EXPECT_CALL(on_discard_, Call(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + buf_.NackAll(); + + EXPECT_THAT(buf_.GetChunkStatesForTesting(), + testing::ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAbandoned))); + + // This will generate a FORWARD-TSN, which is acked + EXPECT_TRUE(buf_.ShouldSendForwardTsn()); + ForwardTsnChunk chunk = buf_.CreateForwardTsn(); + EXPECT_EQ(chunk.new_cumulative_tsn(), TSN(13)); + + OutstandingData::AckInfo ack2 = + buf_.HandleSack(unwrapper_.Unwrap(TSN(13)), {}, false); + EXPECT_FALSE(ack2.has_packet_loss); + EXPECT_THAT(ack2.abandoned_lifecycle_ids, ElementsAre(LifecycleId(42))); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.cc new file mode 100644 index 0000000000..44b20ba2c2 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_error_counter.h" + +#include "absl/strings/string_view.h" +#include "rtc_base/logging.h" + +namespace dcsctp { +bool RetransmissionErrorCounter::Increment(absl::string_view reason) { + ++counter_; + if (limit_.has_value() && counter_ > limit_.value()) { + RTC_DLOG(LS_INFO) << log_prefix_ << reason + << ", too many retransmissions, counter=" << counter_; + return false; + } + + RTC_DLOG(LS_VERBOSE) << log_prefix_ << reason << ", new counter=" << counter_ + << ", max=" << limit_.value_or(-1); + return true; +} + +void RetransmissionErrorCounter::Clear() { + if (counter_ > 0) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "recovered from counter=" << counter_; + counter_ = 0; + } +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.h b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.h new file mode 100644 index 0000000000..18af3d3c4f --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RETRANSMISSION_ERROR_COUNTER_H_ +#define NET_DCSCTP_TX_RETRANSMISSION_ERROR_COUNTER_H_ + +#include <functional> +#include <string> +#include <utility> + +#include "absl/strings/string_view.h" +#include "net/dcsctp/public/dcsctp_options.h" + +namespace dcsctp { + +// The RetransmissionErrorCounter is a simple counter with a limit, and when +// the limit is exceeded, the counter is exhausted and the connection will +// be closed. It's incremented on retransmission errors, such as the T3-RTX +// timer expiring, but also missing heartbeats and stream reset requests. +class RetransmissionErrorCounter { + public: + RetransmissionErrorCounter(absl::string_view log_prefix, + const DcSctpOptions& options) + : log_prefix_(std::string(log_prefix) + "rtx-errors: "), + limit_(options.max_retransmissions) {} + + // Increments the retransmission timer. If the maximum error count has been + // reached, `false` will be returned. + bool Increment(absl::string_view reason); + bool IsExhausted() const { return limit_.has_value() && counter_ > *limit_; } + + // Clears the retransmission errors. + void Clear(); + + // Returns its current value + int value() const { return counter_; } + + private: + const std::string log_prefix_; + const absl::optional<int> limit_; + int counter_ = 0; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RETRANSMISSION_ERROR_COUNTER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter_test.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter_test.cc new file mode 100644 index 0000000000..67bbc0bec5 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_error_counter_test.cc @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_error_counter.h" + +#include "net/dcsctp/public/dcsctp_options.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +TEST(RetransmissionErrorCounterTest, HasInitialValue) { + DcSctpOptions options; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_EQ(counter.value(), 0); +} + +TEST(RetransmissionErrorCounterTest, ReturnsFalseAtMaximumValue) { + DcSctpOptions options; + options.max_retransmissions = 5; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_TRUE(counter.Increment("test")); // 2 + EXPECT_TRUE(counter.Increment("test")); // 3 + EXPECT_TRUE(counter.Increment("test")); // 4 + EXPECT_TRUE(counter.Increment("test")); // 5 + EXPECT_FALSE(counter.Increment("test")); // Too many retransmissions +} + +TEST(RetransmissionErrorCounterTest, CanHandleZeroRetransmission) { + DcSctpOptions options; + options.max_retransmissions = 0; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_FALSE(counter.Increment("test")); // One is too many. +} + +TEST(RetransmissionErrorCounterTest, IsExhaustedAtMaximum) { + DcSctpOptions options; + options.max_retransmissions = 3; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_TRUE(counter.Increment("test")); // 2 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_TRUE(counter.Increment("test")); // 3 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_FALSE(counter.Increment("test")); // Too many retransmissions + EXPECT_TRUE(counter.IsExhausted()); + EXPECT_FALSE(counter.Increment("test")); // One after too many + EXPECT_TRUE(counter.IsExhausted()); +} + +TEST(RetransmissionErrorCounterTest, ClearingCounter) { + DcSctpOptions options; + options.max_retransmissions = 3; + RetransmissionErrorCounter counter("log: ", options); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_TRUE(counter.Increment("test")); // 2 + counter.Clear(); + EXPECT_TRUE(counter.Increment("test")); // 1 + EXPECT_TRUE(counter.Increment("test")); // 2 + EXPECT_TRUE(counter.Increment("test")); // 3 + EXPECT_FALSE(counter.IsExhausted()); + EXPECT_FALSE(counter.Increment("test")); // Too many retransmissions + EXPECT_TRUE(counter.IsExhausted()); +} + +TEST(RetransmissionErrorCounterTest, CanBeLimitless) { + DcSctpOptions options; + options.max_retransmissions = absl::nullopt; + RetransmissionErrorCounter counter("log: ", options); + for (int i = 0; i < 100; ++i) { + EXPECT_TRUE(counter.Increment("test")); + EXPECT_FALSE(counter.IsExhausted()); + } +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.cc new file mode 100644 index 0000000000..36e2a859ba --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.cc @@ -0,0 +1,611 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_queue.h" + +#include <algorithm> +#include <cstdint> +#include <functional> +#include <iterator> +#include <map> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/outstanding_data.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" +#include "rtc_base/strings/string_builder.h" + +namespace dcsctp { +namespace { + +// Allow sending only slightly less than an MTU, to account for headers. +constexpr float kMinBytesRequiredToSendFactor = 0.9; +} // namespace + +RetransmissionQueue::RetransmissionQueue( + absl::string_view log_prefix, + DcSctpSocketCallbacks* callbacks, + TSN my_initial_tsn, + size_t a_rwnd, + SendQueue& send_queue, + std::function<void(DurationMs rtt)> on_new_rtt, + std::function<void()> on_clear_retransmission_counter, + Timer& t3_rtx, + const DcSctpOptions& options, + bool supports_partial_reliability, + bool use_message_interleaving) + : callbacks_(*callbacks), + options_(options), + min_bytes_required_to_send_(options.mtu * kMinBytesRequiredToSendFactor), + partial_reliability_(supports_partial_reliability), + log_prefix_(std::string(log_prefix) + "tx: "), + data_chunk_header_size_(use_message_interleaving + ? IDataChunk::kHeaderSize + : DataChunk::kHeaderSize), + on_new_rtt_(std::move(on_new_rtt)), + on_clear_retransmission_counter_( + std::move(on_clear_retransmission_counter)), + t3_rtx_(t3_rtx), + cwnd_(options_.cwnd_mtus_initial * options_.mtu), + rwnd_(a_rwnd), + // https://tools.ietf.org/html/rfc4960#section-7.2.1 + // "The initial value of ssthresh MAY be arbitrarily high (for + // example, implementations MAY use the size of the receiver advertised + // window)."" + ssthresh_(rwnd_), + partial_bytes_acked_(0), + send_queue_(send_queue), + outstanding_data_( + data_chunk_header_size_, + tsn_unwrapper_.Unwrap(my_initial_tsn), + tsn_unwrapper_.Unwrap(TSN(*my_initial_tsn - 1)), + [this](IsUnordered unordered, StreamID stream_id, MID message_id) { + return send_queue_.Discard(unordered, stream_id, message_id); + }) {} + +bool RetransmissionQueue::IsConsistent() const { + return true; +} + +// Returns how large a chunk will be, serialized, carrying the data +size_t RetransmissionQueue::GetSerializedChunkSize(const Data& data) const { + return RoundUpTo4(data_chunk_header_size_ + data.size()); +} + +void RetransmissionQueue::MaybeExitFastRecovery( + UnwrappedTSN cumulative_tsn_ack) { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "When a SACK acknowledges all TSNs up to and including this [fast + // recovery] exit point, Fast Recovery is exited." + if (fast_recovery_exit_tsn_.has_value() && + cumulative_tsn_ack >= *fast_recovery_exit_tsn_) { + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "exit_point=" << *fast_recovery_exit_tsn_->Wrap() + << " reached - exiting fast recovery"; + fast_recovery_exit_tsn_ = absl::nullopt; + } +} + +void RetransmissionQueue::HandleIncreasedCumulativeTsnAck( + size_t outstanding_bytes, + size_t total_bytes_acked) { + // Allow some margin for classifying as fully utilized, due to e.g. that too + // small packets (less than kMinimumFragmentedPayload) are not sent + + // overhead. + bool is_fully_utilized = outstanding_bytes + options_.mtu >= cwnd_; + size_t old_cwnd = cwnd_; + if (phase() == CongestionAlgorithmPhase::kSlowStart) { + if (is_fully_utilized && !is_in_fast_recovery()) { + // https://tools.ietf.org/html/rfc4960#section-7.2.1 + // "Only when these three conditions are met can the cwnd be + // increased; otherwise, the cwnd MUST not be increased. If these + // conditions are met, then cwnd MUST be increased by, at most, the + // lesser of 1) the total size of the previously outstanding DATA + // chunk(s) acknowledged, and 2) the destination's path MTU." + cwnd_ += std::min(total_bytes_acked, options_.mtu); + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "SS increase cwnd=" << cwnd_ + << " (" << old_cwnd << ")"; + } + } else if (phase() == CongestionAlgorithmPhase::kCongestionAvoidance) { + // https://tools.ietf.org/html/rfc4960#section-7.2.2 + // "Whenever cwnd is greater than ssthresh, upon each SACK arrival + // that advances the Cumulative TSN Ack Point, increase + // partial_bytes_acked by the total number of bytes of all new chunks + // acknowledged in that SACK including chunks acknowledged by the new + // Cumulative TSN Ack and by Gap Ack Blocks." + size_t old_pba = partial_bytes_acked_; + partial_bytes_acked_ += total_bytes_acked; + + if (partial_bytes_acked_ >= cwnd_ && is_fully_utilized) { + // https://tools.ietf.org/html/rfc4960#section-7.2.2 + // "When partial_bytes_acked is equal to or greater than cwnd and + // before the arrival of the SACK the sender had cwnd or more bytes of + // data outstanding (i.e., before arrival of the SACK, flightsize was + // greater than or equal to cwnd), increase cwnd by MTU, and reset + // partial_bytes_acked to (partial_bytes_acked - cwnd)." + + // Errata: https://datatracker.ietf.org/doc/html/rfc8540#section-3.12 + partial_bytes_acked_ -= cwnd_; + cwnd_ += options_.mtu; + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "CA increase cwnd=" << cwnd_ + << " (" << old_cwnd << ") ssthresh=" << ssthresh_ + << ", pba=" << partial_bytes_acked_ << " (" + << old_pba << ")"; + } else { + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "CA unchanged cwnd=" << cwnd_ + << " (" << old_cwnd << ") ssthresh=" << ssthresh_ + << ", pba=" << partial_bytes_acked_ << " (" + << old_pba << ")"; + } + } +} + +void RetransmissionQueue::HandlePacketLoss(UnwrappedTSN highest_tsn_acked) { + if (!is_in_fast_recovery()) { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "If not in Fast Recovery, adjust the ssthresh and cwnd of the + // destination address(es) to which the missing DATA chunks were last + // sent, according to the formula described in Section 7.2.3." + size_t old_cwnd = cwnd_; + size_t old_pba = partial_bytes_acked_; + ssthresh_ = std::max(cwnd_ / 2, options_.cwnd_mtus_min * options_.mtu); + cwnd_ = ssthresh_; + partial_bytes_acked_ = 0; + + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "packet loss detected (not fast recovery). cwnd=" + << cwnd_ << " (" << old_cwnd + << "), ssthresh=" << ssthresh_ + << ", pba=" << partial_bytes_acked_ << " (" << old_pba + << ")"; + + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "If not in Fast Recovery, enter Fast Recovery and mark the highest + // outstanding TSN as the Fast Recovery exit point." + fast_recovery_exit_tsn_ = outstanding_data_.highest_outstanding_tsn(); + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "fast recovery initiated with exit_point=" + << *fast_recovery_exit_tsn_->Wrap(); + } else { + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "While in Fast Recovery, the ssthresh and cwnd SHOULD NOT change for + // any destinations due to a subsequent Fast Recovery event (i.e., one + // SHOULD NOT reduce the cwnd further due to a subsequent Fast Retransmit)." + RTC_DLOG(LS_VERBOSE) << log_prefix_ + << "packet loss detected (fast recovery). No changes."; + } +} + +void RetransmissionQueue::UpdateReceiverWindow(uint32_t a_rwnd) { + rwnd_ = outstanding_data_.outstanding_bytes() >= a_rwnd + ? 0 + : a_rwnd - outstanding_data_.outstanding_bytes(); +} + +void RetransmissionQueue::StartT3RtxTimerIfOutstandingData() { + // Note: Can't use `outstanding_bytes()` as that one doesn't count chunks to + // be retransmitted. + if (outstanding_data_.empty()) { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Whenever all outstanding data sent to an address have been + // acknowledged, turn off the T3-rtx timer of that address. + // Note: Already stopped in `StopT3RtxTimerOnIncreasedCumulativeTsnAck`." + } else { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Whenever a SACK is received that acknowledges the DATA chunk + // with the earliest outstanding TSN for that address, restart the T3-rtx + // timer for that address with its current RTO (if there is still + // outstanding data on that address)." + // "Whenever a SACK is received missing a TSN that was previously + // acknowledged via a Gap Ack Block, start the T3-rtx for the destination + // address to which the DATA chunk was originally transmitted if it is not + // already running." + if (!t3_rtx_.is_running()) { + t3_rtx_.Start(); + } + } +} + +bool RetransmissionQueue::IsSackValid(const SackChunk& sack) const { + // https://tools.ietf.org/html/rfc4960#section-6.2.1 + // "If Cumulative TSN Ack is less than the Cumulative TSN Ack Point, + // then drop the SACK. Since Cumulative TSN Ack is monotonically increasing, + // a SACK whose Cumulative TSN Ack is less than the Cumulative TSN Ack Point + // indicates an out-of- order SACK." + // + // Note: Important not to drop SACKs with identical TSN to that previously + // received, as the gap ack blocks or dup tsn fields may have changed. + UnwrappedTSN cumulative_tsn_ack = + tsn_unwrapper_.PeekUnwrap(sack.cumulative_tsn_ack()); + if (cumulative_tsn_ack < outstanding_data_.last_cumulative_tsn_ack()) { + // https://tools.ietf.org/html/rfc4960#section-6.2.1 + // "If Cumulative TSN Ack is less than the Cumulative TSN Ack Point, + // then drop the SACK. Since Cumulative TSN Ack is monotonically + // increasing, a SACK whose Cumulative TSN Ack is less than the Cumulative + // TSN Ack Point indicates an out-of- order SACK." + return false; + } else if (cumulative_tsn_ack > outstanding_data_.highest_outstanding_tsn()) { + return false; + } + return true; +} + +bool RetransmissionQueue::HandleSack(TimeMs now, const SackChunk& sack) { + if (!IsSackValid(sack)) { + return false; + } + + UnwrappedTSN old_last_cumulative_tsn_ack = + outstanding_data_.last_cumulative_tsn_ack(); + size_t old_outstanding_bytes = outstanding_data_.outstanding_bytes(); + size_t old_rwnd = rwnd_; + UnwrappedTSN cumulative_tsn_ack = + tsn_unwrapper_.Unwrap(sack.cumulative_tsn_ack()); + + if (sack.gap_ack_blocks().empty()) { + UpdateRTT(now, cumulative_tsn_ack); + } + + // Exit fast recovery before continuing processing, in case it needs to go + // into fast recovery again due to new reported packet loss. + MaybeExitFastRecovery(cumulative_tsn_ack); + + OutstandingData::AckInfo ack_info = outstanding_data_.HandleSack( + cumulative_tsn_ack, sack.gap_ack_blocks(), is_in_fast_recovery()); + + // Add lifecycle events for delivered messages. + for (LifecycleId lifecycle_id : ack_info.acked_lifecycle_ids) { + RTC_DLOG(LS_VERBOSE) << "Triggering OnLifecycleMessageDelivered(" + << lifecycle_id.value() << ")"; + callbacks_.OnLifecycleMessageDelivered(lifecycle_id); + callbacks_.OnLifecycleEnd(lifecycle_id); + } + for (LifecycleId lifecycle_id : ack_info.abandoned_lifecycle_ids) { + RTC_DLOG(LS_VERBOSE) << "Triggering OnLifecycleMessageExpired(" + << lifecycle_id.value() << ", true)"; + callbacks_.OnLifecycleMessageExpired(lifecycle_id, + /*maybe_delivered=*/true); + callbacks_.OnLifecycleEnd(lifecycle_id); + } + + // Update of outstanding_data_ is now done. Congestion control remains. + UpdateReceiverWindow(sack.a_rwnd()); + + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Received SACK, cum_tsn_ack=" + << *cumulative_tsn_ack.Wrap() << " (" + << *old_last_cumulative_tsn_ack.Wrap() + << "), outstanding_bytes=" + << outstanding_data_.outstanding_bytes() << " (" + << old_outstanding_bytes << "), rwnd=" << rwnd_ << " (" + << old_rwnd << ")"; + + if (cumulative_tsn_ack > old_last_cumulative_tsn_ack) { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Whenever a SACK is received that acknowledges the DATA chunk + // with the earliest outstanding TSN for that address, restart the T3-rtx + // timer for that address with its current RTO (if there is still + // outstanding data on that address)." + // Note: It may be started again in a bit further down. + t3_rtx_.Stop(); + + HandleIncreasedCumulativeTsnAck(old_outstanding_bytes, + ack_info.bytes_acked); + } + + if (ack_info.has_packet_loss) { + HandlePacketLoss(ack_info.highest_tsn_acked); + } + + // https://tools.ietf.org/html/rfc4960#section-8.2 + // "When an outstanding TSN is acknowledged [...] the endpoint shall clear + // the error counter ..." + if (ack_info.bytes_acked > 0) { + on_clear_retransmission_counter_(); + } + + StartT3RtxTimerIfOutstandingData(); + RTC_DCHECK(IsConsistent()); + return true; +} + +void RetransmissionQueue::UpdateRTT(TimeMs now, + UnwrappedTSN cumulative_tsn_ack) { + // RTT updating is flawed in SCTP, as explained in e.g. Pedersen J, Griwodz C, + // Halvorsen P (2006) Considerations of SCTP retransmission delays for thin + // streams. + // Due to delayed acknowledgement, the SACK may be sent much later which + // increases the calculated RTT. + // TODO(boivie): Consider occasionally sending DATA chunks with I-bit set and + // use only those packets for measurement. + + absl::optional<DurationMs> rtt = + outstanding_data_.MeasureRTT(now, cumulative_tsn_ack); + + if (rtt.has_value()) { + on_new_rtt_(*rtt); + } +} + +void RetransmissionQueue::HandleT3RtxTimerExpiry() { + size_t old_cwnd = cwnd_; + size_t old_outstanding_bytes = outstanding_bytes(); + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "For the destination address for which the timer expires, adjust + // its ssthresh with rules defined in Section 7.2.3 and set the cwnd <- MTU." + ssthresh_ = std::max(cwnd_ / 2, 4 * options_.mtu); + cwnd_ = 1 * options_.mtu; + // Errata: https://datatracker.ietf.org/doc/html/rfc8540#section-3.11 + partial_bytes_acked_ = 0; + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "For the destination address for which the timer expires, set RTO + // <- RTO * 2 ("back off the timer"). The maximum value discussed in rule C7 + // above (RTO.max) may be used to provide an upper bound to this doubling + // operation." + + // Already done by the Timer implementation. + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "Determine how many of the earliest (i.e., lowest TSN) outstanding + // DATA chunks for the address for which the T3-rtx has expired will fit into + // a single packet" + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "Note: Any DATA chunks that were sent to the address for which the + // T3-rtx timer expired but did not fit in one MTU (rule E3 above) should be + // marked for retransmission and sent as soon as cwnd allows (normally, when a + // SACK arrives)." + outstanding_data_.NackAll(); + + // https://tools.ietf.org/html/rfc4960#section-6.3.3 + // "Start the retransmission timer T3-rtx on the destination address + // to which the retransmission is sent, if rule R1 above indicates to do so." + + // Already done by the Timer implementation. + + RTC_DLOG(LS_INFO) << log_prefix_ << "t3-rtx expired. new cwnd=" << cwnd_ + << " (" << old_cwnd << "), ssthresh=" << ssthresh_ + << ", outstanding_bytes " << outstanding_bytes() << " (" + << old_outstanding_bytes << ")"; + RTC_DCHECK(IsConsistent()); +} + +std::vector<std::pair<TSN, Data>> +RetransmissionQueue::GetChunksForFastRetransmit(size_t bytes_in_packet) { + RTC_DCHECK(outstanding_data_.has_data_to_be_fast_retransmitted()); + RTC_DCHECK(IsDivisibleBy4(bytes_in_packet)); + std::vector<std::pair<TSN, Data>> to_be_sent; + size_t old_outstanding_bytes = outstanding_bytes(); + + to_be_sent = + outstanding_data_.GetChunksToBeFastRetransmitted(bytes_in_packet); + RTC_DCHECK(!to_be_sent.empty()); + + // https://tools.ietf.org/html/rfc4960#section-7.2.4 + // "4) Restart the T3-rtx timer only if ... the endpoint is retransmitting + // the first outstanding DATA chunk sent to that address." + if (to_be_sent[0].first == + outstanding_data_.last_cumulative_tsn_ack().next_value().Wrap()) { + RTC_DLOG(LS_VERBOSE) + << log_prefix_ + << "First outstanding DATA to be retransmitted - restarting T3-RTX"; + t3_rtx_.Stop(); + } + + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Every time a DATA chunk is sent to any address (including a + // retransmission), if the T3-rtx timer of that address is not running, + // start it running so that it will expire after the RTO of that address." + if (!t3_rtx_.is_running()) { + t3_rtx_.Start(); + } + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Fast-retransmitting TSN " + << StrJoin(to_be_sent, ",", + [&](rtc::StringBuilder& sb, + const std::pair<TSN, Data>& c) { + sb << *c.first; + }) + << " - " + << absl::c_accumulate( + to_be_sent, 0, + [&](size_t r, const std::pair<TSN, Data>& d) { + return r + GetSerializedChunkSize(d.second); + }) + << " bytes. outstanding_bytes=" << outstanding_bytes() + << " (" << old_outstanding_bytes << ")"; + + RTC_DCHECK(IsConsistent()); + return to_be_sent; +} + +std::vector<std::pair<TSN, Data>> RetransmissionQueue::GetChunksToSend( + TimeMs now, + size_t bytes_remaining_in_packet) { + // Chunks are always padded to even divisible by four. + RTC_DCHECK(IsDivisibleBy4(bytes_remaining_in_packet)); + + std::vector<std::pair<TSN, Data>> to_be_sent; + size_t old_outstanding_bytes = outstanding_bytes(); + size_t old_rwnd = rwnd_; + + // Calculate the bandwidth budget (how many bytes that is + // allowed to be sent), and fill that up first with chunks that are + // scheduled to be retransmitted. If there is still budget, send new chunks + // (which will have their TSN assigned here.) + size_t max_bytes = + RoundDownTo4(std::min(max_bytes_to_send(), bytes_remaining_in_packet)); + + to_be_sent = outstanding_data_.GetChunksToBeRetransmitted(max_bytes); + max_bytes -= absl::c_accumulate(to_be_sent, 0, + [&](size_t r, const std::pair<TSN, Data>& d) { + return r + GetSerializedChunkSize(d.second); + }); + + while (max_bytes > data_chunk_header_size_) { + RTC_DCHECK(IsDivisibleBy4(max_bytes)); + absl::optional<SendQueue::DataToSend> chunk_opt = + send_queue_.Produce(now, max_bytes - data_chunk_header_size_); + if (!chunk_opt.has_value()) { + break; + } + + size_t chunk_size = GetSerializedChunkSize(chunk_opt->data); + max_bytes -= chunk_size; + rwnd_ -= chunk_size; + + absl::optional<UnwrappedTSN> tsn = outstanding_data_.Insert( + chunk_opt->data, now, + partial_reliability_ ? chunk_opt->max_retransmissions + : MaxRetransmits::NoLimit(), + partial_reliability_ ? chunk_opt->expires_at : TimeMs::InfiniteFuture(), + chunk_opt->lifecycle_id); + + if (tsn.has_value()) { + if (chunk_opt->lifecycle_id.IsSet()) { + RTC_DCHECK(chunk_opt->data.is_end); + callbacks_.OnLifecycleMessageFullySent(chunk_opt->lifecycle_id); + } + to_be_sent.emplace_back(tsn->Wrap(), std::move(chunk_opt->data)); + } + } + + if (!to_be_sent.empty()) { + // https://tools.ietf.org/html/rfc4960#section-6.3.2 + // "Every time a DATA chunk is sent to any address (including a + // retransmission), if the T3-rtx timer of that address is not running, + // start it running so that it will expire after the RTO of that address." + if (!t3_rtx_.is_running()) { + t3_rtx_.Start(); + } + RTC_DLOG(LS_VERBOSE) << log_prefix_ << "Sending TSN " + << StrJoin(to_be_sent, ",", + [&](rtc::StringBuilder& sb, + const std::pair<TSN, Data>& c) { + sb << *c.first; + }) + << " - " + << absl::c_accumulate( + to_be_sent, 0, + [&](size_t r, const std::pair<TSN, Data>& d) { + return r + GetSerializedChunkSize(d.second); + }) + << " bytes. outstanding_bytes=" << outstanding_bytes() + << " (" << old_outstanding_bytes << "), cwnd=" << cwnd_ + << ", rwnd=" << rwnd_ << " (" << old_rwnd << ")"; + } + RTC_DCHECK(IsConsistent()); + return to_be_sent; +} + +bool RetransmissionQueue::can_send_data() const { + return cwnd_ < options_.avoid_fragmentation_cwnd_mtus * options_.mtu || + max_bytes_to_send() >= min_bytes_required_to_send_; +} + +bool RetransmissionQueue::ShouldSendForwardTsn(TimeMs now) { + if (!partial_reliability_) { + return false; + } + outstanding_data_.ExpireOutstandingChunks(now); + bool ret = outstanding_data_.ShouldSendForwardTsn(); + RTC_DCHECK(IsConsistent()); + return ret; +} + +size_t RetransmissionQueue::max_bytes_to_send() const { + size_t left = outstanding_bytes() >= cwnd_ ? 0 : cwnd_ - outstanding_bytes(); + + if (outstanding_bytes() == 0) { + // https://datatracker.ietf.org/doc/html/rfc4960#section-6.1 + // ... However, regardless of the value of rwnd (including if it is 0), the + // data sender can always have one DATA chunk in flight to the receiver if + // allowed by cwnd (see rule B, below). + return left; + } + + return std::min(rwnd(), left); +} + +void RetransmissionQueue::PrepareResetStream(StreamID stream_id) { + // TODO(boivie): These calls are now only affecting the send queue. The + // packet buffer can also change behavior - for example draining the chunk + // producer and eagerly assign TSNs so that an "Outgoing SSN Reset Request" + // can be sent quickly, with a known `sender_last_assigned_tsn`. + send_queue_.PrepareResetStream(stream_id); +} +bool RetransmissionQueue::HasStreamsReadyToBeReset() const { + return send_queue_.HasStreamsReadyToBeReset(); +} +void RetransmissionQueue::CommitResetStreams() { + send_queue_.CommitResetStreams(); +} +void RetransmissionQueue::RollbackResetStreams() { + send_queue_.RollbackResetStreams(); +} + +HandoverReadinessStatus RetransmissionQueue::GetHandoverReadiness() const { + HandoverReadinessStatus status; + if (!outstanding_data_.empty()) { + status.Add(HandoverUnreadinessReason::kRetransmissionQueueOutstandingData); + } + if (fast_recovery_exit_tsn_.has_value()) { + status.Add(HandoverUnreadinessReason::kRetransmissionQueueFastRecovery); + } + if (outstanding_data_.has_data_to_be_retransmitted()) { + status.Add(HandoverUnreadinessReason::kRetransmissionQueueNotEmpty); + } + return status; +} + +void RetransmissionQueue::AddHandoverState(DcSctpSocketHandoverState& state) { + state.tx.next_tsn = next_tsn().value(); + state.tx.rwnd = rwnd_; + state.tx.cwnd = cwnd_; + state.tx.ssthresh = ssthresh_; + state.tx.partial_bytes_acked = partial_bytes_acked_; +} + +void RetransmissionQueue::RestoreFromState( + const DcSctpSocketHandoverState& state) { + // Validate that the component is in pristine state. + RTC_DCHECK(outstanding_data_.empty()); + RTC_DCHECK(!t3_rtx_.is_running()); + RTC_DCHECK(partial_bytes_acked_ == 0); + + cwnd_ = state.tx.cwnd; + rwnd_ = state.tx.rwnd; + ssthresh_ = state.tx.ssthresh; + partial_bytes_acked_ = state.tx.partial_bytes_acked; + + outstanding_data_.ResetSequenceNumbers( + tsn_unwrapper_.Unwrap(TSN(state.tx.next_tsn)), + tsn_unwrapper_.Unwrap(TSN(state.tx.next_tsn - 1))); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.h b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.h new file mode 100644 index 0000000000..830c0b346d --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue.h @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RETRANSMISSION_QUEUE_H_ +#define NET_DCSCTP_TX_RETRANSMISSION_QUEUE_H_ + +#include <cstdint> +#include <functional> +#include <map> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/sequence_numbers.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_handover_state.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/outstanding_data.h" +#include "net/dcsctp/tx/retransmission_timeout.h" +#include "net/dcsctp/tx/send_queue.h" + +namespace dcsctp { + +// The RetransmissionQueue manages all DATA/I-DATA chunks that are in-flight and +// schedules them to be retransmitted if necessary. Chunks are retransmitted +// when they have been lost for a number of consecutive SACKs, or when the +// retransmission timer, `t3_rtx` expires. +// +// As congestion control is tightly connected with the state of transmitted +// packets, that's also managed here to limit the amount of data that is +// in-flight (sent, but not yet acknowledged). +class RetransmissionQueue { + public: + static constexpr size_t kMinimumFragmentedPayload = 10; + using State = OutstandingData::State; + // Creates a RetransmissionQueue which will send data using `my_initial_tsn` + // (or a value from `DcSctpSocketHandoverState` if given) as the first TSN + // to use for sent fragments. It will poll data from `send_queue`. When SACKs + // are received, it will estimate the RTT, and call `on_new_rtt`. When an + // outstanding chunk has been ACKed, it will call + // `on_clear_retransmission_counter` and will also use `t3_rtx`, which is the + // SCTP retransmission timer to manage retransmissions. + RetransmissionQueue(absl::string_view log_prefix, + DcSctpSocketCallbacks* callbacks, + TSN my_initial_tsn, + size_t a_rwnd, + SendQueue& send_queue, + std::function<void(DurationMs rtt)> on_new_rtt, + std::function<void()> on_clear_retransmission_counter, + Timer& t3_rtx, + const DcSctpOptions& options, + bool supports_partial_reliability = true, + bool use_message_interleaving = false); + + // Handles a received SACK. Returns true if the `sack` was processed and + // false if it was discarded due to received out-of-order and not relevant. + bool HandleSack(TimeMs now, const SackChunk& sack); + + // Handles an expired retransmission timer. + void HandleT3RtxTimerExpiry(); + + bool has_data_to_be_fast_retransmitted() const { + return outstanding_data_.has_data_to_be_fast_retransmitted(); + } + + // Returns a list of chunks to "fast retransmit" that would fit in one SCTP + // packet with `bytes_in_packet` bytes available. The current value + // of `cwnd` is ignored. + std::vector<std::pair<TSN, Data>> GetChunksForFastRetransmit( + size_t bytes_in_packet); + + // Returns a list of chunks to send that would fit in one SCTP packet with + // `bytes_remaining_in_packet` bytes available. This may be further limited by + // the congestion control windows. Note that `ShouldSendForwardTSN` must be + // called prior to this method, to abandon expired chunks, as this method will + // not expire any chunks. + std::vector<std::pair<TSN, Data>> GetChunksToSend( + TimeMs now, + size_t bytes_remaining_in_packet); + + // Returns the internal state of all queued chunks. This is only used in + // unit-tests. + std::vector<std::pair<TSN, OutstandingData::State>> GetChunkStatesForTesting() + const { + return outstanding_data_.GetChunkStatesForTesting(); + } + + // Returns the next TSN that will be allocated for sent DATA chunks. + TSN next_tsn() const { return outstanding_data_.next_tsn().Wrap(); } + + // Returns the size of the congestion window, in bytes. This is the number of + // bytes that may be in-flight. + size_t cwnd() const { return cwnd_; } + + // Overrides the current congestion window size. + void set_cwnd(size_t cwnd) { cwnd_ = cwnd; } + + // Returns the current receiver window size. + size_t rwnd() const { return rwnd_; } + + // Returns the number of bytes of packets that are in-flight. + size_t outstanding_bytes() const { + return outstanding_data_.outstanding_bytes(); + } + + // Returns the number of DATA chunks that are in-flight. + size_t outstanding_items() const { + return outstanding_data_.outstanding_items(); + } + + // Indicates if the congestion control algorithm allows data to be sent. + bool can_send_data() const; + + // Given the current time `now`, it will evaluate if there are chunks that + // have expired and that need to be discarded. It returns true if a + // FORWARD-TSN should be sent. + bool ShouldSendForwardTsn(TimeMs now); + + // Creates a FORWARD-TSN chunk. + ForwardTsnChunk CreateForwardTsn() const { + return outstanding_data_.CreateForwardTsn(); + } + + // Creates an I-FORWARD-TSN chunk. + IForwardTsnChunk CreateIForwardTsn() const { + return outstanding_data_.CreateIForwardTsn(); + } + + // See the SendQueue for a longer description of these methods related + // to stream resetting. + void PrepareResetStream(StreamID stream_id); + bool HasStreamsReadyToBeReset() const; + std::vector<StreamID> GetStreamsReadyToBeReset() const { + return send_queue_.GetStreamsReadyToBeReset(); + } + void CommitResetStreams(); + void RollbackResetStreams(); + + HandoverReadinessStatus GetHandoverReadiness() const; + + void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); + + private: + enum class CongestionAlgorithmPhase { + kSlowStart, + kCongestionAvoidance, + }; + + bool IsConsistent() const; + + // Returns how large a chunk will be, serialized, carrying the data + size_t GetSerializedChunkSize(const Data& data) const; + + // Indicates if the congestion control algorithm is in "fast recovery". + bool is_in_fast_recovery() const { + return fast_recovery_exit_tsn_.has_value(); + } + + // Indicates if the provided SACK is valid given what has previously been + // received. If it returns false, the SACK is most likely a duplicate of + // something already seen, so this returning false doesn't necessarily mean + // that the SACK is illegal. + bool IsSackValid(const SackChunk& sack) const; + + // When a SACK chunk is received, this method will be called which _may_ call + // into the `RetransmissionTimeout` to update the RTO. + void UpdateRTT(TimeMs now, UnwrappedTSN cumulative_tsn_ack); + + // If the congestion control is in "fast recovery mode", this may be exited + // now. + void MaybeExitFastRecovery(UnwrappedTSN cumulative_tsn_ack); + + // If chunks have been ACKed, stop the retransmission timer. + void StopT3RtxTimerOnIncreasedCumulativeTsnAck( + UnwrappedTSN cumulative_tsn_ack); + + // Update the congestion control algorithm given as the cumulative ack TSN + // value has increased, as reported in an incoming SACK chunk. + void HandleIncreasedCumulativeTsnAck(size_t outstanding_bytes, + size_t total_bytes_acked); + // Update the congestion control algorithm, given as packet loss has been + // detected, as reported in an incoming SACK chunk. + void HandlePacketLoss(UnwrappedTSN highest_tsn_acked); + // Update the view of the receiver window size. + void UpdateReceiverWindow(uint32_t a_rwnd); + // If there is data sent and not ACKED, ensure that the retransmission timer + // is running. + void StartT3RtxTimerIfOutstandingData(); + + // Returns the current congestion control algorithm phase. + CongestionAlgorithmPhase phase() const { + return (cwnd_ <= ssthresh_) + ? CongestionAlgorithmPhase::kSlowStart + : CongestionAlgorithmPhase::kCongestionAvoidance; + } + + // Returns the number of bytes that may be sent in a single packet according + // to the congestion control algorithm. + size_t max_bytes_to_send() const; + + DcSctpSocketCallbacks& callbacks_; + const DcSctpOptions options_; + // The minimum bytes required to be available in the congestion window to + // allow packets to be sent - to avoid sending too small packets. + const size_t min_bytes_required_to_send_; + // If the peer supports RFC3758 - SCTP Partial Reliability Extension. + const bool partial_reliability_; + const std::string log_prefix_; + // The size of the data chunk (DATA/I-DATA) header that is used. + const size_t data_chunk_header_size_; + // Called when a new RTT measurement has been done + const std::function<void(DurationMs rtt)> on_new_rtt_; + // Called when a SACK has been seen that cleared the retransmission counter. + const std::function<void()> on_clear_retransmission_counter_; + // The retransmission counter. + Timer& t3_rtx_; + // Unwraps TSNs + UnwrappedTSN::Unwrapper tsn_unwrapper_; + + // Congestion Window. Number of bytes that may be in-flight (sent, not acked). + size_t cwnd_; + // Receive Window. Number of bytes available in the receiver's RX buffer. + size_t rwnd_; + // Slow Start Threshold. See RFC4960. + size_t ssthresh_; + // Partial Bytes Acked. See RFC4960. + size_t partial_bytes_acked_; + // If set, fast recovery is enabled until this TSN has been cumulative + // acked. + absl::optional<UnwrappedTSN> fast_recovery_exit_tsn_ = absl::nullopt; + + // The send queue. + SendQueue& send_queue_; + // All the outstanding data chunks that are in-flight and that have not been + // cumulative acked. Note that it also contains chunks that have been acked in + // gap ack blocks. + OutstandingData outstanding_data_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RETRANSMISSION_QUEUE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue_test.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue_test.cc new file mode 100644 index 0000000000..e62c030bfa --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_queue_test.cc @@ -0,0 +1,1593 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_queue.h" + +#include <cstddef> +#include <cstdint> +#include <functional> +#include <memory> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "api/task_queue/task_queue_base.h" +#include "net/dcsctp/common/handover_testing.h" +#include "net/dcsctp/common/math.h" +#include "net/dcsctp/packet/chunk/data_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/forward_tsn_common.h" +#include "net/dcsctp/packet/chunk/iforward_tsn_chunk.h" +#include "net/dcsctp/packet/chunk/sack_chunk.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "net/dcsctp/testing/data_generator.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "net/dcsctp/timer/fake_timeout.h" +#include "net/dcsctp/timer/timer.h" +#include "net/dcsctp/tx/mock_send_queue.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::MockFunction; +using State = ::dcsctp::RetransmissionQueue::State; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::NiceMock; +using ::testing::Pair; +using ::testing::Return; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +constexpr uint32_t kArwnd = 100000; +constexpr uint32_t kMaxMtu = 1191; + +DcSctpOptions MakeOptions() { + DcSctpOptions options; + options.mtu = kMaxMtu; + return options; +} + +class RetransmissionQueueTest : public testing::Test { + protected: + RetransmissionQueueTest() + : options_(MakeOptions()), + gen_(MID(42)), + timeout_manager_([this]() { return now_; }), + timer_manager_([this](webrtc::TaskQueueBase::DelayPrecision precision) { + return timeout_manager_.CreateTimeout(precision); + }), + timer_(timer_manager_.CreateTimer( + "test/t3_rtx", + []() { return absl::nullopt; }, + TimerOptions(options_.rto_initial))) {} + + std::function<SendQueue::DataToSend(TimeMs, size_t)> CreateChunk() { + return [this](TimeMs now, size_t max_size) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }; + } + + std::vector<TSN> GetTSNsForFastRetransmit(RetransmissionQueue& queue) { + std::vector<TSN> tsns; + for (const auto& elem : queue.GetChunksForFastRetransmit(10000)) { + tsns.push_back(elem.first); + } + return tsns; + } + + std::vector<TSN> GetSentPacketTSNs(RetransmissionQueue& queue) { + std::vector<TSN> tsns; + for (const auto& elem : queue.GetChunksToSend(now_, 10000)) { + tsns.push_back(elem.first); + } + return tsns; + } + + RetransmissionQueue CreateQueue(bool supports_partial_reliability = true, + bool use_message_interleaving = false) { + return RetransmissionQueue( + "", &callbacks_, TSN(10), kArwnd, producer_, on_rtt_.AsStdFunction(), + on_clear_retransmission_counter_.AsStdFunction(), *timer_, options_, + supports_partial_reliability, use_message_interleaving); + } + + std::unique_ptr<RetransmissionQueue> CreateQueueByHandover( + RetransmissionQueue& queue) { + EXPECT_EQ(queue.GetHandoverReadiness(), HandoverReadinessStatus()); + DcSctpSocketHandoverState state; + queue.AddHandoverState(state); + g_handover_state_transformer_for_test(&state); + auto queue2 = std::make_unique<RetransmissionQueue>( + "", &callbacks_, TSN(10), kArwnd, producer_, on_rtt_.AsStdFunction(), + on_clear_retransmission_counter_.AsStdFunction(), *timer_, options_, + /*supports_partial_reliability=*/true, + /*use_message_interleaving=*/false); + queue2->RestoreFromState(state); + return queue2; + } + + MockDcSctpSocketCallbacks callbacks_; + DcSctpOptions options_; + DataGenerator gen_; + TimeMs now_ = TimeMs(0); + FakeTimeoutManager timeout_manager_; + TimerManager timer_manager_; + NiceMock<MockFunction<void(DurationMs rtt_ms)>> on_rtt_; + NiceMock<MockFunction<void()>> on_clear_retransmission_counter_; + NiceMock<MockSendQueue> producer_; + std::unique_ptr<Timer> timer_; +}; + +TEST_F(RetransmissionQueueTest, InitialAckedPrevTsn) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, SendOneChunk) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(10))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, SendOneChunkAndAck) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(10))); + + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, SendThreeChunksAndAckTwo) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12))); + + queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, AckWithGapBlocksFromRFC4960Section334) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 5)}, + {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kNacked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kNacked), // + Pair(TSN(17), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, ResendPacketsWhenNackedThreeTimes) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + // Send more chunks, but leave some as gaps to force retransmission after + // three NACKs. + + // Send 18 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(18))); + + // Ack 12, 14-15, 17-18 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 6)}, + {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kNacked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kNacked), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked))); + + // Send 19 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(19))); + + // Ack 12, 14-15, 17-19 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 7)}, + {})); + + // Send 20 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(20))); + + // Ack 12, 14-15, 17-20 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 8)}, + {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kToBeRetransmitted), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kToBeRetransmitted), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kAcked), // + Pair(TSN(20), State::kAcked))); + + // This will trigger "fast retransmit" mode and only chunks 13 and 16 will be + // resent right now. The send queue will not even be queried. + EXPECT_CALL(producer_, Produce).Times(0); + + EXPECT_THAT(GetTSNsForFastRetransmit(queue), + testing::ElementsAre(TSN(13), TSN(16))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kInFlight), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kAcked), // + Pair(TSN(20), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, RestartsT3RtxOnRetransmitFirstOutstandingTSN) { + // Verifies that if fast retransmit is retransmitting the first outstanding + // TSN, it will also restart T3-RTX. + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + static constexpr TimeMs kStartTime(100000); + now_ = kStartTime; + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12))); + + // Ack 10, 12, after 100ms. + now_ += DurationMs(100); + queue.HandleSack( + now_, SackChunk(TSN(10), kArwnd, {SackChunk::GapAckBlock(2, 2)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kNacked), // + Pair(TSN(12), State::kAcked))); + + // Send 13 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(13))); + + // Ack 10, 12-13, after 100ms. + now_ += DurationMs(100); + queue.HandleSack( + now_, SackChunk(TSN(10), kArwnd, {SackChunk::GapAckBlock(2, 3)}, {})); + + // Send 14 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), testing::ElementsAre(TSN(14))); + + // Ack 10, 12-14, after 100 ms. + now_ += DurationMs(100); + queue.HandleSack( + now_, SackChunk(TSN(10), kArwnd, {SackChunk::GapAckBlock(2, 4)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kToBeRetransmitted), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked))); + + // This will trigger "fast retransmit" mode and only chunks 13 and 16 will be + // resent right now. The send queue will not even be queried. + EXPECT_CALL(producer_, Produce).Times(0); + + EXPECT_THAT(GetTSNsForFastRetransmit(queue), testing::ElementsAre(TSN(11))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked))); + + // Verify that the timer was really restarted when fast-retransmitting. The + // timeout is `options_.rto_initial`, so advance the time just before that. + now_ += options_.rto_initial - DurationMs(1); + EXPECT_FALSE(timeout_manager_.GetNextExpiredTimeout().has_value()); + + // And ensure it really is running. + now_ += DurationMs(1); + ASSERT_HAS_VALUE_AND_ASSIGN(TimeoutID timeout, + timeout_manager_.GetNextExpiredTimeout()); + // An expired timeout has to be handled (asserts validate this). + timer_manager_.HandleTimeout(timeout); +} + +TEST_F(RetransmissionQueueTest, CanOnlyProduceTwoPacketsButWantsToSendThree) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }) + .WillOnce([this](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, RetransmitsOnT3Expiry) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered({1, 2, 3, 4}, "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + // Will force chunks to be retransmitted + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + + std::vector<std::pair<TSN, Data>> chunks_to_rtx = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_rtx, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, LimitedRetransmissionOnlyWithRfc3758Support) { + RetransmissionQueue queue = + CreateQueue(/*supports_partial_reliability=*/false); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + // Will force chunks to be retransmitted + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); +} // namespace dcsctp + +TEST_F(RetransmissionQueueTest, LimitsRetransmissionsAsUdp) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + // Will force chunks to be retransmitted + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(1); + + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); + + std::vector<std::pair<TSN, Data>> chunks_to_rtx = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_rtx, testing::IsEmpty()); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); +} + +TEST_F(RetransmissionQueueTest, LimitsRetransmissionsToThreeSends) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = MaxRetransmits(3); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + + // Retransmission 1 + queue.HandleT3RtxTimerExpiry(); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), SizeIs(1)); + + // Retransmission 2 + queue.HandleT3RtxTimerExpiry(); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), SizeIs(1)); + + // Retransmission 3 + queue.HandleT3RtxTimerExpiry(); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), SizeIs(1)); + + // Retransmission 4 - not allowed. + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(1); + queue.HandleT3RtxTimerExpiry(); + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), IsEmpty()); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned))); +} + +TEST_F(RetransmissionQueueTest, RetransmitsWhenSendBufferIsFullT3Expiry) { + RetransmissionQueue queue = CreateQueue(); + static constexpr size_t kCwnd = 1200; + queue.set_cwnd(kCwnd); + EXPECT_EQ(queue.cwnd(), kCwnd); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); + + std::vector<uint8_t> payload(1000); + EXPECT_CALL(producer_, Produce) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1500); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + EXPECT_EQ(queue.outstanding_bytes(), payload.size() + DataChunk::kHeaderSize); + EXPECT_EQ(queue.outstanding_items(), 1u); + + // Will force chunks to be retransmitted + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted))); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); + + std::vector<std::pair<TSN, Data>> chunks_to_rtx = + queue.GetChunksToSend(now_, 1500); + EXPECT_THAT(chunks_to_rtx, ElementsAre(Pair(TSN(10), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight))); + EXPECT_EQ(queue.outstanding_bytes(), payload.size() + DataChunk::kHeaderSize); + EXPECT_EQ(queue.outstanding_items(), 1u); +} + +TEST_F(RetransmissionQueueTest, ProducesValidForwardTsn) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + // Send and ack first chunk (TSN 10) + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight))); + + // Chunk 10 is acked, but the remaining are lost + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(true)); + + queue.HandleT3RtxTimerExpiry(); + + // NOTE: The TSN=13 represents the end fragment. + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + ForwardTsnChunk forward_tsn = queue.CreateForwardTsn(); + EXPECT_EQ(forward_tsn.new_cumulative_tsn(), TSN(13)); + EXPECT_THAT(forward_tsn.skipped_streams(), + UnorderedElementsAre( + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(42)))); +} + +TEST_F(RetransmissionQueueTest, ProducesValidForwardTsnWhenFullySent) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "E")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + // Send and ack first chunk (TSN 10) + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight))); + + // Chunk 10 is acked, but the remaining are lost + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + ForwardTsnChunk forward_tsn = queue.CreateForwardTsn(); + EXPECT_EQ(forward_tsn.new_cumulative_tsn(), TSN(12)); + EXPECT_THAT(forward_tsn.skipped_streams(), + UnorderedElementsAre( + ForwardTsnChunk::SkippedStream(StreamID(1), SSN(42)))); +} + +TEST_F(RetransmissionQueueTest, ProducesValidIForwardTsn) { + RetransmissionQueue queue = CreateQueue(/*use_message_interleaving=*/true); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(1); + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B", opts)); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(2); + SendQueue::DataToSend dts(gen_.Unordered({1, 2, 3, 4}, "B", opts)); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(3); + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "B", opts)); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + DataGeneratorOptions opts; + opts.stream_id = StreamID(4); + SendQueue::DataToSend dts(gen_.Ordered({13, 14, 15, 16}, "B", opts)); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _), Pair(TSN(13), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight))); + + // Chunk 13 is acked, but the remaining are lost + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(4, 4)}, {})); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kNacked), // + Pair(TSN(12), State::kNacked), // + Pair(TSN(13), State::kAcked))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(true)); + EXPECT_CALL(producer_, Discard(IsUnordered(true), StreamID(2), MID(42))) + .WillOnce(Return(true)); + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(3), MID(42))) + .WillOnce(Return(true)); + + queue.HandleT3RtxTimerExpiry(); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned), // + Pair(TSN(13), State::kAcked), + // Representing end fragments of stream 1-3 + Pair(TSN(14), State::kAbandoned), // + Pair(TSN(15), State::kAbandoned), // + Pair(TSN(16), State::kAbandoned))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + IForwardTsnChunk forward_tsn1 = queue.CreateIForwardTsn(); + EXPECT_EQ(forward_tsn1.new_cumulative_tsn(), TSN(12)); + EXPECT_THAT( + forward_tsn1.skipped_streams(), + UnorderedElementsAre(IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(1), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(true), StreamID(2), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(3), MID(42)))); + + // When TSN 13 is acked, the placeholder "end fragments" must be skipped as + // well. + + // A receiver is more likely to ack TSN 13, but do it incrementally. + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, {}, {})); + + EXPECT_CALL(producer_, Discard).Times(0); + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + queue.HandleSack(now_, SackChunk(TSN(13), kArwnd, {}, {})); + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAbandoned), // + Pair(TSN(15), State::kAbandoned), // + Pair(TSN(16), State::kAbandoned))); + + IForwardTsnChunk forward_tsn2 = queue.CreateIForwardTsn(); + EXPECT_EQ(forward_tsn2.new_cumulative_tsn(), TSN(16)); + EXPECT_THAT( + forward_tsn2.skipped_streams(), + UnorderedElementsAre(IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(1), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(true), StreamID(2), MID(42)), + IForwardTsnChunk::SkippedStream( + IsUnordered(false), StreamID(3), MID(42)))); +} + +TEST_F(RetransmissionQueueTest, MeasureRTT) { + RetransmissionQueue queue = CreateQueue(/*use_message_interleaving=*/true); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + + now_ = now_ + DurationMs(123); + + EXPECT_CALL(on_rtt_, Call(DurationMs(123))).Times(1); + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); +} + +TEST_F(RetransmissionQueueTest, ValidateCumTsnAtRest) { + RetransmissionQueue queue = CreateQueue(/*use_message_interleaving=*/true); + + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(8), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(9), kArwnd, {}, {}))); + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {}))); +} + +TEST_F(RetransmissionQueueTest, ValidateCumTsnAckOnInflightData) { + RetransmissionQueue queue = CreateQueue(); + + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(8), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(9), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(13), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(14), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(15), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(16), kArwnd, {}, {}))); + EXPECT_TRUE(queue.HandleSack(now_, SackChunk(TSN(17), kArwnd, {}, {}))); + EXPECT_FALSE(queue.HandleSack(now_, SackChunk(TSN(18), kArwnd, {}, {}))); +} + +TEST_F(RetransmissionQueueTest, HandleGapAckBlocksMatchingNoInflightData) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + // Ack 9, 20-25. This is an invalid SACK, but should still be handled. + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(11, 16)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight), // + Pair(TSN(14), State::kInFlight), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, HandleInvalidGapAckBlocks) { + RetransmissionQueue queue = CreateQueue(); + + // Nothing produced - nothing in retransmission queue + + // Ack 9, 12-13 + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(3, 4)}, {})); + + // Gap ack blocks are just ignore. + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked))); +} + +TEST_F(RetransmissionQueueTest, GapAckBlocksDoNotMoveCumTsnAck) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), + testing::ElementsAre(TSN(10), TSN(11), TSN(12), TSN(13), TSN(14), + TSN(15), TSN(16), TSN(17))); + + // Ack 9, 10-14. This is actually an invalid ACK as the first gap can't be + // adjacent to the cum-tsn-ack, but it's not strictly forbidden. However, the + // cum-tsn-ack should not move, as the gap-ack-blocks are just advisory. + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(1, 5)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAcked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, StaysWithinAvailableSize) { + RetransmissionQueue queue = CreateQueue(); + + // See SctpPacketTest::ReturnsCorrectSpaceAvailableToStayWithinMTU for the + // magic numbers in this test. + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t size) { + EXPECT_EQ(size, 1176 - DataChunk::kHeaderSize); + + std::vector<uint8_t> payload(183); + return SendQueue::DataToSend(gen_.Ordered(payload, "BE")); + }) + .WillOnce([this](TimeMs, size_t size) { + EXPECT_EQ(size, 976 - DataChunk::kHeaderSize); + + std::vector<uint8_t> payload(957); + return SendQueue::DataToSend(gen_.Ordered(payload, "BE")); + }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1188 - 12); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _))); +} + +TEST_F(RetransmissionQueueTest, AccountsNackedAbandonedChunksAsNotOutstanding) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({9, 10, 11, 12}, "")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + // Send and ack first chunk (TSN 10) + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight))); + EXPECT_EQ(queue.outstanding_bytes(), (16 + 4) * 3u); + EXPECT_EQ(queue.outstanding_items(), 3u); + + // Mark the message as lost. + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(1); + queue.HandleT3RtxTimerExpiry(); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAbandoned), // + Pair(TSN(12), State::kAbandoned))); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); + + // Now ACK those, one at a time. + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); + + queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {})); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); + + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, {}, {})); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.outstanding_items(), 0u); +} + +TEST_F(RetransmissionQueueTest, ExpireFromSendQueueWhenPartiallySent) { + RetransmissionQueue queue = CreateQueue(); + DataGeneratorOptions options; + options.stream_id = StreamID(17); + options.message_id = MID(42); + TimeMs test_start = now_; + EXPECT_CALL(producer_, Produce) + .WillOnce([&](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "B", options)); + dts.expires_at = TimeMs(test_start + DurationMs(10)); + return dts; + }) + .WillOnce([&](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({5, 6, 7, 8}, "", options)); + dts.expires_at = TimeMs(test_start + DurationMs(10)); + return dts; + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 24); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(17), MID(42))) + .WillOnce(Return(true)); + now_ += DurationMs(100); + + EXPECT_THAT(queue.GetChunksToSend(now_, 24), IsEmpty()); + + EXPECT_THAT( + queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // Initial TSN + Pair(TSN(10), State::kAbandoned), // Produced + Pair(TSN(11), State::kAbandoned), // Produced and expired + Pair(TSN(12), State::kAbandoned))); // Placeholder end +} + +TEST_F(RetransmissionQueueTest, LimitsRetransmissionsOnlyWhenNackedThreeTimes) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = MaxRetransmits(0); + return dts; + }) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), + Pair(TSN(12), _), Pair(TSN(13), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 2)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 3)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 4)}, {})); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); +} + +TEST_F(RetransmissionQueueTest, AbandonsRtxLimit2WhenNackedNineTimes) { + // This is a fairly long test. + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce([this](TimeMs, size_t) { + SendQueue::DataToSend dts(gen_.Ordered({1, 2, 3, 4}, "BE")); + dts.max_retransmissions = MaxRetransmits(2); + return dts; + }) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1000); + EXPECT_THAT(chunks_to_send, + ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), Pair(TSN(12), _), + Pair(TSN(13), _), Pair(TSN(14), _), Pair(TSN(15), _), + Pair(TSN(16), _), Pair(TSN(17), _), Pair(TSN(18), _), + Pair(TSN(19), _))); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), // + Pair(TSN(13), State::kInFlight), // + Pair(TSN(14), State::kInFlight), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight), // + Pair(TSN(18), State::kInFlight), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .Times(0); + + // Ack TSN [11 to 13] - three nacks for TSN(10), which will retransmit it. + for (int tsn = 11; tsn <= 13; ++tsn) { + queue.HandleSack( + now_, + SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, (tsn - 9))}, {})); + } + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kInFlight), // + Pair(TSN(15), State::kInFlight), // + Pair(TSN(16), State::kInFlight), // + Pair(TSN(17), State::kInFlight), // + Pair(TSN(18), State::kInFlight), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_THAT(queue.GetChunksForFastRetransmit(1000), + ElementsAre(Pair(TSN(10), _))); + + // Ack TSN [14 to 16] - three more nacks - second and last retransmission. + for (int tsn = 14; tsn <= 16; ++tsn) { + queue.HandleSack( + now_, + SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, (tsn - 9))}, {})); + } + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kToBeRetransmitted), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kAcked), // + Pair(TSN(17), State::kInFlight), // + Pair(TSN(18), State::kInFlight), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), ElementsAre(Pair(TSN(10), _))); + + // Ack TSN [17 to 18] + for (int tsn = 17; tsn <= 18; ++tsn) { + queue.HandleSack( + now_, + SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, (tsn - 9))}, {})); + } + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kNacked), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kAcked), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kInFlight))); + + EXPECT_FALSE(queue.ShouldSendForwardTsn(now_)); + + // Ack TSN 19 - three more nacks for TSN 10, no more retransmissions. + EXPECT_CALL(producer_, Discard(IsUnordered(false), StreamID(1), MID(42))) + .WillOnce(Return(false)); + queue.HandleSack( + now_, SackChunk(TSN(9), kArwnd, {SackChunk::GapAckBlock(2, 10)}, {})); + + EXPECT_THAT(queue.GetChunksToSend(now_, 1000), IsEmpty()); + + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kAbandoned), // + Pair(TSN(11), State::kAcked), // + Pair(TSN(12), State::kAcked), // + Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kAcked), // + Pair(TSN(15), State::kAcked), // + Pair(TSN(16), State::kAcked), // + Pair(TSN(17), State::kAcked), // + Pair(TSN(18), State::kAcked), // + Pair(TSN(19), State::kAcked))); + + EXPECT_TRUE(queue.ShouldSendForwardTsn(now_)); +} + +TEST_F(RetransmissionQueueTest, CwndRecoversWhenAcking) { + RetransmissionQueue queue = CreateQueue(); + static constexpr size_t kCwnd = 1200; + queue.set_cwnd(kCwnd); + EXPECT_EQ(queue.cwnd(), kCwnd); + + std::vector<uint8_t> payload(1000); + EXPECT_CALL(producer_, Produce) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 1500); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + size_t serialized_size = payload.size() + DataChunk::kHeaderSize; + EXPECT_EQ(queue.outstanding_bytes(), serialized_size); + + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_EQ(queue.cwnd(), kCwnd + serialized_size); +} + +// Verifies that it doesn't produce tiny packets, when getting close to +// the full congestion window. +TEST_F(RetransmissionQueueTest, OnlySendsLargePacketsOnLargeCongestionWindow) { + RetransmissionQueue queue = CreateQueue(); + size_t intial_cwnd = options_.avoid_fragmentation_cwnd_mtus * options_.mtu; + queue.set_cwnd(intial_cwnd); + EXPECT_EQ(queue.cwnd(), intial_cwnd); + + // Fill the congestion window almost - leaving 500 bytes. + size_t chunk_size = intial_cwnd - 500; + EXPECT_CALL(producer_, Produce) + .WillOnce([chunk_size, this](TimeMs, size_t) { + return SendQueue::DataToSend( + gen_.Ordered(std::vector<uint8_t>(chunk_size), "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_TRUE(queue.can_send_data()); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 10000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + + // To little space left - will not send more. + EXPECT_FALSE(queue.can_send_data()); + + // But when the first chunk is acked, it will continue. + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + + EXPECT_TRUE(queue.can_send_data()); + EXPECT_EQ(queue.outstanding_bytes(), 0u); + EXPECT_EQ(queue.cwnd(), intial_cwnd + kMaxMtu); +} + +TEST_F(RetransmissionQueueTest, AllowsSmallFragmentsOnSmallCongestionWindow) { + RetransmissionQueue queue = CreateQueue(); + size_t intial_cwnd = + options_.avoid_fragmentation_cwnd_mtus * options_.mtu - 1; + queue.set_cwnd(intial_cwnd); + EXPECT_EQ(queue.cwnd(), intial_cwnd); + + // Fill the congestion window almost - leaving 500 bytes. + size_t chunk_size = intial_cwnd - 500; + EXPECT_CALL(producer_, Produce) + .WillOnce([chunk_size, this](TimeMs, size_t) { + return SendQueue::DataToSend( + gen_.Ordered(std::vector<uint8_t>(chunk_size), "BE")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_TRUE(queue.can_send_data()); + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 10000); + EXPECT_THAT(chunks_to_send, ElementsAre(Pair(TSN(10), _))); + + // With congestion window under limit, allow small packets to be created. + EXPECT_TRUE(queue.can_send_data()); +} + +TEST_F(RetransmissionQueueTest, ReadyForHandoverWhenHasNoOutstandingData) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1)); + EXPECT_EQ( + queue.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kRetransmissionQueueOutstandingData)); + + queue.HandleSack(now_, SackChunk(TSN(10), kArwnd, {}, {})); + EXPECT_EQ(queue.GetHandoverReadiness(), HandoverReadinessStatus()); +} + +TEST_F(RetransmissionQueueTest, ReadyForHandoverWhenNothingToRetransmit) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(8)); + EXPECT_EQ( + queue.GetHandoverReadiness(), + HandoverReadinessStatus( + HandoverUnreadinessReason::kRetransmissionQueueOutstandingData)); + + // Send more chunks, but leave some chunks unacked to force retransmission + // after three NACKs. + + // Send 18 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1)); + + // Ack 12, 14-15, 17-18 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 6)}, + {})); + + // Send 19 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1)); + + // Ack 12, 14-15, 17-19 + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 7)}, + {})); + + // Send 20 + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(1)); + + // Ack 12, 14-15, 17-20 + // This will trigger "fast retransmit" mode and only chunks 13 and 16 will be + // resent right now. The send queue will not even be queried. + queue.HandleSack(now_, SackChunk(TSN(12), kArwnd, + {SackChunk::GapAckBlock(2, 3), + SackChunk::GapAckBlock(5, 8)}, + {})); + EXPECT_EQ( + queue.GetHandoverReadiness(), + HandoverReadinessStatus() + .Add(HandoverUnreadinessReason::kRetransmissionQueueOutstandingData) + .Add(HandoverUnreadinessReason::kRetransmissionQueueFastRecovery) + .Add(HandoverUnreadinessReason::kRetransmissionQueueNotEmpty)); + + // Send "fast retransmit" mode chunks + EXPECT_CALL(producer_, Produce).Times(0); + EXPECT_THAT(GetTSNsForFastRetransmit(queue), SizeIs(2)); + EXPECT_EQ( + queue.GetHandoverReadiness(), + HandoverReadinessStatus() + .Add(HandoverUnreadinessReason::kRetransmissionQueueOutstandingData) + .Add(HandoverUnreadinessReason::kRetransmissionQueueFastRecovery)); + + // Ack 20 to confirm the retransmission + queue.HandleSack(now_, SackChunk(TSN(20), kArwnd, {}, {})); + EXPECT_EQ(queue.GetHandoverReadiness(), HandoverReadinessStatus()); +} + +TEST_F(RetransmissionQueueTest, HandoverTest) { + RetransmissionQueue queue = CreateQueue(); + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(queue), SizeIs(2)); + queue.HandleSack(now_, SackChunk(TSN(11), kArwnd, {}, {})); + + std::unique_ptr<RetransmissionQueue> handedover_queue = + CreateQueueByHandover(queue); + + EXPECT_CALL(producer_, Produce) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillOnce(CreateChunk()) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + EXPECT_THAT(GetSentPacketTSNs(*handedover_queue), + testing::ElementsAre(TSN(12), TSN(13), TSN(14))); + + handedover_queue->HandleSack(now_, SackChunk(TSN(13), kArwnd, {}, {})); + EXPECT_THAT(handedover_queue->GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(13), State::kAcked), // + Pair(TSN(14), State::kInFlight))); +} + +TEST_F(RetransmissionQueueTest, CanAlwaysSendOnePacket) { + RetransmissionQueue queue = CreateQueue(); + + // A large payload - enough to not fit two DATA in same packet. + size_t mtu = RoundDownTo4(options_.mtu); + std::vector<uint8_t> payload(mtu - 100); + + EXPECT_CALL(producer_, Produce) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "B")); + }) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "")); + }) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "")); + }) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "")); + }) + .WillOnce([this, payload](TimeMs, size_t) { + return SendQueue::DataToSend(gen_.Ordered(payload, "E")); + }) + .WillRepeatedly([](TimeMs, size_t) { return absl::nullopt; }); + + // Produce all chunks and put them in the retransmission queue. + std::vector<std::pair<TSN, Data>> chunks_to_send = + queue.GetChunksToSend(now_, 5 * mtu); + EXPECT_THAT(chunks_to_send, + ElementsAre(Pair(TSN(10), _), Pair(TSN(11), _), Pair(TSN(12), _), + Pair(TSN(13), _), Pair(TSN(14), _))); + EXPECT_THAT(queue.GetChunkStatesForTesting(), + ElementsAre(Pair(TSN(9), State::kAcked), // + Pair(TSN(10), State::kInFlight), // + Pair(TSN(11), State::kInFlight), // + Pair(TSN(12), State::kInFlight), + Pair(TSN(13), State::kInFlight), + Pair(TSN(14), State::kInFlight))); + + // Ack 12, and report an empty receiver window (the peer obviously has a + // tiny receive window). + queue.HandleSack( + now_, SackChunk(TSN(9), /*rwnd=*/0, {SackChunk::GapAckBlock(3, 3)}, {})); + + // Force TSN 10 to be retransmitted. + queue.HandleT3RtxTimerExpiry(); + + // Even if the receiver window is empty, it will allow TSN 10 to be sent. + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), ElementsAre(Pair(TSN(10), _))); + + // But not more than that, as there now is outstanding data. + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), IsEmpty()); + + // Don't ack any new data, and still have receiver window zero. + queue.HandleSack( + now_, SackChunk(TSN(9), /*rwnd=*/0, {SackChunk::GapAckBlock(3, 3)}, {})); + + // There is in-flight data, so new data should not be allowed to be send since + // the receiver window is full. + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), IsEmpty()); + + // Ack that packet (no more in-flight data), but still report an empty + // receiver window. + queue.HandleSack( + now_, SackChunk(TSN(10), /*rwnd=*/0, {SackChunk::GapAckBlock(2, 2)}, {})); + + // Then TSN 11 can be sent, as there is no in-flight data. + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), ElementsAre(Pair(TSN(11), _))); + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), IsEmpty()); + + // Ack and recover the receiver window + queue.HandleSack(now_, SackChunk(TSN(12), /*rwnd=*/5 * mtu, {}, {})); + + // That will unblock sending remaining chunks. + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), ElementsAre(Pair(TSN(13), _))); + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), ElementsAre(Pair(TSN(14), _))); + EXPECT_THAT(queue.GetChunksToSend(now_, mtu), IsEmpty()); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.cc new file mode 100644 index 0000000000..aa2863f931 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.cc @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_timeout.h" + +#include <algorithm> +#include <cstdint> + +#include "net/dcsctp/public/dcsctp_options.h" + +namespace dcsctp { + +RetransmissionTimeout::RetransmissionTimeout(const DcSctpOptions& options) + : min_rto_(*options.rto_min), + max_rto_(*options.rto_max), + max_rtt_(*options.rtt_max), + min_rtt_variance_(*options.min_rtt_variance), + rto_(*options.rto_initial) {} + +void RetransmissionTimeout::ObserveRTT(DurationMs measured_rtt) { + const int32_t rtt = *measured_rtt; + + // Unrealistic values will be skipped. If a wrongly measured (or otherwise + // corrupt) value was processed, it could change the state in a way that would + // take a very long time to recover. + if (rtt < 0 || rtt > max_rtt_) { + return; + } + + // From https://tools.ietf.org/html/rfc4960#section-6.3.1, but avoiding + // floating point math by implementing algorithm from "V. Jacobson: Congestion + // avoidance and control", but adapted for SCTP. + if (first_measurement_) { + scaled_srtt_ = rtt << kRttShift; + scaled_rtt_var_ = (rtt / 2) << kRttVarShift; + first_measurement_ = false; + } else { + int32_t rtt_diff = rtt - (scaled_srtt_ >> kRttShift); + scaled_srtt_ += rtt_diff; + if (rtt_diff < 0) { + rtt_diff = -rtt_diff; + } + rtt_diff -= (scaled_rtt_var_ >> kRttVarShift); + scaled_rtt_var_ += rtt_diff; + } + + if (scaled_rtt_var_ < min_rtt_variance_) { + scaled_rtt_var_ = min_rtt_variance_; + } + + rto_ = (scaled_srtt_ >> kRttShift) + scaled_rtt_var_; + + // Clamp RTO between min and max. + rto_ = std::min(std::max(rto_, min_rto_), max_rto_); +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.h b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.h new file mode 100644 index 0000000000..7cbcc6fcc9 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RETRANSMISSION_TIMEOUT_H_ +#define NET_DCSCTP_TX_RETRANSMISSION_TIMEOUT_H_ + +#include <cstdint> +#include <functional> + +#include "net/dcsctp/public/dcsctp_options.h" + +namespace dcsctp { + +// Manages updating of the Retransmission Timeout (RTO) SCTP variable, which is +// used directly as the base timeout for T3-RTX and for other timers, such as +// delayed ack. +// +// When a round-trip-time (RTT) is calculated (outside this class), `Observe` +// is called, which calculates the retransmission timeout (RTO) value. The RTO +// value will become larger if the RTT is high and/or the RTT values are varying +// a lot, which is an indicator of a bad connection. +class RetransmissionTimeout { + public: + static constexpr int kRttShift = 3; + static constexpr int kRttVarShift = 2; + explicit RetransmissionTimeout(const DcSctpOptions& options); + + // To be called when a RTT has been measured, to update the RTO value. + void ObserveRTT(DurationMs measured_rtt); + + // Returns the Retransmission Timeout (RTO) value, in milliseconds. + DurationMs rto() const { return DurationMs(rto_); } + + // Returns the smoothed RTT value, in milliseconds. + DurationMs srtt() const { return DurationMs(scaled_srtt_ >> kRttShift); } + + private: + const int32_t min_rto_; + const int32_t max_rto_; + const int32_t max_rtt_; + const int32_t min_rtt_variance_; + // If this is the first measurement + bool first_measurement_ = true; + // Smoothed Round-Trip Time, shifted by kRttShift + int32_t scaled_srtt_ = 0; + // Round-Trip Time Variation, shifted by kRttVarShift + int32_t scaled_rtt_var_ = 0; + // Retransmission Timeout + int32_t rto_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RETRANSMISSION_TIMEOUT_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout_test.cc b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout_test.cc new file mode 100644 index 0000000000..f3b20a86ba --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/retransmission_timeout_test.cc @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/retransmission_timeout.h" + +#include "net/dcsctp/public/dcsctp_options.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { + +constexpr DurationMs kMaxRtt = DurationMs(8'000); +constexpr DurationMs kInitialRto = DurationMs(200); +constexpr DurationMs kMaxRto = DurationMs(800); +constexpr DurationMs kMinRto = DurationMs(120); +constexpr DurationMs kMinRttVariance = DurationMs(220); + +DcSctpOptions MakeOptions() { + DcSctpOptions options; + options.rtt_max = kMaxRtt; + options.rto_initial = kInitialRto; + options.rto_max = kMaxRto; + options.rto_min = kMinRto; + options.min_rtt_variance = kMinRttVariance; + return options; +} + +TEST(RetransmissionTimeoutTest, HasValidInitialRto) { + RetransmissionTimeout rto_(MakeOptions()); + EXPECT_EQ(rto_.rto(), kInitialRto); +} + +TEST(RetransmissionTimeoutTest, NegativeValuesDoNotAffectRTO) { + RetransmissionTimeout rto_(MakeOptions()); + // Initial negative value + rto_.ObserveRTT(DurationMs(-10)); + EXPECT_EQ(rto_.rto(), kInitialRto); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + // Subsequent negative value + rto_.ObserveRTT(DurationMs(-10)); + EXPECT_EQ(*rto_.rto(), 372); +} + +TEST(RetransmissionTimeoutTest, TooLargeValuesDoNotAffectRTO) { + RetransmissionTimeout rto_(MakeOptions()); + // Initial too large value + rto_.ObserveRTT(kMaxRtt + DurationMs(100)); + EXPECT_EQ(rto_.rto(), kInitialRto); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + // Subsequent too large value + rto_.ObserveRTT(kMaxRtt + DurationMs(100)); + EXPECT_EQ(*rto_.rto(), 372); +} + +TEST(RetransmissionTimeoutTest, WillNeverGoBelowMinimumRto) { + RetransmissionTimeout rto_(MakeOptions()); + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(DurationMs(1)); + } + EXPECT_GE(rto_.rto(), kMinRto); +} + +TEST(RetransmissionTimeoutTest, WillNeverGoAboveMaximumRto) { + RetransmissionTimeout rto_(MakeOptions()); + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(kMaxRtt - DurationMs(1)); + // Adding jitter, which would make it RTO be well above RTT. + rto_.ObserveRTT(kMaxRtt - DurationMs(100)); + } + EXPECT_LE(rto_.rto(), kMaxRto); +} + +TEST(RetransmissionTimeoutTest, CalculatesRtoForStableRtt) { + RetransmissionTimeout rto_(MakeOptions()); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + rto_.ObserveRTT(DurationMs(128)); + EXPECT_EQ(*rto_.rto(), 344); + rto_.ObserveRTT(DurationMs(123)); + EXPECT_EQ(*rto_.rto(), 344); + rto_.ObserveRTT(DurationMs(125)); + EXPECT_EQ(*rto_.rto(), 344); + rto_.ObserveRTT(DurationMs(127)); + EXPECT_EQ(*rto_.rto(), 344); +} + +TEST(RetransmissionTimeoutTest, CalculatesRtoForUnstableRtt) { + RetransmissionTimeout rto_(MakeOptions()); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + rto_.ObserveRTT(DurationMs(402)); + EXPECT_EQ(*rto_.rto(), 622); + rto_.ObserveRTT(DurationMs(728)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(89)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(126)); + EXPECT_EQ(*rto_.rto(), 800); +} + +TEST(RetransmissionTimeoutTest, WillStabilizeAfterAWhile) { + RetransmissionTimeout rto_(MakeOptions()); + rto_.ObserveRTT(DurationMs(124)); + rto_.ObserveRTT(DurationMs(402)); + rto_.ObserveRTT(DurationMs(728)); + rto_.ObserveRTT(DurationMs(89)); + rto_.ObserveRTT(DurationMs(126)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 800); + rto_.ObserveRTT(DurationMs(122)); + EXPECT_EQ(*rto_.rto(), 710); + rto_.ObserveRTT(DurationMs(123)); + EXPECT_EQ(*rto_.rto(), 631); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 562); + rto_.ObserveRTT(DurationMs(122)); + EXPECT_EQ(*rto_.rto(), 505); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 454); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 410); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 372); + rto_.ObserveRTT(DurationMs(124)); + EXPECT_EQ(*rto_.rto(), 367); +} + +TEST(RetransmissionTimeoutTest, WillAlwaysStayAboveRTT) { + // In simulations, it's quite common to have a very stable RTT, and having an + // RTO at the same value will cause issues as expiry timers will be scheduled + // to be expire exactly when a packet is supposed to arrive. The RTO must be + // larger than the RTT. In non-simulated environments, this is a non-issue as + // any jitter will increase the RTO. + RetransmissionTimeout rto_(MakeOptions()); + + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(DurationMs(124)); + } + EXPECT_EQ(*rto_.rto(), 344); +} + +TEST(RetransmissionTimeoutTest, CanSpecifySmallerMinimumRttVariance) { + DcSctpOptions options = MakeOptions(); + options.min_rtt_variance = kMinRttVariance - DurationMs(100); + RetransmissionTimeout rto_(options); + + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(DurationMs(124)); + } + EXPECT_EQ(*rto_.rto(), 244); +} + +TEST(RetransmissionTimeoutTest, CanSpecifyLargerMinimumRttVariance) { + DcSctpOptions options = MakeOptions(); + options.min_rtt_variance = kMinRttVariance + DurationMs(100); + RetransmissionTimeout rto_(options); + + for (int i = 0; i < 1000; ++i) { + rto_.ObserveRTT(DurationMs(124)); + } + EXPECT_EQ(*rto_.rto(), 444); +} + +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.cc b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.cc new file mode 100644 index 0000000000..b1812f0f8a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.cc @@ -0,0 +1,542 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/rr_send_queue.h" + +#include <cstdint> +#include <deque> +#include <limits> +#include <map> +#include <set> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +RRSendQueue::RRSendQueue(absl::string_view log_prefix, + DcSctpSocketCallbacks* callbacks, + size_t buffer_size, + size_t mtu, + StreamPriority default_priority, + size_t total_buffered_amount_low_threshold) + : log_prefix_(std::string(log_prefix) + "fcfs: "), + callbacks_(*callbacks), + buffer_size_(buffer_size), + default_priority_(default_priority), + scheduler_(mtu), + total_buffered_amount_( + [this]() { callbacks_.OnTotalBufferedAmountLow(); }) { + total_buffered_amount_.SetLowThreshold(total_buffered_amount_low_threshold); +} + +size_t RRSendQueue::OutgoingStream::bytes_to_send_in_next_message() const { + if (pause_state_ == PauseState::kPaused || + pause_state_ == PauseState::kResetting) { + // The stream has paused (and there is no partially sent message). + return 0; + } + + if (items_.empty()) { + return 0; + } + + return items_.front().remaining_size; +} + +void RRSendQueue::OutgoingStream::AddHandoverState( + DcSctpSocketHandoverState::OutgoingStream& state) const { + state.next_ssn = next_ssn_.value(); + state.next_ordered_mid = next_ordered_mid_.value(); + state.next_unordered_mid = next_unordered_mid_.value(); + state.priority = *scheduler_stream_->priority(); +} + +bool RRSendQueue::IsConsistent() const { + std::set<StreamID> expected_active_streams; + std::set<StreamID> actual_active_streams = + scheduler_.ActiveStreamsForTesting(); + + size_t total_buffered_amount = 0; + for (const auto& [stream_id, stream] : streams_) { + total_buffered_amount += stream.buffered_amount().value(); + if (stream.bytes_to_send_in_next_message() > 0) { + expected_active_streams.emplace(stream_id); + } + } + if (expected_active_streams != actual_active_streams) { + auto fn = [&](rtc::StringBuilder& sb, const auto& p) { sb << *p; }; + RTC_DLOG(LS_ERROR) << "Active streams mismatch, is=[" + << StrJoin(actual_active_streams, ",", fn) + << "], expected=[" + << StrJoin(expected_active_streams, ",", fn) << "]"; + return false; + } + + return total_buffered_amount == total_buffered_amount_.value(); +} + +bool RRSendQueue::OutgoingStream::IsConsistent() const { + size_t bytes = 0; + for (const auto& item : items_) { + bytes += item.remaining_size; + } + return bytes == buffered_amount_.value(); +} + +void RRSendQueue::ThresholdWatcher::Decrease(size_t bytes) { + RTC_DCHECK(bytes <= value_); + size_t old_value = value_; + value_ -= bytes; + + if (old_value > low_threshold_ && value_ <= low_threshold_) { + on_threshold_reached_(); + } +} + +void RRSendQueue::ThresholdWatcher::SetLowThreshold(size_t low_threshold) { + // Betting on https://github.com/w3c/webrtc-pc/issues/2654 being accepted. + if (low_threshold_ < value_ && low_threshold >= value_) { + on_threshold_reached_(); + } + low_threshold_ = low_threshold; +} + +void RRSendQueue::OutgoingStream::Add(DcSctpMessage message, + MessageAttributes attributes) { + bool was_active = bytes_to_send_in_next_message() > 0; + buffered_amount_.Increase(message.payload().size()); + parent_.total_buffered_amount_.Increase(message.payload().size()); + items_.emplace_back(std::move(message), std::move(attributes)); + + if (!was_active) { + scheduler_stream_->MaybeMakeActive(); + } + + RTC_DCHECK(IsConsistent()); +} + +absl::optional<SendQueue::DataToSend> RRSendQueue::OutgoingStream::Produce( + TimeMs now, + size_t max_size) { + RTC_DCHECK(pause_state_ != PauseState::kPaused && + pause_state_ != PauseState::kResetting); + + while (!items_.empty()) { + Item& item = items_.front(); + DcSctpMessage& message = item.message; + + // Allocate Message ID and SSN when the first fragment is sent. + if (!item.message_id.has_value()) { + // Oops, this entire message has already expired. Try the next one. + if (item.attributes.expires_at <= now) { + HandleMessageExpired(item); + items_.pop_front(); + continue; + } + + MID& mid = + item.attributes.unordered ? next_unordered_mid_ : next_ordered_mid_; + item.message_id = mid; + mid = MID(*mid + 1); + } + if (!item.attributes.unordered && !item.ssn.has_value()) { + item.ssn = next_ssn_; + next_ssn_ = SSN(*next_ssn_ + 1); + } + + // Grab the next `max_size` fragment from this message and calculate flags. + rtc::ArrayView<const uint8_t> chunk_payload = + item.message.payload().subview(item.remaining_offset, max_size); + rtc::ArrayView<const uint8_t> message_payload = message.payload(); + Data::IsBeginning is_beginning(chunk_payload.data() == + message_payload.data()); + Data::IsEnd is_end((chunk_payload.data() + chunk_payload.size()) == + (message_payload.data() + message_payload.size())); + + StreamID stream_id = message.stream_id(); + PPID ppid = message.ppid(); + + // Zero-copy the payload if the message fits in a single chunk. + std::vector<uint8_t> payload = + is_beginning && is_end + ? std::move(message).ReleasePayload() + : std::vector<uint8_t>(chunk_payload.begin(), chunk_payload.end()); + + FSN fsn(item.current_fsn); + item.current_fsn = FSN(*item.current_fsn + 1); + buffered_amount_.Decrease(payload.size()); + parent_.total_buffered_amount_.Decrease(payload.size()); + + SendQueue::DataToSend chunk(Data(stream_id, item.ssn.value_or(SSN(0)), + item.message_id.value(), fsn, ppid, + std::move(payload), is_beginning, is_end, + item.attributes.unordered)); + chunk.max_retransmissions = item.attributes.max_retransmissions; + chunk.expires_at = item.attributes.expires_at; + chunk.lifecycle_id = + is_end ? item.attributes.lifecycle_id : LifecycleId::NotSet(); + + if (is_end) { + // The entire message has been sent, and its last data copied to `chunk`, + // so it can safely be discarded. + items_.pop_front(); + + if (pause_state_ == PauseState::kPending) { + RTC_DLOG(LS_VERBOSE) << "Pause state on " << *stream_id + << " is moving from pending to paused"; + pause_state_ = PauseState::kPaused; + } + } else { + item.remaining_offset += chunk_payload.size(); + item.remaining_size -= chunk_payload.size(); + RTC_DCHECK(item.remaining_offset + item.remaining_size == + item.message.payload().size()); + RTC_DCHECK(item.remaining_size > 0); + } + RTC_DCHECK(IsConsistent()); + return chunk; + } + RTC_DCHECK(IsConsistent()); + return absl::nullopt; +} + +void RRSendQueue::OutgoingStream::HandleMessageExpired( + OutgoingStream::Item& item) { + buffered_amount_.Decrease(item.remaining_size); + parent_.total_buffered_amount_.Decrease(item.remaining_size); + if (item.attributes.lifecycle_id.IsSet()) { + RTC_DLOG(LS_VERBOSE) << "Triggering OnLifecycleMessageExpired(" + << item.attributes.lifecycle_id.value() << ", false)"; + + parent_.callbacks_.OnLifecycleMessageExpired(item.attributes.lifecycle_id, + /*maybe_delivered=*/false); + parent_.callbacks_.OnLifecycleEnd(item.attributes.lifecycle_id); + } +} + +bool RRSendQueue::OutgoingStream::Discard(IsUnordered unordered, + MID message_id) { + bool result = false; + if (!items_.empty()) { + Item& item = items_.front(); + if (item.attributes.unordered == unordered && item.message_id.has_value() && + *item.message_id == message_id) { + HandleMessageExpired(item); + items_.pop_front(); + + // Only partially sent messages are discarded, so if a message was + // discarded, then it was the currently sent message. + scheduler_stream_->ForceReschedule(); + + if (pause_state_ == PauseState::kPending) { + pause_state_ = PauseState::kPaused; + scheduler_stream_->MakeInactive(); + } else if (bytes_to_send_in_next_message() == 0) { + scheduler_stream_->MakeInactive(); + } + + // As the item still existed, it had unsent data. + result = true; + } + } + RTC_DCHECK(IsConsistent()); + return result; +} + +void RRSendQueue::OutgoingStream::Pause() { + if (pause_state_ != PauseState::kNotPaused) { + // Already in progress. + return; + } + + bool had_pending_items = !items_.empty(); + + // https://datatracker.ietf.org/doc/html/rfc8831#section-6.7 + // "Closing of a data channel MUST be signaled by resetting the corresponding + // outgoing streams [RFC6525]. This means that if one side decides to close + // the data channel, it resets the corresponding outgoing stream." + // ... "[RFC6525] also guarantees that all the messages are delivered (or + // abandoned) before the stream is reset." + + // A stream is paused when it's about to be reset. In this implementation, + // it will throw away all non-partially send messages - they will be abandoned + // as noted above. This is subject to change. It will however not discard any + // partially sent messages - only whole messages. Partially delivered messages + // (at the time of receiving a Stream Reset command) will always deliver all + // the fragments before actually resetting the stream. + for (auto it = items_.begin(); it != items_.end();) { + if (it->remaining_offset == 0) { + HandleMessageExpired(*it); + it = items_.erase(it); + } else { + ++it; + } + } + + pause_state_ = (items_.empty() || items_.front().remaining_offset == 0) + ? PauseState::kPaused + : PauseState::kPending; + + if (had_pending_items && pause_state_ == PauseState::kPaused) { + RTC_DLOG(LS_VERBOSE) << "Stream " << *stream_id() + << " was previously active, but is now paused."; + scheduler_stream_->MakeInactive(); + } + + RTC_DCHECK(IsConsistent()); +} + +void RRSendQueue::OutgoingStream::Resume() { + RTC_DCHECK(pause_state_ == PauseState::kResetting); + pause_state_ = PauseState::kNotPaused; + scheduler_stream_->MaybeMakeActive(); + RTC_DCHECK(IsConsistent()); +} + +void RRSendQueue::OutgoingStream::Reset() { + // This can be called both when an outgoing stream reset has been responded + // to, or when the entire SendQueue is reset due to detecting the peer having + // restarted. The stream may be in any state at this time. + PauseState old_pause_state = pause_state_; + pause_state_ = PauseState::kNotPaused; + next_ordered_mid_ = MID(0); + next_unordered_mid_ = MID(0); + next_ssn_ = SSN(0); + if (!items_.empty()) { + // If this message has been partially sent, reset it so that it will be + // re-sent. + auto& item = items_.front(); + buffered_amount_.Increase(item.message.payload().size() - + item.remaining_size); + parent_.total_buffered_amount_.Increase(item.message.payload().size() - + item.remaining_size); + item.remaining_offset = 0; + item.remaining_size = item.message.payload().size(); + item.message_id = absl::nullopt; + item.ssn = absl::nullopt; + item.current_fsn = FSN(0); + if (old_pause_state == PauseState::kPaused || + old_pause_state == PauseState::kResetting) { + scheduler_stream_->MaybeMakeActive(); + } + } + RTC_DCHECK(IsConsistent()); +} + +bool RRSendQueue::OutgoingStream::has_partially_sent_message() const { + if (items_.empty()) { + return false; + } + return items_.front().message_id.has_value(); +} + +void RRSendQueue::Add(TimeMs now, + DcSctpMessage message, + const SendOptions& send_options) { + RTC_DCHECK(!message.payload().empty()); + // Any limited lifetime should start counting from now - when the message + // has been added to the queue. + + // `expires_at` is the time when it expires. Which is slightly larger than the + // message's lifetime, as the message is alive during its entire lifetime + // (which may be zero). + MessageAttributes attributes = { + .unordered = send_options.unordered, + .max_retransmissions = + send_options.max_retransmissions.has_value() + ? MaxRetransmits(send_options.max_retransmissions.value()) + : MaxRetransmits::NoLimit(), + .expires_at = send_options.lifetime.has_value() + ? now + *send_options.lifetime + DurationMs(1) + : TimeMs::InfiniteFuture(), + .lifecycle_id = send_options.lifecycle_id, + }; + GetOrCreateStreamInfo(message.stream_id()) + .Add(std::move(message), std::move(attributes)); + RTC_DCHECK(IsConsistent()); +} + +bool RRSendQueue::IsFull() const { + return total_buffered_amount() >= buffer_size_; +} + +bool RRSendQueue::IsEmpty() const { + return total_buffered_amount() == 0; +} + +absl::optional<SendQueue::DataToSend> RRSendQueue::Produce(TimeMs now, + size_t max_size) { + return scheduler_.Produce(now, max_size); +} + +bool RRSendQueue::Discard(IsUnordered unordered, + StreamID stream_id, + MID message_id) { + bool has_discarded = + GetOrCreateStreamInfo(stream_id).Discard(unordered, message_id); + + RTC_DCHECK(IsConsistent()); + return has_discarded; +} + +void RRSendQueue::PrepareResetStream(StreamID stream_id) { + GetOrCreateStreamInfo(stream_id).Pause(); + RTC_DCHECK(IsConsistent()); +} + +bool RRSendQueue::HasStreamsReadyToBeReset() const { + for (auto& [unused, stream] : streams_) { + if (stream.IsReadyToBeReset()) { + return true; + } + } + return false; +} +std::vector<StreamID> RRSendQueue::GetStreamsReadyToBeReset() { + RTC_DCHECK(absl::c_count_if(streams_, [](const auto& p) { + return p.second.IsResetting(); + }) == 0); + std::vector<StreamID> ready; + for (auto& [stream_id, stream] : streams_) { + if (stream.IsReadyToBeReset()) { + stream.SetAsResetting(); + ready.push_back(stream_id); + } + } + return ready; +} + +void RRSendQueue::CommitResetStreams() { + RTC_DCHECK(absl::c_count_if(streams_, [](const auto& p) { + return p.second.IsResetting(); + }) > 0); + for (auto& [unused, stream] : streams_) { + if (stream.IsResetting()) { + stream.Reset(); + } + } + RTC_DCHECK(IsConsistent()); +} + +void RRSendQueue::RollbackResetStreams() { + RTC_DCHECK(absl::c_count_if(streams_, [](const auto& p) { + return p.second.IsResetting(); + }) > 0); + for (auto& [unused, stream] : streams_) { + if (stream.IsResetting()) { + stream.Resume(); + } + } + RTC_DCHECK(IsConsistent()); +} + +void RRSendQueue::Reset() { + // Recalculate buffered amount, as partially sent messages may have been put + // fully back in the queue. + for (auto& [unused, stream] : streams_) { + stream.Reset(); + } + scheduler_.ForceReschedule(); +} + +size_t RRSendQueue::buffered_amount(StreamID stream_id) const { + auto it = streams_.find(stream_id); + if (it == streams_.end()) { + return 0; + } + return it->second.buffered_amount().value(); +} + +size_t RRSendQueue::buffered_amount_low_threshold(StreamID stream_id) const { + auto it = streams_.find(stream_id); + if (it == streams_.end()) { + return 0; + } + return it->second.buffered_amount().low_threshold(); +} + +void RRSendQueue::SetBufferedAmountLowThreshold(StreamID stream_id, + size_t bytes) { + GetOrCreateStreamInfo(stream_id).buffered_amount().SetLowThreshold(bytes); +} + +RRSendQueue::OutgoingStream& RRSendQueue::GetOrCreateStreamInfo( + StreamID stream_id) { + auto it = streams_.find(stream_id); + if (it != streams_.end()) { + return it->second; + } + + return streams_ + .emplace( + std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple(this, &scheduler_, stream_id, default_priority_, + [this, stream_id]() { + callbacks_.OnBufferedAmountLow(stream_id); + })) + .first->second; +} + +void RRSendQueue::SetStreamPriority(StreamID stream_id, + StreamPriority priority) { + OutgoingStream& stream = GetOrCreateStreamInfo(stream_id); + + stream.SetPriority(priority); + RTC_DCHECK(IsConsistent()); +} + +StreamPriority RRSendQueue::GetStreamPriority(StreamID stream_id) const { + auto stream_it = streams_.find(stream_id); + if (stream_it == streams_.end()) { + return default_priority_; + } + return stream_it->second.priority(); +} + +HandoverReadinessStatus RRSendQueue::GetHandoverReadiness() const { + HandoverReadinessStatus status; + if (!IsEmpty()) { + status.Add(HandoverUnreadinessReason::kSendQueueNotEmpty); + } + return status; +} + +void RRSendQueue::AddHandoverState(DcSctpSocketHandoverState& state) { + for (const auto& [stream_id, stream] : streams_) { + DcSctpSocketHandoverState::OutgoingStream state_stream; + state_stream.id = stream_id.value(); + stream.AddHandoverState(state_stream); + state.tx.streams.push_back(std::move(state_stream)); + } +} + +void RRSendQueue::RestoreFromState(const DcSctpSocketHandoverState& state) { + for (const DcSctpSocketHandoverState::OutgoingStream& state_stream : + state.tx.streams) { + StreamID stream_id(state_stream.id); + streams_.emplace( + std::piecewise_construct, std::forward_as_tuple(stream_id), + std::forward_as_tuple( + this, &scheduler_, stream_id, StreamPriority(state_stream.priority), + [this, stream_id]() { callbacks_.OnBufferedAmountLow(stream_id); }, + &state_stream)); + } +} +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.h b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.h new file mode 100644 index 0000000000..e9b8cd2081 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue.h @@ -0,0 +1,282 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_RR_SEND_QUEUE_H_ +#define NET_DCSCTP_TX_RR_SEND_QUEUE_H_ + +#include <cstdint> +#include <deque> +#include <map> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/tx/send_queue.h" +#include "net/dcsctp/tx/stream_scheduler.h" + +namespace dcsctp { + +// The Round Robin SendQueue holds all messages that the client wants to send, +// but that haven't yet been split into chunks and fully sent on the wire. +// +// As defined in https://datatracker.ietf.org/doc/html/rfc8260#section-3.2, +// it will cycle to send messages from different streams. It will send all +// fragments from one message before continuing with a different message on +// possibly a different stream, until support for message interleaving has been +// implemented. +// +// As messages can be (requested to be) sent before the connection is properly +// established, this send queue is always present - even for closed connections. +// +// The send queue may trigger callbacks: +// * `OnBufferedAmountLow`, `OnTotalBufferedAmountLow` +// These will be triggered as defined in their documentation. +// * `OnLifecycleMessageExpired(/*maybe_delivered=*/false)`, `OnLifecycleEnd` +// These will be triggered when messages have been expired, abandoned or +// discarded from the send queue. If a message is fully produced, meaning +// that the last fragment has been produced, the responsibility to send +// lifecycle events is then transferred to the retransmission queue, which +// is the one asking to produce the message. +class RRSendQueue : public SendQueue { + public: + RRSendQueue(absl::string_view log_prefix, + DcSctpSocketCallbacks* callbacks, + size_t buffer_size, + size_t mtu, + StreamPriority default_priority, + size_t total_buffered_amount_low_threshold); + + // Indicates if the buffer is full. Note that it's up to the caller to ensure + // that the buffer is not full prior to adding new items to it. + bool IsFull() const; + // Indicates if the buffer is empty. + bool IsEmpty() const; + + // Adds the message to be sent using the `send_options` provided. The current + // time should be in `now`. Note that it's the responsibility of the caller to + // ensure that the buffer is not full (by calling `IsFull`) before adding + // messages to it. + void Add(TimeMs now, + DcSctpMessage message, + const SendOptions& send_options = {}); + + // Implementation of `SendQueue`. + absl::optional<DataToSend> Produce(TimeMs now, size_t max_size) override; + bool Discard(IsUnordered unordered, + StreamID stream_id, + MID message_id) override; + void PrepareResetStream(StreamID streams) override; + bool HasStreamsReadyToBeReset() const override; + std::vector<StreamID> GetStreamsReadyToBeReset() override; + void CommitResetStreams() override; + void RollbackResetStreams() override; + void Reset() override; + size_t buffered_amount(StreamID stream_id) const override; + size_t total_buffered_amount() const override { + return total_buffered_amount_.value(); + } + size_t buffered_amount_low_threshold(StreamID stream_id) const override; + void SetBufferedAmountLowThreshold(StreamID stream_id, size_t bytes) override; + void EnableMessageInterleaving(bool enabled) override { + scheduler_.EnableMessageInterleaving(enabled); + } + + void SetStreamPriority(StreamID stream_id, StreamPriority priority); + StreamPriority GetStreamPriority(StreamID stream_id) const; + HandoverReadinessStatus GetHandoverReadiness() const; + void AddHandoverState(DcSctpSocketHandoverState& state); + void RestoreFromState(const DcSctpSocketHandoverState& state); + + private: + struct MessageAttributes { + IsUnordered unordered; + MaxRetransmits max_retransmissions; + TimeMs expires_at; + LifecycleId lifecycle_id; + }; + + // Represents a value and a "low threshold" that when the value reaches or + // goes under the "low threshold", will trigger `on_threshold_reached` + // callback. + class ThresholdWatcher { + public: + explicit ThresholdWatcher(std::function<void()> on_threshold_reached) + : on_threshold_reached_(std::move(on_threshold_reached)) {} + // Increases the value. + void Increase(size_t bytes) { value_ += bytes; } + // Decreases the value and triggers `on_threshold_reached` if it's at or + // below `low_threshold()`. + void Decrease(size_t bytes); + + size_t value() const { return value_; } + size_t low_threshold() const { return low_threshold_; } + void SetLowThreshold(size_t low_threshold); + + private: + const std::function<void()> on_threshold_reached_; + size_t value_ = 0; + size_t low_threshold_ = 0; + }; + + // Per-stream information. + class OutgoingStream : public StreamScheduler::StreamProducer { + public: + OutgoingStream( + RRSendQueue* parent, + StreamScheduler* scheduler, + StreamID stream_id, + StreamPriority priority, + std::function<void()> on_buffered_amount_low, + const DcSctpSocketHandoverState::OutgoingStream* state = nullptr) + : parent_(*parent), + scheduler_stream_(scheduler->CreateStream(this, stream_id, priority)), + next_unordered_mid_(MID(state ? state->next_unordered_mid : 0)), + next_ordered_mid_(MID(state ? state->next_ordered_mid : 0)), + next_ssn_(SSN(state ? state->next_ssn : 0)), + buffered_amount_(std::move(on_buffered_amount_low)) {} + + StreamID stream_id() const { return scheduler_stream_->stream_id(); } + + // Enqueues a message to this stream. + void Add(DcSctpMessage message, MessageAttributes attributes); + + // Implementing `StreamScheduler::StreamProducer`. + absl::optional<SendQueue::DataToSend> Produce(TimeMs now, + size_t max_size) override; + size_t bytes_to_send_in_next_message() const override; + + const ThresholdWatcher& buffered_amount() const { return buffered_amount_; } + ThresholdWatcher& buffered_amount() { return buffered_amount_; } + + // Discards a partially sent message, see `SendQueue::Discard`. + bool Discard(IsUnordered unordered, MID message_id); + + // Pauses this stream, which is used before resetting it. + void Pause(); + + // Resumes a paused stream. + void Resume(); + + bool IsReadyToBeReset() const { + return pause_state_ == PauseState::kPaused; + } + + bool IsResetting() const { return pause_state_ == PauseState::kResetting; } + + void SetAsResetting() { + RTC_DCHECK(pause_state_ == PauseState::kPaused); + pause_state_ = PauseState::kResetting; + } + + // Resets this stream, meaning MIDs and SSNs are set to zero. + void Reset(); + + // Indicates if this stream has a partially sent message in it. + bool has_partially_sent_message() const; + + StreamPriority priority() const { return scheduler_stream_->priority(); } + void SetPriority(StreamPriority priority) { + scheduler_stream_->SetPriority(priority); + } + + void AddHandoverState( + DcSctpSocketHandoverState::OutgoingStream& state) const; + + private: + // Streams are paused before they can be reset. To reset a stream, the + // socket sends an outgoing stream reset command with the TSN of the last + // fragment of the last message, so that receivers and senders can agree on + // when it stopped. And if the send queue is in the middle of sending a + // message, and without fragments not yet sent and without TSNs allocated to + // them, it will keep sending data until that message has ended. + enum class PauseState { + // The stream is not paused, and not scheduled to be reset. + kNotPaused, + // The stream has requested to be reset/paused but is still producing + // fragments of a message that hasn't ended yet. When it does, it will + // transition to the `kPaused` state. + kPending, + // The stream is fully paused and can be reset. + kPaused, + // The stream has been added to an outgoing stream reset request and a + // response from the peer hasn't been received yet. + kResetting, + }; + + // An enqueued message and metadata. + struct Item { + explicit Item(DcSctpMessage msg, MessageAttributes attributes) + : message(std::move(msg)), + attributes(std::move(attributes)), + remaining_offset(0), + remaining_size(message.payload().size()) {} + DcSctpMessage message; + MessageAttributes attributes; + // The remaining payload (offset and size) to be sent, when it has been + // fragmented. + size_t remaining_offset; + size_t remaining_size; + // If set, an allocated Message ID and SSN. Will be allocated when the + // first fragment is sent. + absl::optional<MID> message_id = absl::nullopt; + absl::optional<SSN> ssn = absl::nullopt; + // The current Fragment Sequence Number, incremented for each fragment. + FSN current_fsn = FSN(0); + }; + + bool IsConsistent() const; + void HandleMessageExpired(OutgoingStream::Item& item); + + RRSendQueue& parent_; + + const std::unique_ptr<StreamScheduler::Stream> scheduler_stream_; + + PauseState pause_state_ = PauseState::kNotPaused; + // MIDs are different for unordered and ordered messages sent on a stream. + MID next_unordered_mid_; + MID next_ordered_mid_; + + SSN next_ssn_; + // Enqueued messages, and metadata. + std::deque<Item> items_; + + // The current amount of buffered data. + ThresholdWatcher buffered_amount_; + }; + + bool IsConsistent() const; + OutgoingStream& GetOrCreateStreamInfo(StreamID stream_id); + absl::optional<DataToSend> Produce( + std::map<StreamID, OutgoingStream>::iterator it, + TimeMs now, + size_t max_size); + + const std::string log_prefix_; + DcSctpSocketCallbacks& callbacks_; + const size_t buffer_size_; + const StreamPriority default_priority_; + StreamScheduler scheduler_; + + // The total amount of buffer data, for all streams. + ThresholdWatcher total_buffered_amount_; + + // All streams, and messages added to those. + std::map<StreamID, OutgoingStream> streams_; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_RR_SEND_QUEUE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue_test.cc b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue_test.cc new file mode 100644 index 0000000000..95416b193a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/rr_send_queue_test.cc @@ -0,0 +1,866 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/rr_send_queue.h" + +#include <cstdint> +#include <type_traits> +#include <vector> + +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_options.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/socket/mock_dcsctp_socket_callbacks.h" +#include "net/dcsctp/testing/testing_macros.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/gunit.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; + +constexpr TimeMs kNow = TimeMs(0); +constexpr StreamID kStreamID(1); +constexpr PPID kPPID(53); +constexpr size_t kMaxQueueSize = 1000; +constexpr StreamPriority kDefaultPriority(10); +constexpr size_t kBufferedAmountLowThreshold = 500; +constexpr size_t kOneFragmentPacketSize = 100; +constexpr size_t kTwoFragmentPacketSize = 101; +constexpr size_t kMtu = 1100; + +class RRSendQueueTest : public testing::Test { + protected: + RRSendQueueTest() + : buf_("log: ", + &callbacks_, + kMaxQueueSize, + kMtu, + kDefaultPriority, + kBufferedAmountLowThreshold) {} + + testing::NiceMock<MockDcSctpSocketCallbacks> callbacks_; + const DcSctpOptions options_; + RRSendQueue buf_; +}; + +TEST_F(RRSendQueueTest, EmptyBuffer) { + EXPECT_TRUE(buf_.IsEmpty()); + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); + EXPECT_FALSE(buf_.IsFull()); +} + +TEST_F(RRSendQueueTest, AddAndGetSingleChunk) { + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, {1, 2, 4, 5, 6})); + + EXPECT_FALSE(buf_.IsEmpty()); + EXPECT_FALSE(buf_.IsFull()); + absl::optional<SendQueue::DataToSend> chunk_opt = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_opt.has_value()); + EXPECT_TRUE(chunk_opt->data.is_beginning); + EXPECT_TRUE(chunk_opt->data.is_end); +} + +TEST_F(RRSendQueueTest, CarveOutBeginningMiddleAndEnd) { + std::vector<uint8_t> payload(60); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_beg = + buf_.Produce(kNow, /*max_size=*/20); + ASSERT_TRUE(chunk_beg.has_value()); + EXPECT_TRUE(chunk_beg->data.is_beginning); + EXPECT_FALSE(chunk_beg->data.is_end); + + absl::optional<SendQueue::DataToSend> chunk_mid = + buf_.Produce(kNow, /*max_size=*/20); + ASSERT_TRUE(chunk_mid.has_value()); + EXPECT_FALSE(chunk_mid->data.is_beginning); + EXPECT_FALSE(chunk_mid->data.is_end); + + absl::optional<SendQueue::DataToSend> chunk_end = + buf_.Produce(kNow, /*max_size=*/20); + ASSERT_TRUE(chunk_end.has_value()); + EXPECT_FALSE(chunk_end->data.is_beginning); + EXPECT_TRUE(chunk_end->data.is_end); + + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); +} + +TEST_F(RRSendQueueTest, GetChunksFromTwoMessages) { + std::vector<uint8_t> payload(60); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(3), PPID(54), payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(chunk_one->data.ppid, kPPID); + EXPECT_TRUE(chunk_one->data.is_beginning); + EXPECT_TRUE(chunk_one->data.is_end); + + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_two->data.ppid, PPID(54)); + EXPECT_TRUE(chunk_two->data.is_beginning); + EXPECT_TRUE(chunk_two->data.is_end); +} + +TEST_F(RRSendQueueTest, BufferBecomesFullAndEmptied) { + std::vector<uint8_t> payload(600); + EXPECT_FALSE(buf_.IsFull()); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_FALSE(buf_.IsFull()); + buf_.Add(kNow, DcSctpMessage(StreamID(3), PPID(54), payload)); + EXPECT_TRUE(buf_.IsFull()); + // However, it's still possible to add messages. It's a soft limit, and it + // might be necessary to forcefully add messages due to e.g. external + // fragmentation. + buf_.Add(kNow, DcSctpMessage(StreamID(5), PPID(55), payload)); + EXPECT_TRUE(buf_.IsFull()); + + absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 1000); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(chunk_one->data.ppid, kPPID); + + EXPECT_TRUE(buf_.IsFull()); + + absl::optional<SendQueue::DataToSend> chunk_two = buf_.Produce(kNow, 1000); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_two->data.ppid, PPID(54)); + + EXPECT_FALSE(buf_.IsFull()); + EXPECT_FALSE(buf_.IsEmpty()); + + absl::optional<SendQueue::DataToSend> chunk_three = buf_.Produce(kNow, 1000); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.stream_id, StreamID(5)); + EXPECT_EQ(chunk_three->data.ppid, PPID(55)); + + EXPECT_FALSE(buf_.IsFull()); + EXPECT_TRUE(buf_.IsEmpty()); +} + +TEST_F(RRSendQueueTest, DefaultsToOrderedSend) { + std::vector<uint8_t> payload(20); + + // Default is ordered + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_FALSE(chunk_one->data.is_unordered); + + // Explicitly unordered. + SendOptions opts; + opts.unordered = IsUnordered(true); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload), opts); + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_TRUE(chunk_two->data.is_unordered); +} + +TEST_F(RRSendQueueTest, ProduceWithLifetimeExpiry) { + std::vector<uint8_t> payload(20); + + // Default is no expiry + TimeMs now = kNow; + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload)); + now += DurationMs(1000000); + ASSERT_TRUE(buf_.Produce(now, kOneFragmentPacketSize)); + + SendOptions expires_2_seconds; + expires_2_seconds.lifetime = DurationMs(2000); + + // Add and consume within lifetime + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + now += DurationMs(2000); + ASSERT_TRUE(buf_.Produce(now, kOneFragmentPacketSize)); + + // Add and consume just outside lifetime + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + now += DurationMs(2001); + ASSERT_FALSE(buf_.Produce(now, kOneFragmentPacketSize)); + + // A long time after expiry + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + now += DurationMs(1000000); + ASSERT_FALSE(buf_.Produce(now, kOneFragmentPacketSize)); + + // Expire one message, but produce the second that is not expired. + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_2_seconds); + + SendOptions expires_4_seconds; + expires_4_seconds.lifetime = DurationMs(4000); + + buf_.Add(now, DcSctpMessage(kStreamID, kPPID, payload), expires_4_seconds); + now += DurationMs(2001); + + ASSERT_TRUE(buf_.Produce(now, kOneFragmentPacketSize)); + ASSERT_FALSE(buf_.Produce(now, kOneFragmentPacketSize)); +} + +TEST_F(RRSendQueueTest, DiscardPartialPackets) { + std::vector<uint8_t> payload(120); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(2), PPID(54), payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_FALSE(chunk_one->data.is_end); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + buf_.Discard(IsUnordered(false), chunk_one->data.stream_id, + chunk_one->data.message_id); + + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_FALSE(chunk_two->data.is_end); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(2)); + + absl::optional<SendQueue::DataToSend> chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_TRUE(chunk_three->data.is_end); + EXPECT_EQ(chunk_three->data.stream_id, StreamID(2)); + ASSERT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize)); + + // Calling it again shouldn't cause issues. + buf_.Discard(IsUnordered(false), chunk_one->data.stream_id, + chunk_one->data.message_id); + ASSERT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize)); +} + +TEST_F(RRSendQueueTest, PrepareResetStreamsDiscardsStream) { + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, {1, 2, 3})); + buf_.Add(kNow, DcSctpMessage(StreamID(2), PPID(54), {1, 2, 3, 4, 5})); + EXPECT_EQ(buf_.total_buffered_amount(), 8u); + + buf_.PrepareResetStream(StreamID(1)); + EXPECT_EQ(buf_.total_buffered_amount(), 5u); + + EXPECT_THAT(buf_.GetStreamsReadyToBeReset(), + UnorderedElementsAre(StreamID(1))); + buf_.CommitResetStreams(); + buf_.PrepareResetStream(StreamID(2)); + EXPECT_EQ(buf_.total_buffered_amount(), 0u); +} + +TEST_F(RRSendQueueTest, PrepareResetStreamsNotPartialPackets) { + std::vector<uint8_t> payload(120); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 50); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 2 * payload.size() - 50); + + buf_.PrepareResetStream(StreamID(1)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size() - 50); +} + +TEST_F(RRSendQueueTest, EnqueuedItemsArePausedDuringStreamReset) { + std::vector<uint8_t> payload(50); + + buf_.PrepareResetStream(StreamID(1)); + EXPECT_EQ(buf_.total_buffered_amount(), 0u); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); + + EXPECT_TRUE(buf_.HasStreamsReadyToBeReset()); + EXPECT_THAT(buf_.GetStreamsReadyToBeReset(), + UnorderedElementsAre(StreamID(1))); + + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); + + buf_.CommitResetStreams(); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 50); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 0u); +} + +TEST_F(RRSendQueueTest, PausedStreamsStillSendPartialMessagesUntilEnd) { + constexpr size_t kPayloadSize = 100; + constexpr size_t kFragmentSize = 50; + std::vector<uint8_t> payload(kPayloadSize); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kFragmentSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 2 * kPayloadSize - kFragmentSize); + + // This will stop the second message from being sent. + buf_.PrepareResetStream(StreamID(1)); + EXPECT_EQ(buf_.total_buffered_amount(), 1 * kPayloadSize - kFragmentSize); + + // Should still produce fragments until end of message. + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kFragmentSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 0ul); + + // But shouldn't produce any more messages as the stream is paused. + EXPECT_FALSE(buf_.Produce(kNow, kFragmentSize).has_value()); +} + +TEST_F(RRSendQueueTest, CommittingResetsSSN) { + std::vector<uint8_t> payload(50); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.ssn, SSN(0)); + + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.ssn, SSN(1)); + + buf_.PrepareResetStream(StreamID(1)); + + // Buffered + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + EXPECT_TRUE(buf_.HasStreamsReadyToBeReset()); + EXPECT_THAT(buf_.GetStreamsReadyToBeReset(), + UnorderedElementsAre(StreamID(1))); + buf_.CommitResetStreams(); + + absl::optional<SendQueue::DataToSend> chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.ssn, SSN(0)); +} + +TEST_F(RRSendQueueTest, CommittingResetsSSNForPausedStreamsOnly) { + std::vector<uint8_t> payload(50); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, StreamID(1)); + EXPECT_EQ(chunk_one->data.ssn, SSN(0)); + + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_two->data.ssn, SSN(0)); + + buf_.PrepareResetStream(StreamID(3)); + + // Send two more messages - SID 3 will buffer, SID 1 will send. + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, payload)); + + EXPECT_TRUE(buf_.HasStreamsReadyToBeReset()); + EXPECT_THAT(buf_.GetStreamsReadyToBeReset(), + UnorderedElementsAre(StreamID(3))); + + buf_.CommitResetStreams(); + + absl::optional<SendQueue::DataToSend> chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.stream_id, StreamID(1)); + EXPECT_EQ(chunk_three->data.ssn, SSN(1)); + + absl::optional<SendQueue::DataToSend> chunk_four = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_four.has_value()); + EXPECT_EQ(chunk_four->data.stream_id, StreamID(3)); + EXPECT_EQ(chunk_four->data.ssn, SSN(0)); +} + +TEST_F(RRSendQueueTest, RollBackResumesSSN) { + std::vector<uint8_t> payload(50); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.ssn, SSN(0)); + + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_two.has_value()); + EXPECT_EQ(chunk_two->data.ssn, SSN(1)); + + buf_.PrepareResetStream(StreamID(1)); + + // Buffered + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + + EXPECT_TRUE(buf_.HasStreamsReadyToBeReset()); + EXPECT_THAT(buf_.GetStreamsReadyToBeReset(), + UnorderedElementsAre(StreamID(1))); + buf_.RollbackResetStreams(); + + absl::optional<SendQueue::DataToSend> chunk_three = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_three.has_value()); + EXPECT_EQ(chunk_three->data.ssn, SSN(2)); +} + +TEST_F(RRSendQueueTest, ReturnsFragmentsForOneMessageBeforeMovingToNext) { + std::vector<uint8_t> payload(200); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, payload)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(2)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(2)); +} + +TEST_F(RRSendQueueTest, ReturnsAlsoSmallFragmentsBeforeMovingToNext) { + std::vector<uint8_t> payload(kTwoFragmentPacketSize); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, payload)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(kOneFragmentPacketSize)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, + SizeIs(kTwoFragmentPacketSize - kOneFragmentPacketSize)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk4.data.payload, + SizeIs(kTwoFragmentPacketSize - kOneFragmentPacketSize)); +} + +TEST_F(RRSendQueueTest, WillCycleInRoundRobinFashionBetweenStreams) { + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1))); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(2))); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(3))); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(4))); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, std::vector<uint8_t>(5))); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, std::vector<uint8_t>(6))); + buf_.Add(kNow, DcSctpMessage(StreamID(4), kPPID, std::vector<uint8_t>(7))); + buf_.Add(kNow, DcSctpMessage(StreamID(4), kPPID, std::vector<uint8_t>(8))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk2.data.payload, SizeIs(3)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(3)); + EXPECT_THAT(chunk3.data.payload, SizeIs(5)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(4)); + EXPECT_THAT(chunk4.data.payload, SizeIs(7)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk5, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk5.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk5.data.payload, SizeIs(2)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk6, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk6.data.stream_id, StreamID(2)); + EXPECT_THAT(chunk6.data.payload, SizeIs(4)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk7, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk7.data.stream_id, StreamID(3)); + EXPECT_THAT(chunk7.data.payload, SizeIs(6)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk8, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk8.data.stream_id, StreamID(4)); + EXPECT_THAT(chunk8.data.payload, SizeIs(8)); +} + +TEST_F(RRSendQueueTest, DoesntTriggerOnBufferedAmountLowWhenSetToZero) { + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 0u); +} + +TEST_F(RRSendQueueTest, TriggersOnBufferedAmountAtZeroLowWhenSent) { + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 1u); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 0u); +} + +TEST_F(RRSendQueueTest, WillRetriggerOnBufferedAmountLowIfAddingMore) { + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1))); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 1u); + + // Should now trigger again, as buffer_amount went above the threshold. + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(1)); +} + +TEST_F(RRSendQueueTest, OnlyTriggersWhenTransitioningFromAboveToBelowOrEqual) { + buf_.SetBufferedAmountLowThreshold(StreamID(1), 1000); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(10))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 10u); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(10)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 0u); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(20))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 20u); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(20)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 0u); +} + +TEST_F(RRSendQueueTest, WillTriggerOnBufferedAmountLowSetAboveZero) { + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + + buf_.SetBufferedAmountLowThreshold(StreamID(1), 700); + + std::vector<uint8_t> payload(1000); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, payload)); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 900u); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 800u); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 700u); + + // Doesn't trigger when reducing even further. + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 600u); +} + +TEST_F(RRSendQueueTest, WillRetriggerOnBufferedAmountLowSetAboveZero) { + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + + buf_.SetBufferedAmountLowThreshold(StreamID(1), 700); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(1000))); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, 400)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk1.data.payload, SizeIs(400)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 600u); + + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(200))); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 800u); + + // Will trigger again, as it went above the limit. + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, 200)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(200)); + EXPECT_EQ(buf_.buffered_amount(StreamID(1)), 600u); +} + +TEST_F(RRSendQueueTest, TriggersOnBufferedAmountLowOnThresholdChanged) { + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(100))); + + // Modifying the threshold, still under buffered_amount, should not trigger. + buf_.SetBufferedAmountLowThreshold(StreamID(1), 50); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 99); + + // When the threshold reaches buffered_amount, it will trigger. + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 100); + + // But not when it's set low again. + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 50); + + // But it will trigger when it overshoots. + EXPECT_CALL(callbacks_, OnBufferedAmountLow(StreamID(1))); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 150); + + // But not when it's set low again. + EXPECT_CALL(callbacks_, OnBufferedAmountLow).Times(0); + buf_.SetBufferedAmountLowThreshold(StreamID(1), 0); +} + +TEST_F(RRSendQueueTest, + OnTotalBufferedAmountLowDoesNotTriggerOnBufferFillingUp) { + EXPECT_CALL(callbacks_, OnTotalBufferedAmountLow).Times(0); + std::vector<uint8_t> payload(kBufferedAmountLowThreshold - 1); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + // Will not trigger if going above but never below. + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, + std::vector<uint8_t>(kOneFragmentPacketSize))); +} + +TEST_F(RRSendQueueTest, TriggersOnTotalBufferedAmountLowWhenCrossing) { + EXPECT_CALL(callbacks_, OnTotalBufferedAmountLow).Times(0); + std::vector<uint8_t> payload(kBufferedAmountLowThreshold); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size()); + + // Reaches it. + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, std::vector<uint8_t>(1))); + + // Drain it a bit - will trigger. + EXPECT_CALL(callbacks_, OnTotalBufferedAmountLow).Times(1); + absl::optional<SendQueue::DataToSend> chunk_two = + buf_.Produce(kNow, kOneFragmentPacketSize); +} + +TEST_F(RRSendQueueTest, WillStayInAStreamAsLongAsThatMessageIsSending) { + buf_.Add(kNow, DcSctpMessage(StreamID(5), kPPID, std::vector<uint8_t>(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk1, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk1.data.stream_id, StreamID(5)); + EXPECT_THAT(chunk1.data.payload, SizeIs(1)); + + // Next, it should pick a different stream. + + buf_.Add(kNow, + DcSctpMessage(StreamID(1), kPPID, + std::vector<uint8_t>(kOneFragmentPacketSize * 2))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk2, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk2.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk2.data.payload, SizeIs(kOneFragmentPacketSize)); + + // It should still stay on the Stream1 now, even if might be tempted to switch + // to this stream, as it's the stream following 5. + buf_.Add(kNow, DcSctpMessage(StreamID(6), kPPID, std::vector<uint8_t>(1))); + + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk3, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk3.data.stream_id, StreamID(1)); + EXPECT_THAT(chunk3.data.payload, SizeIs(kOneFragmentPacketSize)); + + // After stream id 1 is complete, it's time to do stream 6. + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk4, + buf_.Produce(kNow, kOneFragmentPacketSize)); + EXPECT_EQ(chunk4.data.stream_id, StreamID(6)); + EXPECT_THAT(chunk4.data.payload, SizeIs(1)); + + EXPECT_FALSE(buf_.Produce(kNow, kOneFragmentPacketSize).has_value()); +} + +TEST_F(RRSendQueueTest, StreamsHaveInitialPriority) { + EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), kDefaultPriority); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40))); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), kDefaultPriority); +} + +TEST_F(RRSendQueueTest, CanChangeStreamPriority) { + buf_.SetStreamPriority(StreamID(1), StreamPriority(42)); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(1)), StreamPriority(42)); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40))); + buf_.SetStreamPriority(StreamID(2), StreamPriority(42)); + EXPECT_EQ(buf_.GetStreamPriority(StreamID(2)), StreamPriority(42)); +} + +TEST_F(RRSendQueueTest, WillHandoverPriority) { + buf_.SetStreamPriority(StreamID(1), StreamPriority(42)); + + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(40))); + buf_.SetStreamPriority(StreamID(2), StreamPriority(42)); + + DcSctpSocketHandoverState state; + buf_.AddHandoverState(state); + + RRSendQueue q2("log: ", &callbacks_, kMaxQueueSize, kMtu, kDefaultPriority, + kBufferedAmountLowThreshold); + q2.RestoreFromState(state); + EXPECT_EQ(q2.GetStreamPriority(StreamID(1)), StreamPriority(42)); + EXPECT_EQ(q2.GetStreamPriority(StreamID(2)), StreamPriority(42)); +} + +TEST_F(RRSendQueueTest, WillSendMessagesByPrio) { + buf_.EnableMessageInterleaving(true); + buf_.SetStreamPriority(StreamID(1), StreamPriority(10)); + buf_.SetStreamPriority(StreamID(2), StreamPriority(20)); + buf_.SetStreamPriority(StreamID(3), StreamPriority(30)); + + buf_.Add(kNow, DcSctpMessage(StreamID(1), kPPID, std::vector<uint8_t>(40))); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, std::vector<uint8_t>(20))); + buf_.Add(kNow, DcSctpMessage(StreamID(3), kPPID, std::vector<uint8_t>(10))); + std::vector<uint16_t> expected_streams = {3, 2, 2, 1, 1, 1, 1}; + + for (uint16_t stream_num : expected_streams) { + ASSERT_HAS_VALUE_AND_ASSIGN(SendQueue::DataToSend chunk, + buf_.Produce(kNow, 10)); + EXPECT_EQ(chunk.data.stream_id, StreamID(stream_num)); + } + EXPECT_FALSE(buf_.Produce(kNow, 1).has_value()); +} + +TEST_F(RRSendQueueTest, WillSendLifecycleExpireWhenExpiredInSendQueue) { + std::vector<uint8_t> payload(kOneFragmentPacketSize); + buf_.Add(kNow, DcSctpMessage(StreamID(2), kPPID, payload), + SendOptions{.lifetime = DurationMs(1000), + .lifecycle_id = LifecycleId(1)}); + + EXPECT_CALL(callbacks_, OnLifecycleMessageExpired(LifecycleId(1), + /*maybe_delivered=*/false)); + EXPECT_CALL(callbacks_, OnLifecycleEnd(LifecycleId(1))); + EXPECT_FALSE(buf_.Produce(kNow + DurationMs(1001), kOneFragmentPacketSize) + .has_value()); +} + +TEST_F(RRSendQueueTest, WillSendLifecycleExpireWhenDiscardingDuringPause) { + std::vector<uint8_t> payload(120); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload), + SendOptions{.lifecycle_id = LifecycleId(1)}); + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload), + SendOptions{.lifecycle_id = LifecycleId(2)}); + + absl::optional<SendQueue::DataToSend> chunk_one = buf_.Produce(kNow, 50); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_EQ(buf_.total_buffered_amount(), 2 * payload.size() - 50); + + EXPECT_CALL(callbacks_, OnLifecycleMessageExpired(LifecycleId(2), + /*maybe_delivered=*/false)); + EXPECT_CALL(callbacks_, OnLifecycleEnd(LifecycleId(2))); + buf_.PrepareResetStream(StreamID(1)); + EXPECT_EQ(buf_.total_buffered_amount(), payload.size() - 50); +} + +TEST_F(RRSendQueueTest, WillSendLifecycleExpireWhenDiscardingExplicitly) { + std::vector<uint8_t> payload(kOneFragmentPacketSize + 20); + + buf_.Add(kNow, DcSctpMessage(kStreamID, kPPID, payload), + SendOptions{.lifecycle_id = LifecycleId(1)}); + + absl::optional<SendQueue::DataToSend> chunk_one = + buf_.Produce(kNow, kOneFragmentPacketSize); + ASSERT_TRUE(chunk_one.has_value()); + EXPECT_FALSE(chunk_one->data.is_end); + EXPECT_EQ(chunk_one->data.stream_id, kStreamID); + EXPECT_CALL(callbacks_, OnLifecycleMessageExpired(LifecycleId(1), + /*maybe_delivered=*/false)); + EXPECT_CALL(callbacks_, OnLifecycleEnd(LifecycleId(1))); + buf_.Discard(IsUnordered(false), chunk_one->data.stream_id, + chunk_one->data.message_id); +} +} // namespace +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/send_queue.h b/third_party/libwebrtc/net/dcsctp/tx/send_queue.h new file mode 100644 index 0000000000..0b96e9041a --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/send_queue.h @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_SEND_QUEUE_H_ +#define NET_DCSCTP_TX_SEND_QUEUE_H_ + +#include <cstdint> +#include <limits> +#include <utility> +#include <vector> + +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/internal_types.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/types.h" + +namespace dcsctp { + +class SendQueue { + public: + // Container for a data chunk that is produced by the SendQueue + struct DataToSend { + explicit DataToSend(Data data) : data(std::move(data)) {} + // The data to send, including all parameters. + Data data; + + // Partial reliability - RFC3758 + MaxRetransmits max_retransmissions = MaxRetransmits::NoLimit(); + TimeMs expires_at = TimeMs::InfiniteFuture(); + + // Lifecycle - set for the last fragment, and `LifecycleId::NotSet()` for + // all other fragments. + LifecycleId lifecycle_id = LifecycleId::NotSet(); + }; + + virtual ~SendQueue() = default; + + // TODO(boivie): This interface is obviously missing an "Add" function, but + // that is postponed a bit until the story around how to model message + // prioritization, which is important for any advanced stream scheduler, is + // further clarified. + + // Produce a chunk to be sent. + // + // `max_size` refers to how many payload bytes that may be produced, not + // including any headers. + virtual absl::optional<DataToSend> Produce(TimeMs now, size_t max_size) = 0; + + // Discards a partially sent message identified by the parameters `unordered`, + // `stream_id` and `message_id`. The `message_id` comes from the returned + // information when having called `Produce`. A partially sent message means + // that it has had at least one fragment of it returned when `Produce` was + // called prior to calling this method). + // + // This is used when a message has been found to be expired (by the partial + // reliability extension), and the retransmission queue will signal the + // receiver that any partially received message fragments should be skipped. + // This means that any remaining fragments in the Send Queue must be removed + // as well so that they are not sent. + // + // This function returns true if this message had unsent fragments still in + // the queue that were discarded, and false if there were no such fragments. + virtual bool Discard(IsUnordered unordered, + StreamID stream_id, + MID message_id) = 0; + + // Prepares the stream to be reset. This is used to close a WebRTC data + // channel and will be signaled to the other side. + // + // Concretely, it discards all whole (not partly sent) messages in the given + // stream and pauses that stream so that future added messages aren't + // produced until `ResumeStreams` is called. + // + // TODO(boivie): Investigate if it really should discard any message at all. + // RFC8831 only mentions that "[RFC6525] also guarantees that all the messages + // are delivered (or abandoned) before the stream is reset." + // + // This method can be called multiple times to add more streams to be + // reset, and paused while they are resetting. This is the first part of the + // two-phase commit protocol to reset streams, where the caller completes the + // procedure by either calling `CommitResetStreams` or `RollbackResetStreams`. + virtual void PrepareResetStream(StreamID stream_id) = 0; + + // Indicates if there are any streams that are ready to be reset. + virtual bool HasStreamsReadyToBeReset() const = 0; + + // Returns a list of streams that are ready to be included in an outgoing + // stream reset request. Any streams that are returned here must be included + // in an outgoing stream reset request, and there must not be concurrent + // requests. Before calling this method again, you must have called + virtual std::vector<StreamID> GetStreamsReadyToBeReset() = 0; + + // Called to commit to reset the streams returned by + // `GetStreamsReadyToBeReset`. It will reset the stream sequence numbers + // (SSNs) and message identifiers (MIDs) and resume the paused streams. + virtual void CommitResetStreams() = 0; + + // Called to abort the resetting of streams returned by + // `GetStreamsReadyToBeReset`. Will resume the paused streams without + // resetting the stream sequence numbers (SSNs) or message identifiers (MIDs). + // Note that the non-partial messages that were discarded when calling + // `PrepareResetStreams` will not be recovered, to better match the intention + // from the sender to "close the channel". + virtual void RollbackResetStreams() = 0; + + // Resets all message identifier counters (MID, SSN) and makes all partially + // messages be ready to be re-sent in full. This is used when the peer has + // been detected to have restarted and is used to try to minimize the amount + // of data loss. However, data loss cannot be completely guaranteed when a + // peer restarts. + virtual void Reset() = 0; + + // Returns the amount of buffered data. This doesn't include packets that are + // e.g. inflight. + virtual size_t buffered_amount(StreamID stream_id) const = 0; + + // Returns the total amount of buffer data, for all streams. + virtual size_t total_buffered_amount() const = 0; + + // Returns the limit for the `OnBufferedAmountLow` event. Default value is 0. + virtual size_t buffered_amount_low_threshold(StreamID stream_id) const = 0; + + // Sets a limit for the `OnBufferedAmountLow` event. + virtual void SetBufferedAmountLowThreshold(StreamID stream_id, + size_t bytes) = 0; + + // Configures the send queue to support interleaved message sending as + // described in RFC8260. Every send queue starts with this value set as + // disabled, but can later change it when the capabilities of the connection + // have been negotiated. This affects the behavior of the `Produce` method. + virtual void EnableMessageInterleaving(bool enabled) = 0; +}; +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_SEND_QUEUE_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.cc b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.cc new file mode 100644 index 0000000000..d1560a75e4 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.cc @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/stream_scheduler.h" + +#include <algorithm> + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/common/str_join.h" +#include "net/dcsctp/packet/data.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/checks.h" +#include "rtc_base/logging.h" + +namespace dcsctp { + +void StreamScheduler::Stream::SetPriority(StreamPriority priority) { + priority_ = priority; + inverse_weight_ = InverseWeight(priority); +} + +absl::optional<SendQueue::DataToSend> StreamScheduler::Produce( + TimeMs now, + size_t max_size) { + // For non-interleaved streams, avoid rescheduling while still sending a + // message as it needs to be sent in full. For interleaved messaging, + // reschedule for every I-DATA chunk sent. + bool rescheduling = + enable_message_interleaving_ || !currently_sending_a_message_; + + RTC_LOG(LS_VERBOSE) << "Producing data, rescheduling=" << rescheduling + << ", active=" + << StrJoin(active_streams_, ", ", + [&](rtc::StringBuilder& sb, const auto& p) { + sb << *p->stream_id() << "@" + << *p->next_finish_time(); + }); + + RTC_DCHECK(rescheduling || current_stream_ != nullptr); + + absl::optional<SendQueue::DataToSend> data; + while (!data.has_value() && !active_streams_.empty()) { + if (rescheduling) { + auto it = active_streams_.begin(); + current_stream_ = *it; + RTC_DLOG(LS_VERBOSE) << "Rescheduling to stream " + << *current_stream_->stream_id(); + + active_streams_.erase(it); + current_stream_->ForceMarkInactive(); + } else { + RTC_DLOG(LS_VERBOSE) << "Producing from previous stream: " + << *current_stream_->stream_id(); + RTC_DCHECK(absl::c_any_of(active_streams_, [this](const auto* p) { + return p == current_stream_; + })); + } + + data = current_stream_->Produce(now, max_size); + } + + if (!data.has_value()) { + RTC_DLOG(LS_VERBOSE) + << "There is no stream with data; Can't produce any data."; + RTC_DCHECK(IsConsistent()); + + return absl::nullopt; + } + + RTC_DCHECK(data->data.stream_id == current_stream_->stream_id()); + + RTC_DLOG(LS_VERBOSE) << "Producing DATA, type=" + << (data->data.is_unordered ? "unordered" : "ordered") + << "::" + << (*data->data.is_beginning && *data->data.is_end + ? "complete" + : *data->data.is_beginning ? "first" + : *data->data.is_end ? "last" + : "middle") + << ", stream_id=" << *current_stream_->stream_id() + << ", ppid=" << *data->data.ppid + << ", length=" << data->data.payload.size(); + + currently_sending_a_message_ = !*data->data.is_end; + virtual_time_ = current_stream_->current_time(); + + // One side-effect of rescheduling is that the new stream will not be present + // in `active_streams`. + size_t bytes_to_send_next = current_stream_->bytes_to_send_in_next_message(); + if (rescheduling && bytes_to_send_next > 0) { + current_stream_->MakeActive(bytes_to_send_next); + } else if (!rescheduling && bytes_to_send_next == 0) { + current_stream_->MakeInactive(); + } + + RTC_DCHECK(IsConsistent()); + return data; +} + +StreamScheduler::VirtualTime StreamScheduler::Stream::CalculateFinishTime( + size_t bytes_to_send_next) const { + if (parent_.enable_message_interleaving_) { + // Perform weighted fair queuing scheduling. + return VirtualTime(*current_virtual_time_ + + bytes_to_send_next * *inverse_weight_); + } + + // Perform round-robin scheduling by letting the stream have its next virtual + // finish time in the future. It doesn't matter how far into the future, just + // any positive number so that any other stream that has the same virtual + // finish time as this stream gets to produce their data before revisiting + // this stream. + return VirtualTime(*current_virtual_time_ + 1); +} + +absl::optional<SendQueue::DataToSend> StreamScheduler::Stream::Produce( + TimeMs now, + size_t max_size) { + absl::optional<SendQueue::DataToSend> data = producer_.Produce(now, max_size); + + if (data.has_value()) { + VirtualTime new_current = CalculateFinishTime(data->data.payload.size()); + RTC_DLOG(LS_VERBOSE) << "Virtual time changed: " << *current_virtual_time_ + << " -> " << *new_current; + current_virtual_time_ = new_current; + } + + return data; +} + +bool StreamScheduler::IsConsistent() const { + for (Stream* stream : active_streams_) { + if (stream->next_finish_time_ == VirtualTime::Zero()) { + RTC_DLOG(LS_VERBOSE) << "Stream " << *stream->stream_id() + << " is active, but has no next-finish-time"; + return false; + } + } + return true; +} + +void StreamScheduler::Stream::MaybeMakeActive() { + RTC_DLOG(LS_VERBOSE) << "MaybeMakeActive(" << *stream_id() << ")"; + RTC_DCHECK(next_finish_time_ == VirtualTime::Zero()); + size_t bytes_to_send_next = bytes_to_send_in_next_message(); + if (bytes_to_send_next == 0) { + return; + } + + MakeActive(bytes_to_send_next); +} + +void StreamScheduler::Stream::MakeActive(size_t bytes_to_send_next) { + current_virtual_time_ = parent_.virtual_time_; + RTC_DCHECK_GT(bytes_to_send_next, 0); + VirtualTime next_finish_time = CalculateFinishTime( + std::min(bytes_to_send_next, parent_.max_payload_bytes_)); + RTC_DCHECK_GT(*next_finish_time, 0); + RTC_DLOG(LS_VERBOSE) << "Making stream " << *stream_id() + << " active, expiring at " << *next_finish_time; + RTC_DCHECK(next_finish_time_ == VirtualTime::Zero()); + next_finish_time_ = next_finish_time; + RTC_DCHECK(!absl::c_any_of(parent_.active_streams_, + [this](const auto* p) { return p == this; })); + parent_.active_streams_.emplace(this); +} + +void StreamScheduler::Stream::ForceMarkInactive() { + RTC_DLOG(LS_VERBOSE) << "Making stream " << *stream_id() << " inactive"; + RTC_DCHECK(next_finish_time_ != VirtualTime::Zero()); + next_finish_time_ = VirtualTime::Zero(); +} + +void StreamScheduler::Stream::MakeInactive() { + ForceMarkInactive(); + webrtc::EraseIf(parent_.active_streams_, + [&](const auto* s) { return s == this; }); +} + +std::set<StreamID> StreamScheduler::ActiveStreamsForTesting() const { + std::set<StreamID> stream_ids; + for (const auto& stream : active_streams_) { + stream_ids.insert(stream->stream_id()); + } + return stream_ids; +} + +} // namespace dcsctp diff --git a/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.h b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.h new file mode 100644 index 0000000000..9c523edbfc --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler.h @@ -0,0 +1,222 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#ifndef NET_DCSCTP_TX_STREAM_SCHEDULER_H_ +#define NET_DCSCTP_TX_STREAM_SCHEDULER_H_ + +#include <algorithm> +#include <cstdint> +#include <deque> +#include <map> +#include <memory> +#include <queue> +#include <set> +#include <string> +#include <utility> + +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "api/array_view.h" +#include "net/dcsctp/packet/chunk/idata_chunk.h" +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/dcsctp_message.h" +#include "net/dcsctp/public/dcsctp_socket.h" +#include "net/dcsctp/public/types.h" +#include "net/dcsctp/tx/send_queue.h" +#include "rtc_base/containers/flat_set.h" +#include "rtc_base/strong_alias.h" + +namespace dcsctp { + +// A parameterized stream scheduler. Currently, it implements the round robin +// scheduling algorithm using virtual finish time. It is to be used as a part of +// a send queue and will track all active streams (streams that have any data +// that can be sent). +// +// The stream scheduler works with the concept of associating active streams +// with a "virtual finish time", which is the time when a stream is allowed to +// produce data. Streams are ordered by their virtual finish time, and the +// "current virtual time" will advance to the next following virtual finish time +// whenever a chunk is to be produced. +// +// When message interleaving is enabled, the WFQ - Weighted Fair Queueing - +// scheduling algorithm will be used. And when it's not, round-robin scheduling +// will be used instead. +// +// In the round robin scheduling algorithm, a stream's virtual finish time will +// just increment by one (1) after having produced a chunk, which results in a +// round-robin scheduling. +// +// In WFQ scheduling algorithm, a stream's virtual finish time will be defined +// as the number of bytes in the next fragment to be sent, multiplied by the +// inverse of the stream's priority, meaning that a high priority - or a smaller +// fragment - results in a closer virtual finish time, compared to a stream with +// either a lower priority or a larger fragment to be sent. +class StreamScheduler { + private: + class VirtualTime : public webrtc::StrongAlias<class VirtualTimeTag, double> { + public: + constexpr explicit VirtualTime(const UnderlyingType& v) + : webrtc::StrongAlias<class VirtualTimeTag, double>(v) {} + + static constexpr VirtualTime Zero() { return VirtualTime(0); } + }; + class InverseWeight + : public webrtc::StrongAlias<class InverseWeightTag, double> { + public: + constexpr explicit InverseWeight(StreamPriority priority) + : webrtc::StrongAlias<class InverseWeightTag, double>( + 1.0 / std::max(static_cast<double>(*priority), 0.000001)) {} + }; + + public: + class StreamProducer { + public: + virtual ~StreamProducer() = default; + + // Produces a fragment of data to send. The current wall time is specified + // as `now` and should be used to skip chunks with expired limited lifetime. + // The parameter `max_size` specifies the maximum amount of actual payload + // that may be returned. If these constraints prevents the stream from + // sending some data, `absl::nullopt` should be returned. + virtual absl::optional<SendQueue::DataToSend> Produce(TimeMs now, + size_t max_size) = 0; + + // Returns the number of payload bytes that is scheduled to be sent in the + // next enqueued message, or zero if there are no enqueued messages or if + // the stream has been actively paused. + virtual size_t bytes_to_send_in_next_message() const = 0; + }; + + class Stream { + public: + StreamID stream_id() const { return stream_id_; } + + StreamPriority priority() const { return priority_; } + void SetPriority(StreamPriority priority); + + // Will activate the stream _if_ it has any data to send. That is, if the + // callback to `bytes_to_send_in_next_message` returns non-zero. If the + // callback returns zero, the stream will not be made active. + void MaybeMakeActive(); + + // Will remove the stream from the list of active streams, and will not try + // to produce data from it. To make it active again, call `MaybeMakeActive`. + void MakeInactive(); + + // Make the scheduler move to another message, or another stream. This is + // used to abort the scheduler from continuing producing fragments for the + // current message in case it's deleted. + void ForceReschedule() { parent_.ForceReschedule(); } + + private: + friend class StreamScheduler; + + Stream(StreamScheduler* parent, + StreamProducer* producer, + StreamID stream_id, + StreamPriority priority) + : parent_(*parent), + producer_(*producer), + stream_id_(stream_id), + priority_(priority), + inverse_weight_(priority) {} + + // Produces a message from this stream. This will only be called on streams + // that have data. + absl::optional<SendQueue::DataToSend> Produce(TimeMs now, size_t max_size); + + void MakeActive(size_t bytes_to_send_next); + void ForceMarkInactive(); + + VirtualTime current_time() const { return current_virtual_time_; } + VirtualTime next_finish_time() const { return next_finish_time_; } + size_t bytes_to_send_in_next_message() const { + return producer_.bytes_to_send_in_next_message(); + } + + VirtualTime CalculateFinishTime(size_t bytes_to_send_next) const; + + StreamScheduler& parent_; + StreamProducer& producer_; + const StreamID stream_id_; + StreamPriority priority_; + InverseWeight inverse_weight_; + // This outgoing stream's "current" virtual_time. + VirtualTime current_virtual_time_ = VirtualTime::Zero(); + VirtualTime next_finish_time_ = VirtualTime::Zero(); + }; + + // The `mtu` parameter represents the maximum SCTP packet size, which should + // be the same as `DcSctpOptions::mtu`. + explicit StreamScheduler(size_t mtu) + : max_payload_bytes_(mtu - SctpPacket::kHeaderSize - + IDataChunk::kHeaderSize) {} + + std::unique_ptr<Stream> CreateStream(StreamProducer* producer, + StreamID stream_id, + StreamPriority priority) { + return absl::WrapUnique(new Stream(this, producer, stream_id, priority)); + } + + void EnableMessageInterleaving(bool enabled) { + enable_message_interleaving_ = enabled; + } + + // Makes the scheduler stop producing message from the current stream and + // re-evaluates which stream to produce from. + void ForceReschedule() { currently_sending_a_message_ = false; } + + // Produces a fragment of data to send. The current wall time is specified as + // `now` and will be used to skip chunks with expired limited lifetime. The + // parameter `max_size` specifies the maximum amount of actual payload that + // may be returned. If no data can be produced, `absl::nullopt` is returned. + absl::optional<SendQueue::DataToSend> Produce(TimeMs now, size_t max_size); + + std::set<StreamID> ActiveStreamsForTesting() const; + + private: + struct ActiveStreamComparator { + // Ordered by virtual finish time (primary), stream-id (secondary). + bool operator()(Stream* a, Stream* b) const { + VirtualTime a_vft = a->next_finish_time(); + VirtualTime b_vft = b->next_finish_time(); + if (a_vft == b_vft) { + return a->stream_id() < b->stream_id(); + } + return a_vft < b_vft; + } + }; + + bool IsConsistent() const; + + const size_t max_payload_bytes_; + + // The current virtual time, as defined in the WFQ algorithm. + VirtualTime virtual_time_ = VirtualTime::Zero(); + + // The current stream to send chunks from. + Stream* current_stream_ = nullptr; + + bool enable_message_interleaving_ = false; + + // Indicates if the streams is currently sending a message, and should then + // - if message interleaving is not enabled - continue sending from this + // stream until that message has been sent in full. + bool currently_sending_a_message_ = false; + + // The currently active streams, ordered by virtual finish time. + webrtc::flat_set<Stream*, ActiveStreamComparator> active_streams_; +}; + +} // namespace dcsctp + +#endif // NET_DCSCTP_TX_STREAM_SCHEDULER_H_ diff --git a/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler_test.cc b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler_test.cc new file mode 100644 index 0000000000..58f0bc4690 --- /dev/null +++ b/third_party/libwebrtc/net/dcsctp/tx/stream_scheduler_test.cc @@ -0,0 +1,740 @@ +/* + * Copyright (c) 2022 The WebRTC project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ +#include "net/dcsctp/tx/stream_scheduler.h" + +#include <vector> + +#include "net/dcsctp/packet/sctp_packet.h" +#include "net/dcsctp/public/types.h" +#include "test/gmock.h" + +namespace dcsctp { +namespace { +using ::testing::Return; +using ::testing::StrictMock; + +constexpr size_t kMtu = 1000; +constexpr size_t kPayloadSize = 4; + +MATCHER_P(HasDataWithMid, mid, "") { + if (!arg.has_value()) { + *result_listener << "There was no produced data"; + return false; + } + + if (arg->data.message_id != mid) { + *result_listener << "the produced data had mid " << *arg->data.message_id + << " and not the expected " << *mid; + return false; + } + + return true; +} + +std::function<absl::optional<SendQueue::DataToSend>(TimeMs, size_t)> +CreateChunk(StreamID sid, MID mid, size_t payload_size = kPayloadSize) { + return [sid, mid, payload_size](TimeMs now, size_t max_size) { + return SendQueue::DataToSend(Data( + sid, SSN(0), mid, FSN(0), PPID(42), std::vector<uint8_t>(payload_size), + Data::IsBeginning(true), Data::IsEnd(true), IsUnordered(true))); + }; +} + +std::map<StreamID, size_t> GetPacketCounts(StreamScheduler& scheduler, + size_t packets_to_generate) { + std::map<StreamID, size_t> packet_counts; + for (size_t i = 0; i < packets_to_generate; ++i) { + absl::optional<SendQueue::DataToSend> data = + scheduler.Produce(TimeMs(0), kMtu); + if (data.has_value()) { + ++packet_counts[data->data.stream_id]; + } + } + return packet_counts; +} + +class MockStreamProducer : public StreamScheduler::StreamProducer { + public: + MOCK_METHOD(absl::optional<SendQueue::DataToSend>, + Produce, + (TimeMs, size_t), + (override)); + MOCK_METHOD(size_t, bytes_to_send_in_next_message, (), (const, override)); +}; + +class TestStream { + public: + TestStream(StreamScheduler& scheduler, + StreamID stream_id, + StreamPriority priority, + size_t packet_size = kPayloadSize) { + EXPECT_CALL(producer_, Produce) + .WillRepeatedly(CreateChunk(stream_id, MID(0), packet_size)); + EXPECT_CALL(producer_, bytes_to_send_in_next_message) + .WillRepeatedly(Return(packet_size)); + stream_ = scheduler.CreateStream(&producer_, stream_id, priority); + stream_->MaybeMakeActive(); + } + + StreamScheduler::Stream& stream() { return *stream_; } + + private: + StrictMock<MockStreamProducer> producer_; + std::unique_ptr<StreamScheduler::Stream> stream_; +}; + +// A scheduler without active streams doesn't produce data. +TEST(StreamSchedulerTest, HasNoActiveStreams) { + StreamScheduler scheduler(kMtu); + + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Stream properties can be set and retrieved +TEST(StreamSchedulerTest, CanSetAndGetStreamProperties) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer; + auto stream = + scheduler.CreateStream(&producer, StreamID(1), StreamPriority(2)); + + EXPECT_EQ(stream->stream_id(), StreamID(1)); + EXPECT_EQ(stream->priority(), StreamPriority(2)); + + stream->SetPriority(StreamPriority(0)); + EXPECT_EQ(stream->priority(), StreamPriority(0)); +} + +// A scheduler with a single stream produced packets from it. +TEST(StreamSchedulerTest, CanProduceFromSingleStream) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer; + EXPECT_CALL(producer, Produce).WillOnce(CreateChunk(StreamID(1), MID(0))); + EXPECT_CALL(producer, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(0)); + auto stream = + scheduler.CreateStream(&producer, StreamID(1), StreamPriority(2)); + stream->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(0))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Switches between two streams after every packet. +TEST(StreamSchedulerTest, WillRoundRobinBetweenStreams) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce(CreateChunk(StreamID(1), MID(101))) + .WillOnce(CreateChunk(StreamID(1), MID(102))); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> producer2; + EXPECT_CALL(producer2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(200))) + .WillOnce(CreateChunk(StreamID(2), MID(201))) + .WillOnce(CreateChunk(StreamID(2), MID(202))); + EXPECT_CALL(producer2, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&producer2, StreamID(2), StreamPriority(2)); + stream2->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Switches between two streams after every packet, but keeps producing from the +// same stream when a packet contains of multiple fragments. +TEST(StreamSchedulerTest, WillRoundRobinOnlyWhenFinishedProducingChunk) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce([](...) { + return SendQueue::DataToSend( + Data(StreamID(1), SSN(0), MID(101), FSN(0), PPID(42), + std::vector<uint8_t>(4), Data::IsBeginning(true), + Data::IsEnd(false), IsUnordered(true))); + }) + .WillOnce([](...) { + return SendQueue::DataToSend( + Data(StreamID(1), SSN(0), MID(101), FSN(0), PPID(42), + std::vector<uint8_t>(4), Data::IsBeginning(false), + Data::IsEnd(false), IsUnordered(true))); + }) + .WillOnce([](...) { + return SendQueue::DataToSend( + Data(StreamID(1), SSN(0), MID(101), FSN(0), PPID(42), + std::vector<uint8_t>(4), Data::IsBeginning(false), + Data::IsEnd(true), IsUnordered(true))); + }) + .WillOnce(CreateChunk(StreamID(1), MID(102))); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> producer2; + EXPECT_CALL(producer2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(200))) + .WillOnce(CreateChunk(StreamID(2), MID(201))) + .WillOnce(CreateChunk(StreamID(2), MID(202))); + EXPECT_CALL(producer2, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&producer2, StreamID(2), StreamPriority(2)); + stream2->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Deactivates a stream before it has finished producing all packets. +TEST(StreamSchedulerTest, StreamsCanBeMadeInactive) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce(CreateChunk(StreamID(1), MID(101))); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)); // hints that there is a MID(2) coming. + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + + // ... but the stream is made inactive before it can be produced. + stream1->MakeInactive(); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Resumes a paused stream - makes a stream active after inactivating it. +TEST(StreamSchedulerTest, SingleStreamCanBeResumed) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer1; + // Callbacks are setup so that they hint that there is a MID(2) coming... + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce(CreateChunk(StreamID(1), MID(101))) + .WillOnce(CreateChunk(StreamID(1), MID(102))); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) // When making active again + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + + stream1->MakeInactive(); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); + stream1->MaybeMakeActive(); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Iterates between streams, where one is suddenly paused and later resumed. +TEST(StreamSchedulerTest, WillRoundRobinWithPausedStream) { + StreamScheduler scheduler(kMtu); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce(CreateChunk(StreamID(1), MID(101))) + .WillOnce(CreateChunk(StreamID(1), MID(102))); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> producer2; + EXPECT_CALL(producer2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(200))) + .WillOnce(CreateChunk(StreamID(2), MID(201))) + .WillOnce(CreateChunk(StreamID(2), MID(202))); + EXPECT_CALL(producer2, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&producer2, StreamID(2), StreamPriority(2)); + stream2->MaybeMakeActive(); + + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + stream1->MakeInactive(); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + stream1->MaybeMakeActive(); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Verifies that packet counts are evenly distributed in round robin scheduling. +TEST(StreamSchedulerTest, WillDistributeRoundRobinPacketsEvenlyTwoStreams) { + StreamScheduler scheduler(kMtu); + TestStream stream1(scheduler, StreamID(1), StreamPriority(1)); + TestStream stream2(scheduler, StreamID(2), StreamPriority(1)); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 10); + EXPECT_EQ(packet_counts[StreamID(1)], 5U); + EXPECT_EQ(packet_counts[StreamID(2)], 5U); +} + +// Verifies that packet counts are evenly distributed among active streams, +// where a stream is suddenly made inactive, two are added, and then the paused +// stream is resumed. +TEST(StreamSchedulerTest, WillDistributeEvenlyWithPausedAndAddedStreams) { + StreamScheduler scheduler(kMtu); + TestStream stream1(scheduler, StreamID(1), StreamPriority(1)); + TestStream stream2(scheduler, StreamID(2), StreamPriority(1)); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 10); + EXPECT_EQ(packet_counts[StreamID(1)], 5U); + EXPECT_EQ(packet_counts[StreamID(2)], 5U); + + stream2.stream().MakeInactive(); + + TestStream stream3(scheduler, StreamID(3), StreamPriority(1)); + TestStream stream4(scheduler, StreamID(4), StreamPriority(1)); + + std::map<StreamID, size_t> counts2 = GetPacketCounts(scheduler, 15); + EXPECT_EQ(counts2[StreamID(1)], 5U); + EXPECT_EQ(counts2[StreamID(2)], 0U); + EXPECT_EQ(counts2[StreamID(3)], 5U); + EXPECT_EQ(counts2[StreamID(4)], 5U); + + stream2.stream().MaybeMakeActive(); + + std::map<StreamID, size_t> counts3 = GetPacketCounts(scheduler, 20); + EXPECT_EQ(counts3[StreamID(1)], 5U); + EXPECT_EQ(counts3[StreamID(2)], 5U); + EXPECT_EQ(counts3[StreamID(3)], 5U); + EXPECT_EQ(counts3[StreamID(4)], 5U); +} + +// Degrades to fair queuing with streams having identical priority. +TEST(StreamSchedulerTest, WillDoFairQueuingWithSamePriority) { + StreamScheduler scheduler(kMtu); + scheduler.EnableMessageInterleaving(true); + + constexpr size_t kSmallPacket = 30; + constexpr size_t kLargePacket = 70; + + StrictMock<MockStreamProducer> callback1; + EXPECT_CALL(callback1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100), kSmallPacket)) + .WillOnce(CreateChunk(StreamID(1), MID(101), kSmallPacket)) + .WillOnce(CreateChunk(StreamID(1), MID(102), kSmallPacket)); + EXPECT_CALL(callback1, bytes_to_send_in_next_message) + .WillOnce(Return(kSmallPacket)) // When making active + .WillOnce(Return(kSmallPacket)) + .WillOnce(Return(kSmallPacket)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&callback1, StreamID(1), StreamPriority(2)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> callback2; + EXPECT_CALL(callback2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(200), kLargePacket)) + .WillOnce(CreateChunk(StreamID(2), MID(201), kLargePacket)) + .WillOnce(CreateChunk(StreamID(2), MID(202), kLargePacket)); + EXPECT_CALL(callback2, bytes_to_send_in_next_message) + .WillOnce(Return(kLargePacket)) // When making active + .WillOnce(Return(kLargePacket)) + .WillOnce(Return(kLargePacket)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&callback2, StreamID(2), StreamPriority(2)); + stream2->MaybeMakeActive(); + + // t = 30 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + // t = 60 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + // t = 70 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + // t = 90 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + // t = 140 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + // t = 210 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Will do weighted fair queuing with three streams having different priority. +TEST(StreamSchedulerTest, WillDoWeightedFairQueuingSameSizeDifferentPriority) { + StreamScheduler scheduler(kMtu); + scheduler.EnableMessageInterleaving(true); + + StrictMock<MockStreamProducer> callback1; + EXPECT_CALL(callback1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(100))) + .WillOnce(CreateChunk(StreamID(1), MID(101))) + .WillOnce(CreateChunk(StreamID(1), MID(102))); + EXPECT_CALL(callback1, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + // Priority 125 -> allowed to produce every 1000/125 ~= 80 time units. + auto stream1 = + scheduler.CreateStream(&callback1, StreamID(1), StreamPriority(125)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> callback2; + EXPECT_CALL(callback2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(200))) + .WillOnce(CreateChunk(StreamID(2), MID(201))) + .WillOnce(CreateChunk(StreamID(2), MID(202))); + EXPECT_CALL(callback2, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + // Priority 200 -> allowed to produce every 1000/200 ~= 50 time units. + auto stream2 = + scheduler.CreateStream(&callback2, StreamID(2), StreamPriority(200)); + stream2->MaybeMakeActive(); + + StrictMock<MockStreamProducer> callback3; + EXPECT_CALL(callback3, Produce) + .WillOnce(CreateChunk(StreamID(3), MID(300))) + .WillOnce(CreateChunk(StreamID(3), MID(301))) + .WillOnce(CreateChunk(StreamID(3), MID(302))); + EXPECT_CALL(callback3, bytes_to_send_in_next_message) + .WillOnce(Return(kPayloadSize)) // When making active + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(kPayloadSize)) + .WillOnce(Return(0)); + // Priority 500 -> allowed to produce every 1000/500 ~= 20 time units. + auto stream3 = + scheduler.CreateStream(&callback3, StreamID(3), StreamPriority(500)); + stream3->MaybeMakeActive(); + + // t ~= 20 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(300))); + // t ~= 40 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(301))); + // t ~= 50 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + // t ~= 60 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(302))); + // t ~= 80 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + // t ~= 100 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + // t ~= 150 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + // t ~= 160 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + // t ~= 240 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Will do weighted fair queuing with three streams having different priority +// and sending different payload sizes. +TEST(StreamSchedulerTest, WillDoWeightedFairQueuingDifferentSizeAndPriority) { + StreamScheduler scheduler(kMtu); + scheduler.EnableMessageInterleaving(true); + + constexpr size_t kSmallPacket = 20; + constexpr size_t kMediumPacket = 50; + constexpr size_t kLargePacket = 70; + + // Stream with priority = 125 -> inverse weight ~=80 + StrictMock<MockStreamProducer> callback1; + EXPECT_CALL(callback1, Produce) + // virtual finish time ~ 0 + 50 * 80 = 4000 + .WillOnce(CreateChunk(StreamID(1), MID(100), kMediumPacket)) + // virtual finish time ~ 4000 + 20 * 80 = 5600 + .WillOnce(CreateChunk(StreamID(1), MID(101), kSmallPacket)) + // virtual finish time ~ 5600 + 70 * 80 = 11200 + .WillOnce(CreateChunk(StreamID(1), MID(102), kLargePacket)); + EXPECT_CALL(callback1, bytes_to_send_in_next_message) + .WillOnce(Return(kMediumPacket)) // When making active + .WillOnce(Return(kSmallPacket)) + .WillOnce(Return(kLargePacket)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&callback1, StreamID(1), StreamPriority(125)); + stream1->MaybeMakeActive(); + + // Stream with priority = 200 -> inverse weight ~=50 + StrictMock<MockStreamProducer> callback2; + EXPECT_CALL(callback2, Produce) + // virtual finish time ~ 0 + 50 * 50 = 2500 + .WillOnce(CreateChunk(StreamID(2), MID(200), kMediumPacket)) + // virtual finish time ~ 2500 + 70 * 50 = 6000 + .WillOnce(CreateChunk(StreamID(2), MID(201), kLargePacket)) + // virtual finish time ~ 6000 + 20 * 50 = 7000 + .WillOnce(CreateChunk(StreamID(2), MID(202), kSmallPacket)); + EXPECT_CALL(callback2, bytes_to_send_in_next_message) + .WillOnce(Return(kMediumPacket)) // When making active + .WillOnce(Return(kLargePacket)) + .WillOnce(Return(kSmallPacket)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&callback2, StreamID(2), StreamPriority(200)); + stream2->MaybeMakeActive(); + + // Stream with priority = 500 -> inverse weight ~=20 + StrictMock<MockStreamProducer> callback3; + EXPECT_CALL(callback3, Produce) + // virtual finish time ~ 0 + 20 * 20 = 400 + .WillOnce(CreateChunk(StreamID(3), MID(300), kSmallPacket)) + // virtual finish time ~ 400 + 50 * 20 = 1400 + .WillOnce(CreateChunk(StreamID(3), MID(301), kMediumPacket)) + // virtual finish time ~ 1400 + 70 * 20 = 2800 + .WillOnce(CreateChunk(StreamID(3), MID(302), kLargePacket)); + EXPECT_CALL(callback3, bytes_to_send_in_next_message) + .WillOnce(Return(kSmallPacket)) // When making active + .WillOnce(Return(kMediumPacket)) + .WillOnce(Return(kLargePacket)) + .WillOnce(Return(0)); + auto stream3 = + scheduler.CreateStream(&callback3, StreamID(3), StreamPriority(500)); + stream3->MaybeMakeActive(); + + // t ~= 400 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(300))); + // t ~= 1400 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(301))); + // t ~= 2500 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(200))); + // t ~= 2800 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(302))); + // t ~= 4000 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(100))); + // t ~= 5600 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(101))); + // t ~= 6000 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(201))); + // t ~= 7000 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(202))); + // t ~= 11200 + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(102))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} +TEST(StreamSchedulerTest, WillDistributeWFQPacketsInTwoStreamsByPriority) { + // A simple test with two streams of different priority, but sending packets + // of identical size. Verifies that the ratio of sent packets represent their + // priority. + StreamScheduler scheduler(kMtu); + scheduler.EnableMessageInterleaving(true); + + TestStream stream1(scheduler, StreamID(1), StreamPriority(100), kPayloadSize); + TestStream stream2(scheduler, StreamID(2), StreamPriority(200), kPayloadSize); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 15); + EXPECT_EQ(packet_counts[StreamID(1)], 5U); + EXPECT_EQ(packet_counts[StreamID(2)], 10U); +} + +TEST(StreamSchedulerTest, WillDistributeWFQPacketsInFourStreamsByPriority) { + // Same as `WillDistributeWFQPacketsInTwoStreamsByPriority` but with more + // streams. + StreamScheduler scheduler(kMtu); + scheduler.EnableMessageInterleaving(true); + + TestStream stream1(scheduler, StreamID(1), StreamPriority(100), kPayloadSize); + TestStream stream2(scheduler, StreamID(2), StreamPriority(200), kPayloadSize); + TestStream stream3(scheduler, StreamID(3), StreamPriority(300), kPayloadSize); + TestStream stream4(scheduler, StreamID(4), StreamPriority(400), kPayloadSize); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 50); + EXPECT_EQ(packet_counts[StreamID(1)], 5U); + EXPECT_EQ(packet_counts[StreamID(2)], 10U); + EXPECT_EQ(packet_counts[StreamID(3)], 15U); + EXPECT_EQ(packet_counts[StreamID(4)], 20U); +} + +TEST(StreamSchedulerTest, WillDistributeFromTwoStreamsFairly) { + // A simple test with two streams of different priority, but sending packets + // of different size. Verifies that the ratio of total packet payload + // represent their priority. + // In this example, + // * stream1 has priority 100 and sends packets of size 8 + // * stream2 has priority 400 and sends packets of size 4 + // With round robin, stream1 would get twice as many payload bytes on the wire + // as stream2, but with WFQ and a 4x priority increase, stream2 should 4x as + // many payload bytes on the wire. That translates to stream2 getting 8x as + // many packets on the wire as they are half as large. + StreamScheduler scheduler(kMtu); + // Enable WFQ scheduler. + scheduler.EnableMessageInterleaving(true); + + TestStream stream1(scheduler, StreamID(1), StreamPriority(100), + /*packet_size=*/8); + TestStream stream2(scheduler, StreamID(2), StreamPriority(400), + /*packet_size=*/4); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 90); + EXPECT_EQ(packet_counts[StreamID(1)], 10U); + EXPECT_EQ(packet_counts[StreamID(2)], 80U); +} + +TEST(StreamSchedulerTest, WillDistributeFromFourStreamsFairly) { + // Same as `WillDistributeWeightedFairFromTwoStreamsFairly` but more + // complicated. + StreamScheduler scheduler(kMtu); + // Enable WFQ scheduler. + scheduler.EnableMessageInterleaving(true); + + TestStream stream1(scheduler, StreamID(1), StreamPriority(100), + /*packet_size=*/10); + TestStream stream2(scheduler, StreamID(2), StreamPriority(200), + /*packet_size=*/10); + TestStream stream3(scheduler, StreamID(3), StreamPriority(200), + /*packet_size=*/20); + TestStream stream4(scheduler, StreamID(4), StreamPriority(400), + /*packet_size=*/30); + + std::map<StreamID, size_t> packet_counts = GetPacketCounts(scheduler, 80); + // 15 packets * 10 bytes = 150 bytes at priority 100. + EXPECT_EQ(packet_counts[StreamID(1)], 15U); + // 30 packets * 10 bytes = 300 bytes at priority 200. + EXPECT_EQ(packet_counts[StreamID(2)], 30U); + // 15 packets * 20 bytes = 300 bytes at priority 200. + EXPECT_EQ(packet_counts[StreamID(3)], 15U); + // 20 packets * 30 bytes = 600 bytes at priority 400. + EXPECT_EQ(packet_counts[StreamID(4)], 20U); +} + +// Sending large messages with small MTU will fragment the messages and produce +// a first fragment not larger than the MTU, and will then not first send from +// the stream with the smallest message, as their first fragment will be equally +// small for both streams. See `LargeMessageWithLargeMtu` for the same test, but +// with a larger MTU. +TEST(StreamSchedulerTest, SendLargeMessageWithSmallMtu) { + StreamScheduler scheduler(100 + SctpPacket::kHeaderSize + + IDataChunk::kHeaderSize); + scheduler.EnableMessageInterleaving(true); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(0), 100)) + .WillOnce(CreateChunk(StreamID(1), MID(0), 100)); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(200)) // When making active + .WillOnce(Return(100)) + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(1)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> producer2; + EXPECT_CALL(producer2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(1), 100)) + .WillOnce(CreateChunk(StreamID(2), MID(1), 50)); + EXPECT_CALL(producer2, bytes_to_send_in_next_message) + .WillOnce(Return(150)) // When making active + .WillOnce(Return(50)) + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&producer2, StreamID(2), StreamPriority(1)); + stream2->MaybeMakeActive(); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(0))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(1))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(1))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(0))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +// Sending large messages with large MTU will not fragment messages and will +// send the message first from the stream that has the smallest message. +TEST(StreamSchedulerTest, SendLargeMessageWithLargeMtu) { + StreamScheduler scheduler(200 + SctpPacket::kHeaderSize + + IDataChunk::kHeaderSize); + scheduler.EnableMessageInterleaving(true); + + StrictMock<MockStreamProducer> producer1; + EXPECT_CALL(producer1, Produce) + .WillOnce(CreateChunk(StreamID(1), MID(0), 200)); + EXPECT_CALL(producer1, bytes_to_send_in_next_message) + .WillOnce(Return(200)) // When making active + .WillOnce(Return(0)); + auto stream1 = + scheduler.CreateStream(&producer1, StreamID(1), StreamPriority(1)); + stream1->MaybeMakeActive(); + + StrictMock<MockStreamProducer> producer2; + EXPECT_CALL(producer2, Produce) + .WillOnce(CreateChunk(StreamID(2), MID(1), 150)); + EXPECT_CALL(producer2, bytes_to_send_in_next_message) + .WillOnce(Return(150)) // When making active + .WillOnce(Return(0)); + auto stream2 = + scheduler.CreateStream(&producer2, StreamID(2), StreamPriority(1)); + stream2->MaybeMakeActive(); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(1))); + EXPECT_THAT(scheduler.Produce(TimeMs(0), kMtu), HasDataWithMid(MID(0))); + EXPECT_EQ(scheduler.Produce(TimeMs(0), kMtu), absl::nullopt); +} + +} // namespace +} // namespace dcsctp |