Alioth Code Coverage

io_uring.rs0.00%

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::iter;
16use std::os::fd::{AsFd, AsRawFd};
17use std::sync::Arc;
18use std::sync::mpsc::Receiver;
19use std::thread::JoinHandle;
20
21use io_uring::cqueue::Entry as Cqe;
22use io_uring::squeue::Entry as Sqe;
23use io_uring::{SubmissionQueue, opcode, types};
24
25use crate::hv::IoeventFd;
26use crate::mem::mapped::{Ram, RamBus};
27use crate::sync::notifier::Notifier;
28use crate::virtio::dev::{
29 ActiveBackend, Backend, BackendEvent, Context, StartParam, Virtio, WakeEvent, Worker,
30 WorkerState,
31};
32use crate::virtio::queue::{DescChain, Queue, QueueReg, Status, VirtQueue};
33use crate::virtio::{IrqSender, Result};
34
35pub enum BufferAction {
36 Sqe(Sqe),
37 Written(u32),
38}
39
40pub trait VirtioIoUring: Virtio {
41 fn activate<'m, Q, S, E>(
42 &mut self,
43 feature: u128,
44 ring: &mut ActiveIoUring<'_, '_, 'm, Q, S, E>,
45 ) -> Result<()>
46 where
47 Q: VirtQueue<'m>,
48 S: IrqSender,
49 E: IoeventFd;
50
51 fn handle_desc(&mut self, q_index: u16, chain: &mut DescChain) -> Result<BufferAction>;
52
53 fn complete_desc(&mut self, q_index: u16, chain: &mut DescChain, cqe: &Cqe) -> Result<u32>;
54}
55
56const TOKEN_QUEUE: u64 = 1 << 62;
57const TOKEN_DESCRIPTOR: u64 = (1 << 62) | (1 << 61);
58
59pub struct IoUring {
60 notifier: Arc<Notifier>,
61 notifier_token: u64,
62}
63
64impl IoUring {
65 fn submit_notifier(&self, sq: &mut SubmissionQueue) -> Result<()> {
66 let fd = types::Fd(self.notifier.as_fd().as_raw_fd());
67 let poll = opcode::PollAdd::new(fd, libc::EPOLLIN as _).multi(true);
68 let entry = poll.build().user_data(self.notifier_token);
69 unsafe { sq.push(&entry) }.unwrap();
70 Ok(())
71 }
72
73 pub fn spawn_worker<D, S, E>(
74 dev: D,
75 event_rx: Receiver<WakeEvent<S, E>>,
76 memory: Arc<RamBus>,
77 queue_regs: Arc<[QueueReg]>,
78 ) -> Result<(JoinHandle<()>, Arc<Notifier>)>
79 where
80 D: VirtioIoUring,
81 E: IoeventFd,
82 S: IrqSender,
83 {
84 let notifier = Notifier::new()?;
85 let ring = IoUring {
86 notifier: Arc::new(notifier),
87 notifier_token: 0,
88 };
89 Worker::spawn(dev, ring, event_rx, memory, queue_regs)
90 }
91}
92
93impl BackendEvent for Cqe {
94 fn token(&self) -> u64 {
95 self.user_data()
96 }
97}
98
99const RING_SIZE: u16 = 256;
100const QUEUE_RESERVE_SIZE: u16 = 1;
101
102impl<D> Backend<D> for IoUring
103where
104 D: VirtioIoUring,
105{
106 fn register_notifier(&mut self, token: u64) -> Result<Arc<Notifier>> {
107 self.notifier_token = token;
108 Ok(self.notifier.clone())
109 }
110
111 fn reset(&self, _dev: &mut D) -> Result<()> {
112 Ok(())
113 }
114
115 fn event_loop<'m, S, Q, E>(
116 &mut self,
117 memory: &'m Ram,
118 context: &mut Context<D, S, E>,
119 queues: &mut [Option<Queue<'_, 'm, Q>>],
120 param: &StartParam<S, E>,
121 ) -> Result<()>
122 where
123 S: IrqSender,
124 Q: VirtQueue<'m>,
125 E: IoeventFd,
126 {
127 let submit_counts = iter::repeat_n(0, queues.len()).collect();
128 let mut active_ring = ActiveIoUring {
129 ring: io_uring::IoUring::new(RING_SIZE as u32)?,
130 shared_count: RING_SIZE - 1,
131 irq_sender: &*param.irq_sender,
132 ioeventfds: param.ioeventfds.as_deref().unwrap_or(&[]),
133 mem: memory,
134 queues,
135 submit_counts,
136 };
137 self.submit_notifier(&mut active_ring.ring.submission())?;
138 context.dev.activate(param.feature, &mut active_ring)?;
139
140 if let Some(fds) = &param.ioeventfds {
141 let sq = &mut active_ring.ring.submission();
142 for (index, fd) in fds.iter().enumerate() {
143 if context.dev.ioeventfd_offloaded(index as u16)? {
144 continue;
145 }
146 submit_queue_ioeventfd(index as u16, fd, sq)?;
147 active_ring.shared_count -= QUEUE_RESERVE_SIZE + 1;
148 }
149 }
150
151 'out: loop {
152 active_ring.ring.submit_and_wait(1)?;
153 loop {
154 let Some(entry) = active_ring.ring.completion().next() else {
155 break;
156 };
157 context.handle_event(&entry, &mut active_ring)?;
158 if context.state != WorkerState::Running {
159 break 'out;
160 }
161 }
162 }
163 Ok(())
164 }
165}
166
167pub struct ActiveIoUring<'a, 'r, 'm, Q, S, E>
168where
169 Q: VirtQueue<'m>,
170{
171 ring: io_uring::IoUring,
172 pub queues: &'a mut [Option<Queue<'r, 'm, Q>>],
173 pub irq_sender: &'a S,
174 pub ioeventfds: &'a [E],
175 pub mem: &'m Ram,
176 shared_count: u16,
177 submit_counts: Box<[u16]>,
178}
179
180fn submit_queue_ioeventfd<E>(index: u16, fd: &E, sq: &mut SubmissionQueue) -> Result<()>
181where
182 E: IoeventFd,
183{
184 let token = index as u64 | TOKEN_QUEUE;
185
186 let fd = types::Fd(fd.as_fd().as_raw_fd());
187 let poll = opcode::PollAdd::new(fd, libc::EPOLLIN as _).multi(true);
188 let entry = poll.build().user_data(token);
189 unsafe { sq.push(&entry) }.unwrap();
190 Ok(())
191}
192
193impl<'m, Q, S, E> ActiveIoUring<'_, '_, 'm, Q, S, E>
194where
195 Q: VirtQueue<'m>,
196 S: IrqSender,
197 E: IoeventFd,
198{
199 fn submit_buffers<D>(&mut self, dev: &mut D, q_index: u16) -> Result<()>
200 where
201 D: VirtioIoUring,
202 {
203 let Some(Some(q)) = self.queues.get_mut(q_index as usize) else {
204 log::error!("{}: invalid queue index {q_index}", dev.name());
205 return Ok(());
206 };
207 let submit_count = self.submit_counts.get_mut(q_index as usize).unwrap();
208
209 q.handle_desc(q_index, self.irq_sender, |chain| {
210 if *submit_count >= QUEUE_RESERVE_SIZE && self.shared_count == 0 {
211 log::debug!("{}: queue-{q_index}: no more free entries", dev.name());
212 return Ok(Status::Break);
213 };
214 match dev.handle_desc(q_index, chain)? {
215 BufferAction::Sqe(sqe) => {
216 let buffer_key = ((chain.id() as u64) << 16) | q_index as u64;
217 let sqe = sqe.user_data(buffer_key | TOKEN_DESCRIPTOR);
218 if unsafe { self.ring.submission().push(&sqe) }.is_err() {
219 log::error!("{}: queue-{q_index}: unexpected full queue", dev.name());
220 return Ok(Status::Break);
221 }
222 *submit_count += 1;
223 if *submit_count > QUEUE_RESERVE_SIZE {
224 self.shared_count -= 1;
225 }
226 Ok(Status::Deferred)
227 }
228 BufferAction::Written(len) => Ok(Status::Done { len }),
229 }
230 })
231 }
232}
233
234impl<'m, D, Q, S, E> ActiveBackend<D> for ActiveIoUring<'_, '_, 'm, Q, S, E>
235where
236 D: VirtioIoUring,
237 Q: VirtQueue<'m>,
238 S: IrqSender,
239 E: IoeventFd,
240{
241 type Event = Cqe;
242
243 fn handle_event(&mut self, dev: &mut D, event: &Self::Event) -> Result<()> {
244 let token = event.user_data();
245 if token & TOKEN_DESCRIPTOR == TOKEN_DESCRIPTOR {
246 let buffer_key = token as u32;
247 let q_index = buffer_key as u16;
248 let chain_id = (buffer_key >> 16) as u16;
249 let Some(Some(queue)) = self.queues.get_mut(q_index as usize) else {
250 log::error!("{}: invalid queue index {q_index}", dev.name());
251 return Ok(());
252 };
253 let submit_count = self.submit_counts.get_mut(q_index as usize).unwrap();
254 if *submit_count > QUEUE_RESERVE_SIZE {
255 self.shared_count += 1;
256 }
257 *submit_count -= 1;
258 queue.handle_deferred(chain_id, q_index, self.irq_sender, |chain| {
259 dev.complete_desc(q_index, chain, event)
260 })?;
261
262 self.submit_buffers(dev, q_index)
263 } else if token & TOKEN_QUEUE == TOKEN_QUEUE {
264 let index = token as u16;
265 self.submit_buffers(dev, index)
266 } else {
267 unreachable!()
268 }
269 }
270
271 fn handle_queue(&mut self, dev: &mut D, index: u16) -> Result<()> {
272 self.submit_buffers(dev, index)
273 }
274}
275