summaryrefslogtreecommitdiffstats
path: root/library/std/src/sync/mpsc/sync.rs
blob: 733761671a041e75bffa361e8cc906a26559b4fb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
use self::Blocker::*;
/// Synchronous channels/ports
///
/// This channel implementation differs significantly from the asynchronous
/// implementations found next to it (oneshot/stream/share). This is an
/// implementation of a synchronous, bounded buffer channel.
///
/// Each channel is created with some amount of backing buffer, and sends will
/// *block* until buffer space becomes available. A buffer size of 0 is valid,
/// which means that every successful send is paired with a successful recv.
///
/// This flavor of channels defines a new `send_opt` method for channels which
/// is the method by which a message is sent but the thread does not panic if it
/// cannot be delivered.
///
/// Another major difference is that send() will *always* return back the data
/// if it couldn't be sent. This is because it is deterministically known when
/// the data is received and when it is not received.
///
/// Implementation-wise, it can all be summed up with "use a mutex plus some
/// logic". The mutex used here is an OS native mutex, meaning that no user code
/// is run inside of the mutex (to prevent context switching). This
/// implementation shares almost all code for the buffered and unbuffered cases
/// of a synchronous channel. There are a few branches for the unbuffered case,
/// but they're mostly just relevant to blocking senders.
pub use self::Failure::*;

use core::intrinsics::abort;
use core::mem;
use core::ptr;

use crate::sync::atomic::{AtomicUsize, Ordering};
use crate::sync::mpsc::blocking::{self, SignalToken, WaitToken};
use crate::sync::{Mutex, MutexGuard};
use crate::time::Instant;

const MAX_REFCOUNT: usize = (isize::MAX) as usize;

pub struct Packet<T> {
    /// Only field outside of the mutex. Just done for kicks, but mainly because
    /// the other shared channel already had the code implemented
    channels: AtomicUsize,

    lock: Mutex<State<T>>,
}

unsafe impl<T: Send> Send for Packet<T> {}

unsafe impl<T: Send> Sync for Packet<T> {}

struct State<T> {
    disconnected: bool, // Is the channel disconnected yet?
    queue: Queue,       // queue of senders waiting to send data
    blocker: Blocker,   // currently blocked thread on this channel
    buf: Buffer<T>,     // storage for buffered messages
    cap: usize,         // capacity of this channel

    /// A curious flag used to indicate whether a sender failed or succeeded in
    /// blocking. This is used to transmit information back to the thread that it
    /// must dequeue its message from the buffer because it was not received.
    /// This is only relevant in the 0-buffer case. This obviously cannot be
    /// safely constructed, but it's guaranteed to always have a valid pointer
    /// value.
    canceled: Option<&'static mut bool>,
}

unsafe impl<T: Send> Send for State<T> {}

/// Possible flavors of threads who can be blocked on this channel.
enum Blocker {
    BlockedSender(SignalToken),
    BlockedReceiver(SignalToken),
    NoneBlocked,
}

/// Simple queue for threading threads together. Nodes are stack-allocated, so
/// this structure is not safe at all
struct Queue {
    head: *mut Node,
    tail: *mut Node,
}

struct Node {
    token: Option<SignalToken>,
    next: *mut Node,
}

unsafe impl Send for Node {}

/// A simple ring-buffer
struct Buffer<T> {
    buf: Vec<Option<T>>,
    start: usize,
    size: usize,
}

#[derive(Debug)]
pub enum Failure {
    Empty,
    Disconnected,
}

/// Atomically blocks the current thread, placing it into `slot`, unlocking `lock`
/// in the meantime. This re-locks the mutex upon returning.
fn wait<'a, 'b, T>(
    lock: &'a Mutex<State<T>>,
    mut guard: MutexGuard<'b, State<T>>,
    f: fn(SignalToken) -> Blocker,
) -> MutexGuard<'a, State<T>> {
    let (wait_token, signal_token) = blocking::tokens();
    match mem::replace(&mut guard.blocker, f(signal_token)) {
        NoneBlocked => {}
        _ => unreachable!(),
    }
    drop(guard); // unlock
    wait_token.wait(); // block
    lock.lock().unwrap() // relock
}

