1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
|
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, oneshot};
use tokio_test::assert_ok;
use std::net::{IpAddr, SocketAddr};
macro_rules! test_accept {
($(($ident:ident, $target:expr),)*) => {
$(
#[tokio::test]
async fn $ident() {
let mut listener = assert_ok!(TcpListener::bind($target).await);
let addr = listener.local_addr().unwrap();
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
let (socket, _) = assert_ok!(listener.accept().await);
assert_ok!(tx.send(socket));
});
let cli = assert_ok!(TcpStream::connect(&addr).await);
let srv = assert_ok!(rx.await);
assert_eq!(cli.local_addr().unwrap(), srv.peer_addr().unwrap());
}
)*
}
}
test_accept! {
(ip_str, "127.0.0.1:0"),
(host_str, "localhost:0"),
(socket_addr, "127.0.0.1:0".parse::<SocketAddr>().unwrap()),
(str_port_tuple, ("127.0.0.1", 0)),
(ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
}
use std::pin::Pin;
use std::sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
Arc,
};
use std::task::{Context, Poll};
use tokio::stream::{Stream, StreamExt};
struct TrackPolls<S> {
npolls: Arc<AtomicUsize>,
s: S,
}
impl<S> Stream for TrackPolls<S>
where
S: Stream,
{
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// safety: we do not move s
let this = unsafe { self.get_unchecked_mut() };
this.npolls.fetch_add(1, SeqCst);
// safety: we are pinned, and so is s
unsafe { Pin::new_unchecked(&mut this.s) }.poll_next(cx)
}
}
#[tokio::test]
async fn no_extra_poll() {
let mut listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
let addr = listener.local_addr().unwrap();
let (tx, rx) = oneshot::channel();
let (accepted_tx, mut accepted_rx) = mpsc::unbounded_channel();
tokio::spawn(async move {
let mut incoming = TrackPolls {
npolls: Arc::new(AtomicUsize::new(0)),
s: listener.incoming(),
};
assert_ok!(tx.send(Arc::clone(&incoming.npolls)));
while let Some(_) = incoming.next().await {
accepted_tx.send(()).unwrap();
}
});
let npolls = assert_ok!(rx.await);
tokio::task::yield_now().await;
// should have been polled exactly once: the initial poll
assert_eq!(npolls.load(SeqCst), 1);
let _ = assert_ok!(TcpStream::connect(&addr).await);
accepted_rx.next().await.unwrap();
// should have been polled twice more: once to yield Some(), then once to yield Pending
assert_eq!(npolls.load(SeqCst), 1 + 2);
}
|