Skip to content

Commit d972cc0

Browse files
HaoZekeLuthaf
andcommitted
Fix tests when running under miri
There were a couple of violations of stacked borrow that where fixed by using `Box<i64>` instead of inline i64 for the shape/stride values. Co-Authored-By: Guillaume Fraux <guillaume.fraux@epfl.ch>
1 parent dc92b57 commit d972cc0

5 files changed

Lines changed: 69 additions & 20 deletions

File tree

.github/workflows/tests.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,16 @@ jobs:
3434
run: |
3535
pip install numpy
3636
cargo test --all-features
37+
38+
miri:
39+
runs-on: ubuntu-latest
40+
steps:
41+
- uses: actions/checkout@v5
42+
43+
- name: Setup toolchain
44+
run: |
45+
rustup update nightly && rustup default nightly
46+
rustup component add miri
47+
48+
- name: Run tests under miri
49+
run: cargo miri test --all-features

src/data_types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ mod tests {
247247

248248
#[test]
249249
#[cfg(feature = "half")]
250+
#[cfg_attr(miri, ignore)]
250251
fn test_half_precision() {
251252
// Test that half::f16 correctly maps to kDLFloat with 16 bits
252253
let dtype = half::f16::get_dlpack_data_type();

src/pyo3.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
//! # Examples
2121
//!
2222
//! ```
23+
//! # #[cfg(miri)] fn main() {}
24+
//! # #[cfg(not(miri))]
25+
//! # fn main() {
2326
//! use pyo3::prelude::*;
2427
//! use pyo3::types::IntoPyDict;
2528
//! use pyo3::ffi::c_str;
@@ -55,6 +58,7 @@
5558
//!
5659
//! assert_eq!(array, ndarray::arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]));
5760
//! });
61+
//! # }
5862
//! ```
5963
6064
use crate::sys::{self, DLManagedTensorVersioned};
@@ -334,6 +338,7 @@ mod tests {
334338
macro_rules! test_numpy_to_ndarray_via_dlpack_dtype {
335339
($test_name:ident, $rust_type:ty, $np_dtype:expr) => {
336340
#[test]
341+
#[cfg_attr(miri, ignore)]
337342
fn $test_name() -> PyResult<()> {
338343
Python::initialize();
339344
Python::attach(|py| {
@@ -370,6 +375,7 @@ result_capsule = array.__dlpack__()
370375
macro_rules! test_ndarray_to_numpy_via_dlpack_dtype {
371376
($test_name:ident, $rust_type:ty, $np_dtype:expr) => {
372377
#[test]
378+
#[cfg_attr(miri, ignore)]
373379
fn $test_name() -> PyResult<()> {
374380
Python::initialize();
375381
Python::attach(|py| {
@@ -414,6 +420,7 @@ assert np.allclose(array, expected)
414420
test_ndarray_to_numpy_via_dlpack_dtype!(test_to_numpy_i64, i64, "int64");
415421

416422
#[test]
423+
#[cfg_attr(miri, ignore)]
417424
fn test_null_strides_fails_conversion() -> PyResult<()> {
418425
Python::initialize();
419426
Python::attach(|py| {
@@ -456,6 +463,7 @@ assert np.allclose(array, expected)
456463
}
457464

458465
#[test]
466+
#[cfg_attr(miri, ignore)]
459467
fn test_v1_0_null_strides_allowed() -> PyResult<()> {
460468
Python::initialize();
461469
Python::attach(|py| {

src/sync.rs

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,12 @@ struct ManagerContextMutex<T> where T: 'static {
2424
#[borrows(array)]
2525
#[covariant]
2626
lock: MutexGuard<'this, Vec<T>>,
27-
shape: i64,
28-
stride: i64,
27+
// Use Box<i64> so that pointers derived via with_*_mut target heap
28+
// memory rather than inline struct fields. This avoids Stacked Borrows
29+
// violations when multiple with_*_mut calls each create exclusive
30+
// reborrows of the ouroboros struct.
31+
shape: Box<i64>,
32+
stride: Box<i64>,
2933
}
3034

3135
unsafe extern "C" fn mutex_deleter_fn<T>(manager: *mut sys::DLManagedTensorVersioned) where T: 'static {
@@ -41,24 +45,24 @@ impl<T> TryFrom<Arc<Mutex<Vec<T>>>> for DLPackTensor where T: GetDLPackDataType
4145
let ctx = ManagerContextMutexBuilder {
4246
array: array,
4347
lock_builder: |array| { array.lock().expect("could not lock the mutex") },
44-
shape: 0,
45-
stride: 1,
48+
shape: Box::new(0),
49+
stride: Box::new(1),
4650
};
4751
let mut ctx = Box::new(ctx.build());
4852

4953
// set the shape after acquiring the lock to avoid deadlocks
5054
let shape = ctx.with_lock(|lock| lock.len() as i64);
51-
ctx.with_shape_mut(|v| *v = shape);
55+
ctx.with_shape_mut(|v| **v = shape);
5256

5357
// extract pointers out of the boxed context to use in the DLPack tensor
5458
let mut shape_ptr = std::ptr::null_mut();
5559
ctx.with_shape_mut(|shape| {
56-
shape_ptr = shape as *mut i64;
60+
shape_ptr = shape.as_mut();
5761
});
5862

5963
let mut stride_ptr = std::ptr::null_mut();
6064
ctx.with_stride_mut(|stride| {
61-
stride_ptr = stride as *mut i64;
65+
stride_ptr = stride.as_mut();
6266
});
6367

6468
let mut data = std::ptr::null_mut();
@@ -102,8 +106,8 @@ struct ManagerContextRwLock<T> where T: 'static {
102106
#[borrows(array)]
103107
#[covariant]
104108
lock: RwLockWriteGuard<'this, Vec<T>>,
105-
shape: i64,
106-
stride: i64,
109+
shape: Box<i64>,
110+
stride: Box<i64>,
107111
}
108112

109113
unsafe extern "C" fn rwlock_deleter_fn<T>(manager: *mut sys::DLManagedTensorVersioned) where T: 'static {
@@ -119,24 +123,24 @@ impl<T> TryFrom<Arc<RwLock<Vec<T>>>> for DLPackTensor where T: GetDLPackDataType
119123
let ctx = ManagerContextRwLockBuilder {
120124
array: array,
121125
lock_builder: move |array| { array.write().expect("could not lock the rwlock") },
122-
shape: 0,
123-
stride: 1,
126+
shape: Box::new(0),
127+
stride: Box::new(1),
124128
};
125129
let mut ctx = Box::new(ctx.build());
126130

127131
// set the shape after acquiring the lock to avoid deadlocks
128132
let shape = ctx.with_lock(|lock| lock.len() as i64);
129-
ctx.with_shape_mut(|v| *v = shape);
133+
ctx.with_shape_mut(|v| **v = shape);
130134

131135
// extract pointers out of the boxed context to use in the DLPack tensor
132136
let mut shape_ptr = std::ptr::null_mut();
133137
ctx.with_shape_mut(|shape| {
134-
shape_ptr = shape as *mut i64;
138+
shape_ptr = shape.as_mut();
135139
});
136140

137141
let mut stride_ptr = std::ptr::null_mut();
138142
ctx.with_stride_mut(|stride| {
139-
stride_ptr = stride as *mut i64;
143+
stride_ptr = stride.as_mut();
140144
});
141145

142146
let mut data = std::ptr::null_mut();
@@ -215,4 +219,27 @@ mod tests {
215219
let lock = data.read().unwrap();
216220
assert_eq!(&*lock, &[1, 42, 3]);
217221
}
222+
223+
// Last-ref tests: the tensor holds the only Arc reference, so dropping
224+
// it actually deallocates the ManagerContext via the deleter function.
225+
226+
#[test]
227+
fn test_mutex_last_arc_ref() {
228+
let data = Arc::new(Mutex::new(vec![1i32, 2, 3]));
229+
230+
let mut tensor: DLPackTensor = data.try_into().unwrap();
231+
let tensor_mut_ref = tensor.as_mut();
232+
let slice: &mut [i32] = tensor_mut_ref.try_into().unwrap();
233+
assert_eq!(slice, &[1, 2, 3]);
234+
}
235+
236+
#[test]
237+
fn test_rwlock_last_arc_ref() {
238+
let data = Arc::new(RwLock::new(vec![1i32, 2, 3]));
239+
240+
let mut tensor: DLPackTensor = data.try_into().unwrap();
241+
let tensor_mut_ref = tensor.as_mut();
242+
let slice: &mut [i32] = tensor_mut_ref.try_into().unwrap();
243+
assert_eq!(slice, &[1, 2, 3]);
244+
}
218245
}

src/vec.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ impl<'a, T> TryFrom<DLPackTensorRefMut<'a>> for &'a mut [T] where T: DLPackPoint
125125

126126
struct ManagerContext<T> {
127127
array: T,
128-
shape: i64,
129-
stride: i64,
128+
shape: Box<i64>,
129+
stride: Box<i64>,
130130
}
131131

132132
unsafe extern "C" fn deleter_fn<T>(manager: *mut sys::DLManagedTensorVersioned) {
@@ -144,12 +144,12 @@ macro_rules! impl_try_from {
144144
let len = value.len();
145145
let mut ctx = Box::new(ManagerContext {
146146
array: value,
147-
shape: len as i64,
148-
stride: 1,
147+
shape: Box::new(len as i64),
148+
stride: Box::new(1),
149149
});
150150

151-
let shape_ptr = &mut ctx.shape;
152-
let stride_ptr = &mut ctx.stride;
151+
let shape_ptr = ctx.shape.as_mut();
152+
let stride_ptr = ctx.stride.as_mut();
153153

154154
let dl_tensor = sys::DLTensor {
155155
data: ctx.array.as_ptr().cast_mut().cast(),

0 commit comments

Comments
 (0)