/// Same as wait, but waiting at most until `deadline`.
fn wait_timeout_receiver<'a, 'b, T>(
    lock: &'a Mutex<State<T>>,
    deadline: Instant,
    mut guard: MutexGuard<'b, State<T>>,
    success: &mut bool,
) -> MutexGuard<'a, State<T>> {
    let (wait_token, signal_token) = blocking::tokens();
    match mem::replace(&mut guard.blocker, BlockedReceiver(signal_token)) {
        NoneBlocked => {}
        _ => unreachable!(),
    }
    drop(guard); // unlock
    *success = wait_token.wait_max_until(deadline); // block
    let mut new_guard = lock.lock().unwrap(); // relock
    if !*success {
        abort_selection(&mut new_guard);
    }
    new_guard
}

fn abort_selection<T>(guard: &mut MutexGuard<'_, State<T>>) -> bool {
    match mem::replace(&mut guard.blocker, NoneBlocked) {
        NoneBlocked => true,
        BlockedSender(token) => {
            guard.blocker = BlockedSender(token);
            true
        }
        BlockedReceiver(token) => {
            drop(token);
            false
        }
    }
}

/// Wakes up a thread, dropping the lock at the correct time
fn wakeup<T>(token: SignalToken, guard: MutexGuard<'_, State<T>>) {
    // We need to be careful to wake up the waiting thread *outside* of the mutex
    // in case it incurs a context switch.
    drop(guard);
    token.signal();
}

impl<T> Packet<T> {
    pub fn new(capacity: usize) -> Packet<T> {
        Packet {
            channels: AtomicUsize::new(1),
            lock: Mutex::new(State {
                disconnected: false,
                blocker: NoneBlocked,
                cap: capacity,
                canceled: None,
                queue: Queue { head: ptr::null_mut(), tail: ptr::null_mut() },
                buf: Buffer {
                    buf: (0..capacity + if capacity == 0 { 1 } else { 0 }).map(|_| None).collect(),
                    start: 0,
                    size: 0,
                },
            }),
        }
    }

