From 9bf3f84fae26089555bc8946431c93dbb81bb325 Mon Sep 17 00:00:00 2001 From: Bruce Leidl Date: Wed, 11 Sep 2019 16:20:16 -0400 Subject: [PATCH] virtio-wayland --- rust/src/devices/mod.rs | 8 +- rust/src/devices/virtio_wl/device.rs | 339 +++++++++++++++++++++++++++ rust/src/devices/virtio_wl/mod.rs | 124 ++++++++++ rust/src/devices/virtio_wl/pipe.rs | 89 +++++++ rust/src/devices/virtio_wl/shm.rs | 70 ++++++ rust/src/devices/virtio_wl/socket.rs | 102 ++++++++ rust/src/devices/virtio_wl/vfd.rs | 314 +++++++++++++++++++++++++ 7 files changed, 1043 insertions(+), 3 deletions(-) create mode 100644 rust/src/devices/virtio_wl/device.rs create mode 100644 rust/src/devices/virtio_wl/mod.rs create mode 100644 rust/src/devices/virtio_wl/pipe.rs create mode 100644 rust/src/devices/virtio_wl/shm.rs create mode 100644 rust/src/devices/virtio_wl/socket.rs create mode 100644 rust/src/devices/virtio_wl/vfd.rs diff --git a/rust/src/devices/mod.rs b/rust/src/devices/mod.rs index cb7c4d9..62979db 100644 --- a/rust/src/devices/mod.rs +++ b/rust/src/devices/mod.rs @@ -1,11 +1,13 @@ pub mod serial; pub mod rtc; -pub mod virtio_9p; -pub mod virtio_serial; -pub mod virtio_rng; +mod virtio_9p; +mod virtio_serial; +mod virtio_rng; +mod virtio_wl; mod virtio_block; pub use self::virtio_serial::VirtioSerial; pub use self::virtio_9p::VirtioP9; pub use self::virtio_rng::VirtioRandom; +pub use self::virtio_wl::VirtioWayland; pub use self::virtio_block::VirtioBlock; diff --git a/rust/src/devices/virtio_wl/device.rs b/rust/src/devices/virtio_wl/device.rs new file mode 100644 index 0000000..6b7cbc9 --- /dev/null +++ b/rust/src/devices/virtio_wl/device.rs @@ -0,0 +1,339 @@ +use std::os::unix::io::{AsRawFd,RawFd}; +use std::sync::{RwLock, Arc}; +use std::thread; + +use crate::{vm, system}; +use crate::system::EPoll; +use crate::memory::MemoryManager; +use crate::virtio::{VirtQueue, EventFd, Chain, VirtioBus, VirtioDeviceOps}; + +use crate::devices::virtio_wl::{vfd::VfdManager, consts::*, Error, Result, VfdObject}; + +pub struct VirtioWayland { + feature_bits: u64, +} + +impl VirtioWayland { + fn new() -> Self { + VirtioWayland { feature_bits: 0 } + } + + pub fn create(vbus: &mut VirtioBus) -> vm::Result<()> { + let dev = Arc::new(RwLock::new(VirtioWayland::new())); + vbus.new_virtio_device(VIRTIO_ID_WL, dev) + .set_num_queues(2) + .set_features(VIRTIO_WL_F_TRANS_FLAGS as u64) + .register() + } + + fn transition_flags(&self) -> bool { + self.feature_bits & VIRTIO_WL_F_TRANS_FLAGS as u64 != 0 + } + + fn create_device(memory: MemoryManager, in_vq: VirtQueue, out_vq: VirtQueue, transition: bool) -> Result { + let kill_evt = EventFd::new().map_err(Error::IoEventError)?; + let dev = WaylandDevice::new(memory, in_vq, out_vq, kill_evt, transition)?; + Ok(dev) + } +} + +impl VirtioDeviceOps for VirtioWayland { + fn enable_features(&mut self, bits: u64) -> bool { + self.feature_bits = bits; + true + } + + fn start(&mut self, memory: &MemoryManager, mut queues: Vec) { + thread::spawn({ + let memory = memory.clone(); + let transition = self.transition_flags(); + move || { + let out_vq = queues.pop().unwrap(); + let in_vq = queues.pop().unwrap(); + let mut dev = match Self::create_device(memory.clone(), in_vq, out_vq,transition) { + Err(e) => { + warn!("Error creating virtio wayland device: {}", e); + return; + } + Ok(dev) => dev, + }; + if let Err(e) = dev.run() { + warn!("Error running virtio-wl device: {}", e); + }; + } + }); + } +} + +struct WaylandDevice { + vfd_manager: VfdManager, + out_vq: VirtQueue, + kill_evt: EventFd, +} + +impl WaylandDevice { + const IN_VQ_TOKEN: u64 = 0; + const OUT_VQ_TOKEN:u64 = 1; + const KILL_TOKEN: u64 = 2; + const VFDS_TOKEN: u64 = 3; + + fn new(mm: MemoryManager, in_vq: VirtQueue, out_vq: VirtQueue, kill_evt: EventFd, use_transition: bool) -> Result { + let vfd_manager = VfdManager::new(mm, use_transition, in_vq, "/run/user/1000/wayland-0")?; + Ok(WaylandDevice { + vfd_manager, + out_vq, + kill_evt + }) + } + + pub fn get_vfd(&self, vfd_id: u32) -> Option<&dyn VfdObject> { + self.vfd_manager.get_vfd(vfd_id) + } + + pub fn get_mut_vfd(&mut self, vfd_id: u32) -> Option<&mut dyn VfdObject> { + self.vfd_manager.get_mut_vfd(vfd_id) + } + + fn setup_poll(&mut self) -> system::Result { + let poll = EPoll::new()?; + poll.add_read(self.vfd_manager.in_vq_poll_fd(), Self::IN_VQ_TOKEN as u64)?; + poll.add_read(self.out_vq.ioevent().as_raw_fd(), Self::OUT_VQ_TOKEN as u64)?; + poll.add_read(self.kill_evt.as_raw_fd(), Self::KILL_TOKEN as u64)?; + poll.add_read(self.vfd_manager.poll_fd(), Self::VFDS_TOKEN as u64)?; + Ok(poll) + } + fn run(&mut self) -> Result<()> { + let mut poll = self.setup_poll().map_err(Error::FailedPollContextCreate)?; + + 'poll: loop { + let events = match poll.wait() { + Ok(v) => v, + Err(e) => { + warn!("error waiting for poll events: {}", e); + break; + } + }; + for ev in events.iter() { + match ev.id() { + Self::IN_VQ_TOKEN => { + self.vfd_manager.in_vq_ready()?; + }, + Self::OUT_VQ_TOKEN => { + self.out_vq.ioevent().read().map_err(Error::IoEventError)?; + if let Some(chain) = self.out_vq.next_chain() { + let mut handler = MessageHandler::new(self, chain); + match handler.run() { + Ok(()) => { + }, + Err(err) => { + warn!("virtio_wl: error handling request: {}", err); + if !handler.responded { + let _ = handler.send_err(); + } + }, + } + handler.chain.flush_chain(); + } + }, + Self::KILL_TOKEN => break 'poll, + Self::VFDS_TOKEN => self.vfd_manager.process_poll_events(), + _ => warn!("unexpected poll token value"), + } + }; + } + Ok(()) + } +} + +struct MessageHandler<'a> { + device: &'a mut WaylandDevice, + chain: Chain, + responded: bool, +} + +impl <'a> MessageHandler<'a> { + + fn new(device: &'a mut WaylandDevice, chain: Chain) -> Self { + MessageHandler { device, chain, responded: false } + } + + fn run(&mut self) -> Result<()> { + let msg_type = self.chain.r32()?; + // Flags are always zero + let _flags = self.chain.r32()?; + match msg_type { + VIRTIO_WL_CMD_VFD_NEW => self.cmd_new_alloc(), + VIRTIO_WL_CMD_VFD_CLOSE => self.cmd_close(), + VIRTIO_WL_CMD_VFD_SEND => self.cmd_send(), + VIRTIO_WL_CMD_VFD_NEW_CTX => self.cmd_new_ctx(), + VIRTIO_WL_CMD_VFD_NEW_PIPE => self.cmd_new_pipe(), + v => { + self.send_invalid_command()?; + Err(Error::UnexpectedCommand(v)) + }, + } + } + + fn cmd_new_alloc(&mut self) -> Result<()> { + let id = self.chain.r32()?; + let flags = self.chain.r32()?; + let _pfn = self.chain.r64()?; + let size = self.chain.r32()?; + + match self.device.vfd_manager.create_shm(id, size) { + Ok((pfn,size)) => self.resp_vfd_new(id, flags, pfn, size as u32), + Err(Error::ShmAllocFailed(_)) => self.send_simple_resp(VIRTIO_WL_RESP_OUT_OF_MEMORY), + Err(e) => Err(e), + } + } + + fn resp_vfd_new(&mut self, id: u32, flags: u32, pfn: u64, size: u32) -> Result<()> { + self.chain.w32(VIRTIO_WL_RESP_VFD_NEW)?; + self.chain.w32(0)?; + self.chain.w32(id)?; + self.chain.w32(flags)?; + self.chain.w64(pfn)?; + self.chain.w32(size as u32)?; + self.responded = true; + Ok(()) + } + + fn cmd_close(&mut self) -> Result<()> { + let id = self.chain.r32()?; + self.device.vfd_manager.close_vfd(id)?; + self.send_ok() + } + + fn cmd_send(&mut self) -> Result<()> { + let id = self.chain.r32()?; + + let send_fds = self.read_vfd_ids()?; + let data = self.chain.current_read_slice(); + + let vfd = match self.device.get_mut_vfd(id) { + Some(vfd) => vfd, + None => return self.send_invalid_id(), + }; + + if let Some(fds) = send_fds.as_ref() { + vfd.send_with_fds(data, fds)?; + } else { + vfd.send(data)?; + } + self.send_ok() + } + + fn read_vfd_ids(&mut self) -> Result>> { + let vfd_count = self.chain.r32()? as usize; + if vfd_count > VIRTWL_SEND_MAX_ALLOCS { + return Err(Error::TooManySendVfds(vfd_count)) + } + if vfd_count == 0 { + return Ok(None); + } + + let mut raw_fds = Vec::with_capacity(vfd_count); + for _ in 0..vfd_count { + let vfd_id = self.chain.r32()?; + if let Some(fd) = self.vfd_id_to_raw_fd(vfd_id)? { + raw_fds.push(fd); + } + } + Ok(Some(raw_fds)) + } + + fn vfd_id_to_raw_fd(&mut self, vfd_id: u32) -> Result> { + let vfd = match self.device.get_vfd(vfd_id) { + Some(vfd) => vfd, + None => { + warn!("Received unexpected vfd id 0x{:08x}", vfd_id); + return Ok(None); + } + }; + + if let Some(fd) = vfd.send_fd() { + Ok(Some(fd)) + } else { + self.send_invalid_type()?; + Err(Error::InvalidSendVfd) + } + } + + fn cmd_new_ctx(&mut self) -> Result<()> { + let id = self.chain.r32()?; + if !Self::is_valid_id(id) { + return self.send_invalid_id(); + } + let flags = self.device.vfd_manager.create_socket(id)?; + self.resp_vfd_new(id, flags, 0, 0)?; + Ok(()) + } + + fn cmd_new_pipe(&mut self) -> Result<()> { + let id = self.chain.r32()?; + let flags = self.chain.r32()?; + + if !Self::is_valid_id(id) { + return self.send_invalid_id(); + } + if !Self::valid_new_pipe_flags(flags) { + notify!("invalid flags: 0x{:08}", flags); + return self.send_invalid_flags(); + } + + let is_write = Self::is_flag_set(flags, VIRTIO_WL_VFD_WRITE); + + self.device.vfd_manager.create_pipe(id, is_write)?; + + self.resp_vfd_new(id, 0, 0, 0) + } + + fn valid_new_pipe_flags(flags: u32) -> bool { + // only VFD_READ and VFD_WRITE may be set + if flags & !(VIRTIO_WL_VFD_WRITE|VIRTIO_WL_VFD_READ) != 0 { + return false; + } + let read = Self::is_flag_set(flags, VIRTIO_WL_VFD_READ); + let write = Self::is_flag_set(flags, VIRTIO_WL_VFD_WRITE); + // exactly one of them must be set + !(read && write) && (read || write) + } + + fn is_valid_id(id: u32) -> bool { + id & VFD_ID_HOST_MASK == 0 + } + + fn is_flag_set(flags: u32, bit: u32) -> bool { + flags & bit != 0 + } + + fn send_invalid_flags(&mut self) -> Result<()> { + self.send_simple_resp(VIRTIO_WL_RESP_INVALID_FLAGS) + } + + fn send_invalid_id(&mut self) -> Result<()> { + self.send_simple_resp(VIRTIO_WL_RESP_INVALID_ID) + } + + fn send_invalid_type(&mut self) -> Result<()> { + self.send_simple_resp(VIRTIO_WL_RESP_INVALID_TYPE) + } + + fn send_invalid_command(&mut self) -> Result<()> { + self.send_simple_resp(VIRTIO_WL_RESP_INVALID_CMD) + } + + fn send_ok(&mut self) -> Result<()> { + self.send_simple_resp(VIRTIO_WL_RESP_OK) + } + + fn send_err(&mut self) -> Result<()> { + self.send_simple_resp(VIRTIO_WL_RESP_ERR) + } + + fn send_simple_resp(&mut self, code: u32) -> Result<()> { + self.chain.w32(code)?; + self.responded = true; + Ok(()) + } +} diff --git a/rust/src/devices/virtio_wl/mod.rs b/rust/src/devices/virtio_wl/mod.rs new file mode 100644 index 0000000..ce49985 --- /dev/null +++ b/rust/src/devices/virtio_wl/mod.rs @@ -0,0 +1,124 @@ +use std::os::unix::io::RawFd; +use std::{result, io, fmt}; + +use crate::{vm, system}; +use crate::memory::Error as MemError; +use crate::system::FileDesc; + +mod vfd; +mod shm; +mod pipe; +mod socket; +mod device; + +mod consts { + use std::mem; + + pub const VIRTIO_ID_WL: u16 = 30; + pub const VIRTWL_SEND_MAX_ALLOCS: usize = 28; + pub const VIRTIO_WL_CMD_VFD_NEW: u32 = 256; + pub const VIRTIO_WL_CMD_VFD_CLOSE: u32 = 257; + pub const VIRTIO_WL_CMD_VFD_SEND: u32 = 258; + pub const VIRTIO_WL_CMD_VFD_RECV: u32 = 259; + pub const VIRTIO_WL_CMD_VFD_NEW_CTX: u32 = 260; + pub const VIRTIO_WL_CMD_VFD_NEW_PIPE: u32 = 261; + pub const VIRTIO_WL_CMD_VFD_HUP: u32 = 262; + pub const VIRTIO_WL_RESP_OK: u32 = 4096; + pub const VIRTIO_WL_RESP_VFD_NEW: u32 = 4097; + pub const VIRTIO_WL_RESP_ERR: u32 = 4352; + pub const VIRTIO_WL_RESP_OUT_OF_MEMORY: u32 = 4353; + pub const VIRTIO_WL_RESP_INVALID_ID: u32 = 4354; + pub const VIRTIO_WL_RESP_INVALID_TYPE: u32 = 4355; + pub const VIRTIO_WL_RESP_INVALID_FLAGS: u32 = 4356; + pub const VIRTIO_WL_RESP_INVALID_CMD: u32 = 4357; + + pub const VIRTIO_WL_VFD_WRITE: u32 = 0x1; // Intended to be written by guest + pub const VIRTIO_WL_VFD_READ: u32 = 0x2; // Intended to be read by guest + + pub const VIRTIO_WL_VFD_MAP: u32 = 0x2; + pub const VIRTIO_WL_VFD_CONTROL: u32 = 0x4; + pub const VIRTIO_WL_F_TRANS_FLAGS: u32 = 0x01; + + pub const NEXT_VFD_ID_BASE: u32 = 0x40000000; + pub const VFD_ID_HOST_MASK: u32 = NEXT_VFD_ID_BASE; + + pub const VFD_RECV_HDR_SIZE: usize = 16; + pub const IN_BUFFER_LEN: usize = + 0x1000 - VFD_RECV_HDR_SIZE - VIRTWL_SEND_MAX_ALLOCS * mem::size_of::(); +} + +pub use device::VirtioWayland; +pub type Result = result::Result; + +pub struct VfdRecv { + buf: Vec, + fds: Option>, +} + +impl VfdRecv { + fn new(buf: Vec) -> Self { + VfdRecv { buf, fds: None } + } + fn new_with_fds(buf: Vec, fds: Vec) -> Self { + VfdRecv { buf, fds: Some(fds) } + } +} + +pub trait VfdObject { + fn id(&self) -> u32; + fn send_fd(&self) -> Option { None } + fn poll_fd(&self) -> Option { None } + fn recv(&mut self) -> Result> { Ok(None) } + fn send(&mut self, _data: &[u8]) -> Result<()> { Err(Error::InvalidSendVfd) } + fn send_with_fds(&mut self, _data: &[u8], _fds: &[RawFd]) -> Result<()> { Err(Error::InvalidSendVfd) } + fn flags(&self) -> u32; + fn pfn_and_size(&self) -> Option<(u64, u64)> { None } + fn close(&mut self) -> Result<()>; +} + + +#[derive(Debug)] +pub enum Error { + IoEventError(vm::Error), + ChainIoError(io::Error), + UnexpectedCommand(u32), + ShmAllocFailed(system::Error), + RegisterMemoryFailed(MemError), + CreatePipesFailed(system::Error), + SocketReceive(system::Error), + SocketConnect(io::Error), + PipeReceive(io::Error), + SendVfd(io::Error), + InvalidSendVfd, + TooManySendVfds(usize), + FailedPollContextCreate(system::Error), + FailedPollAdd(system::Error), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use Error::*; + match self { + IoEventError(e) => write!(f, "error reading from ioevent fd: {}", e), + ChainIoError(e) => write!(f, "i/o error on virtio chain operation: {}", e), + UnexpectedCommand(cmd) => write!(f, "unexpected virtio wayland command: {}", cmd), + ShmAllocFailed(e) => write!(f, "failed to allocate shared memory: {}", e), + RegisterMemoryFailed(e) => write!(f, "failed to register memory with hypervisor: {}", e), + CreatePipesFailed(e) => write!(f, "failed to create pipes: {}", e), + SocketReceive(e) => write!(f, "error reading from socket: {}", e), + SocketConnect(e) => write!(f, "error connecting to socket: {}", e), + PipeReceive(e) => write!(f, "error reading from pipe: {}", e), + SendVfd(e) => write!(f, "error writing to vfd: {}", e), + InvalidSendVfd => write!(f, "attempt to send to incorrect vfd type"), + TooManySendVfds(n) => write!(f, "message has too many vfd ids: {}", n), + FailedPollContextCreate(e) => write!(f, "failed creating poll context: {}", e), + FailedPollAdd(e) => write!(f, "failed adding fd to poll context: {}", e), + } + } +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Error::ChainIoError(e) + } +} diff --git a/rust/src/devices/virtio_wl/pipe.rs b/rust/src/devices/virtio_wl/pipe.rs new file mode 100644 index 0000000..cbb1777 --- /dev/null +++ b/rust/src/devices/virtio_wl/pipe.rs @@ -0,0 +1,89 @@ +use std::os::unix::io::{AsRawFd,RawFd}; + +use crate::system::{self,FileDesc}; + +use crate::devices::virtio_wl::{ + consts::{VIRTIO_WL_VFD_WRITE, VIRTIO_WL_VFD_READ, IN_BUFFER_LEN}, + Error, Result, VfdObject, VfdRecv, +}; + + +pub struct VfdPipe { + vfd_id: u32, + flags: u32, + local: Option, + remote: Option, +} + +impl VfdPipe { + + pub fn new(vfd_id: u32, read_pipe: FileDesc, write_pipe: FileDesc, local_write: bool) -> Self { + if local_write { + VfdPipe { vfd_id, local: Some(write_pipe), remote: Some(read_pipe), flags: VIRTIO_WL_VFD_WRITE } + } else { + VfdPipe { vfd_id, local: Some(read_pipe), remote: Some(write_pipe), flags: VIRTIO_WL_VFD_READ} + } + } + + pub fn local_only(vfd_id: u32, local_pipe: FileDesc, flags: u32) -> Self { + VfdPipe { vfd_id, local: Some(local_pipe), remote: None, flags } + } + + pub fn create(vfd_id: u32, local_write: bool) -> Result { + let mut pipe_fds: [libc::c_int; 2] = [-1; 2]; + unsafe { + if libc::pipe2(pipe_fds.as_mut_ptr(), libc::O_CLOEXEC) < 0 { + return Err(Error::CreatePipesFailed(system::Error::last_os_error())); + } + let read_pipe = FileDesc::new(pipe_fds[0]); + let write_pipe = FileDesc::new(pipe_fds[1]); + Ok(Self::new(vfd_id, read_pipe, write_pipe, local_write)) + } + } +} + +impl VfdObject for VfdPipe { + fn id(&self) -> u32 { + self.vfd_id + } + + fn send_fd(&self) -> Option { + self.remote.as_ref().map(|p| p.as_raw_fd()) + } + + fn poll_fd(&self) -> Option { + self.local.as_ref().map(|p| p.as_raw_fd()) + } + + fn recv(&mut self) -> Result> { + if let Some(pipe) = self.local.take() { + let mut buf = vec![0; IN_BUFFER_LEN]; + let len = pipe.read(&mut buf[..IN_BUFFER_LEN]) + .map_err(Error::PipeReceive)?; + buf.truncate(len); + if buf.len() > 0 { + self.local.replace(pipe); + return Ok(Some(VfdRecv::new(buf))); + } + } + Ok(None) + } + + fn send(&mut self, data: &[u8]) -> Result<()> { + if let Some(pipe) = self.local.as_ref() { + pipe.write_all(data).map_err(Error::SendVfd) + } else { + Err(Error::InvalidSendVfd) + } + } + + fn flags(&self) -> u32 { + self.flags + } + + fn close(&mut self) -> Result<()> { + self.local = None; + self.remote = None; + Ok(()) + } +} diff --git a/rust/src/devices/virtio_wl/shm.rs b/rust/src/devices/virtio_wl/shm.rs new file mode 100644 index 0000000..49a9164 --- /dev/null +++ b/rust/src/devices/virtio_wl/shm.rs @@ -0,0 +1,70 @@ +use std::os::unix::io::{AsRawFd,RawFd}; + +use crate::memory::MemoryManager; +use crate::system::MemoryFd; + +use crate::devices::virtio_wl::{ + consts::{VIRTIO_WL_VFD_MAP, VIRTIO_WL_VFD_WRITE}, + Error, Result, VfdObject +}; + +pub struct VfdSharedMemory { + vfd_id: u32, + flags: u32, + mm: MemoryManager, + memfd: Option, + slot: u32, + pfn: u64, +} + +impl VfdSharedMemory { + fn round_to_page_size(n: usize) -> usize { + let mask = 4096 - 1; + (n + mask) & !mask + } + + pub fn new(vfd_id: u32, transition_flags: bool, mm: MemoryManager, memfd: MemoryFd, slot: u32, pfn: u64) -> Self { + let flags = if transition_flags { 0 } else { VIRTIO_WL_VFD_WRITE | VIRTIO_WL_VFD_MAP}; + let memfd = Some(memfd); + VfdSharedMemory { vfd_id, flags, mm, memfd, slot, pfn } + } + + pub fn create(vfd_id: u32, transition_flags: bool, size: u32, mm: &MemoryManager) -> Result { + let size = Self::round_to_page_size(size as usize); + let memfd = MemoryFd::new_memfd(size, true) + .map_err(Error::ShmAllocFailed)?; + let (pfn, slot) = mm.register_device_memory(memfd.as_raw_fd(), size) + .map_err(Error::RegisterMemoryFailed)?; + Ok(Self::new(vfd_id, transition_flags, mm.clone(), memfd, slot, pfn)) + } +} + +impl VfdObject for VfdSharedMemory { + fn id(&self) -> u32 { + self.vfd_id + } + + fn send_fd(&self) -> Option { + self.memfd.as_ref().map(AsRawFd::as_raw_fd) + } + + fn flags(&self) -> u32 { + self.flags + } + + fn pfn_and_size(&self) -> Option<(u64, u64)> { + if let Some(memfd) = self.memfd.as_ref() { + Some((self.pfn, memfd.size() as u64)) + } else { + None + } + } + + fn close(&mut self) -> Result<()> { + if let Some(_) = self.memfd.take() { + self.mm.unregister_device_memory(self.slot) + .map_err(Error::RegisterMemoryFailed)?; + } + Ok(()) + } +} diff --git a/rust/src/devices/virtio_wl/socket.rs b/rust/src/devices/virtio_wl/socket.rs new file mode 100644 index 0000000..634f83f --- /dev/null +++ b/rust/src/devices/virtio_wl/socket.rs @@ -0,0 +1,102 @@ +use std::io::{self,Write}; +use std::path::Path; +use std::os::unix::{net::UnixStream, io::{AsRawFd, RawFd}}; + +use crate::system::{FileDesc,ScmSocket}; +use crate::devices::virtio_wl::{consts:: *, Error, Result, VfdObject, VfdRecv}; + +pub struct VfdSocket { + vfd_id: u32, + flags: u32, + socket: Option, +} + +impl VfdSocket { + pub fn open>(vfd_id: u32, transition_flags: bool, path: P) -> Result { + let flags = if transition_flags { + VIRTIO_WL_VFD_READ | VIRTIO_WL_VFD_WRITE + } else { + VIRTIO_WL_VFD_CONTROL + }; + let socket = UnixStream::connect(path) + .map_err(Error::SocketConnect)?; + socket.set_nonblocking(true) + .map_err(Error::SocketConnect)?; + + Ok(VfdSocket{ + vfd_id, + flags, + socket: Some(socket), + }) + } + fn socket_recv(socket: &mut UnixStream) -> Result<(Vec, Vec)> { + let mut buf = vec![0; IN_BUFFER_LEN]; + let mut fd_buf = [0; VIRTWL_SEND_MAX_ALLOCS]; + let (len, fd_len) = socket.recv_with_fds(&mut buf, &mut fd_buf) + .map_err(Error::SocketReceive)?; + buf.truncate(len); + let files = fd_buf[..fd_len].iter() + .map(|&fd| FileDesc::new(fd)).collect(); + Ok((buf, files)) + } +} +impl VfdObject for VfdSocket { + fn id(&self) -> u32 { + self.vfd_id + } + + fn send_fd(&self) -> Option { + self.socket.as_ref().map(|s| s.as_raw_fd()) + } + + fn poll_fd(&self) -> Option { + self.socket.as_ref().map(|s| s.as_raw_fd()) + } + + fn recv(&mut self) -> Result> { + if let Some(mut sock) = self.socket.take() { + let (buf,files) = Self::socket_recv(&mut sock)?; + if !(buf.is_empty() && files.is_empty()) { + self.socket.replace(sock); + if files.is_empty() { + return Ok(Some(VfdRecv::new(buf))); + } else { + return Ok(Some(VfdRecv::new_with_fds(buf, files))); + } + } + } + Ok(None) + } + + fn send(&mut self, data: &[u8]) -> Result<()> { + if let Some(s) = self.socket.as_mut() { + s.write_all(data).map_err(Error::SendVfd) + } else { + Err(Error::InvalidSendVfd) + } + } + + fn send_with_fds(&mut self, data: &[u8], fds: &[RawFd]) -> Result<()> { + if let Some(s) = self.socket.as_mut() { + s.send_with_fds(data, fds) + .map_err(|_| Error::SendVfd(io::Error::last_os_error()))?; + Ok(()) + } else { + Err(Error::InvalidSendVfd) + } + } + + fn flags(&self) -> u32 { + if self.socket.is_some() { + self.flags + } else { + 0 + } + } + fn close(&mut self) -> Result<()> { + self.socket = None; + Ok(()) + } +} + + diff --git a/rust/src/devices/virtio_wl/vfd.rs b/rust/src/devices/virtio_wl/vfd.rs new file mode 100644 index 0000000..0e7b1f3 --- /dev/null +++ b/rust/src/devices/virtio_wl/vfd.rs @@ -0,0 +1,314 @@ +use std::collections::{HashMap, VecDeque}; +use std::io::{Write, SeekFrom}; +use std::os::unix::io::{AsRawFd,RawFd}; +use std::path::PathBuf; +use std::time::Duration; + +use crate::memory::MemoryManager; +use crate::system::{FileDesc, FileFlags,EPoll,MemoryFd}; +use crate::virtio::{VirtQueue, Chain}; + +use crate::devices::virtio_wl::{ + consts::*, Error, Result, shm::VfdSharedMemory, pipe::VfdPipe, socket::VfdSocket, VfdObject +}; + +pub struct VfdManager { + wayland_path: PathBuf, + mm: MemoryManager, + use_transition_flags: bool, + vfd_map: HashMap>, + next_vfd_id: u32, + poll_ctx: EPoll, + in_vq: VirtQueue, + in_queue_pending: VecDeque, +} + +impl VfdManager { + fn round_to_page_size(n: usize) -> usize { + let mask = 4096 - 1; + (n + mask) & !mask + } + + pub fn new>(mm: MemoryManager, use_transition_flags: bool, in_vq: VirtQueue, wayland_path: P) -> Result { + let poll_ctx = EPoll::new().map_err(Error::FailedPollContextCreate)?; + Ok(VfdManager { + wayland_path: wayland_path.into(), + mm, use_transition_flags, + vfd_map: HashMap::new(), + next_vfd_id: NEXT_VFD_ID_BASE, + poll_ctx, + in_vq, + in_queue_pending: VecDeque::new(), + }) + } + + pub fn get_vfd(&self, vfd_id: u32) -> Option<&dyn VfdObject> { + self.vfd_map.get(&vfd_id).map(|vfd| vfd.as_ref()) + } + + pub fn get_mut_vfd(&mut self, vfd_id: u32) -> Option<&mut dyn VfdObject> { + self.vfd_map.get_mut(&vfd_id) + .map(|v| v.as_mut() as &mut dyn VfdObject) + } + + + pub fn create_pipe(&mut self, vfd_id: u32, is_local_write: bool) -> Result<()> { + let pipe = VfdPipe::create(vfd_id, is_local_write)?; + // XXX unwrap + self.poll_ctx.add_read(pipe.poll_fd().unwrap(), vfd_id as u64) + .map_err(Error::FailedPollAdd)?; + self.vfd_map.insert(vfd_id, Box::new(pipe)); + Ok(()) + } + + pub fn create_shm(&mut self, vfd_id: u32, size: u32) -> Result<(u64,u64)> { + let shm = VfdSharedMemory::create(vfd_id, self.use_transition_flags, size, &self.mm)?; + let (pfn,size) = shm.pfn_and_size().unwrap(); + self.vfd_map.insert(vfd_id, Box::new(shm)); + Ok((pfn,size)) + } + + pub fn create_socket(&mut self, vfd_id: u32) -> Result { + let sock = VfdSocket::open(vfd_id, self.use_transition_flags,&self.wayland_path)?; + self.poll_ctx.add_read(sock.poll_fd().unwrap(), vfd_id as u64) + .map_err(Error::FailedPollAdd)?; + let flags = sock.flags(); + self.vfd_map.insert(vfd_id, Box::new(sock)); + Ok(flags) + + } + + pub fn poll_fd(&self) -> RawFd { + self.poll_ctx.as_raw_fd() + } + + pub fn in_vq_poll_fd(&self) -> RawFd { + self.in_vq.ioevent().as_raw_fd() + } + + pub fn process_poll_events(&mut self) { + let events = match self.poll_ctx.wait_timeout(Duration::from_secs(0)) { + Ok(v) => v.to_owned(), + Err(e) => { + warn!("Failed wait on wayland vfd events: {}", e); + return; + } + }; + for ev in events.iter() { + if ev.is_readable() { + if let Err(e) = self.recv_from_vfd(ev.id() as u32) { + warn!("Error on wayland vfd recv(0x{:08x}): {}", ev.id() as u32, e); + } + } else if ev.is_hangup() { + self.process_hangup_event(ev.id() as u32); + } + + } + + if let Err(e) = self.drain_pending() { + warn!("Error sending pending input: {}", e); + } + } + + fn drain_pending(&mut self) -> Result<()> { + if self.in_queue_pending.is_empty() { + } + while !self.in_queue_pending.is_empty() { + let mut chain = match self.in_vq.next_chain() { + Some(chain) => chain, + None => return Ok(()), + }; + self.send_next_input_message(&mut chain)?; + } + Ok(()) + } + + fn process_hangup_event(&mut self, vfd_id: u32) { + if let Some(vfd) = self.vfd_map.get(&vfd_id) { + if let Some(fd) = vfd.poll_fd() { + if let Err(e) = self.poll_ctx.delete(fd) { + warn!("failed to remove hangup vfd from poll context: {}", e); + } + } + } + self.in_queue_pending.push_back(PendingInput::new_hup(vfd_id)); + } + + fn recv_from_vfd(&mut self, vfd_id: u32) -> Result<()> { + let vfd = match self.vfd_map.get_mut(&vfd_id) { + Some(vfd) => vfd, + None => return Ok(()) + }; + let recv = match vfd.recv()? { + Some(recv) => recv, + None => { + self.in_queue_pending.push_back(PendingInput::new_hup(vfd_id)); + return Ok(()) + } + }; + + if let Some(fds) = recv.fds { + let mut vfd_ids = Vec::new(); + for fd in fds { + let vfd = self.vfd_from_file(self.next_vfd_id, fd)?; + let id = self.add_vfd_device(vfd)?; + vfd_ids.push(id); + } + self.in_queue_pending.push_back(PendingInput::new(vfd_id, Some(recv.buf), Some(vfd_ids))); + } else { + self.in_queue_pending.push_back(PendingInput::new(vfd_id, Some(recv.buf), None)); + } + Ok(()) + } + + fn add_vfd_device(&mut self, vfd: Box) -> Result { + let id = self.next_vfd_id; + if let Some(poll_fd) = vfd.poll_fd() { + self.poll_ctx.add_read(poll_fd, id as u64) + .map_err(Error::FailedPollAdd)?; + } + self.vfd_map.insert(id, vfd); + self.next_vfd_id += 1; + Ok(id) + } + + pub fn in_vq_ready(&mut self) -> Result<()> { + self.in_vq.ioevent().read().map_err(Error::IoEventError)?; + self.drain_pending() + } + + fn send_next_input_message(&mut self, chain: &mut Chain) -> Result<()> { + let pop = match self.in_queue_pending.front_mut() { + Some(msg) => msg.send_message(chain, &self.vfd_map)?, + None => false, + }; + if pop { + self.in_queue_pending.pop_front(); + } + Ok(()) + } + + fn vfd_from_file(&self, vfd_id: u32, fd: FileDesc) -> Result> { + match fd.seek(SeekFrom::End(0)) { + Ok(size) => { + let size = Self::round_to_page_size(size as usize) as u64; + let (pfn,slot) = self.mm.register_device_memory(fd.as_raw_fd(), size as usize) + .map_err(Error::RegisterMemoryFailed)?; + + let memfd = MemoryFd::from_filedesc(fd).map_err(Error::ShmAllocFailed)?; + return Ok(Box::new(VfdSharedMemory::new(vfd_id, self.use_transition_flags,self.mm.clone(), memfd, slot, pfn))); + } + _ => { + let flags = match fd.flags() { + Ok(FileFlags::Read) => VIRTIO_WL_VFD_READ, + Ok(FileFlags::Write) => VIRTIO_WL_VFD_WRITE, + Ok(FileFlags::ReadWrite) =>VIRTIO_WL_VFD_READ | VIRTIO_WL_VFD_WRITE, + _ => 0, + }; + return Ok(Box::new(VfdPipe::local_only(vfd_id, fd, flags))); + } + } + } + + pub fn close_vfd(&mut self, vfd_id: u32) -> Result<()> { + if let Some(mut vfd) = self.vfd_map.remove(&vfd_id) { + vfd.close()?; + } + // XXX remove any matching fds from in_queue_pending + Ok(()) + } +} + +struct PendingInput { + vfd_id: u32, + buf: Option>, + vfds: Option>, + // next index to transmit from vfds vector + vfd_current: usize, +} + +impl PendingInput { + fn new_hup(vfd_id: u32) -> Self { + Self::new(vfd_id, None, None) + } + + fn new(vfd_id: u32, buf: Option>, vfds: Option>) -> Self { + PendingInput { vfd_id, buf, vfds, vfd_current: 0 } + } + + fn is_hup(&self) -> bool { + self.buf.is_none() && self.vfds.is_none() + } + + fn next_vfd(&mut self) -> Option { + if let Some(ref vfds) = self.vfds { + if self.vfd_current < vfds.len() { + let id = vfds[self.vfd_current]; + self.vfd_current += 1; + return Some(id); + } + } + None + } + + fn send_message(&mut self, chain: &mut Chain, vfd_map: &HashMap>) -> Result { + let pop = if self.is_hup() { + self.send_hup_message(chain)?; + true + } else if let Some(id) = self.next_vfd() { + if let Some(vfd) = vfd_map.get(&id) { + self.send_vfd_new_message(chain, vfd.as_ref())?; + } else { + warn!("No VFD found for vfd_id = {}", id) + } + false + + } else { + self.send_recv_message(chain)?; + true + }; + Ok(pop) + } + + fn send_hup_message(&self, chain: &mut Chain) -> Result { + chain.w32(VIRTIO_WL_CMD_VFD_HUP)?; + chain.w32(0)?; + chain.w32(self.vfd_id)?; + chain.flush_chain(); + Ok(true) + } + + fn send_vfd_new_message(&self, chain: &mut Chain, vfd: &dyn VfdObject) -> Result<()> { + chain.w32(VIRTIO_WL_CMD_VFD_NEW)?; + chain.w32(0)?; + chain.w32(vfd.id())?; + chain.w32(vfd.flags())?; + let (pfn, size) = match vfd.pfn_and_size() { + Some(vals) => vals, + None => (0,0), + }; + chain.w64(pfn)?; + chain.w32(size as u32)?; + Ok(()) + } + + fn send_recv_message(&self, chain: &mut Chain) -> Result { + chain.w32(VIRTIO_WL_CMD_VFD_RECV)?; + chain.w32(0)?; + chain.w32(self.vfd_id)?; + if let Some(vfds) = self.vfds.as_ref() { + chain.w32(vfds.len() as u32)?; + for vfd_id in vfds { + chain.w32(*vfd_id)?; + } + } else { + chain.w32(0)?; + } + if let Some(buf) = self.buf.as_ref() { + chain.write_all(buf)?; + } + chain.flush_chain(); + Ok(true) + } +} +