summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/rtc_tools/data_channel_benchmark/data_channel_benchmark.cc
blob: fa0b6ca9c4f6d51b8e116199b4b5ad1b4b5164cd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
/*
 *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 *
 *  Data Channel Benchmarking tool.
 *
 *  Create a server using: ./data_channel_benchmark --server --port 12345
 *  Start the flow of data from the server to a client using:
 *  ./data_channel_benchmark --port 12345 --transfer_size 100 --packet_size 8196
 *  The throughput is reported on the server console.
 *
 *  The negotiation does not require a 3rd party server and is done over a gRPC
 *  transport. No TURN server is configured, so both peers need to be reachable
 *  using STUN only.
 */
#include <inttypes.h>

#include <charconv>

#include "absl/cleanup/cleanup.h"
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "rtc_base/event.h"
#include "rtc_base/ssl_adapter.h"
#include "rtc_base/thread.h"
#include "rtc_tools/data_channel_benchmark/grpc_signaling.h"
#include "rtc_tools/data_channel_benchmark/peer_connection_client.h"
#include "system_wrappers/include/field_trial.h"

ABSL_FLAG(int, verbose, 0, "verbosity level (0-5)");
ABSL_FLAG(bool, server, false, "Server mode");
ABSL_FLAG(bool, oneshot, true, "Terminate after serving a client");
ABSL_FLAG(std::string, address, "localhost", "Connect to server address");
ABSL_FLAG(uint16_t, port, 0, "Connect to port (0 for random)");
ABSL_FLAG(uint64_t, transfer_size, 2, "Transfer size (MiB)");
ABSL_FLAG(uint64_t, packet_size, 256 * 1024, "Packet size");
ABSL_FLAG(std::string,
          force_fieldtrials,
          "",
          "Field trials control experimental feature code which can be forced. "
          "E.g. running with --force_fieldtrials=WebRTC-FooFeature/Enable/"
          " will assign the group Enable to field trial WebRTC-FooFeature.");

struct SetupMessage {
  size_t packet_size;
  size_t transfer_size;

  std::string ToString() {
    char buffer[64];
    rtc::SimpleStringBuilder sb(buffer);
    sb << packet_size << "," << transfer_size;

    return sb.str();
  }

  static SetupMessage FromString(absl::string_view sv) {
    SetupMessage result;
    auto parameters = rtc::split(sv, ',');
    std::from_chars(parameters[0].data(),
                    parameters[0].data() + parameters[0].size(),
                    result.packet_size, 10);
    std::from_chars(parameters[1].data(),
                    parameters[1].data() + parameters[1].size(),
                    result.transfer_size, 10);
    return result;
  }
};

class DataChannelServerObserverImpl : public webrtc::DataChannelObserver {
 public:
  explicit DataChannelServerObserverImpl(webrtc::DataChannelInterface* dc,
                                         rtc::Thread* signaling_thread)
      : dc_(dc), signaling_thread_(signaling_thread) {}

  void OnStateChange() override {
    RTC_LOG(LS_INFO) << "Server state changed to " << dc_->state();
    switch (dc_->state()) {
      case webrtc::DataChannelInterface::DataState::kOpen:
        break;
      case webrtc::DataChannelInterface::DataState::kClosed:
        closed_event_.Set();
        break;
      default:
        break;
    }
  }

  void OnMessage(const webrtc::DataBuffer& buffer) override {
    if (!buffer.binary) {
      std::string setup_message(buffer.data.cdata<char>(), buffer.data.size());
      setup_ = SetupMessage::FromString(setup_message);
      remaining_data_ = setup_.transfer_size;
      setup_message_event_.Set();
    }
  }

  void OnBufferedAmountChange(uint64_t sent_data_size) override {
    remaining_data_ -= sent_data_size;
    // Allow the transport buffer to be drained before starting again.
    if (buffer_ && dc_->buffered_amount() <= ok_to_resume_sending_threshold_) {
      total_queued_up_ += buffer_->size();
      dc_->SendAsync(*buffer_, [this, buffer = buffer_](webrtc::RTCError err) {
        OnSendAsyncComplete(err, buffer);
      });
      buffer_ = nullptr;
    }
  }

  bool IsOkToCallOnTheNetworkThread() override { return true; }

  bool WaitForClosedState() { return closed_event_.Wait(rtc::Event::kForever); }

  bool WaitForSetupMessage() {
    return setup_message_event_.Wait(rtc::Event::kForever);
  }