    // wait until a send slot is available, returning locked access to
    // the channel state.
    fn acquire_send_slot(&self) -> MutexGuard<'_, State<T>> {
        let mut node = Node { token: None, next: ptr::null_mut() };
        loop {
            let mut guard = self.lock.lock().unwrap();
            // are we ready to go?
            if guard.disconnected || guard.buf.size() < guard.buf.capacity() {
                return guard;
            }
            // no room; actually block
            let wait_token = guard.queue.enqueue(&mut node);
            drop(guard);
            wait_token.wait();
        }
    }

    pub fn send(&self, t: T) -> Result<(), T> {
        let mut guard = self.acquire_send_slot();
        if guard.disconnected {
            return Err(t);
        }
        guard.buf.enqueue(t);

        match mem::replace(&mut guard.blocker, NoneBlocked) {
            // if our capacity is 0, then we need to wait for a receiver to be
            // available to take our data. After waiting, we check again to make
            // sure the port didn't go away in the meantime. If it did, we need
            // to hand back our data.
            NoneBlocked if guard.cap == 0 => {
                let mut canceled = false;
                assert!(guard.canceled.is_none());
                guard.canceled = Some(unsafe { mem::transmute(&mut canceled) });
                let mut guard = wait(&self.lock, guard, BlockedSender);
                if canceled { Err(guard.buf.dequeue()) } else { Ok(()) }
            }

            // success, we buffered some data
            NoneBlocked => Ok(()),

            // success, someone's about to receive our buffered data.
            BlockedReceiver(token) => {
                wakeup(token, guard);
                Ok(())
            }

            BlockedSender(..) => panic!("lolwut"),
        }
    }

    pub fn try_send(&self, t: T) -> Result<(), super::TrySendError<T>> {
        let mut guard = self.lock.lock().unwrap();
        if guard.disconnected {
            Err(super::TrySendError::Disconnected(t))
        } else if guard.buf.size() == guard.buf.capacity() {
            Err(super::TrySendError::Full(t))
        } else if guard.cap == 0 {
            // With capacity 0, even though we have buffer space we can't
            // transfer the data unless there's a receiver waiting.
            match mem::replace(&mut guard.blocker, NoneBlocked) {
                NoneBlocked => Err(super::TrySendError::Full(t)),
                BlockedSender(..) => unreachable!(),
                BlockedReceiver(token) => {
                    guard.buf.enqueue(t);
                    wakeup(token, guard);
                    Ok(())
                }
            }
        } else {
            // If the buffer has some space and the capacity isn't 0, then we
            // just enqueue the data for later retrieval, ensuring to wake up
            // any blocked receiver if there is one.
            assert!(guard.buf.size() < guard.buf.capacity());
            guard.buf.enqueue(t);
            match mem::replace(&mut guard.blocker, NoneBlocked) {
                BlockedReceiver(token) => wakeup(token, guard),
                NoneBlocked => {}
                BlockedSender(..) => unreachable!(),
            }
            Ok(())
        }
    }

    // Receives a message from this channel
    //
    // When reading this, remember that there can only ever be one receiver at
    // time.
    pub fn recv(&self, deadline: Option<Instant>) -> Result<T, Failure> {
        let mut guard = self.lock.lock().unwrap();

        let mut woke_up_after_waiting = false;
        // Wait for the buffer to have something in it. No need for a
        // while loop because we're the only receiver.
        if !guard.disconnected && guard.buf.size() == 0 {
            if let Some(deadline) = deadline {
                guard =
                    wait_timeout_receiver(&self.lock, deadline, guard, &mut woke_up_after_waiting);
            } else {
                guard = wait(&self.lock, guard, BlockedReceiver);
                woke_up_after_waiting = true;
            }
        }

        // N.B., channel could be disconnected while waiting, so the order of
        // these conditionals is important.
        if guard.disconnected && guard.buf.size() == 0 {
            return Err(Disconnected);
        }

        // Pick up the data, wake up our neighbors, and carry on
        assert!(guard.buf.size() > 0 || (deadline.is_some() && !woke_up_after_waiting));

        if guard.buf.size() == 0 {
            return Err(Empty);
        }

        let ret = guard.buf.dequeue();
        self.wakeup_senders(woke_up_after_waiting, guard);
        Ok(ret)
    }

    pub fn try_recv(&self) -> Result<T, Failure> {
        let mut guard = self.lock.lock().unwrap();

        // Easy cases first
        if guard.disconnected && guard.buf.size() == 0 {
            return Err(Disconnected);
        }
        if guard.buf.size() == 0 {
            return Err(Empty);
        }

        // Be sure to wake up neighbors
        let ret = Ok(guard.buf.dequeue());
        self.wakeup_senders(false, guard);
        ret
    }

    // Wake up pending senders after some data has been received
    //
    // * `waited` - flag if the receiver blocked to receive some data, or if it
    //              just picked up some data on the way out
    // * `guard` - the lock guard that is held over this channel's lock
    fn wakeup_senders(&self, waited: bool, mut guard: MutexGuard<'_, State<T>>) {
        let pending_sender1: Option<SignalToken> = guard.queue.dequeue();

        // If this is a no-buffer channel (cap == 0), then if we didn't wait we
        // need to ACK the sender. If we waited, then the sender waking us up
        // was already the ACK.
        let pending_sender2 = if guard.cap == 0 && !waited {
            match mem::replace(&mut guard.blocker, NoneBlocked) {
                NoneBlocked => None,
                BlockedReceiver(..) => unreachable!(),
                BlockedSender(token) => {
                    guard.canceled.take();
                    Some(token)
                }
            }
        } else {
            None
        };
        mem::drop(guard);

        // only outside of the lock do we wake up the pending threads
        if let Some(token) = pending_sender1 {
            token.signal();
        }
        if let Some(token) = pending_sender2 {
            token.signal();
        }
    }

    // Prepares this shared packet for a channel clone, essentially just bumping
    // a refcount.
    pub fn clone_chan(&self) {
        let old_count = self.channels.fetch_add(1, Ordering::SeqCst);

        // See comments on Arc::clone() on why we do this (for `mem::forget`).
        if old_count > MAX_REFCOUNT {
            abort();
        }
    }

    pub fn drop_chan(&self) {
        // Only flag the channel as disconnected if we're the last channel
        match self.channels.fetch_sub(1, Ordering::SeqCst) {
            1 => {}
            _ => return,
        }

        // Not much to do other than wake up a receiver if one's there
        let mut guard = self.lock.lock().unwrap();
        if guard.disconnected {
            return;
        }
        guard.disconnected = true;
        match mem::replace(&mut guard.blocker, NoneBlocked) {
            NoneBlocked => {}
            BlockedSender(..) => unreachable!(),
            BlockedReceiver(token) => wakeup(token, guard),
        }
    }

    pub fn drop_port(&self) {
        let mut guard = self.lock.lock().unwrap();

        if guard.disconnected {
            return;
        }
        guard.disconnected = true;

        // If the capacity is 0, then the sender may want its data back after
        // we're disconnected. Otherwise it's now our responsibility to destroy
        // the buffered data. As with many other portions of this code, this
        // needs to be careful to destroy the data *outside* of the lock to
        // prevent deadlock.
        let _data = if guard.cap != 0 { mem::take(&mut guard.buf.buf) } else { Vec::new() };
        let mut queue =
            mem::replace(&mut guard.queue, Queue { head: ptr::null_mut(), tail: ptr::null_mut() });

        let waiter = match mem::replace(&mut guard.blocker, NoneBlocked) {
            NoneBlocked => None,
            BlockedSender(token) => {
                *guard.canceled.take().unwrap() = true;
                Some(token)
            }
            BlockedReceiver(..) => unreachable!(),
        };
        mem::drop(guard);

        while let Some(token) = queue.dequeue() {
            token.signal();
        }
        if let Some(token) = waiter {
            token.signal();
        }
    }
}

