summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/net/dcsctp/fuzzers
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/libwebrtc/net/dcsctp/fuzzers')
-rw-r--r--third_party/libwebrtc/net/dcsctp/fuzzers/BUILD.gn50
-rw-r--r--third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.cc461
-rw-r--r--third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.h119
-rw-r--r--third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers_test.cc40
4 files changed, 670 insertions, 0 deletions
diff --git a/third_party/libwebrtc/net/dcsctp/fuzzers/BUILD.gn b/third_party/libwebrtc/net/dcsctp/fuzzers/BUILD.gn
new file mode 100644
index 0000000000..302c828684
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/fuzzers/BUILD.gn
@@ -0,0 +1,50 @@
+# Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+#
+# Use of this source code is governed by a BSD-style license
+# that can be found in the LICENSE file in the root of the source
+# tree. An additional intellectual property rights grant can be found
+# in the file PATENTS. All contributing project authors may
+# be found in the AUTHORS file in the root of the source tree.
+
+import("../../../webrtc.gni")
+
+rtc_library("dcsctp_fuzzers") {
+ testonly = true
+ deps = [
+ "../../../api:array_view",
+ "../../../api/task_queue:task_queue",
+ "../../../rtc_base:checks",
+ "../../../rtc_base:logging",
+ "../common:math",
+ "../packet:chunk",
+ "../packet:error_cause",
+ "../packet:parameter",
+ "../public:socket",
+ "../public:types",
+ "../socket:dcsctp_socket",
+ ]
+ sources = [
+ "dcsctp_fuzzers.cc",
+ "dcsctp_fuzzers.h",
+ ]
+}
+
+if (rtc_include_tests) {
+ rtc_library("dcsctp_fuzzers_unittests") {
+ testonly = true
+
+ deps = [
+ ":dcsctp_fuzzers",
+ "../../../api:array_view",
+ "../../../rtc_base:checks",
+ "../../../rtc_base:gunit_helpers",
+ "../../../rtc_base:logging",
+ "../../../test:test_support",
+ "../packet:sctp_packet",
+ "../public:socket",
+ "../socket:dcsctp_socket",
+ "../testing:testing_macros",
+ ]
+ sources = [ "dcsctp_fuzzers_test.cc" ]
+ }
+}
diff --git a/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.cc b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.cc
new file mode 100644
index 0000000000..e8fcacffa0
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.cc
@@ -0,0 +1,461 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/fuzzers/dcsctp_fuzzers.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "net/dcsctp/common/math.h"
+#include "net/dcsctp/packet/chunk/cookie_ack_chunk.h"
+#include "net/dcsctp/packet/chunk/cookie_echo_chunk.h"
+#include "net/dcsctp/packet/chunk/data_chunk.h"
+#include "net/dcsctp/packet/chunk/forward_tsn_chunk.h"
+#include "net/dcsctp/packet/chunk/forward_tsn_common.h"
+#include "net/dcsctp/packet/chunk/shutdown_chunk.h"
+#include "net/dcsctp/packet/error_cause/protocol_violation_cause.h"
+#include "net/dcsctp/packet/error_cause/user_initiated_abort_cause.h"
+#include "net/dcsctp/packet/parameter/forward_tsn_supported_parameter.h"
+#include "net/dcsctp/packet/parameter/outgoing_ssn_reset_request_parameter.h"
+#include "net/dcsctp/packet/parameter/state_cookie_parameter.h"
+#include "net/dcsctp/public/dcsctp_message.h"
+#include "net/dcsctp/public/types.h"
+#include "net/dcsctp/socket/dcsctp_socket.h"
+#include "net/dcsctp/socket/state_cookie.h"
+#include "rtc_base/logging.h"
+
+namespace dcsctp {
+namespace dcsctp_fuzzers {
+namespace {
+static constexpr int kRandomValue = FuzzerCallbacks::kRandomValue;
+static constexpr size_t kMinInputLength = 5;
+static constexpr size_t kMaxInputLength = 1024;
+
+// A starting state for the socket, when fuzzing.
+enum class StartingState : int {
+ kConnectNotCalled,
+ // When socket initiating Connect
+ kConnectCalled,
+ kReceivedInitAck,
+ kReceivedCookieAck,
+ // When socket initiating Shutdown
+ kShutdownCalled,
+ kReceivedShutdownAck,
+ // When peer socket initiated Connect
+ kReceivedInit,
+ kReceivedCookieEcho,
+ // When peer initiated Shutdown
+ kReceivedShutdown,
+ kReceivedShutdownComplete,
+ kNumberOfStates,
+};
+
+// State about the current fuzzing iteration
+class FuzzState {
+ public:
+ explicit FuzzState(rtc::ArrayView<const uint8_t> data) : data_(data) {}
+
+ uint8_t GetByte() {
+ uint8_t value = 0;
+ if (offset_ < data_.size()) {
+ value = data_[offset_];
+ ++offset_;
+ }
+ return value;
+ }
+
+ TSN GetNextTSN() { return TSN(tsn_++); }
+ MID GetNextMID() { return MID(mid_++); }
+
+ bool empty() const { return offset_ >= data_.size(); }
+
+ private:
+ uint32_t tsn_ = kRandomValue;
+ uint32_t mid_ = 0;
+ rtc::ArrayView<const uint8_t> data_;
+ size_t offset_ = 0;
+};
+
+void SetSocketState(DcSctpSocketInterface& socket,
+ FuzzerCallbacks& socket_cb,
+ StartingState state) {
+ // We'll use another temporary peer socket for the establishment.
+ FuzzerCallbacks peer_cb;
+ DcSctpSocket peer("peer", peer_cb, nullptr, {});
+
+ switch (state) {
+ case StartingState::kConnectNotCalled:
+ return;
+ case StartingState::kConnectCalled:
+ socket.Connect();
+ return;
+ case StartingState::kReceivedInitAck:
+ socket.Connect();
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
+ return;
+ case StartingState::kReceivedCookieAck:
+ socket.Connect();
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
+ return;
+ case StartingState::kShutdownCalled:
+ socket.Connect();
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
+ socket.Shutdown();
+ return;
+ case StartingState::kReceivedShutdownAck:
+ socket.Connect();
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
+ socket.Shutdown();
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_ACK
+ return;
+ case StartingState::kReceivedInit:
+ peer.Connect();
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT
+ return;
+ case StartingState::kReceivedCookieEcho:
+ peer.Connect();
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT_ACK
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ECHO
+ return;
+ case StartingState::kReceivedShutdown:
+ socket.Connect();
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
+ peer.Shutdown();
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN
+ return;
+ case StartingState::kReceivedShutdownComplete:
+ socket.Connect();
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // INIT
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // INIT_ACK
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // COOKIE_ECHO
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // COOKIE_ACK
+ peer.Shutdown();
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN
+ peer.ReceivePacket(socket_cb.ConsumeSentPacket()); // SHUTDOWN_ACK
+ socket.ReceivePacket(peer_cb.ConsumeSentPacket()); // SHUTDOWN_COMPLETE
+ return;
+ case StartingState::kNumberOfStates:
+ RTC_CHECK(false);
+ return;
+ }
+}
+
+void MakeDataChunk(FuzzState& state, SctpPacket::Builder& b) {
+ DataChunk::Options options;
+ options.is_unordered = IsUnordered(state.GetByte() != 0);
+ options.is_beginning = Data::IsBeginning(state.GetByte() != 0);
+ options.is_end = Data::IsEnd(state.GetByte() != 0);
+ b.Add(DataChunk(state.GetNextTSN(), StreamID(state.GetByte()),
+ SSN(state.GetByte()), PPID(53), std::vector<uint8_t>(10),
+ options));
+}
+
+void MakeInitChunk(FuzzState& state, SctpPacket::Builder& b) {
+ Parameters::Builder builder;
+ builder.Add(ForwardTsnSupportedParameter());
+
+ b.Add(InitChunk(VerificationTag(kRandomValue), 10000, 1000, 1000,
+ TSN(kRandomValue), builder.Build()));
+}
+
+void MakeInitAckChunk(FuzzState& state, SctpPacket::Builder& b) {
+ Parameters::Builder builder;
+ builder.Add(ForwardTsnSupportedParameter());
+
+ uint8_t state_cookie[] = {1, 2, 3, 4, 5};
+ Parameters::Builder params_builder =
+ Parameters::Builder().Add(StateCookieParameter(state_cookie));
+
+ b.Add(InitAckChunk(VerificationTag(kRandomValue), 10000, 1000, 1000,
+ TSN(kRandomValue), builder.Build()));
+}
+
+void MakeSackChunk(FuzzState& state, SctpPacket::Builder& b) {
+ std::vector<SackChunk::GapAckBlock> gap_ack_blocks;
+ uint16_t last_end = 0;
+ while (gap_ack_blocks.size() < 20) {
+ uint8_t delta_start = state.GetByte();
+ if (delta_start < 0x80) {
+ break;
+ }
+ uint8_t delta_end = state.GetByte();
+
+ uint16_t start = last_end + delta_start;
+ uint16_t end = start + delta_end;
+ last_end = end;
+ gap_ack_blocks.emplace_back(start, end);
+ }
+
+ TSN cum_ack_tsn(kRandomValue + state.GetByte());
+ b.Add(SackChunk(cum_ack_tsn, 10000, std::move(gap_ack_blocks), {}));
+}
+
+void MakeHeartbeatRequestChunk(FuzzState& state, SctpPacket::Builder& b) {
+ uint8_t info[] = {1, 2, 3, 4, 5};
+ b.Add(HeartbeatRequestChunk(
+ Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build()));
+}
+
+void MakeHeartbeatAckChunk(FuzzState& state, SctpPacket::Builder& b) {
+ std::vector<uint8_t> info(8);
+ b.Add(HeartbeatRequestChunk(
+ Parameters::Builder().Add(HeartbeatInfoParameter(info)).Build()));
+}
+
+void MakeAbortChunk(FuzzState& state, SctpPacket::Builder& b) {
+ b.Add(AbortChunk(
+ /*filled_in_verification_tag=*/true,
+ Parameters::Builder().Add(UserInitiatedAbortCause("Fuzzing")).Build()));
+}
+
+void MakeErrorChunk(FuzzState& state, SctpPacket::Builder& b) {
+ b.Add(ErrorChunk(
+ Parameters::Builder().Add(ProtocolViolationCause("Fuzzing")).Build()));
+}
+
+void MakeCookieEchoChunk(FuzzState& state, SctpPacket::Builder& b) {
+ std::vector<uint8_t> cookie(StateCookie::kCookieSize);
+ b.Add(CookieEchoChunk(cookie));
+}
+
+void MakeCookieAckChunk(FuzzState& state, SctpPacket::Builder& b) {
+ b.Add(CookieAckChunk());
+}
+
+void MakeShutdownChunk(FuzzState& state, SctpPacket::Builder& b) {
+ b.Add(ShutdownChunk(state.GetNextTSN()));
+}
+
+void MakeShutdownAckChunk(FuzzState& state, SctpPacket::Builder& b) {
+ b.Add(ShutdownAckChunk());
+}
+
+void MakeShutdownCompleteChunk(FuzzState& state, SctpPacket::Builder& b) {
+ b.Add(ShutdownCompleteChunk(false));
+}
+
+void MakeReConfigChunk(FuzzState& state, SctpPacket::Builder& b) {
+ std::vector<StreamID> streams = {StreamID(state.GetByte())};
+ Parameters::Builder params_builder =
+ Parameters::Builder().Add(OutgoingSSNResetRequestParameter(
+ ReconfigRequestSN(kRandomValue), ReconfigRequestSN(kRandomValue),
+ state.GetNextTSN(), streams));
+ b.Add(ReConfigChunk(params_builder.Build()));
+}
+
+void MakeForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) {
+ std::vector<ForwardTsnChunk::SkippedStream> skipped_streams;
+ for (;;) {
+ uint8_t stream = state.GetByte();
+ if (skipped_streams.size() > 20 || stream < 0x80) {
+ break;
+ }
+ skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte()));
+ }
+ b.Add(ForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams)));
+}
+
+void MakeIDataChunk(FuzzState& state, SctpPacket::Builder& b) {
+ DataChunk::Options options;
+ options.is_unordered = IsUnordered(state.GetByte() != 0);
+ options.is_beginning = Data::IsBeginning(state.GetByte() != 0);
+ options.is_end = Data::IsEnd(state.GetByte() != 0);
+ b.Add(IDataChunk(state.GetNextTSN(), StreamID(state.GetByte()),
+ state.GetNextMID(), PPID(53), FSN(0),
+ std::vector<uint8_t>(10), options));
+}
+
+void MakeIForwardTsnChunk(FuzzState& state, SctpPacket::Builder& b) {
+ std::vector<ForwardTsnChunk::SkippedStream> skipped_streams;
+ for (;;) {
+ uint8_t stream = state.GetByte();
+ if (skipped_streams.size() > 20 || stream < 0x80) {
+ break;
+ }
+ skipped_streams.emplace_back(StreamID(stream), SSN(state.GetByte()));
+ }
+ b.Add(IForwardTsnChunk(state.GetNextTSN(), std::move(skipped_streams)));
+}
+
+class RandomFuzzedChunk : public Chunk {
+ public:
+ explicit RandomFuzzedChunk(FuzzState& state) : state_(state) {}
+
+ void SerializeTo(std::vector<uint8_t>& out) const override {
+ size_t bytes = state_.GetByte();
+ for (size_t i = 0; i < bytes; ++i) {
+ out.push_back(state_.GetByte());
+ }
+ }
+
+ std::string ToString() const override { return std::string("RANDOM_FUZZED"); }
+
+ private:
+ FuzzState& state_;
+};
+
+void MakeChunkWithRandomContent(FuzzState& state, SctpPacket::Builder& b) {
+ b.Add(RandomFuzzedChunk(state));
+}
+
+std::vector<uint8_t> GeneratePacket(FuzzState& state) {
+ DcSctpOptions options;
+ // Setting a fixed limit to not be dependent on the defaults, which may
+ // change.
+ options.mtu = 2048;
+ SctpPacket::Builder builder(VerificationTag(kRandomValue), options);
+
+ // The largest expected serialized chunk, as created by fuzzers.
+ static constexpr size_t kMaxChunkSize = 256;
+
+ for (int i = 0; i < 5 && builder.bytes_remaining() > kMaxChunkSize; ++i) {
+ switch (state.GetByte()) {
+ case 1:
+ MakeDataChunk(state, builder);
+ break;
+ case 2:
+ MakeInitChunk(state, builder);
+ break;
+ case 3:
+ MakeInitAckChunk(state, builder);
+ break;
+ case 4:
+ MakeSackChunk(state, builder);
+ break;
+ case 5:
+ MakeHeartbeatRequestChunk(state, builder);
+ break;
+ case 6:
+ MakeHeartbeatAckChunk(state, builder);
+ break;
+ case 7:
+ MakeAbortChunk(state, builder);
+ break;
+ case 8:
+ MakeErrorChunk(state, builder);
+ break;
+ case 9:
+ MakeCookieEchoChunk(state, builder);
+ break;
+ case 10:
+ MakeCookieAckChunk(state, builder);
+ break;
+ case 11:
+ MakeShutdownChunk(state, builder);
+ break;
+ case 12:
+ MakeShutdownAckChunk(state, builder);
+ break;
+ case 13:
+ MakeShutdownCompleteChunk(state, builder);
+ break;
+ case 14:
+ MakeReConfigChunk(state, builder);
+ break;
+ case 15:
+ MakeForwardTsnChunk(state, builder);
+ break;
+ case 16:
+ MakeIDataChunk(state, builder);
+ break;
+ case 17:
+ MakeIForwardTsnChunk(state, builder);
+ break;
+ case 18:
+ MakeChunkWithRandomContent(state, builder);
+ break;
+ default:
+ break;
+ }
+ }
+ std::vector<uint8_t> packet = builder.Build();
+ return packet;
+}
+} // namespace
+
+void FuzzSocket(DcSctpSocketInterface& socket,
+ FuzzerCallbacks& cb,
+ rtc::ArrayView<const uint8_t> data) {
+ if (data.size() < kMinInputLength || data.size() > kMaxInputLength) {
+ return;
+ }
+ if (data[0] >= static_cast<int>(StartingState::kNumberOfStates)) {
+ return;
+ }
+
+ // Set the socket in a specified valid starting state
+ SetSocketState(socket, cb, static_cast<StartingState>(data[0]));
+
+ FuzzState state(data.subview(1));
+
+ while (!state.empty()) {
+ switch (state.GetByte()) {
+ case 1:
+ // Generate a valid SCTP packet (based on fuzz data) and "receive it".
+ socket.ReceivePacket(GeneratePacket(state));
+ break;
+ case 2:
+ socket.Connect();
+ break;
+ case 3:
+ socket.Shutdown();
+ break;
+ case 4:
+ socket.Close();
+ break;
+ case 5: {
+ StreamID streams[] = {StreamID(state.GetByte())};
+ socket.ResetStreams(streams);
+ } break;
+ case 6: {
+ uint8_t flags = state.GetByte();
+ SendOptions options;
+ options.unordered = IsUnordered(flags & 0x01);
+ options.max_retransmissions =
+ (flags & 0x02) != 0 ? absl::make_optional(0) : absl::nullopt;
+ options.lifecycle_id = LifecycleId(42);
+ size_t payload_exponent = (flags >> 2) % 16;
+ size_t payload_size = static_cast<size_t>(1) << payload_exponent;
+ socket.Send(DcSctpMessage(StreamID(state.GetByte()), PPID(53),
+ std::vector<uint8_t>(payload_size)),
+ options);
+ break;
+ }
+ case 7: {
+ // Expire an active timeout/timer.
+ uint8_t timeout_idx = state.GetByte();
+ absl::optional<TimeoutID> timeout_id = cb.ExpireTimeout(timeout_idx);
+ if (timeout_id.has_value()) {
+ socket.HandleTimeout(*timeout_id);
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ }
+}
+} // namespace dcsctp_fuzzers
+} // namespace dcsctp
diff --git a/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.h b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.h
new file mode 100644
index 0000000000..90cfa35099
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.h
@@ -0,0 +1,119 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#ifndef NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_
+#define NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_
+
+#include <deque>
+#include <memory>
+#include <set>
+#include <vector>
+
+#include "api/array_view.h"
+#include "api/task_queue/task_queue_base.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+
+namespace dcsctp {
+namespace dcsctp_fuzzers {
+
+// A fake timeout used during fuzzing.
+class FuzzerTimeout : public Timeout {
+ public:
+ explicit FuzzerTimeout(std::set<TimeoutID>& active_timeouts)
+ : active_timeouts_(active_timeouts) {}
+
+ void Start(DurationMs duration_ms, TimeoutID timeout_id) override {
+ // Start is only allowed to be called on stopped or expired timeouts.
+ if (timeout_id_.has_value()) {
+ // It has been started before, but maybe it expired. Ensure that it's not
+ // running at least.
+ RTC_DCHECK(active_timeouts_.find(*timeout_id_) == active_timeouts_.end());
+ }
+ timeout_id_ = timeout_id;
+ RTC_DCHECK(active_timeouts_.insert(timeout_id).second);
+ }
+
+ void Stop() override {
+ // Stop is only allowed to be called on active timeouts. Not stopped or
+ // expired.
+ RTC_DCHECK(timeout_id_.has_value());
+ RTC_DCHECK(active_timeouts_.erase(*timeout_id_) == 1);
+ timeout_id_ = absl::nullopt;
+ }
+
+ // A set of all active timeouts, managed by `FuzzerCallbacks`.
+ std::set<TimeoutID>& active_timeouts_;
+ // If present, the timout is active and will expire reported as `timeout_id`.
+ absl::optional<TimeoutID> timeout_id_;
+};
+
+class FuzzerCallbacks : public DcSctpSocketCallbacks {
+ public:
+ static constexpr int kRandomValue = 42;
+ void SendPacket(rtc::ArrayView<const uint8_t> data) override {
+ sent_packets_.emplace_back(std::vector<uint8_t>(data.begin(), data.end()));
+ }
+ std::unique_ptr<Timeout> CreateTimeout(
+ webrtc::TaskQueueBase::DelayPrecision precision) override {
+ // The fuzzer timeouts don't implement |precision|.
+ return std::make_unique<FuzzerTimeout>(active_timeouts_);
+ }
+ TimeMs TimeMillis() override { return TimeMs(42); }
+ uint32_t GetRandomInt(uint32_t low, uint32_t high) override {
+ return kRandomValue;
+ }
+ void OnMessageReceived(DcSctpMessage message) override {}
+ void OnError(ErrorKind error, absl::string_view message) override {}
+ void OnAborted(ErrorKind error, absl::string_view message) override {}
+ void OnConnected() override {}
+ void OnClosed() override {}
+ void OnConnectionRestarted() override {}
+ void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,
+ absl::string_view reason) override {}
+ void OnStreamsResetPerformed(
+ rtc::ArrayView<const StreamID> outgoing_streams) override {}
+ void OnIncomingStreamsReset(
+ rtc::ArrayView<const StreamID> incoming_streams) override {}
+
+ std::vector<uint8_t> ConsumeSentPacket() {
+ if (sent_packets_.empty()) {
+ return {};
+ }
+ std::vector<uint8_t> ret = sent_packets_.front();
+ sent_packets_.pop_front();
+ return ret;
+ }
+
+ // Given an index among the active timeouts, will expire that one.
+ absl::optional<TimeoutID> ExpireTimeout(size_t index) {
+ if (index < active_timeouts_.size()) {
+ auto it = active_timeouts_.begin();
+ std::advance(it, index);
+ TimeoutID timeout_id = *it;
+ active_timeouts_.erase(it);
+ return timeout_id;
+ }
+ return absl::nullopt;
+ }
+
+ private:
+ // Needs to be ordered, to allow fuzzers to expire timers.
+ std::set<TimeoutID> active_timeouts_;
+ std::deque<std::vector<uint8_t>> sent_packets_;
+};
+
+// Given some fuzzing `data` will send packets to the socket as well as calling
+// API methods.
+void FuzzSocket(DcSctpSocketInterface& socket,
+ FuzzerCallbacks& cb,
+ rtc::ArrayView<const uint8_t> data);
+
+} // namespace dcsctp_fuzzers
+} // namespace dcsctp
+#endif // NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_
diff --git a/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers_test.cc b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers_test.cc
new file mode 100644
index 0000000000..c7d2cd7c99
--- /dev/null
+++ b/third_party/libwebrtc/net/dcsctp/fuzzers/dcsctp_fuzzers_test.cc
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
+ *
+ * Use of this source code is governed by a BSD-style license
+ * that can be found in the LICENSE file in the root of the source
+ * tree. An additional intellectual property rights grant can be found
+ * in the file PATENTS. All contributing project authors may
+ * be found in the AUTHORS file in the root of the source tree.
+ */
+#include "net/dcsctp/fuzzers/dcsctp_fuzzers.h"
+
+#include "api/array_view.h"
+#include "net/dcsctp/packet/sctp_packet.h"
+#include "net/dcsctp/public/dcsctp_socket.h"
+#include "net/dcsctp/socket/dcsctp_socket.h"
+#include "net/dcsctp/testing/testing_macros.h"
+#include "rtc_base/gunit.h"
+#include "rtc_base/logging.h"
+#include "test/gmock.h"
+
+namespace dcsctp {
+namespace dcsctp_fuzzers {
+namespace {
+
+// This is a testbed where fuzzed data that cause issues can be evaluated and
+// crashes reproduced. Use `xxd -i ./crash-abc` to generate `data` below.
+TEST(DcsctpFuzzersTest, PassesTestbed) {
+ uint8_t data[] = {0x07, 0x09, 0x00, 0x01, 0x11, 0xff, 0xff};
+
+ FuzzerCallbacks cb;
+ DcSctpOptions options;
+ options.disable_checksum_verification = true;
+ DcSctpSocket socket("A", cb, nullptr, options);
+
+ FuzzSocket(socket, cb, data);
+}
+
+} // namespace
+} // namespace dcsctp_fuzzers
+} // namespace dcsctp