  void StartSending() {
    RTC_CHECK(remaining_data_) << "Error: no data to send";
    std::string data(std::min(setup_.packet_size, remaining_data_), '0');
    webrtc::DataBuffer* data_buffer =
        new webrtc::DataBuffer(rtc::CopyOnWriteBuffer(data), true);
    total_queued_up_ = data_buffer->size();
    dc_->SendAsync(*data_buffer,
                   [this, data_buffer = data_buffer](webrtc::RTCError err) {
                     OnSendAsyncComplete(err, data_buffer);
                   });
  }

  const struct SetupMessage& parameters() const { return setup_; }

 private:
  void OnSendAsyncComplete(webrtc::RTCError error, webrtc::DataBuffer* buffer) {
    total_queued_up_ -= buffer->size();
    if (!error.ok()) {
      RTC_CHECK_EQ(error.type(), webrtc::RTCErrorType::RESOURCE_EXHAUSTED);
      RTC_CHECK(!buffer_);
      // Buffer saturated. Retry when OnBufferedAmountChange() detects we can.
      buffer_ = buffer;
      return;
    }
    signaling_thread_->PostTask([this, buffer = buffer,
                                 remaining_data = remaining_data_]() {
      fprintf(stderr, "Progress: %zu / %zu (%zu%%)\n",
              (setup_.transfer_size - remaining_data), setup_.transfer_size,
              (100 - remaining_data * 100 / setup_.transfer_size));

      if (!remaining_data) {
        RTC_CHECK(!total_queued_up_);
        // We're done.
        delete buffer;
        return;
      }

      if (remaining_data < buffer->data.size()) {
        buffer->data.SetSize(remaining_data);
      }

      total_queued_up_ += buffer->size();
      dc_->SendAsync(*buffer, [this, buffer = buffer](webrtc::RTCError err) {
        OnSendAsyncComplete(err, buffer);
      });
    });
  }

  webrtc::DataChannelInterface* const dc_;
  rtc::Thread* const signaling_thread_;
  rtc::Event closed_event_;
  rtc::Event setup_message_event_;
  size_t remaining_data_ = 0u;
  size_t total_queued_up_ = 0u;
  struct SetupMessage setup_;
  webrtc::DataBuffer* buffer_ = nullptr;
  const uint64_t ok_to_resume_sending_threshold_ =
      webrtc::DataChannelInterface::MaxSendQueueSize() / 2;
};

class DataChannelClientObserverImpl : public webrtc::DataChannelObserver {
 public:
  explicit DataChannelClientObserverImpl(webrtc::DataChannelInterface* dc,
                                         uint64_t bytes_received_threshold)
      : dc_(dc), bytes_received_threshold_(bytes_received_threshold) {}

  void OnStateChange() override {
    RTC_LOG(LS_INFO) << "Client state changed to " << dc_->state();
    switch (dc_->state()) {
      case webrtc::DataChannelInterface::DataState::kOpen:
        open_event_.Set();
        break;
      default:
        break;
    }
  }

  void OnMessage(const webrtc::DataBuffer& buffer) override {
    bytes_received_ += buffer.data.size();
    if (bytes_received_ >= bytes_received_threshold_) {
      bytes_received_event_.Set();
    }
  }

  void OnBufferedAmountChange(uint64_t sent_data_size) override {}
  bool IsOkToCallOnTheNetworkThread() override { return true; }

  bool WaitForOpenState() { return open_event_.Wait(rtc::Event::kForever); }

  // Wait until the received byte count reaches the desired value.
  bool WaitForBytesReceivedThreshold() {
    return bytes_received_event_.Wait(rtc::Event::kForever);
  }

 private:
  webrtc::DataChannelInterface* const dc_;
  rtc::Event open_event_;
  rtc::Event bytes_received_event_;
  const uint64_t bytes_received_threshold_;
  uint64_t bytes_received_ = 0u;
};

