use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; use syn::{parse::Parser, punctuated::Punctuated, Expr, Index, Token}; /// The `stream_select!` macro. pub(crate) fn stream_select(input: TokenStream) -> Result<TokenStream, syn::Error> { let args = Punctuated::<Expr, Token![,]>::parse_terminated.parse2(input)?; if args.len() < 2 { return Ok(quote! { compile_error!("stream select macro needs at least two arguments.") }); } let generic_idents = (0..args.len()).map(|i| format_ident!("_{}", i)).collect::<Vec<_>>(); let field_idents = (0..args.len()).map(|i| format_ident!("__{}", i)).collect::<Vec<_>>(); let field_idents_2 = (0..args.len()).map(|i| format_ident!("___{}", i)).collect::<Vec<_>>(); let field_indices = (0..args.len()).map(Index::from).collect::<Vec<_>>(); let args = args.iter().map(|e| e.to_token_stream()); Ok(quote! { { #[derive(Debug)] struct StreamSelect<#(#generic_idents),*> (#(Option<#generic_idents>),*); enum StreamEnum<#(#generic_idents),*> { #( #generic_idents(#generic_idents) ),*, None, } impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamEnum<#(#generic_idents),*> where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)* { type Item = ITEM; fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> { match self.get_mut() { #( Self::#generic_idents(#generic_idents) => ::std::pin::Pin::new(#generic_idents).poll_next(cx) ),*, Self::None => panic!("StreamEnum::None should never be polled!"), } } } impl<ITEM, #(#generic_idents),*> __futures_crate::stream::Stream for StreamSelect<#(#generic_idents),*> where #(#generic_idents: __futures_crate::stream::Stream<Item=ITEM> + ::std::marker::Unpin,)* { type Item = ITEM; fn poll_next(mut self: ::std::pin::Pin<&mut Self>, cx: &mut __futures_crate::task::Context<'_>) -> __futures_crate::task::Poll<Option<Self::Item>> { let Self(#(ref mut #field_idents),*) = self.get_mut(); #( let mut #field_idents_2 = false; )* let mut any_pending = false; { let mut stream_array = [#(#field_idents.as_mut().map(|f| StreamEnum::#generic_idents(f)).unwrap_or(StreamEnum::None)),*]; __futures_crate::async_await::shuffle(&mut stream_array); for mut s in stream_array { if let StreamEnum::None = s { continue; } else { match __futures_crate::stream::Stream::poll_next(::std::pin::Pin::new(&mut s), cx) { r @ __futures_crate::task::Poll::Ready(Some(_)) => { return r; }, __futures_crate::task::Poll::Pending => { any_pending = true; }, __futures_crate::task::Poll::Ready(None) => { match s { #( StreamEnum::#generic_idents(_) => { #field_idents_2 = true; } ),*, StreamEnum::None => panic!("StreamEnum::None should never be polled!"), } }, } } } } #( if #field_idents_2 { *#field_idents = None; } )* if any_pending { __futures_crate::task::Poll::Pending } else { __futures_crate::task::Poll::Ready(None) } } fn size_hint(&self) -> (usize, Option<usize>) { let mut s = (0, Some(0)); #( if let Some(new_hint) = self.#field_indices.as_ref().map(|s| s.size_hint()) { s.0 += new_hint.0; // We can change this out for `.zip` when the MSRV is 1.46.0 or higher. s.1 = s.1.and_then(|a| new_hint.1.map(|b| a + b)); } )* s } } StreamSelect(#(Some(#args)),*) } }) }