summaryrefslogtreecommitdiffstats
path: root/vendor/unarray/src/from_iter.rs
blob: e96275bb2d2fcc197c2f0cf7739803e451e3b290 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
use crate::{mark_initialized, uninit_buf};
use core::iter::FromIterator;

/// A wrapper type to collect an [`Iterator`] into an array
///
/// ```
/// # use unarray::*;
/// let iter = vec![1, 2, 3].into_iter();
/// let ArrayFromIter(array) = iter.collect();
///
/// assert_eq!(array, Some([1, 2, 3]));
/// ```
/// Since iterators don't carry compile-time information about their length (even
/// [`core::iter::ExactSizeIterator`] only provides this at runtime), collecting may fail if the
/// iterator doesn't yield **exactly** `N` elements:
/// ```
/// use unarray::*;
/// let too_many = vec![1, 2, 3, 4].into_iter();
/// let ArrayFromIter::<i32, 3>(option) = too_many.collect();
/// assert!(option.is_none());
///
/// let too_few = vec![1, 2].into_iter();
/// let ArrayFromIter::<i32, 3>(option) = too_few.collect();
/// assert!(option.is_none());
/// ```
pub struct ArrayFromIter<T, const N: usize>(pub Option<[T; N]>);

impl<T, const N: usize> FromIterator<T> for ArrayFromIter<T, N> {
    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
        let mut buffer = uninit_buf::<T, N>();
        let mut iter = iter.into_iter();
        let mut buf_iter = buffer.iter_mut();
        let mut num_written = 0;

        loop {
            let item = iter.next();
            let slot = buf_iter.next();

            match (item, slot) {
                (Some(item), Some(slot)) => {
                    num_written += 1;
                    slot.write(item);
                }
                // error case, we should free the previous ones and continue
                (Some(_), None) | (None, Some(_)) => {
                    // SAFETY:
                    // There are `num_written` elements fully initialized, so we can safely drop
                    // them
                    buffer
                        .iter_mut()
                        .take(num_written)
                        .for_each(|slot| unsafe { slot.assume_init_drop() });

                    return Self(None);
                }
                // SAFETY
                // If this is reached, every prior iteration of the loop has matched
                // (Some(_), Some(_)). As such, both iterators have yielded the same number of
                // elements, so every slot has been written to
                (None, None) => return Self(Some(unsafe { mark_initialized(buffer) })),
            };
        }
    }
}

#[cfg(test)]
mod tests {

    use core::{
        convert::TryInto,
        sync::atomic::{AtomicUsize, Ordering},
    };

    use crate::testing::vec_strategy;
    use proptest::{prop_assert, prop_assert_eq};
    use test_strategy::proptest;

    use super::*;

    #[test]
    fn can_collect_array_from_iter() {
        let iter = vec![1, 2, 3].into_iter();

        let ArrayFromIter(array) = iter.collect();
        assert_eq!(array.unwrap(), [1, 2, 3]);
    }

    #[test]
    fn fails_if_incorrect_number_of_elements() {
        let iter = [1, 2, 3].iter();
        let ArrayFromIter::<_, 4>(array) = iter.collect();
        assert!(array.is_none());

        let iter = [1, 2, 3].iter();
        let ArrayFromIter::<_, 2>(array) = iter.collect();
        assert!(array.is_none());
    }

    const LEN: usize = 100;
    const SHORT_LEN: usize = LEN - 1;
    const LONG_LEN: usize = LEN + 1;

    #[derive(Clone)]
    struct IncrementOnDrop<'a>(&'a AtomicUsize);
    impl Drop for IncrementOnDrop<'_> {
        fn drop(&mut self) {
            self.0.fetch_add(1, Ordering::Relaxed);
        }
    }

    #[test]
    fn doesnt_leak_too_long() {
        let drop_count = 0.into();
        let ArrayFromIter::<_, 3>(_) = vec![IncrementOnDrop(&drop_count); 4].into_iter().collect();
        // since it failed, all 4 should be dropped
        assert_eq!(drop_count.load(Ordering::Relaxed), 4);
    }

    #[test]
    fn doesnt_leak_too_short() {
        let drop_count = 0.into();
        let ArrayFromIter::<_, 3>(_) = vec![IncrementOnDrop(&drop_count); 2].into_iter().collect();
        // since it failed, both should be dropped
        assert_eq!(drop_count.load(Ordering::Relaxed), 2);
    }

    #[proptest]
    #[cfg_attr(miri, ignore)]
    fn undersized_proptest(#[strategy(vec_strategy(LEN))] vec: Vec<String>) {
        let ArrayFromIter::<String, SHORT_LEN>(array) = vec.into_iter().collect();
        prop_assert!(array.is_none());
    }

    #[proptest]
    #[cfg_attr(miri, ignore)]
    fn oversized_proptest(#[strategy(vec_strategy(LEN))] vec: Vec<String>) {
        let ArrayFromIter::<String, LONG_LEN>(array) = vec.into_iter().collect();
        prop_assert!(array.is_none());
    }

    #[proptest]
    #[cfg_attr(miri, ignore)]
    fn just_right_proptest(#[strategy(vec_strategy(LEN))] vec: Vec<String>) {
        let expected: [String; LEN] = vec.clone().try_into().unwrap();
        let ArrayFromIter(array) = vec.into_iter().collect();
        prop_assert_eq!(array.unwrap(), expected);
    }
}