int RunServer() {
  bool oneshot = absl::GetFlag(FLAGS_oneshot);
  uint16_t port = absl::GetFlag(FLAGS_port);

  auto signaling_thread = rtc::Thread::Create();
  signaling_thread->Start();
  {
    auto factory = webrtc::PeerConnectionClient::CreateDefaultFactory(
        signaling_thread.get());

    auto grpc_server = webrtc::GrpcSignalingServerInterface::Create(
        [factory = rtc::scoped_refptr<webrtc::PeerConnectionFactoryInterface>(
             factory),
         signaling_thread =
             signaling_thread.get()](webrtc::SignalingInterface* signaling) {
          webrtc::PeerConnectionClient client(factory.get(), signaling);
          client.StartPeerConnection();
          auto peer_connection = client.peerConnection();

          // Set up the data channel
          auto dc_or_error =
              peer_connection->CreateDataChannelOrError("benchmark", nullptr);
          RTC_CHECK(dc_or_error.ok());
          auto data_channel = dc_or_error.MoveValue();
          auto data_channel_observer =
              std::make_unique<DataChannelServerObserverImpl>(
                  data_channel.get(), signaling_thread);
          data_channel->RegisterObserver(data_channel_observer.get());
          absl::Cleanup unregister_observer(
              [data_channel] { data_channel->UnregisterObserver(); });

          // Wait for a first message from the remote peer.
          // It configures how much data should be sent and how big the packets
          // should be.
          // First message is "packet_size,transfer_size".
          data_channel_observer->WaitForSetupMessage();

          // Wait for the sender and receiver peers to stabilize (send all ACKs)
          // This makes it easier to isolate the sending part when profiling.
          absl::SleepFor(absl::Seconds(1));

          auto begin_time = webrtc::Clock::GetRealTimeClock()->CurrentTime();

          data_channel_observer->StartSending();

          // Receiver signals the data channel close event when it has received
          // all the data it requested.
          data_channel_observer->WaitForClosedState();

          auto end_time = webrtc::Clock::GetRealTimeClock()->CurrentTime();
          auto duration_ms = (end_time - begin_time).ms<size_t>();
          double throughput =
              (data_channel_observer->parameters().transfer_size / 1024. /
               1024.) /
              (duration_ms / 1000.);
          printf("Elapsed time: %zums %gMiB/s\n", duration_ms, throughput);
        },
        port, oneshot);
    grpc_server->Start();

    printf("Server listening on port %d\n", grpc_server->SelectedPort());
    grpc_server->Wait();
  }

  signaling_thread->Stop();
  return 0;
}

int RunClient() {
  uint16_t port = absl::GetFlag(FLAGS_port);
  std::string server_address = absl::GetFlag(FLAGS_address);
  size_t transfer_size = absl::GetFlag(FLAGS_transfer_size) * 1024 * 1024;
  size_t packet_size = absl::GetFlag(FLAGS_packet_size);

  auto signaling_thread = rtc::Thread::Create();
  signaling_thread->Start();
  {
    auto factory = webrtc::PeerConnectionClient::CreateDefaultFactory(
        signaling_thread.get());
    auto grpc_client = webrtc::GrpcSignalingClientInterface::Create(
        server_address + ":" + std::to_string(port));
    webrtc::PeerConnectionClient client(factory.get(),
                                        grpc_client->signaling_client());

    std::unique_ptr<DataChannelClientObserverImpl> observer;

    // Set up the callback to receive the data channel from the sender.
    rtc::scoped_refptr<webrtc::DataChannelInterface> data_channel;
    rtc::Event got_data_channel;
    client.SetOnDataChannel(
        [&](rtc::scoped_refptr<webrtc::DataChannelInterface> channel) {
          data_channel = std::move(channel);
          // DataChannel needs an observer to drain the read queue.
          observer = std::make_unique<DataChannelClientObserverImpl>(
              data_channel.get(), transfer_size);
          data_channel->RegisterObserver(observer.get());
          got_data_channel.Set();
        });

    // Connect to the server.
    if (!grpc_client->Start()) {
      fprintf(stderr, "Failed to connect to server\n");
      return 1;
    }

    // Wait for the data channel to be received
    got_data_channel.Wait(rtc::Event::kForever);

    absl::Cleanup unregister_observer(
        [data_channel] { data_channel->UnregisterObserver(); });

    // Send a configuration string to the server to tell it to send
    // 'packet_size' bytes packets and send a total of 'transfer_size' MB.
    observer->WaitForOpenState();
    SetupMessage setup_message = {
        .packet_size = packet_size,
        .transfer_size = transfer_size,
    };
    if (!data_channel->Send(webrtc::DataBuffer(setup_message.ToString()))) {
      fprintf(stderr, "Failed to send parameter string\n");
      return 1;
    }

    // Wait until we have received all the data
    observer->WaitForBytesReceivedThreshold();

    // Close the data channel, signaling to the server we have received
    // all the requested data.
    data_channel->Close();
  }

  signaling_thread->Stop();

  return 0;
}

int main(int argc, char** argv) {
  rtc::InitializeSSL();
  absl::ParseCommandLine(argc, argv);

  // Make sure that higher severity number means more logs by reversing the
  // rtc::LoggingSeverity values.
  auto logging_severity =
      std::max(0, rtc::LS_NONE - absl::GetFlag(FLAGS_verbose));
  rtc::LogMessage::LogToDebug(
      static_cast<rtc::LoggingSeverity>(logging_severity));

  bool is_server = absl::GetFlag(FLAGS_server);
  std::string field_trials = absl::GetFlag(FLAGS_force_fieldtrials);

  webrtc::field_trial::InitFieldTrialsFromString(field_trials.c_str());

  return is_server ? RunServer() : RunClient();
}