conn.rs0.00%
1
// Copyright 2025 Google LLC2
//3
// Licensed under the Apache License, Version 2.0 (the "License");4
// you may not use this file except in compliance with the License.5
// You may obtain a copy of the License at6
//7
// https://www.apache.org/licenses/LICENSE-2.08
//9
// Unless required by applicable law or agreed to in writing, software10
// distributed under the License is distributed on an "AS IS" BASIS,11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.12
// See the License for the specific language governing permissions and13
// limitations under the License.14
15
use std::io::{IoSlice, IoSliceMut, Read, Write};16
use std::os::fd::{AsFd, BorrowedFd, FromRawFd, OwnedFd};17
use std::os::unix::net::UnixStream;18
use std::path::Path;19
20
use snafu::ResultExt;21
use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes};22
23
use crate::ffi;24
use crate::utils::uds::{recv_msg_with_fds, send_msg_with_fds};25
use crate::virtio::vu::bindings::{26
DeviceConfig, FsMap, MAX_CONFIG_SIZE, MemorySingleRegion, Message, MessageFlag, VirtqAddr,27
VirtqState, VuBackMsg, VuFrontMsg,28
};29
use crate::virtio::vu::{Result, error};30
31
fn send<T, R>(32
mut conn: &UnixStream,33
req: u32,34
payload: &T,35
in_: &[u8],36
out: &mut [u8],37
fds: &[BorrowedFd],38
) -> Result<R>39
where40
T: IntoBytes + Immutable,41
R: FromBytes + IntoBytes,42
{43
let vhost_msg = Message {44
request: req,45
flag: MessageFlag::sender(),46
size: (size_of::<T>() + in_.len()) as u32,47
};48
let bufs = [49
IoSlice::new(vhost_msg.as_bytes()),50
IoSlice::new(payload.as_bytes()),51
IoSlice::new(in_),52
];53
let done = send_msg_with_fds(conn, &bufs, fds)?;54
let want = size_of_val(&vhost_msg) + vhost_msg.size as usize;55
if done != want {56
return error::PartialWrite { done, want }.fail();57
}58
59
let mut resp = Message::new_zeroed();60
let mut payload = R::new_zeroed();61
let mut ret_code = u64::MAX;62
let mut bufs = [63
IoSliceMut::new(resp.as_mut_bytes()),64
if size_of::<R>() > 0 {65
IoSliceMut::new(payload.as_mut_bytes())66
} else {67
IoSliceMut::new(ret_code.as_mut_bytes())68
},69
IoSliceMut::new(out),70
];71
let resp_size = bufs[1].len() + bufs[2].len();72
let expect_size = size_of::<Message>() + resp_size;73
74
let size = conn.read_vectored(&mut bufs)?;75
if size != expect_size {76
return error::MsgSize {77
want: expect_size,78
got: size,79
}80
.fail();81
}82
if resp.request != req {83
return error::Response {84
want: req,85
got: resp.request,86
}87
.fail();88
}89
if resp.size as usize != resp_size {90
return error::PayloadSize {91
want: resp_size,92
got: resp.size,93
}94
.fail();95
}96
if size_of::<R>() == 0 && ret_code != 0 {97
return error::RequestErr { ret: ret_code, req }.fail();98
}99
100
Ok(payload)101
}102
103
fn reply<T>(conn: &UnixStream, req: u32, payload: &T, fds: &[BorrowedFd]) -> Result<()>104
where105
T: IntoBytes + Immutable,106
{107
let msg = Message {108
request: req,109
flag: MessageFlag::receiver(),110
size: size_of_val(payload) as _,111
};112
let bufs = [113
IoSlice::new(msg.as_bytes()),114
IoSlice::new(payload.as_bytes()),115
];116
let done = send_msg_with_fds(conn, &bufs, fds)?;117
let want = size_of_val(&msg) + size_of_val(payload);118
if done != want {119
return error::PartialWrite { want, done }.fail();120
}121
Ok(())122
}123
124
fn reply_config(mut conn: &UnixStream, config: &DeviceConfig, buf: &[u8]) -> Result<()> {125
let msg = Message {126
request: VuFrontMsg::GET_CONFIG.raw(),127
flag: MessageFlag::receiver(),128
size: (size_of_val(config) + buf.len()) as _,129
};130
let bufs = [131
IoSlice::new(msg.as_bytes()),132
IoSlice::new(config.as_bytes()),133
IoSlice::new(buf),134
];135
let done = conn.write_vectored(&bufs)?;136
let want = size_of_val(&msg) + msg.size as usize;137
if done != want {138
return error::PartialWrite { want, done }.fail();139
}140
141
Ok(())142
}143
144
fn recv_with_fds<T>(conn: &UnixStream, fds: &mut [Option<OwnedFd>]) -> Result<T>145
where146
T: IntoBytes + Immutable + FromBytes,147
{148
let mut msg = T::new_zeroed();149
let mut bufs = [IoSliceMut::new(msg.as_mut_bytes())];150
let size = recv_msg_with_fds(conn, &mut bufs, fds)?;151
if size != size_of::<T>() {152
error::MsgSize {153
want: size_of::<T>(),154
got: size,155
}156
.fail()157
} else {158
Ok(msg)159
}160
}161
162
fn recv_config(mut conn: &UnixStream, buf: &mut [u8]) -> Result<DeviceConfig> {163
let mut dev_config = DeviceConfig::new_zeroed();164
let mut bufs = [165
IoSliceMut::new(dev_config.as_mut_bytes()),166
IoSliceMut::new(buf),167
];168
let got = conn.read_vectored(&mut bufs)?;169
let want = size_of::<DeviceConfig>() + dev_config.size as usize;170
if got != want {171
return error::PayloadSize {172
want,173
got: got as u32,174
}175
.fail();176
}177
Ok(dev_config)178
}179
180
#[derive(Debug)]181
pub struct VuSession {182
pub conn: UnixStream,183
}184
185
impl VuSession {186
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {187
let conn = UnixStream::connect(&path).context(error::AccessSocket {188
path: path.as_ref(),189
})?;190
Ok(VuSession { conn })191
}192
193
fn send<T, R>(&self, req: VuFrontMsg, payload: &T, fds: &[BorrowedFd]) -> Result<R>194
where195
T: IntoBytes + Immutable,196
R: FromBytes + IntoBytes,197
{198
send(&self.conn, req.raw(), payload, &[], &mut [], fds)199
}200
201
pub fn recv_payload<T>(&self) -> Result<T>202
where203
T: IntoBytes + Immutable + FromBytes,204
{205
recv_with_fds(&self.conn, &mut [])206
}207
208
pub fn recv_config(&self, buf: &mut [u8]) -> Result<DeviceConfig> {209
recv_config(&self.conn, buf)210
}211
212
pub fn recv_msg(&self, fds: &mut [Option<OwnedFd>]) -> Result<Message> {213
recv_with_fds(&self.conn, fds)214
}215
216
pub fn reply<T: IntoBytes + Immutable>(217
&self,218
req: VuFrontMsg,219
payload: &T,220
fds: &[BorrowedFd],221
) -> Result<()> {222
reply(&self.conn, req.raw(), payload, fds)223
}224
225
pub fn reply_config(&self, config: &DeviceConfig, buf: &[u8]) -> Result<()> {226
reply_config(&self.conn, config, buf)227
}228
229
pub fn get_features(&self) -> Result<u64> {230
self.send(VuFrontMsg::GET_FEATURES, &(), &[])231
}232
233
pub fn set_features(&self, payload: &u64) -> Result<()> {234
self.send(VuFrontMsg::SET_FEATURES, payload, &[])235
}236
237
pub fn get_protocol_features(&self) -> Result<u64> {238
self.send(VuFrontMsg::GET_PROTOCOL_FEATURES, &(), &[])239
}240
241
pub fn set_protocol_features(&self, payload: &u64) -> Result<u64> {242
self.send(VuFrontMsg::SET_PROTOCOL_FEATURES, payload, &[])243
}244
245
pub fn set_owner(&self) -> Result<()> {246
self.send(VuFrontMsg::SET_OWNER, &(), &[])247
}248
249
pub fn set_virtq_num(&self, payload: &VirtqState) -> Result<()> {250
self.send(VuFrontMsg::SET_VIRTQ_NUM, payload, &[])251
}252
253
pub fn set_virtq_addr(&self, payload: &VirtqAddr) -> Result<()> {254
self.send(VuFrontMsg::SET_VIRTQ_ADDR, payload, &[])255
}256
257
pub fn set_virtq_base(&self, payload: &VirtqState) -> Result<()> {258
self.send(VuFrontMsg::SET_VIRTQ_BASE, payload, &[])259
}260
261
pub fn get_config(&self, payload: &DeviceConfig, buf: &mut [u8]) -> Result<DeviceConfig> {262
let in_ = [0; MAX_CONFIG_SIZE];263
let len = buf.len();264
let req = VuFrontMsg::GET_CONFIG.raw();265
send(&self.conn, req, payload, &in_[..len], buf, &[])266
}267
268
pub fn set_config(&self, payload: &DeviceConfig, buf: &[u8]) -> Result<()>269
where270
DeviceConfig: IntoBytes,271
{272
let req = VuFrontMsg::SET_CONFIG.raw();273
send(&self.conn, req, payload, buf, &mut [], &[])274
}275
276
pub fn get_virtq_base(&self, payload: &VirtqState) -> Result<VirtqState> {277
self.send(VuFrontMsg::GET_VIRTQ_BASE, payload, &[])278
}279
280
pub fn get_queue_num(&self) -> Result<u64> {281
self.send(VuFrontMsg::GET_QUEUE_NUM, &(), &[])282
}283
284
pub fn set_virtq_kick(&self, payload: &u64, fd: BorrowedFd) -> Result<()> {285
self.send(VuFrontMsg::SET_VIRTQ_KICK, payload, &[fd])286
}287
288
pub fn set_virtq_call(&self, payload: &u64, fd: BorrowedFd) -> Result<()> {289
self.send(VuFrontMsg::SET_VIRTQ_CALL, payload, &[fd])290
}291
292
pub fn set_virtq_err(&self, payload: &u64, fd: BorrowedFd) -> Result<()> {293
self.send(VuFrontMsg::SET_VIRTQ_ERR, payload, &[fd])294
}295
296
pub fn set_virtq_enable(&self, payload: &VirtqState) -> Result<()> {297
self.send(VuFrontMsg::SET_VIRTQ_ENABLE, payload, &[])298
}299
300
pub fn set_status(&self, payload: &u64) -> Result<()> {301
self.send(VuFrontMsg::SET_STATUS, payload, &[])302
}303
304
pub fn get_status(&self) -> Result<u64> {305
self.send(VuFrontMsg::GET_STATUS, &(), &[])306
}307
308
pub fn add_mem_region(&self, payload: &MemorySingleRegion, fd: BorrowedFd) -> Result<()> {309
self.send(VuFrontMsg::ADD_MEM_REG, payload, &[fd])310
}311
312
pub fn remove_mem_region(&self, payload: &MemorySingleRegion) -> Result<()> {313
self.send(VuFrontMsg::REM_MEM_REG, payload, &[])314
}315
316
fn set_backend_req_fd(&self, fd: BorrowedFd) -> Result<()> {317
self.send(VuFrontMsg::SET_BACKEND_REQ_FD, &(), &[fd])318
}319
320
pub fn create_channel(&self) -> Result<VuChannel> {321
let mut socket_fds = [0; 2];322
ffi!(unsafe {323
libc::socketpair(libc::PF_UNIX, libc::SOCK_STREAM, 0, socket_fds.as_mut_ptr())324
})?;325
let channel = unsafe { UnixStream::from_raw_fd(socket_fds[0]) };326
let peer = unsafe { OwnedFd::from_raw_fd(socket_fds[1]) };327
self.set_backend_req_fd(peer.as_fd())?;328
Ok(VuChannel { conn: channel })329
}330
}331
332
#[derive(Debug)]333
pub struct VuChannel {334
pub conn: UnixStream,335
}336
337
impl VuChannel {338
pub fn recv_payload<T>(&self) -> Result<T>339
where340
T: IntoBytes + Immutable + FromBytes,341
{342
recv_with_fds(&self.conn, &mut [])343
}344
345
pub fn recv_msg(&self, fds: &mut [Option<OwnedFd>]) -> Result<Message> {346
recv_with_fds(&self.conn, fds)347
}348
349
pub fn reply<T: IntoBytes + Immutable>(350
&self,351
req: VuBackMsg,352
payload: &T,353
fds: &[BorrowedFd],354
) -> Result<()> {355
reply(&self.conn, req.raw(), payload, fds)356
}357
358
fn send<T, R>(&self, req: VuBackMsg, payload: &T, fds: &[BorrowedFd]) -> Result<R>359
where360
T: IntoBytes + Immutable,361
R: FromBytes + IntoBytes,362
{363
send(&self.conn, req.raw(), payload, &[], &mut [], fds)364
}365
366
pub fn fs_map(&self, payload: &FsMap, fds: &[BorrowedFd]) -> Result<()> {367
self.send(VuBackMsg::SHARED_OBJECT_ADD, payload, fds)368
}369
370
pub fn fs_unmap(&self, payload: &FsMap) -> Result<()> {371
self.send(VuBackMsg::SHARED_OBJECT_REMOVE, payload, &[])372
}373
}374