summaryrefslogtreecommitdiffstats
path: root/third_party/rust/audioipc2/src/sys/unix/cmsg.rs
blob: 0451f119268d077d1ec15b89c9863613752b7929 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
// Copyright © 2017 Mozilla Foundation
//
// This program is made available under an ISC-style license.  See the
// accompanying file LICENSE for details

use crate::sys::HANDLE_QUEUE_LIMIT;
use bytes::{BufMut, BytesMut};
use libc::{self, cmsghdr};
use std::convert::TryInto;
use std::os::unix::io::RawFd;
use std::{mem, slice};

trait AsBytes {
    fn as_bytes(&self) -> &[u8];
}

impl<'a, T: Sized> AsBytes for &'a [T] {
    fn as_bytes(&self) -> &[u8] {
        // TODO: This should account for the alignment of T
        let byte_count = self.len() * mem::size_of::<T>();
        unsafe { slice::from_raw_parts(self.as_ptr() as *const _, byte_count) }
    }
}

// Encode `handles` into a cmsghdr in `buf`.
pub fn encode_handles(cmsg: &mut BytesMut, handles: &[RawFd]) {
    assert!(handles.len() <= HANDLE_QUEUE_LIMIT);
    let msg = handles.as_bytes();

    let cmsg_space = space(msg.len());
    assert!(cmsg.remaining_mut() >= cmsg_space);

    // Some definitions of cmsghdr contain padding.  Rather
    // than try to keep an up-to-date #cfg list to handle
    // that, just use a pre-zeroed struct to fill out any
    // fields we don't care about.
    let zeroed = unsafe { mem::zeroed() };
    #[allow(clippy::needless_update)]
    // `cmsg_len` is `usize` on some platforms, `u32` on others.
    #[allow(clippy::useless_conversion)]
    let cmsghdr = cmsghdr {
        cmsg_len: len(msg.len()).try_into().unwrap(),
        cmsg_level: libc::SOL_SOCKET,
        cmsg_type: libc::SCM_RIGHTS,
        ..zeroed
    };

    unsafe {
        let cmsghdr_ptr = cmsg.chunk_mut().as_mut_ptr();
        std::ptr::copy_nonoverlapping(
            &cmsghdr as *const _ as *const _,
            cmsghdr_ptr,
            mem::size_of::<cmsghdr>(),
        );
        let cmsg_data_ptr = libc::CMSG_DATA(cmsghdr_ptr as _);
        std::ptr::copy_nonoverlapping(msg.as_ptr(), cmsg_data_ptr, msg.len());
        cmsg.advance_mut(cmsg_space);
    }
}

// Decode `buf` containing a cmsghdr with one or more handle(s).
pub fn decode_handles(buf: &mut BytesMut) -> arrayvec::ArrayVec<RawFd, HANDLE_QUEUE_LIMIT> {
    let mut fds = arrayvec::ArrayVec::<RawFd, HANDLE_QUEUE_LIMIT>::new();

    let cmsghdr_len = len(0);

    if buf.len() < cmsghdr_len {
        // No more entries---not enough data in `buf` for a
        // complete message.
        return fds;
    }

    let cmsg: &cmsghdr = unsafe { &*(buf.as_ptr() as *const _) };
    let cmsg_len = cmsg.cmsg_len as usize;

    match (cmsg.cmsg_level, cmsg.cmsg_type) {
        (libc::SOL_SOCKET, libc::SCM_RIGHTS) => {
            trace!("Found SCM_RIGHTS...");
            let slice = &buf[cmsghdr_len..cmsg_len];
            let slice = unsafe {
                slice::from_raw_parts(
                    slice.as_ptr() as *const _,
                    slice.len() / mem::size_of::<i32>(),
                )
            };
            fds.try_extend_from_slice(slice).unwrap();
        }
        (level, kind) => {
            trace!("Skipping cmsg level, {}, type={}...", level, kind);
        }
    }

    assert!(fds.len() <= HANDLE_QUEUE_LIMIT);
    fds
}

fn len(len: usize) -> usize {
    unsafe { libc::CMSG_LEN(len.try_into().unwrap()) as usize }
}

pub fn space(len: usize) -> usize {
    unsafe { libc::CMSG_SPACE(len.try_into().unwrap()) as usize }
}