summaryrefslogtreecommitdiffstats
path: root/third_party/rust/smawk/tests/complexity.rs
blob: c9881eaeac49ee5305fd5dcd4b287175f2ca8c1e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#![cfg(feature = "ndarray")]

use ndarray::{Array1, Array2};
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use smawk::online_column_minima;

mod random_monge;
use random_monge::random_monge_matrix;

#[derive(Debug)]
struct LinRegression {
    alpha: f64,
    beta: f64,
    r_squared: f64,
}

/// Square an expression. Works equally well for floats and matrices.
macro_rules! squared {
    ($x:expr) => {
        $x * $x
    };
}

/// Compute the mean of a 1-dimensional array.
macro_rules! mean {
    ($a:expr) => {
        $a.mean().expect("Mean of empty array")
    };
}

/// Compute a simple linear regression from the list of values.
///
/// See <https://en.wikipedia.org/wiki/Simple_linear_regression>.
fn linear_regression(values: &[(usize, i32)]) -> LinRegression {
    let xs = values.iter().map(|&(x, _)| x as f64).collect::<Array1<_>>();
    let ys = values.iter().map(|&(_, y)| y as f64).collect::<Array1<_>>();

    let xs_mean = mean!(&xs);
    let ys_mean = mean!(&ys);
    let xs_ys_mean = mean!(&xs * &ys);

    let cov_xs_ys = ((&xs - xs_mean) * (&ys - ys_mean)).sum();
    let var_xs = squared!(&xs - xs_mean).sum();

    let beta = cov_xs_ys / var_xs;
    let alpha = ys_mean - beta * xs_mean;
    let r_squared = squared!(xs_ys_mean - xs_mean * ys_mean)
        / ((mean!(&xs * &xs) - squared!(xs_mean)) * (mean!(&ys * &ys) - squared!(ys_mean)));

    LinRegression {
        alpha: alpha,
        beta: beta,
        r_squared: r_squared,
    }
}

/// Check that the number of matrix accesses in `online_column_minima`
/// grows as O(*n*) for *n* ✕ *n* matrix.
#[test]
fn online_linear_complexity() {
    let mut rng = ChaCha20Rng::seed_from_u64(0);
    let mut data = vec![];

    for &size in &[1, 2, 3, 4, 5, 10, 15, 20, 30, 40, 50, 60, 70, 80, 90, 100] {
        let matrix: Array2<i32> = random_monge_matrix(size, size, &mut rng);
        let count = std::cell::RefCell::new(0);
        online_column_minima(0, size, |_, i, j| {
            *count.borrow_mut() += 1;
            matrix[[i, j]]
        });
        data.push((size, count.into_inner()));
    }

    let lin_reg = linear_regression(&data);
    assert!(
        lin_reg.r_squared > 0.95,
        "r² = {:.4} is lower than expected for a linear fit\nData points: {:?}\n{:?}",
        lin_reg.r_squared,
        data,
        lin_reg
    );
}