diff --git a/rust/src/system/socket.rs b/rust/src/system/socket.rs new file mode 100644 index 0000000..fd97e4b --- /dev/null +++ b/rust/src/system/socket.rs @@ -0,0 +1,294 @@ +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Used to send and receive messages with file descriptors on sockets that accept control messages +//! (e.g. Unix domain sockets). + +use std::fs::File; +use std::mem::size_of; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::net::{UnixDatagram, UnixStream}; +use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned}; + +use libc::{ + c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET, +}; + +use crate::system::errno::{Error,Result}; + +// Each of the following macros performs the same function as their C counterparts. They are each +// macros because they are used to size statically allocated arrays. + +macro_rules! CMSG_ALIGN { + ($len:expr) => { + (($len) + size_of::() - 1) & !(size_of::() - 1) + }; +} + +macro_rules! CMSG_SPACE { + ($len:expr) => { + size_of::() + CMSG_ALIGN!($len) + }; +} + +macro_rules! CMSG_LEN { + ($len:expr) => { + size_of::() + ($len) + }; +} + +// This function (macro in the C version) is not used in any compile time constant slots, so is just +// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this +// module supports. +#[allow(non_snake_case)] +#[inline(always)] +fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd { + // Essentially returns a pointer to just past the header. + cmsg_buffer.wrapping_offset(1) as *mut RawFd +} + +// This function is like CMSG_NEXT, but safer because it reads only from references, although it +// does some pointer arithmetic on cmsg_ptr. +#[allow(clippy::cast_ptr_alignment)] +fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr { + let next_cmsg = (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len)) as *mut cmsghdr; + if next_cmsg + .wrapping_offset(1) + .wrapping_sub(msghdr.msg_control as usize) as usize + > msghdr.msg_controllen + { + null_mut() + } else { + next_cmsg + } +} + +const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::() * 32); + +enum CmsgBuffer { + Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]), + Heap(Box<[cmsghdr]>), +} + +impl CmsgBuffer { + fn with_capacity(capacity: usize) -> CmsgBuffer { + let cap_in_cmsghdr_units = + (capacity.checked_add(size_of::()).unwrap() - 1) / size_of::(); + if capacity <= CMSG_BUFFER_INLINE_CAPACITY { + CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]) + } else { + CmsgBuffer::Heap( + vec![ + cmsghdr { + cmsg_len: 0, + cmsg_level: 0, + cmsg_type: 0, + }; + cap_in_cmsghdr_units + ] + .into_boxed_slice(), + ) + } + } + + fn as_mut_ptr(&mut self) -> *mut cmsghdr { + match self { + CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr, + CmsgBuffer::Heap(a) => a.as_mut_ptr(), + } + } +} + +fn raw_sendmsg(fd: RawFd, out_data: &[u8], out_fds: &[RawFd]) -> Result { + let cmsg_capacity = CMSG_SPACE!(size_of::() * out_fds.len()); + let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); + + let mut iovec = iovec { + iov_base: out_data.as_ptr() as *mut c_void, + iov_len: out_data.len(), + }; + + let mut msg = msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: &mut iovec as *mut iovec, + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + + if !out_fds.is_empty() { + let cmsg = cmsghdr { + cmsg_len: CMSG_LEN!(size_of::() * out_fds.len()), + cmsg_level: SOL_SOCKET, + cmsg_type: SCM_RIGHTS, + }; + unsafe { + // Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr. + write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg); + // Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len() + // file descriptors. + copy_nonoverlapping( + out_fds.as_ptr(), + CMSG_DATA(cmsg_buffer.as_mut_ptr()), + out_fds.len(), + ); + } + + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; + msg.msg_controllen = cmsg_capacity; + } + + // Safe because the msghdr was properly constructed from valid (or null) pointers of the + // indicated length and we check the return value. + let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) }; + + if write_count == -1 { + Err(Error::last_os_error()) + } else { + Ok(write_count as usize) + } +} + +fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(usize, usize)> { + let cmsg_capacity = CMSG_SPACE!(size_of::() * in_fds.len()); + let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); + + let mut iovec = iovec { + iov_base: in_data.as_mut_ptr() as *mut c_void, + iov_len: in_data.len(), + }; + + let mut msg = msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: &mut iovec as *mut iovec, + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + + if !in_fds.is_empty() { + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; + msg.msg_controllen = cmsg_capacity; + } + + // Safe because the msghdr was properly constructed from valid (or null) pointers of the + // indicated length and we check the return value. + let total_read = unsafe { recvmsg(fd, &mut msg, 0) }; + + if total_read == -1 { + return Err(Error::last_os_error()); + } + + if total_read == 0 && msg.msg_controllen < size_of::() { + return Ok((0, 0)); + } + + let mut cmsg_ptr = msg.msg_control as *mut cmsghdr; + let mut in_fds_count = 0; + while !cmsg_ptr.is_null() { + // Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that + // that only happens when there is at least sizeof(cmsghdr) space after the pointer to read. + let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() }; + + if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS { + let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) / size_of::(); + unsafe { + copy_nonoverlapping( + CMSG_DATA(cmsg_ptr), + in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(), + fd_count, + ); + } + in_fds_count += fd_count; + } + + cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr); + } + + Ok((total_read as usize, in_fds_count)) +} + +/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and +/// `recvmsg`. +pub trait ScmSocket { + /// Gets the file descriptor of this socket. + fn socket_fd(&self) -> RawFd; + + /// Sends the given data and file descriptor over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `buf` - A buffer of data to send on the `socket`. + /// * `fd` - A file descriptors to be sent. + fn send_with_fd(&self, buf: &[u8], fd: RawFd) -> Result { + self.send_with_fds(buf, &[fd]) + } + + /// Sends the given data and file descriptors over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `buf` - A buffer of data to send on the `socket`. + /// * `fds` - A list of file descriptors to be sent. + fn send_with_fds(&self, buf: &[u8], fd: &[RawFd]) -> Result { + raw_sendmsg(self.socket_fd(), buf, fd) + } + + /// Receives data and potentially a file descriptor from the socket. + /// + /// On success, returns the number of bytes and an optional file descriptor. + /// + /// # Arguments + /// + /// * `buf` - A buffer to receive data from the socket.vm + fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option)> { + let mut fd = [0]; + let (read_count, fd_count) = self.recv_with_fds(buf, &mut fd)?; + let file = if fd_count == 0 { + None + } else { + // Safe because the first fd from recv_with_fds is owned by us and valid because this + // branch was taken. + Some(unsafe { File::from_raw_fd(fd[0]) }) + }; + Ok((read_count, file)) + } + + /// Receives data and file descriptors from the socket. + /// + /// On success, returns the number of bytes and file descriptors received as a tuple + /// `(bytes count, files count)`. + /// + /// # Arguments + /// + /// * `buf` - A buffer to receive data from the socket. + /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the + /// number of valid file descriptors is indicated by the second element of the + /// returned tuple. The caller owns these file descriptors, but they will not be + /// closed on drop like a `File`-like type would be. It is recommended that each valid + /// file descriptor gets wrapped in a drop type that closes it after this returns. + fn recv_with_fds(&self, buf: &mut [u8], fds: &mut [RawFd]) -> Result<(usize, usize)> { + raw_recvmsg(self.socket_fd(), buf, fds) + } +} + +impl ScmSocket for UnixDatagram { + fn socket_fd(&self) -> RawFd { + self.as_raw_fd() + } +} + +impl ScmSocket for UnixStream { + fn socket_fd(&self) -> RawFd { + self.as_raw_fd() + } +}