diff --git a/Justfile b/Justfile index 781c653..9cd864f 100644 --- a/Justfile +++ b/Justfile @@ -20,6 +20,11 @@ test-rc-pingpong-with-cov: sleep 2 cargo llvm-cov --no-report run --example rc_pingpong_split -- -d {{rdma_dev}} -g 1 127.0.0.1 +test-ud-pingpong-with-cov: + cargo llvm-cov --no-report run --example ud_pingpong_split -- -d {{rdma_dev}} -g 1 & + sleep 2 + cargo llvm-cov --no-report run --example ud_pingpong_split -- -d {{rdma_dev}} -g 1 127.0.0.1 + test-cmtime-with-cov: cargo llvm-cov --no-report run --example cmtime -- -b {{ip}} & sleep 2 diff --git a/examples/ud_pingpong_split.rs b/examples/ud_pingpong_split.rs new file mode 100644 index 0000000..f8708f9 --- /dev/null +++ b/examples/ud_pingpong_split.rs @@ -0,0 +1,538 @@ +#![allow(clippy::too_many_arguments)] + +use std::io::{Error, Read, Write}; +use std::net::{IpAddr, Ipv6Addr, SocketAddr, TcpListener, TcpStream}; +use std::str::FromStr; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use byte_unit::{Byte, UnitType}; +use clap::{Parser, ValueEnum}; +use postcard::{from_bytes, to_allocvec}; +use serde::{Deserialize, Serialize}; +use sideway::ibverbs::address::{AddressHandle, AddressHandleAttribute, Gid, GlobalRoutingHeader, GRH_HEADER_LEN}; +use sideway::ibverbs::completion::{ + CreateCompletionQueueWorkCompletionFlags, ExtendedCompletionQueue, ExtendedWorkCompletion, GenericCompletionQueue, + WorkCompletionStatus, +}; +use sideway::ibverbs::device::{Device, DeviceInfo, DeviceList}; +use sideway::ibverbs::device_context::{DeviceContext, Mtu}; +use sideway::ibverbs::memory_region::MemoryRegion; +use sideway::ibverbs::protection_domain::ProtectionDomain; +use sideway::ibverbs::queue_pair::{ + ExtendedQueuePair, PostSendGuard, QueuePair, QueuePairAttribute, QueuePairState, QueuePairType, SendOperationFlags, + SetScatterGatherEntry, WorkRequestFlags, +}; +use sideway::ibverbs::AccessFlags; + +const SEND_WR_ID: u64 = 0; +const RECV_WR_ID: u64 = 1; +const DEFAULT_QKEY: u32 = 0x1111_1111; + +#[derive(Debug, Parser)] +#[clap(name = "ud_pingpong", version = "0.1.0")] +pub struct Args { + /// Listen on / connect to port + #[clap(long, short = 'p', default_value_t = 18515)] + port: u16, + /// The IB device to use + #[clap(long, short = 'd')] + ib_dev: Option, + /// The port of IB device + #[clap(long, short = 'i', default_value_t = 1)] + ib_port: u8, + /// The size of message to exchange + #[clap(long, short = 's', default_value_t = 1024)] + size: u32, + /// Path MTU + #[clap(long, short = 'm', value_enum, default_value_t = PathMtu(Mtu::Mtu1024))] + mtu: PathMtu, + /// Numbers of receives to post at a time + #[clap(long, short = 'r', default_value_t = 500)] + rx_depth: u32, + /// Numbers of exchanges + #[clap(long, short = 'n', default_value_t = 1000)] + iter: u32, + /// Service level value + #[clap(long, short = 'l', default_value_t = 0)] + sl: u8, + /// Local port GID index + #[clap(long, short = 'g', default_value_t = 0)] + gid_idx: u8, + /// Get CQE with timestamp + #[arg(long, short = 't', default_value_t = false)] + ts: bool, + /// Print GRH on each received packet + #[arg(long, default_value_t = false)] + debug_grh: bool, + /// If no value provided, start a server and wait for connection, otherwise, connect to server at [host] + #[arg(name = "host")] + server_ip: Option, +} + +#[derive(Clone, Copy, Debug)] +struct PathMtu(Mtu); + +impl ValueEnum for PathMtu { + fn value_variants<'a>() -> &'a [Self] { + &[ + Self(Mtu::Mtu256), + Self(Mtu::Mtu512), + Self(Mtu::Mtu1024), + Self(Mtu::Mtu2048), + Self(Mtu::Mtu4096), + ] + } + + fn to_possible_value(&self) -> Option { + match self.0 { + Mtu::Mtu256 => Some(clap::builder::PossibleValue::new("256")), + Mtu::Mtu512 => Some(clap::builder::PossibleValue::new("512")), + Mtu::Mtu1024 => Some(clap::builder::PossibleValue::new("1024")), + Mtu::Mtu2048 => Some(clap::builder::PossibleValue::new("2048")), + Mtu::Mtu4096 => Some(clap::builder::PossibleValue::new("4096")), + } + } +} + +struct PingPongContext { + ctx: Arc, + pd: Arc, + _send_buf: Arc>, + send_mr: Arc, + _recv_buf: Arc>, + recv_mr: Arc, + cq: Arc, + qp: ExtendedQueuePair, + size: u32, + completion_timestamp_mask: u64, +} + +impl PingPongContext { + fn build(device: &Device, size: u32, rx_depth: u32, ib_port: u8, use_ts: bool) -> Result { + let context: Arc = device + .open() + .with_context(|| format!("Couldn't get context for {}", device.name()))?; + + let attr = context.query_device().with_context(|| "Failed to query device")?; + + let completion_timestamp_mask = if use_ts { + match attr.completion_timestamp_mask() { + 0 => anyhow::bail!("The device isn't completion timestamp capable"), + mask => mask, + } + } else { + 0 + }; + + let pd = context.alloc_pd().context("Failed to allocate PD")?; + + let send_buf = Arc::new(vec![0; size as usize]); + let send_mr = unsafe { + pd.reg_mr( + send_buf.as_ptr() as usize, + send_buf.len(), + AccessFlags::LocalWrite | AccessFlags::RemoteWrite, + ) + .context("Failed to register send MR")? + }; + + let recv_buf = Arc::new(vec![0; size as usize + GRH_HEADER_LEN]); + let recv_mr = unsafe { + pd.reg_mr( + recv_buf.as_ptr() as usize, + recv_buf.len(), + AccessFlags::LocalWrite | AccessFlags::RemoteWrite, + ) + .context("Failed to register recv MR")? + }; + + let mut cq_builder = context.create_cq_builder(); + if use_ts { + cq_builder.setup_wc_flags( + CreateCompletionQueueWorkCompletionFlags::StandardFlags + | CreateCompletionQueueWorkCompletionFlags::CompletionTimestamp, + ); + } + let cq = cq_builder.setup_cqe(rx_depth + 1).build_ex()?; + + let cq_for_qp = GenericCompletionQueue::from(Arc::clone(&cq)); + + let mut builder = pd.create_qp_builder(); + + let mut qp = builder + .setup_qp_type(QueuePairType::UnreliableDatagram) + .setup_max_inline_data(0) + .setup_send_cq(cq_for_qp.clone()) + .setup_recv_cq(cq_for_qp) + .setup_max_send_wr(1) + .setup_max_recv_wr(rx_depth) + .setup_send_ops_flags(SendOperationFlags::Send | SendOperationFlags::SendWithImmediate) + .build_ex() + .context("Failed to create QP")?; + + let mut attr = QueuePairAttribute::new(); + attr.setup_state(QueuePairState::Init) + .setup_pkey_index(0) + .setup_port(ib_port) + .setup_qkey(DEFAULT_QKEY); + qp.modify(&attr).context("Failed to modify QP to INIT")?; + + let mut attr = QueuePairAttribute::new(); + attr.setup_state(QueuePairState::ReadyToReceive); + qp.modify(&attr).context("Failed to modify QP to RTR")?; + + let mut attr = QueuePairAttribute::new(); + attr.setup_state(QueuePairState::ReadyToSend).setup_sq_psn(1); + qp.modify(&attr).context("Failed to modify QP to RTS")?; + + Ok(PingPongContext { + ctx: context, + pd, + _send_buf: send_buf, + send_mr, + _recv_buf: recv_buf, + recv_mr, + cq, + qp, + size, + completion_timestamp_mask, + }) + } + + fn post_recv(&mut self, num: u32) -> Result<()> { + for _ in 0..num { + let mut guard = self.qp.start_post_recv(); + let lkey = self.recv_mr.lkey(); + let ptr = self.recv_mr.get_ptr() as u64; + let size = self.size + GRH_HEADER_LEN as u32; + + let recv_handle = guard.construct_wr(RECV_WR_ID); + + unsafe { + recv_handle.setup_sge(lkey, ptr, size); + }; + + guard.post()?; + } + + Ok(()) + } + + fn create_address_handle( + &self, remote_context: &PingPongDestination, ib_port: u8, sl: u8, gid_idx: u8, + ) -> Result { + let mut ah_attr = AddressHandleAttribute::new(); + ah_attr + .setup_dest_lid(remote_context.lid) + .setup_port(ib_port) + .setup_service_level(sl) + .setup_grh_src_gid_index(gid_idx) + .setup_grh_dest_gid(&remote_context.gid) + .setup_grh_hop_limit(1); + + self.pd + .create_ah(&mut ah_attr) + .context("Failed to create address handle") + } + + fn post_send(&mut self, ah: &AddressHandle, remote_context: &PingPongDestination) -> Result<()> { + let mut guard = self.qp.start_post_send(); + + let send_handle = guard + .construct_wr(SEND_WR_ID, WorkRequestFlags::Signaled) + .setup_ud_addr(ah, remote_context.qp_number, remote_context.qkey) + .setup_send(); + unsafe { send_handle.setup_sge(self.send_mr.lkey(), self.send_mr.get_ptr() as u64, self.size) }; + + guard.post()?; + + Ok(()) + } + + #[inline] + fn parse_single_work_completion( + &self, wc: &ExtendedWorkCompletion, ts_param: &mut TimeStamps, scnt: &mut u32, rcnt: &mut u32, + outstanding_send: &mut bool, rout: &mut u32, rx_depth: u32, need_post_recv: &mut bool, to_post_recv: &mut u32, + use_ts: bool, debug_grh: bool, + ) { + if wc.status() != WorkCompletionStatus::Success as u32 { + panic!( + "Failed status {:#?} ({}) for wr_id {}", + Into::::into(wc.status()), + wc.status(), + wc.wr_id() + ); + } + match wc.wr_id() { + SEND_WR_ID => { + *scnt += 1; + *outstanding_send = false; + }, + RECV_WR_ID => { + *rcnt += 1; + *rout -= 1; + + // Print GRH if debug mode is enabled + if debug_grh { + let recv_buf = unsafe { + std::slice::from_raw_parts( + self.recv_mr.get_ptr() as *const u8, + self.size as usize + GRH_HEADER_LEN, + ) + }; + match GlobalRoutingHeader::new_checked(recv_buf) { + Ok(grh) => { + println!( + "[recv #{}] GRH: version={}, traffic_class={:#04x}, flow_label={:#07x}, \ + payload_len={}, next_hdr={:#04x}, hop_limit={}, src_gid={}, dst_gid={}", + *rcnt, + grh.version(), + grh.traffic_class(), + grh.flow_label(), + grh.payload_length(), + grh.next_header(), + grh.hop_limit(), + grh.source_gid(), + grh.destination_gid() + ); + }, + Err(e) => { + eprintln!("[recv #{}] Failed to parse GRH: {}", *rcnt, e); + }, + } + } + + if *rout <= rx_depth / 2 { + *to_post_recv = rx_depth - *rout; + *need_post_recv = true; + } + + if use_ts { + let timestamp = wc.completion_timestamp(); + if ts_param.last_completion_with_timestamp != 0 { + let delta: u64 = if timestamp >= ts_param.completion_recv_prev_time { + timestamp - ts_param.completion_recv_prev_time + } else { + self.completion_timestamp_mask - ts_param.completion_recv_prev_time + timestamp + 1 + }; + + ts_param.completion_recv_max_time_delta = ts_param.completion_recv_max_time_delta.max(delta); + ts_param.completion_recv_min_time_delta = ts_param.completion_recv_min_time_delta.min(delta); + ts_param.completion_recv_total_time_delta += delta; + ts_param.completion_with_time_iters += 1; + } + + ts_param.completion_recv_prev_time = timestamp; + ts_param.last_completion_with_timestamp = 1; + } else { + ts_param.last_completion_with_timestamp = 0; + } + }, + _ => { + panic!("Unknown wr_id {}", wc.wr_id()); + }, + } + } +} + +#[derive(Deserialize, Serialize, Debug)] +struct PingPongDestination { + lid: u16, + qp_number: u32, + qkey: u32, + gid: Gid, +} + +#[derive(Debug, Default)] +struct TimeStamps { + completion_recv_max_time_delta: u64, + completion_recv_min_time_delta: u64, + completion_recv_total_time_delta: u64, + completion_recv_prev_time: u64, + last_completion_with_timestamp: u32, + completion_with_time_iters: u32, +} + +#[allow(clippy::while_let_on_iterator)] +fn main() -> Result<()> { + let args = Args::parse(); + let mut scnt: u32 = 0; + let mut rcnt: u32 = 0; + let mut rout: u32 = 0; + let rx_depth = if args.iter > args.rx_depth { + args.rx_depth + } else { + args.iter + }; + let mut ts_param = TimeStamps { + completion_recv_min_time_delta: u64::MAX, + ..Default::default() + }; + + let device_list = DeviceList::new().with_context(|| "Failed to get IB devices list")?; + let device = match args.ib_dev { + Some(ib_dev) => device_list + .iter() + .find(|dev| dev.name().eq(&ib_dev)) + .with_context(|| format!("IB device {ib_dev} not found"))?, + None => device_list.iter().next().with_context(|| "No IB device found")?, + }; + + let mut ctx = PingPongContext::build(&device, args.size, rx_depth, args.ib_port, args.ts)?; + + let gid = ctx.ctx.query_gid(args.ib_port, args.gid_idx.into())?; + let port_attr = ctx.ctx.query_port(args.ib_port)?; + let lid = port_attr.lid(); + + ctx.post_recv(rx_depth)?; + rout += rx_depth; + + println!( + " local address: QPN {:#06x}, QKey {DEFAULT_QKEY:#010x}, LID {lid:#06x}, GID {gid}", + ctx.qp.qp_number() + ); + + let mut stream = match args.server_ip { + Some(ref ip_str) => { + let ip = IpAddr::from_str(ip_str).context("Invalid IP address")?; + let server_addr = SocketAddr::from((ip, args.port)); + TcpStream::connect(server_addr)? + }, + None => { + let server_addr = SocketAddr::from((Ipv6Addr::UNSPECIFIED, args.port)); + let listener = TcpListener::bind(server_addr)?; + let (stream, _peer_addr) = listener.accept()?; + stream + }, + }; + + let send_context = |stream: &mut TcpStream, dest: &PingPongDestination| { + let msg_buf = to_allocvec(dest).unwrap(); + let size = msg_buf.len().to_be_bytes(); + stream.write_all(&size)?; + stream.write_all(&msg_buf)?; + stream.flush()?; + + Ok::<(), Error>(()) + }; + + let recv_context = |stream: &mut TcpStream, msg_buf: &mut Vec| { + let mut size = usize::to_be_bytes(0); + stream.read_exact(&mut size)?; + msg_buf.clear(); + msg_buf.resize(usize::from_be_bytes(size), 0); + stream.read_exact(&mut *msg_buf)?; + let dest: PingPongDestination = from_bytes(msg_buf).unwrap(); + + Ok::(dest) + }; + + let local_context = PingPongDestination { + lid, + qp_number: ctx.qp.qp_number(), + qkey: DEFAULT_QKEY, + gid, + }; + let mut msg_buf = Vec::new(); + send_context(&mut stream, &local_context)?; + let remote_context = recv_context(&mut stream, &mut msg_buf)?; + println!( + "remote address: QPN {:#06x}, QKey {:#010x}, LID {:#06x}, GID {}", + remote_context.qp_number, remote_context.qkey, remote_context.lid, remote_context.gid + ); + + let remote_ah = ctx.create_address_handle(&remote_context, args.ib_port, args.sl, args.gid_idx)?; + + let clock = quanta::Clock::new(); + let start_time = clock.now(); + let mut outstanding_send = false; + + if args.server_ip.is_some() { + ctx.post_send(&remote_ah, &remote_context)?; + outstanding_send = true; + } + // poll for the completion + loop { + let mut need_post_recv = false; + let mut to_post_recv = 0; + let mut need_post_send = false; + + match ctx.cq.start_poll() { + Ok(mut poller) => { + while let Some(wc) = poller.next() { + ctx.parse_single_work_completion( + &wc, + &mut ts_param, + &mut scnt, + &mut rcnt, + &mut outstanding_send, + &mut rout, + rx_depth, + &mut need_post_recv, + &mut to_post_recv, + args.ts, + args.debug_grh, + ); + + if scnt < args.iter && !outstanding_send { + need_post_send = true; + outstanding_send = true; + } + } + }, + Err(_) => { + continue; + }, + } + + if need_post_recv { + ctx.post_recv(to_post_recv)?; + rout += to_post_recv; + } + + if need_post_send { + ctx.post_send(&remote_ah, &remote_context)?; + } + + if scnt >= args.iter && rcnt >= args.iter { + break; + } + } + + let end_time = clock.now(); + let time = end_time.duration_since(start_time); + let bytes = args.size as u64 * args.iter as u64 * 2; + let bytes_per_second = bytes as f64 / time.as_secs_f64(); + println!( + "{} bytes in {:.2} seconds = {:.2}/s", + bytes, + time.as_secs_f64(), + Byte::from_f64(bytes_per_second) + .unwrap() + .get_appropriate_unit(UnitType::Binary) + ); + println!( + "{} iters in {:.2} seconds = {:#.2?}/iter", + args.iter, + time.as_secs_f64(), + time / args.iter + ); + + if args.ts && ts_param.completion_with_time_iters != 0 { + println!( + "Max receive completion clock cycles = {}", + ts_param.completion_recv_max_time_delta + ); + println!( + "Min receive completion clock cycles = {}", + ts_param.completion_recv_min_time_delta + ); + println!( + "Average receive completion clock cycles = {}", + ts_param.completion_recv_total_time_delta as f64 / ts_param.completion_with_time_iters as f64 + ); + } + + Ok(()) +} diff --git a/src/ibverbs/address.rs b/src/ibverbs/address.rs index e7a1844..a319daa 100644 --- a/src/ibverbs/address.rs +++ b/src/ibverbs/address.rs @@ -2,13 +2,42 @@ //! in RDMA communication, like [`Gid`] and [`AddressHandleAttribute`]. use libc::IF_NAMESIZE; use rdma_mummy_sys::{ - ibv_ah_attr, ibv_gid, ibv_gid_entry, ibv_global_route, IBV_GID_TYPE_IB, IBV_GID_TYPE_ROCE_V1, IBV_GID_TYPE_ROCE_V2, + ibv_ah, ibv_ah_attr, ibv_create_ah, ibv_gid, ibv_gid_entry, ibv_global_route, ibv_grh, IBV_GID_TYPE_IB, + IBV_GID_TYPE_ROCE_V1, IBV_GID_TYPE_ROCE_V2, }; use serde::{Deserialize, Serialize}; use std::ffi::CStr; use std::io; +use std::ptr; +use std::ptr::NonNull; +use std::sync::Arc; use std::{fmt, mem::MaybeUninit, net::Ipv6Addr}; +use super::protection_domain::ProtectionDomain; + +/// Error returned by [`ProtectionDomain::create_ah`] for creating a new RDMA Address Handle. +#[derive(Debug, thiserror::Error)] +#[error("failed to create address handle")] +#[non_exhaustive] +pub struct CreateAddressHandleError(#[from] pub CreateAddressHandleErrorKind); + +/// The enum type for [`CreateAddressHandleError`]. +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +#[non_exhaustive] +pub enum CreateAddressHandleErrorKind { + Ibverbs(#[from] io::Error), +} + +/// Error returned when constructing a [`GlobalRoutingHeader`] from invalid input. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum GlobalRoutingHeaderError { + /// The provided slice is smaller than the GRH header size. + #[error("raw slice length {actual} is smaller than required {expected} bytes")] + SliceTooSmall { actual: usize, expected: usize }, +} + /// GID is a global identifier for sending packets between different subnets. For RoCEv1 and RoCEv2, /// it would correspond to an IP address set up on the ethernet device. #[derive(Default, Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] @@ -251,14 +280,256 @@ impl AddressHandleAttribute { } } +pub struct AddressHandle { + pub(crate) handle: NonNull, + pub(crate) _pd: Arc, +} + +impl AddressHandle { + pub fn new(pd: Arc, attr: &mut AddressHandleAttribute) -> Result { + let pd_ptr = unsafe { pd.pd() }; + let ah = unsafe { ibv_create_ah(pd_ptr.as_ptr(), &mut attr.attr as *mut _) }; + + Ok(AddressHandle { + handle: NonNull::new(ah).ok_or::( + CreateAddressHandleErrorKind::Ibverbs(io::Error::last_os_error()).into(), + )?, + _pd: pd, + }) + } + + /// # Safety + /// + /// Return the handle of address handle. + /// We mark this method unsafe because the lifetime of `ibv_ah` is not associated + /// with the return value. + pub unsafe fn ah(&self) -> NonNull { + self.handle + } +} + +/// Size in bytes of a Global Routing Header. +pub const GRH_HEADER_LEN: usize = 40; + +/// A read/write wrapper around a Global Routing Header buffer. +/// +/// This generic wrapper allows zero-copy access to GRH fields directly from a byte buffer, +/// following the smoltcp packet wrapper pattern. +#[derive(Debug)] +pub struct GlobalRoutingHeader> { + buffer: T, +} + +// Field byte offsets within the GRH +mod grh_field { + use std::ops::Range; + + pub const VERSION_TCLASS_FLOW: Range = 0..4; + pub const PAYLOAD_LENGTH: Range = 4..6; + pub const NEXT_HEADER: usize = 6; + pub const HOP_LIMIT: usize = 7; + pub const SOURCE_GID: Range = 8..24; + pub const DESTINATION_GID: Range = 24..40; +} + +impl> GlobalRoutingHeader { + /// Create a raw octet buffer with a Global Routing Header structure. + /// + /// # Panics + /// + /// Accessor methods will panic if the buffer is smaller than [`GRH_HEADER_LEN`]. + /// Use [`new_checked`](Self::new_checked) if the buffer length is not guaranteed. + pub const fn new_unchecked(buffer: T) -> Self { + GlobalRoutingHeader { buffer } + } + + /// Shorthand for a combination of [`new_unchecked`](Self::new_unchecked) and + /// [`check_len`](Self::check_len). + pub fn new_checked(buffer: T) -> Result { + let header = Self::new_unchecked(buffer); + header.check_len()?; + Ok(header) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err` if the buffer is too short. + pub fn check_len(&self) -> Result<(), GlobalRoutingHeaderError> { + let len = self.buffer.as_ref().len(); + if len < GRH_HEADER_LEN { + return Err(GlobalRoutingHeaderError::SliceTooSmall { + actual: len, + expected: GRH_HEADER_LEN, + }); + } + Ok(()) + } + + /// Consume the header wrapper, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return a reference to the underlying buffer. + pub fn buffer(&self) -> &[u8] { + self.buffer.as_ref() + } + + /// Return a const pointer to the header buffer, suitable for FFI. + /// + /// The pointer is valid for the lifetime of the underlying buffer. + #[inline] + pub fn as_ptr(&self) -> *const u8 { + self.buffer.as_ref().as_ptr() + } + + /// Copy the header bytes into a newly allocated `ibv_grh` structure. + /// + /// This is useful for FFI calls that require an owned `ibv_grh`. + #[inline] + pub fn grh(&self) -> ibv_grh { + unsafe { ptr::read_unaligned(self.buffer.as_ref().as_ptr().cast::()) } + } + + /// Get the GRH version number (4 bits). + #[inline] + pub fn version(&self) -> u8 { + (self.buffer.as_ref()[grh_field::VERSION_TCLASS_FLOW.start] >> 4) & 0x0f + } + + /// Get the IPv6 traffic class encoded in the header (8 bits). + #[inline] + pub fn traffic_class(&self) -> u8 { + let data = &self.buffer.as_ref()[grh_field::VERSION_TCLASS_FLOW]; + ((data[0] & 0x0f) << 4) | ((data[1] >> 4) & 0x0f) + } + + /// Get the IPv6-style flow label carried in the header (20 bits). + #[inline] + pub fn flow_label(&self) -> u32 { + let data = &self.buffer.as_ref()[grh_field::VERSION_TCLASS_FLOW]; + u32::from_be_bytes([0, data[1] & 0x0f, data[2], data[3]]) + } + + /// Payload length in bytes following the GRH. + #[inline] + pub fn payload_length(&self) -> u16 { + let data = &self.buffer.as_ref()[grh_field::PAYLOAD_LENGTH]; + u16::from_be_bytes([data[0], data[1]]) + } + + /// Next header identifier. + #[inline] + pub fn next_header(&self) -> u8 { + self.buffer.as_ref()[grh_field::NEXT_HEADER] + } + + /// Hop limit value for the packet. + #[inline] + pub fn hop_limit(&self) -> u8 { + self.buffer.as_ref()[grh_field::HOP_LIMIT] + } + + /// Source GID extracted from the header. + #[inline] + pub fn source_gid(&self) -> Gid { + let mut raw = [0u8; 16]; + raw.copy_from_slice(&self.buffer.as_ref()[grh_field::SOURCE_GID]); + Gid { raw } + } + + /// Destination GID extracted from the header. + #[inline] + pub fn destination_gid(&self) -> Gid { + let mut raw = [0u8; 16]; + raw.copy_from_slice(&self.buffer.as_ref()[grh_field::DESTINATION_GID]); + Gid { raw } + } +} + +impl + AsMut<[u8]>> GlobalRoutingHeader { + /// Return a mutable reference to the underlying buffer. + pub fn buffer_mut(&mut self) -> &mut [u8] { + self.buffer.as_mut() + } + + /// Return a mutable pointer to the header buffer, suitable for FFI. + /// + /// The pointer is valid for the lifetime of the underlying buffer. + #[inline] + pub fn as_mut_ptr(&mut self) -> *mut u8 { + self.buffer.as_mut().as_mut_ptr() + } + + /// Set the GRH version number (4 bits). + #[inline] + pub fn setup_version(&mut self, value: u8) { + let data = &mut self.buffer.as_mut()[grh_field::VERSION_TCLASS_FLOW.start]; + *data = (*data & 0x0f) | ((value & 0x0f) << 4); + } + + /// Set the IPv6 traffic class (8 bits). + #[inline] + pub fn setup_traffic_class(&mut self, value: u8) { + let data = &mut self.buffer.as_mut()[grh_field::VERSION_TCLASS_FLOW]; + data[0] = (data[0] & 0xf0) | ((value >> 4) & 0x0f); + data[1] = (data[1] & 0x0f) | ((value & 0x0f) << 4); + } + + /// Set the IPv6-style flow label (20 bits). + #[inline] + pub fn setup_flow_label(&mut self, value: u32) { + let data = &mut self.buffer.as_mut()[grh_field::VERSION_TCLASS_FLOW]; + let bytes = value.to_be_bytes(); + data[1] = (data[1] & 0xf0) | (bytes[1] & 0x0f); + data[2] = bytes[2]; + data[3] = bytes[3]; + } + + /// Set the payload length. + #[inline] + pub fn setup_payload_length(&mut self, value: u16) { + let bytes = value.to_be_bytes(); + self.buffer.as_mut()[grh_field::PAYLOAD_LENGTH].copy_from_slice(&bytes); + } + + /// Set the next header identifier. + #[inline] + pub fn setup_next_header(&mut self, value: u8) { + self.buffer.as_mut()[grh_field::NEXT_HEADER] = value; + } + + /// Set the hop limit. + #[inline] + pub fn setup_hop_limit(&mut self, value: u8) { + self.buffer.as_mut()[grh_field::HOP_LIMIT] = value; + } + + /// Set the source GID. + #[inline] + pub fn setup_source_gid(&mut self, gid: &Gid) { + self.buffer.as_mut()[grh_field::SOURCE_GID].copy_from_slice(&gid.raw); + } + + /// Set the destination GID. + #[inline] + pub fn setup_destination_gid(&mut self, gid: &Gid) { + self.buffer.as_mut()[grh_field::DESTINATION_GID].copy_from_slice(&gid.raw); + } +} + #[cfg(test)] mod tests { - use crate::ibverbs::address::Gid; - use rdma_mummy_sys::ibv_gid; + use super::{Gid, GlobalRoutingHeader, GlobalRoutingHeaderError, GRH_HEADER_LEN}; + use rdma_mummy_sys::{ibv_gid, ibv_grh}; use rstest::rstest; use std::net::Ipv6Addr; use std::str::FromStr; + #[test] + fn test_grh_header_len_matches_ibv_grh_size() { + assert_eq!(GRH_HEADER_LEN, std::mem::size_of::()); + } + #[rstest] #[case("fe80::", true)] #[case("fe80::1", true)] @@ -280,4 +551,105 @@ mod tests { let gid = Gid::from(gid_); assert_eq!(format!("{gid}"), expected); } + + #[test] + fn test_global_routing_header_read() { + let src_gid = Gid::from(Ipv6Addr::from_str("fe80::1").unwrap()); + let dst_gid = Gid::from(Ipv6Addr::from_str("fe80::2").unwrap()); + + // Build raw bytes: version=6, traffic_class=0x12, flow_label=0x34567 + let mut raw = vec![0u8; GRH_HEADER_LEN + 4]; // 4 bytes payload + raw[0] = 0x61; // version=6, traffic_class high nibble=1 + raw[1] = 0x23; // traffic_class low nibble=2, flow_label high nibble=3 + raw[2] = 0x45; // flow_label + raw[3] = 0x67; // flow_label + raw[4] = 0x00; // payload length high byte + raw[5] = 0x04; // payload length low byte = 4 + raw[6] = 0x1b; // next header + raw[7] = 0x40; // hop limit + raw[8..24].copy_from_slice(&src_gid.raw); + raw[24..40].copy_from_slice(&dst_gid.raw); + raw[40..44].copy_from_slice(&[0xde, 0xad, 0xbe, 0xef]); // payload + + let header = GlobalRoutingHeader::new_checked(raw.as_slice()).unwrap(); + assert_eq!(header.version(), 6); + assert_eq!(header.traffic_class(), 0x12); + assert_eq!(header.flow_label(), 0x34567); + assert_eq!(header.payload_length(), 4); + assert_eq!(header.next_header(), 0x1b); + assert_eq!(header.hop_limit(), 0x40); + assert_eq!(header.source_gid(), src_gid); + assert_eq!(header.destination_gid(), dst_gid); + } + + #[test] + fn test_global_routing_header_write() { + let src_gid = Gid::from(Ipv6Addr::from_str("fe80::1").unwrap()); + let dst_gid = Gid::from(Ipv6Addr::from_str("fe80::2").unwrap()); + + let mut raw = vec![0u8; GRH_HEADER_LEN]; + let mut header = GlobalRoutingHeader::new_checked(raw.as_mut_slice()).unwrap(); + + header.setup_version(6); + header.setup_traffic_class(0x12); + header.setup_flow_label(0x34567); + header.setup_payload_length(100); + header.setup_next_header(0x1b); + header.setup_hop_limit(64); + header.setup_source_gid(&src_gid); + header.setup_destination_gid(&dst_gid); + + // Verify by reading back + let header = GlobalRoutingHeader::new_unchecked(raw.as_slice()); + assert_eq!(header.version(), 6); + assert_eq!(header.traffic_class(), 0x12); + assert_eq!(header.flow_label(), 0x34567); + assert_eq!(header.payload_length(), 100); + assert_eq!(header.next_header(), 0x1b); + assert_eq!(header.hop_limit(), 64); + assert_eq!(header.source_gid(), src_gid); + assert_eq!(header.destination_gid(), dst_gid); + } + + #[test] + fn test_global_routing_header_check_len_error() { + let raw = vec![0u8; GRH_HEADER_LEN - 1]; + let err = GlobalRoutingHeader::new_checked(raw.as_slice()).unwrap_err(); + + match err { + GlobalRoutingHeaderError::SliceTooSmall { actual, expected } => { + assert_eq!(actual, GRH_HEADER_LEN - 1); + assert_eq!(expected, GRH_HEADER_LEN); + }, + } + } + + #[test] + fn test_global_routing_header_to_ibv_grh() { + let src_gid = Gid::from(Ipv6Addr::from_str("fe80::1").unwrap()); + let dst_gid = Gid::from(Ipv6Addr::from_str("fe80::2").unwrap()); + + let mut raw = vec![0u8; GRH_HEADER_LEN]; + let mut header = GlobalRoutingHeader::new_unchecked(raw.as_mut_slice()); + header.setup_version(6); + header.setup_traffic_class(0); + header.setup_flow_label(1); + header.setup_payload_length(0x1234); + header.setup_next_header(0x1b); + header.setup_hop_limit(0x40); + header.setup_source_gid(&src_gid); + header.setup_destination_gid(&dst_gid); + + let header = GlobalRoutingHeader::new_unchecked(raw.as_slice()); + let grh = header.grh(); + + // Verify the ibv_grh fields + assert_eq!(u32::from_be(grh.version_tclass_flow) >> 28, 6); // version + assert_eq!(u32::from_be(grh.version_tclass_flow) & 0x000f_ffff, 1); // flow_label + assert_eq!(u16::from_be(grh.paylen), 0x1234); + assert_eq!(grh.next_hdr, 0x1b); + assert_eq!(grh.hop_limit, 0x40); + assert_eq!(Gid::from(grh.sgid), src_gid); + assert_eq!(Gid::from(grh.dgid), dst_gid); + } } diff --git a/src/ibverbs/device_context.rs b/src/ibverbs/device_context.rs index f1c8802..8e7c307 100644 --- a/src/ibverbs/device_context.rs +++ b/src/ibverbs/device_context.rs @@ -404,6 +404,11 @@ impl PortAttr { self.attr.phys_state.into() } + /// Get the local identifier (LID) assigned to this port. + pub fn lid(&self) -> u16 { + self.attr.lid + } + /// Get the active link width of this port. pub fn active_width(&self) -> PortWidth { self.attr.active_width.into() diff --git a/src/ibverbs/protection_domain.rs b/src/ibverbs/protection_domain.rs index cc16671..f897e73 100644 --- a/src/ibverbs/protection_domain.rs +++ b/src/ibverbs/protection_domain.rs @@ -8,6 +8,7 @@ use std::ptr::NonNull; use std::sync::Arc; use super::{ + address::{AddressHandle, AddressHandleAttribute, CreateAddressHandleError}, device_context::DeviceContext, memory_region::{MemoryRegion, RegisterMemoryRegionError}, queue_pair::QueuePairBuilder, @@ -54,6 +55,13 @@ impl ProtectionDomain { Ok(Arc::new(MemoryRegion::reg_mr(Arc::clone(self), ptr, len, access)?)) } + /// Create a new address handle on this protection domain. + pub fn create_ah( + self: &Arc, attr: &mut AddressHandleAttribute, + ) -> Result { + AddressHandle::new(Arc::clone(self), attr) + } + /// Create a [`QueuePairBuilder`] for building QPs on this protection domain /// later. pub fn create_qp_builder(self: &Arc) -> QueuePairBuilder { diff --git a/src/ibverbs/queue_pair.rs b/src/ibverbs/queue_pair.rs index b36e78b..55a0aa7 100644 --- a/src/ibverbs/queue_pair.rs +++ b/src/ibverbs/queue_pair.rs @@ -7,7 +7,7 @@ use rdma_mummy_sys::{ ibv_qp_init_attr_ex, ibv_qp_init_attr_mask, ibv_qp_state, ibv_qp_to_qp_ex, ibv_qp_type, ibv_query_qp, ibv_recv_wr, ibv_rx_hash_conf, ibv_send_flags, ibv_send_wr, ibv_sge, ibv_wr_abort, ibv_wr_complete, ibv_wr_opcode, ibv_wr_rdma_read, ibv_wr_rdma_write, ibv_wr_rdma_write_imm, ibv_wr_send, ibv_wr_send_imm, ibv_wr_set_inline_data, - ibv_wr_set_inline_data_list, ibv_wr_set_sge, ibv_wr_set_sge_list, ibv_wr_start, + ibv_wr_set_inline_data_list, ibv_wr_set_sge, ibv_wr_set_sge_list, ibv_wr_set_ud_addr, ibv_wr_start, }; use std::sync::{Arc, LazyLock}; use std::{ @@ -19,7 +19,7 @@ use std::{ }; use super::{ - address::{AddressHandleAttribute, Gid}, + address::{AddressHandle, AddressHandleAttribute, Gid}, completion::{CompletionQueue, GenericCompletionQueue}, device_context::Mtu, protection_domain::ProtectionDomain, @@ -407,7 +407,9 @@ pub trait QueuePair { mod private_traits { use std::io::IoSlice; + use crate::ibverbs::address::AddressHandle; use rdma_mummy_sys::ibv_sge; + // This is the private part of PostSendGuard, which is a workaround for pub trait // not being able to have private functions. // @@ -418,6 +420,8 @@ mod private_traits { fn setup_send_imm(&mut self, imm_data: u32); + fn setup_ud_addr(&mut self, ah: &AddressHandle, remote_qpn: u32, remote_qkey: u32); + fn setup_write(&mut self, rkey: u32, remote_addr: u64); fn setup_write_imm(&mut self, rkey: u32, remote_addr: u64, imm_data: u32); @@ -911,6 +915,10 @@ impl QueuePairBuilder { let qp = unsafe { ibv_create_qp_ex((*(attr.pd)).context, &mut attr) }; + if qp.is_null() { + return Err(CreateQueuePairErrorKind::Ibverbs(io::Error::last_os_error()).into()); + } + Ok(ExtendedQueuePair { qp_ex: NonNull::new(unsafe { ibv_qp_to_qp_ex(qp) }) .ok_or::(CreateQueuePairErrorKind::Ibverbs(io::Error::last_os_error()).into())?, @@ -992,11 +1000,23 @@ impl QueuePairAttribute { self } + /// Setup the queue key (QKey) for this [`QueuePair`]. + pub fn setup_qkey(&mut self, qkey: u32) -> &mut Self { + self.attr.qkey = qkey; + self.attr_mask |= QueuePairAttributeMask::QueueKey; + self + } + /// Get the primary physical port number you filled in or queried from [`QueuePair::query`]. pub fn port(&self) -> u8 { self.attr.port_num } + /// Get the queue key (QKey) you filled in or queried from [`QueuePair::query`]. + pub fn qkey(&self) -> u32 { + self.attr.qkey + } + /// Setup allowed remote operations for incoming packets. It's either 0 or /// the bitwise `OR` of one or more of the following flags. /// @@ -1386,6 +1406,12 @@ impl<'g, G: PostSendGuard> WorkRequestHandle<'g, G> { LocalBufferHandle { guard: self.guard } } + pub fn setup_ud_addr(self, ah: &AddressHandle, remote_qpn: u32, remote_qkey: u32) -> Self { + let WorkRequestHandle { guard } = self; + guard.setup_ud_addr(ah, remote_qpn, remote_qkey); + WorkRequestHandle { guard } + } + pub fn setup_write(self, rkey: u32, remote_addr: u64) -> LocalBufferHandle<'g, G> { self.guard.setup_write(rkey, remote_addr); LocalBufferHandle { guard: self.guard } @@ -1481,6 +1507,12 @@ impl private_traits::PostSendGuard for BasicPostSendGuard<'_> { self.wrs.last_mut().unwrap().imm_data_invalidated_rkey_union.imm_data = imm_data; } + fn setup_ud_addr(&mut self, ah: &AddressHandle, remote_qpn: u32, remote_qkey: u32) { + self.wrs.last_mut().unwrap().wr.ud.ah = unsafe { ah.ah().as_ptr() }; + self.wrs.last_mut().unwrap().wr.ud.remote_qpn = remote_qpn; + self.wrs.last_mut().unwrap().wr.ud.remote_qkey = remote_qkey; + } + fn setup_write(&mut self, rkey: u32, remote_addr: u64) { self.wrs.last_mut().unwrap().opcode = WorkRequestOperationType::Write as _; self.wrs.last_mut().unwrap().wr.rdma.remote_addr = remote_addr; @@ -1609,6 +1641,12 @@ impl private_traits::PostSendGuard for ExtendedPostSendGuard<'_> { unsafe { ibv_wr_send_imm(self.qp_ex.as_ptr(), imm_data) }; } + fn setup_ud_addr(&mut self, ah: &AddressHandle, remote_qpn: u32, remote_qkey: u32) { + unsafe { + ibv_wr_set_ud_addr(self.qp_ex.as_ptr(), ah.ah().as_ptr(), remote_qpn, remote_qkey); + } + } + fn setup_write(&mut self, rkey: u32, remote_addr: u64) { unsafe { ibv_wr_rdma_write(self.qp_ex.as_ptr(), rkey, remote_addr) }; } @@ -1824,6 +1862,13 @@ impl private_traits::PostSendGuard for GenericPostSendGuard<'_> { } } + fn setup_ud_addr(&mut self, ah: &AddressHandle, remote_qpn: u32, remote_qkey: u32) { + match self { + GenericPostSendGuard::Basic(guard) => guard.setup_ud_addr(ah, remote_qpn, remote_qkey), + GenericPostSendGuard::Extended(guard) => guard.setup_ud_addr(ah, remote_qpn, remote_qkey), + } + } + fn setup_write(&mut self, rkey: u32, remote_addr: u64) { match self { GenericPostSendGuard::Basic(guard) => guard.setup_write(rkey, remote_addr),