Alioth Code Coverage

conn.rs0.00%

1// Copyright 2025 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::io::{IoSlice, IoSliceMut, Read, Write};
16use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd};
17use std::os::unix::net::UnixStream;
18use std::path::Path;
19
20use snafu::ResultExt;
21use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes};
22
23use crate::ffi;
24use crate::utils::uds::{recv_msg_with_fds, send_msg_with_fds};
25use crate::virtio::vu::bindings::{
26 DeviceConfig, FsMap, MAX_CONFIG_SIZE, MemorySingleRegion, Message, MessageFlag, VirtqAddr,
27 VirtqState, VuBackMsg, VuFrontMsg,
28};
29use crate::virtio::vu::{Result, error};
30
31fn send<T, R>(
32 mut conn: &UnixStream,
33 req: u32,
34 payload: &T,
35 in_: &[u8],
36 out: &mut [u8],
37 fds: &[BorrowedFd],
38) -> Result<R>
39where
40 T: IntoBytes + Immutable,
41 R: FromBytes + IntoBytes,
42{
43 let vhost_msg = Message {
44 request: req,
45 flag: MessageFlag::sender(),
46 size: (size_of::<T>() + in_.len()) as u32,
47 };
48 let bufs = [
49 IoSlice::new(vhost_msg.as_bytes()),
50 IoSlice::new(payload.as_bytes()),
51 IoSlice::new(in_),
52 ];
53 let done = send_msg_with_fds(conn, &bufs, fds)?;
54 let want = size_of_val(&vhost_msg) + vhost_msg.size as usize;
55 if done != want {
56 return error::PartialWrite { done, want }.fail();
57 }
58
59 let mut resp = Message::new_zeroed();
60 let mut payload = R::new_zeroed();
61 let mut ret_code = u64::MAX;
62 let mut bufs = [
63 IoSliceMut::new(resp.as_mut_bytes()),
64 if size_of::<R>() > 0 {
65 IoSliceMut::new(payload.as_mut_bytes())
66 } else {
67 IoSliceMut::new(ret_code.as_mut_bytes())
68 },
69 IoSliceMut::new(out),
70 ];
71 let resp_size = bufs[1].len() + bufs[2].len();
72 let expect_size = size_of::<Message>() + resp_size;
73
74 let size = conn.read_vectored(&mut bufs)?;
75 if size != expect_size {
76 return error::MsgSize {
77 want: expect_size,
78 got: size,
79 }
80 .fail();
81 }
82 if resp.request != req {
83 return error::Response {
84 want: req,
85 got: resp.request,
86 }
87 .fail();
88 }
89 if resp.size as usize != resp_size {
90 return error::PayloadSize {
91 want: resp_size,
92 got: resp.size,
93 }
94 .fail();
95 }
96 if size_of::<R>() == 0 && ret_code != 0 {
97 return error::RequestErr { ret: ret_code, req }.fail();
98 }
99
100 Ok(payload)
101}
102
103fn reply<T>(conn: &UnixStream, req: u32, payload: &T, fds: &[BorrowedFd]) -> Result<()>
104where
105 T: IntoBytes + Immutable,
106{
107 let msg = Message {
108 request: req,
109 flag: MessageFlag::receiver(),
110 size: size_of_val(payload) as _,
111 };
112 let bufs = [
113 IoSlice::new(msg.as_bytes()),
114 IoSlice::new(payload.as_bytes()),
115 ];
116 let done = send_msg_with_fds(conn, &bufs, fds)?;
117 let want = size_of_val(&msg) + size_of_val(payload);
118 if done != want {
119 return error::PartialWrite { want, done }.fail();
120 }
121 Ok(())
122}
123
124fn reply_config(mut conn: &UnixStream, config: &DeviceConfig, buf: &[u8]) -> Result<()> {
125 let msg = Message {
126 request: VuFrontMsg::GET_CONFIG.raw(),
127 flag: MessageFlag::receiver(),
128 size: (size_of_val(config) + buf.len()) as _,
129 };
130 let bufs = [
131 IoSlice::new(msg.as_bytes()),
132 IoSlice::new(config.as_bytes()),
133 IoSlice::new(buf),
134 ];
135 let done = conn.write_vectored(&bufs)?;
136 let want = size_of_val(&msg) + msg.size as usize;
137 if done != want {
138 return error::PartialWrite { want, done }.fail();
139 }
140
141 Ok(())
142}
143
144fn recv_with_fds<T>(conn: &UnixStream, fds: &mut [Option<OwnedFd>]) -> Result<T>
145where
146 T: IntoBytes + Immutable + FromBytes,
147{
148 let mut msg = T::new_zeroed();
149 let mut bufs = [IoSliceMut::new(msg.as_mut_bytes())];
150 let size = recv_msg_with_fds(conn, &mut bufs, fds)?;
151 if size != size_of::<T>() {
152 error::MsgSize {
153 want: size_of::<T>(),
154 got: size,
155 }
156 .fail()
157 } else {
158 Ok(msg)
159 }
160}
161
162fn recv_config(mut conn: &UnixStream, buf: &mut [u8]) -> Result<DeviceConfig> {
163 let mut dev_config = DeviceConfig::new_zeroed();
164 let mut bufs = [
165 IoSliceMut::new(dev_config.as_mut_bytes()),
166 IoSliceMut::new(buf),
167 ];
168 let got = conn.read_vectored(&mut bufs)?;
169 let want = size_of::<DeviceConfig>() + dev_config.size as usize;
170 if got != want {
171 return error::PayloadSize {
172 want,
173 got: got as u32,
174 }
175 .fail();
176 }
177 Ok(dev_config)
178}
179
180#[derive(Debug)]
181pub struct VuSession {
182 pub conn: UnixStream,
183}
184
185impl VuSession {
186 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
187 let conn = UnixStream::connect(&path).context(error::AccessSocket {
188 path: path.as_ref(),
189 })?;
190 Ok(VuSession { conn })
191 }
192
193 fn send<T, R>(&self, req: VuFrontMsg, payload: &T, fds: &[BorrowedFd]) -> Result<R>
194 where
195 T: IntoBytes + Immutable,
196 R: FromBytes + IntoBytes,
197 {
198 send(&self.conn, req.raw(), payload, &[], &mut [], fds)
199 }
200
201 pub fn recv_payload<T>(&self) -> Result<T>
202 where
203 T: IntoBytes + Immutable + FromBytes,
204 {
205 recv_with_fds(&self.conn, &mut [])
206 }
207
208 pub fn recv_config(&self, buf: &mut [u8]) -> Result<DeviceConfig> {
209 recv_config(&self.conn, buf)
210 }
211
212 pub fn recv_msg(&self, fds: &mut [Option<OwnedFd>]) -> Result<Message> {
213 recv_with_fds(&self.conn, fds)
214 }
215
216 pub fn reply<T: IntoBytes + Immutable>(
217 &self,
218 req: VuFrontMsg,
219 payload: &T,
220 fds: &[BorrowedFd],
221 ) -> Result<()> {
222 reply(&self.conn, req.raw(), payload, fds)
223 }
224
225 pub fn reply_config(&self, config: &DeviceConfig, buf: &[u8]) -> Result<()> {
226 reply_config(&self.conn, config, buf)
227 }
228
229 pub fn get_features(&self) -> Result<u64> {
230 self.send(VuFrontMsg::GET_FEATURES, &(), &[])
231 }
232
233 pub fn set_features(&self, payload: &u64) -> Result<()> {
234 self.send(VuFrontMsg::SET_FEATURES, payload, &[])
235 }
236
237 pub fn get_protocol_features(&self) -> Result<u64> {
238 self.send(VuFrontMsg::GET_PROTOCOL_FEATURES, &(), &[])
239 }
240
241 pub fn set_protocol_features(&self, payload: &u64) -> Result<u64> {
242 self.send(VuFrontMsg::SET_PROTOCOL_FEATURES, payload, &[])
243 }
244
245 pub fn set_owner(&self) -> Result<()> {
246 self.send(VuFrontMsg::SET_OWNER, &(), &[])
247 }
248
249 pub fn set_virtq_num(&self, payload: &VirtqState) -> Result<()> {
250 self.send(VuFrontMsg::SET_VIRTQ_NUM, payload, &[])
251 }
252
253 pub fn set_virtq_addr(&self, payload: &VirtqAddr) -> Result<()> {
254 self.send(VuFrontMsg::SET_VIRTQ_ADDR, payload, &[])
255 }
256
257 pub fn set_virtq_base(&self, payload: &VirtqState) -> Result<()> {
258 self.send(VuFrontMsg::SET_VIRTQ_BASE, payload, &[])
259 }
260
261 pub fn get_config(&self, payload: &DeviceConfig, buf: &mut [u8]) -> Result<DeviceConfig> {
262 let in_ = [0; MAX_CONFIG_SIZE];
263 let len = buf.len();
264 let req = VuFrontMsg::GET_CONFIG.raw();
265 send(&self.conn, req, payload, &in_[..len], buf, &[])
266 }
267
268 pub fn set_config(&self, payload: &DeviceConfig, buf: &[u8]) -> Result<()>
269 where
270 DeviceConfig: IntoBytes,
271 {
272 let req = VuFrontMsg::SET_CONFIG.raw();
273 send(&self.conn, req, payload, buf, &mut [], &[])
274 }
275
276 pub fn get_virtq_base(&self, payload: &VirtqState) -> Result<VirtqState> {
277 self.send(VuFrontMsg::GET_VIRTQ_BASE, payload, &[])
278 }
279
280 pub fn get_queue_num(&self) -> Result<u64> {
281 self.send(VuFrontMsg::GET_QUEUE_NUM, &(), &[])
282 }
283
284 pub fn set_virtq_kick(&self, payload: &u64, fd: BorrowedFd) -> Result<()> {
285 self.send(VuFrontMsg::SET_VIRTQ_KICK, payload, &[fd])
286 }
287
288 pub fn set_virtq_call(&self, payload: &u64, fd: BorrowedFd) -> Result<()> {
289 self.send(VuFrontMsg::SET_VIRTQ_CALL, payload, &[fd])
290 }
291
292 pub fn set_virtq_err(&self, payload: &u64, fd: BorrowedFd) -> Result<()> {
293 self.send(VuFrontMsg::SET_VIRTQ_ERR, payload, &[fd])
294 }
295
296 pub fn set_virtq_enable(&self, payload: &VirtqState) -> Result<()> {
297 self.send(VuFrontMsg::SET_VIRTQ_ENABLE, payload, &[])
298 }
299
300 pub fn set_status(&self, payload: &u64) -> Result<()> {
301 self.send(VuFrontMsg::SET_STATUS, payload, &[])
302 }
303
304 pub fn get_status(&self) -> Result<u64> {
305 self.send(VuFrontMsg::GET_STATUS, &(), &[])
306 }
307
308 pub fn add_mem_region(&self, payload: &MemorySingleRegion, fd: BorrowedFd) -> Result<()> {
309 self.send(VuFrontMsg::ADD_MEM_REG, payload, &[fd])
310 }
311
312 pub fn remove_mem_region(&self, payload: &MemorySingleRegion) -> Result<()> {
313 self.send(VuFrontMsg::REM_MEM_REG, payload, &[])
314 }
315
316 fn set_backend_req_fd(&self, fd: BorrowedFd) -> Result<()> {
317 self.send(VuFrontMsg::SET_BACKEND_REQ_FD, &(), &[fd])
318 }
319
320 pub fn create_channel(&self) -> Result<VuChannel> {
321 let mut socket_fds = [0; 2];
322 ffi!(unsafe {
323 libc::socketpair(libc::PF_UNIX, libc::SOCK_STREAM, 0, socket_fds.as_mut_ptr())
324 })?;
325 let channel = unsafe { UnixStream::from_raw_fd(socket_fds[0]) };
326 let peer = unsafe { OwnedFd::from_raw_fd(socket_fds[1]) };
327 self.set_backend_req_fd(peer.as_fd())?;
328 Ok(VuChannel { conn: channel })
329 }
330}
331
332#[derive(Debug)]
333pub struct VuChannel {
334 pub conn: UnixStream,
335}
336
337impl VuChannel {
338 pub fn recv_payload<T>(&self) -> Result<T>
339 where
340 T: IntoBytes + Immutable + FromBytes,
341 {
342 recv_with_fds(&self.conn, &mut [])
343 }
344
345 pub fn recv_msg(&self, fds: &mut [Option<OwnedFd>]) -> Result<Message> {
346 recv_with_fds(&self.conn, fds)
347 }
348
349 pub fn reply<T: IntoBytes + Immutable>(
350 &self,
351 req: VuBackMsg,
352 payload: &T,
353 fds: &[BorrowedFd],
354 ) -> Result<()> {
355 reply(&self.conn, req.raw(), payload, fds)
356 }
357
358 fn send<T, R>(&self, req: VuBackMsg, payload: &T, fds: &[BorrowedFd]) -> Result<R>
359 where
360 T: IntoBytes + Immutable,
361 R: FromBytes + IntoBytes,
362 {
363 send(&self.conn, req.raw(), payload, &[], &mut [], fds)
364 }
365
366 pub fn fs_map(&self, payload: &FsMap, fds: &[BorrowedFd]) -> Result<()> {
367 self.send(VuBackMsg::SHARED_OBJECT_ADD, payload, fds)
368 }
369
370 pub fn fs_unmap(&self, payload: &FsMap) -> Result<()> {
371 self.send(VuBackMsg::SHARED_OBJECT_REMOVE, payload, &[])
372 }
373}
374