Support ancillary data such as SCM_RIGHTS for unix domain sockets.
Needed to allow transfer of file descriptors across sockets.
This commit is contained in:
parent
5d19d1e2f3
commit
245b48cf1e
294
rust/src/system/socket.rs
Normal file
294
rust/src/system/socket.rs
Normal file
@ -0,0 +1,294 @@
|
||||
// Copyright 2017 The Chromium OS Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
//! Used to send and receive messages with file descriptors on sockets that accept control messages
|
||||
//! (e.g. Unix domain sockets).
|
||||
|
||||
use std::fs::File;
|
||||
use std::mem::size_of;
|
||||
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
|
||||
use std::os::unix::net::{UnixDatagram, UnixStream};
|
||||
use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned};
|
||||
|
||||
use libc::{
|
||||
c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET,
|
||||
};
|
||||
|
||||
use crate::system::errno::{Error,Result};
|
||||
|
||||
// Each of the following macros performs the same function as their C counterparts. They are each
|
||||
// macros because they are used to size statically allocated arrays.
|
||||
|
||||
macro_rules! CMSG_ALIGN {
|
||||
($len:expr) => {
|
||||
(($len) + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1)
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! CMSG_SPACE {
|
||||
($len:expr) => {
|
||||
size_of::<cmsghdr>() + CMSG_ALIGN!($len)
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! CMSG_LEN {
|
||||
($len:expr) => {
|
||||
size_of::<cmsghdr>() + ($len)
|
||||
};
|
||||
}
|
||||
|
||||
// This function (macro in the C version) is not used in any compile time constant slots, so is just
|
||||
// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this
|
||||
// module supports.
|
||||
#[allow(non_snake_case)]
|
||||
#[inline(always)]
|
||||
fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd {
|
||||
// Essentially returns a pointer to just past the header.
|
||||
cmsg_buffer.wrapping_offset(1) as *mut RawFd
|
||||
}
|
||||
|
||||
// This function is like CMSG_NEXT, but safer because it reads only from references, although it
|
||||
// does some pointer arithmetic on cmsg_ptr.
|
||||
#[allow(clippy::cast_ptr_alignment)]
|
||||
fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr {
|
||||
let next_cmsg = (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len)) as *mut cmsghdr;
|
||||
if next_cmsg
|
||||
.wrapping_offset(1)
|
||||
.wrapping_sub(msghdr.msg_control as usize) as usize
|
||||
> msghdr.msg_controllen
|
||||
{
|
||||
null_mut()
|
||||
} else {
|
||||
next_cmsg
|
||||
}
|
||||
}
|
||||
|
||||
const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32);
|
||||
|
||||
enum CmsgBuffer {
|
||||
Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]),
|
||||
Heap(Box<[cmsghdr]>),
|
||||
}
|
||||
|
||||
impl CmsgBuffer {
|
||||
fn with_capacity(capacity: usize) -> CmsgBuffer {
|
||||
let cap_in_cmsghdr_units =
|
||||
(capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>();
|
||||
if capacity <= CMSG_BUFFER_INLINE_CAPACITY {
|
||||
CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8])
|
||||
} else {
|
||||
CmsgBuffer::Heap(
|
||||
vec![
|
||||
cmsghdr {
|
||||
cmsg_len: 0,
|
||||
cmsg_level: 0,
|
||||
cmsg_type: 0,
|
||||
};
|
||||
cap_in_cmsghdr_units
|
||||
]
|
||||
.into_boxed_slice(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn as_mut_ptr(&mut self) -> *mut cmsghdr {
|
||||
match self {
|
||||
CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr,
|
||||
CmsgBuffer::Heap(a) => a.as_mut_ptr(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn raw_sendmsg(fd: RawFd, out_data: &[u8], out_fds: &[RawFd]) -> Result<usize> {
|
||||
let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len());
|
||||
let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
|
||||
|
||||
let mut iovec = iovec {
|
||||
iov_base: out_data.as_ptr() as *mut c_void,
|
||||
iov_len: out_data.len(),
|
||||
};
|
||||
|
||||
let mut msg = msghdr {
|
||||
msg_name: null_mut(),
|
||||
msg_namelen: 0,
|
||||
msg_iov: &mut iovec as *mut iovec,
|
||||
msg_iovlen: 1,
|
||||
msg_control: null_mut(),
|
||||
msg_controllen: 0,
|
||||
msg_flags: 0,
|
||||
};
|
||||
|
||||
if !out_fds.is_empty() {
|
||||
let cmsg = cmsghdr {
|
||||
cmsg_len: CMSG_LEN!(size_of::<RawFd>() * out_fds.len()),
|
||||
cmsg_level: SOL_SOCKET,
|
||||
cmsg_type: SCM_RIGHTS,
|
||||
};
|
||||
unsafe {
|
||||
// Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr.
|
||||
write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg);
|
||||
// Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len()
|
||||
// file descriptors.
|
||||
copy_nonoverlapping(
|
||||
out_fds.as_ptr(),
|
||||
CMSG_DATA(cmsg_buffer.as_mut_ptr()),
|
||||
out_fds.len(),
|
||||
);
|
||||
}
|
||||
|
||||
msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
|
||||
msg.msg_controllen = cmsg_capacity;
|
||||
}
|
||||
|
||||
// Safe because the msghdr was properly constructed from valid (or null) pointers of the
|
||||
// indicated length and we check the return value.
|
||||
let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) };
|
||||
|
||||
if write_count == -1 {
|
||||
Err(Error::last_os_error())
|
||||
} else {
|
||||
Ok(write_count as usize)
|
||||
}
|
||||
}
|
||||
|
||||
fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(usize, usize)> {
|
||||
let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len());
|
||||
let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
|
||||
|
||||
let mut iovec = iovec {
|
||||
iov_base: in_data.as_mut_ptr() as *mut c_void,
|
||||
iov_len: in_data.len(),
|
||||
};
|
||||
|
||||
let mut msg = msghdr {
|
||||
msg_name: null_mut(),
|
||||
msg_namelen: 0,
|
||||
msg_iov: &mut iovec as *mut iovec,
|
||||
msg_iovlen: 1,
|
||||
msg_control: null_mut(),
|
||||
msg_controllen: 0,
|
||||
msg_flags: 0,
|
||||
};
|
||||
|
||||
if !in_fds.is_empty() {
|
||||
msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
|
||||
msg.msg_controllen = cmsg_capacity;
|
||||
}
|
||||
|
||||
// Safe because the msghdr was properly constructed from valid (or null) pointers of the
|
||||
// indicated length and we check the return value.
|
||||
let total_read = unsafe { recvmsg(fd, &mut msg, 0) };
|
||||
|
||||
if total_read == -1 {
|
||||
return Err(Error::last_os_error());
|
||||
}
|
||||
|
||||
if total_read == 0 && msg.msg_controllen < size_of::<cmsghdr>() {
|
||||
return Ok((0, 0));
|
||||
}
|
||||
|
||||
let mut cmsg_ptr = msg.msg_control as *mut cmsghdr;
|
||||
let mut in_fds_count = 0;
|
||||
while !cmsg_ptr.is_null() {
|
||||
// Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that
|
||||
// that only happens when there is at least sizeof(cmsghdr) space after the pointer to read.
|
||||
let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() };
|
||||
|
||||
if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
|
||||
let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) / size_of::<RawFd>();
|
||||
unsafe {
|
||||
copy_nonoverlapping(
|
||||
CMSG_DATA(cmsg_ptr),
|
||||
in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(),
|
||||
fd_count,
|
||||
);
|
||||
}
|
||||
in_fds_count += fd_count;
|
||||
}
|
||||
|
||||
cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr);
|
||||
}
|
||||
|
||||
Ok((total_read as usize, in_fds_count))
|
||||
}
|
||||
|
||||
/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
|
||||
/// `recvmsg`.
|
||||
pub trait ScmSocket {
|
||||
/// Gets the file descriptor of this socket.
|
||||
fn socket_fd(&self) -> RawFd;
|
||||
|
||||
/// Sends the given data and file descriptor over the socket.
|
||||
///
|
||||
/// On success, returns the number of bytes sent.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - A buffer of data to send on the `socket`.
|
||||
/// * `fd` - A file descriptors to be sent.
|
||||
fn send_with_fd(&self, buf: &[u8], fd: RawFd) -> Result<usize> {
|
||||
self.send_with_fds(buf, &[fd])
|
||||
}
|
||||
|
||||
/// Sends the given data and file descriptors over the socket.
|
||||
///
|
||||
/// On success, returns the number of bytes sent.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - A buffer of data to send on the `socket`.
|
||||
/// * `fds` - A list of file descriptors to be sent.
|
||||
fn send_with_fds(&self, buf: &[u8], fd: &[RawFd]) -> Result<usize> {
|
||||
raw_sendmsg(self.socket_fd(), buf, fd)
|
||||
}
|
||||
|
||||
/// Receives data and potentially a file descriptor from the socket.
|
||||
///
|
||||
/// On success, returns the number of bytes and an optional file descriptor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - A buffer to receive data from the socket.vm
|
||||
fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> {
|
||||
let mut fd = [0];
|
||||
let (read_count, fd_count) = self.recv_with_fds(buf, &mut fd)?;
|
||||
let file = if fd_count == 0 {
|
||||
None
|
||||
} else {
|
||||
// Safe because the first fd from recv_with_fds is owned by us and valid because this
|
||||
// branch was taken.
|
||||
Some(unsafe { File::from_raw_fd(fd[0]) })
|
||||
};
|
||||
Ok((read_count, file))
|
||||
}
|
||||
|
||||
/// Receives data and file descriptors from the socket.
|
||||
///
|
||||
/// On success, returns the number of bytes and file descriptors received as a tuple
|
||||
/// `(bytes count, files count)`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - A buffer to receive data from the socket.
|
||||
/// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the
|
||||
/// number of valid file descriptors is indicated by the second element of the
|
||||
/// returned tuple. The caller owns these file descriptors, but they will not be
|
||||
/// closed on drop like a `File`-like type would be. It is recommended that each valid
|
||||
/// file descriptor gets wrapped in a drop type that closes it after this returns.
|
||||
fn recv_with_fds(&self, buf: &mut [u8], fds: &mut [RawFd]) -> Result<(usize, usize)> {
|
||||
raw_recvmsg(self.socket_fd(), buf, fds)
|
||||
}
|
||||
}
|
||||
|
||||
impl ScmSocket for UnixDatagram {
|
||||
fn socket_fd(&self) -> RawFd {
|
||||
self.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl ScmSocket for UnixStream {
|
||||
fn socket_fd(&self) -> RawFd {
|
||||
self.as_raw_fd()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user