mirror of
https://github.com/containers/youki
synced 2024-06-01 20:36:12 +02:00
326aea1021
Signed-off-by: anti-entropy123 <1348651580@qq.com>
226 lines
6.5 KiB
Rust
226 lines
6.5 KiB
Rust
use nix::{
|
|
sys::socket::{self, UnixAddr},
|
|
unistd::{self},
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::{
|
|
io::{IoSlice, IoSliceMut},
|
|
marker::PhantomData,
|
|
os::{fd::AsRawFd, unix::prelude::RawFd},
|
|
};
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum ChannelError {
|
|
#[error("failed unix syscalls")]
|
|
Nix(#[from] nix::Error),
|
|
#[error("failed serde serialization")]
|
|
Serde(#[from] serde_json::Error),
|
|
#[error("channel connection broken")]
|
|
BrokenChannel,
|
|
}
|
|
#[derive(Clone)]
|
|
pub struct Receiver<T> {
|
|
receiver: RawFd,
|
|
phantom: PhantomData<T>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct Sender<T> {
|
|
sender: RawFd,
|
|
phantom: PhantomData<T>,
|
|
}
|
|
|
|
impl<T> Sender<T>
|
|
where
|
|
T: Serialize,
|
|
{
|
|
fn send_iovec(
|
|
&mut self,
|
|
iov: &[IoSlice],
|
|
fds: Option<&[RawFd]>,
|
|
) -> Result<usize, ChannelError> {
|
|
let cmsgs = if let Some(fds) = fds {
|
|
vec![socket::ControlMessage::ScmRights(fds)]
|
|
} else {
|
|
vec![]
|
|
};
|
|
socket::sendmsg::<UnixAddr>(self.sender, iov, &cmsgs, socket::MsgFlags::empty(), None)
|
|
.map_err(|e| e.into())
|
|
}
|
|
|
|
fn send_slice_with_len(
|
|
&mut self,
|
|
data: &[u8],
|
|
fds: Option<&[RawFd]>,
|
|
) -> Result<usize, ChannelError> {
|
|
let len = data.len() as u64;
|
|
// Here we prefix the length of the data onto the serialized data.
|
|
let iov = [
|
|
IoSlice::new(unsafe {
|
|
std::slice::from_raw_parts(
|
|
(&len as *const u64) as *const u8,
|
|
std::mem::size_of::<u64>(),
|
|
)
|
|
}),
|
|
IoSlice::new(data),
|
|
];
|
|
self.send_iovec(&iov[..], fds)
|
|
}
|
|
|
|
pub fn send(&mut self, object: T) -> Result<(), ChannelError> {
|
|
let payload = serde_json::to_vec(&object)?;
|
|
self.send_slice_with_len(&payload, None)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn send_fds(&mut self, object: T, fds: &[RawFd]) -> Result<(), ChannelError> {
|
|
let payload = serde_json::to_vec(&object)?;
|
|
self.send_slice_with_len(&payload, Some(fds))?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn close(&self) -> Result<(), ChannelError> {
|
|
Ok(unistd::close(self.sender)?)
|
|
}
|
|
}
|
|
|
|
impl<T> Receiver<T>
|
|
where
|
|
T: serde::de::DeserializeOwned,
|
|
{
|
|
fn peek_size_iovec(&mut self) -> Result<u64, ChannelError> {
|
|
let mut len: u64 = 0;
|
|
let mut iov = [IoSliceMut::new(unsafe {
|
|
std::slice::from_raw_parts_mut(
|
|
(&mut len as *mut u64) as *mut u8,
|
|
std::mem::size_of::<u64>(),
|
|
)
|
|
})];
|
|
let _ =
|
|
socket::recvmsg::<UnixAddr>(self.receiver, &mut iov, None, socket::MsgFlags::MSG_PEEK)?;
|
|
match len {
|
|
0 => Err(ChannelError::BrokenChannel),
|
|
_ => Ok(len),
|
|
}
|
|
}
|
|
|
|
fn recv_into_iovec<F>(
|
|
&mut self,
|
|
iov: &mut [IoSliceMut],
|
|
) -> Result<(usize, Option<F>), ChannelError>
|
|
where
|
|
F: Default + AsMut<[RawFd]>,
|
|
{
|
|
let mut cmsgspace = nix::cmsg_space!(F);
|
|
let msg = socket::recvmsg::<UnixAddr>(
|
|
self.receiver,
|
|
iov,
|
|
Some(&mut cmsgspace),
|
|
socket::MsgFlags::MSG_CMSG_CLOEXEC,
|
|
)?;
|
|
|
|
// Sending multiple SCM_RIGHTS message will led to platform dependent
|
|
// behavior, with some system choose to return EINVAL when sending or
|
|
// silently only process the first msg or send all of it. Here we assume
|
|
// there is only one SCM_RIGHTS message and will only process the first
|
|
// message.
|
|
let fds: Option<F> = msg
|
|
.cmsgs()
|
|
.find_map(|cmsg| {
|
|
if let socket::ControlMessageOwned::ScmRights(fds) = cmsg {
|
|
Some(fds)
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.map(|fds| {
|
|
let mut fds_array: F = Default::default();
|
|
<F as AsMut<[RawFd]>>::as_mut(&mut fds_array).clone_from_slice(&fds);
|
|
fds_array
|
|
});
|
|
|
|
Ok((msg.bytes, fds))
|
|
}
|
|
|
|
fn recv_into_buf_with_len<F>(&mut self) -> Result<(Vec<u8>, Option<F>), ChannelError>
|
|
where
|
|
F: Default + AsMut<[RawFd]>,
|
|
{
|
|
let msg_len = self.peek_size_iovec()?;
|
|
let mut len: u64 = 0;
|
|
let mut buf = vec![0u8; msg_len as usize];
|
|
let (bytes, fds) = {
|
|
let mut iov = [
|
|
IoSliceMut::new(unsafe {
|
|
std::slice::from_raw_parts_mut(
|
|
(&mut len as *mut u64) as *mut u8,
|
|
std::mem::size_of::<u64>(),
|
|
)
|
|
}),
|
|
IoSliceMut::new(&mut buf),
|
|
];
|
|
self.recv_into_iovec(&mut iov)?
|
|
};
|
|
|
|
match bytes {
|
|
0 => Err(ChannelError::BrokenChannel),
|
|
_ => Ok((buf, fds)),
|
|
}
|
|
}
|
|
|
|
// Recv the next message of type T.
|
|
pub fn recv(&mut self) -> Result<T, ChannelError> {
|
|
let (buf, _) = self.recv_into_buf_with_len::<[RawFd; 0]>()?;
|
|
Ok(serde_json::from_slice(&buf[..])?)
|
|
}
|
|
|
|
// Works similar to `recv`, but will look for fds sent by SCM_RIGHTS
|
|
// message. We use F as as `[RawFd; n]`, where `n` is the number of
|
|
// descriptors you want to receive.
|
|
pub fn recv_with_fds<F>(&mut self) -> Result<(T, Option<F>), ChannelError>
|
|
where
|
|
F: Default + AsMut<[RawFd]>,
|
|
{
|
|
let (buf, fds) = self.recv_into_buf_with_len::<F>()?;
|
|
Ok((serde_json::from_slice(&buf[..])?, fds))
|
|
}
|
|
|
|
pub fn close(&self) -> Result<(), ChannelError> {
|
|
Ok(unistd::close(self.receiver)?)
|
|
}
|
|
}
|
|
|
|
pub fn channel<T>() -> Result<(Sender<T>, Receiver<T>), ChannelError>
|
|
where
|
|
T: for<'de> Deserialize<'de> + Serialize,
|
|
{
|
|
let (os_sender, os_receiver) = unix_channel()?;
|
|
let receiver = Receiver {
|
|
receiver: os_receiver,
|
|
phantom: PhantomData,
|
|
};
|
|
let sender = Sender {
|
|
sender: os_sender,
|
|
phantom: PhantomData,
|
|
};
|
|
Ok((sender, receiver))
|
|
}
|
|
|
|
// Use socketpair as the underlying pipe.
|
|
fn unix_channel() -> Result<(RawFd, RawFd), ChannelError> {
|
|
let (f1, f2) = socket::socketpair(
|
|
socket::AddressFamily::Unix,
|
|
socket::SockType::SeqPacket,
|
|
None,
|
|
socket::SockFlag::SOCK_CLOEXEC,
|
|
)?;
|
|
// It is not straightforward to share the OwnedFd across forks, so we
|
|
// treat them as i32. We use ManuallyDrop to keep the connection open.
|
|
let f1 = std::mem::ManuallyDrop::new(f1);
|
|
let f2 = std::mem::ManuallyDrop::new(f2);
|
|
|
|
Ok((f1.as_raw_fd(), f2.as_raw_fd()))
|
|
}
|