diff --git a/Cargo.lock b/Cargo.lock index be64bc53..d0347b6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4288,7 +4288,9 @@ dependencies = [ "paste", "pyo3", "zenoh", + "zenoh-buffers", "zenoh-ext", + "zenoh-shm", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index ff161af4..21dfbcc3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ crate-type = ["cdylib"] [features] default = ["shared-memory", "zenoh-ext", "zenoh/default"] -shared-memory = ["zenoh/shared-memory"] +shared-memory = ["zenoh/shared-memory", "dep:zenoh-shm"] zenoh-ext = ["dep:zenoh-ext", "zenoh-ext/internal", "zenoh-ext/unstable"] [badges] @@ -43,11 +43,13 @@ maintenance = { status = "actively-developed" } [dependencies] paste = "1.0.14" -pyo3 = { version = "0.25.1", features = ["abi3-py39", "extension-module"] } +pyo3 = { version = "0.25.1", features = ["extension-module"] } zenoh = { version = "1.9.0", git = "https://github.com/eclipse-zenoh/zenoh.git", branch = "main", features = [ "internal", "unstable", ], default-features = false } +zenoh-buffers = { version = "1.9.0", git = "https://github.com/eclipse-zenoh/zenoh.git", branch = "main" } zenoh-ext = { version = "1.9.0", git = "https://github.com/eclipse-zenoh/zenoh.git", branch = "main", features = [ "internal", ], optional = true } +zenoh-shm = { version = "1.9.0", git = "https://github.com/eclipse-zenoh/zenoh.git", branch = "main", optional = true } diff --git a/docs/concepts.rst b/docs/concepts.rst index 1dcf474d..c92c1b4a 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -267,6 +267,154 @@ Example: Using :class:`zenoh.ZBytes` :start-after: [raw_data] :end-before: # [raw_data] +Scatter-gather payloads +~~~~~~~~~~~~~~~~~~~~~~~ + +Use :meth:`zenoh.ZBytes.from_segments` to construct a payload from multiple +Python buffer protocol objects without first joining them into one large +``bytes`` value: + +.. code-block:: python + + payload = zenoh.ZBytes.from_segments( + [header, segment_0, segment_1], + copy=True, + ) + publisher.put(payload) + +``copy=True`` copies each input buffer into Zenoh-owned memory while preserving +separate physical slices where possible. C-contiguous, byte-compatible buffers +are copied directly into Zenoh-owned slices. Set ``require_contiguous=False`` to +explicitly allow a non-contiguous input buffer to be flattened and copied. + +``copy=False`` performs strict zero-copy construction for read-only, +C-contiguous, single-byte Python buffer exporters. This includes ``bytes``, +eligible ``memoryview`` objects, and custom exporters such as serialization +library segment views. Cropped memoryviews are supported if they still describe +one contiguous slice. With shared memory enabled, ``shm.ZShm`` segments are +preserved without copying, and ``shm.ZShmMut`` segments are consumed just like +passing them directly to ``ZBytes``. Generic ``memoryview`` objects are treated +as raw borrowed buffers and do not carry shared-memory identity. ``ZBytes`` +retains each exported buffer view until Zenoh no longer references the payload. + +External buffer pools can attach a lease token to a raw borrowed zero-copy +payload. The lease object must provide ``sink`` and ``lease_id`` attributes. +When Zenoh releases its last borrowed-buffer reference, zenoh-python only +notifies the sink by calling ``lease.sink.release(lease.lease_id)``; the +provider decides how to enqueue, deduplicate, or process that release event: + +.. code-block:: python + + class LeaseState: + def __init__(self, sink, lease_id): + self.sink = sink + self.lease_id = lease_id + + payload = zenoh.ZBytes.from_segments( + [header_view, body_view], + copy=False, + lease=LeaseState(pool_release_sink, slot_id), + ) + +The sink's ``release`` method should be non-blocking or return quickly. Shared +memory buffers have their own lifecycle management, so ``lease`` cannot be used +with ``shm.ZShm`` or ``shm.ZShmMut`` segments. + +Pool-owned shared memory payloads +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Use :class:`zenoh.shm.ZShmPool` when an external writer needs a mutable Python +buffer first, and Zenoh should publish the result as a real shared-memory +payload: + +.. code-block:: python + + import zenoh + + pool = zenoh.shm.ZShmPool( + pool_size=256 * 1024 * 1024, + cuda_pinned=True, + cuda_device=0, + alignment=zenoh.shm.AllocAlignment(12), + ) + + buf = pool.alloc(1024) + memoryview(buf)[:] = b"\0" * 1024 + + payload = pool.seal_to_zbytes([buf]) + publisher.put(payload) + +``ZShmPool`` is a generic shared-memory pool. It does not know about any +particular serialization format, and it is not a strict singleton; create +explicit pool instances for workloads with different sizes, alignments, CUDA +pinning requirements, or devices. + +``ZShmPoolBuf`` implements the writable Python buffer protocol until it is +sealed. :meth:`zenoh.shm.ZShmPool.seal_to_zbytes` accepts only buffers allocated +by the same pool, consumes them, and returns :class:`zenoh.ZBytes` preserving the +Zenoh SHM descriptor. Arbitrary ``memoryview`` objects should continue to use +:meth:`zenoh.ZBytes.from_segments`; those are raw borrowed buffers and do not +carry SHM identity. + +By default, :meth:`zenoh.shm.ZShmPool.seal_to_zbytes` refuses to seal a buffer +while active Python buffer exports exist. The explicit +:meth:`zenoh.shm.ZShmPool.seal_to_zbytes_unchecked` path skips that check for +advanced integrations such as torch or capnp, where the exporter object may +outlive the write phase. This is unsafe: before calling it, the application must +guarantee that all CPU and GPU writers have completed and that no existing +``memoryview``, torch tensor, capnp view, or other alias will write to the +buffer afterwards. Otherwise those writes can race with Zenoh reads or sends and +produce torn payload contents. Keep the returned :class:`zenoh.ZBytes` alive +until any pre-existing aliases are released. + +When ``cuda_pinned=True``, zenoh-python loads ``libcuda.so.1`` through the CUDA +Driver API and registers allocated SHM host memory for CUDA-aware writers. When +``cuda_pinned=False``, no CUDA library is loaded. + +The current Zenoh SHM allocation API does not expose the full backing pool +segment base and length to zenoh-python. CUDA pinning therefore uses a +per-driver page registry: overlapping allocations in the same pool share page +registrations and a page is unregistered only after the last owning +``ZShmPoolBuf`` or sealed ``ZBytes`` is dropped. The residual limitation is that +the whole pool is not pinned eagerly as one segment. + +The read-only flag prevents writes through the exported view but cannot prevent +writes through every alias to the same backing memory. After passing segments +to ``copy=False``, the application must treat their backing memory as immutable +until Zenoh no longer references the payload. Writable buffers, non-contiguous +buffers, and buffers whose items are not one byte wide raise ``RuntimeError`` +instead of silently falling back to copying. Use ``copy=True`` for those +buffers. + +On the receiving side, :meth:`zenoh.ZBytes.segments` returns a tuple of +zero-copy :class:`zenoh.ZBytesSegment` views over the payload's physical slices: + +.. code-block:: python + + physical_slices = sample.payload.segments() + payload_bytes = b"".join(map(bytes, physical_slices)) + +The returned segments remain valid after a subscriber callback returns. Each +segment implements the Python buffer protocol, so consumers that accept buffer +objects can use the segments directly or explicitly create ``memoryview`` +objects: + +.. code-block:: python + + views = sample.payload.memoryviews() + first = memoryview(sample.payload.segments()[0]) + +``bytes(segment)`` copies one segment. ``bytes(payload)`` or +``payload.to_bytes()`` copies the whole payload. If you need the previous +copy-out behavior where each returned memoryview is backed by a new Python +``bytes`` object, use :meth:`zenoh.ZBytes.copied_memoryviews`. + +Physical slice boundaries are an internal memory layout optimization. They are +not application-level frames and may differ from sender-side input boundaries +after routing, fragmentation, or shared-memory conversion. Applications that +need stable frames, such as Cap'n Proto segments, must encode segment lengths or +offsets in a payload header and reconstruct logical segments on receipt. + Serialization and deserialization of basic types and structures is provided in the :mod:`zenoh.ext` module via :func:`zenoh.ext.z_serialize` and :func:`zenoh.ext.z_deserialize`. diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 00000000..5d32ffe3 --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,126 @@ +// +// Copyright (c) 2024 ZettaScale Technology +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// +// Contributors: +// ZettaScale Zenoh Team, +// +use std::{ + ffi::CString, + os::raw::{c_int, c_void}, + ptr, +}; + +use pyo3::{exceptions::PyBufferError, ffi, prelude::*}; + +/// Populate `view` as a read-only, single-byte, C-contiguous buffer over +/// `data`, transferring ownership of `owner` into the exported view so the +/// backing storage stays alive while the consumer holds the buffer. +/// +/// # Safety +/// `data` must remain valid for as long as `owner` keeps the buffer alive, and +/// `view` must point to a valid `Py_buffer` provided by the buffer protocol. +pub(crate) unsafe fn fill_readonly_u8_buffer( + owner: Bound<'_, PyAny>, + data: &[u8], + view: *mut ffi::Py_buffer, + flags: c_int, +) -> PyResult<()> { + if view.is_null() { + return Err(PyBufferError::new_err("view is null")); + } + if flags & ffi::PyBUF_WRITABLE == ffi::PyBUF_WRITABLE { + return Err(PyBufferError::new_err("object is not writable")); + } + + unsafe { + (*view).obj = owner.into_ptr(); + (*view).buf = data.as_ptr() as *mut c_void; + (*view).len = data.len() as ffi::Py_ssize_t; + (*view).readonly = 1; + (*view).itemsize = 1; + (*view).format = if flags & ffi::PyBUF_FORMAT == ffi::PyBUF_FORMAT { + CString::new("B").unwrap().into_raw() + } else { + ptr::null_mut() + }; + (*view).ndim = 1; + (*view).shape = if flags & ffi::PyBUF_ND == ffi::PyBUF_ND { + &mut (*view).len + } else { + ptr::null_mut() + }; + (*view).strides = if flags & ffi::PyBUF_STRIDES == ffi::PyBUF_STRIDES { + &mut (*view).itemsize + } else { + ptr::null_mut() + }; + (*view).suboffsets = ptr::null_mut(); + (*view).internal = ptr::null_mut(); + } + Ok(()) +} + +/// Populate `view` as a writable, single-byte, C-contiguous buffer over `data`. +/// +/// # Safety +/// `data` must remain valid for as long as `owner` keeps the buffer alive, and +/// callers must ensure there is no conflicting mutable access while the Python +/// buffer export exists. +#[cfg(feature = "shared-memory")] +pub(crate) unsafe fn fill_writable_u8_buffer( + owner: Bound<'_, PyAny>, + data: &mut [u8], + view: *mut ffi::Py_buffer, + flags: c_int, +) -> PyResult<()> { + if view.is_null() { + return Err(PyBufferError::new_err("view is null")); + } + + unsafe { + (*view).obj = owner.into_ptr(); + (*view).buf = data.as_mut_ptr() as *mut c_void; + (*view).len = data.len() as ffi::Py_ssize_t; + (*view).readonly = 0; + (*view).itemsize = 1; + (*view).format = if flags & ffi::PyBUF_FORMAT == ffi::PyBUF_FORMAT { + CString::new("B").unwrap().into_raw() + } else { + ptr::null_mut() + }; + (*view).ndim = 1; + (*view).shape = if flags & ffi::PyBUF_ND == ffi::PyBUF_ND { + &mut (*view).len + } else { + ptr::null_mut() + }; + (*view).strides = if flags & ffi::PyBUF_STRIDES == ffi::PyBUF_STRIDES { + &mut (*view).itemsize + } else { + ptr::null_mut() + }; + (*view).suboffsets = ptr::null_mut(); + (*view).internal = ptr::null_mut(); + } + Ok(()) +} + +/// Release the format string allocated by the u8 buffer helpers. +/// +/// # Safety +/// `view` must be a `Py_buffer` previously populated by +/// [`fill_readonly_u8_buffer`] or the shared-memory writable helper. +pub(crate) unsafe fn release_u8_buffer(view: *mut ffi::Py_buffer) { + unsafe { + if !view.is_null() && !(*view).format.is_null() { + drop(CString::from_raw((*view).format)); + } + } +} diff --git a/src/bytes.rs b/src/bytes.rs index c683ce5d..714dd065 100644 --- a/src/bytes.rs +++ b/src/bytes.rs @@ -11,22 +11,432 @@ // Contributors: // ZettaScale Zenoh Team, // -use std::{borrow::Cow, io::Read}; +#[cfg(feature = "shared-memory")] +use std::collections::HashSet; +use std::{ + any::Any, + borrow::Cow, + fmt, + io::Read, + os::raw::{c_int, c_void}, + ptr, slice, + sync::Arc, +}; use pyo3::{ - exceptions::{PyTypeError, PyValueError}, + exceptions::{PyRuntimeError, PyTypeError, PyValueError}, + ffi, prelude::*, - types::{PyByteArray, PyBytes, PyString}, + types::{PyByteArray, PyBytes, PyMemoryView, PyString, PyTuple}, }; +#[cfg(feature = "shared-memory")] +use zenoh_buffers::ZSlice; +use zenoh_buffers::{ZBuf, ZSliceBuffer}; +#[cfg(feature = "shared-memory")] +use zenoh_shm::ShmBufInner; use crate::{ + buffer::{fill_readonly_u8_buffer, release_u8_buffer}, macros::{downcast_or_new, wrapper}, - utils::{IntoPyResult, MapInto}, + utils::{IntoPyResult, IntoPython, IntoRust, MapInto}, }; -wrapper!(zenoh::bytes::ZBytes: Clone, Default); +unsafe extern "C" { + // `PyObject_AsReadBuffer` is part of the stable ABI. Holding a Python + // `memoryview` separately gives the acquired exporter resources a clear + // lifetime even though this legacy API returns only a pointer and length. + #[link_name = "PyObject_AsReadBuffer"] + fn py_object_as_read_buffer( + obj: *mut ffi::PyObject, + buffer: *mut *const c_void, + buffer_len: *mut ffi::Py_ssize_t, + ) -> c_int; +} + +struct BorrowedPyBufferSlice { + _owner: Py, + _lease: Option>, + ptr: *const u8, + len: usize, +} + +impl BorrowedPyBufferSlice { + fn new(buffer: &Bound) -> PyResult { + let mut ptr = ptr::null(); + let mut len = 0; + if unsafe { py_object_as_read_buffer(buffer.as_ptr(), &mut ptr, &mut len) } == -1 { + return Err(PyErr::fetch(buffer.py())); + } + if len < 0 { + Err(PyRuntimeError::new_err( + "buffer exporter returned a negative length", + )) + } else if len > 0 && ptr.is_null() { + Err(PyRuntimeError::new_err( + "buffer exporter returned a null pointer for a non-empty segment", + )) + } else { + Ok(Self { + _owner: buffer.clone().unbind(), + _lease: None, + ptr: ptr.cast(), + len: len as usize, + }) + } + } + + fn with_lease(mut self, lease: Option>) -> Self { + self._lease = lease; + self + } + + fn as_bytes(&self) -> &[u8] { + if self.len == 0 { + &[] + } else { + // SAFETY: `_owner` retains the validated `memoryview`, which owns + // its exporter resources and keeps this contiguous slice valid. + unsafe { slice::from_raw_parts(self.ptr, self.len) } + } + } +} + +struct LeaseState { + sink: Py, + lease_id: Py, +} + +impl LeaseState { + fn new(lease: &Bound) -> PyResult { + let sink = lease + .getattr("sink") + .map_err(|_| PyTypeError::new_err("lease must provide a 'sink' attribute"))?; + let release = sink + .getattr("release") + .map_err(|_| PyTypeError::new_err("lease.sink must provide a 'release' method"))?; + if !release.is_callable() { + return Err(PyTypeError::new_err("lease.sink.release must be callable")); + } + let lease_id = lease + .getattr("lease_id") + .map_err(|_| PyTypeError::new_err("lease must provide a 'lease_id' attribute"))?; + Ok(Self { + sink: sink.unbind(), + lease_id: lease_id.unbind(), + }) + } + + fn into_guard(self) -> Arc { + Arc::new(LeaseGuard { + sink: self.sink, + lease_id: self.lease_id, + }) + } +} + +struct LeaseGuard { + sink: Py, + lease_id: Py, +} + +impl Drop for LeaseGuard { + fn drop(&mut self) { + Python::with_gil(|py| { + let sink = self.sink.bind(py); + if let Err(err) = sink.call_method1("release", (self.lease_id.bind(py),)) { + err.write_unraisable(py, Some(sink)); + } + }); + } +} + +// SAFETY: `_owner` retains the validated `memoryview` while this pointer may be +// read from another thread. +unsafe impl Send for BorrowedPyBufferSlice {} +// SAFETY: The `copy=False` contract requires callers not to mutate the +// exported memory through another alias while Zenoh may reference it. +unsafe impl Sync for BorrowedPyBufferSlice {} + +impl fmt::Debug for BorrowedPyBufferSlice { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter + .debug_struct("BorrowedPyBufferSlice") + .field("len", &self.len) + .finish_non_exhaustive() + } +} + +impl ZSliceBuffer for BorrowedPyBufferSlice { + fn as_slice(&self) -> &[u8] { + self.as_bytes() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } +} + +fn py_buffer_zbytes(buffer: BorrowedPyBufferSlice) -> zenoh::bytes::ZBytes { + ZBuf::from(Arc::new(buffer)).into() +} + +fn single_slice_zbytes(slice: zenoh_buffers::ZSlice) -> zenoh::bytes::ZBytes { + let mut zbuf = ZBuf::empty(); + zbuf.push_zslice(slice); + zbuf.into() +} + +#[cfg(feature = "shared-memory")] +struct CudaRegistrationOwner( + #[allow(dead_code)] Vec, +); + +#[cfg(feature = "shared-memory")] +struct CudaPinnedShmSlice { + // Drop registrations before the SHM clone so cuMemHostUnregister runs + // while the mapping is still alive. + _cuda_owner: Arc, + inner: ShmBufInner, +} + +#[cfg(feature = "shared-memory")] +impl fmt::Debug for CudaPinnedShmSlice { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter + .debug_struct("CudaPinnedShmSlice") + .field("len", &self.inner.as_ref().len()) + .finish_non_exhaustive() + } +} + +#[cfg(feature = "shared-memory")] +impl ZSliceBuffer for CudaPinnedShmSlice { + fn as_slice(&self) -> &[u8] { + self.inner.as_ref() + } + + fn as_any(&self) -> &dyn Any { + &self.inner + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.inner + } +} + +#[cfg(feature = "shared-memory")] +fn attach_cuda_registrations( + inner: zenoh::bytes::ZBytes, + cuda_registrations: Vec, +) -> zenoh::bytes::ZBytes { + if cuda_registrations.is_empty() { + return inner; + } + + let owner = Arc::new(CudaRegistrationOwner(cuda_registrations)); + let mut zbuf = ZBuf::empty(); + for zslice in ZBuf::from(inner).into_zslices() { + if let Some(shm) = zslice.downcast_ref::() { + let mut wrapped = ZSlice::from(CudaPinnedShmSlice { + _cuda_owner: owner.clone(), + inner: shm.clone(), + }); + wrapped.kind = zslice.kind; + zbuf.push_zslice(wrapped); + } else { + zbuf.push_zslice(zslice); + } + } + zbuf.into() +} + +fn physical_segment_zbytes( + zbytes: &zenoh::bytes::ZBytes, +) -> impl Iterator { + let zbuf: ZBuf = zbytes.clone().into(); + zbuf.into_zslices().map(single_slice_zbytes) +} + +fn copied_memoryviews<'py>( + zbytes: &zenoh::bytes::ZBytes, + py: Python<'py>, +) -> PyResult> { + let memoryview = py.import("builtins")?.getattr("memoryview")?; + let views = zbytes + .slices() + .map(|slice| memoryview.call1((PyBytes::new(py, slice),))) + .collect::>>()?; + PyTuple::new(py, views) +} + +#[pyclass] +#[derive(Clone)] +pub(crate) struct ZBytesSegment { + inner: zenoh::bytes::ZBytes, +} + +impl ZBytesSegment { + fn new(inner: zenoh::bytes::ZBytes) -> Self { + Self { inner } + } + + fn as_slice(&self) -> &[u8] { + self.inner.slices().next().unwrap_or(&[]) + } + + fn clone_inner(&self) -> zenoh::bytes::ZBytes { + self.inner.clone() + } +} + +#[pymethods] +impl ZBytesSegment { + fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult> { + PyBytes::new_with(py, self.inner.len(), |bytes| { + self.inner.reader().read_exact(bytes).into_pyres() + }) + } + + #[cfg(feature = "shared-memory")] + fn as_shm(&self) -> Option { + self.inner.as_shm().map(ToOwned::to_owned).map_into() + } + + fn __len__(&self) -> usize { + self.inner.len() + } + + fn __bool__(&self) -> bool { + !self.inner.is_empty() + } + + fn __bytes__<'py>(&self, py: Python<'py>) -> PyResult> { + self.to_bytes(py) + } + + fn __repr__(&self) -> String { + format!("ZBytesSegment({:?})", self.inner) + } + + unsafe fn __getbuffer__( + slf: Bound<'_, Self>, + view: *mut ffi::Py_buffer, + flags: c_int, + ) -> PyResult<()> { + let (ptr, len) = { + let segment = slf.borrow(); + let bytes = segment.as_slice(); + (bytes.as_ptr(), bytes.len()) + }; + let bytes = if len == 0 { + &[] + } else { + // SAFETY: `slf` owns the single-slice ZBytes that keeps the backing + // storage alive for at least as long as the exported buffer. + unsafe { slice::from_raw_parts(ptr, len) } + }; + unsafe { fill_readonly_u8_buffer(slf.into_any(), bytes, view, flags) } + } + + unsafe fn __releasebuffer__(&self, view: *mut ffi::Py_buffer) { + unsafe { release_u8_buffer(view) } + } +} + +#[pyclass] +#[derive(Clone)] +pub(crate) struct ZBytes { + // Drop CUDA registrations before the ZBytes inner so cuMemHostUnregister + // runs while the SHM mapping is still alive. + #[cfg(feature = "shared-memory")] + #[allow(dead_code)] + cuda_registrations: Vec, + pub(crate) inner: zenoh::bytes::ZBytes, +} + +impl Default for ZBytes { + fn default() -> Self { + Self { + #[cfg(feature = "shared-memory")] + cuda_registrations: Vec::new(), + inner: zenoh::bytes::ZBytes::default(), + } + } +} + +impl ZBytes { + #[cfg(feature = "shared-memory")] + pub(crate) fn with_cuda_registrations( + inner: zenoh::bytes::ZBytes, + cuda_registrations: Vec, + ) -> Self { + Self { + cuda_registrations: Vec::new(), + inner: attach_cuda_registrations(inner, cuda_registrations), + } + } +} + +impl From for ZBytes { + fn from(inner: zenoh::bytes::ZBytes) -> Self { + Self { + #[cfg(feature = "shared-memory")] + cuda_registrations: Vec::new(), + inner, + } + } +} + +impl From for zenoh::bytes::ZBytes { + fn from(value: ZBytes) -> Self { + #[cfg(feature = "shared-memory")] + { + return attach_cuda_registrations(value.inner, value.cuda_registrations); + } + #[cfg(not(feature = "shared-memory"))] + { + return value.inner; + } + } +} + +impl IntoRust for ZBytes { + type Into = zenoh::bytes::ZBytes; + + fn into_rust(self) -> Self::Into { + self.into() + } +} + +impl IntoPython for zenoh::bytes::ZBytes { + type Into = ZBytes; + + fn into_python(self) -> Self::Into { + self.into() + } +} + +impl IntoPython for ZBytes { + type Into = ZBytes; + + fn into_python(self) -> Self::Into { + self + } +} + downcast_or_new!(ZBytes); +enum SegmentAction { + Append(zenoh::bytes::ZBytes), + Borrow(BorrowedPyBufferSlice), + #[cfg(feature = "shared-memory")] + MoveShmMut(Py), +} + #[pymethods] impl ZBytes { #[new] @@ -35,19 +445,25 @@ impl ZBytes { return Ok(Self::default()); }; if let Ok(bytes) = obj.downcast::() { - Ok(Self(bytes.to_vec().into())) + Ok(Self::from(zenoh::bytes::ZBytes::from(bytes.to_vec()))) } else if let Ok(bytes) = obj.downcast::() { - Ok(Self(bytes.as_bytes().into())) + Ok(Self::from(zenoh::bytes::ZBytes::from( + bytes.as_bytes().to_vec(), + ))) } else if let Ok(string) = obj.downcast::() { - Ok(Self(string.to_string().into())) + Ok(Self::from(zenoh::bytes::ZBytes::from(string.to_string()))) } else { #[cfg(feature = "shared-memory")] if let Ok(buf) = obj.downcast_exact::() { - return Ok(Self(buf.borrow_mut().take()?.into())); + return Ok(Self::from(zenoh::bytes::ZBytes::from( + buf.borrow_mut().take()?, + ))); } #[cfg(feature = "shared-memory")] if let Ok(buf) = obj.downcast_exact::() { - return Ok(Self(buf.borrow().0.clone().into())); + return Ok(Self::from(zenoh::bytes::ZBytes::from( + buf.borrow().0.clone(), + ))); } Err(PyTypeError::new_err(format!( "expected bytes/str type, found '{}'", @@ -58,28 +474,216 @@ impl ZBytes { fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult> { // Not using `ZBytes::to_bytes` - PyBytes::new_with(py, self.0.len(), |bytes| { - self.0.reader().read_exact(bytes).into_pyres() + PyBytes::new_with(py, self.inner.len(), |bytes| { + self.inner.reader().read_exact(bytes).into_pyres() }) } + #[staticmethod] + #[pyo3(signature = (segments, *, copy = false, require_contiguous = true, lease = None))] + fn from_segments( + segments: &Bound, + copy: bool, + require_contiguous: bool, + lease: Option<&Bound>, + ) -> PyResult { + let py = segments.py(); + let lease = lease.map(LeaseState::new).transpose()?; + if lease.is_some() && copy { + return Err(PyRuntimeError::new_err( + "lease can only be used with copy=False", + )); + } + let memoryview = py.import("builtins")?.getattr("memoryview")?; + let mut actions = Vec::new(); + let mut has_nonempty_raw_borrow = false; + #[cfg(feature = "shared-memory")] + let mut mutable_shm_segments = HashSet::new(); + + for (index, segment) in segments.try_iter()?.enumerate() { + let segment = segment?; + if let Ok(segment) = segment.downcast_exact::() { + if copy { + actions.push(SegmentAction::Append( + segment.borrow().as_slice().to_vec().into(), + )); + } else { + actions.push(SegmentAction::Append(segment.borrow().clone_inner())); + } + continue; + } + + #[cfg(feature = "shared-memory")] + if let Ok(buf) = segment.downcast_exact::() { + if lease.is_some() { + return Err(PyRuntimeError::new_err( + "lease cannot be used with shared-memory segments", + )); + } + if copy { + actions.push(SegmentAction::Append( + buf.borrow().0.as_ref().to_vec().into(), + )); + } else { + actions.push(SegmentAction::Append(buf.borrow().0.clone().into())); + } + continue; + } + + #[cfg(feature = "shared-memory")] + if let Ok(buf) = segment.downcast_exact::() { + if lease.is_some() { + return Err(PyRuntimeError::new_err( + "lease cannot be used with shared-memory segments", + )); + } + if copy { + actions.push(SegmentAction::Append(buf.borrow().get()?.to_vec().into())); + } else { + buf.borrow().get()?; + let ptr = segment.as_ptr() as usize; + if !mutable_shm_segments.insert(ptr) { + return Err(PyRuntimeError::new_err(format!( + "segment {index} repeats the same mutable SHM buffer; \ + zero-copy would need to consume it more than once" + ))); + } + actions.push(SegmentAction::MoveShmMut(buf.clone().unbind())); + } + continue; + } + + if !copy { + let view = memoryview.call1((&segment,)).map_err(|_| { + let type_name = segment + .get_type() + .name() + .map(|name| name.to_string()) + .unwrap_or_else(|_| "".to_string()); + PyRuntimeError::new_err(format!( + "zero-copy requires a read-only, C-contiguous, byte-compatible Python \ + buffer; segment {index} has type '{type_name}'; use copy=True" + )) + })?; + if !view.getattr("readonly")?.extract::()? { + return Err(PyRuntimeError::new_err(format!( + "segment {index} is writable; zero-copy requires a read-only buffer; \ + use copy=True" + ))); + } + if !view.getattr("c_contiguous")?.extract::()? { + return Err(PyRuntimeError::new_err(format!( + "segment {index} is not C-contiguous; zero-copy requires one contiguous \ + byte slice; use copy=True" + ))); + } + if view.getattr("itemsize")?.extract::()? != 1 { + return Err(PyRuntimeError::new_err(format!( + "segment {index} has unsupported item format; zero-copy requires a \ + single-byte buffer; use copy=True" + ))); + } + let buffer = BorrowedPyBufferSlice::new(&view)?; + has_nonempty_raw_borrow |= buffer.len > 0; + actions.push(SegmentAction::Borrow(buffer)); + continue; + } + + let view = memoryview.call1((&segment,)).map_err(|_| { + PyTypeError::new_err(format!( + "segment {index} does not support the Python buffer protocol" + )) + })?; + if view.getattr("itemsize")?.extract::()? != 1 { + return Err(PyTypeError::new_err(format!( + "segment {index} has unsupported item format; \ + expected a byte-compatible buffer" + ))); + } + let c_contiguous = view.getattr("c_contiguous")?.extract::()?; + if require_contiguous && !c_contiguous { + return Err(PyTypeError::new_err(format!( + "segment {index} is not C-contiguous; use require_contiguous=False" + ))); + } + if c_contiguous { + actions.push(SegmentAction::Append( + BorrowedPyBufferSlice::new(&view)? + .as_bytes() + .to_vec() + .into(), + )); + } else { + let bytes = view.call_method0("tobytes")?; + actions.push(SegmentAction::Append( + bytes.downcast::()?.as_bytes().to_vec().into(), + )); + } + } + let lease = if let Some(lease) = lease { + if !has_nonempty_raw_borrow { + return Err(PyRuntimeError::new_err( + "lease requires at least one non-empty ordinary Python zero-copy buffer segment", + )); + } + Some(lease.into_guard()) + } else { + None + }; + let mut writer = zenoh::bytes::ZBytes::writer(); + for action in actions { + match action { + SegmentAction::Append(zbytes) => writer.append(zbytes), + SegmentAction::Borrow(buffer) => { + writer.append(py_buffer_zbytes(buffer.with_lease(lease.clone()))) + } + #[cfg(feature = "shared-memory")] + SegmentAction::MoveShmMut(buf) => { + writer.append(buf.bind(py).borrow_mut().take()?.into()); + } + } + } + Ok(Self::from(writer.finish())) + } + + fn segments<'py>(&self, py: Python<'py>) -> PyResult> { + let segments = physical_segment_zbytes(&self.inner) + .map(|inner| Py::new(py, ZBytesSegment::new(inner))) + .collect::>>()?; + PyTuple::new(py, segments) + } + + fn memoryviews<'py>(&self, py: Python<'py>) -> PyResult> { + let views = physical_segment_zbytes(&self.inner) + .map(|inner| { + let segment = Py::new(py, ZBytesSegment::new(inner))?; + PyMemoryView::from(segment.bind(py).as_any()).map(|view| view.unbind()) + }) + .collect::>>()?; + PyTuple::new(py, views) + } + + fn copied_memoryviews<'py>(&self, py: Python<'py>) -> PyResult> { + copied_memoryviews(&self.inner, py) + } + fn to_string(&self) -> PyResult> { - self.0 + self.inner .try_to_string() .map_err(|_| PyValueError::new_err("not an UTF8 error")) } #[cfg(feature = "shared-memory")] fn as_shm(&self) -> Option { - self.0.as_shm().map(ToOwned::to_owned).map_into() + self.inner.as_shm().map(ToOwned::to_owned).map_into() } fn __len__(&self) -> usize { - self.0.len() + self.inner.len() } fn __bool__(&self) -> bool { - !self.0.is_empty() + !self.inner.is_empty() } fn __bytes__<'py>(&self, py: Python<'py>) -> PyResult> { @@ -91,7 +695,7 @@ impl ZBytes { } fn __eq__(&self, #[pyo3(from_py_with = Self::from_py)] other: Self) -> bool { - self.0 == other.0 + self.inner == other.inner } fn __hash__(&self, py: Python) -> PyResult { @@ -99,7 +703,7 @@ impl ZBytes { } fn __repr__(&self) -> String { - format!("{:?}", self.0) + format!("{:?}", self.inner) } } @@ -245,3 +849,40 @@ impl Encoding { #[classattr] const VIDEO_VP9: Self = Self(zenoh::bytes::Encoding::VIDEO_VP9); } + +#[cfg(all(test, feature = "shared-memory"))] +mod tests { + use super::*; + use crate::cuda_shm::test_support::{fake_driver, take_events, test_lock, Event}; + use zenoh::{ + shm::{PosixShmProviderBackend, ShmProviderBuilder}, + Wait, + }; + + #[test] + fn cuda_registration_survives_conversion_into_rust_zbytes() { + let _guard = test_lock(); + let backend = PosixShmProviderBackend::builder(4096).wait().unwrap(); + let provider = ShmProviderBuilder::backend(backend).wait(); + let layout = provider.alloc_layout(16).unwrap(); + let mut shm = layout.alloc().wait().unwrap(); + shm.as_mut().copy_from_slice(b"cuda-pinned-shm!"); + let ptr = shm.as_ref().as_ptr() as usize; + + let registration = fake_driver() + .register(ptr as *mut u8, shm.as_ref().len()) + .unwrap(); + let page = ptr - (ptr % 4096); + assert_eq!(take_events(), vec![Event::Register(page, 4096)]); + + let py_zbytes = ZBytes::with_cuda_registrations(shm.into(), vec![registration]); + let rust_zbytes: zenoh::bytes::ZBytes = py_zbytes.into(); + + assert_eq!(rust_zbytes.to_bytes().as_ref(), b"cuda-pinned-shm!"); + assert!(rust_zbytes.as_shm().is_some()); + assert_eq!(take_events(), Vec::::new()); + + drop(rust_zbytes); + assert_eq!(take_events(), vec![Event::Unregister(page)]); + } +} diff --git a/src/cuda_shm.rs b/src/cuda_shm.rs new file mode 100644 index 00000000..5b587cea --- /dev/null +++ b/src/cuda_shm.rs @@ -0,0 +1,441 @@ +// +// Copyright (c) 2026 ZettaScale Technology +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +// which is available at https://www.apache.org/licenses/LICENSE-2.0. +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// +// Contributors: +// ZettaScale Zenoh Team, +// +use std::{ + collections::BTreeMap, + ffi::{CStr, CString}, + os::raw::{c_char, c_int, c_uint, c_void}, + ptr, + sync::{Arc, Mutex}, +}; + +use pyo3::{exceptions::PyRuntimeError, prelude::*}; + +type CuResult = c_int; +type CuDevice = c_int; +type CuContext = *mut c_void; + +const CUDA_SUCCESS: CuResult = 0; +const RTLD_NOW: c_int = 2; + +type CuInit = unsafe extern "C" fn(c_uint) -> CuResult; +type CuDeviceGet = unsafe extern "C" fn(*mut CuDevice, c_int) -> CuResult; +type CuDevicePrimaryCtxRetain = unsafe extern "C" fn(*mut CuContext, CuDevice) -> CuResult; +type CuDevicePrimaryCtxRelease = unsafe extern "C" fn(CuDevice) -> CuResult; +type CuCtxSetCurrent = unsafe extern "C" fn(CuContext) -> CuResult; +type CuMemHostRegister = unsafe extern "C" fn(*mut c_void, usize, c_uint) -> CuResult; +type CuMemHostUnregister = unsafe extern "C" fn(*mut c_void) -> CuResult; +type CuGetErrorString = unsafe extern "C" fn(CuResult, *mut *const c_char) -> CuResult; + +#[cfg(target_os = "linux")] +#[link(name = "dl")] +unsafe extern "C" { + fn dlopen(filename: *const c_char, flags: c_int) -> *mut c_void; + fn dlsym(handle: *mut c_void, symbol: *const c_char) -> *mut c_void; + fn sysconf(name: c_int) -> isize; +} + +pub(crate) struct CudaDriver { + _lib: *mut c_void, + device: CuDevice, + context: CuContext, + page_size: usize, + registered_pages: Mutex>, + cu_device_primary_ctx_release: CuDevicePrimaryCtxRelease, + cu_ctx_set_current: CuCtxSetCurrent, + cu_mem_host_register: CuMemHostRegister, + cu_mem_host_unregister: CuMemHostUnregister, + cu_get_error_string: CuGetErrorString, +} + +unsafe impl Send for CudaDriver {} +unsafe impl Sync for CudaDriver {} + +impl CudaDriver { + pub(crate) fn new(device_index: c_int) -> PyResult> { + if device_index < 0 { + return Err(PyRuntimeError::new_err( + "CUDA device index must be non-negative", + )); + } + #[cfg(not(target_os = "linux"))] + { + let _ = device_index; + return Err(PyRuntimeError::new_err( + "CUDA pinned SHM currently requires Linux libcuda.so.1", + )); + } + + #[cfg(target_os = "linux")] + unsafe { + let lib_name = CString::new("libcuda.so.1").unwrap(); + let lib = dlopen(lib_name.as_ptr(), RTLD_NOW); + if lib.is_null() { + return Err(PyRuntimeError::new_err( + "CUDA driver library libcuda.so.1 is unavailable", + )); + } + + let cu_init: CuInit = symbol(lib, "cuInit")?; + let cu_device_get: CuDeviceGet = symbol(lib, "cuDeviceGet")?; + let cu_device_primary_ctx_retain: CuDevicePrimaryCtxRetain = + symbol(lib, "cuDevicePrimaryCtxRetain")?; + let cu_device_primary_ctx_release: CuDevicePrimaryCtxRelease = + symbol(lib, "cuDevicePrimaryCtxRelease")?; + let cu_ctx_set_current: CuCtxSetCurrent = symbol(lib, "cuCtxSetCurrent")?; + let cu_mem_host_register: CuMemHostRegister = symbol(lib, "cuMemHostRegister")?; + let cu_mem_host_unregister: CuMemHostUnregister = symbol(lib, "cuMemHostUnregister")?; + let cu_get_error_string: CuGetErrorString = symbol(lib, "cuGetErrorString")?; + + check(cu_init(0), cu_get_error_string, "cuInit")?; + + let mut device = 0; + check( + cu_device_get(&mut device, device_index), + cu_get_error_string, + "cuDeviceGet", + )?; + + let mut context = ptr::null_mut(); + check( + cu_device_primary_ctx_retain(&mut context, device), + cu_get_error_string, + "cuDevicePrimaryCtxRetain", + )?; + check( + cu_ctx_set_current(context), + cu_get_error_string, + "cuCtxSetCurrent", + )?; + + Ok(Arc::new(Self { + _lib: lib, + device, + context, + page_size: page_size(), + registered_pages: Mutex::new(BTreeMap::new()), + cu_device_primary_ctx_release, + cu_ctx_set_current, + cu_mem_host_register, + cu_mem_host_unregister, + cu_get_error_string, + })) + } + } + + pub(crate) fn register( + self: &Arc, + ptr: *mut u8, + len: usize, + ) -> PyResult { + if len == 0 { + return Err(PyRuntimeError::new_err( + "cannot CUDA-register an empty SHM buffer", + )); + } + if ptr.is_null() { + return Err(PyRuntimeError::new_err( + "cannot CUDA-register a null SHM buffer", + )); + } + // The upstream ShmProvider/ZShmMut API used here does not expose the + // full backing segment base/len. Register page-by-page and share + // overlapping pages within this CUDA driver instead of registering + // potentially overlapping allocation ranges. + let page_size = self.page_size; + let ptr_addr = ptr as usize; + let start = ptr_addr - (ptr_addr % page_size); + let last = ptr_addr + .checked_add(len - 1) + .ok_or_else(|| PyRuntimeError::new_err("CUDA registration range overflow"))?; + let end = (last - (last % page_size)) + .checked_add(page_size) + .ok_or_else(|| PyRuntimeError::new_err("CUDA registration range overflow"))?; + let mut pages = Vec::new(); + let mut page = start; + while page < end { + if let Err(err) = self.acquire_page(page) { + self.release_pages(&pages); + return Err(err); + } + pages.push(page); + page = page + .checked_add(page_size) + .ok_or_else(|| PyRuntimeError::new_err("CUDA registration range overflow"))?; + } + + Ok(CudaRegisteredMemory { + inner: Arc::new(CudaRegisteredMemoryInner { + driver: self.clone(), + pages, + }), + }) + } + + fn acquire_page(&self, page: usize) -> PyResult<()> { + let mut registered_pages = self + .registered_pages + .lock() + .map_err(|_| PyRuntimeError::new_err("CUDA registration registry is poisoned"))?; + if let Some(count) = registered_pages.get_mut(&page) { + *count += 1; + return Ok(()); + } + + unsafe { + check( + (self.cu_ctx_set_current)(self.context), + self.cu_get_error_string, + "cuCtxSetCurrent", + )?; + check( + (self.cu_mem_host_register)(page as *mut c_void, self.page_size, 0), + self.cu_get_error_string, + "cuMemHostRegister", + )?; + } + registered_pages.insert(page, 1); + Ok(()) + } + + fn release_pages(&self, pages: &[usize]) { + let Ok(mut registered_pages) = self.registered_pages.lock() else { + eprintln!("CUDA registration registry is poisoned during ZShmPool cleanup"); + return; + }; + for page in pages { + let Some(count) = registered_pages.get_mut(page) else { + continue; + }; + if *count > 1 { + *count -= 1; + continue; + } + registered_pages.remove(page); + unsafe { + let _ = (self.cu_ctx_set_current)(self.context); + let rc = (self.cu_mem_host_unregister)(*page as *mut c_void); + if rc != CUDA_SUCCESS { + eprintln!( + "CUDA cuMemHostUnregister failed during ZShmPool cleanup: {}", + cuda_error_message(rc, self.cu_get_error_string) + ); + } + } + } + } +} + +impl Drop for CudaDriver { + fn drop(&mut self) { + unsafe { + let _ = (self.cu_ctx_set_current)(self.context); + let _ = (self.cu_device_primary_ctx_release)(self.device); + } + } +} + +#[derive(Clone)] +pub(crate) struct CudaRegisteredMemory { + #[allow(dead_code)] + inner: Arc, +} + +struct CudaRegisteredMemoryInner { + driver: Arc, + pages: Vec, +} + +unsafe impl Send for CudaRegisteredMemory {} +unsafe impl Sync for CudaRegisteredMemory {} + +impl Drop for CudaRegisteredMemoryInner { + fn drop(&mut self) { + self.driver.release_pages(&self.pages); + } +} + +#[cfg(target_os = "linux")] +unsafe fn symbol(lib: *mut c_void, name: &str) -> PyResult { + let c_name = CString::new(name).unwrap(); + let ptr = unsafe { dlsym(lib, c_name.as_ptr()) }; + if ptr.is_null() { + return Err(PyRuntimeError::new_err(format!( + "CUDA driver symbol {name} is unavailable" + ))); + } + Ok(unsafe { std::mem::transmute_copy(&ptr) }) +} + +unsafe fn check( + rc: CuResult, + cu_get_error_string: CuGetErrorString, + operation: &str, +) -> PyResult<()> { + if rc == CUDA_SUCCESS { + Ok(()) + } else { + Err(PyRuntimeError::new_err(format!( + "{operation} failed: {}", + cuda_error_message(rc, cu_get_error_string) + ))) + } +} + +#[cfg(target_os = "linux")] +fn page_size() -> usize { + const SC_PAGESIZE: c_int = 30; + let value = unsafe { sysconf(SC_PAGESIZE) }; + if value > 0 { + value as usize + } else { + 4096 + } +} + +#[cfg(not(target_os = "linux"))] +fn page_size() -> usize { + 4096 +} + +fn cuda_error_message(rc: CuResult, cu_get_error_string: CuGetErrorString) -> String { + let mut ptr = ptr::null(); + let lookup = unsafe { cu_get_error_string(rc, &mut ptr) }; + if lookup == CUDA_SUCCESS && !ptr.is_null() { + let msg = unsafe { CStr::from_ptr(ptr) }.to_string_lossy(); + format!("CUDA error {rc}: {msg}") + } else { + format!("CUDA error {rc}") + } +} + +#[cfg(test)] +pub(crate) mod test_support { + use std::{ + os::raw::c_void, + sync::{Arc, Mutex, OnceLock}, + }; + + use super::*; + + #[derive(Clone, Debug, PartialEq, Eq)] + pub(crate) enum Event { + Register(usize, usize), + Unregister(usize), + } + + static EVENTS: OnceLock>> = OnceLock::new(); + static TEST_LOCK: OnceLock> = OnceLock::new(); + + fn events() -> &'static Mutex> { + EVENTS.get_or_init(|| Mutex::new(Vec::new())) + } + + pub(crate) fn test_lock() -> std::sync::MutexGuard<'static, ()> { + TEST_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap() + } + + pub(crate) fn take_events() -> Vec { + std::mem::take(&mut *events().lock().unwrap()) + } + + unsafe extern "C" fn fake_ctx_set_current(_context: CuContext) -> CuResult { + CUDA_SUCCESS + } + + unsafe extern "C" fn fake_ctx_release(_device: CuDevice) -> CuResult { + CUDA_SUCCESS + } + + unsafe extern "C" fn fake_register( + ptr: *mut c_void, + len: usize, + _flags: c_uint, + ) -> CuResult { + events() + .lock() + .unwrap() + .push(Event::Register(ptr as usize, len)); + CUDA_SUCCESS + } + + unsafe extern "C" fn fake_unregister(ptr: *mut c_void) -> CuResult { + events() + .lock() + .unwrap() + .push(Event::Unregister(ptr as usize)); + CUDA_SUCCESS + } + + unsafe extern "C" fn fake_error_string( + _rc: CuResult, + out: *mut *const c_char, + ) -> CuResult { + unsafe { + *out = c"fake cuda error".as_ptr(); + } + CUDA_SUCCESS + } + + pub(crate) fn fake_driver() -> Arc { + take_events(); + Arc::new(CudaDriver { + _lib: ptr::null_mut(), + device: 0, + context: ptr::null_mut(), + page_size: 4096, + registered_pages: Mutex::new(BTreeMap::new()), + cu_device_primary_ctx_release: fake_ctx_release, + cu_ctx_set_current: fake_ctx_set_current, + cu_mem_host_register: fake_register, + cu_mem_host_unregister: fake_unregister, + cu_get_error_string: fake_error_string, + }) + } +} + +#[cfg(test)] +mod tests { + use super::test_support::{fake_driver, take_events, test_lock, Event}; + + #[test] + fn overlapping_page_registrations_are_shared_until_last_owner_drops() { + let _guard = test_lock(); + let driver = fake_driver(); + let first = driver.register(0x1008usize as *mut u8, 16).unwrap(); + let second = driver.register(0x1080usize as *mut u8, 16).unwrap(); + + assert_eq!(take_events(), vec![Event::Register(0x1000, 4096)]); + + drop(first); + assert_eq!(take_events(), Vec::::new()); + + drop(second); + assert_eq!(take_events(), vec![Event::Unregister(0x1000)]); + } + + #[test] + fn cloned_registration_keeps_pages_pinned_until_last_clone_drops() { + let _guard = test_lock(); + let driver = fake_driver(); + let registration = driver.register(0x2008usize as *mut u8, 16).unwrap(); + assert_eq!(take_events(), vec![Event::Register(0x2000, 4096)]); + + let clone = registration.clone(); + drop(registration); + assert_eq!(take_events(), Vec::::new()); + + drop(clone); + assert_eq!(take_events(), vec![Event::Unregister(0x2000)]); + } +} diff --git a/src/ext.rs b/src/ext.rs index d2de6084..db970f51 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -170,7 +170,7 @@ fn serialize_impl( Ok(()) }; match tp { - SupportedType::ZBytes => serializer.serialize(obj.extract::()?.0), + SupportedType::ZBytes => serializer.serialize(obj.extract::()?.inner), // SAFETY: bytes are immediately copied SupportedType::ByteArray => { serializer.serialize(unsafe { obj.downcast::()?.as_bytes() }) @@ -322,7 +322,10 @@ fn deserialize_impl( }; Ok(match tp { SupportedType::ZBytes => { - ZBytes(deserializer.deserialize::>()?.into()).into_py_any(py)? + ZBytes::from(zenoh::bytes::ZBytes::from( + deserializer.deserialize::>()?, + )) + .into_py_any(py)? } SupportedType::ByteArray => { PyByteArray::new(py, &deserializer.deserialize::>()?).into_py_any(py)? @@ -441,7 +444,7 @@ fn deserialize_collection( #[pyfunction] pub(crate) fn z_deserialize(tp: &Bound, zbytes: &ZBytes) -> PyResult { - let mut deserializer = ZDeserializer::new(&zbytes.0); + let mut deserializer = ZDeserializer::new(&zbytes.inner); deserialize(&mut deserializer, tp).map_err(|err| err.0) } diff --git a/src/lib.rs b/src/lib.rs index 09235cce..b06039d8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,9 +13,12 @@ // // TODO https://github.com/eclipse-zenoh/zenoh-python/pull/235#discussion_r1644498390 // mod logging; +mod buffer; mod bytes; mod cancellation; mod config; +#[cfg(feature = "shared-memory")] +mod cuda_shm; #[cfg(feature = "zenoh-ext")] mod ext; mod handlers; @@ -57,7 +60,7 @@ pub(crate) mod zenoh { #[pymodule_export] use crate::{ - bytes::{Encoding, ZBytes}, + bytes::{Encoding, ZBytes, ZBytesSegment}, cancellation::CancellationToken, config::{Config, WhatAmI, WhatAmIMatcher, ZenohId}, handlers::Handler, @@ -106,7 +109,7 @@ pub(crate) mod zenoh { #[pymodule_export] use crate::shm::{ AllocAlignment, BlockOn, Deallocate, Defragment, GarbageCollect, JustAlloc, - MemoryLayout, ShmProvider, ZShm, ZShmMut, + MemoryLayout, ShmProvider, ZShm, ZShmMut, ZShmPool, ZShmPoolBuf, }; } diff --git a/src/shm.rs b/src/shm.rs index 0f067033..7875dd47 100644 --- a/src/shm.rs +++ b/src/shm.rs @@ -1,13 +1,24 @@ -use std::{num::NonZeroUsize, str, sync::Arc}; +use std::{ + collections::HashSet, + num::NonZeroUsize, + os::raw::c_int, + slice, str, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; use pyo3::{ - exceptions::{PyTypeError, PyValueError}, + exceptions::{PyBufferError, PyRuntimeError, PyTypeError, PyValueError}, + ffi, prelude::*, types::{PyByteArray, PyBytes, PySlice, PyString, PyType}, }; use zenoh::shm::{ChunkAllocResult, PosixShmProviderBackend, ShmBuf}; use crate::{ + buffer::{fill_readonly_u8_buffer, fill_writable_u8_buffer, release_u8_buffer}, macros::{downcast_or_new, wrapper, zerror}, utils::{wait, IntoPyResult, MapInto}, }; @@ -165,6 +176,318 @@ impl MemoryLayout { } } +static NEXT_POOL_ID: AtomicUsize = AtomicUsize::new(1); + +struct ZShmPoolInner { + id: usize, + provider: zenoh::shm::ShmProvider, + cuda: Option>, +} + +#[pyclass] +pub(crate) struct ZShmPool { + inner: Arc, +} + +impl ZShmPool { + fn seal_to_zbytes_impl( + &self, + py: Python, + buffers: &Bound, + allow_exports: bool, + method_name: &str, + ) -> PyResult { + let mut handles = Vec::new(); + let mut seen = HashSet::new(); + + for (index, item) in buffers.try_iter()?.enumerate() { + let item = item?; + let buf = item.downcast_exact::().map_err(|_| { + PyTypeError::new_err(format!("segment {index} must be a pool-owned ZShmPoolBuf")) + })?; + let borrow = buf.borrow(); + borrow.check_unsealed()?; + if borrow.pool.id != self.inner.id { + return Err(PyRuntimeError::new_err(format!( + "segment {index} belongs to a different pool" + ))); + } + if !allow_exports && borrow.exports != 0 { + return Err(PyBufferError::new_err(format!( + "segment {index} cannot be sealed while Python buffer exports exist" + ))); + } + if !seen.insert(buf.as_ptr() as usize) { + return Err(PyRuntimeError::new_err(format!( + "segment {index} repeats a duplicate mutable SHM buffer" + ))); + } + handles.push(buf.clone().unbind()); + } + + if handles.is_empty() { + return Err(PyValueError::new_err(format!( + "{method_name} requires a non-empty input" + ))); + } + + let mut writer = zenoh::bytes::ZBytes::writer(); + let mut cuda_registrations = Vec::new(); + for handle in handles { + let mut buf = handle.bind(py).borrow_mut(); + let (shm, cuda_registration) = buf.take(allow_exports)?; + writer.append(shm.into()); + if let Some(cuda_registration) = cuda_registration { + cuda_registrations.push(cuda_registration); + } + } + #[cfg(feature = "shared-memory")] + { + Ok(crate::bytes::ZBytes::with_cuda_registrations( + writer.finish(), + cuda_registrations, + )) + } + #[cfg(not(feature = "shared-memory"))] + { + let _ = cuda_registrations; + Ok(crate::bytes::ZBytes::from(writer.finish())) + } + } +} + +#[pymethods] +impl ZShmPool { + #[new] + #[pyo3(signature = (pool_size = 268435456, *, cuda_pinned = false, cuda_device = 0, alignment = None))] + fn new( + py: Python, + pool_size: usize, + cuda_pinned: bool, + cuda_device: i32, + alignment: Option, + ) -> PyResult { + let layout: zenoh::shm::MemoryLayout = if let Some(alignment) = alignment { + (pool_size, alignment.0).try_into() + } else { + pool_size.try_into() + } + .into_pyres()?; + let provider = wait(py, zenoh::shm::ShmProviderBuilder::default_backend(layout))?; + let cuda = if cuda_pinned { + Some(crate::cuda_shm::CudaDriver::new(cuda_device)?) + } else { + None + }; + + Ok(Self { + inner: Arc::new(ZShmPoolInner { + id: NEXT_POOL_ID.fetch_add(1, Ordering::Relaxed), + provider, + cuda, + }), + }) + } + + #[pyo3(signature = (size, alignment = None))] + fn alloc( + &self, + py: Python, + size: usize, + alignment: Option, + ) -> PyResult { + let layout: zenoh::shm::MemoryLayout = if let Some(alignment) = alignment { + (size, alignment.0).try_into() + } else { + size.try_into() + } + .into_pyres()?; + let builder = self.inner.provider.alloc(layout); + let mut buf: zenoh::shm::ZShmMut = wait(py, builder)?; + let cuda_registration = if let Some(cuda) = &self.inner.cuda { + let bytes = buf.as_mut(); + Some(cuda.register(bytes.as_mut_ptr(), bytes.len())?) + } else { + None + }; + + Ok(ZShmPoolBuf { + cuda_registration, + buf: Some(buf), + pool: self.inner.clone(), + exports: 0, + }) + } + + fn seal_to_zbytes(&self, py: Python, buffers: &Bound) -> PyResult { + self.seal_to_zbytes_impl(py, buffers, false, "seal_to_zbytes") + } + + fn seal_to_zbytes_unchecked( + &self, + py: Python, + buffers: &Bound, + ) -> PyResult { + self.seal_to_zbytes_impl(py, buffers, true, "seal_to_zbytes_unchecked") + } + + #[getter] + fn cuda_pinned(&self) -> bool { + self.inner.cuda.is_some() + } +} + +#[pyclass] +pub(crate) struct ZShmPoolBuf { + cuda_registration: Option, + buf: Option, + pool: Arc, + exports: usize, +} + +impl ZShmPoolBuf { + fn sealed_error() -> PyErr { + zerror!("ZShmPoolBuf has been sealed into ZBytes") + } + + fn get(&self) -> PyResult<&zenoh::shm::ZShmMut> { + self.buf.as_ref().ok_or_else(Self::sealed_error) + } + + fn get_mut(&mut self) -> PyResult<&mut zenoh::shm::ZShmMut> { + self.buf.as_mut().ok_or_else(Self::sealed_error) + } + + fn check_unsealed(&self) -> PyResult<()> { + self.get().map(|_| ()) + } + + fn take( + &mut self, + allow_exports: bool, + ) -> PyResult<( + zenoh::shm::ZShmMut, + Option, + )> { + if !allow_exports && self.exports != 0 { + return Err(PyBufferError::new_err( + "cannot seal ZShmPoolBuf while Python buffer exports exist", + )); + } + let cuda_registration = self.cuda_registration.take(); + let buf = self.buf.take().ok_or_else(Self::sealed_error)?; + Ok((buf, cuda_registration)) + } +} + +#[pymethods] +impl ZShmPoolBuf { + #[getter] + fn ptr(&self) -> PyResult { + Ok(self.get()?.as_ref().as_ptr() as usize) + } + + #[getter] + fn is_sealed(&self) -> bool { + self.buf.is_none() + } + + fn is_valid(&self) -> PyResult { + Ok(self.get()?.is_valid()) + } + + fn __len__(&self) -> PyResult { + Ok(self.get()?.len()) + } + + fn __bytes__<'py>(&self, py: Python<'py>) -> PyResult> { + Ok(PyBytes::new(py, self.get()?)) + } + + fn __str__<'py>(&self, py: Python<'py>) -> PyResult> { + Ok(PyString::new(py, str::from_utf8(self.get()?).into_pyres()?)) + } + + fn __repr__(&self) -> PyResult { + Ok(format!("ZShmPoolBuf({:?})", self.get()?)) + } + + fn __setitem__(&mut self, key: &Bound, value: &Bound) -> PyResult<()> { + if self.exports != 0 { + return Err(PyBufferError::new_err( + "cannot mutate ZShmPoolBuf while Python buffer exports exist", + )); + } + if let Ok(key) = key.extract::() { + if let Ok(value) = value.extract::() { + if let Some(entry) = self.get_mut()?.get_mut(key) { + *entry = value; + return Ok(()); + } + } + } else if let Ok(key) = key.downcast::() { + let slice = self.get_mut()?; + let indices = key.indices(slice.len() as isize)?; + let mut copy_bytes = |b: &[u8]| { + if b.len() != indices.slicelength { + return Err(PyValueError::new_err( + "memoryview assignment: lvalue and rvalue have different structures", + )); + } + let mut target = indices.start; + for byte in b { + slice[target as usize] = *byte; + target += indices.step; + } + Ok(()) + }; + if let Ok(bytes) = value.downcast::() { + return copy_bytes(unsafe { bytes.as_bytes() }); + } else if let Ok(bytes) = value.downcast::() { + return copy_bytes(bytes.as_bytes()); + } + } + Err(PyTypeError::new_err("expected bytes like argument")) + } + + unsafe fn __getbuffer__( + slf: Bound<'_, Self>, + view: *mut ffi::Py_buffer, + flags: c_int, + ) -> PyResult<()> { + let (ptr, len) = { + let mut this = slf.borrow_mut(); + let buf = this.get_mut()?; + let bytes = buf.as_mut(); + let ptr = bytes.as_mut_ptr(); + let len = bytes.len(); + this.exports += 1; + (ptr, len) + }; + let bytes = if len == 0 { + &mut [] + } else { + // SAFETY: `slf` owns the ZShmPoolBuf handle and export tracking + // prevents sealing the buffer while Python holds this view. + unsafe { slice::from_raw_parts_mut(ptr, len) } + }; + match unsafe { fill_writable_u8_buffer(slf.clone().into_any(), bytes, view, flags) } { + Ok(()) => Ok(()), + Err(err) => { + slf.borrow_mut().exports -= 1; + Err(err) + } + } + } + + unsafe fn __releasebuffer__(&mut self, view: *mut ffi::Py_buffer) { + if self.exports > 0 { + self.exports -= 1; + } + unsafe { release_u8_buffer(view) } + } +} + wrapper!(zenoh::shm::ShmProvider); #[pymethods] @@ -225,6 +548,30 @@ impl ZShm { fn __bytes__<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { PyBytes::new(py, &self.0) } + + unsafe fn __getbuffer__( + slf: Bound<'_, Self>, + view: *mut ffi::Py_buffer, + flags: c_int, + ) -> PyResult<()> { + let (ptr, len) = { + let shm = slf.borrow(); + let bytes: &[u8] = shm.0.as_ref(); + (bytes.as_ptr(), bytes.len()) + }; + let bytes = if len == 0 { + &[] + } else { + // SAFETY: `slf` owns the ZShm handle and keeps the mapped SHM + // buffer alive for at least as long as the exported buffer. + unsafe { slice::from_raw_parts(ptr, len) } + }; + unsafe { fill_readonly_u8_buffer(slf.into_any(), bytes, view, flags) } + } + + unsafe fn __releasebuffer__(&self, view: *mut ffi::Py_buffer) { + unsafe { release_u8_buffer(view) } + } } #[pyclass] @@ -233,7 +580,7 @@ pub(crate) struct ZShmMut { } impl ZShmMut { - fn get(&self) -> PyResult<&zenoh::shm::ZShmMut> { + pub(crate) fn get(&self) -> PyResult<&zenoh::shm::ZShmMut> { self.buf .as_ref() .ok_or_else(|| zerror!("ZShmMut has been consumed by ZBytes conversion")) diff --git a/tests/test_session.py b/tests/test_session.py index 14f91880..412a3a0d 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -11,6 +11,7 @@ # Contributors: # ZettaScale Zenoh Team, # +import gc import json import time from typing import List, Tuple @@ -147,17 +148,32 @@ def run_session_pubsub(peer01: Session, peer02: Session): keyexpr = "test_pub/session" msg = "Pub Message".encode() + def make_payload(): + return zenoh.ZBytes.from_segments( + [ + memoryview(bytearray(b"Pub ")).toreadonly(), + memoryview(bytearray(b"Message")).toreadonly(), + ] + ) + + payload = make_payload() + gc.collect() + num_received = 0 num_errors = 0 + retained_segments = None def sub_callback(sample: Sample): nonlocal num_received nonlocal num_errors + nonlocal retained_segments + if retained_segments is None: + retained_segments = sample.payload.segments() if ( sample.key_expr != keyexpr or sample.priority != Priority.DATA_HIGH or sample.congestion_control != CongestionControl.BLOCK - or bytes(sample.payload) != msg + or b"".join(map(bytes, sample.payload.segments())) != msg ): num_errors += 1 num_received += 1 @@ -173,12 +189,15 @@ def sub_callback(sample: Sample): time.sleep(SLEEP) for _ in range(0, MSG_COUNT): - publisher.put("Pub Message") + publisher.put(payload) time.sleep(SLEEP) print(f"[PS][02d] Received on peer02 session. {num_received}/{MSG_COUNT} msgs.") assert num_received == MSG_COUNT assert num_errors == 0 + gc.collect() + assert retained_segments is not None + assert b"".join(map(bytes, retained_segments)) == msg print("[PS][03d] Undeclare publisher on peer01 session") publisher.undeclare() diff --git a/tests/test_shm_pool.py b/tests/test_shm_pool.py new file mode 100644 index 00000000..f4739f96 --- /dev/null +++ b/tests/test_shm_pool.py @@ -0,0 +1,335 @@ +# +# Copyright (c) 2026 ZettaScale Technology +# +# This program and the accompanying materials are made available under the +# terms of the Eclipse Public License 2.0 which is available at +# http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +# which is available at https://www.apache.org/licenses/LICENSE-2.0. +# +# SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +# +# Contributors: +# ZettaScale Zenoh Team, +# +import gc +import time + +import pytest + + +def require_shm(): + return pytest.importorskip("zenoh.shm") + + +def make_cuda_pinned_payload(data: bytes, *, unchecked: bool = False): + torch = pytest.importorskip("torch") + if not torch.cuda.is_available(): + pytest.skip("CUDA is unavailable") + torch.cuda.set_device(0) + torch.empty(1, device="cuda") + shm = require_shm() + try: + pool = shm.ZShmPool(pool_size=4096, cuda_pinned=True, cuda_device=0) + except RuntimeError as exc: + pytest.skip(f"CUDA pinned SHM unavailable: {exc}") + buf = pool.alloc(len(data)) + view = memoryview(buf) + view[:] = data + if unchecked: + payload = pool.seal_to_zbytes_unchecked([buf]) + view.release() + return torch, payload + view.release() + return torch, pool.seal_to_zbytes([buf]) + + +def wait_for_sample(samples): + deadline = time.monotonic() + 5 + while time.monotonic() < deadline: + if samples: + return samples[0] + time.sleep(0.01) + pytest.fail("timed out waiting for sample") + + +def assert_cuda_pinned_zbytes(torch, payload, expected: bytes): + assert bytes(payload) == expected + shm_payload = payload.as_shm() + assert shm_payload is not None + tensor = torch.frombuffer(memoryview(shm_payload), dtype=torch.uint8) + assert tensor.is_pinned() + + +def test_zshm_pool_alloc_exposes_writable_buffer_protocol(): + shm = require_shm() + pool = shm.ZShmPool(pool_size=4096) + + buf = pool.alloc(5) + view = memoryview(buf) + view[:] = b"hello" + + assert not view.readonly + assert view.format == "B" + assert bytes(buf) == b"hello" + assert buf.ptr > 0 + assert len(buf) == 5 + assert buf.is_valid() + assert not buf.is_sealed + + +def test_zshm_pool_seal_to_zbytes_returns_shm_payload(): + shm = require_shm() + pool = shm.ZShmPool(pool_size=4096) + buf = pool.alloc(5) + memoryview(buf)[:] = b"hello" + + payload = pool.seal_to_zbytes([buf]) + + assert bytes(payload) == b"hello" + assert payload.as_shm() is not None + assert buf.is_sealed + with pytest.raises(Exception, match="sealed|consumed"): + buf[0] = 0 + with pytest.raises(Exception, match="sealed|consumed"): + memoryview(buf) + + +def test_zshm_pool_seal_rejects_live_export_then_succeeds_after_release(): + shm = require_shm() + pool = shm.ZShmPool(pool_size=4096) + buf = pool.alloc(5) + view = memoryview(buf) + view[:] = b"hello" + + with pytest.raises(BufferError, match="exports exist"): + pool.seal_to_zbytes([buf]) + + view.release() + payload = pool.seal_to_zbytes([buf]) + + assert bytes(payload) == b"hello" + assert payload.as_shm() is not None + + +def test_zshm_pool_seal_unchecked_allows_live_export(): + shm = require_shm() + pool = shm.ZShmPool(pool_size=4096) + buf = pool.alloc(5) + view = memoryview(buf) + view[:] = b"hello" + + with pytest.raises(BufferError, match="exports exist"): + pool.seal_to_zbytes([buf]) + + payload = pool.seal_to_zbytes_unchecked([buf]) + + assert bytes(payload) == b"hello" + assert payload.as_shm() is not None + assert buf.is_sealed + with pytest.raises(Exception, match="sealed|consumed"): + memoryview(buf) + + view.release() + with pytest.raises(Exception, match="sealed|consumed"): + buf[0] = 0 + + +def test_zshm_pool_buf_supports_stepped_and_reverse_slice_assignment(): + shm = require_shm() + pool = shm.ZShmPool(pool_size=4096) + buf = pool.alloc(6) + buf[:] = b"abcdef" + + buf[::2] = b"XYZ" + assert bytes(buf) == b"XbYdZf" + + buf[::-1] = b"fedcba" + assert bytes(buf) == b"abcdef" + + +def test_zshm_pool_seal_rejects_empty_and_non_pool_buffers(): + shm = require_shm() + pool = shm.ZShmPool(pool_size=4096) + + with pytest.raises(ValueError, match="empty"): + pool.seal_to_zbytes([]) + with pytest.raises(TypeError, match="pool-owned|ZShmPoolBuf"): + pool.seal_to_zbytes([memoryview(b"hello")]) + + +def test_zshm_pool_seal_rejects_duplicate_buffer_without_consuming(): + shm = require_shm() + pool = shm.ZShmPool(pool_size=4096) + buf = pool.alloc(5) + buf[:] = b"hello" + + with pytest.raises(RuntimeError, match="duplicate|repeats"): + pool.seal_to_zbytes([buf, buf]) + + assert bytes(buf) == b"hello" + assert not buf.is_sealed + + +def test_zshm_pool_seal_rejects_buffer_from_another_pool(): + shm = require_shm() + pool_a = shm.ZShmPool(pool_size=4096) + pool_b = shm.ZShmPool(pool_size=4096) + buf = pool_a.alloc(5) + buf[:] = b"hello" + + with pytest.raises(RuntimeError, match="different pool"): + pool_b.seal_to_zbytes([buf]) + + assert bytes(buf) == b"hello" + assert not buf.is_sealed + + +def test_zshm_pool_instances_are_independent(): + shm = require_shm() + small = shm.ZShmPool(pool_size=4096, cuda_pinned=False) + large = shm.ZShmPool(pool_size=8192, cuda_pinned=False) + + left = small.alloc(4) + right = large.alloc(5) + left[:] = b"left" + right[:] = b"right" + + assert bytes(small.seal_to_zbytes([left])) == b"left" + assert bytes(large.seal_to_zbytes([right])) == b"right" + + +def test_zshm_pool_cpu_buffer_is_not_torch_pinned_when_available(): + torch = pytest.importorskip("torch") + shm = require_shm() + pool = shm.ZShmPool(pool_size=4096, cuda_pinned=False) + buf = pool.alloc(16) + + tensor = torch.frombuffer(memoryview(buf), dtype=torch.uint8) + + assert not tensor.is_pinned() + + +def test_zshm_pool_cuda_pinned_buffer_is_torch_pinned_when_available(): + torch = pytest.importorskip("torch") + if not torch.cuda.is_available(): + pytest.skip("CUDA is unavailable") + torch.cuda.set_device(0) + torch.empty(1, device="cuda") + shm = require_shm() + try: + pool = shm.ZShmPool(pool_size=4096, cuda_pinned=True, cuda_device=0) + except RuntimeError as exc: + pytest.skip(f"CUDA pinned SHM unavailable: {exc}") + buf = pool.alloc(16) + + view = memoryview(buf) + tensor = torch.frombuffer(view, dtype=torch.uint8) + + assert tensor.is_pinned() + del tensor + del view + gc.collect() + + payload = pool.seal_to_zbytes([buf]) + assert payload.as_shm() is not None + assert torch.frombuffer(memoryview(payload.as_shm()), dtype=torch.uint8).is_pinned() + + +def test_zshm_pool_cuda_pinned_supports_d2h_copy_when_available(): + torch = pytest.importorskip("torch") + if not torch.cuda.is_available(): + pytest.skip("CUDA is unavailable") + torch.cuda.set_device(0) + shm = require_shm() + try: + pool = shm.ZShmPool(pool_size=4096, cuda_pinned=True, cuda_device=0) + except RuntimeError as exc: + pytest.skip(f"CUDA pinned SHM unavailable: {exc}") + buf = pool.alloc(16) + host = torch.frombuffer(memoryview(buf), dtype=torch.uint8) + device = torch.arange(16, dtype=torch.uint8, device="cuda") + + assert host.is_pinned() + host.copy_(device, non_blocking=True) + torch.cuda.synchronize() + + assert bytes(buf) == bytes(range(16)) + + +def test_zshm_pool_cuda_registration_survives_session_put_wrapper_drop(): + zenoh = pytest.importorskip("zenoh") + torch, payload = make_cuda_pinned_payload(b"session-payload") + samples = [] + key_expr = "test/zshm_pool/session_put" + session = zenoh.open(zenoh.Config()) + subscriber = session.declare_subscriber(key_expr, lambda sample: samples.append(sample)) + try: + session.put(key_expr, payload) + del payload + gc.collect() + + sample = wait_for_sample(samples) + assert_cuda_pinned_zbytes(torch, sample.payload, b"session-payload") + finally: + subscriber.undeclare() + session.close() + + +def test_zshm_pool_cuda_registration_survives_unchecked_session_put_wrapper_drop(): + zenoh = pytest.importorskip("zenoh") + torch, payload = make_cuda_pinned_payload(b"unchecked-session-payload", unchecked=True) + samples = [] + key_expr = "test/zshm_pool/unchecked_session_put" + session = zenoh.open(zenoh.Config()) + subscriber = session.declare_subscriber(key_expr, lambda sample: samples.append(sample)) + try: + session.put(key_expr, payload) + del payload + gc.collect() + + sample = wait_for_sample(samples) + assert_cuda_pinned_zbytes(torch, sample.payload, b"unchecked-session-payload") + finally: + subscriber.undeclare() + session.close() + + +def test_zshm_pool_cuda_registration_survives_publisher_put_wrapper_drop(): + zenoh = pytest.importorskip("zenoh") + torch, payload = make_cuda_pinned_payload(b"publisher-payload") + samples = [] + key_expr = "test/zshm_pool/publisher_put" + session = zenoh.open(zenoh.Config()) + subscriber = session.declare_subscriber(key_expr, lambda sample: samples.append(sample)) + publisher = session.declare_publisher(key_expr) + try: + publisher.put(payload) + del payload + gc.collect() + + sample = wait_for_sample(samples) + assert_cuda_pinned_zbytes(torch, sample.payload, b"publisher-payload") + finally: + publisher.undeclare() + subscriber.undeclare() + session.close() + + +def test_zshm_pool_cuda_registration_survives_attachment_wrapper_drop(): + zenoh = pytest.importorskip("zenoh") + torch, attachment = make_cuda_pinned_payload(b"attachment-payload") + samples = [] + key_expr = "test/zshm_pool/attachment" + session = zenoh.open(zenoh.Config()) + subscriber = session.declare_subscriber(key_expr, lambda sample: samples.append(sample)) + try: + session.put(key_expr, b"body", attachment=attachment) + del attachment + gc.collect() + + sample = wait_for_sample(samples) + assert sample.attachment is not None + assert_cuda_pinned_zbytes(torch, sample.attachment, b"attachment-payload") + finally: + subscriber.undeclare() + session.close() diff --git a/tests/test_zbytes_segments.py b/tests/test_zbytes_segments.py new file mode 100644 index 00000000..5c0f7bd5 --- /dev/null +++ b/tests/test_zbytes_segments.py @@ -0,0 +1,468 @@ +# +# Copyright (c) 2026 ZettaScale Technology +# +# This program and the accompanying materials are made available under the +# terms of the Eclipse Public License 2.0 which is available at +# http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +# which is available at https://www.apache.org/licenses/LICENSE-2.0. +# +# SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +# +import gc +import sys +from array import array + +import pytest + +from zenoh import ZBytes, ZBytesSegment + + +class RecordingLeaseSink: + def __init__(self, *, raise_on_release=False): + self.raise_on_release = raise_on_release + self.released = [] + + def release(self, lease_id): + self.released.append(lease_id) + if self.raise_on_release: + raise RuntimeError("release failed") + + +class LeaseState: + def __init__(self, sink, lease_id): + self.sink = sink + self.lease_id = lease_id + + +def make_lease(lease_id="lease-1", *, raise_on_release=False): + sink = RecordingLeaseSink(raise_on_release=raise_on_release) + return sink, LeaseState(sink, lease_id) + + +@pytest.mark.parametrize( + "segments", + [ + [b"hello", b"world"], + [bytearray(b"hello"), bytearray(b"world")], + [memoryview(b"hello"), memoryview(bytearray(b"world"))], + [array("B", b"hello"), array("b", b"world")], + [b"", b"hello", b"", b"world"], + [bytes([i]) for i in range(256)], + ], +) +def test_from_segments_copies_byte_compatible_buffers(segments): + payload = ZBytes.from_segments(segments, copy=True) + expected = b"helloworld" if len(segments) < 10 else bytes(range(256)) + + assert bytes(payload) == expected + + +def test_from_segments_preserves_owned_segment_layout(): + payload = ZBytes.from_segments([b"hello", b"world"], copy=True) + + assert tuple(map(bytes, payload.segments())) == (b"hello", b"world") + + +def test_from_segments_constructs_zero_copy_payload_from_immutable_bytes(): + def make_payload(): + return ZBytes.from_segments([b"hello", b"world"]) + + payload = make_payload() + gc.collect() + + assert bytes(payload) == b"helloworld" + assert tuple(map(bytes, payload.segments())) == (b"hello", b"world") + + +@pytest.mark.parametrize( + ("segment", "expected"), + [ + (memoryview(b"hello"), b"hello"), + (memoryview(b"hello")[1:], b"ello"), + ], +) +def test_from_segments_constructs_zero_copy_payload_from_bytes_memoryview( + segment, expected +): + payload = ZBytes.from_segments([segment]) + + assert bytes(payload) == expected + + +def test_zero_copy_payload_keeps_python_bytes_owner_alive(): + owner = bytes(bytearray(b"hello")) + initial_refcount = sys.getrefcount(owner) + + payload = ZBytes.from_segments([owner]) + + assert sys.getrefcount(owner) > initial_refcount + del payload + gc.collect() + assert sys.getrefcount(owner) == initial_refcount + + +def test_zero_copy_payload_keeps_readonly_buffer_export_alive(): + owner = bytearray(b"hello") + payload = ZBytes.from_segments([memoryview(owner).toreadonly()]) + + with pytest.raises(BufferError): + owner.extend(b"!") + + assert bytes(payload) == b"hello" + del payload + gc.collect() + owner.extend(b"!") + assert owner == b"hello!" + + +def test_zero_copy_payload_lease_releases_after_payload_drop(): + sink, lease = make_lease("slot-1") + payload = ZBytes.from_segments([b"hello"], lease=lease) + + assert sink.released == [] + + del payload + gc.collect() + + assert sink.released == ["slot-1"] + + +def test_zero_copy_payload_lease_is_shared_by_multiple_segments(): + sink, lease = make_lease("slot-2") + payload = ZBytes.from_segments([b"hello", memoryview(b"world")], lease=lease) + + assert bytes(payload) == b"helloworld" + del payload + gc.collect() + + assert sink.released == ["slot-2"] + + +@pytest.mark.parametrize( + "view_factory", + [ + lambda payload: payload.segments(), + lambda payload: payload.memoryviews(), + ], +) +def test_zero_copy_payload_lease_waits_for_derived_views(view_factory): + sink, lease = make_lease("slot-3") + payload = ZBytes.from_segments([b"hello", b"world"], lease=lease) + views = view_factory(payload) + + del payload + gc.collect() + assert sink.released == [] + + assert b"".join(map(bytes, views)) == b"helloworld" + del views + gc.collect() + + assert sink.released == ["slot-3"] + + +def test_zero_copy_payload_lease_release_error_is_unraisable(monkeypatch): + captured = [] + + def hook(args): + captured.append(args) + + monkeypatch.setattr(sys, "unraisablehook", hook) + sink, lease = make_lease("slot-error", raise_on_release=True) + payload = ZBytes.from_segments([b"hello"], lease=lease) + + del payload + gc.collect() + + assert sink.released == ["slot-error"] + assert len(captured) == 1 + assert isinstance(captured[0].exc_value, RuntimeError) + assert captured[0].object is sink + + +@pytest.mark.parametrize( + "segment", + [ + bytearray(b"hello"), + memoryview(bytearray(b"hello")), + memoryview(b"hello")[::2], + array("I", [1, 2, 3]), + ], +) +def test_from_segments_rejects_unsupported_zero_copy_buffers(segment): + with pytest.raises(RuntimeError, match="segment 0.*use copy=True"): + ZBytes.from_segments([segment]) + + +def test_from_segments_rejects_non_buffer_segment(): + with pytest.raises(TypeError, match="segment 1 does not support"): + ZBytes.from_segments([b"hello", object()], copy=True) + + +def test_from_segments_rejects_non_byte_compatible_segment(): + with pytest.raises(TypeError, match="segment 0 has unsupported item format"): + ZBytes.from_segments([array("I", [1, 2, 3])], copy=True) + + +def test_from_segments_rejects_non_contiguous_segment_by_default(): + segment = memoryview(bytearray(b"hello"))[::2] + + with pytest.raises(TypeError, match="segment 0 is not C-contiguous"): + ZBytes.from_segments([segment], copy=True) + + +def test_from_segments_rejects_lease_with_copy(): + sink, lease = make_lease("slot-copy") + + with pytest.raises(RuntimeError, match="lease can only be used with copy=False"): + ZBytes.from_segments([b"hello"], copy=True, lease=lease) + + assert sink.released == [] + + +@pytest.mark.parametrize("segments", [[], [b""]]) +def test_from_segments_rejects_lease_without_nonempty_raw_borrow(segments): + sink, lease = make_lease("slot-empty") + + with pytest.raises(RuntimeError, match="lease requires at least one non-empty"): + ZBytes.from_segments(segments, lease=lease) + + assert sink.released == [] + + +def test_from_segments_rejects_lease_with_zbytes_segment_only(): + source = ZBytes.from_segments([b"hello"], copy=True) + (segment,) = source.segments() + sink, lease = make_lease("slot-segment") + + with pytest.raises(RuntimeError, match="lease requires at least one non-empty"): + ZBytes.from_segments([segment], lease=lease) + + assert sink.released == [] + + +def test_from_segments_rejects_invalid_lease_interface(): + with pytest.raises(TypeError, match="lease must provide a 'sink' attribute"): + ZBytes.from_segments([b"hello"], lease=object()) + + +def test_from_segments_can_copy_non_contiguous_segment_explicitly(): + segment = memoryview(bytearray(b"hello"))[::2] + + payload = ZBytes.from_segments( + [segment], + copy=True, + require_contiguous=False, + ) + + assert bytes(payload) == b"hlo" + + +def test_segments_return_zero_copy_segment_views_with_independent_lifetimes(): + payload = ZBytes.from_segments([b"hello", b"world"], copy=True) + segments = payload.segments() + + del payload + gc.collect() + + assert isinstance(segments, tuple) + assert all(isinstance(segment, ZBytesSegment) for segment in segments) + assert all(memoryview(segment).readonly for segment in segments) + assert b"".join(map(bytes, segments)) == b"helloworld" + with pytest.raises(TypeError): + memoryview(segments[0])[0] = 0 + + +def test_memoryviews_are_zero_copy_views_over_segments(): + payload = ZBytes.from_segments([b"hello", b"world"], copy=True) + views = payload.memoryviews() + + assert all(isinstance(view, memoryview) for view in views) + assert all(view.readonly for view in views) + assert tuple(map(bytes, views)) == tuple(map(bytes, payload.segments())) + + +def test_memoryviews_keep_segment_owner_alive(): + payload = ZBytes.from_segments([b"hello", b"world"], copy=True) + views = payload.memoryviews() + + del payload + gc.collect() + + assert tuple(map(bytes, views)) == (b"hello", b"world") + + +def test_copied_memoryviews_preserve_old_copy_out_behavior(): + payload = ZBytes.from_segments([b"hello", b"world"], copy=True) + views = payload.copied_memoryviews() + + del payload + gc.collect() + + assert all(isinstance(view, memoryview) for view in views) + assert tuple(map(bytes, views)) == (b"hello", b"world") + + +def test_from_segments_copies_large_payload_without_joining_inputs(): + segment = bytes(1024 * 1024) + payload = ZBytes.from_segments([segment, segment, segment, segment], copy=True) + + assert len(payload) == 4 * 1024 * 1024 + assert sum(map(len, payload.segments())) == len(payload) + + +def test_from_segments_accepts_numpy_uint8_when_available(): + numpy = pytest.importorskip("numpy") + segments = [numpy.array([1, 2, 3], dtype=numpy.uint8)] + + assert bytes(ZBytes.from_segments(segments, copy=True)) == b"\x01\x02\x03" + + +def test_from_segments_accepts_readonly_numpy_uint8_zero_copy_when_available(): + numpy = pytest.importorskip("numpy") + segment = numpy.array([1, 2, 3], dtype=numpy.uint8) + segment.flags.writeable = False + + assert bytes(ZBytes.from_segments([segment])) == b"\x01\x02\x03" + + +def test_from_segments_accepts_zbytes_segment_without_copy(): + source = ZBytes.from_segments([b"hello", b"world"], copy=True) + hello, world = source.segments() + + payload = ZBytes.from_segments([world, hello], copy=False) + + assert bytes(payload) == b"worldhello" + + +def test_from_segments_copies_zbytes_segment(): + source = ZBytes.from_segments([bytearray(b"hello")], copy=True) + (segment,) = source.segments() + + payload = ZBytes.from_segments([segment], copy=True) + + assert bytes(payload) == b"hello" + + +def test_from_segments_accepts_shm_mut_zero_copy_when_available(): + shm = pytest.importorskip("zenoh.shm") + provider = shm.ShmProvider.default_backend(4096) + buf = provider.alloc(5) + buf[:] = b"hello" + + payload = ZBytes.from_segments([buf], copy=False) + + assert bytes(payload) == b"hello" + assert payload.as_shm() is not None + with pytest.raises(Exception, match="consumed"): + bytes(buf) + + +def test_from_segments_rejects_lease_with_shm_mut_when_available(): + shm = pytest.importorskip("zenoh.shm") + provider = shm.ShmProvider.default_backend(4096) + buf = provider.alloc(5) + buf[:] = b"hello" + sink, lease = make_lease("slot-shm-mut") + + with pytest.raises(RuntimeError, match="lease cannot be used with shared-memory"): + ZBytes.from_segments([buf], copy=False, lease=lease) + + assert bytes(buf) == b"hello" + assert sink.released == [] + + +def test_from_segments_accepts_shm_zero_copy_when_available(): + shm = pytest.importorskip("zenoh.shm") + provider = shm.ShmProvider.default_backend(4096) + buf = provider.alloc(5) + buf[:] = b"hello" + original = ZBytes(buf).as_shm() + + payload = ZBytes.from_segments([original], copy=False) + + assert bytes(payload) == b"hello" + assert payload.as_shm() is not None + + +def test_from_segments_rejects_lease_with_shm_when_available(): + shm = pytest.importorskip("zenoh.shm") + provider = shm.ShmProvider.default_backend(4096) + buf = provider.alloc(5) + buf[:] = b"hello" + original = ZBytes(buf).as_shm() + sink, lease = make_lease("slot-shm") + + with pytest.raises(RuntimeError, match="lease cannot be used with shared-memory"): + ZBytes.from_segments([original], copy=False, lease=lease) + + assert sink.released == [] + + +def test_from_segments_accepts_mixed_shm_segments_when_available(): + shm = pytest.importorskip("zenoh.shm") + provider = shm.ShmProvider.default_backend(4096) + buf = provider.alloc(5) + buf[:] = b"frame" + + payload = ZBytes.from_segments([b"h", buf, b"t"], copy=False) + + assert bytes(payload) == b"hframet" + assert payload.as_shm() is None + segments = payload.segments() + assert tuple(map(bytes, segments)) == (b"h", b"frame", b"t") + assert segments[1].as_shm() is not None + + +def test_from_segments_copies_shm_mut_without_consuming_when_available(): + shm = pytest.importorskip("zenoh.shm") + provider = shm.ShmProvider.default_backend(4096) + buf = provider.alloc(5) + buf[:] = b"hello" + + payload = ZBytes.from_segments([buf], copy=True) + + assert bytes(payload) == b"hello" + assert payload.as_shm() is None + assert bytes(buf) == b"hello" + + +def test_from_segments_does_not_partially_consume_shm_mut_on_validation_error(): + shm = pytest.importorskip("zenoh.shm") + provider = shm.ShmProvider.default_backend(4096) + buf = provider.alloc(5) + buf[:] = b"hello" + + with pytest.raises(RuntimeError, match="segment 1.*use copy=True"): + ZBytes.from_segments([buf, bytearray(b"mutable")], copy=False) + + assert bytes(buf) == b"hello" + + +def test_from_segments_rejects_repeated_shm_mut_when_available(): + shm = pytest.importorskip("zenoh.shm") + provider = shm.ShmProvider.default_backend(4096) + buf = provider.alloc(5) + buf[:] = b"hello" + + with pytest.raises(RuntimeError, match="repeats the same mutable SHM"): + ZBytes.from_segments([buf, buf], copy=False) + + assert bytes(buf) == b"hello" + + +def test_zshm_and_shm_segment_export_readonly_memoryview_when_available(): + shm = pytest.importorskip("zenoh.shm") + provider = shm.ShmProvider.default_backend(4096) + buf = provider.alloc(5) + buf[:] = b"hello" + payload = ZBytes.from_segments([buf], copy=False) + zshm = payload.as_shm() + + shm_view = memoryview(zshm) + segment_view = memoryview(payload.segments()[0]) + + assert shm_view.readonly + assert segment_view.readonly + assert bytes(shm_view) == b"hello" + assert bytes(segment_view) == b"hello" diff --git a/zenoh/__init__.pyi b/zenoh/__init__.pyi index 5fe200e5..4bad41c2 100644 --- a/zenoh/__init__.pyi +++ b/zenoh/__init__.pyi @@ -11,7 +11,7 @@ # Contributors: # ZettaScale Zenoh Team, # -from collections.abc import Callable +from collections.abc import Callable, Iterable from datetime import datetime, timedelta from enum import Enum, auto from pathlib import Path @@ -2193,6 +2193,19 @@ class WhatAmIMatcher: _IntoWhatAmIMatcher = WhatAmIMatcher | str +@final +class ZBytesSegment: + """A zero-copy view over one physical ZBytes slice.""" + + @_unstable + def as_shm(self) -> shm.ZShm | None: ... + def to_bytes(self) -> bytes: + """Copy this segment into a Python bytes object.""" + + def __bool__(self) -> bool: ... + def __len__(self) -> int: ... + def __bytes__(self) -> bytes: ... + @final class ZBytes: """ZBytes represents raw bytes data that can be interpreted as strings or byte arrays. @@ -2214,6 +2227,49 @@ class ZBytes: def __new__( cls, bytes: bytearray | bytes | str | shm.ZShm | shm.ZShmMut | None = None ) -> Self: ... + @staticmethod + def from_segments( + segments: Iterable[Any], + *, + copy: bool = False, + require_contiguous: bool = True, + lease: Any | None = None, + ) -> Self: + """Build a payload from Python buffer protocol objects. + + ``copy=True`` copies each input buffer into a separate Zenoh-owned segment. + ``copy=False`` performs strict zero-copy construction for read-only, + C-contiguous, single-byte buffer exporters and raises ``RuntimeError`` for + unsupported buffers. The caller must not mutate their backing memory through + another alias while Zenoh may still reference the payload. + + With shared-memory enabled, ``copy=False`` preserves ``shm.ZShm`` and + consumes ``shm.ZShmMut`` segments. Generic memoryviews are treated as raw + borrowed buffers, not as shared-memory descriptors. + + ``lease`` may be used with ``copy=False`` raw borrowed buffers to bind an + external pool lease to Zenoh's internal payload lifetime. The object must + provide ``lease.sink`` and ``lease.lease_id``; when Zenoh releases the last + borrowed buffer reference, it calls ``lease.sink.release(lease.lease_id)``. + The release method should be non-blocking or return quickly. Shared-memory + segments have their own lifetime management and cannot be combined with a + custom lease. + """ + + def segments(self) -> tuple[ZBytesSegment, ...]: + """Return zero-copy views for the payload's physical slices. + + The returned segment views keep their backing payload memory alive. + Physical slice boundaries are an internal memory layout detail and must not + be used as application-level framing. + """ + + def memoryviews(self) -> tuple[memoryview, ...]: + """Return zero-copy memoryviews for the payload's physical slices.""" + + def copied_memoryviews(self) -> tuple[memoryview, ...]: + """Return memoryviews backed by copied Python bytes for each physical slice.""" + def to_bytes(self) -> bytes: """Return the underlying data as bytes. diff --git a/zenoh/shm.pyi b/zenoh/shm.pyi index 9a3a6c49..8cfb47ae 100644 --- a/zenoh/shm.pyi +++ b/zenoh/shm.pyi @@ -11,7 +11,11 @@ # Contributors: # ZettaScale Zenoh Team, # -from typing import Self, TypeVar, final, overload +from collections.abc import Iterable +from typing import TYPE_CHECKING, Self, TypeVar, final, overload + +if TYPE_CHECKING: + from . import ZBytes _T = TypeVar("_T") @@ -122,10 +126,75 @@ class ShmProvider: _IntoMemoryLayout = MemoryLayout | tuple[int, AllocAlignment] | int +@_unstable +@final +class ZShmPool: + """Explicit shared-memory pool for creating writable SHM payload buffers. + + When ``cuda_pinned`` is true, allocations are additionally registered with + the CUDA Driver API so CUDA-aware writers can treat the host memory as + pinned. Overlapping allocations share page registrations within the pool. + CUDA libraries are not required when ``cuda_pinned`` is false. + """ + + def __new__( + cls, + pool_size: int = 268435456, + *, + cuda_pinned: bool = False, + cuda_device: int = 0, + alignment: AllocAlignment | None = None, + ) -> Self: ... + def alloc(self, size: int, alignment: AllocAlignment | None = None) -> ZShmPoolBuf: + """Allocate a pool-owned mutable SHM buffer.""" + + def seal_to_zbytes(self, buffers: Iterable[ZShmPoolBuf]) -> "ZBytes": + """Consume pool-owned buffers and return a true SHM-backed ZBytes. + + Fails if any buffer still has active Python buffer exports. + """ + + def seal_to_zbytes_unchecked(self, buffers: Iterable[ZShmPoolBuf]) -> "ZBytes": + """Consume pool-owned buffers even when active buffer exports exist. + + Danger: the caller must guarantee all CPU/GPU writers have completed + and no existing memoryview, torch tensor, capnp view, or other alias + will write after this call. Violating that guarantee can race with + Zenoh reads/sends and produce torn payload contents. Keep the returned + ZBytes alive until any pre-existing aliases are released. + """ + + @property + def cuda_pinned(self) -> bool: ... + +@_unstable +@final +class ZShmPoolBuf: + """A mutable buffer allocated by :class:`ZShmPool`. + + It implements the writable Python buffer protocol until sealed. + """ + + @property + def ptr(self) -> int: ... + @property + def is_sealed(self) -> bool: ... + def is_valid(self) -> bool: ... + def __len__(self) -> int: ... + def __bytes__(self) -> bytes: ... + def __str__(self) -> str: ... + @overload + def __setitem__(self, item: int, value: int): ... + @overload + def __setitem__(self, item: slice, value: bytes | bytearray): ... + @_unstable @final class ZShm: - """A SHM buffer""" + """An immutable SHM buffer. + + Implements the Python buffer protocol for read-only memoryviews. + """ def is_valid(self) -> bool: ... def __bytes__(self) -> bytes: ...