1
0
Fork 0
mirror of https://github.com/containers/youki synced 2024-06-01 20:36:12 +02:00
youki/crates/libcontainer/src/channel.rs
anti-entropy123 326aea1021 fix lint error
Signed-off-by: anti-entropy123 <1348651580@qq.com>
2023-10-09 10:23:55 +08:00

226 lines
6.5 KiB
Rust

use nix::{
sys::socket::{self, UnixAddr},
unistd::{self},
};
use serde::{Deserialize, Serialize};
use std::{
io::{IoSlice, IoSliceMut},
marker::PhantomData,
os::{fd::AsRawFd, unix::prelude::RawFd},
};
#[derive(Debug, thiserror::Error)]
pub enum ChannelError {
#[error("failed unix syscalls")]
Nix(#[from] nix::Error),
#[error("failed serde serialization")]
Serde(#[from] serde_json::Error),
#[error("channel connection broken")]
BrokenChannel,
}
#[derive(Clone)]
pub struct Receiver<T> {
receiver: RawFd,
phantom: PhantomData<T>,
}
#[derive(Clone)]
pub struct Sender<T> {
sender: RawFd,
phantom: PhantomData<T>,
}
impl<T> Sender<T>
where
T: Serialize,
{
fn send_iovec(
&mut self,
iov: &[IoSlice],
fds: Option<&[RawFd]>,
) -> Result<usize, ChannelError> {
let cmsgs = if let Some(fds) = fds {
vec![socket::ControlMessage::ScmRights(fds)]
} else {
vec![]
};
socket::sendmsg::<UnixAddr>(self.sender, iov, &cmsgs, socket::MsgFlags::empty(), None)
.map_err(|e| e.into())
}
fn send_slice_with_len(
&mut self,
data: &[u8],
fds: Option<&[RawFd]>,
) -> Result<usize, ChannelError> {
let len = data.len() as u64;
// Here we prefix the length of the data onto the serialized data.
let iov = [
IoSlice::new(unsafe {
std::slice::from_raw_parts(
(&len as *const u64) as *const u8,
std::mem::size_of::<u64>(),
)
}),
IoSlice::new(data),
];
self.send_iovec(&iov[..], fds)
}
pub fn send(&mut self, object: T) -> Result<(), ChannelError> {
let payload = serde_json::to_vec(&object)?;
self.send_slice_with_len(&payload, None)?;
Ok(())
}
pub fn send_fds(&mut self, object: T, fds: &[RawFd]) -> Result<(), ChannelError> {
let payload = serde_json::to_vec(&object)?;
self.send_slice_with_len(&payload, Some(fds))?;
Ok(())
}
pub fn close(&self) -> Result<(), ChannelError> {
Ok(unistd::close(self.sender)?)
}
}
impl<T> Receiver<T>
where
T: serde::de::DeserializeOwned,
{
fn peek_size_iovec(&mut self) -> Result<u64, ChannelError> {
let mut len: u64 = 0;
let mut iov = [IoSliceMut::new(unsafe {
std::slice::from_raw_parts_mut(
(&mut len as *mut u64) as *mut u8,
std::mem::size_of::<u64>(),
)
})];
let _ =
socket::recvmsg::<UnixAddr>(self.receiver, &mut iov, None, socket::MsgFlags::MSG_PEEK)?;
match len {
0 => Err(ChannelError::BrokenChannel),
_ => Ok(len),
}
}
fn recv_into_iovec<F>(
&mut self,
iov: &mut [IoSliceMut],
) -> Result<(usize, Option<F>), ChannelError>
where
F: Default + AsMut<[RawFd]>,
{
let mut cmsgspace = nix::cmsg_space!(F);
let msg = socket::recvmsg::<UnixAddr>(
self.receiver,
iov,
Some(&mut cmsgspace),
socket::MsgFlags::MSG_CMSG_CLOEXEC,
)?;
// Sending multiple SCM_RIGHTS message will led to platform dependent
// behavior, with some system choose to return EINVAL when sending or
// silently only process the first msg or send all of it. Here we assume
// there is only one SCM_RIGHTS message and will only process the first
// message.
let fds: Option<F> = msg
.cmsgs()
.find_map(|cmsg| {
if let socket::ControlMessageOwned::ScmRights(fds) = cmsg {
Some(fds)
} else {
None
}
})
.map(|fds| {
let mut fds_array: F = Default::default();
<F as AsMut<[RawFd]>>::as_mut(&mut fds_array).clone_from_slice(&fds);
fds_array
});
Ok((msg.bytes, fds))
}
fn recv_into_buf_with_len<F>(&mut self) -> Result<(Vec<u8>, Option<F>), ChannelError>
where
F: Default + AsMut<[RawFd]>,
{
let msg_len = self.peek_size_iovec()?;
let mut len: u64 = 0;
let mut buf = vec![0u8; msg_len as usize];
let (bytes, fds) = {
let mut iov = [
IoSliceMut::new(unsafe {
std::slice::from_raw_parts_mut(
(&mut len as *mut u64) as *mut u8,
std::mem::size_of::<u64>(),
)
}),
IoSliceMut::new(&mut buf),
];
self.recv_into_iovec(&mut iov)?
};
match bytes {
0 => Err(ChannelError::BrokenChannel),
_ => Ok((buf, fds)),
}
}
// Recv the next message of type T.
pub fn recv(&mut self) -> Result<T, ChannelError> {
let (buf, _) = self.recv_into_buf_with_len::<[RawFd; 0]>()?;
Ok(serde_json::from_slice(&buf[..])?)
}
// Works similar to `recv`, but will look for fds sent by SCM_RIGHTS
// message. We use F as as `[RawFd; n]`, where `n` is the number of
// descriptors you want to receive.
pub fn recv_with_fds<F>(&mut self) -> Result<(T, Option<F>), ChannelError>
where
F: Default + AsMut<[RawFd]>,
{
let (buf, fds) = self.recv_into_buf_with_len::<F>()?;
Ok((serde_json::from_slice(&buf[..])?, fds))
}
pub fn close(&self) -> Result<(), ChannelError> {
Ok(unistd::close(self.receiver)?)
}
}
pub fn channel<T>() -> Result<(Sender<T>, Receiver<T>), ChannelError>
where
T: for<'de> Deserialize<'de> + Serialize,
{
let (os_sender, os_receiver) = unix_channel()?;
let receiver = Receiver {
receiver: os_receiver,
phantom: PhantomData,
};
let sender = Sender {
sender: os_sender,
phantom: PhantomData,
};
Ok((sender, receiver))
}
// Use socketpair as the underlying pipe.
fn unix_channel() -> Result<(RawFd, RawFd), ChannelError> {
let (f1, f2) = socket::socketpair(
socket::AddressFamily::Unix,
socket::SockType::SeqPacket,
None,
socket::SockFlag::SOCK_CLOEXEC,
)?;
// It is not straightforward to share the OwnedFd across forks, so we
// treat them as i32. We use ManuallyDrop to keep the connection open.
let f1 = std::mem::ManuallyDrop::new(f1);
let f2 = std::mem::ManuallyDrop::new(f2);
Ok((f1.as_raw_fd(), f2.as_raw_fd()))
}