uds.rs0.00%
1
// Copyright 2025 Google LLC2
//3
// Licensed under the Apache License, Version 2.0 (the "License");4
// you may not use this file except in compliance with the License.5
// You may obtain a copy of the License at6
//7
// https://www.apache.org/licenses/LICENSE-2.08
//9
// Unless required by applicable law or agreed to in writing, software10
// distributed under the License is distributed on an "AS IS" BASIS,11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12
// See the License for the specific language governing permissions and13
// limitations under the License.14
15
use std::io::{ErrorKind, IoSlice, IoSliceMut, Result};16
use std::iter::zip;17
use std::os::fd::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};18
use std::os::unix::net::UnixStream;19
use std::ptr::{null_mut, read_unaligned, write_unaligned};20
21
use crate::ffi;22
23
pub const UDS_MAX_FD: usize = 32;24
25
const CMSG_BUF_LEN: usize =26
unsafe { libc::CMSG_SPACE((UDS_MAX_FD * size_of::<RawFd>()) as u32) } as usize;27
28
pub fn recv_msg_with_fds(29
conn: &UnixStream,30
bufs: &mut [IoSliceMut],31
fds: &mut [Option<OwnedFd>],32
) -> Result<usize> {33
let mut cmsg_buf = [0u64; CMSG_BUF_LEN / size_of::<u64>()];34
let mut uds_msg = libc::msghdr {35
msg_name: null_mut(),36
msg_namelen: 0,37
msg_iov: bufs.as_mut_ptr() as _,38
msg_iovlen: bufs.len() as _,39
msg_control: cmsg_buf.as_mut_ptr() as _,40
msg_controllen: CMSG_BUF_LEN as _,41
msg_flags: 0,42
};43
let flag = libc::MSG_CMSG_CLOEXEC;44
let size = ffi!(unsafe { libc::recvmsg(conn.as_raw_fd(), &mut uds_msg, flag) })?;45
46
if size == 0 {47
let buffer_size = bufs.iter().map(|b| b.len()).sum::<usize>();48
let err = if buffer_size == 0 {49
ErrorKind::InvalidInput50
} else {51
ErrorKind::ConnectionAborted52
};53
return Err(err.into());54
}55
56
if uds_msg.msg_flags & libc::MSG_CTRUNC > 0 {57
return Err(ErrorKind::OutOfMemory.into());58
}59
60
let mut overflow = false;61
let mut cmsg_ptr = unsafe { libc::CMSG_FIRSTHDR(&uds_msg) };62
let mut iter = fds.iter_mut();63
while !cmsg_ptr.is_null() {64
let cmsg = unsafe { read_unaligned(cmsg_ptr) };65
if cmsg.cmsg_level != libc::SOL_SOCKET || cmsg.cmsg_type != libc::SCM_RIGHTS {66
continue;67
}68
69
let cmsg_data_ptr = unsafe { libc::CMSG_DATA(cmsg_ptr) } as *const RawFd;70
for i in 0.. {71
let len = unsafe { libc::CMSG_LEN((size_of::<RawFd>() * (i + 1)) as u32) };72
if len > cmsg.cmsg_len as u32 {73
break;74
}75
76
let raw_fd = unsafe { read_unaligned(cmsg_data_ptr.add(i)) };77
let owned_fd = unsafe { OwnedFd::from_raw_fd(raw_fd) };78
if let Some(fd) = iter.next() {79
*fd = Some(owned_fd);80
} else {81
overflow = true;82
}83
}84
cmsg_ptr = unsafe { libc::CMSG_NXTHDR(&uds_msg, cmsg_ptr) };85
}86
87
if overflow {88
Err(ErrorKind::OutOfMemory.into())89
} else {90
Ok(size as usize)91
}92
}93
94
pub fn send_msg_with_fds(conn: &UnixStream, bufs: &[IoSlice], fds: &[BorrowedFd]) -> Result<usize> {95
if fds.len() > UDS_MAX_FD {96
return Err(ErrorKind::OutOfMemory.into());97
}98
99
let mut raw_fds = [0; UDS_MAX_FD];100
for (raw_fd, fd) in zip(&mut raw_fds, fds) {101
*raw_fd = fd.as_raw_fd();102
}103
let fds_size = size_of_val(fds) as u32;104
let buf_len = if fds_size > 0 {105
unsafe { libc::CMSG_SPACE(fds_size) }106
} else {107
0108
} as usize;109
let mut cmsg_buf = [0u64; CMSG_BUF_LEN / size_of::<u64>()];110
let uds_msg = libc::msghdr {111
msg_name: null_mut(),112
msg_namelen: 0,113
msg_iov: bufs.as_ptr() as _,114
msg_iovlen: bufs.len() as _,115
msg_control: cmsg_buf.as_mut_ptr() as _,116
msg_controllen: buf_len as _,117
msg_flags: 0,118
};119
if fds_size > 0 {120
let cmsg = libc::cmsghdr {121
cmsg_level: libc::SOL_SOCKET,122
cmsg_type: libc::SCM_RIGHTS,123
cmsg_len: unsafe { libc::CMSG_LEN(fds_size) } as _,124
};125
let cmsg_ptr = unsafe { libc::CMSG_FIRSTHDR(&uds_msg) };126
unsafe {127
write_unaligned(cmsg_ptr, cmsg);128
write_unaligned(libc::CMSG_DATA(cmsg_ptr) as *mut _, raw_fds);129
}130
}131
let size = ffi!(unsafe { libc::sendmsg(conn.as_raw_fd(), &uds_msg, 0) })?;132
Ok(size as usize)133
}134