impl<T> Drop for Packet<T> {
    fn drop(&mut self) {
        assert_eq!(self.channels.load(Ordering::SeqCst), 0);
        let mut guard = self.lock.lock().unwrap();
        assert!(guard.queue.dequeue().is_none());
        assert!(guard.canceled.is_none());
    }
}

////////////////////////////////////////////////////////////////////////////////
// Buffer, a simple ring buffer backed by Vec<T>
////////////////////////////////////////////////////////////////////////////////

impl<T> Buffer<T> {
    fn enqueue(&mut self, t: T) {
        let pos = (self.start + self.size) % self.buf.len();
        self.size += 1;
        let prev = mem::replace(&mut self.buf[pos], Some(t));
        assert!(prev.is_none());
    }

    fn dequeue(&mut self) -> T {
        let start = self.start;
        self.size -= 1;
        self.start = (self.start + 1) % self.buf.len();
        let result = &mut self.buf[start];
        result.take().unwrap()
    }

    fn size(&self) -> usize {
        self.size
    }
    fn capacity(&self) -> usize {
        self.buf.len()
    }
}

////////////////////////////////////////////////////////////////////////////////
// Queue, a simple queue to enqueue threads with (stack-allocated nodes)
////////////////////////////////////////////////////////////////////////////////

impl Queue {
    fn enqueue(&mut self, node: &mut Node) -> WaitToken {
        let (wait_token, signal_token) = blocking::tokens();
        node.token = Some(signal_token);
        node.next = ptr::null_mut();

        if self.tail.is_null() {
            self.head = node as *mut Node;
            self.tail = node as *mut Node;
        } else {
            unsafe {
                (*self.tail).next = node as *mut Node;
                self.tail = node as *mut Node;
            }
        }

        wait_token
    }

    fn dequeue(&mut self) -> Option<SignalToken> {
        if self.head.is_null() {
            return None;
        }
        let node = self.head;
        self.head = unsafe { (*node).next };
        if self.head.is_null() {
            self.tail = ptr::null_mut();
        }
        unsafe {
            (*node).next = ptr::null_mut();
            Some((*node).token.take().unwrap())
        }
    }
}