summaryrefslogtreecommitdiffstats
path: root/third_party/libwebrtc/rtc_base/numerics/running_statistics.h
blob: bbcc7e2a73ba00017a05e4f788bc217cb9c70e88 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
/*
 *  Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
 *
 *  Use of this source code is governed by a BSD-style license
 *  that can be found in the LICENSE file in the root of the source
 *  tree. An additional intellectual property rights grant can be found
 *  in the file PATENTS.  All contributing project authors may
 *  be found in the AUTHORS file in the root of the source tree.
 */

#ifndef API_NUMERICS_RUNNING_STATISTICS_H_
#define API_NUMERICS_RUNNING_STATISTICS_H_

#include <algorithm>
#include <cmath>
#include <limits>

#include "absl/types/optional.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/math_utils.h"

namespace webrtc {
namespace webrtc_impl {

// tl;dr: Robust and efficient online computation of statistics,
//        using Welford's method for variance. [1]
//
// This should be your go-to class if you ever need to compute
// min, max, mean, variance and standard deviation.
// If you need to get percentiles, please use webrtc::SamplesStatsCounter.
//
// Please note RemoveSample() won't affect min and max.
// If you want a full-fledged moving window over N last samples,
// please use webrtc::RollingAccumulator.
//
// The measures return absl::nullopt if no samples were fed (Size() == 0),
// otherwise the returned optional is guaranteed to contain a value.
//
// [1]
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm

// The type T is a scalar which must be convertible to double.
// Rationale: we often need greater precision for measures
//            than for the samples themselves.
template <typename T>
class RunningStatistics {
 public:
  // Update stats ////////////////////////////////////////////

  // Add a value participating in the statistics in O(1) time.
  void AddSample(T sample) {
    max_ = std::max(max_, sample);
    min_ = std::min(min_, sample);
    ++size_;
    // Welford's incremental update.
    const double delta = sample - mean_;
    mean_ += delta / size_;
    const double delta2 = sample - mean_;
    cumul_ += delta * delta2;
  }

  // Remove a previously added value in O(1) time.
  // Nb: This doesn't affect min or max.
  // Calling RemoveSample when Size()==0 is incorrect.
  void RemoveSample(T sample) {
    RTC_DCHECK_GT(Size(), 0);
    // In production, just saturate at 0.
    if (Size() == 0) {
      return;
    }
    // Since samples order doesn't matter, this is the
    // exact reciprocal of Welford's incremental update.
    --size_;
    const double delta = sample - mean_;
    mean_ -= delta / size_;
    const double delta2 = sample - mean_;
    cumul_ -= delta * delta2;
  }

  // Merge other stats, as if samples were added one by one, but in O(1).
  void MergeStatistics(const RunningStatistics<T>& other) {
    if (other.size_ == 0) {
      return;
    }
    max_ = std::max(max_, other.max_);
    min_ = std::min(min_, other.min_);
    const int64_t new_size = size_ + other.size_;
    const double new_mean =
        (mean_ * size_ + other.mean_ * other.size_) / new_size;
    // Each cumulant must be corrected.
    //   * from: sum((x_i - mean_)²)
    //   * to:   sum((x_i - new_mean)²)
    auto delta = [new_mean](const RunningStatistics<T>& stats) {
      return stats.size_ * (new_mean * (new_mean - 2 * stats.mean_) +
                            stats.mean_ * stats.mean_);
    };
    cumul_ = cumul_ + delta(*this) + other.cumul_ + delta(other);
    mean_ = new_mean;
    size_ = new_size;
  }

  // Get Measures ////////////////////////////////////////////

  // Returns number of samples involved via AddSample() or MergeStatistics(),
  // minus number of times RemoveSample() was called.
  int64_t Size() const { return size_; }

  // Returns minimum among all seen samples, in O(1) time.
  // This isn't affected by RemoveSample().
  absl::optional<T> GetMin() const {
    if (size_ == 0) {
      return absl::nullopt;
    }
    return min_;
  }

  // Returns maximum among all seen samples, in O(1) time.
  // This isn't affected by RemoveSample().
  absl::optional<T> GetMax() const {
    if (size_ == 0) {
      return absl::nullopt;
    }
    return max_;
  }

  // Returns mean in O(1) time.
  absl::optional<double> GetMean() const {
    if (size_ == 0) {
      return absl::nullopt;
    }
    return mean_;
  }

  // Returns unbiased sample variance in O(1) time.
  absl::optional<double> GetVariance() const {
    if (size_ == 0) {
      return absl::nullopt;
    }
    return cumul_ / size_;
  }

  // Returns unbiased standard deviation in O(1) time.
  absl::optional<double> GetStandardDeviation() const {
    if (size_ == 0) {
      return absl::nullopt;
    }
    return std::sqrt(*GetVariance());
  }

 private:
  int64_t size_ = 0;  // Samples seen.
  T min_ = infinity_or_max<T>();
  T max_ = minus_infinity_or_min<T>();
  double mean_ = 0;
  double cumul_ = 0;  // Variance * size_, sometimes noted m2.
};

}  // namespace webrtc_impl
}  // namespace webrtc

#endif  // API_NUMERICS_RUNNING_STATISTICS_H_