summaryrefslogtreecommitdiffstats
path: root/third_party/rust/futures-macro/src/stream_select.rs
blob: 9927b5307348dc2302060481a2993078df984bc6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use proc_macro2::TokenStream;
use quote::{format_ident, quote, ToTokens};
use syn::{parse::Parser, punctuated::Punctuated, Expr, Index, Token};

/// The `stream_select!` macro.
pub(crate) fn stream_select(input: TokenStream) -> Result<TokenStream, syn::Error> {
    let args = Punctuated::<Expr, Token![,]>::parse_terminated.parse2(input)?;
    if args.len() < 2 {
        return Ok(quote! {
           compile_error!("stream select macro needs at least two arguments.")
        });
    }
    let generic_idents = (0..args.len()).map(|i| format_ident!("_{}", i)).collect::<Vec<_>>();
    let field_idents = (0..args.len()).map(|i| format_ident!("__{}", i)).collect::<Vec<_>>();
    let field_idents_2 = (0..args.len()).map(|i| format_ident!("___{}", i)).collect::<Vec<_>>();
    let field_indices = (0..args.len()).map(Index::from).collect::<Vec<_>>();
    let args = args.iter().map(|e| e.to_token_stream());

    Ok(quote! {
        {
            #[derive(Debug)]
            struct StreamSelect<#(#generic_idents),*> (#(Option<#generic_idents>),*);

            enum StreamEnum<#(#generic_idents),*> {
                #(
                    #generic_idents(#generic_idents)
                ),*,
                None,
            }

            impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamEnum<#(#generic_idents),*>
            where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
            {
                type Item = ITEM;

                fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
                    match self.get_mut() {
                        #(
                            Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll_next(cx)
                        ),*,
                        Self::None => panic!("StreamEnum::None should never be polled!"),
                    }
                }
            }

            impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*>
            where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)*
            {
                type Item = ITEM;

                fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> {
                    let Self(#(ref mut #field_idents),*) = self.get_mut();
                    #(
                        let mut #field_idents_2 = false;
                    )*
                    let mut any_pending = false;
                    {
                        let mut stream_array = [#(#field_idents.as_mut().map(|f| StreamEnum::#generic_idents(f)).unwrap_or(StreamEnum::None)),*];
                        __futures_crate::async_await::shuffle(&mut stream_array);

                        for mut s in stream_array {
                            if let StreamEnum::None = s {
                                continue;
                            } else {
                                match __futures_crate::stream::Stream::poll_next(::std::pin::Pin::new(&mut s), cx) {
                                    r @ __futures_crate::task::Poll::Ready(Some(_)) => {
                                        return r;
                                    },
                                    __futures_crate::task::Poll::Pending => {
                                        any_pending = true;
                                    },
                                    __futures_crate::task::Poll::Ready(None) => {
                                        match s {
                                            #(
                                                StreamEnum::#generic_idents(_) => { #field_idents_2 = true; }
                                            ),*,
                                            StreamEnum::None => panic!("StreamEnum::None should never be polled!"),
                                        }
                                    },
                                }
                            }
                        }
                    }
                    #(
                        if #field_idents_2 {
                            *#field_idents = None;
                        }
                    )*
                    if any_pending {
                        __futures_crate::task::Poll::Pending
                    } else {
                        __futures_crate::task::Poll::Ready(None)
                    }
                }

                fn size_hint(&self) -> (usize, Option<usize>) {
                    let mut s = (0, Some(0));
                    #(
                        if let Some(new_hint) = self.#field_indices.as_ref().map(|s| s.size_hint()) {
                            s.0 += new_hint.0;
                            // We can change this out for `.zip` when the MSRV is 1.46.0 or higher.
                            s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b));
                        }
                    )*
                    s
                }
            }

            StreamSelect(#(Some(#args)),*)

        }
    })
}