From bb826b9c77583665f2b3e5ba1c13e4d3be336208 Mon Sep 17 00:00:00 2001 From: Junliang Hu Date: Thu, 20 Nov 2025 00:04:52 -0500 Subject: [PATCH 1/4] feat(ibverbs): allow getting the lid from PortAttr --- src/ibverbs/device_context.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/ibverbs/device_context.rs b/src/ibverbs/device_context.rs index f1c8802..eac6e49 100644 --- a/src/ibverbs/device_context.rs +++ b/src/ibverbs/device_context.rs @@ -416,6 +416,11 @@ impl PortAttr { } (self.attr.active_speed as u32).into() } + + /// Get the lid of this port. + pub fn lid(&self) -> u16 { + self.attr.lid + } } /// The attributes of an RDMA device that is associated with a context. From 76455a2ecf47115cd42706d6cd8d7d9fab374536 Mon Sep 17 00:00:00 2001 From: Junliang Hu Date: Tue, 2 Dec 2025 16:17:27 -0500 Subject: [PATCH 2/4] chore: better naming --- tests/test_post_send.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_post_send.rs b/tests/test_post_send.rs index 1cc1ad7..2d7a0e8 100644 --- a/tests/test_post_send.rs +++ b/tests/test_post_send.rs @@ -31,7 +31,7 @@ fn main(#[case] use_qp_ex: bool, #[case] use_cq_ex: bool) -> Result<(), Box = vec![0; 64]; let mut recv_data: Vec = vec![0; 64]; - let mr = unsafe { + let send_mr = unsafe { pd.reg_mr( send_data.as_ptr() as _, send_data.len(), @@ -151,7 +151,7 @@ fn main(#[case] use_qp_ex: bool, #[case] use_cq_ex: bool) -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Date: Tue, 2 Dec 2025 23:00:12 -0500 Subject: [PATCH 3/4] refactor(ibverbs): ergonomically treat CompletionQueueEmpty as an empty iterator --- examples/rc_pingpong.rs | 143 +++++++++++------------ examples/rc_pingpong_split.rs | 45 +++---- src/ibverbs/completion.rs | 213 +++++++++++++++++++--------------- tests/test_post_send.rs | 12 +- 4 files changed, 205 insertions(+), 208 deletions(-) diff --git a/examples/rc_pingpong.rs b/examples/rc_pingpong.rs index 524dbf7..e272623 100644 --- a/examples/rc_pingpong.rs +++ b/examples/rc_pingpong.rs @@ -326,89 +326,76 @@ fn main() -> anyhow::Result<()> { // poll for the completion { loop { - match cq.start_poll() { - Ok(mut poller) => { - while let Some(wc) = poller.next() { - if wc.status() != WorkCompletionStatus::Success as u32 { - panic!( - "Failed status {:#?} ({}) for wr_id {}", - Into::::into(wc.status()), - wc.status(), - wc.wr_id() - ); - } - match wc.wr_id() { - SEND_WR_ID => { - scnt += 1; - outstanding_send = false; - }, - RECV_WR_ID => { - rcnt += 1; - rout -= 1; - - // Post more receives if the receive side credit is low - if rout <= rx_depth / 2 { - let to_post = rx_depth - rout; - for _ in 0..to_post { - let mut guard = qp.start_post_recv(); - let recv_handle = guard.construct_wr(RECV_WR_ID); - unsafe { - recv_handle.setup_sge( - recv_mr.lkey(), - recv_data.as_mut_ptr() as _, - args.size, - ); - }; - guard.post().unwrap(); - } - rout += to_post; - } - - if args.ts { - let timestamp = wc.completion_timestamp(); - if ts_param.last_completion_with_timestamp != 0 { - let delta: u64 = if timestamp >= ts_param.completion_recv_prev_time { - timestamp - ts_param.completion_recv_prev_time - } else { - completion_timestamp_mask - ts_param.completion_recv_prev_time - + timestamp - + 1 - }; - - ts_param.completion_recv_max_time_delta = - ts_param.completion_recv_max_time_delta.max(delta); - ts_param.completion_recv_min_time_delta = - ts_param.completion_recv_min_time_delta.min(delta); - ts_param.completion_recv_total_time_delta += delta; - ts_param.completion_with_time_iters += 1; - } - - ts_param.completion_recv_prev_time = timestamp; - ts_param.last_completion_with_timestamp = 1; - } else { - ts_param.last_completion_with_timestamp = 0; - } - }, - _ => { - panic!("Unknown error!"); - }, + for wc in cq.iter()? { + if wc.status() != WorkCompletionStatus::Success as u32 { + panic!( + "Failed status {:#?} ({}) for wr_id {}", + Into::::into(wc.status()), + wc.status(), + wc.wr_id() + ); + } + match wc.wr_id() { + SEND_WR_ID => { + scnt += 1; + outstanding_send = false; + }, + RECV_WR_ID => { + rcnt += 1; + rout -= 1; + + // Post more receives if the receive side credit is low + if rout <= rx_depth / 2 { + let to_post = rx_depth - rout; + for _ in 0..to_post { + let mut guard = qp.start_post_recv(); + let recv_handle = guard.construct_wr(RECV_WR_ID); + unsafe { + recv_handle.setup_sge(recv_mr.lkey(), recv_data.as_mut_ptr() as _, args.size); + }; + guard.post().unwrap(); + } + rout += to_post; } - if scnt < args.iter && !outstanding_send { - // Post another send if we haven't reached the iteration limit - let mut guard = qp.start_post_send(); - let send_handle = guard.construct_wr(SEND_WR_ID, WorkRequestFlags::Signaled).setup_send(); - unsafe { - send_handle.setup_sge(send_mr.lkey(), send_data.as_ptr() as _, args.size); + if args.ts { + let timestamp = wc.completion_timestamp(); + if ts_param.last_completion_with_timestamp != 0 { + let delta: u64 = if timestamp >= ts_param.completion_recv_prev_time { + timestamp - ts_param.completion_recv_prev_time + } else { + completion_timestamp_mask - ts_param.completion_recv_prev_time + timestamp + 1 + }; + + ts_param.completion_recv_max_time_delta = + ts_param.completion_recv_max_time_delta.max(delta); + ts_param.completion_recv_min_time_delta = + ts_param.completion_recv_min_time_delta.min(delta); + ts_param.completion_recv_total_time_delta += delta; + ts_param.completion_with_time_iters += 1; } - guard.post()?; - outstanding_send = true; + + ts_param.completion_recv_prev_time = timestamp; + ts_param.last_completion_with_timestamp = 1; + } else { + ts_param.last_completion_with_timestamp = 0; } + }, + _ => { + panic!("Unknown error!"); + }, + } + + if scnt < args.iter && !outstanding_send { + // Post another send if we haven't reached the iteration limit + let mut guard = qp.start_post_send(); + let send_handle = guard.construct_wr(SEND_WR_ID, WorkRequestFlags::Signaled).setup_send(); + unsafe { + send_handle.setup_sge(send_mr.lkey(), send_data.as_ptr() as _, args.size); } - }, - Err(_) => { - continue; - }, + guard.post()?; + outstanding_send = true; + } } // Check if we're done diff --git a/examples/rc_pingpong_split.rs b/examples/rc_pingpong_split.rs index c0d53ed..16d4712 100644 --- a/examples/rc_pingpong_split.rs +++ b/examples/rc_pingpong_split.rs @@ -457,32 +457,25 @@ fn main() -> anyhow::Result<()> { let mut need_post_send = false; { - match ctx.cq.start_poll() { - Ok(mut poller) => { - while let Some(wc) = poller.next() { - ctx.parse_single_work_completion( - &wc, - &mut ts_param, - &mut scnt, - &mut rcnt, - &mut outstanding_send, - &mut rout, - rx_depth, - &mut need_post_recv, - &mut to_post_recv, - args.ts, - ); - - // Record that we need to post a send later - if scnt < args.iter && !outstanding_send { - need_post_send = true; - outstanding_send = true; - } - } - }, - Err(_) => { - continue; - }, + for wc in ctx.cq.iter()? { + ctx.parse_single_work_completion( + &wc, + &mut ts_param, + &mut scnt, + &mut rcnt, + &mut outstanding_send, + &mut rout, + rx_depth, + &mut need_post_recv, + &mut to_post_recv, + args.ts, + ); + + // Record that we need to post a send later + if scnt < args.iter && !outstanding_send { + need_post_send = true; + outstanding_send = true; + } } } diff --git a/src/ibverbs/completion.rs b/src/ibverbs/completion.rs index 3d9c7d6..0f03364 100644 --- a/src/ibverbs/completion.rs +++ b/src/ibverbs/completion.rs @@ -53,8 +53,6 @@ pub enum CreateCompletionQueueErrorKind { pub enum PollCompletionQueueError { #[error("poll completion queue failed")] Ibverbs(#[from] io::Error), - #[error("completion queue is empty")] - CompletionQueueEmpty, } /// Possible statuses of a Work Completion's corresponding operation. @@ -326,37 +324,8 @@ impl BasicCompletionQueue { /// /// [`ibv_poll_cq`]: https://www.rdmamojo.com/2013/02/15/ibv_poll_cq/ /// - pub fn start_poll(&self) -> Result, PollCompletionQueueError> { - let mut cqes = Vec::::with_capacity(self.poll_batch.get() as _); - - let ret = unsafe { - ibv_poll_cq( - self.cq.as_ptr(), - self.poll_batch.get().try_into().unwrap(), - cqes.as_mut_ptr(), - ) - }; - - unsafe { - match ret { - 0 => Err(PollCompletionQueueError::CompletionQueueEmpty), - err if err < 0 => Err(PollCompletionQueueError::Ibverbs(io::Error::from_raw_os_error(-err))), - res => Ok(BasicPoller { - cq: self.cq(), - wcs: { - cqes.set_len(res as _); - cqes - }, - status: if res < self.poll_batch.get().try_into().unwrap_unchecked() { - BasicCompletionQueueState::Drained - } else { - BasicCompletionQueueState::Ready - }, - current: 0, - _phantom: PhantomData, - }), - } - } + pub fn iter(&self) -> Result, PollCompletionQueueError> { + BasicCompletionQueueIter::new(self.cq, self.poll_batch.get().try_into().unwrap()) } /// Change the polling batch size, note that this won't take effect until your next call to @@ -401,23 +370,8 @@ impl CompletionQueue for ExtendedCompletionQueue { impl ExtendedCompletionQueue { /// Starts to poll Work Completions over this CQ, every [`ExtendedCompletionQueue`] should hold /// only one [`ExtendedPoller`] at the same time. - pub fn start_poll(&self) -> Result, PollCompletionQueueError> { - let ret = unsafe { - ibv_start_poll( - self.cq_ex.as_ptr(), - MaybeUninit::::zeroed().as_mut_ptr(), - ) - }; - - match ret { - 0 => Ok(ExtendedPoller { - cq: self.cq_ex, - is_first: true, - _phantom: PhantomData, - }), - libc::ENOENT => Err(PollCompletionQueueError::CompletionQueueEmpty), - err => Err(PollCompletionQueueError::Ibverbs(io::Error::from_raw_os_error(err))), - } + pub fn iter(&self) -> Result, PollCompletionQueueError> { + ExtendedCompletionQueueIter::new(self.cq_ex) } } @@ -635,33 +589,69 @@ enum BasicCompletionQueueState { // TODO: provide a trait for poller? /// The basic `Poller` that works for [`BasicCompletionQueue`] for getting Work Completions in an /// iterator style. -pub struct BasicPoller<'cq> { +pub struct BasicCompletionQueueIter<'cq> { cq: NonNull, wcs: Vec, status: BasicCompletionQueueState, - current: usize, + next: usize, _phantom: PhantomData<&'cq ()>, } +impl<'cq> BasicCompletionQueueIter<'cq> { + pub fn new(cq: NonNull, num_entries: i32) -> Result { + let mut cqes = Vec::::with_capacity(num_entries as _); + + let ret = unsafe { ibv_poll_cq(cq.as_ptr(), num_entries, cqes.as_mut_ptr()) }; + + match ret { + err if err < 0 => Err(PollCompletionQueueError::Ibverbs(io::Error::from_raw_os_error(-err))), + 0 => Ok(Self { + cq, + wcs: Vec::new(), + status: BasicCompletionQueueState::Empty, + next: 0, + _phantom: PhantomData, + }), + res => Ok(Self { + cq, + wcs: { + // SAFETY: ibv_poll_cq returns the number of valid work completions + unsafe { cqes.set_len(res as _) }; + cqes + }, + status: if res < num_entries { + BasicCompletionQueueState::Drained + } else { + BasicCompletionQueueState::Ready + }, + next: 0, + _phantom: PhantomData, + }), + } + } +} + // TODO: implement BasicPoller with lending iterator for better performance. -impl<'cq> Iterator for BasicPoller<'cq> { +impl<'cq> Iterator for BasicCompletionQueueIter<'cq> { type Item = BasicWorkCompletion<'cq>; fn next(&mut self) -> Option { use BasicCompletionQueueState::*; - let current = self.current; + if self.status == Empty { + return None; + } + + let current = self.next; let len = self.wcs.len(); if (self.status == Ready || self.status == Drained) && current < len { - let wc = unsafe { - BasicWorkCompletion { - wc: *self.wcs.get_unchecked(current), - _phantom: PhantomData, - } - }; - self.current += 1; - return Some(wc); + self.next += 1; + return Some(BasicWorkCompletion { + // SAFETY: current < len, so index current is valid + wc: unsafe { *self.wcs.get_unchecked(current) }, + _phantom: PhantomData, + }); } if self.status == Drained && current >= len { @@ -682,66 +672,97 @@ impl<'cq> Iterator for BasicPoller<'cq> { ) }; - if ret > 0 { - unsafe { - if ret < self.wcs.capacity().try_into().unwrap_unchecked() { + match ret { + res if res > 0 => { + // SAFETY: the capacity is previously checked to be convertable to i32 + if ret < unsafe { self.wcs.capacity().try_into().unwrap_unchecked() } { self.status = Drained; } else { self.status = Ready; } - self.wcs.set_len(ret as usize); - } - let wc = unsafe { - BasicWorkCompletion { - wc: *self.wcs.get_unchecked(0), + // SAFETY: ibv_poll_cq returns the number of valid work completions + unsafe { self.wcs.set_len(ret as usize) }; + self.next = 1; + Some(BasicWorkCompletion { + // SAFETY: ret > 0, so index 0 is valid + wc: unsafe { *self.wcs.get_unchecked(0) }, _phantom: PhantomData, - } - }; - self.current = 1; - Some(wc) - } else { - self.status = Empty; - None + }) + }, + 0 => { + self.status = Empty; + None + }, + _ => { + unreachable!() + }, } } } /// The extended `Poller` that works for [`ExtendedCompletionQueue`] for getting Work Completions in /// an iterator style. -pub struct ExtendedPoller<'cq> { +pub struct ExtendedCompletionQueueIter<'cq> { cq: NonNull, is_first: bool, + is_done: bool, _phantom: PhantomData<&'cq ()>, } -impl Drop for ExtendedPoller<'_> { +impl<'cq> ExtendedCompletionQueueIter<'cq> { + pub fn new(cq: NonNull) -> Result { + let ret = unsafe { ibv_start_poll(cq.as_ptr(), MaybeUninit::::zeroed().as_mut_ptr()) }; + match ret { + 0 => Ok(Self { + cq, + is_first: true, + is_done: false, + _phantom: PhantomData, + }), + libc::ENOENT => Ok(Self { + cq, + is_first: false, + is_done: true, + _phantom: PhantomData, + }), + err => Err(PollCompletionQueueError::Ibverbs(io::Error::from_raw_os_error(err))), + } + } +} + +impl Drop for ExtendedCompletionQueueIter<'_> { fn drop(&mut self) { + if self.is_done { + return; + } unsafe { ibv_end_poll(self.cq.as_ptr()) } } } -impl<'cq> Iterator for ExtendedPoller<'cq> { +impl<'cq> Iterator for ExtendedCompletionQueueIter<'cq> { type Item = ExtendedWorkCompletion<'cq>; fn next(&mut self) -> Option { + if self.is_done { + return None; + } if self.is_first { self.is_first = false; + return Some(ExtendedWorkCompletion { + cq: self.cq, + _phantom: PhantomData, + }); + } + + let ret = unsafe { ibv_next_poll(self.cq.as_ptr()) }; + if ret != 0 { + None + } else { Some(ExtendedWorkCompletion { cq: self.cq, _phantom: PhantomData, }) - } else { - let ret = unsafe { ibv_next_poll(self.cq.as_ptr()) }; - - if ret != 0 { - None - } else { - Some(ExtendedWorkCompletion { - cq: self.cq, - _phantom: PhantomData, - }) - } } } } @@ -777,15 +798,15 @@ impl CompletionQueue for GenericCompletionQueue { /// A unified interface for [`BasicPoller`] and [`ExtendedPoller`], implemented with enum /// dispatching. pub enum GenericPoller<'cq> { - Basic(BasicPoller<'cq>), - Extended(ExtendedPoller<'cq>), + Basic(BasicCompletionQueueIter<'cq>), + Extended(ExtendedCompletionQueueIter<'cq>), } impl GenericCompletionQueue { - pub fn start_poll(&self) -> Result, PollCompletionQueueError> { + pub fn iter(&self) -> Result, PollCompletionQueueError> { match self { - GenericCompletionQueue::Basic(cq) => cq.start_poll().map(GenericPoller::Basic), - GenericCompletionQueue::Extended(cq) => cq.start_poll().map(GenericPoller::Extended), + GenericCompletionQueue::Basic(cq) => cq.iter().map(GenericPoller::Basic), + GenericCompletionQueue::Extended(cq) => cq.iter().map(GenericPoller::Extended), } } } diff --git a/tests/test_post_send.rs b/tests/test_post_send.rs index 2d7a0e8..f2f3282 100644 --- a/tests/test_post_send.rs +++ b/tests/test_post_send.rs @@ -183,8 +183,7 @@ fn main(#[case] use_qp_ex: bool, #[case] use_cq_ex: bool) -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box Date: Tue, 2 Dec 2025 23:58:22 -0500 Subject: [PATCH 4/4] chore: fix a sq/rq typo Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> chore: fix a sq/rq typo Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/test_post_send.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_post_send.rs b/tests/test_post_send.rs index f2f3282..cc1c9a1 100644 --- a/tests/test_post_send.rs +++ b/tests/test_post_send.rs @@ -195,7 +195,7 @@ fn main(#[case] use_qp_ex: bool, #[case] use_cq_ex: bool) -> Result<(), Box Result<(), Box